Giter VIP home page Giter VIP logo

dreamerv2's Introduction

Dreamer- v2 Pytorch

Pytorch implementation of Mastering Atari with Discrete World Models

Installation

Dependencies:

I have added requirements.txt using conda list -e > requirements.txt and environment.yml using conda env export > environment.yml from my own conda environment.
I think it is easier to create a new conda environment(or venv etc.) and manually install the above listed few dependencies one by one.

Running experiments

  1. In tests folder, mdp.py and pomdp.py have been setup for experiments with MinAtar environments. All default hyper-parameters used are stored in a dataclass in config.py. To run dreamerv2 with default HPs on POMDP breakout and cuda :
python pomdp.py --env breakout --device cuda
  • Training curves are logged using wandb.
  • A results folder will be created locally to store models while training: test/results/env_name+'_'+env_id+'_'+pomdp/models
  1. Experimenting on other environments(using gym-api) can be done by adding another hyper-parameter dataclass in config.py.

Evaluating saved models

Trained models for all 5 games (mdp and pomdp version of each) are uploaded to the drive link: link (64 MBs)
Download and unzip the models inside /test directory.

Evaluate the saved model for POMDP version of breakout environment for 5 episodes, without rendering:

python eval.py --env breakout --eval_episode 5 --eval_render 0 --pomdp 1

Evaluation Results

Average evaluation score(over 50 evaluation episodes) of models saved at every 0.1 million frames. Green curves correspond to agent which have access to complete information, while red curves correspond to agents trained with partial observability.

In freeway, the agent gets stuck in a local maxima, wherein it learns to always move forward. The reason being that it is not penalised for crashing into cars. Probably due to policy entropy regularisation, its returns drop drastically around the 1 million frame mark, and gradually improve while maintaing the policy entropy.

Training curves

All experiments were logged using wandb. Training runs for all MDP and POMDP variants of MinAtar environments can be found on the wandb project page.

Please create an issue if you find a bug or have any queries.

Code structure:

  • test
    • pomdp.py run MinAtar experiments with partial observability.
    • mdp.py run MinAtar experiments with complete observability.
    • eval.y evaluate saved agents.
  • dreamerv2 dreamerv2 plus dreamerv1 and their combinations.
    • models neural network models.
      • actor.py discrete action model.
      • dense.py fully connected neural networks.
      • pixel.py convolutional encoder and decoder.
      • rssm.py recurrent state space model.
    • training
      • config.py hyper-parameter dataclass.
      • trainer.py training class, loss calculation.
      • evaluator.py evaluation class.
    • utils
      • algorithm.py lambda return function.
      • buffer.py replay buffers, batches of sequences.
      • module.py neural network parameters utils.
      • rssm.py recurrent state space model utils.
      • wrapper.py gym api and pomdp wrappers for MinAtar.

Hyper-Parameter description:

  • train_every: number of frames to skip while training.
  • collect_intervals: number of batches to be sampled from buffer, at every "train-every" iteration.
  • seq_len: length of trajectory sequence to be sampled from buffer.
  • embedding_size: size of embedding vector that is output by observation encoder.
  • rssm_type: categorical or gaussian random variables for stochastic states.
  • rssm_node_size: size of hidden layers of temporal posteriors and priors.
  • deter_size: size of deterministic part of recurrent state.
  • stoch_size: size of stochastic part of recurrent state.
  • class_size: number of classes for each categorical random variable
  • category_size: number of categorical random variables.
  • horizon: horizon for imagination in future latent state space.
  • kl_balance_scale: scale for kl balancing.
  • actor_entropy_scale: scale for policy entropy regularization in latent state space.

Acknowledgments

Awesome Environments used for testing:

This code is heavily inspired by the following works:

dreamerv2's People

Contributors

rajghugare19 avatar m-barker avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.