Giter VIP home page Giter VIP logo

drqv2's Introduction

DrQ-v2: Improved Data-Augmented RL Agent

This is an original PyTorch implementation of DrQ-v2 from

[Mastering Visual Continuous Control: Improved Data-Augmented Reinforcement Learning] by

Denis Yarats, Rob Fergus, Alessandro Lazaric, and Lerrel Pinto.

Method

DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ, an actor-critic approach that uses data augmentation to learn directly from pixels. We introduce several improvements including:

  • Switch the base RL learner from SAC to DDPG.
  • Incorporate n-step returns to estimate TD error.
  • Introduce a decaying schedule for exploration noise.
  • Make implementation 3.5 times faster.
  • Find better hyper-parameters.

These changes allow us to significantly improve sample efficiency and wall-clock training time on a set of challenging tasks from the DeepMind Control Suite compared to prior methods. Furthermore, DrQ-v2 is able to solve complex humanoid locomotion tasks directly from pixel observations, previously unattained by model-free RL.

Citation

If you use this repo in your research, please consider citing the paper as follows:

@article{yarats2021drqv2,
  title={Mastering Visual Continuous Control: Improved Data-Augmented Reinforcement Learning},
  author={Denis Yarats and Rob Fergus and Alessandro Lazaric and Lerrel Pinto},
  journal={arXiv preprint arXiv:2107.09645},
  year={2021}
}

Please also cite our original paper:

@inproceedings{yarats2021image,
  title={Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels},
  author={Denis Yarats and Ilya Kostrikov and Rob Fergus},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=GY6-6sTvGaf}
}

Instructions

Install MuJoCo if it is not already the case:

  • Obtain a license on the MuJoCo website.
  • Download MuJoCo binaries here.
  • Unzip the downloaded archive into ~/.mujoco/mujoco200 and place your license key file mjkey.txt at ~/.mujoco.
  • Use the env variables MUJOCO_PY_MJKEY_PATH and MUJOCO_PY_MUJOCO_PATH to specify the MuJoCo license key path and the MuJoCo directory path.
  • Append the MuJoCo subdirectory bin path into the env variable LD_LIBRARY_PATH.

Install the following libraries:

sudo apt update
sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3

Install dependencies:

conda env create -f conda_env.yml
conda activate drqv2

Train the agent:

python train.py task=quadruped_walk

Monitor results:

tensorboard --logdir exp_local

License

The majority of DrQ-v2 is licensed under the MIT license, however portions of the project are available under separate license terms: DeepMind is licensed under the Apache 2.0 license.

drqv2's People

Contributors

aladoro avatar denisyarats avatar desaixie avatar medric49 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

drqv2's Issues

Why using 'F.grid_sample()' in class RandomShiftsAug?

Actually, for the random shift operation, we can easily choose two random variable as the shift axis like this:

assert origin_image.shape=(512,3,84,84)
assert pad_image.shape=(512,3,92,92)
shift_x = random.randint(0,92-84-1)
shift_y = random.randint(0,92-84-1)
aug_image = pad_image[:,:,x:x+84,y:y+84]

Then we can get the augmented image.

I guess interpolation is the reason you guys choose F.grid_sample()?

Replayloader doesn't work for Atari

Have you tried using this replay loader with Atari? I keep getting this error unless I set the num replay workers to 1:

File "/u/slerman/miniconda3/envs/agi/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/u/slerman/miniconda3/envs/agi/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 28, in fetch
    data.append(next(self.dataset_iter))
  File "/home/cxu-serve/u1/slerman/drqv2/replay_buffer.py", line 176, in __iter__
    yield self._sample()
  File "/home/cxu-serve/u1/slerman/drqv2/replay_buffer.py", line 159, in _sample
    episode = self._sample_episode()
  File "/home/cxu-serve/u1/slerman/drqv2/replay_buffer.py", line 99, in _sample_episode
    eps_fn = random.choice(self._episode_fns)
  File "/u/slerman/miniconda3/envs/agi/lib/python3.8/random.py", line 290, in choice
    raise IndexError('Cannot choose from an empty sequence') from None
IndexError: Cannot choose from an empty sequence


Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Edit: Sorry, originally posted the wrong trace.

Get to 96 frame throughput?

Hi! First of all, thank you very much for the nice paper and open sourcing this great codebase, I think it's really nice and clean.

I have been experimenting with the codebase recently, however, I found that on a V100 gpu the maximum FPS I can get is around 60 and not 96 as reported in the paper, I found this on multiple environments, and on walker where a larger batch size is used it will be a bit more slower. So I'm just wondering if there is any detail that I should pay attention to in order to achieve a FPS of 96? For example, do we need to set number of replay workers to a larger number, or do we need to disable tensorboard?

I understand that it's always possible that there is sth wrong with my own environment or my hardware, but just want to check with you to see if I missed anything important.

Thank you so much!

Can't run on school compute...

Hi, I've been going back and forth with a system admin trying to get this to run, but we can't get passed an OpenGL error. This is is how we installed everything:

module load mesa  
module load glfw/3.3.2/b2   
module load mujoco/200 

git clone [email protected]:facebookresearch/drqv2.git
cd drqv2 

conda env create -f conda_env.yml
conda activate drqv2

python train.py task=quadruped_walk

Then this returns the following error no matter what we do:

OpenGL.raw.EGL._errors.EGLError: EGLError(
        err = EGL_BAD_PARAMETER,
        baseOperation = eglGetPlatformDisplayEXT,
        cArguments = (
                12607,
                <OpenGL._opaque.EGLDeviceEXT_pointer object at 0x2ab4078bd240>,
                None,
        ),
        result = <OpenGL._opaque.EGLDisplay_pointer object at 0x2ab4078bd5c0>
)
(/s

Here is what the system admin sent me:

hi Sam unfortunately I’ve had no luck getting this to work besides spending quite a bit of time on it today. I get the same EGL error unless I’m on a visual node (bhx nodes..) but if I do everything (including the build/install) on a visual node I get a core dump with Illegal Instruction. it is possible that’s due to one of the dependencies from pip/conda and if it were compiled from source on a bhx node that might go away but I have no idea which one...

Any ideas?

How can i start with window 10.

I'm using Window10

and i stuck with error like
Traceback (most recent call last): File "train.py", line 19, in <module> import dmc File "C:\Users\KANG\Desktop\RLstudy\drqv2\dmc.py", line 10, in <module> from dm_control import manipulation, suite File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\manipulation\__init__.py", line 20, in <module> from dm_control import composer as _composer File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\composer\__init__.py", line 18, in <module> from dm_control.composer.arena import Arena File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\composer\arena.py", line 20, in <module> from dm_control import mjcf File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\mjcf\__init__.py", line 18, in <module> from dm_control.mjcf.attribute import Asset File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\mjcf\attribute.py", line 28, in <module> from dm_control.mujoco.wrapper import util File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\mujoco\__init__.py", line 18, in <module> from dm_control.mujoco.engine import action_spec File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\mujoco\engine.py", line 41, in <module> from dm_control import _render File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\_render\__init__.py", line 67, in <module> Renderer = import_func() # pylint: disable=invalid-name File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\_render\__init__.py", line 36, in _import_egl from dm_control._render.pyopengl.egl_renderer import EGLContext File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\_render\pyopengl\egl_renderer.py", line 39, in <module> from dm_control._render.pyopengl import egl_ext as EGL File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\dm_control\_render\pyopengl\egl_ext.py", line 19, in <module> from OpenGL.platform import ctypesloader # pylint: disable=g-bad-import-order File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\OpenGL\platform\__init__.py", line 36, in <module> _load() File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\OpenGL\platform\__init__.py", line 33, in _load plugin.install(globals()) File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\OpenGL\platform\baseplatform.py", line 97, in install namespace[ name ] = getattr(self,name,None) File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\OpenGL\platform\baseplatform.py", line 15, in __get__ value = self.fget( obj ) File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\OpenGL\platform\egl.py", line 106, in GetCurrentContext return self.EGL.eglGetCurrentContext File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\OpenGL\platform\baseplatform.py", line 15, in __get__ value = self.fget( obj ) File "C:\Users\KANG\anaconda3\envs\drqv2\lib\site-packages\OpenGL\platform\egl.py", line 86, in EGL raise ImportError("Unable to load EGL library", *err.args) ImportError: ('Unable to load EGL library', "Could not find module 'EGL' (or one of its dependencies). Try using the full path with constructor syntax.", 'EGL', None)

i know it was someting wrong installing about
sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3

but, i don't know how to intsall libosmesa6-dev libgl1-mesa-glx libglfw3 on window 10,,,

reproducing results for humanoid run

Thanks for open sourcing this! I wanted to replicate the results on the humanoid run/stand task. Using the default parameter I'm able to get similar result on the walker stand task, i.e. ~400 reward after 15 million stems. However, for the walker run task no learning seems to happen, see attached figure which is averaged across 5 runs. The only change I've made to the parameters is setting stddev_schedule: 'linear(1.0,0.1,2000000). Are there any other modifications necessary for the humanoid run task? Thanks!!

perf

a certain chance to fail

Hi, @denisyarats , Thanks for your great work!

I rerun your code and find there is a certain chance to fail to train a good agent. For example, as the below picture shows, I run Cartpole Swing 4 seeds with your config file except that I shorten the total frame to 200K. According to curves in the paper, 200K frames are enough for agents to perform well in Cartpole Swing. However, one of 4 runnings failed. I have checked data in curves folder . It seems that you don't encounter this problem.
image

Similar problems occur in other tasks such as Walker Run. I just rerun these two tasks so far. But I doubt this problem would happen also in other tasks.
Now I doubt the version of pytoch or Mujoco caused this since I used my own docker. I will try to use the same version with you and report the new results later. If you have encountered similar problems or have any idea about why this happened, please tell me.

image

Question regarding indices for replay buffer

I am a little confused regarding how the indices work for the replay buffer. Specifically, a "dummy" transition is repeatedly references in replay_buffer.py, and there are some +/- 1 made to the indices:

idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
obs = episode['observation'][idx - 1]
action = episode['action'][idx]
next_obs = episode['observation'][idx + self._nstep - 1]
reward = np.zeros_like(episode['reward'][idx])

From reading the code, it seems like the storage layout is

----------------------------------------------------------
rewards           |   None   |  reward_0  |  reward_1  |   ........
----------------------------------------------------------
observations      |   obs_0  |   obs_1    |   obs_2    |   ........
----------------------------------------------------------
actions           |    None  |  action_0  |  action_1  |   .......
----------------------------------------------------------

Is that the correct interpretation of the memory layout? And if so, why is the offset used?

Thanks for open-sourcing this!!

custom env

Hello I would like to train your algorithm on the fish task of the dm_control_suite, however, it's not in the list, could you please tell me where to start. I also have a custom visual env based on pyglet, what steps should I take to couple it with ur algorithm.
Thanks in advance!

Error when Running Training Command

Hello,

When I ran the training command given in the readme python train.py task=quadruped_walk I got this error:

File "/home/anavani/anaconda3/lib/python3.9/site-packages/hydra/_internal/defaults_list.py", line 168, in ensure_overrides_used raise ConfigCompositionException(msg) hydra.errors.ConfigCompositionException: Could not override 'task'. Did you mean to override task@_global_? To append to your default list use +task=quadruped_walk

I changed the command to python train.py +task=quadruped_walk and this seemed to fix the issue. However, after I let it train for a bit, I got this error

It seems as if the +task=quadruped_walk is causing an EOF error, but I'm not sure what is causing the seocnd error. I would really appreciate any help. @denisyarats @Aladoro @desaixie @medric49

Truncation for exploration?

I'm reading the paper and code, and can't follow the truncation process. Table 2 sets exploration stddev. clip equal to 0.3, so I assume that the exploration noise is clipped. However, the action seems to be selected by action = dist.sample(clip=None) which does no clipping. Instead clipping is seemingly applied during training with dist.sample(clip=self.stddev_clip). Am I misunderstanding something here? Thanks!!

Dreamerv2 learning curve

Hi, @denisyarats,
In the paper, you said that you run the dreamerv2 to get the learning curve on 12 Deep mind control suit tasks. Could you kindly share the learning curve of DreamerV2?

BR,

Great work! Exploration noise question...

Great work! Jus ta question about exploration noise. Is there a benefit/reason for using exploration noise over a stochastic policy with a learned mean and standard variation?

One more question, sorry...

Hi, is the random augmentation implemented a faster/better transform than the torchvision transforms.RandomAffine function?

Training Manipulation Tasks

Hello,

Could someone kindly tell me what command I should run to train manipulation tasks?
Is it by creating .yaml files for them?

Thank you!

Redundant n-step trick in this repo?

Hi,

There are two places in this code base that use n-step returns trick for TD error.

  1. drqv2/dmc.py

    Line 40 in ccb9a4d

    def step(self, action):
  2. for i in range(self._nstep):

For the first place, num_repeats=2, in the ActionRepeatWrapper of env wrapper dmc.
For the second place, _nstep=3, in the replay_buffer.
And both of them are used the discount reward instead of the original reward.

I think it is inconsistent with the parameters described in the paper.

Hyperparameters optimization

🚀 Feature Request

Hyperparameter optimization for the fish environment of dm_control.

Motivation

I tried easy,medium,hard sets for the upright and swim task of the fish environment but neither of them seemed to work for the swim task.

Could you provide the way(script) you found the hyperparameters for other envs? Otherwise I can open the pull request if I succeed for the fish env.

Length of replay_loader?

Hi, I tried adding a length check to the replay loader class:

def __len__(self):
        return self._size

But then when I call len(self.replay_loader) I get a warning about the size being 0 even though items have been fetched. You wouldn't know off the top of your head why this is, would you? Thanks!

Memory Leak

when I run multi seed of this code, I found the problem of memory leak. This problem may be caused by the 'dataloader', when I set the worker num as 0 the memory leak problem is solved, but in this setting the training process is too slow.

Evaluation video overwritten

I have noticed, in this line when you are saving evaluation episodes, that you are actually overwriting previous evaluation videos of the same global frame when num_eval_episodes > 1.

            self.video_recorder.save(f'{self.global_frame}.mp4')

The attribute self.global_frame doesn't change during the self.cfg.num_eval_episodes evaluations in the loop.

I would propose this line instead

            self.video_recorder.save(f'{self.global_frame}_{episode}.mp4')

or other things like that.

A question about num_train_frames

Hi,

Great work, thanks for making it open-source!

Could you tell me why in the configs you use more num_train_frames than reported in the paper? In the paper, the numbers of frames for easy/medium/hard tasks are 1/3/30 * 10^6; however, in the config files (easy.yaml, medium.yaml, hard.yaml) the numbers are 1.1/3.1/30.1 * 10^6

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.