Giter VIP home page Giter VIP logo

few-shot-diffusion-models's Introduction

Few-Shot Diffusion Models (FSDM)

Denoising diffusion probabilistic models (DDPM) are powerful hierarchical latent variable models with remarkable sample generation quality and training stability. These properties can be attributed to parameter sharing in the generative hierarchy, as well as a parameter-free diffusion-based inference procedure. In this paper, we present Few-Shot Diffusion Models (FSDM), a framework for few-shot generation leveraging conditional DDPMs. FSDMs are trained to adapt the generative process conditioned on a small set of images from a given class by aggregating image patch information using a set-based Vision Transformer (ViT). At test time, the model is able to generate samples from previously unseen classes conditioned on as few as 5 samples from that class. We empirically show that FSDM can perform few-shot generation and transfer to new datasets taking full advantage of the conditional DDPM.

teaser

Set the env

conda create -n fsdm python=3.6

git clone https://github.com/georgosgeorgos/few-shot-diffusion-models

cd few-shot-diffusion-models

pip install -r requirements.txt

Datasets

We train the models on small sets of dimension 2-20. Train/val/test sets use disjoint classes by default.

Binary:

  • Omniglot (back_eval) - (1 x 28 x 28) - 964/97/659

RGB:

  • CIFAR100 - (3 x 32 x 32) - 60/20/20
  • CIFAR100mix - (3 x 32 x 32) - 60/20/20
  • MinImageNet - (3 x 32 x 32) - 64/16/20
  • CelebA - (3 x 64 x 64) - 4444/635/1270

Training

Train a DDPM on CIFAR100

sh script/run.sh gpu_num ddpm_cifar100 

Train a FSDM model on CIFAR100 dataset with ViT encoder, FiLM conditioning and MEAN aggregation

sh script/run.sh gpu_num vfsddpm_cifar100_vit_film_mean

Train a MODEL on DATASET with ENCODER, CONDITIONING and AGGREGATION

sh script/run.sh gpu_num {dddpm, cddpm, sddpm, addpm, vfsddpm}_{omniglot, cifar100, cifar100mix, minimagenet, cub, celeba}_{vit, unet}_{mean, lag, cls, sum_patch_mean}

Sampling

Sample a FSDM model on CIFAR100 for new classes after 100K iterations 1000 samples

sh script/sample_conditional.sh gpu_num vfsddpm_cifar100_vit_film_mean_outdistro {date} 100000 1000

Metrics

Compute FID, IS, Precision, Recall for a FSDM model on CIFAR100 new classes

Acknoledgments

A lot of code and ideas borrowed from:

few-shot-diffusion-models's People

Contributors

georgosgeorgos avatar submission-conference24 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

few-shot-diffusion-models's Issues

Details about the Metrics Computation

Can you provide the details about the metrics computation?
I found the code in the following run file:

python metrics/metrics.py /scratch/gigi/fsddpm/cifar100_ddpm_sigma/sampling-conditional-out-distro-2022-04-11-01-22-01-001184/full_samples_conditional_10000x32x32x3_out-distro_5_transfer_minimagenet.npz /scratch/gigi/fsddpm/cifar100_ddpm_sigma/sampling-conditional-out-distro-2022-04-11-01-22-01-001184/full_samples_conditional_10000x32x32x3_out-distro_5_transfer_minimagenet.npz

Looks like the two '.npz' files are all full_samples_conditional_10000x32x32x3_in-distro_5.npz, how you selected the reference images in the new classes? Or can you provide the reference data?

there is no `run.sh` script

There is no run.sh script in the scripts folder. There is only one main.sh script. Is it the run.sh document? I renamed main.sh to run.sh, and then run bash script/run.sh 2num ddpm_ cifar100, there is no output, why?

About Sampling

Thank you for your excellent work. I would like to ask how to specify the category of the generated image during the sampling process. Is it by using the relevant seen or unseen image set as the input set of ViT to control the single-round image generation?

Pre-trained weights

Is there a plan to release pre-trained weights so that sampling/evaluation can be done without re-training everything from scratch?

Pickle files

Thank you for open-sourcing the project. I see a lot of issues with the repo, no explicit instruction to prepare data, no working training or evaluation scripts, etc. Do you plan to publish the dataset files at least so we can experiment with it?

Confusion in "mean_patch" pooling with ViT encoders

In vit.py in forward_set function

the transform module is being used twice in 1 forward evaluation.

First the common forward

x = self.dropout(x)
x_set = self.transformer(x)

Second during mean_patch pooling

elif self.pool == "mean_patch":
x = x_set[:, self.k:]
# attention here what you average
x = x.view(b, np//ns, ns, -1)
x = x.mean(dim = 2)
x = self.transformer(x)

Is this intended?

How to prepare data?

It seems datasets are expected to be packed in .pkl file, but the instruction on how to prepare the .pkl files are missing. Alternatively, would it be possible to provide pre-processed pkl files?

Here are errors I got due to missing files.

Traceback (most recent call last):
  File "main.py", line 112, in <module>
    main()
  File "main.py", line 49, in main
    TrainLoop(
  File "/home/roger/reproduction/few-shot-diffusion-models/model/set_diffusion/train_util.py", line 389, in run_loop
    batch = next(self.data)
  File "/home/roger/reproduction/few-shot-diffusion-models/dataset/__init__.py", line 58, in create_loader
    dataset = select_dataset(args, split)
  File "/home/roger/reproduction/few-shot-diffusion-models/dataset/__init__.py", line 26, in select_dataset
    dataset = BaseSetsDataset(**kwargs)
  File "/home/roger/reproduction/few-shot-diffusion-models/dataset/base.py", line 56, in __init__
    self.images, self.labels, self.map_cls = self.get_data()
  File "/home/roger/reproduction/few-shot-diffusion-models/dataset/base.py", line 81, in get_data
    with open(path, 'rb') as f:
FileNotFoundError: [Errno 2] No such file or directory: '/home/gigi/ns_data/cifar100/train_cifar100.pkl'

t_emd error in ViTset code

Im vfsddpm.py line 223-225: x_set_tmp = batch[:, ix] yout select ns - 1 of data, but you don't select ns-1 t_emd, that cause vitset.py line 176 :t_emb = t_emb.view(b, ns, -1) error, this t_emd is 49152 contain deleted data t_emd, but x_set 's ns is 5.
I want to know this code is the origin code of your paper result?
Thank you.

requirements.txt error

It seems that the requirements.txt need some local files.
Could you please check the requirements.txt and modify its content?

sh script/main.sh 0 ddpm_cif ar100_sigma

few-shot-diffusion-models/dataset/base.py", line 95, in get_data
residual = self.img_cls - value.shape[0]
AttributeError: 'list' object has no attribute 'shape'

Could you please provide the data file?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.