Giter VIP home page Giter VIP logo

pymarl_transformers's Introduction

arXiv Homepage Poster Presentation

‼️ TransfQMix is now officially supported in JaxMARL

TransfQMix

Official repository of the AAMAS 2023 paper: TransfQMix: Transformers for Leveraging the Graph Structure of Multi-Agent Reinforcement Learning Problems. The codebase is built on top of Pymarl.

Usage

With docker

The repository makes available a Dockerfile to containerize the execution of the code with GPU support (recommended for transformer models). To build the image, run the standard:

sudo docker build . -t pymarl

You can then run any of the available models with the run.sh script: bash run.sh. Change the last line of the script in order to choose your configuration files. For example:

# run StarCraft2 experiment
python3 src/main.py --config=transf_qmix_smac --env-config=sc2
# run Spread experiment
python3 src/main.py --config=transf_qmix --env-config=mpe/spread

Remember that you need nvidia-container to use your GPU with Docker.

With python

If you want to run the codebase without docker, you can install the requirements.txt in a python 3.8 virtual environment (conda, pipenv).

You will also need to install StarCraft2 in your computer: with linux, you can use the bash install_sc2.sh script. You will also need SMAC with pip install git+https://github.com/oxwhirl/smac.git.

Finally, install the pytorch version that is more suitable for your system. For example (for GPU support with CUDA 11.6): pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116.

Then run the python commands as above.


Use transformers in new environments

In order to use TransfQMix (or just the transformer agent or mixer) in other environments, at every time-step you will need to return observation and state vectors that are reshapable as matrices. In particular, the shape of the matrices will be $(k, z)$, where $k$ is the number of entities and $z$ the entity features. Therefore, the observations and states are flattened vectors with a dimension of $k \times z$. For simplicity we can assume that $k$ and $z$ are the same for agents and mixer but this could not the case (use n_entities_obs and n_entities_state to differenciate $k$ for the agent and the mixer; to differenciate $z$, use obs_entity_feats and obs_state_feats).

  • The definition of the entity depends on the environment. In SC2, they are the allies+enemies. In Spread, the agents+landmarks. But an entity can be in principle any information channel (i.e. different sensors, communication channels, and so on).

  • The entity features $z$ are the features which describe every entity. Sparse matrices are not allowed, i.e. all the features should be used to described every entity. If it doesn't make sense to use some of them for some entities, pad them to 0. Check the paper for additional information about $k$ and $z$.

In the case of SC2 and Spread, we included two new environment parameters: obs_entity_mode and state_entity_mode that allow to chose how to return the observation and state vectors at every time step. If they are set to True, the environment is expected to return flattened matrices that will be reshaped again as matrices internally by TransfQMix. If obs_entity_mode or state_entity_mode are set to True, the original observation or state vectors are returned.

We encourage you to follow the same line, i.e. include a parameter in your environment that allows to chose if use the entity-mode or not.

Take in mind:

  1. In the init method of your environment wrapper, you should define additional attributes in respect to the traditional pymarl (please check MultiAgentEnv):
    • self.obs_entity_feats: number of features that define the entities in the obs matrix
    • self.state_entity_feat: number of features that define the entities in the state matrix
    • self.n_entities: number of fixed entities observed by agents and mixer
    • (optional) self.n_entities_obs: number of entities observed by agents if different than n_entites
    • (optional) self.n_entities_state: number of entities observed by mixer if different than n_entites
  2. You can define different features for the entity observation matrix and for the entity observation state. Change obs_entity_feats and state_entity_feats accordingly.
  3. The number of entities is assumed to be invariant during an episode. If an entity dies or is not observable, set all its features to 0s (this can be improved).
  4. The order of the entities in the flatten vectors is important:
    • For the agent only if your using policy decoupling: in this case you need to ensure that you're taking track of the positions of the entities which have some entity-based actions (for example, the ). This is because you will need to extract their specific embeddings in order to sample the entity-based actions from them. See the SC2 agent for how this is done in StarCraft 2.
    • For the mixer: the first entity features must be relative to the agents and must follow the same order of the agents q-values. The codebase puts the agents always in the same order, so this should not be a problem.

Extra

  1. The repository supports parallelization with the parallel_runner with most of the models (i.e. a parallel environment running for each experiment), and also TransfQMix is ready to be used in parallel environments. Performances are not comparable of models trained with different number of processes. By default, an experiment with a single environment in single process is run.
  2. If you're using policy decoupling in a new environment, it is recommended that you add a new environment-specific transformer agent in order to menage the output layers. You can use SC2 agent as a base. You would need only to change the output layers, the way in which the entity-based-action embeddings are extracted from the output of the transformer, and add your agent in the agent registry. In future version this could parametrized, but for now this is the easiest way to go. If you're not using policy decoupling, you can use the standard transformer agent.
  3. This codebase includes a matplotlib-based and a plolty-based animation classes for the MPE environment, which allow to generate customized gifs at the end of an episode. You can get inspired from them to generate animations of your environment in a simpler way than using gym. Here is an example for 6v6 Spread:

spread_5v5

Citation

If you use this codebase please cite:

@inproceedings{10.5555/3545946.3598825,
author = {Gallici, Matteo and Martin, Mario and Masmitja, Ivan},
title = {TransfQMix: Transformers for Leveraging the Graph Structure of Multi-Agent Reinforcement Learning Problems},
year = {2023},
publisher = {International Foundation for Autonomous Agents and Multiagent Systems},
address = {Richland, SC},
booktitle = {Proceedings of the 2023 International Conference on Autonomous Agents and Multiagent Systems},
pages = {1679–1687},
location = {London, United Kingdom},
series = {AAMAS '23}
}

pymarl_transformers's People

Contributors

mttga avatar

Stargazers

unaughty avatar Weijun avatar  avatar  avatar  avatar Jacopo Castellini avatar  avatar WangTF avatar Noah Syrkis avatar Wang Lijuan avatar  avatar ayton_Zhang avatar XINGJIAN Zhang avatar  avatar Shuhao Liao avatar Chenasuny avatar Jeff Carpenter avatar  avatar  avatar Yaru Niu avatar Luo Fuliang avatar NKU_Syl avatar Yuchen Wu avatar  avatar  avatar Raytexh avatar  avatar Sam Fazel avatar  avatar  avatar 齐小阳 avatar  avatar  avatar James avatar ww avatar Xiaoyang Yu avatar Jahandad avatar

Watchers

 avatar  avatar

pymarl_transformers's Issues

Questions about Graph Observations

Hello teacher, I successfully TransformerQMix combined with my own environment and achieved very good results. But I have a question, in my environment, Graph based Observations require agent to get global information, is this not in line with CTDE training architecture? Also the observation dimension grows linearly with the number of intelligences making it difficult to train.

Transfer learning in practice: Size mismatch error

Hi!

Thanks a lot for this very interesting line of work. I am particularly curious about the zero-shot transfer learning performance shown in your paper.
However, I don't understand how to "apply the networks trained in a particular task to the others" (as described in the paper) when the number of agents varies between tasks.
For instance, I tried to naïvely load the checkpoint of a trained model on the "3m" SMAC map to apply it to the "8m" one, but as I expected I got a size mismatch error in the loading of the torch weights and biases:

RuntimeError: Error(s) in loading state_dict for TransformerAgent: size mismatch for q_basic.weight: copying a param with shape torch.Size([9, 32]) from checkpoint, the shape in current model is torch.Size([14, 32]). size mismatch for q_basic.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([14]).

Could you help me with this problem ?
Many thanks for considering my request.

Queries about the migratory nature of the algorithm

Hi, the code in the article adds Transformer and implements mobility in both Agent Network as well as Mix'ing Network. I have an idea to use only Transfor-mix network and still use rnn_agent for Agent Network. set each agent to observe only its own information and not use graph structure observation. Will this achieve mobility? I still use the graph structure's state, which contains information about all vertices.
I hope to get your answer if you have time, thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.