Giter VIP home page Giter VIP logo

kunheek / style-aware-discriminator Goto Github PK

View Code? Open in Web Editor NEW
110.0 4.0 9.0 8.9 MB

CVPR 2022 - Official PyTorch implementation of "A Style-Aware Discriminator for Controllable Image Translation"

Home Page: https://arxiv.org/abs/2203.15375

License: MIT License

Python 86.14% Shell 3.57% C++ 1.57% Cuda 8.72%
cvpr2022 deep-learning pytorch unsupervised-learning image-to-image-translation image-translation

style-aware-discriminator's Introduction

A Style-aware Discriminator for Controllable Image Translation

Kunhee Kim, Sanghun Park, Eunyeong Jeon, Taehun Kim, Daijin Kim
POSTECH

Our model discovers various style prototypes from the dataset in a self-supervised manner. The style prototype consists of a combination of various attributes including (left) time, weather, season, and texture; and (right) age, gender, and accessories.

Paper: https://arxiv.org/abs/2203.15375

Abstract: Current image-to-image translations do not control the output domain beyond the classes used during training, nor do they interpolate between different domains well, leading to implausible results. This limitation largely arises because labels do not consider the semantic distance. To mitigate such problems, we propose a style-aware discriminator that acts as a critic as well as a style encoder to provide conditions. The style-aware discriminator learns a controllable style space using prototype-based self-supervised learning and simultaneously guides the generator. Experiments on multiple datasets verify that the proposed model outperforms current state-of-the-art image-to-image translation methods. In contrast with current methods, the proposed approach supports various applications, including style interpolation, content transplantation, and local image translation.

Installation / Requirements

  • CUDA 10.1 or newer is required for the StyleGAN2-based model since it uses custom CUDA kernels of StyleGAN2 ported by @rosinality.
  • We mainly tested on Python 3.8 and 1.10.2 with cudatoolkit=11.3 (see environment.yml) with CUDA 11.2 for custom CUDA kernel.

Clone this repository:

git clone https://github.com/kunheek/style-aware-discriminator.git
cd style-aware-discriminator

Then, install dependencies using anaconda or pip:

conda env create -f environment.yml
# or
pip install -r requirements.txt

Testing and Evaluation

We provide the following pre-trained networks.

Dataset Resolution Method #images OneDrive link
afhq-adain AFHQ $256^2$ AdaIN 1.6 M afhq-adain.pt
afhq-stylegan2 AFHQ $256^2$ StyleGAN2 5 M afhq-stylegan2-5M.pt
afhqv2 AFHQ v2 $512^2$ StyleGAN2 5 M afhqv2-512x512-5M.pt
celebahq-adain CelebA-HQ $256^2$ AdaIN 1.6 M celebahq-adain.pt
celebahq-stylegan2 CelebA-HQ $256^2$ StyleGAN2 5 M celebahq-stylegan2-5M.pt
church LSUN church $256^2$ StyleGAN2 25 M church-25M.pt
ffhq FFHQ $256^2$ StyleGAN2 25 M ffhq-25M.pt
flower Oxford 102 $256^2$ AdaIN 1.6 M flower-256x256-adain.pt

Here are links to all checkpoints (checkpoints.zip) and MD5 file (checkpoints.md5). If you have wget and unzip in your environment, you can also download the checkpoints using the following command:

# download all checkpoints.
bash download.sh checkpoints
# download a specific checkpoint.
bash download.sh afhq-adain

See the table above or download.sh for available checkpoints.

Quantitative results

(Optional) Computing inception stats requires long time. We provide pre-calculated stats for AFHQ 256 and CelebA-HQ 256 datasets (link). You can download and register them using the following command:

bash download.sh stats
# python -m tools.register_stats PATH/TO/STATS
python -m tools.register_stats assets/stats

To evaluate our model run python -m metrics METRICS --checkpoint CKPT --train-dataset TRAINDIR --eval-dataset EVALDIR. By default, all metrics will be saved in runs/{run-dir}/metrics.txt. Available metrics are:

See metrics/{task}_evaluator.py for task specific options. You can parse multiple tasks at the same time. Here are some examples:

python -m metrics fid reconstruction --seed 123 --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --train-dataset ./datasets/afhq/train --eval-dataset ./datasets/afhq/val

python -m metrics mean_fid --seed 777 --checkpoint ./checkpoints/celebahq-stylegan2-5M.pt --train-dataset ./datasets/celeba_hq/train --eval-dataset ./datasets/celeba_hq/val

Qualitative results

You can synthesize images similarly to the quantitave evaluations (replace metrics to synthesis). By default, all images will be saved in runs/{run-dir}/{task} folder.

# python -m synthesis [TASKS] --checkpoint PATH/TO/CKPT --folder PATH/TO/FOLDERS
python -m synthesis swap --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --folder ./testphotos/afhq/content ./testphotos/afhq/style

python -m synthesis interpolation --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --folder ./testphotos/afhq/content ./testphotos/afhq/style

Some tasks require multiple folders (e.g., content and style) or extra arguments. Available synthesis tasks are:

Additional tools

We provide additional tools for visualizing the learned style space:

  • plot_tsne: visualize the learned style space and prototypes using t-SNE.

python -m tools.plot_tsne --checkpoint checkpoints/afhq-stylegan2-5M.pt --target-dataset datasets/afhq/val --seed 7 --title AFHQ --labels cat dog wild

python -m tools.plot_tsne --checkpoint checkpoints/celebahq-stylegan2-5M.pt --target-dataset datasets/celeba_hq/val --seed 7 --title CelebA-HQ --legends female male
  • similarity_search: find samples that are most similar to the query (in the style space and the content space) in the target dataset.
python -m tools.similarity_search --checkpoint CKPT --query QUERY_IMAGE --target-dataset TESTDIR

Training

Datasets

By default, all images in the folder will be used for training or evaluation (supported image formats can be found here). For example, if you parse --train-dataset=./datasets/afhq/train, all images in the ./datasets/afhq/train folder will be used for training.
For LSUN datasets, lsun must be included in the folder path.

datasets
└─ lsun
   ├─ church_outdoor_train_lmdb
   └─ church_outdoor_val_lmdb

To measure mean fid, a subdirectory corresponding to each class must exist (less than 5). If you want to reproduce experiments in the paper, we recommend to use the following structure:

datasets
├─ afhq
│  ├─ train
│  │  ├─ cat
│  │  ├─ dog
│  │  └─ wild
│  └─ val (or test)
│     └─ (cat/dog/wild)
└─ celeba_hq
   ├─ train
   │  ├─ female
   │  └─ male
   └─ val
      └─ (female/male)

Training scripts

Notice: We recommend training networks on a single GPU with enough memory (e.g., A100) to obtain best results, since we observed performance degradation with current implementation when using multiple GPUs (DDP). For example, a model trained on a A100 GPU (40GB) is slightly better than a model trained on two TITAN XP GPU (12GB * 2). We used a single NVIDIA A100 GPU for AFHQ and CelebA-HQ experiments and four NVIDIA RTX3090 GPUs for AFHQ v2, LSUN churches, and FFHQ experiments. Note that we disabled tf32 for all experiments.

We provide training scripts here. Use the following commands to train networks with custom arguments:

# Single GPU training.
python train.py --mod-type adain --total-nimg 1.6M --batch-size 16 --load-size 320 --crop-size 256 --image-size 256 --train-dataset datasets/afhq/train --eval-dataset datasets/afhq/val --out-dir runs --extra-desc some descriptions

# Multi-GPU training.
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train.py --total-nimg 25M --batch-size 64 --load-size 320 --crop-size 256 --image-size 256 --train-dataset datasets/ffhq/images1024x1024 --eval-dataset datasets/ffhq/images1024x1024 --nb-proto 128 --latent-dim 512 --latent-ratio 0.5 --jitter true --cutout true --out-dir runs --extra-desc some descriptions

Training options, codes, checkpoints, and snapshots will be saved in the {out-dir}/{run-id}-{dataset}-{resolution}-{extra-desc}. Please see train.py, model.py, and augmentation.py for available arguments.

To resume training, run python train.py --resume PATH/TO/RUNDIR. For example:

# Single GPU training.
python train.py --resume runs/000-afhq-256x256-some-discriptions

# Multi-GPU training.
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 python train.py --resume runs/001-ffhq-some-discriptions

Citation

If you find this repository useful for your research, please cite our paper:

@InProceedings{kim2022style,
  title={A Style-Aware Discriminator for Controllable Image Translation},
  author={Kim, Kunhee and Park, Sanghun and Jeon, Eunyeong and Kim, Taehun and Kim, Daijin},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022},
  pages={18239--18248}
}

Acknowledgements

Many of our implementations are adapted from previous works, including SwAV, DINO, StarGAN v2, Swapping Autoencoder, clean-fid, and stylegan2-pytorch.

Licenses

All materials except custom CUDA kernels in this repository are made available under the MIT License.

The custom CUDA kernels (fused_bias_act_kernel.cu and upfirdn2d_kernel.cu) are under the Nvidia Source Code License, and are for non-commercial use only.

style-aware-discriminator's People

Contributors

kunheek avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

style-aware-discriminator's Issues

About Interpolation.

image
I use "python -m synthesis interpolation --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --folder ./testphotos/afhq/content ./testphotos/afhq/style“ to synthesize images. But got synthesized images between content image and style image, no style1 images and style2 image. If I want to synthesize images using the style code interpolated between the two style codes obtained from the two reference images,what script should I use? Thank you.

Aboutstop-gradient operation

Hi, thank you for sharing your code. Can you please tell me how to stop the gradient when training the generator in your code? Looking forward to your reply, thanks!

Simple description of training data requirements

Please create a small section in the README that describes in detail the expected format/size etc. for training data including directory structure. I assume it's something like this:
A/
trainA
testA
B/
trainB
testB

images should be of dimension X pixels by Y pixels

Training the problem

Hello, there is a cutoff in the training data 999 epoch Evaluating k-NN accuracy. appear error:ValueError: range() arg 3 must not be zero but my train afhq datasets likewise error :ValueError: range() arg 3 must not be zero

**Traceback (most recent call last):
File "train.py", line 258, in
File "train.py", line 254, in main
File "train.py", line 190, in training_loop
File "C:\Users\yuanx.conda\envs\style2\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "C:\Users\yuanx\Desktop\style\style-aware-discriminator\metrics\knn_evaluator.py", line 69, in evaluate
top1, top5 = knn_classifier(
File "C:\Users\yuanx.conda\envs\style2\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
return func(*args, kwargs)
File "C:\Users\yuanx\Desktop\style\style-aware-discriminator\metrics\knn_evaluator.py", line 106, in knn_classifier
for idx in range(0, num_test_images, imgs_per_chunk):
ValueError: range() arg 3 must not be zero

this is my print
num_test_images, num_chunks = test_labels.shape[0], 100
num_test_images = 32;

imgs_per_chunk = num_test_images // num_chunks
imgs_per_chunk = 0

environment:torch=1.11.0+cu113 cuda=11.3

Sample code for image translation

Hi,
Can you please share a sample code to translate img_src.jpg with img_reference.jpg. This seems very interesting project. Just wanna have some fun with this :)

训练的epoch

我跑了一下代码,那个训练轮数在哪里设置,为啥我感觉只有一轮就结束了

lack count_parameters() in torch_utils.py

I got an error here

if torch_utils.count_parameters(module) > 0:

No function of count_parameters() is found

I find a similar function on the web and add it to torch_utils.py, then I can successfully run the code.

def count_parameters(model):
    counts = sum(p.numel() for p in model.parameters() if p.requires_grad)
    #print(f'The model has {counts:,} trainable parameters')
    return counts

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.