Giter VIP home page Giter VIP logo

rat-gan's Introduction

Recurrent-Affine-Transformation-for-Text-to-image-Synthesis

Official Pytorch implementation for our paper Recurrent-Affine-Transformation-for-Text-to-image-Synthesis

image

Examples

图片


Requirements

  • python 3.8
  • Pytorch 1.11.0+cu113
  • easydict
  • nltk
  • scikit-image
  • A 2080 TI (set nf=32 in *.yaml) or a 3090 32GB (set nf=64 in *.yaml)

Note that nf=32 produces a IS around 5.0 on CUB. To reproduce the final results, please use a GPU more than 32GB.

Installation

Clone this repo.

git clone https://github.com/senmaoy/RAT-GAN.git
cd RAT-GAN/code/

Datasets Preparation

  1. Download the preprocessed metadata for birds coco and save them to data/
  2. Download the birds image data. Extract them to data/birds/.Raw text data of CUB dataset is avaiable here
  3. Download coco dataset and extract the images to data/coco/
  4. Download flower dataset and extract the images to data/flower/. Raw text data of flower dataset is avaiable here

Note that flower dataset is a bit different from cub and coco with a standalone dataset processing script.

It's easy to train on your own Datasets (similar to the processing for flower dataset)

  1. Prepare a captions.pickle containing all the image paths. Note that captions.pickle should be prepared by yourself.
  2. Save captions.pickle under data_dir.
  3. Put all the captions of an image in a standalone txt file (one caption per line). This txt file will be later read by dataset_flower.py in line 149: cap_path = '%s/%s.txt' % ('/home/yesenmao/dataset/flower/jpg_text/', filenames['img'][i])
  4. Run main.py as usual. Dataset_flower.py will automatically process your own dataset.

Pre-trained text encoder

  1. Download the pre-trained text encoder for CUB and save it to ../bird/
  2. Download the pre-trained text encoder for coco and save it to ../bird/
  3. Download the pre-trained text encoder for flower and save it to ../bird/

Training

Train RAT-GAN models:

  • For bird dataset: python main.py --cfg cfg/bird.yml

  • For coco dataset: python main.py --cfg cfg/coco.yml

  • For flower dataset: python main.py --cfg cfg/flower.yml

  • *.yml files are example configuration files for training/evaluation our models.

Evaluating

Dwonload Pretrained Model

Evaluate RAT-GAN models:


Citing RAT-GAN

If you find RAT-GAN useful in your research, please consider citing our paper:

@article{ye2022recurrent,
  title={Recurrent Affine Transformation for Text-to-image Synthesis},
  author={Ye, Senmao and Liu, Fei and Tan, Minkui},
  journal={arXiv preprint arXiv:2204.10482},
  year={2022}
}

If you are interseted, join us on Wechat group where a dozen of t2i partners are waiting for you! If the QR code is expired, you can add this wechat: Unsupervised2020

1662194596425

The code is released for academic research use only. Please contact me through [email protected]

Reference

rat-gan's People

Contributors

senmaoy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

rat-gan's Issues

加载coco的netG模型报错

RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for NetG:
Missing key(s) in state_dict: "lstm.W", "lstm.U", "lstm.bias", "lstm.noise2h.weight", "lstm.noise2h.bias", "lstm.noise2c.weight", "lstm.noise2c.bias", "fc.weight", "fc.bias", "block0.gamma", "block0.lstm.W", "block0.lstm.U", "block0.lstm.bias", "block0.lstm.noise2h.weight", "block0.lstm.noise2h.bias",

How to use multiGPU to train

When I change your code and use nn.DataParallel and change netD.COND_DNET to netD.module.COND_DNET to train, there is an error RuntimeError: The size of tensor a (24) must match the size of tensor b (48) at non-singleton dimension 0. Could you tell me how to solve it

Flower dataset FID and IS reproduce

Hi, I loaded the pretrained model you provided and synthesized images. However, the FID is 241, and IS is 2.0. Meanwhile, the visualized results are very blurred. Could you please give me any advise about these results? Thanks very much!

inference

你好,请问有用一句话的文本去生成图像推理脚本吗?

`

`

about attention

Hello, I would like to ask how this attention is handled and what is the code?

Pretrain code for

Hi, I am following your works and I am wondering if you have the code for pre-training from the paper "3.1 Contrastive Text Embedding Pre-training". In the DAMSM.py file, I saw the CustomLSTM class and the CNN_ENCODER class, I assume they are used to get the sentence-level feature s and the image-level f. I didn't find the pretraining part for the constrastive pre-training. If you have the code, would u mind sharing it? Thanks a lot!

epoch数量

老哥你好,我看你的论文里面cub写的600epoch,但是yml文件里面是6010epoch,啥情况呀,吓到了,我魔改了你的模型,貌似要好几个月才能跑完。

inference

你好,请问有用一句话的文本去生成图像推理脚本吗?

您好

您好,请问您说的训练自己的数据集用的.pkl里具体的格式是啥

執行 pretrain_DAMSM.py 時出現 ModuleNotFoundError

Traceback (most recent call last):
  File "/home/kumaizeo/文件/NTHU/Homework/11/DL/cup3/refer/RAT-GAN/code/pretrain_DAMSM.py", line 4, in <module>
    from miscc.losses import sent_loss, words_loss
  File "/home/kumaizeo/文件/NTHU/Homework/11/DL/cup3/refer/RAT-GAN/code/miscc/losses.py", line 7, in <module>
    from GlobalAttention import func_attention
ModuleNotFoundError: No module named 'GlobalAttention'

執行 pretrain_DAMSM.py 會出現上面的錯誤,請問一下這個GlobalAttention是哪個模組?

关于“花”数据集

很棒的任务。作者您好,我最近刚接触到这个课题,找了很多程序都没有运行花数据集的程序,请问您可以公开花数据集的cfg文件或者是说一下怎么训练花的模型吗,十分感谢。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.