Giter VIP home page Giter VIP logo

relational-networks's Introduction

Pytorch implementation of Relational Networks - A simple neural network module for relational reasoning

Implemented & tested on Sort-of-CLEVR task.

Sort-of-CLEVR

Sort-of-CLEVR is simplified version of CLEVR.This is composed of 10000 images and 20 questions (10 relational questions and 10 non-relational questions) per each image. 6 colors (red, green, blue, orange, gray, yellow) are assigned to randomly chosen shape (square or circle), and placed in a image.

Non-relational questions are composed of 3 subtypes:

  1. Shape of certain colored object
  2. Horizontal location of certain colored object : whether it is on the left side of the image or right side of the image
  3. Vertical location of certain colored object : whether it is on the upside of the image or downside of the image

Theses questions are "non-relational" because the agent only need to focus on certain object.

Relational questions are composed of 3 subtypes:

  1. Shape of the object which is closest to the certain colored object
  2. Shape of the object which is furthest to the certain colored object
  3. Number of objects which have the same shape with the certain colored object

These questions are "relational" because the agent has to consider the relations between objects.

Questions are encoded into a vector of size of 11 : 6 for one-hot vector for certain color among 6 colors, 2 for one-hot vector of relational/non-relational questions. 3 for one-hot vector of 3 subtypes.

I.e., with the sample image shown, we can generate non-relational questions like:

  1. What is the shape of the red object? => Circle (even though it does not really look like "circle"...)
  2. Is green object placed on the left side of the image? => yes
  3. Is orange object placed on the upside of the image? => no

And relational questions:

  1. What is the shape of the object closest to the red object? => square
  2. What is the shape of the object furthest to the orange object? => circle
  3. How many objects have same shape with the blue object? => 3

Setup

Create conda environment from environment.yml file

$ conda env create -f environment.yml

Activate environment

$ conda activate RN3

If you don't use conda install python 3 normally and use pip install to install remaining dependencies. The list of dependencies can be found in the environment.yml file.

Usage

$ ./run.sh

or

$ python sort_of_clevr_generator.py

to generate sort-of-clevr dataset and

 $ python main.py 

to train the binary RN model. Alternatively, use

 $ python main.py --relation-type=ternary

to train the ternary RN model.

Modifications

In the original paper, Sort-of-CLEVR task used different model from CLEVR task. However, because model used CLEVR requires much less time to compute (network is much smaller), this model is used for Sort-of-CLEVR task.

Result

Relational Networks (20th epoch) CNN + MLP (without RN, 100th epoch)
Non-relational question 99% 66%
Relational question 89% 66%

CNN + MLP occured overfitting to the training data.

Relational networks shows far better results in relational questions and non-relation questions.

Contributions

@gngdb speeds up the model by 10 times.

relational-networks's People

Contributors

gngdb avatar justinbuzzni avatar kimhc6028 avatar mdda avatar saduras avatar thiviyant avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

relational-networks's Issues

Suggestion

Very nice work!
I suggest saving the pickle file for the dataset with protocol=2. This is faster in loading and saving and also uses 1.4GB instead of 5GB.

Question about input to relational network

I do not understand well how you pair the objects. My question is if each pair of objects is feed as input once or twice. If it is fed once then how do you choose the order? to be independent of the order each pair should be fed twice, right? Otherwise there is some break of symmetry in the input pair. I could not understand this from the original article, and the code is still obscure to me to understand this. thanks!
UPDATE: I think I got it know, you input all pairs twice (in different order each time), and also each object with itself, thus n^2 instead of n(n-1)/2, Is this correct?

AMI to run the code

Hi, is there any public AMI on amazon so that we can try the code ? I cannot run your code a couple of different pytorch AMIs since I got a couple of errors inside the main.py file.

'CNN_MLP' object has no attribute 'train'

I am implementing something using CLEVR dataset and is using your repository for reference purpose.

But everytime I am executing the script, I am getting the error 'CNN_MLP' object has no attribute 'train'. In main.py right after the function 'train' starts, it executes model.train(). Somehow I cannot go beyond this point.

Could somebody help me to understand if I am doing something wrong?

bug in question generation

EDIT: Sorry - this is my mistake. I was confused because apparently in numpy.random.randint the "high" is exclusive, whereas the python random.randint isn't. You can delete this comment.

Encoding the question.

In the model, there is no inclusion of an LSTM. How is the question then encoded? @gngdb Do you have any idea about this?

Theoretical question: can RN generalize to shape-color combinations not previously seen in training data?

Say, if in the training data the RN saw many green circles, and many yellow triangles, but it never saw green triangles or yellow circles, would the RN perform well if asked a question about those shape-color combinations it never saw in the training data? In other words, can the RN learn the abstract concept of "color", the abstract concept of "shape" and generalize those concepts to understand new questions involving novel color-shape combinations?

Text relational preprocessing

How would one preprocess the bAbI tasks text into supposed objects and relations for loading into this training model?

The original paper does not seem entirely clear to me:

Dealing with natural language For the bAbI suite of tasks the natural language inputs must be transformed into a set of objects. This is a distinctly di↵erent requirement from visual QA, where objects were defined as spatially distinct regions in convolved feature maps. So, we first identified up to 20 sentences in the support set that were immediately prior to the probe question. Then, we tagged these sentences with labels indicating their relative position in the support set, and processed each sentence word-by-word with an LSTM (with the same LSTM acting on each sentence independently).

Dealing with pixels misunderstanding ?

Hello sir,
I am reading the paper and your code. However there is one thing I dont understand from your implementation.
In the paper, at the Section 4 (Dealing with pixels), the author said that: "So after convolving the image, each of the d2 k-dimensional cells in the d x d feature maps was tagged with an arbitrary coordinate ...". So my question is in your code, which part is referring to this and would you please explain more about it since it is kind a difficult for me to understand it. Thanks !

dead code reported by vulture

We used vulture to search for unused code in your project. You can find the report below. It would be great if you could give us feedback about which items are actually used or unused. This would allow us to improve vulture and ideally it also helps you to remove obsolete code or even find typos and bugs.

Command

vulture relational-networks

Raw results

relational-networks/main.py:43: Unused variable 'kwargs'
relational-networks/main.py:138: Unused variable 's_datasets'
relational-networks/model.py:34: Unused function 'forward'
relational-networks/sort_of_clevr_generator.py:11: Unused variable 'question_size'
relational-networks/sort_of_clevr_generator.py:13: Unused variable 'answer_size'
relational-networks/translator.py:2: Unused function 'translate'

There might be false positives, which can be prevented by adding them to a
whitelist file. You may find more info here

Regards,
vulture team

About coord_tensor and np_coord_tensor part in model.py

Hi, the code is almost completely self-explanatory. However, I couldn't understand this part. Could you explain there ? Why you're creating coord_tensor and np_coord_tensor and what is the number 25 there ? I also would like to hear about lines 48-49-50.
Edit:
I have also another question about the implementation. Eventhough it is completely differ'rent, I don't want to open one more issue :)
I guess translator.py file and function inside of it isn't used anywhere in the code. What was the aim of that file ?

Train Relational Networks for 10 epochs

I have trained RN for 10 epochs. The final test set accuracy is 73% for relational question and 72% for non-relational question. It seems that there is no significant improvement for relational questions.

Variables and their requires_grad flags

Why all of your variables here and here's requires_grad argumand set to False ? You set requires_grad parameter of coord_tensor variable to false and then you concatenate it with the output of cnn (defined as x),whose requires_grad arguman is true by default I guess, at line 78. In this case what is the requires_grad parameter of concatenated variable (output of line 78 which is also x_flat)?

Failure to replicate results

I wanted to report that I didn't manage to replicate the results in the paper or in the repo.
Relation accuracy: 80%
Non-Relation accuracy: 93%

trained for 20 epochs, with all default arguments

Training details

Regarding the training procedure for the entire CLEVR, how did you manage to train pixel and state description stages? i.e., did you train end-to-end the whole system (LSTMs, ConvInputModel, and RN)? Or was it by stages?

Another question off the topic: What is the purpose of coord_oi and coord_oj

Thank you! Great implementation by the way.

Using Dataset

Hi, I'm trying to find a source for the Sort-of-CLEVR Dataset. The provided code in this repository seems to be what I'm looking for, but I need help understanding how to set it up for training, validation, and testing. Could you provide a brief example of how the included code could be used to generate a training, validation, and testing set, and then from this, iterate through these datasets in batches of chosen size?

Question on sort_of_clevr_generator "count+4"

Hi, thanks for your work and sharing of the code!!!
I have on question on data generation part,
I know the questions and answers are represented in one-hot vectors
where
questions = 2 x (6 for one-hot vector of color), 3 for question type(binary, ternary, norel), 3 for question subtype
answers = yes, no, rectangle, circle, r, g, b, o, k, y

My question is why you use count+4 in here bianry question-subtype 3, which is as follows:

  elif subtype == 2:
      """count->1~6"""
      my_obj = objects[color][2]
      count = -1
      for obj in objects:
          if obj[2] == my_obj:
              count +=1 
      answer = count+4

As I understand, the count is already the number of Number of objects which have the same shape with the certain colored object.
The +4 in [ yes, no, rectangle, circle, r, g, b, o, k, y] means the colors?

Any help would be appreciated and thanks for your time

what is output in BasicModel?

Hello, this might be a stupid question but I have not seen any use of self like:
output = self(input_img, input_qst)
in the definition of BasicModel before. From later lines I know output must be a tensor but I cannot understand how it processes input_img and input_qst. If this is simply a Python or Pytorch question, not relevant to your code, could you please perhaps let me know where I can find relevant answers? Thanks!

Magic Number Question

Hi, in line 21 I had a hard time understanding the calculation
(24+2)*2+11
The input of the linear layer should be two objects each encoded as a vector size of 24.
Where did the +2 and +11 came from?

Thanks!

Thanks. I have repeat your result, but I wander the result in the paper

Train Epoch: 18 [193280/196000 (99%)] Relations accuracy: 95% | Non-relations accuracy: 100%
Train Epoch: 18 [194560/196000 (99%)] Relations accuracy: 86% | Non-relations accuracy: 100%
Train Epoch: 18 [195840/196000 (100%)] Relations accuracy: 89% | Non-relations accuracy: 100%

Test set: Relation accuracy: 89% | Non-relation accuracy: 100%
Train Epoch: 19 [192000/196000 (98%)] Relations accuracy: 94% | Non-relations accuracy: 100%
Train Epoch: 19 [193280/196000 (99%)] Relations accuracy: 80% | Non-relations accuracy: 100%
Train Epoch: 19 [194560/196000 (99%)] Relations accuracy: 89% | Non-relations accuracy: 100%
Train Epoch: 19 [195840/196000 (100%)] Relations accuracy: 91% | Non-relations accuracy: 100%

Test set: Relation accuracy: 90% | Non-relation accuracy: 99%
Train Epoch: 20 [193280/196000 (99%)] Relations accuracy: 91% | Non-relations accuracy: 100%
Train Epoch: 20 [194560/196000 (99%)] Relations accuracy: 97% | Non-relations accuracy: 100%
Train Epoch: 20 [195840/196000 (100%)] Relations accuracy: 95% | Non-relations accuracy: 100%

Test set: Relation accuracy: 89% | Non-relation accuracy: 99%

Object coordinates missing

From the article in the "Dealing with pixels" case:

So, after convolving the image, each of the d^2 k-dimensional cells in the d × d feature maps was tagged with an arbitrary coordinate indicating its relative spatial position, and was treated as an object for the RN.

Also, the author (/u/asantoro) confirmed on reddit that objects were of the form:
[x, y, v_1, v_2, ..., v_k] where k is the number of kernels and the range of the coordinates x,y doesn't matter.
(Reddit link)

So I think in the model, object coordinates should be added to oi and oj.
https://github.com/kimhc6028/relational-networks/blob/master/model.py#L53

for i in range(25):
    oi = x[:,:,i/5,i%5]
    for j in range(25):
        oj = x[:,:,j/5,j%5]
        x_ = torch.cat((oi,oj,qst), 1)
        x_ = self.g_fc1(x_)

I believe this should improve performance on questions where the spatial relationship between objects is important (closest, furthest, ...).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.