Giter VIP home page Giter VIP logo

zehlucas / reacher-deep-reinforcement-learning Goto Github PK

View Code? Open in Web Editor NEW

This project forked from koulakis/reacher-deep-reinforcement-learning

1.0 0.0 0.0 231.92 MB

This is a solution for the second project of the Udacity deep reinforcement learning course. It includes code for training an agent and for using it in a simulation environment.

License: MIT License

Shell 0.13% Python 24.07% Jupyter Notebook 75.81%

reacher-deep-reinforcement-learning's Introduction

Reacher

Introduction

This is a solution for the second project of the Udacity deep reinforcement learning course. It includes scripts for training agents using any of the A2C, PPO and TD3 algorithms and for testing it in a simulation environment. The models were trained using the Stable Baselines3 project.

Example agents

The giff shows the behavior of multiple agents using a model trained using PPO in this codebase. The agent parameters can be found under experiments/ppo_multi_agent_lr_0_00003/model.zip. Agent test run

Problem description

The agent consists of an arm with two joints and the environment contains a sphere which is rotating around the agent. The goal is to keep touching the ball as long as possible during an episode of 1000 timesteps.

  • Rewards:

    • +0.04 for each timestep the agent touches the sphere
  • Input state:

    • 33 continuous variables corresponding to position, rotation, velocity, and angular velocities of the arm
  • Actions:

    • 4 continuous variables, corresponding to torque applicable to two joints with values in [-1.0, 1.0]
  • Goal:

    • Get an average score of at least +30 over 100 consecutive episodes
  • Environments: Two environments are available, one with a single agent and one with 20 agents. The evaluation for the 20 agents environment differs in that the reward of each episode is the average of all agent rewards. In training the only difference is that one can practically simulate 20 environments in one to speed up exploration.

Solution

The problem is solved with all A2C, PPO and TD3 using the stable baselines framework. For more details and a comparison of the algorithms' behavior look in the corresponding report.

Setup project

To setup the project follow those steps:

  • Provide an environment with python 3.6.x installed, ideally create a new one with e.g. pyenv or conda
  • Clone and install the project:
git clone [email protected]:koulakis/reacher-deep-reinforcement-learning.git
cd reacher-reinforcement-learning
pip install .
  • Create a directory called udacity_reacher_environment_single_agent or udacity_reacher_environment_multi_agent (to use with the single or 20 agent environments respectively) under the root of the project and download and extract there the environment compatible with your architecture. You can find the download links here.
  • Install a version of pytorch compatible with your architecture. The version used to develop the project was 1.5.0. e.g. pip install pytorch==1.5.0

To check that everything is setup properly, run the following test which loads an environment and runs a random agent: python scripts/run_agent_in_environment.py --random-agent

or

python scripts/run_agents_in_environment.py --random-agent --agent-type multi

which run the 20 agents environment.

Training and testing the agent

The project comes along with some pre-trained agents, scripts to test them and train your own.

Scripts

  • train_agent.py: This one is used to train an agent. The parameter experiment-name is used to name your agent and the script will create by default a directory under experiments with the same name. The trained agent parameters will be saved there in the end of the training and during training several metrics are logged to a tfevents file under the same directory. Here is an example call: python scripts/train_agent.py --experiment-name td3_rl_0_001 --agent-type single --learning-rate 0.001 --rl-algorithm td3 --total-timesteps 500000 --environment-port 5005

    To monitor the metrics one can launch a tensorboard server with: tensorboard --logdir experiments This will read the metrics of all experiments and make the available under localhost:6006

    One can run multiple trainings in parallel by using different ports per environment with the environment-port flag.

  • test_agent_in_environment: This script can be used to test an agent on a given environment. As mentioned above, one can access the saved agent models inside the sub-folders of experiments. An example usage: python scripts/run_agents_in_environment.py --agent-type multi --agent-parameters-path experiments/ppo_multi_agent_lr_0_00003/model.zip --environment-port 5007

Pre-trained models

Under the experiments directory there are several pre-trained agents one can used to run in the environment. Some examples of models which have solved the environment are:

  • Best A2C model: a2c_lr_0_0001/tensorboard_logs/A2C_1
  • Best PPO model: ppo_multi_agent_lr_0_00003/tensorboard_logs/PPO_1
  • Best PPO model trained with a single agent: ppo_large_128_128_128_lr_0_00005_3M_steps/tensorboard_logs/PPO_1
  • Best TD3 model trained with a single agent: td3_0_001/tensorboard_logs/TD3_1

References

Given that this project is an assignment of an online course, it has been influenced heavily by code provided by Udacity and several mainstream publications. Below you can find some links which can give some broader context.

Frameworks & codebases

  1. All 3 algorithms used were trained using the Stable Baselines3 project
  2. Most of the simulation setup comes from this notebook
  3. The unity environment created by Udacity is a direct copy from here

Publications

The following publications were used:

  1. Asynchronous Methods for Deep Reinforcement Learning. Mnih, V., Badia, A.P., Mirza, M., Graves, A., Lillicrap, T.P., Harley, T., Silver, D., & Kavukcuoglu, K. arXiv:1602.01783. 2016.
  2. Proximal Policy Optimization Algorithms. John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov. arXiv:1707.06347. 2017.
  3. High-Dimensional Continuous Control Using Generalized Advantage Estimation. John Schulman, Philipp Moritz, Sergey Levine, Michael Jordan, Pieter Abbeel. arXiv:1506.02438. 2015.
  4. Continuous control with deep reinforcement learning. Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, Daan Wierstra. arXiv:1509.02971. 2015.
  5. Addressing Function Approximation Error in Actor-Critic Methods. Scott Fujimoto, Herke van Hoof, David Meger. arXiv:1802.09477. 2018.

reacher-deep-reinforcement-learning's People

Contributors

koulakis avatar

Stargazers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.