Giter VIP home page Giter VIP logo

ista-net-pytorch's Introduction

ISTA-Net: Interpretable Optimization-Inspired Deep Network for Image Compressive Sensing [PyTorch version]

including codes of CS for natural image (CS-NI) and CS for magnetic resonance imaging (CS-MRI)

This repository is for ISTA-Net and ISTA-Net+ introduced in the following paper

Jian Zhang, Bernard Ghanem , "ISTA-Net: Interpretable Optimization-Inspired Deep Network for Image Compressive Sensing", CVPR 2018, [pdf] [Supp]

The code is built on PyTorch and tested on Ubuntu 16.04/18.04 and Windows 10 environment (Python3.x, PyTorch>=0.4) with 1080Ti GPU.

[Old Tensorflow Version]

Introduction

With the aim of developing a fast yet accurate algorithm for compressive sensing (CS) reconstruction of natural images, we combine in this paper the merits of two existing categories of CS methods: the structure insights of traditional optimization-based methods and the speed of recent network-based ones. Specifically, we propose a novel structured deep network, dubbed ISTA-Net, which is inspired by the Iterative Shrinkage-Thresholding Algorithm (ISTA) for optimizing a general L1 norm CS reconstruction model. To cast ISTA into deep network form, we develop an effective strategy to solve the proximal mapping associated with the sparsity-inducing regularizer using nonlinear transforms. All the parameters in ISTA-Net (\eg nonlinear transforms, shrinkage thresholds, step sizes, etc.) are learned end-to-end, rather than being hand-crafted. Moreover, considering that the residuals of natural images are more compressible, an enhanced version of ISTA-Net in the residual domain, dubbed ISTA-Net+, is derived to further improve CS reconstruction. Extensive CS experiments demonstrate that the proposed ISTA-Nets outperform existing state-of-the-art optimization-based and network-based CS methods by large margins, while maintaining fast computational speed.

ISTA-Net Figure 1. Illustration of the proposed ISTA-Net framework.

Contents

  1. Test-CS-NI
  2. Train-CS-NI
  3. Test-CS-MRI
  4. Train-CS-MRI
  5. Results
  6. Citation
  7. Acknowledgements

Test-CS-NI

Quick start

  1. All models for our paper have been put in './model'.

  2. Run the following scripts to test ISTA-Net models.

    You can use scripts in file 'TEST_ISTA_Net_scripts.sh' to produce results for our paper.

    # test scripts
    python TEST_CS_ISTA_Net.py --epoch_num 200 --cs_ratio 1 --layer_num 9
    python TEST_CS_ISTA_Net.py --epoch_num 200 --cs_ratio 4 --layer_num 9
    python TEST_CS_ISTA_Net.py --epoch_num 200 --cs_ratio 10 --layer_num 9
    python TEST_CS_ISTA_Net.py --epoch_num 200 --cs_ratio 25 --layer_num 9
    python TEST_CS_ISTA_Net.py --epoch_num 200 --cs_ratio 30 --layer_num 9
    python TEST_CS_ISTA_Net.py --epoch_num 200 --cs_ratio 40 --layer_num 9
    python TEST_CS_ISTA_Net.py --epoch_num 200 --cs_ratio 50 --layer_num 9
  3. Run the following scripts to test ISTA-Net+ models.

    You can use scripts in file 'TEST_ISTA_Net_plus_scripts.sh' to produce results for our paper.

    # test scripts
    python TEST_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 1 --layer_num 9
    python TEST_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 4 --layer_num 9
    python TEST_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 10 --layer_num 9
    python TEST_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 25 --layer_num 9
    python TEST_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 30 --layer_num 9
    python TEST_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 40 --layer_num 9
    python TEST_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 50 --layer_num 9

The whole test pipeline

  1. Prepare test data.

    The original test set11 is in './data'

  2. Run the test scripts.

    See Quick start

  3. Check the results in './result'.

Train-CS-NI

Prepare training data

  1. Trainding data (Training_Data.mat including 88912 image blocks) is in './data'. If not, please download it from GoogleDrive or BaiduPan [code: xy52].

  2. Place Training_Data.mat in './data' directory

Begin to train

  1. run the following scripts to train ISTA-Net models.

    You can use scripts in file 'Train_ISTA_Net_scripts.sh' to train models for our paper.

    # CS ratio 1, 4, 10, 25, 30, 40, 50
    # train scripts
    python Train_CS_ISTA_Net.py --cs_ratio 10 --layer_num 9
    python Train_CS_ISTA_Net.py --cs_ratio 25 --layer_num 9
    python Train_CS_ISTA_Net.py --cs_ratio 50 --layer_num 9
    python Train_CS_ISTA_Net.py --cs_ratio 1 --layer_num 9
    python Train_CS_ISTA_Net.py --cs_ratio 4 --layer_num 9
    python Train_CS_ISTA_Net.py --cs_ratio 30 --layer_num 9
    python Train_CS_ISTA_Net.py --cs_ratio 40 --layer_num 9

    We found that the re-trained ISTA-Net models may get a bit higher performance than the results reported in our paper.

  2. run the following scripts to train ISTA-Net+ models.

    You can use scripts in file 'Train_ISTA_Net_plus_scripts.sh' to train models for our paper.

     # CS ratio 1, 4, 10, 25, 30, 40, 50
    # train scripts
    python Train_CS_ISTA_Net_plus.py --cs_ratio 10 --layer_num 9
    python Train_CS_ISTA_Net_plus.py --cs_ratio 25 --layer_num 9
    python Train_CS_ISTA_Net_plus.py --cs_ratio 50 --layer_num 9
    python Train_CS_ISTA_Net_plus.py --cs_ratio 1 --layer_num 9
    python Train_CS_ISTA_Net_plus.py --cs_ratio 4 --layer_num 9
    python Train_CS_ISTA_Net_plus.py --cs_ratio 30 --layer_num 9
    python Train_CS_ISTA_Net_plus.py --cs_ratio 40 --layer_num 9

Test-CS-MRI

Quick start

  1. All models for our paper have been put in './model'.

  2. Run the following scripts to test ISTA-Net+ models.

    # test scripts
    python TEST_MRI_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 20 --layer_num 9
    python TEST_MRI_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 30 --layer_num 9
    python TEST_MRI_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 40 --layer_num 9
    python TEST_MRI_CS_ISTA_Net_plus.py --epoch_num 200 --cs_ratio 50 --layer_num 9

The whole test pipeline

  1. Prepare test data.

    The original test BrainImages_test is in './data'

  2. Run the test scripts.

    See Quick start

  3. Check the results in './result'.

Train-CS-MRI

Prepare training data

  1. Trainding data (Training_BrainImages_256x256_100.mat including 88912 image blocks) is in './data'. If not, please download it from GoogleDrive.

  2. Place Training_BrainImages_256x256_100.mat in './data' directory

Begin to train

  1. run the following scripts to train ISTA-Net+ models.

    You can use scripts in file 'Train_ISTA_Net_plus_scripts.sh' to train models for our paper.

    # train scripts
    python Train_MRI_CS_ISTA_Net_plus.py --cs_ratio 20 --layer_num 9
    python Train_MRI_CS_ISTA_Net_plus.py --cs_ratio 30 --layer_num 9
    python Train_MRI_CS_ISTA_Net_plus.py --cs_ratio 40 --layer_num 9
    python Train_MRI_CS_ISTA_Net_plus.py --cs_ratio 50 --layer_num 9

Results

Quantitative Results

Visual Results

Citation

If you find the code helpful in your resarch or work, please cite the following papers.

@inproceedings{zhang2018ista,
  title={ISTA-Net: Interpretable optimization-inspired deep network for image compressive sensing},
  author={Zhang, Jian and Ghanem, Bernard},
  booktitle={CVPR},
  pages={1828--1837},
  year={2018}
}

Acknowledgements

ista-net-pytorch's People

Contributors

ae86zhizhi avatar jianzhangcs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

ista-net-pytorch's Issues

Source of MRI images

Hi,
The images used for the MRI case are not referenced to in the paper. Can you specify where are these images from?

Thank you

Sampling matrix

How is the sampling matrix obtained? which algorithm is used? Thanks a lot!

Train-CS-MRI

您好 我在训练MRI_plus版本时,遇到维度不匹配的问题,具体的报错如下: x = x - self.lambda_step * fft_forback(x, mask)
RuntimeError: The size of tensor a (2) must match the size of tensor b (256) at non-singleton dimension 5 请问这该如何解决呢?

missing mask_%d.mat in sampling matrix directory

Hi, I'm trying to train a model for MRI reconstruction but I am not able to make it work since it seems that the sampling matrices are not in the directory.

Could you please provide these matrices ? Thank you

Test-CS-MRI

Hi, when I run the test ISTA-Net+ for MRI_CS I get the flowing result:
MRI CS Reconstruction Start
CS ratio is 50, Avg Initial PSNR/SSIM for Brainimages_test is nan/nan
CS ratio is 50, Avg Proposed PSNR/SSIM for Brainimages_test is nan/nan, Epoch number of model is 200
MRI CS Reconstruction End
/home/amax/anaconda3/envs/pytorch/lib/python3.8/site-packages/numpy/core/fromnumeric.py:3372: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/home/amax/anaconda3/envs/pytorch/lib/python3.8/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in true_divide
ret = ret.dtype.type(ret / rcount)

Any suggestion thank you in advance.

Sampling kernel_size

The size of the convolution kernel in the sampling process is 33*33,
Could this size be modified?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.