This is a pytorch implementation of Generative Adversarial Text-to-Image Synthesis paper, we train a conditional generative adversarial network, conditioned on text descriptions, to generate images that correspond to the description. The network architecture is shown below (Image from [1]). This architecture is based on DCGAN.
We used Caltech-UCSD Birds 200 and Flowers datasets, we converted each dataset (images, text embeddings) to hd5 format.
Hd5 file taxonomy `
- split (train | valid | test )
- example_name
- 'name'
- 'img'
- 'embeddings'
- 'class'
- 'txt'
- example_name
Python 3.10
pip install -r requirements.txt
- Download and extract the birds and flowers and COCO caption data in Torch format.
- Download and extract the birds and flowers and COCO Text encoding.
- Download and extract the birds and flowers and COCO image data.
- Use convert_cub_to_hd5_script or convert_flowers_to_hd5_script script to convert the dataset.
`python runtime.py --ds
Arguments:
type
: GAN archiecture to use(gan | wgan | vanilla_gan | vanilla_wgan)
. default =gan
. Vanilla mean not conditionaldataset
: Dataset to use(birds | flowers)
. default =flowers
split
: An integer indicating which split to use(0 : train | 1: valid | 2: test)
. default =0
lr
: The learning rate. default =0.0002
diter
: Only for WGAN, number of iteration for discriminator for each iteration of the generator. default =5
save_path
: Path for saving the models.l1_coef
: L1 loss coefficient in the generator loss fucntion for gan and vanilla_gan. default=50
l2_coef
: Feature matching coefficient in the generator loss fucntion for gan and vanilla_gan. default=100
pre_trained_disc
: Discriminator pre-tranined model path used for intializing training.pre_trained_gen
Generator pre-tranined model path used for intializing training.batch_size
: Batch size. default =64
num_workers
: Number of dataloader workers used for fetching data. default =8
epochs
: Number of training epochs. default=200
eval_batch_size
: Evaluation batch size. default =512
eval_interval
: Interval for evaluating GAN. default =10
ds
: Enable efficient data selection. defaultFalse