mila-iqia / atari-representation-learning Goto Github PK
View Code? Open in Web Editor NEWCode for "Unsupervised State Representation Learning in Atari"
Home Page: https://arxiv.org/abs/1906.08226
License: MIT License
Code for "Unsupervised State Representation Learning in Atari"
Home Page: https://arxiv.org/abs/1906.08226
License: MIT License
We should have a design document for the benchmark itself. :
Did you find out how the positions of the sprites are encoded? For example in Asteroids, I want to find the asteroid or the player given its x-y-coordinate in the RAM. But I can't find a simple mapping between the RAM coordinates and the pixels.
Hi,
I've tried using the provided AtariARIWrapper to get the 'labels' for MsPacmanNoFrameskip-v4.
However, when I draw the locations of player, enemy_sue, enemy_inky..., using ('player_x', 'player_y'), ('enemy_sue_x', 'enemy_sue_y'), there seems to be always certain offsets to the real location.
How can get the aligned locations?
Thanks
I've been trying to generate the same number of frames described in the paper for probing evaluations (35,000 train; 5,000 validation; 10,000 test), but for Tennis I am unable to do so because of the large number of duplicates in the collected episodes. May ask are how you were able to for the paper?
I've tried collecting episodes with different random seeds and up to 400,000 steps (to try and account for duplicates), but so far to no avail. I've succeeded in generating episodes for other games but just not Tennis. I believe this might be because the agent spends the majority of episodes refusing to serve the ball (an issue I've come across previously before).
tr_episodes, val_episodes,\
tr_labels, val_labels,\
test_episodes, test_labels = get_episodes(env_name='TennisNoFrameskip-v4',
steps=50000,
collect_mode="pretrained_ppo",
seed=seed)
I feel the features extracted for seaquest are incorrect/ambiguous. In the attached image, there are
{'labels': {'player_y': 13, 'oxygen_meter_value': 64, 'num_lives': 3, 'missile_direction': 0, 'diver_x_1': 0, 'player_direction': 0, 'diver_x_0': 45, 'player_x': 76, 'enemy_obstacle_x_3': 96, 'diver_x_3': 0, 'missile_x': 0, 'score_0': 0, 'diver_x_2': 0, 'enemy_obstacle_x_1': 96, 'enemy_obstacle_x_0': 96, 'score_1': 0, 'enemy_obstacle_x_2': 96, 'divers_collected_count': 0}, 'ale.lives': 4}
Request you to please look into this or let us know if we are misinterpreting anything here.
Thanks,
Vaibhav
python -m scripts.run_probe --method infonce-stdim --env-name Pong-v0
Traceback (most recent call last):
File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/runpy.py", line 183, in _run_module_as_main
mod_name, mod_spec, code = _get_module_details(mod_name, _Error)
File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/runpy.py", line 109, in _get_module_details
import(pkg_name)
File "/home/duane/PycharmProjects/atari-representation-learning/scripts/init.py", line 1, in
from .run_contrastive import train_encoder
File "/home/duane/PycharmProjects/atari-representation-learning/scripts/run_contrastive.py", line 8, in
from atariari.methods.dim_baseline import DIMTrainer
File "/home/duane/PycharmProjects/atari-representation-learning/atariari/methods/dim_baseline.py", line 13, in
from torchvision import transforms
File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/init.py", line 4, in
from torchvision import datasets
File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/datasets/init.py", line 9, in
from .fakedata import FakeData
File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/datasets/fakedata.py", line 3, in
from .. import transforms
File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/transforms/init.py", line 1, in
from .transforms import *
File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 17, in
from . import functional as F
File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 5, in
from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
ImportError: cannot import name 'PILLOW_VERSION' from 'PIL' (/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/PIL/init.py)
python -m scripts.run_probe --method infonce-stdim --env-name PongNoFrameskip-v4
Requires a weights and biases account. OK. Account created.
After setup, it seems that the account name (entity) is hardcoded.to "curl-atari"
wandb.init(project=args.wandb_proj, entity="curl-atari", tags=tags)
perhaps it should be something like
wandb.init(project=args.wandb_proj, entity=args.wandb_entiti, tags=tags)
Or am I mistaken?
I understand that global-global infomax can lead to some salient features to be neglected.
I can see global-local infomax as a way to get around this, but additionally having local-local infomax feels redundant. Is the LL loss a necessity to get the method working or does it simply improve results?
When gathering states using the pretrained_ppo
option, the following error appears:
Traceback (most recent call last):
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/wandb/retry.py", line 95, in __call__
result = self._call_fn(*args, **kwargs)
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/wandb/apis/public.py", line 79, in execute
return self._client.execute(*args, **kwargs)
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/gql/client.py", line 50, in execute
result = self._get_result(document, *args, **kwargs)
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/gql/client.py", line 58, in _get_result
return self.transport.execute(document, *args, **kwargs)
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/gql/transport/requests.py", line 38, in execute
request.raise_for_status()
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/requests/models.py", line 940, in raise_for_status
raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://api.wandb.ai/graphql
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 2, in <module>
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/wandb/apis/public.py", line 242, in __len__
self._load_page()
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/wandb/apis/public.py", line 270, in _load_page
self.QUERY, variable_values=self.variables)
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/wandb/retry.py", line 130, in wrapped_fn
return retrier(*args, **kargs)
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/wandb/retry.py", line 102, in __call__
if not check_retry_fn(e):
File "/data/anaconda/envs/rllib/lib/python3.6/site-packages/wandb/util.py", line 485, in no_retry_auth
raise CommError("Permission denied, ask the project owner to grant you access")
wandb.apis.CommError: Permission denied, ask the project owner to grant you access
This is for the SeaquestNoFrameskip-v4
environment. Are the PPO weights in open access?
We need a better method than random policy to collect samples for training the encoder since a random policy might not give enough diversity of frames.
Using an expert policy isn't ideal, that makes it seem like our agent is dependent on an expert.
Initial ideas:
right now, the values of different labels can take on different ranges of values. These values (if the labels are positions) may not makes sense with downsampling of frames, so we need to fix this.
I ran into an issue where I did pip install -e .
for this repo and then worked on another repo that had a module called src
. pip install -e .
seemed to put the src from this repo on the path, so any time I referred to src in my other project it would used this one
In commit 23c1ede, I changed the instructions to python setup.py install
, which will only put the aari module on the path.
Not sure how other repos, like gym, that recommend pip install -e .
seem to avoid this issue
Hello, I was checking the ram annotations for Pong and it seems like only the score and position of the players and ball is known from the ram, my thinking was that they also stored ball direction in order to make it move frame by frame.
Are those annotations just missing from that page or do they do that in some other way that does not use the ram?
This line is a little confusing to me. How does the range over N samples correspond to our cross-entropy targets? All help is highly appreciated!
In https://github.com/mila-iqia/atari-representation-learning/blob/master/atariari/benchmark/envs.py, line 85 it is written that the class GrayscaleWrapper
"""Warp frames to 84x84 as done in the Nature paper and later work."""
, although in the paper it is written that you do not scale to 84x84. Does this function scale the observations?
Thank you
Thank you very much for sharing this implementation!
There is a minor issue, as far as i understand, here should be steps=args.pretraining_steps
.
I tried to run the experiment with python -m scripts.run_probe --method infonce-stdim --env-name MsPacmanNoFrameskip-v4 --pretraining-steps 100000
couple of times, mean f1 is โ0.65. In the paper score is 0.7. Other methods also show a slightly lower score, although "supervised" one matches exactly. Am I missing some hyperparameter or it's just a typical score fluctuation?
I tried to run probing tasks for different Atari environments, using the following command:
python -m scripts.run_probe --method infonce-stdim --env-name {env_name}
I did not change any code, just tried different game, including PongNoFrameskip-v4
, BowlingNoFrameskip-v4
, BreakoutNoFrameskip-v4
, HeroNoFrameskip-v4
.
However, only the F1 score for pong
matches the score reported in the paper. The F1 scores of the other three games are far worse than the score shown in the paper (for bowling
, I got 0.22).
I check the training loss logged in wandb, it seems that training has not converged at all. See the figure below.
How to get the F1 socres reported in the paper? Am I missing something?
I'd like to see another game or two in the list - do you have any pointers to source code for the games / tips based on previous games about how to find the RAM locations for useful information in the game? (I was initially worried, given the limited memory of the Atari 2600, that multiple pieces of information might be stored in separate bits of the same byte, but from your list, that doesn't seem to be the case, so perhaps it isn't too complicated?).
I tried to reproduce the result in Table 10, with
python -m scripts.run_probe --method pretrained-rl-agent
but got
-------Collecting samples----------
Deleting room_number for being too low in entropy! Sorry, dood!
Deleting enemy_skull_y for being too low in entropy! Sorry, dood!
Deleting key_monster_x for being too low in entropy! Sorry, dood!
Deleting key_monster_y for being too low in entropy! Sorry, dood!
Deleting level for being too low in entropy! Sorry, dood!
Deleting score_0 for being too low in entropy! Sorry, dood!
Deleting score_1 for being too low in entropy! Sorry, dood!
Deleting score_2 for being too low in entropy! Sorry, dood!
Duplicates: 98, Test Len: 8011
got episodes!
Total Steps: 27411
Traceback (most recent call last):
File "/home/liuyuezhangadam/anaconda3/envs/pytorch/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/liuyuezhangadam/anaconda3/envs/pytorch/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/liuyuezhangadam/Git/atari-representation-learning/scripts/run_probe.py", line 87, in <module>
run_probe(args)
File "/home/liuyuezhangadam/Git/atari-representation-learning/scripts/run_probe.py", line 71, in run_probe
trainer.train(tr_eps, val_eps, tr_labels, val_labels)
File "/home/liuyuezhangadam/Git/atari-representation-learning/atariari/benchmark/probe.py", line 197, in train
epoch_loss, accuracy = self.do_one_epoch(tr_eps, tr_labels)
File "/home/liuyuezhangadam/Git/atari-representation-learning/atariari/benchmark/probe.py", line 143, in do_one_epoch
preds = self.probe(x, k)
File "/home/liuyuezhangadam/Git/atari-representation-learning/atariari/benchmark/probe.py", line 117,
wandb: Waiting for W&B process to finish, PID 30113
in probe
assert len(f.squeeze().shape) == 2, "if input is a batch of vectors you must specify an encoder!"
AssertionError: if input is a batch of vectors you must specify an encoder!
seems the encoder is simply defined as None
in the code,
atari-representation-learning/scripts/run_probe.py
Lines 36 to 37 in 017f926
any help for fixing it to reproduce the result in Table 10? Or did I make any mistake?
Thanks, @ankeshanand
We should re-evaluate the downstream RL performance of representations on some control tasks, now that the bilinear classifier is working.
In the Berzerk game dictionary there are 8 indices corresponding to robot_x coordinates and 9 indices corresponding to robot_y coordinates.
enemy_robots_x=range(65, 73),
enemy_robots_y=range(56, 65),
Is this a bug/typo?
Here I have found several conflicts in requirement:
Awesome work guys. Can't wait to take this for a spin.
Recently Mnih's group released a paper on representation learning using Atari.
Unsupervised Learning of Object Keypoints for Perception and Control
https://arxiv.org/abs/1906.11883
Are you guys intending to benchmark their transporter network?
I'd be interested to know your thoughts. (I also have a working pytorch implementation of the transporter network).
I'm trying to test this implementation and I installed the packages according to the readme. Also, I have installed gym[accept-rom-licence] correctly.
I get this error
File "path-to-package/atariari/benchmark/wrapper.py", line 7, in step
observation, reward, done, info = self.env.step(action)
This is probably because gym was updated and now step returns 5 values instead of 4.
So which version of gym was used for this project? I need it to install the correct version
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.