Giter VIP home page Giter VIP logo

clip2latent's Introduction

clip2latent - Official PyTorch Code

Open Arxiv Open In Colab Open in Spaces

clip2latent: Text driven sampling of a pre-trained StyleGAN using denoising diffusion and CLIP

Justin N. M. Pinkney and Chuan Li @ Lambda Inc.

We introduce a new method to efficiently create text-to-image models from a pre-trained CLIP and StyleGAN. It enables text driven sampling with an existing generative model without any external data or fine-tuning. This is achieved by training a diffusion model conditioned on CLIP embeddings to sample latent vectors of a pre-trained StyleGAN, which we call \textit{clip2latent}. We leverage the alignment between CLIP’s image and text embeddings to avoid the need for any text labelled data for training the conditional diffusion model. We demonstrate that clip2latent allows us to generate high-resolution (1024x1024 pixels) images based on text prompts with fast sampling, high image quality, and low training compute and data requirements. We also show that the use of the well studied StyleGAN architecture, without further fine-tuning, allows us to directly apply existing methods to control and modify the generated images adding a further layer of control to our text-to-image pipeline.

Installation

git clone https://github.com/justinpinkney/clip2latent.git
cd clip2latent
python -m venv .venv --prompt clip2latent
. .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt

Usage

Inference

To run the models for inference the simplest way is to start the gradio demo (or run it in Colab):

python scripts/demo.py

This will fetch the required models from huggingface hub and start gradio demo which can be accessed via a web browser.

To run a model via python:

from clip2latent import models

prompt = "a hairy man"
device = "cuda:0"
cfg_file = "https://huggingface.co/lambdalabs/clip2latent/resolve/main/ffhq-sg2-510.yaml"
checkpoint =  "https://huggingface.co/lambdalabs/clip2latent/resolve/main/ffhq-sg2-510.ckpt"

model = models.Clip2StyleGAN(cfg_file, device, checkpoint)
images, clip_score = model(prompt)
# images are tensors of shape: bchw, range: -1..1

Or take a look at the example notebook demo.ipynb.

Training

Generate data

To train a model of your own first you need to generate some data. We provide a command line interface which will run a StyleGAN model and pass the generated images to CLIP. The W latent vector and the CLIP image embedding will be stored as npy files, packed into tar files ready for use as a webdataset. To generate data used to traing the ffhq model in the paper do:

python scripts/generate_dataset.py

For more details of dataset generation options see the help for generate_dataset.py:

Usage: generate_dataset.py [OPTIONS] OUT_DIR

Arguments:
  OUT_DIR  Location to save dataset [required]

Options:
  --n-samples INTEGER             Number of samples to generate [default: 1000000]
  --generator-name TEXT           Name of predefined generator loader [default: sg2-ffhq-1024]
  --feature-extractor-name TEXT   CLIP model to use for image embedding [default: ViT-B/32]
  --n-gpus INTEGER                Number of GPUs to use [default: 2]
  --out-image-size INTEGER        If saving generated images, resize to this dimension [default: 256]
  --batch-size INTEGER            Batch size [default: 32]
  --n-save-workers INTEGER        Number of workers to use while saving [default: 16]
  --space TEXT                    Latent space to use [default: w]
  --samples-per-folder INTEGER    Number of samples per tar file [default: 10000]
  --save-im / --no-save-im        Save images? [default: no-save-im]

To use a different StyleGAN generator, add the required loading function to the generators dict in generate_dataset.py, then use that key as the generator_name. To use non-StyleGAN generators should be possible but would require additional modifications.

Train

To manage configuration for the model and training parameters we use hydra, to train with default configuration simply run:

python scripts/train.py

This will run the model with the default configuration as follows:

model:
  network:
    dim: 512
    num_timesteps: 1000
    depth: 12
    dim_head: 64
    heads: 12
  diffusion:
    image_embed_dim: 512
    timesteps: 1000
    cond_drop_prob: 0.2
    image_embed_scale: 1.0
    text_embed_scale: 1.0
    beta_schedule: cosine
    predict_x_start: true
data:
  bs: 512
  format: webdataset
  path: data/webdataset/sg2-ffhq-1024-clip/{00000..99}.tar
  embed_noise_scale: 1.0
  sg_pkl: https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl
  clip_variant: ViT-B/32
  n_latents: 1
  latent_dim: 512
  latent_repeats:
  - 18
  val_im_samples: 64
  val_text_samples: text/face-val.txt
  val_samples_per_text: 4
logging: wandb
wandb_project: clip2latent
wandb_entity: null
name: null
device: cuda:0
resume: null
train:
  znorm_embed: false
  znorm_latent: true
  max_it: 1000000
  val_it: 10000
  lr: 0.0001
  weight_decay: 0.01
  ema_update_every: 10
  ema_beta: 0.9999
  ema_power: 0.75

To train with a different configuration you can either change individual parameters using the following command line override syntax:

python scripts/train.py data.bs=128

which would set the batch size to 128.

Alternatively you can create your own yaml configuration files and switch between them, e.g. we also provide an example 'small' model configuration at config/model/small.yaml, to train using this simply run

python scripts/train.py model=small

For more details please refer to the hydra documentation.

Training is set up to run on a single GPU and does not currently support multigpu training. The default settings will take around 18 hours to train on a single A100-80GB, although the best checkpoint is likely to occur within 10 hours of training.

Acknowledgements

Citation

@misc{https://doi.org/10.48550/arxiv.2210.02347,
  doi = {10.48550/ARXIV.2210.02347},
  url = {https://arxiv.org/abs/2210.02347},
  author = {Pinkney, Justin N. M. and Li, Chuan},
  keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {clip2latent: Text driven sampling of a pre-trained StyleGAN using denoising diffusion and CLIP},
  publisher = {arXiv},
  year = {2022},
  copyright = {Creative Commons Attribution 4.0 International}
}

clip2latent's People

Contributors

justinpinkney avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

clip2latent's Issues

To Pytorch ONNX raising type issue in padding function

Hi, I used this clip2latent and its fast sampling performance is really impressive.

I wanted to transfer it by torch.onnx to get some further results but I met following issue

0INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":607, please report a bug to PyTorch. We don't have an op for aten::constant_pad_nd but it isn't a special case.  Argument types: Tensor, int[], bool, 

Candidates:
	aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)

This is actually an issue in TorchScript when Pytorch could support Bool value in Python but when it is traced by JIT it could not (as mentioned in pytorch/pytorch#77167).

However I did some search in the project but could not find any clues about this issue, I even tried to replace the following but no luck

mask = F.pad(mask, (0, attend_padding), value = True) # replace True with 1.

Do you have any ideas about this?

Error downloading dalle2_pytorch during environment configuration

Collecting dalle2_pytorch==0.2.38
Using cached dalle2_pytorch-0.2.38-py3-none-any.whl (1.4 MB)
Collecting rotary-embedding-torch
Using cached rotary_embedding_torch-0.1.5-py3-none-any.whl (4.1 kB)
Collecting vector-quantize-pytorch
Using cached vector_quantize_pytorch-0.9.2-py3-none-any.whl (7.6 kB)
Collecting youtokentome
Using cached youtokentome-1.0.6.tar.gz (86 kB)
Preparing metadata (setup.py) ... error
error: subprocess-exited-with-error

× python setup.py egg_info did not run successfully.
│ exit code: 1
╰─> [6 lines of output]
Traceback (most recent call last):
File "", line 2, in
File "", line 34, in
File "D:\Temp\pip-install-ahp9flmk\youtokentome_0c252ff1db5a4b04a093d59dfe250b09\setup.py", line 5, in
from Cython.Build import cythonize
ModuleNotFoundError: No module named 'Cython'
[end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

Unoptimal training results

Dear authors:

Thanks a lot your excellent works!

However, when I try to run the training code, I get strange memory consuming and unoptimal training results.
I just run the traing stage with the following cmd:
python scripts/train.py

Then, only about 8GB GPU memory is used instead of 80 GB for A100.
After the training, all the checkpoint trained by myself generate obviously worse results compared with "ffhq-sg2-510.ckpt".

Could you please give me some suggestions?

Thanks you! : )

Normalization twice in Clipper

Hi,

Here in the Clipper, before encoding the image, the code seems normalize the latent twice. Is this correct?

def embed_image(self, image): ` 
        """Expects images in -1 to 1 range""" 
        clip_in = F.interpolate(image, self.clip_size, mode="area")
        clip_in = self.normalize(0.5*clip_in + 0.5).clamp(0,1)
        return self.clip.encode_image(self.normalize(clip_in))

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.