Giter VIP home page Giter VIP logo

mava's Introduction

Mava logo

Distributed Multi-Agent Reinforcement Learning in JAX

Welcome to Mava! ๐Ÿฆ

Mava provides simplified code for quickly iterating on ideas in multi-agent reinforcement learning (MARL) with useful implementations of MARL algorithms in JAX allowing for easy parallelisation across devices with JAX's pmap. Mava is a project originating in the Research Team at InstaDeep.

To join us in these efforts, please feel free to reach out, raise issues or read our contribution guidelines (or just star ๐ŸŒŸ to stay up to date with the latest developments)!

Overview ๐Ÿฆœ

Mava currently offers the following building blocks for MARL research:

  • ๐Ÿฅ‘ Implementations of MARL algorithms: Implementations of multi-agent PPO systems that follow both the Centralised Training with Decentralised Execution (CTDE) and Decentralised Training with Decentralised Execution (DTDE) MARL paradigms.
  • ๐Ÿฌ Environment Wrappers: Example wrappers for mapping Jumanji environments to an environment that is compatible with Mava. At the moment, we support Robotic Warehouse and Level-Based Foraging with plans to support more environments soon. We have also recently added support for the SMAX environment from JaxMARL.
  • ๐ŸŽ“ Educational Material: Quickstart notebook to demonstrate how Mava can be used and to highlight the added value of JAX-based MARL.
  • ๐Ÿงช Statistically robust evaluation: Mava natively supports logging to json files which adhere to the standard suggested by Gorsane et al. (2022). This enables easy downstream experiment plotting and aggregation using the tools found in the MARL-eval library.

Performance and Speed ๐Ÿš€

SMAX

For comparing Mavaโ€™s stability to other JAX-based baseline algorithms, we train Mavaโ€™s recurrent IPPO and MAPPO systems on a broad range of SMAX tasks. In all cases we do not rerun baselines but instead take results for final win rates from the JaxMARL technical report. For the full SMAX experiments results, please see the following page.

legend

Mava ff mappo tiny 2ag Mava ff mappo tiny 4ag Mava ff mappo small 4ag

Mava Recurrent IPPO and MAPPO performance on the 3s5z, 6h_vs_8z and 3s5z_vs_3s6z SMAX tasks.

Robotic Warehouse

All of the experiments below were performed using an NVIDIA Quadro RTX 4000 GPU with 8GB Memory.

In order to show the utility of end-to-end JAX-based MARL systems and JAX-based environments we compare the speed of Mava against EPyMARL as measured in total training wallclock time on simple Robotic Warehouse (RWARE) tasks with 2 and 4 agents. Our aim is to illustrate the speed increases that are possible with using end-to-end Jax-based systems and we do not necessarily make an effort to achieve optimal performance. For EPyMARL, we use the hyperparameters as recommended by Papoudakis et al. (2020) and for Mava we performed a basic grid search. In both cases, systems were trained up to 20 million total environment steps using 16 vectorised environments.

legend

Mava ff mappo tiny 2ag Mava ff mappo tiny 4ag Mava ff mappo small 4ag

Mava feedforward MAPPO performance on the tiny-2ag, tiny-4ag and small-4ag RWARE tasks.

๐Ÿ“Œ An important note on the differences in converged performance

In order to benefit from the wallclock speed-ups afforded by JAX-based systems it is required that environments also be written in JAX. It is for this reason that Mava does not use the exact same version of the RWARE environment as EPyMARL but instead uses a JAX-based implementation of RWARE found in Jumanji, under the name RobotWarehouse. One of the notable differences in the underlying environment logic is that RobotWarehouse will not attempt to resolve agent collisions but will instead terminate an episode when agents do collide. In our experiments, this appeared to make the environment more challenging. For this reason we show the performance of Mava on Jumanji with and without termination upon collision indicated with w/o collision in the figure legends. For a more detailed discussion, please see the following page.

Level-Based Foraging

Mava also supports Jumanji's LBF. We evaluate Mava's recurrent MAPPO system on LBF, against EPyMARL (we used original LBF for EPyMARL) in 2 and 4 agent settings up to 20 million timesteps. Both systems were trained using 16 vectorized environments. For the EPyMARL systems we use a NVIDIA A100 GPU and for the Mava systems we use a GeForce RTX 3050 laptop GPU with 4GB of memory. To show how Mava can generalise to different hardware, we also train the Mava systems on a TPU v3-8. We plan to publish comprehensive performance benchmarks for all Mava's algorithms across various LBF scenarios soon.

legend

Mava ff mappo tiny 2ag Mava ff mappo small 4ag

Mava Recurrent MAPPO performance on the 2s-8x8-2p-2f-coop, and 15x15-4p-3fz Level-Based Foraging tasks.

๐Ÿงจ Steps per second experiments using vectorised environments

Furthermore, we illustrate the speed of Mava by showing the steps per second as the number of parallel environments is increased. These steps per second scaling plots were computed using a standard laptop GPU, specifically an RTX-3060 GPU with 6GB memory.

Mava sps Mava ff mappo speed comparison

Mava steps per second scaling with increased vectorised environments and total training run time for 20M environment steps.

Code Philosophy ๐Ÿง˜

The current code in Mava is adapted from PureJaxRL which provides high-quality single-file implementations with research-friendly features. In turn, PureJaxRL is inspired by the code philosophy from CleanRL. Along this vein of easy-to-use and understandable RL codebases, Mava is not designed to be a modular library and is not meant to be imported. Our repository focuses on simplicity and clarity in its implementations while utilising the advantages offered by JAX such as pmap and vmap, making it an excellent resource for researchers and practitioners to build upon.

Installation ๐ŸŽฌ

At the moment Mava is not meant to be installed as a library, but rather to be used as a research tool.

You can use Mava by cloning the repo and pip installing as follows:

git clone https://github.com/instadeepai/mava.git
cd mava
pip install -e .

We have tested Mava on Python 3.9. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the official installation guide). For more in-depth installation guides including Docker builds and virtual environments, please see our detailed installation guide.

Quickstart โšก

To get started with training your first Mava system, simply run one of the system files. e.g.,

python mava/systems/ff_ippo.py

Mava makes use of Hydra for config management. In order to see our default system configs please see the mava/configs/ directory. A benefit of Hydra is that configs can either be set in config yaml files or overwritten from the terminal on the fly. For an example of running a system on the LBF environment, the above code can simply be adapted as follows:

python mava/systems/ff_ippo.py env=lbf

Different scenarios can also be run by making the following config updates from the terminal:

python mava/systems/ff_ippo.py env=rware env/scenario=tiny-4ag

Additionally, we also have a Quickstart notebook that can be used to quickly create and train your first Multi-agent system.

Advanced Usage ๐Ÿ‘ฝ

Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a Flashbax Vault. This vault can then easily be integrated into offline MARL systems, such as those found in OG-MARL. See the Advanced README for more information.

Contributing ๐Ÿค

Please read our contributing docs for details on how to submit pull requests, our Contributor License Agreement and community guidelines.

Roadmap ๐Ÿ›ค๏ธ

We plan to iteratively expand Mava in the following increments:

  • ๐ŸŒด Support for more environments.
  • ๐Ÿ” More robust recurrent systems.
  • ๐ŸŒณ Support for non JAX-based environments.
  • ๐Ÿฆพ Support for off-policy algorithms.
  • ๐ŸŽ› Continuous action space environments and algorithms.

Please do follow along as we develop this next phase!

TensorFlow 2 Mava:

Originally Mava was written in Tensorflow 2. Support for the TF2-based framework and systems has now been fully deprecated. If you would still like to use it, please install v0.1.3 of Mava (i.e. pip install id-mava==0.1.3).

See Also ๐Ÿ”Ž

InstaDeep's MARL ecosystem in JAX. In particular, we suggest users check out the following sister repositories:

  • ๐Ÿ”Œ OG-MARL: datasets with baselines for offline MARL in JAX.
  • ๐ŸŒด Jumanji: a diverse suite of scalable reinforcement learning environments in JAX.
  • ๐Ÿ˜Ž Matrax: a collection of matrix games in JAX.
  • โšก Flashbax: accelerated replay buffers in JAX.
  • ๐Ÿ“ˆ MARL-eval: standardised experiment data aggregation and visualisation for MARL.

Related. Other libraries related to accelerated MARL in JAX.

  • ๐ŸฆŠ JaxMARL: accelerated MARL environments with baselines in JAX.
  • ๐ŸŒ€ DeepMind Anakin for the Anakin podracer architecture to train RL agents at scale.
  • โ™Ÿ๏ธ Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
  • ๐Ÿ”ผ Minimax: JAX implementations of autocurricula baselines for RL.

Citing Mava ๐Ÿ“š

If you use Mava in your work, please cite the accompanying technical report:

@article{dekock2023mava,
    title={Mava: a research library for distributed multi-agent reinforcement learning in JAX},
    author={Ruan de Kock and Omayma Mahjoub and Sasha Abramowitz and Wiem Khlifi and Callum Rhys Tilbury
    and Claude Formanek and Andries P. Smit and Arnu Pretorius},
    year={2023},
    journal={arXiv preprint arXiv:2107.01460},
    url={https://arxiv.org/pdf/2107.01460.pdf},
}

Acknowledgements ๐Ÿ™

We would like to thank all the authors who contributed to the previous TF version of Mava: Kale-ab Tessera, St John Grimbly, Kevin Eloff, Siphelele Danisa, Lawrence Francis, Jonathan Shock, Herman Kamper, Willie Brink, Herman Engelbrecht, Alexandre Laterre, Karim Beguir. Their contributions can be found in our TF technical report.

The development of Mava was supported with Cloud TPUs from Google's TPU Research Cloud (TRC) ๐ŸŒค.

mava's People

Contributors

alaterre avatar arnupretorius avatar asadjeewa avatar callumtilbury avatar cwichka avatar driessmit avatar edantoledo avatar eltociear avatar jcformanek avatar jemmaldaniel avatar kaleabtessera avatar kevineloff avatar lbeyers avatar ldfrancis avatar liamclarkza avatar louay-ben-nessir avatar mmorris44 avatar mnguyen0226 avatar nashlen avatar omaymamahjoub avatar ruanjohn avatar sash-a avatar sgrimbly avatar siddarthsingh1 avatar simondutoit avatar sipheleledanisa avatar ulricharmel avatar wiemkhlifi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mava's Issues

Implement checkpointing

This will allow for periodic saving of the system networks and loading it again to resume training.

Fix _transform_observations called per agent problem

Try and resolve the problem with calling _transform_observations for each agent even though it is the same calculation. It has its own loop over all agents. Also, try and do a batch update of all networks instead of the sequential updates that are currently done. This is mostly to do with the shared networks between agents that are getting updated sequentially. This might introduce some problem where agent order determines the effect it has on shared network weights, which we do not want.

Implement additional logging metrics

Metric to track during training:

mean/std/min/max for the following:

  • for cumulative rewards
  • episode length
  • value function estimates
  • losses for the objectives
  • exploration parameters (like mean entropy for stochastic policy optimization, or current epsilon for epsilon-greedy as in DQN)

General MARL env loop

This is in connection with implementing logging metric #27. If we have one general MARL env loop, we will only have to implement the metric logging function once. Then we can have all the other env inherit this. Similar argument goes for other functions associated with the env loop that can be shared across different envs.

Fix training error

The agents are not learning anymore. Investigate why that is and fix it.

Fix memory leak issue

It seems that the RAM used throughout training keeps increasing as the training progresses. This might be due to some memory leakage problem.

Implement observation and reward scaling wrappers

Best practice advice:

  • Make sure everything is reasonably scaled.

Rule of thumb:

  • Observations: Make everything mean 0, standard deviation 1.
  • Reward: If you control it, then scale it to a reasonable value.
  • Do it across ALL your data so far.
  • Look at all observations and rewards and make sure there aren't crazy outliers

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.