Giter VIP home page Giter VIP logo

iunets's Introduction

iUNets - Fully invertible U-Nets in Pytorch

This library enables highly memory-efficient training of fully-invertible U-Nets (iUNets) in 1D, 2D and 3D for use cases such as segmentation of medical images. It is based on the paper iUNets: Fully invertible U-Nets with Learnable Up- and Downsampling by Christian Etmann, Rihuan Ke & Carola-Bibiane Schönlieb.

The library can be installed via the following command:

pip install iunets

Official documentation: https://iunets.readthedocs.io.

By combining well-known reversible layers (such as additive coupling layers) with novel learnable invertible up- and downsampling operators and suitable channel splitting/concatenation, the iUNet is fully bijective. This allows for reconstructing activations instead of storing them. As such, the memory demand is (in theory) independent of the number of layers.

img/iunet_for_segmentation.png

An iUNet used as a memory-efficient sub-network for segmenting a 3-channel input into 10 classes.

The following table exemplifies the memory savings that can be achieved by applying our memory-efficient gradient calculation, in contrast to the conventional backpropagation procedure. Details are found in the paper.

Depth Conventional Ours Ratio
5 3.17 GB 0.85 GB 26.8%
10 5.90 GB 1.09 GB 18.4%
20 11.4 GB 1.57 GB 13.8%
30 16.8 GB 2.06 GB 12.2%

If you're using this code in a publication, please cite this as:

@inproceedings{etmann2020iunets,
  title={iUNets: learnable invertible up-and downsampling for large-scale inverse problems},
  author={Etmann, Christian and Ke, Rihuan and Sch{\"o}nlieb, Carola-Bibiane},
  booktitle={2020 IEEE 30th International Workshop on Machine Learning for Signal Processing (MLSP)},
  pages={1--6},
  year={2020},
  organization={IEEE}
}

Features

  • Easily set up memory-efficient iUNets in 1D, 2D or 3D, that can be trained like any other model in Pytorch and can be used e.g. for high-dimensional segmentation.
  • Highly customizable.
  • Learnable, invertible and possibly anisotropic up- and downsampling.
  • Orthogonal channel mixing.
  • Orthogonality for channel mixing and learnable invertible up- and downsampling is enforced by efficient Lie group methods, i.e. Cayley transforms and matrix exponentials of skew-symmetric matrices.
  • Quality-of-life features such as automatic, dynamic padding and unpadding (if required for invertibility), as well as model summaries.

Requirements

iUNets are powered by the following two libraries:

Example usage

A simple 2D iUNet

A version of the iUNet depicted above can be created incredibly simply. Let's say that we want 2 additive coupling layers per resolution, both in the downsampling branch (left) and the upsampling branch (right).

from iunets import iUNet
model = iUNet(
    channels=(64,128,256,384),
    architecture=(2,2,2,2),
    dim=2
)
model.print_layout()

Output:

64-64-(32/32)--------------------------------------------------------(32/32)-64-64
---------128-128-(64/64)---------------------------------(64/64)-128-128----------
---------------------256-256-(160/96)--------(160/96)-256-256---------------------
---------------------------------384-384--384-384---------------------------------

This model can now be integrated into the normal Pytorch workflow (and in particular be used as a sub-network) just like any other torch.nn.Module, and it automatically employs the memory-efficient backpropagation.

A fully-customized 3D iUNet

While the above example shows that a simple iUNet can be created quite simply, our library also allows for a high degree of customization. Refer to the API documentation for more information.

from iunets import iUNet
from iunets.layers import create_standard_module
model = iUNet(
    channels=(7,15,35,91),
    dim=3,
    architecture=(2,3,1,3),
    create_module_fn=create_standard_module,
    module_kwargs={'depth': 3},
    slice_mode='double',
    resampling_stride=[2,2,(1,2,2)],
    learnable_resampling=True,
    resampling_init='haar',
    resampling_method='cayley',
    disable_custom_gradient=False,
    revert_input_padding=True,
    padding_mode='reflect',
    verbose=1
    )
model.print_layout()

Output:

Could not exactly create an iUNet with channels=(7, 15, 35, 91) and
resampling_stride=[(2, 2, 2), (2, 2, 2), (1, 2, 2)]. Instead using closest
achievable configuration: channels=(7, 16, 32, 92).
Average relative error: 0.0508

7-7-(5/2)-------------------------------------------------(5/2)-7-7
------16-16-16-(12/4)------------------------(12/4)-16-16-16-------
------------------32-(9/23)------------(9/23)-32-------------------
------------------------92-92-92--92-92-92-------------------------

iunets's People

Contributors

cetmann avatar yoelshoshan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.