Giter VIP home page Giter VIP logo

captiongan's Introduction

Source code for the paper "Speaking the Same Language: Matching Machine to Human Captions by Adversarial Training"

The starting point for the adversarial caption generator training is "train_adversarial_caption_gen_v2.py". To train regular captioning models (using ML losss) start with "driver_theano.py"

This is built Python+numpy+theano. It's a large codebase containing the code to implement captioning frameworks used in the following papers:

Image captioning:

  1. "Speaking the Same Language: Matching Machine to Human Captions by Adversarial Training" (https://arxiv.org/abs/1703.10476)
  2. "Paying Attention to Descriptions Generated by Image Captioning Models" (https://arxiv.org/abs/1704.07434)
  3. "Exploiting scene context for image captioning" (https://dl.acm.org/citation.cfm?id=2983571)

Video captioning:

  1. "Frame-and segment-level features and candidate pool evaluation for video caption generation" (https://arxiv.org/abs/1608.04959)
  2. "Video captioning with recurrent networks based on frame-and video-level features and visual content classification" (https://arxiv.org/abs/1512.02949)

Instruction on using the code

  1. Make sure you have theano installed and working. As a quick check "import theano" should work without any errors on a python shell

  2. The code expects the data files to be in "data/<dataset_name>" directory. It needs a .json file containing all the training/validation/test samples and we need a .npy/.mat/.bin feature files containin the CNN features for each of the samples. Actual images are only needed for visualisation of results and are not needed during training.

  3. The data and some pre-trained models can be downloaded from the below links. This doesn't include image features. You can use any extracted features for this purpose. The pre-trained models use ResNet features extracted as in (https://github.com/akirafukui/vqa-mcb/tree/master/preprocess). Since the feature files are large, I have not uploaded them here.

    Data: https://drive.google.com/open?id=0B76QzqVJdOJ5VjlaR294SVV6Z00 Pre-Trained: https://drive.google.com/open?id=0B76QzqVJdOJ5TV9FMjhpVmlsTFE

Example Usage

  1. Training the adversarial model

THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda3' python train_adversarial_caption_gen_v2.py --maxlen 20 -o cvCoco/advDummy --fappend r-dep3-frc80-resnet-1samp-pretrainBOTH-miniBatchDiscr-GUMBHard0p5-smooth3-noMle-featMatchPsentEmbMatch-randInp50d --batch_size 10 --eval_period 0.5 --max_epochs 50 --feature_file fasterRcnn_clasDetFEat80.npy --eval_feature aux_inp --aux_inp_file resnet150_2048-mean.npy -ld 1e-5 -cb 50 --word_encoding_size 512 --sent_encoding_size 400 --solver rmsprop --train_evaluator_only 0 --use_gumbel_mse 1 -lg 1e-6 --eval_model lstm_eval --eval_init trainedModels/advers/evaluators/advmodel_checkpoint_coco_wks-12-46_r-reg-res150mean-5samp-lstmevalonly_318_94.22_EVOnly.p --disk_feature 0 --metrics_to_track meteor cider len lcldiv_1 lcldiv_2 --gumbel_temp_init 0.5 --use_gumbel_hard 1 --hidden_depth 3 --en_residual_conn 1 --n_gen_samples 5 --merge_dim 50 --softmax_smooth_factor 3.0 --use_mle_train 0 --rev_eval 1 --gen_input_noise 1 --gen_feature_matching 1 --continue_training trainedModels/coco/mpi/model_checkpoint_coco_wks-12-46_r-dep3-frc80-resnet150mean_per9.32.p

  1. Generating captions

THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda1' python predict_on_images.py cvCoco/advDummy/advmodel_checkpoint_coco_wks-12-46_r-dep3-frc80-resnet-5samp-pretrainBOTH-miniBatchDiscr-GUMBHard0p5-smooth3-noMle-featMatch-randInp50d_55999_15.20_genacc.p --aux_inp_file data/coco/resnet150_2048-mean.npy -f data/coco/fasterRcnn_clasDetFEat80.npy -i imgLists/imgListCOCO_MiniTestSet_ranzato.txt --fname_append ranzatotest_MLE-20Wrd-Smth3-randInpFeatMatch-ResnetMean-56k-beamsearch5 --softmax_smooth_factor 3.0 --labels data/coco/labels.txt --greedy 0 --computelogprob 1 --dobeamsearch 1 -b 5

Example image list file is here: https://drive.google.com/open?id=0B76QzqVJdOJ5NUtEMkx4ZzNKRWM

  1. Pre-Training the caption generator

THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda0' python driver_theano.py -d coco -l 1e-4 --maxlen 20 --decay_rate 0.999 --grad_clip 10.0 --image_encoding_size 512 --word_encoding_size 512 --hidden_size 512 -o cvCoco/salLclExpts --fappend r-dep3-frc80-resnet150mean --worker_status_output_directory statusCoco/c1 --write_checkpoint_ppl_threshold 14 --regc 2.66e-07 --batch_size 256 --eval_period 0.5 --max_epochs 60 --eval_batch_size 256 --aux_inp_file resnet150_2048-14-14.npzl --feature_file fasterRcnn_clasDetFEat80.npy --data_file dataset.json --sample_by_len 1 --lr_decay_st_epoch 1 --lr_decay 0.99 --disk_feature 2 --hidden_depth 3 --en_residual_conn 1 --poolmethod "none mean"

Some of the code and structure is based on original neuraltalk code relased by Andrej Karpath at https://github.com/karpathy/neuraltalk

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.