Giter VIP home page Giter VIP logo

batch_rl's Introduction

Google Research

This repository contains code released by Google Research.

All datasets in this repository are released under the CC BY 4.0 International license, which can be found here: https://creativecommons.org/licenses/by/4.0/legalcode. All source files in this repository are released under the Apache 2.0 license, the text of which can be found in the LICENSE file.


Because the repo is large, we recommend you download only the subdirectory of interest:

SUBDIR=foo
svn export https://github.com/google-research/google-research/trunk/$SUBDIR

If you'd like to submit a pull request, you'll need to clone the repository; we recommend making a shallow clone (without history).

git clone [email protected]:google-research/google-research.git --depth=1

Disclaimer: This is not an official Google product.

Updated in 2023.

batch_rl's People

Contributors

agarwl avatar tangbotony avatar thesparta avatar zhixuan-lin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

batch_rl's Issues

configurable() got an unexpected keyword argument 'blacklist'

Hi, Iam newibe in RL, but i want to implement the code of this article Conservative Q-Learning for Offline Reinforcement Learning

as its mentioned in Atari experiment section

i run this code

python -um batch_rl.fixed_replay.train \
  --base_dir=/tmp/batch_rl \
  --replay_dir=$DATA_DIR/Pong/1 \
  --agent_name=quantile \
  --gin_files='batch_rl/fixed_replay/configs/quantile_pong.gin' \
  --gin_bindings='FixedReplayRunner.num_iterations=1000' \
  --gin_bindings='atari_lib.create_atari_environment.game_name = "Pong"'
  --gin_bindings='FixedReplayQuantileAgent.minq_weight=1.0'

but it gives me this error:

File "/home/masooti/Projects/DQN/CQL/atari/batch_rl/fixed_replay/train.py", line 33, in <module>
    from batch_rl.fixed_replay.agents import dqn_agent
  File "/home/masooti/Projects/DQN/CQL/atari/batch_rl/fixed_replay/agents/dqn_agent.py", line 25, in <module>
    from batch_rl.fixed_replay.replay_memory import fixed_replay_buffer
  File "/home/masooti/Projects/DQN/CQL/atari/batch_rl/fixed_replay/replay_memory/fixed_replay_buffer.py", line 25, in <module>
    from batch_rl.baselines.replay_memory import logged_replay_buffer
  File "/home/masooti/Projects/DQN/CQL/atari/batch_rl/baselines/replay_memory/logged_replay_buffer.py", line 92, in <module>
    @gin.configurable(blacklist=['observation_shape', 'stack_size',
TypeError: configurable() got an unexpected keyword argument 'blacklist'

how can i fix this

please help

thank you in advanced

How to get expert data ?

I have read all the issues. Thanks for responding to them. I had a query regarding extracting expert data from the replay buffer. As you've specified in one of the issues, due to the size of the data, 50M datapoints from the game have been split into 50 files of 1M datapoints each. I wanted to ask, does this mean that the last 10 files in [GAME_NAME]/1/replay_logs/ (buffer files ending (having suffix) with 41, 42...50) represent expert behavior while the first buffer files (ending with 1, 2,..10) represent beginner level performance ? I tried searching about my question and found this in the AI blog For example, the first k million frames from the DQN replay dataset emulate exploration data with suboptimal returns while the last k million frames are analogous to near-expert data with stochasticity. So going by what is being said, by considering the buffer files that end with 40-50, is it safe to assume that I'm extracting expert level behavior ?

Asterix/1 dataset broken?

Hi,

I tried reproducing the offline REM results with Asterix/1 dataset by using the command below:

python -um batch_rl.fixed_replay.train \
  --base_dir=/tmp/batch_rl \
  --replay_dir=/data_large/readonly/atari/Asterix/1 \
  --agent_name=multi_head_dqn \
  --gin_files='batch_rl/fixed_replay/configs/rem.gin' \
  --gin_bindings='FixedReplayRunner.num_iterations=1000' \
  --gin_bindings='atari_lib.create_atari_environment.game_name = "Asterix"'

But could not reproduce the results (about avg 50 return on 200th iteration).
Meanwhile, I can reproduce the results in other Asterix datasets (e.g. Asterix/2, ...).
Could you check if the Asterix/1 dataset has some errors?

Thank you!

Can a customized env be added to the current framework?

Hi, I am new to this repo and offline setting for RL. I guess it should be possible, but still would like to hear some suggestions from the pros. More importantly, if it is possible to add new gym env, how to prepare the offline data?

Thank you very much!

Getting 7 as action for a game with 3 actions

I have been trying to train an online agent on the environment FreewayNoFrameskip-v4. Because this gym environment is not deterministic, I seeded the environment. Specifically, in atari_lib.py, I added

  • env.seed(0) after env = gym.make(full_game_name) in create_atari_environment
  • self.environment.seed(0) at the end of the AtariPreprocessing class's __init__ function
  • self.environment.seed(0) at the start of the reset function in the AtariPreprocessing class

No other changes were made. I then used this repo to train an online agent.

In all of training, there was one instance of a 7 stored as the action (specifically the last action in the very first action checkpoint stored in replay_logs), even though Freeway only has three actions. All other stored actions were {0, 1, 2}. Any ideas what could be the cause of this, or has anything similar been observed? Going in and changing this one 7 to the most common action isn't a problem, but if this problem arises repeatedly, and for other games, it could be difficult to deal with.

a minor modification for download command

Hi,

Thanks for sharing this interesting project.

In README.md, the download command is gsutil -m cp -R gs://atari-replay-datasets/dqn. Directly running the command will make us encounter CommandException: Wrong number of arguments for "cp" command.. It will work by change the command to gsutil -m cp -R gs://atari-replay-datasets/dqn ./. Though it's just a minor modification and naive for the users that are familiar with gsutil, it's good for the novice ๐Ÿ˜„.

Why use a small batch size?

Hi,
May I ask why you have used such a small batch size?
Since you have mentioned in the paper that a larger batch size would lead to a significant speed up. Why still 32 in the standard implementation?
I am trying to implement your work but found it takes over a week to train one agent with the hyper parameters in the paper.
Thus I am confused and curious.

Ambiguous selector

hello Rishabh Agarwal:
when I try to run this step [Test for training an agent with fixed replay buffer]
it always have this problem:

(python27) โžœ batch_rl git:(master) โœ— python -um batch_rl.tests.fixed_replay_runner_test
--replay_dir=$DATA_DIR/Pong/1
Running tests under Python 3.7.3: /Users/xiayong/anaconda3/envs/python27/bin/python
[ RUN ] FixedReplayRunnerIntegrationTest.testIntegrationFixedReplayREM
INFO:tensorflow:####### Training the REM agent #####
I0329 15:56:52.711899 4739141120 fixed_replay_runner_test.py:81] ####### Training the REM agent #####
INFO:tensorflow:####### REM base_dir: /tmp/batch_rltests/run_2021_03_29_13_56_52
I0329 15:56:52.712044 4739141120 fixed_replay_runner_test.py:82] ####### REM base_dir: /tmp/batch_rltests/run_2021_03_29_13_56_52
INFO:tensorflow:####### replay_dir: /Users/xiayong/Pong/1
I0329 15:56:52.712151 4739141120 fixed_replay_runner_test.py:83] ####### replay_dir: /Users/xiayong/Pong/1
INFO:tensorflow:time(main.FixedReplayRunnerIntegrationTest.testIntegrationFixedReplayREM): 0.01s
I0329 15:56:52.725455 4739141120 test_util.py:2076] time(main.FixedReplayRunnerIntegrationTest.testIntegrationFixedReplayREM): 0.01s
[ FAILED ] FixedReplayRunnerIntegrationTest.testIntegrationFixedReplayREM
[ RUN ] FixedReplayRunnerIntegrationTest.test_session
[ SKIPPED ] FixedReplayRunnerIntegrationTest.test_session

ERROR: testIntegrationFixedReplayREM (main.FixedReplayRunnerIntegrationTest)
FixedReplayRunnerIntegrationTest.testIntegrationFixedReplayREM
Test the FixedReplayMultiHeadDQN agent.

Traceback (most recent call last):
File "/Users/xiayong/batch_rl/batch_rl/tests/fixed_replay_runner_test.py", line 85, in testIntegrationFixedReplayREM
train.main([])
File "/Users/xiayong/batch_rl/batch_rl/fixed_replay/train.py", line 95, in main
base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/dopamine/discrete_domains/run_experiment.py", line 56, in load_gin_configs
skip_unknown=False)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/config.py", line 1810, in parse_config_files_and_bindings
includes_and_imports = parse_config_file(config_file, skip_unknown)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/config.py", line 1764, in parse_config_file
includes, imports = parse_config(f, skip_unknown=skip_unknown)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/config.py", line 1657, in parse_config
bind_parameter((scope, selector, arg_name), value)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/contextlib.py", line 130, in exit
self.gen.throw(type, value, traceback)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/utils.py", line 58, in try_with_location
augment_exception_message_and_reraise(exception, format_location(location))
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
raise proxy.with_traceback(exception.traceback) from None
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/utils.py", line 56, in try_with_location
yield
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/config.py", line 1657, in parse_config
bind_parameter((scope, selector, arg_name), value)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/config.py", line 645, in bind_parameter
pbk = ParsedBindingKey(binding_key)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/config.py", line 508, in new
configurable
= _REGISTRY.get_match(selector)
File "/Users/xiayong/anaconda3/envs/python27/lib/python3.7/site-packages/gin/selector_map.py", line 169, in get_match
raise KeyError(err_str.format(partial_selector, matching_selectors))
KeyError: "Ambiguous selector 'FixedReplayMultiHeadDQNAgent', matches ['fixed_replay.agents.multi_head_dqn_agent.FixedReplayMultiHeadDQNAgent', 'batch_rl.fixed_replay.agents.multi_head_dqn_agent.FixedReplayMultiHeadDQNAgent']."
In file "batch_rl/fixed_replay/configs/rem.gin", line 10
FixedReplayMultiHeadDQNAgent.gamma = 0.99


Ran 2 tests in 0.021s

FAILED (errors=1, skipped=1)

Can I train with my own dataset?

Thanks for the great work first!
I have a bunch of data in (state, action, reward, next state) format. I try to understand how you guys parse the $store$_action_ckpt file in the code but I failed. It would be greatful if you could provide a way to train this model with my own dataset~

Thanks again

TF Version must < 2.0?

The code cant work in my environment where TF version is the newest(2.2.0).
because the tf.contrib moudle has been removed?

Windows: basic test failed

I am on Windows, and run the basic test python -um batch_rl.tests.atari_init_test. But it failed.

Here is the traceback:

Running tests under Python 3.9.12: C:\Users\cenyyang\Anaconda3\python.exe
[ RUN ] AtariInitTest.test_atari_init
INFO:tensorflow:Saving replay buffer data to C:\Users\cenyyang\OneDrive\ -\ City\ University\ of\ Hong Kong\batch_rl\replay_logs
I0131 22:32:19.065589 27504 train.py:75] Saving replay buffer data to C:\Users\cenyyang\OneDrive\ -\ City\ University\ of\ Hong Kong\batch_rl\replay_logs
W0131 22:32:19.066590 27504 run_experiment.py:267] DEPRECATION WARNING: Logger is being deprecated. Please switch to CollectorDispatcher!
INFO:tensorflow:time(main.AtariInitTest.test_atari_init): 0.02s
I0131 22:32:19.066590 27504 test_util.py:2462] time(main.AtariInitTest.test_atari_init): 0.02s
[ FAILED ] AtariInitTest.test_atari_init
[ RUN ] AtariInitTest.test_session
[ SKIPPED ] AtariInitTest.test_session

ERROR: test_atari_init (main.AtariInitTest)
AtariInitTest.test_atari_init
Tests that a DQN agent is initialized.

Traceback (most recent call last):
File "C:\Users\cenyyang\OneDrive - City University of Hong Kong\batch_rl\batch_rl\tests\atari_init_test.py", line 49, in test_atari_init
train.main([])
File "C:\Users\cenyyang\OneDrive - City University of Hong Kong\batch_rl\batch_rl\baselines\train.py", line 78, in main
runner = LoggedRunner(FLAGS.base_dir, create_agent_fn)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\config.py", line 1605, in gin_wrapper
utils.augment_exception_message_and_reraise(e, err_str)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\utils.py", line 41, in augment_exception_message_and_reraise
raise proxy.with_traceback(exception.traceback) from None
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\config.py", line 1582, in gin_wrapper
return fn(*new_args, **new_kwargs)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\config.py", line 1605, in gin_wrapper
utils.augment_exception_message_and_reraise(e, err_str)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\utils.py", line 41, in augment_exception_message_and_reraise
raise proxy.with_traceback(exception.traceback) from None
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\config.py", line 1582, in gin_wrapper
return fn(*new_args, **new_kwargs)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\dopamine\discrete_domains\run_experiment.py", line 222, in init
self._environment = create_environment_fn()
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\config.py", line 1605, in gin_wrapper
utils.augment_exception_message_and_reraise(e, err_str)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\utils.py", line 41, in augment_exception_message_and_reraise
raise proxy.with_traceback(exception.traceback) from None
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gin\config.py", line 1582, in gin_wrapper
return fn(*new_args, **new_kwargs)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\dopamine\discrete_domains\atari_lib.py", line 96, in create_atari_environment
env = gym.make(full_game_name)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gym\envs\registration.py", line 607, in make
_check_version_exists(ns, name, version)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gym\envs\registration.py", line 234, in _check_version_exists
_check_name_exists(ns, name)
File "C:\Users\cenyyang\Anaconda3\lib\site-packages\gym\envs\registration.py", line 212, in _check_name_exists
raise error.NameNotFound(
gym.error.NameNotFound: Environment PongNoFrameskip doesn't exist.
In call to configurable 'create_atari_environment' (<function create_atari_environment at 0x000001E26C2BB280>)
In call to configurable 'Runner' (<class 'dopamine.discrete_domains.run_experiment.Runner'>)
In call to configurable 'LoggedRunner' (<class 'batch_rl.baselines.run_experiment.LoggedRunner'>)

Ran 2 tests in 0.027s

FAILED (errors=1, skipped=1)

The error seems to be related to the base_dir?

Python version?

Can we get the python version used in the documentation? I'm trying to implement this in Ubuntu 20 so I need to manually install a previous version of python.

Would also be helpful to the the versions used for absl-py, atari-py, gin-config, opencv-python, gym, numpy. As well as thewhich virtual environment manager that was used (e.g. conda, venv)

Save the trained model to hdf5 file format

Hi,

I'm trying to make the trained model made by the offline agents to work with my online environment, which is written in Golang and loads models from hdf5 files. But when I'm looking at the source code from this repo, I can't seem to find a way to do this easily.

Is there an "easy" way to save the trained model as a hdf5 file, not just checkpoints?

No module named 'dopamine.google'

The train.py file located in batch_rl/fixed_replay has on line 41 from dopamine.google import xm_utils. When I run the train.py file (e.g. by running the fixed_replay_runner_test test as shown on the README), I get the following error: ModuleNotFoundError: No module named 'dopamine.google'. I downloaded the dopamine library as instructed in the README file. I found that replacing this line with what was there from a previous version, from dopamine.discrete_domains import train as base_train allowed the fixed_replay_runner_test test to run successfully. Training offline agents also appears to work correctly, in that no errors are thrown and performance appears to improve over iterations according to the saved logs.

Is the dopamine.google module from a previous or experimental version of dopamine-rl, or have I done something incorrectly? Would changing line 41 to what was there before introduce any issues, such as computational slowdowns or inaccuracies?

The flag 'base_dir' is defined twice

MWE:
pip install domaine-rl
git clone https://github.com/google-research/batch_rl.git
cd batch_rl
python -um batch_rl.tests.atari_init_teste

This test raises:
absl.flags._exceptions.DuplicateFlagError: The flag 'base_dir' is defined twice. First from batch_rl.baselines.train, Second from dopamine.discrete_domains.train. Description from first occurrence: Base directory to host all required sub-directories.

And I do find the two definitions.

How to evaluate agents?

I was able to train dqn agents using off-line data. I wonder how to evaluate agents ? e.g. reproduce the figure for Pong in Figure 3?

Reading atari files directly.

Hi, contributing this example of how to read the atari files directly, in case anyone wants to do that.

Note that the data is stored in the same temporal sequence it was logged, as you can see by watching the replay.

import gzip
import cv2
import numpy as np

STORE_FILENAME_PREFIX = '$store$_'

ELEMS = ['observation', 'action', 'reward', 'terminal']

if __name__ == '__main__':

    data = {}

    data_dir = '/home/duane/data/dqn/Breakout/1/replay_logs/'
    suffix = 0
    for elem in ELEMS:
        filename = f'{data_dir}{STORE_FILENAME_PREFIX}{elem}_ckpt.{suffix}.gz'
        with open(filename, 'rb') as f:
            with gzip.GzipFile(fileobj=f) as infile:
                data[elem] = np.load(infile)

    for obs in data['observation']:
        cv2.imshow('obs', obs)
        cv2.waitKey(20)

Difference b/w checkpoint 49 and checkpoint 50

Are the 50 checkpoints indexed 0...49 or 1...50?

The following games are missing gs://atari-replay-datasets/dqn/${g}/1/replay_logs/FILE.50.gz

$ ./check_c50.sh
Carnival missing ckpt50
Centipede missing ckpt50
IceHockey missing ckpt50
StarGunner missing ckpt50
VideoPinball missing ckpt50
YarsRevenge missing ckpt50

check_c50.sh:

games='AirRaid Alien Amidar Assault Asterix Asteroids Atlantis BankHeist BattleZone BeamRider Berzerk Bowling Boxing Breakout Carnival Centipede ChopperCommand CrazyClimber DemonAttack DoubleDunk ElevatorAction Enduro FishingDerby Freeway Frostbite Gopher Gravitar Hero IceHockey Jamesbond JourneyEscape Kangaroo Krull KungFuMaster MontezumaRevenge MsPacman NameThisGame Phoenix Pitfall Pong Pooyan PrivateEye Qbert Riverraid RoadRunner Robotank Seaquest Skiing Solaris SpaceInvaders StarGunner Tennis TimePilot Tutankham UpNDown Venture VideoPinball WizardOfWor YarsRevenge Zaxxon'

for g in ${games[@]}; do
  output=$(gsutil ls gs://atari-replay-datasets/dqn/${g}/1/replay_logs/)
  # echo -n "${g} "
  # echo -n "${output} " | wc -l
  if [ -z "$(echo ${output} | grep 50)" ] ; then echo "${g} missing ckpt50" ; fi
done;

WARNING:absl:Unable to find episode_end_indices. This is expected for old checkpoints.

$ gsutil -m cp -R gs://atari-replay-datasets/dqn/Breakout .
I downloaded the atari-replay-datsets with this command, but I don't see episode_end_indices data.

WARNING:absl:Unable to find episode_end_indices. This is expected for old checkpoints.

and I executed the dopamine code and received the following warning. How can I get episode_end_indices data?

I'd really appreciate your help.

gsutil error

The command fails to execute and produces this error.
command : "gsutil -m cp -R gs://batch-rl-datasets/dqn/Pong/1 $DATA_DIR/Pong"

error: " argument should be integer or bytes-like object, not 'str'
CommandException: 1 file/object could not be transferred. "

JAX code

Hi,

I would like to ask whether there is a jax-based code.

And whether there are some recommendations about jax-based offline rl algorithms.

Thanks!

About DQN replay dataset

Hi

Thank you so much for your contribution. This is a really great repo for students.

I think it will be very nice if we can try the atari offline training with some recently proposed methods.

Could you please recommend some recent papers about offline rl training on atari?

Thank you very much!

Best

Raw results

To facilitate comparison with a method we are developing, is it possible to release raw results (e.g. similar to dopamine json files?)

These data already "exist" as part of your figures in the appendix of your paper, so what we really want is to produce similar figures (comparing our method with your method) without having to rerun yours from scratch.

About terminal in Atari dataset

In the Atari dataset provided, I can tell when the game is over, but I can't tell when the agent has lost a life. (i.e. 3 lives in Alien).
Is there any way to determine this?

Data generation is very slow

I am using the following command to try and generate data for one run of Freeway:

python -um batch_rl.baselines.train \
  --base_dir=/tmp/batch_rl_data \
  --gin_files='batch_rl/baselines/configs/dqn.gin'

I made no changes to the code, except for a workaround to the flags issue at #14. Each iteration is taking ~40 min, and seeing as there are 200 iterations in a single run (out of five runs), at this rate, one run will take over 5 days to generate. This is significantly longer than what is expected according to the response in #13, according to which it should take 3-5 days to generate the entire data.

Was this 3-5 day figure based on the exact config provided? I have verified that I am running on GPU, although with the default settings, I am only using 1 GB of memory. Is there anything else I can look into to see why my runs are so much slower? Thanks!

Retraining the online agent

I'd like to retrain the online DQN agent in order to log some additional data during online training. The README says

This data can be generated by running the online agents using batch_rl/baselines/train.py for 200 million frames (standard protocol)

However, this is not enough information to accurately replicate the setup. Could you share the gin file that was used? Is it the same as the one in dopamine?

Also, is there a faster way to accurately replicate the online training and logging via the RL unplugged project, or is it just the offline part that has become faster?

Data Generation KeyError

I am using the following command to try and generate data:

python -um batch_rl.baselines.train \
  --base_dir=/tmp/batch_rl_data \
  --gin_files='batch_rl/baselines/configs/dqn.gin'

However, I am getting the following error:

WARNING:root:Argument blacklist is deprecated. Please use denylist.
WARNING:root:Argument blacklist is deprecated. Please use denylist.
/data/venv_dopamine/lib/python3.8/site-packages/flax/nn/__init__.py:35: DeprecationWarning: The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/master/flax/linen/README.md
  warnings.warn("The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/master/flax/linen/README.md", DeprecationWarning)
Traceback (most recent call last):
  File "/home/anaconda3/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/anaconda3/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/batch_rl/batch_rl/baselines/train.py", line 83, in <module>
    flags.mark_flag_as_required('base_dir')
  File "/data/venv_dopamine/lib/python3.8/site-packages/absl/flags/_validators.py", line 352, in mark_flag_as_required
    if flag_values[flag_name].default is not None:
  File "/data/venv_dopamine/lib/python3.8/site-packages/absl/flags/_flagvalues.py", line 470, in __getitem__
    return self._flags()[name]
KeyError: 'base_dir'

base_dir doesn't seem to be defined as a flag anywhere, yet it is a required flag. Any ideas what I am doing wrong?

Transform matrix create only once?

The paper says for each mini-batch randomly draw a categorical distribution. But in the code i only find the transform_matrix generate once in _create_network function, and is not change during trainning. Maybe i just miss it, the training option uses q_heads, and q_heads read transform_matrix from Network's kwargs, but i can't find where it update.

Code for offline continuous control experiments

Hi, it is very generous for your team to open the code. However, I did not find the code for Offline Continuous Control Experiments which shown in Fig. 8 in the paper. Did I miss some part in the batch_rl folder? If not, will you post it someday?

THX.

How to train offline agent on the huge dataset (50 Million) ?

Hi, I have read your paper which was published on ICML 2020, now I try to do some research on the offline image data. I have noticed that when training the online agent, such as DQN, replay buffer capacity is usually set to be 1 million, when the size of collected data is above 1 million, the new data will cover the oldest data in the replay buffer. But when training DQN on offline data, such as your contributed data, the size of data is 50M, how do I train the agent on this so huge dataset? Since memory of the machine is limited, it's impossible to load 50M data into the memory once. I wonder that how you solved this problem and if you implement your idea in this project, please refer it to me. At last, I really appreciate your great job and your open-source code!

Native Resolution Images

Hi,

I've been trying to use the data from rl unplugged in its native resolution (210x160). I hoping to replay the rl unplugged actions from the sequential data release for dopamine into the environment to regenerate the frames. However I think sticky actions is making the behavior not align perfectly. Is the script that was used to generate the data publicly available somewhere? Or is there another way I can use the rl unplugged data in the native resolution? Thanks!

Sequential Data

Hi!

I was wondering if there is a way to get sequential data, i.e. o_t in iteration 1 corresponds to o_tm1 in iteration 2?

This doesn't seem to be the case in the RL Unplugged codebase, even when all instances of .shuffle() are commented out. I also tried reading the data directly using tfrecord with the following code:

dataset = tf.data.TFRecordDataset(['rl_unplugged/tmp/atari/Gravitar/run_1-00000-of-00100'],
                                                            compression_type='GZIP')
iterator = iter(dataset)

inputs_1 = next(iterator)
inputs_1 = _tf_example_to_reverb_sample_np(inputs_1)
o_t, a_t, r_t, d_t, o_tp1, a_tp1, extras = inputs_1.data

inputs_2 = next(iterator)
inputs_2 = _tf_example_to_reverb_sample_np(inputs_2)
o_t2, a_t2, r_t2, d_t2, o_tp12, a_tp12, extras2 = inputs_2.data

However, here as well, o_tp1 != o_t2, which is what I'm looking for. Is this possible?

Thanks!

A copy of the dataset?

Hi team!
First thanks a mil for the amazing work!
TorchRL is working on providing an easy, out-of-the-box API to get offline RL datasets such as this one.
We were wondering what the license was for your data?
Could we create a copy in a different format on our bucket to make it available to the users without the need to install gcloud?

Can i use the Tensorboard to check the intermediate results?

Hello, When I run the demo ' python -um batch_rl.fixed_replay.train
--base_dir=/tmp/batch_rl
--replay_dir=$DATA_DIR/Pong/1
--gin_files='batch_rl/fixed_replay/configs/dqn.gin' '
The train demo has cost me 2 days and i have not finished it!
I want to check the intermediate results using Tensorboard, but it seems not work

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.