Giter VIP home page Giter VIP logo

xswindiffusion's Introduction

Hierarchical U-Net Vision Transformers with Residual Cross Attention for Latent Diffusion

Authors

Arman Ommid, Mayank Jain

Setup

Download the Repository

git clone https://github.com/ArmanOmmid/XSwinDiffusion.git

Install Requirements

pip install timm
pip install diffusers
pip install accelerate
pip install torchinfo
pip install --user scipy==1.11.1

How to Run our Code:

Diffusion Training

Diffusion Training Notebook

Loss Comparison

Loss Comparison Notebook

Inference Sampling

Inference Sampling Notebook

Evaluation Metrics

Evaluation Metrics Notebook

Custom Models

  • XSwin

    • Location: /models/xswin.py
    • Relevant Custom Modules: /models/modules/normal
    • Description: Our Segmentation Backbone
    • Implementation: As an isolated segmentation network, XSwin is largely based on SwinV2 Blocks supported by outer convolutional blocks and inner global attention ViT blocks for the bottleneck. This promotes heiarchical, multiscale learning with appropriate inductive biases. The architecture also feature localized residual cross attention that dynamically aggregate shallow encoder features for refinement before being combined with deep decoder features for further processing. The ViT bottleneck recieves positional embeddings with the features while the convolutional skip connection is just traditional concatonation.
  • XSwinDiffusion

    • Location: /models/xswin_diffusion.py
    • Relevant Custom Modules: /models/modules/modulated
    • Run Script: run_diffusion.py
    • Description: Our Conditioning Modulated Denoising Backbone
    • Implementation: We take our XSwin isolated segmentation backbone and make the following modifications. First, we create frozen parameters for time step and class label conditioning based on the number of diffusion steps and the number of classes using Fourier based embeddings. We then augment all parameterized layers with conditioning modulation layers using adaptive layer normalization as per DiT to encode conditioning information for time steps and class labels efficiently. We also make sure that the input and output hidden dimensions are modifiable, ensuring the output dimensions output both the predicted image and the predicted noise as per the DiT diffusion pipeline. We also implement an additional forward function focused on classifier free guidance as per DiT.

Baseline Models

  • DiT

    • Location: /models/dit.py
    • Relevant Modules: None
    • Run Script: run_diffusion_dit.py
    • Description: The original DiT Implementation to compare with
    • Implementation: Identical implementation to that of DiT
  • UViT

    • Location: /models/uvit.py
    • Relevant Modules: None
    • Run Script: run_diffusion_uvit.py
    • Description: Actually a DiT with UViT based skip connections. There are subtle differences.
    • Implementation: *DiT with UNet structure by storing shallow "encoder" featuresand concatonating them with deep "decoder" features. After concatonation, they are passed through a linear layer for downsampling back to the original hidden dimension size as per the UViT design. The main differences between UDiT and UViT are namely that conditioning is done with additional sequence tokens with UViT while we use adaptive layer normalization modulation like with DiT. *

Auxilary Code

Diffusion Code

  • Diffusion Pipeline
    • Location: /diffusion
    • Description Diffusion Pipelining Code from OpenAI
  • Miscellaneous Modules
    • Location: /models/modules
      • conditioned_sequential.py : Implementation of nn.Sequential with Conditioning Information
      • embeddings.py : Implementation of time step and class label embeddings from DiT as well as our custom Modulator layer
      • initialize.py : Weight initializers for various and specific layers
      • positional.py " Positional Embeddings from FAIR
  • Validation Code
    • Location: /runners, /data
    • Description : Validation code to validate the isolated XSwin backbone

xswindiffusion's People

Contributors

armanommid avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.