Giter VIP home page Giter VIP logo

messenger-emma's Introduction

Messenger-EMMA

Implementation of the Messenger environment and EMMA model from the ICML 2021 paper: Grounding Language to Entities and Dynamics for Generalization in Reinforcement Learning.

Installation

Currently, only local installations are supported. Clone the repository and run:

pip install -e messenger-emma

This will install only the Messenger gym environments and dependencies. If you want to use models, you must install additional dependencies such as torch. Run the following instead:

pip install -e 'messenger-emma[models]'

Usage

To instantiate a gym environment, use:

import gym
import messenger
env = gym.make('msgr-train-v2')
obs, manual = env.reset()
obs, reward, done, info = env.step(<some action>)

Here <some action> should be an integer between 0 and 4 corresponding to the actions up,down,left,right,stay. Notice that in contrast to standard gym, env.reset() returns a tuple of an observation and the text manual sampled for the current episode. If you have installed the model dependencies, you can use our model EMMA just as you would any torch model. A full example of using EMMA with our environment can be found in the run.py file. To download model weights run the following:

wget -O pretrained.zip https://www.dropbox.com/s/ne8yglb0765f111/pretrained.zip?raw=1
unzip pretrained.zip

This will put pretrained model weights in a folder called pretrained. You can run EMMA using these weights with:

python run.py --model_state pretrained/emma_s2_1.pth --env_id msgr-train-v2

Please make sure that you load the correct weights for the correct environment stages. v1, v2, v3 environments should use model states with s1, s2, s3 in the filename respectively.

Training

Training scripts and usage information is provided in the folder training

Environment IDs

Environment ids follow the following format: msgr-{split}-v{stage}. There are three stages (1,2,3) and the splits include: train, val, test, as well as train-sc and train-mc for the single and multi-combination subsets of the training games. The split test-se is the state estimation version of the test environment, and is only available on stage 2. Please ignore any warnings from gym telling you to "upgrade to v3".

Human Play

To get a better sense of what Messenger is like, you can play it in the terminal assuming you have installed the environment. Specify the --env_id to the gym id you want to play:

python play_msgr.py --env_id msgr-train-v1

Note that in this human-play version, the entity groundings are provided to you upfront by rendering each entity with its first two letters (e.g. airplane as AI). In the actual environment, the agent must learn this grounding from scratch by matching text symbols like "plane" to the symbol 2.

Environment Details

This section documents some of the nuances of the environment and its usage.

Additional Hyperparameters

On stage 1, the agent begins with or without the message and wins the episode if it interacts with the correct entity. As a default, the agent begins with the message with prob 0.2 at the start of each episode. You can change this parameter as follows (note that this only applies to stage 1.):

env = gym.make("msgr-train-v1", message_prob=0.5)

On training games, there is a concept of single and multi-combination games. Since there are not many single-combination game variants, we sample one of these games with probability 0.25. the prob_env_1 keyword sets the probability of sampling a multi-combination game (which is 0.75 by default). You can change this with:

env = gym.make("msgr-train-v1", prob_env_1=0.6)

Note that there are no concepts of single and multi-combination games on test or validation games.

Step Limits and Penalities

The gym environment does not implemenet any sort of step limit or step penalty. This is to allow for maximum flexibility for various training setups (for example, you might want to start with a higher limit, and then anneal it over the course of training). Note that since entities are not always chasing, depending on the quality of the agent, some episodes may never terminate if no limit is specified, so we recommend including one in your training loop. During training, we also penalized the agent with a -1 reward if it did not complete the episode within our step limit.

Text Manual

Due to the noisy nature of data collected from human writers, sometimes the manual may contain a description that provides no useful information. In most cases, the correct course of action can still be deduced by reading the other descriptions.

Changes

  • July 07 2021: Added script for playing Messenger in the terminal. Removed pretrained weights from repo and moved it to Dropbox.
  • June 15 2021: We have introduced a stage 3, and msgr-test-v2 which includes more movement combinations for a more comprehensive test. Other stages/splits should be identical. If you cloned before this 8f6bd5c commit, we recommend getting the latest version.

Miscellaneous

If there are issues with the installation, try using Python 3.7. The model is tested working with transformers version 4.2.2. The license is MIT. If you get an error with gym try downgrading gym to 0.22.0 or lower.

Please use the following citation from DBLP (note author list and name changes from early arxiv versions).

@inproceedings{hanjie21grounding,
  author    = {Austin W. Hanjie and
               Victor Zhong and
               Karthik Narasimhan},
  editor    = {Marina Meila and
               Tong Zhang},
  title     = {Grounding Language to Entities and Dynamics for Generalization in
               Reinforcement Learning},
  booktitle = {Proceedings of the 38th International Conference on Machine Learning,
               {ICML} 2021, 18-24 July 2021, Virtual Event},
  series    = {Proceedings of Machine Learning Research},
  volume    = {139},
  pages     = {4051--4062},
  publisher = {{PMLR}},
  year      = {2021},
  url       = {http://proceedings.mlr.press/v139/hanjie21a.html},
  timestamp = {Wed, 14 Jul 2021 15:41:58 +0200},
  biburl    = {https://dblp.org/rec/conf/icml/HanjieZN21.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}

messenger-emma's People

Contributors

ahjwang avatar

Stargazers

Woo-chang Sim avatar RWLinno avatar Sacha Chernyavskiy avatar Alexander Nikulin avatar  avatar Zhang Jian avatar Wenhao Li avatar  avatar  avatar Linghui Meng avatar Xingdi (Eric) Yuan avatar Jens Tuyls avatar Mengdi-Xu avatar  avatar 崔小二 avatar Kazuki Irie avatar Jesse Mu avatar TzuRen avatar yangchao avatar Frank Röder avatar Shiro Takagi avatar  avatar

Watchers

 avatar

messenger-emma's Issues

Training script

Cool paper! Does the repository include a script for training the model? (If not, could you please add it?) Thanks!

Ask about test-se

Hi authors,

What is the difference between test-se and test in S2? I can't find this in the paper or in the codebase.

The split test-se is the state estimation version of the test environment, and is only available on stage 2.

Can you elaborate what "state-estimation version" of S2 means?
Thank you,

The collision handling is wrong

Hi authors,

Thanks for the great work,

I've been using your framework and I found one issue in the physical world of Messenger. It is the collision handling in (https://github.com/ahjwang/py-vgdl ) is wrong.

In other words, the agent dies when it is near the enemy, not when it touches the enemy. Because of the stochastic moving of entities, it's tricky to reproduce the issue. However, I screenshotted one state from the game where this happened:

image

where 3 is the enemy, 15 is the player (agent). The function _event_handling in py_vgdl in here returns these two entities collide, although they don't.

The error is because of this check sprite.rect.collidelistall(others), which checks whether two under-the-hood Rectangles of two entities overlap. However, their overlap doesn't necessarily mean they touch each other in the Gridworld coordinate. Their under-the-hood rectangle coordinate are (3, 11, 2, 2 and (2, 12, 2, 2), and their real gridworld coordinates are (1, 5) and (1, 6), respectively. Obviously, they haven't touched each other yet, because their real coordinates are different.

So right now, I fixed it by adding the following to the loop of here

                    p1 = rect_to_pos(sprite.rect, self.block_size)
                    p2 = rect_to_pos(other.rect, self.block_size)
                    if p1 != p2:
                        continue
                    # print(p1, p2)

where

def rect_to_pos(r, block_size):
    return r.left // block_size, r.top // block_size

Can you double check this?
Thank you very much!

Joe

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.