Giter VIP home page Giter VIP logo

mimo-unet's Introduction

MIMO-UNet - Official Pytorch Implementation

PWC PWC

This repository provides the official PyTorch implementation of the following paper:

Rethinking Coarse-to-Fine Approach in Single Image Deblurring

Sung-Jin Cho *, Seo-Won Ji *, Jun-Pyo Hong, Seung-Won Jung, Sung-Jea Ko

In ICCV 2021. (* indicates equal contribution)

Paper: https://arxiv.org/abs/2108.05054

Abstract: Coarse-to-fine strategies have been extensively used for the architecture design of single image deblurring networks. Conventional methods typically stack sub-networks with multi-scale input images and gradually improve sharpness of images from the bottom sub-network to the top sub-network, yielding inevitably high computational costs. Toward a fast and accurate deblurring network design, we revisit the coarse-to-fine strategy and present a multi-input multi-output U-net (MIMO-UNet). The MIMO-UNet has three distinct features. First, the single encoder of the MIMO-UNet takes multi-scale input images to ease the difficulty of training. Second, the single decoder of the MIMO-UNet outputs multiple deblurred images with different scales to mimic multi-cascaded U-nets using a single U-shaped network. Last, asymmetric feature fusion is introduced to merge multi-scale features in an efficient manner. Extensive experiments on the GoPro and RealBlur datasets demonstrate that the proposed network outperforms the state-of-the-art methods in terms of both accuracy and computational complexity.


Contents

The contents of this repository are as follows:

  1. Dependencies
  2. Dataset
  3. Train
  4. Test
  5. Performance
  6. Model

Dependencies

  • Python
  • Pytorch (1.4)
    • Different versions may cause some errors.
  • scikit-image
  • opencv-python
  • Tensorboard

Dataset

  • Download deblur dataset from the GoPro dataset .

  • Unzip files dataset folder.

  • Preprocess dataset by running the command below:

    python data/preprocessing.py

After preparing data set, the data folder should be like the format below:

GOPRO
├─ train
│ ├─ blur    % 2103 image pairs
│ │ ├─ xxxx.png
│ │ ├─ ......
│ │
│ ├─ sharp
│ │ ├─ xxxx.png
│ │ ├─ ......
│
├─ test    % 1111 image pairs
│ ├─ ...... (same as train)


Train

To train MIMO-UNet+ , run the command below:

python main.py --model_name "MIMO-UNetPlus" --mode "train" --data_dir "dataset/GOPRO"

or to train MIMO-UNet, run the command below:

python main.py --model_name "MIMO-UNet" --mode "train" --data_dir "dataset/GOPRO"

Model weights will be saved in results/model_name/weights folder.


Test

To test MIMO-UNet+ , run the command below:

python main.py --model_name "MIMO-UNetPlus" --mode "test" --data_dir "dataset/GOPRO" --test_model "MIMO-UNetPlus.pkl"

or to test MIMO-UNet, run the command below:

python main.py --model_name "MIMO-UNet" --mode "test" --data_dir "dataset/GOPRO" --test_model "MIMO-UNet.pkl"

Output images will be saved in results/model_name/result_image folder.


Performance

Method MIMO-UNet MIMO-UNet+ MIMO-UNet++
PSNR (dB) 31.73 32.45 32.68
SSIM 0.951 0.957 0.959
Runtime (s) 0.008 0.017 0.040

GPU syncronization issue on measuring inference time

We recently found an issue about measuring the inference time of networks implemented using the PyTorch framework.

The official codes of many papers (more than twenty papers at a glance) presented at the top conferences measured the inference time simply using time measuring functions such as time.time(), time.perf_counter(), or tqdm. However, since the CUDA calls are asynchronous, the synchronized inference time needs to be measured using torch.cuda.synchronize().

We thus present Table and Figure containing the re-measured inference time using the synchronization mode for various methods developed with the PyTorch framework as shown below.

The inference times presented below were all measured using an RTX3090 due to the recent upgrade of our system. (The use of VRAM was restricted to 12 GB, which is the same value as that of Titan XP)

Methods Async-Time*(s) Sync-Time** (s) PSNR
DMPHN 0.308 0.588 31.20
MT-RNN 0.031 0.394 31.15
MPRNet 0.075 1.474 32.66
MIMO-UNet 0.012 0.130 31.73
MIMO-UNet + 0.025 0.282 32.45
MIMO-UNet ++*** 0.049 1.115 32.68

* indicates inference time measured without torch.cuda.synchronize().

** indicates inference time measured with torch.cuda.synchronize().

*** In case for MIMO-UNet++, we used batch inference (inference two times with batch size of 2) for geometrical ensemble. However, we noticed that the inference time measured with torch.cuda.synchronize() cannot take advantage of the batch inference, resulting MIMO-UNet++ performed 4x slower than MIMO-UNet+. (still 32% faster than the conventional SOTA method)

As the GPU was changed from Titan Xp to 3090, the asynchronized inference times of the conventional methods were reduced, but the inference times of MIMO-UNet and its variants were maintained. We will conduct additional tests on this issue and will update this page if there is any progress.

We hope this observation will be helpful to many researchers in this field.


Model

We provide our pre-trained models. You can test our network following the instruction above.

* The test code for RealBlur dataset will be released soon.

** We measured PSNR using official RealBlur test code. You can get the PSNR we achieved by cloning and following the RealBlur repository.

mimo-unet's People

Contributors

chosj95 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

mimo-unet's Issues

Realblur 데이터셋에서의 실험에 대한 질문입니다.

안녕하십니까,

화질 개선 연구를 진행하고 있는 석사과정 학생입니다. 선배님의 훌륭한 연구에 감사 드립니다.
올려주신 MIMO-UNet 모델을 Realblur-J데이터셋에서 PSNR을 확인한 결과 28.99가 나왔습니다.
제가 성능 평가 방식에 문제가 있는지 확인 해 주실 수 있으십니까?
dataset함수는 제가 직접 구현했으며, Reablur-J 데이터셋의 Test_list 파일을 통해 데이터를 불러오는 방식으로 구현했습니다.
PSNR은 올려주신 방식과 동일하게 구했으며, 이미지 사이즈가 8의 배수가 될 수 있도록 zero padding후 forward할 수 있도록했습니다. 이후 출력 이미지를 잘라서 Ground truth와 비교하여 평가했습니다.

감사합니다.

RealBlur data for training

Hello,

How do you train the model for RealBlur?
Do you fine-tune on the GoPro pre-trained weights or train form scratch with GoPro+BSD+RealBlurR datasets like MPRNet?
And, do you plan to release the deblurring results of RealBlur_J and RealBlur_R?

loss implementation different from paper

Hi,

Thanks for sharing this great work~ I found your implementation of loss here is different from the paper, where the "1/t_k" normalizationweight is not used in your code. Could you help clarify this?

Training Problems

I encountered some problems when training the Go_Pro dataset on MIMO-UNet++. May I ask the author during the training process, will there be some unknown color blocks in the image generated at certain positions of the image?

downsample function and loss function normalization at training stage

Hi @chosj95 , MIMO-UNet is a fantastic work! THX for sharing!
However, there are some problems confusing me a lot.

  • During the training stage, the different size of blurry images are generated by nearest interpolation. However, the corresponding supervision of different size sharp images are generated by bilinear interpolation, rather than nearest interpolation again. Intuitively, the interpolation function should be the same.
  • In the paper, function 7&8 have denominators in loss function for normalization. However, the code of loss function hasn't the denominators that I have mentioned above. I take it that chances are that the model will have better performance with normalization in loss function.

I am eagerly waitting for your explanation. THX a lot!

PSNR performance

Does anyone reproduce the results given in the paper. I train mimo-unet, but it only reaches 31.40 on PSNR, which is 0.33 away from the 31.73 given in the paper.

FLOPs?

Hi,

Thanks for the great work. I didn't find the FLOPs of MIMO-Unet, MIMO-Unet+ in your paper. Could you please provide such numbers?

A question about results on RealBlur dataset

Hello, thanks for your excellent work.
I have some confusions about the dataset you used for RealBlur.
Could you please tell me which dataset(s) did you use to obtain the results in Table 2, both GoPro and RealBlur or just only RealBlur itself.
MPRNET only used RealBlur to get the RealBlur results.

Thanks.

Line 53 in eval.py

Thank you for your great work!

I have a question on evaluation code.
In line 53 of eval.py:
pred_clip += 0.5 / 255
0.5/255 is added.
What does that mean? and the code calculates PSNR with unadded one(pred_numpy). Why is it?

Thank you.

What is the actual effect of the model?

Hello, author, is your model suitable for mixed image restoration of motion blur and focus blur? Moreover, after training with motion blur data only, I tested the normal and focus blur images, and the results were extremely poor.

关于AFF模块

作者您好:
我想问下AFF融合模块的三个输入的维度以及是如何reshape操作的,期待您的回复

training time

Your work has been incredibly helpful to me, and I am very grateful. I was wondering if you could provide some information on the expected training time and required computing resources for the model. I would greatly appreciate it!

The setting about training on realblur

Hello!
I'd like to know how to train MIMO_UNet on the realblur dataset.
Are any other modifications needed besides the parameters mentioned in the paper?
For example, is the patch size still 256*256?

Looking forward to your response. Thank you!

The inference time is slower than that reported in the paper

I have test the MIMO-UNet and MIMO-UNet+ on a single 2080Ti card (Theoretical performance is higher than TitanXp), which takes about 15ms and 30ms. I didn't make any changes to the open source code, just run the test command (https://github.com/chosj95/MIMO-UNet#test) directly.

python main.py --model_name "MIMO-UNet" --mode "test" --data_dir "dataset/GOPRO" --test_model "MIMO-UNet.pkl"

Namespace(batch_size=4, data_dir='dataset/GOPRO', gamma=0.5, learning_rate=0.0001, lr_steps=[500, 1000, 1500, 2000, 2500, 3000], mode='test', model_name='MIMO-UNet', model_save_dir='results/MIMO-UNet/weights/', num_epoch=3000, num_worker=8, print_freq=100, result_dir='results/MIMO-UNet/result_image/', resume='', save_freq=100, save_image=False, test_model='MIMO-UNet.pkl', valid_freq=100, weight_decay=0)

For MIMO-UNet:

==========================================================
The average PSNR is 31.73 dB
Average time: 0.015028

And for MIMO-UNet+

==========================================================
The average PSNR is 32.45 dB
Average time: 0.030238
  1. Why the time of 8ms/17ms reported in the paper cannot be reproduced?
  2. Why the asynchronous inference time on 2080Ti or 3090(https://github.com/chosj95/MIMO-UNet#gpu-syncronization-issue-on-measuring-inference-time) are slower than Titan XP(https://github.com/chosj95/MIMO-UNet#performance)?

In addition, I think the CUDA synchronized time should be used when reporting the time performance. The unsynchronized time can not correctly measure the speed and complexity of the model.

Testing

Could you please clarify your question? Are you asking if the GOPRO test set was resized to 256x256 during testing? I tried applying a center crop of 256x256 and found that it improved the results by nearly 0.5 points.

Eval issue

The valid_freq paramater is 100 in default. When your epoch is bigger than 100, you need to create a folder in GOPRO/tarin below, and put some sharp imgs in it.

the image shown in the paper

In the Page 5 in paper (Figure 5. Several examples on the GoPro test dataset. )
I found the data in the Gopro dataset, the image is “test/blur/GOPR0854_11_00/001653_3.png” right?
I don't see the original image as blurry as the image shown in the paper.
Were you testing the image with the original Gopro dataset?

SSIM

When you test SSIM ,did you use the function from skimage.metrics import structural_similarity?And set
structural_similarity(p_numpy, label_numpy, data_range=1,multichannel=True)

Prediction Questions

Hello, I ran the algorithm as you requested, but no prediction images appeared

save deblur image in eval phase

In line 53 of eval.py(show as below), the output tensor is turned to PIL image after adding 0.5/255, then is saved. Why is this? This tensor has already been clipped to 0-1. I can't find the same processing way in the official document of torchvision

if args.save_image:
    save_name = os.path.join(args.result_dir, name[0])
    pred_clip += 0.5 / 255
    pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB')
    pred.save(save_name)

License

Hi, Can i get license information?

Thanks for awesome work!

how to use multi-gpus for training

when i use the nn.DataParallel(model), i got the problem as follows
Traceback (most recent call last):
File "main.py", line 67, in
main(args)
File "main.py", line 31, in main
_train(model, args)
File "MIMO-UNet/train.py", line 53, in _train
pred_img = model(input_img)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py", line 154, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py", line 159, in replicate
return replicate(module, device_ids, not torch.is_grad_enabled())
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/replicate.py", line 88, in replicate
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/replicate.py", line 71, in _broadcast_coalesced_reshape
tensor_copies = Broadcast.apply(devices, *tensors)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/_functions.py", line 21, in forward
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
File "/usr/local/lib/python3.6/dist-packages/torch/cuda/comm.py", line 39, in broadcast_coalesced
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: inputs must be on unique devices

about the training time

Hello. Thanks for your great work and code.
I train it using the Tesla V100 16G, and it takes about 4 hours for 50 epochs.
Could you please tell me how long it took you to train for 3000 epochs?
Thanks in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.