Giter VIP home page Giter VIP logo

gigagan-pytorch's Introduction

GigaGAN - Pytorch

Implementation of GigaGAN (project page), new SOTA GAN out of Adobe.

I will also add a few findings from lightweight gan, for faster convergence (skip layer excitation) and better stability (reconstruction auxiliary loss in discriminator)

It will also contain the code for the 1k - 4k upsamplers, which I find to be the highlight of this paper.

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

Appreciation

  • StabilityAI and ๐Ÿค— Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • ๐Ÿค— Huggingface for their accelerate library

  • All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models

  • Xavier for the very helpful code review, and for discussions on how the scale invariance in the discriminator should be built!

  • @CerebralSeed for pull requesting the initial sampling code for both the generator and upsampler!

  • Keerth for the code review and pointing out some discrepancies with the paper!

Install

$ pip install gigagan-pytorch

Usage

Simple unconditional GAN, for starters

import torch

from gigagan_pytorch import (
    GigaGAN,
    ImageDataset
)

gan = GigaGAN(
    generator = dict(
        dim_capacity = 8,
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        image_size = 256,
        dim_max = 512,
        num_skip_layers_excite = 4,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        unconditional = True
    ),
    amp = True
).cuda()

# dataset

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

dataloader = dataset.get_dataloader(batch_size = 1)

# you must then set the dataloader for the GAN before training

gan.set_dataloader(dataloader)

# training the discriminator and generator alternating
# for 100 steps in this example, batch size 1, gradient accumulated 8 times

gan(
    steps = 100,
    grad_accum_every = 8
)

# after much training

images = gan.generate(batch_size = 4) # (4, 3, 256, 256)

For unconditional Unet Upsampler

import torch
from gigagan_pytorch import (
    GigaGAN,
    ImageDataset
)

gan = GigaGAN(
    train_upsampler = True,     # set this to True
    generator = dict(
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        dim = 32,
        image_size = 256,
        input_image_size = 64,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        multiscale_input_resolutions = (128,),
        unconditional = True
    ),
    amp = True
).cuda()

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

dataloader = dataset.get_dataloader(batch_size = 1)

gan.set_dataloader(dataloader)

# training the discriminator and generator alternating
# for 100 steps in this example, batch size 1, gradient accumulated 8 times

gan(
    steps = 100,
    grad_accum_every = 8
)

# after much training

lowres = torch.randn(1, 3, 64, 64).cuda()

images = gan.generate(lowres) # (1, 3, 256, 256)

Losses

  • G - Generator
  • MSG - Multiscale Generator
  • D - Discriminator
  • MSD - Multiscale Discriminator
  • GP - Gradient Penalty
  • SSL - Auxiliary Reconstruction in Discriminator (from Lightweight GAN)
  • VD - Vision-aided Discriminator
  • VG - Vision-aided Generator
  • CL - Generator Constrastive Loss
  • MAL - Matching Aware Loss

A healthy run would have G, MSG, D, MSD with values hovering between 0 to 10, and usually staying pretty constant. If at any time after 1k training steps these values persist at triple digits, that would mean something is wrong. It is ok for generator and discriminator values to occasionally dip negative, but it should swing back up to the range above.

GP and SSL should be pushed towards 0. GP can occasionally spike; I like to imagine it as the networks undergoing some epiphany

Multi-GPU Training

The GigaGAN class is now equipped with ๐Ÿค— Accelerator. You can easily do multi-gpu training in two steps using their accelerate CLI

At the project root directory, where the training script is, run

$ accelerate config

Then, in the same directory

$ accelerate launch train.py

Todo

  • make sure it can be trained unconditionally

  • read the relevant papers and knock out all 3 auxiliary losses

    • matching aware loss
    • clip loss
    • vision-aided discriminator loss
    • add reconstruction losses on arbitrary stages in the discriminator (lightweight gan)
    • figure out how the random projections are used from projected-gan
    • vision aided discriminator needs to extract N layers from the vision model in CLIP
    • figure out whether to discard CLS token and reshape into image dimensions for convolution, or stick with attention and condition with adaptive layernorm - also turn off vision aided gan in unconditional case
  • unet upsampler

    • add adaptive conv
    • modify latter stage of unet to also output rgb residuals, and pass the rgb into discriminator. make discriminator agnostic to rgb being passed in
    • do pixel shuffle upsamples for unet
  • get a code review for the multi-scale inputs and outputs, as the paper was a bit vague

  • add upsampling network architecture

  • make unconditional work for both base generator and upsampler

  • make text conditioned training work for both base and upsampler

  • make recon more efficient by random sampling patches

  • make sure generator and discriminator can also accept pre-encoded CLIP text encodings

  • do a review of the auxiliary losses

    • add contrastive loss for generator
    • add vision aided loss
    • add gradient penalty for vision aided discr - make optional
    • add matching awareness loss - figure out if rotating text conditions by one is good enough for mismatching (without drawing an additional batch from dataloader)
    • make sure gradient accumulation works with matching aware loss
    • matching awareness loss runs and is stable
    • vision aided trains
  • add some differentiable augmentations, proven technique from the old GAN days

    • remove any magic being done with automatic rgbs processing, and have it explicitly passed in - offer functions on the discriminator that can process real images into the right multi-scales
    • add horizontal flip for starters
  • move all modulation projections into the adaptive conv2d class

  • add accelerate

    • works single machine
    • works for mixed precision (make sure gradient penalty is scaled correctly), take care of manual scaler saving and reloading, borrow from imagen-pytorch
    • make sure it works multi-GPU for one machine
    • have someone else try multiple machines
  • clip should be optional for all modules, and managed by GigaGAN, with text -> text embeds processed once

  • add ability to select a random subset from multiscale dimension, for efficiency

  • port over CLI from lightweight|stylegan2-pytorch

  • hook up laion dataset for text-image

Citations

@misc{https://doi.org/10.48550/arxiv.2303.05511,
    url     = {https://arxiv.org/abs/2303.05511},
    author  = {Kang, Minguk and Zhu, Jun-Yan and Zhang, Richard and Park, Jaesik and Shechtman, Eli and Paris, Sylvain and Park, Taesung},  
    title   = {Scaling up GANs for Text-to-Image Synthesis},
    publisher = {arXiv},
    year    = {2023},
    copyright = {arXiv.org perpetual, non-exclusive license}
}
@article{Liu2021TowardsFA,
    title   = {Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis},
    author  = {Bingchen Liu and Yizhe Zhu and Kunpeng Song and A. Elgammal},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2101.04775}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Karras2020ada,
    title     = {Training Generative Adversarial Networks with Limited Data},
    author    = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
    booktitle = {Proc. NeurIPS},
    year      = {2020}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.