Giter VIP home page Giter VIP logo

convolutional-vqa's Introduction

Fully Convolutional Visual Question Answering

This is an attention based model for VQA using a dilated convolutional neural network for modelling the question and a resnet for visual features. The text model is based on the recent convolutional architecture ByteNet. Stacked attention distributions over the images are then used to compute weighted image features, which are concatenated with the text features to predict the answer. Following is the rough diagram for the described model.

Model architecture

Requirements

  • Python 2.7.6
  • Tensorflow 1.3.0
  • nltk

Datasets and Paths

  • The model is can be trained either on VQA 1.0 or VQA 2.0. Download the dataset by running sh download.sh in Data directory. Unzip the downloaded files and create the directory Data/CNNModels. Download the pretrained Resnet-152 from here to Data/CNNModels.
  • Make 2 empty directories Data/Models1, Data/Models2 for saving the checkpoints while training VQA 1.0 and 2.0 respectively.

Usage

Extract the Image features

  • Extract the image features as per the following
    • DEFAULT - Resnet (14,14,2048) block4 features(attention model) - python extract_conv_features.py --feature_layer="block4"
    • VGG (14,14,512) pool5 features(attention model) - python extract_conv_features.py --feature_layer="pool5"
    • VGG fc7 features (4096,) - python extract_conv_features.py --feature_layer="fc7"

Preprocess Questions/Answers

  • Tokeinze the questions/answers using python data_loader.py --version=VQA_VERSION (1 or 2)

Training the attention model

  • Train using python train_evaluate.py --version=VQA_VERSION
  • Following are the customizable model options
    • residual_channels : Number channels in the residual block of bytenet/state of the lstm. Default 512.
    • batch_size : Default 64.
    • learning_rate : default 0.001
    • epochs : Default 25
    • version : VQA dataset version 1 or 2
    • sample_every : sample attention distributions/answers every x steps. Default 200.
    • evaluate_every : Evaluate over validation set every x steps. Default 6000.
    • resume_model : Resume training the model from a checkpoint file.
    • training_log_file : Log accuracy/steps in this filepath. Default 'Data/training_log.json' .
    • feature_layer : Which conv features to use. Default block4 of resnet.
    • text_model : Text model to use : LSTM or bytenet. Default is bytenet

Evaluating a trained model

  • The accuracy on the validation is logged every evaluate_every steps while training the model in Data/training_log.json.
  • Use python train_evaluate.py --evaluate_every=1 --max_steps=1 --resume_model="Trained Model Path (Data/Models<vqa_-version>/model<epoch>.ckpt)" to evaluate a checkpoint.

Generating Answers/Attention Distributions

Pretrained Model

You may download the pretrained model from here. Save the files in Data/Models1.

  • Use python generate.py --question="<QUESTION ABOUT THE IMAGE>" --image_file="<IMAGE FILE PATH>" --model_path="<PATH_TO_CHECKPOINT = Data/Models1/model10.ckpt>" to generate answer/attention distributions in Data/gen_samples.

Sample Results

Image Question Attention1 Attention2 Predicted Answer
is she going to eat both pizza No
What color is the traffic light green
is the persons hair short Yes
what musical instrument is beside the laptop keyboard
what color hat is the boy wearing blue
what are the men doing eating
what type of drink is in the glass orange juice
is there a house yes

References

convolutional-vqa's People

Contributors

paarthneekhara avatar parasgandhi avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.