Giter VIP home page Giter VIP logo

la-mbda's Introduction

LAMBDA

A repository for the implementation of the paper Constrained Policy Optimization via Bayesian World Models (Yarden As, Ilnura Usmanova, Sebastian Curi, Andreas Krause, ICLR 2022). Please see our paper (arXiv) for further details. To cite, please use:

@article{as2022constrained,
  title={Constrained Policy Optimization via Bayesian World Models},
  author={As, Yarden and Usmanova, Ilnura and Curi, Sebastian and Krause, Andreas},
  journal={arXiv preprint arXiv:2201.09802},
  year={2022}
}

Idea

By taking a Bayesian perspective, LAMBDA learns a world model and uses it to generate sequences using different posterior samples of the world model parameters. Following that, it chooses and learns from the optimistic sequences how to solve the task and from the pessimistic sequences how to adhere to safety restrictions.

Running and plotting

Install dependencies (this may take more than an hour):

conda create -n lambda python=3.6
conda activate lambda
pip3 install .

Run experiments:

python3 experiments/train.py --log_dir results/la_mbda/point_goal2/314 --environment sgym_Safexp-PointGoal2-v0 --total_training_steps 1000000 --safety

Plot:

python3 experiments/plot.py --data_path results/

where the script expects the following directory tree structure:

results
├── algo1
│   └── environment1
│       └── experiment1
│       └── ...
│   └── ...
└── algo2
    └── environment1
        ├── experiment1
        └── experiment2

Acknowledgement

Dreamer codebase which served as a starting point for this github repo

la-mbda's People

Contributors

yardenas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

la-mbda's Issues

IndexError in plot.py

Dear Mr. As,
I would like to use safe RL agent for my seminar project, but I ran into an issue with the provided example in your README.md.
This could very well be my fault, but I would be very happy if you could help me out!

Setup

I installed the necessary dependencies in an anaconda environment as described, but I had to install mujoco200 and dm_hack (https://github.com/yardenas/safety-gym/tarball/dm_hack) manually.

Problem

I get the following error executing python3 experiments/plot.py --data_path results/:

Traceback (most recent call last):
  File "experiments/plot.py", line 287, in <module>
    summarize_experiments(args)
  File "experiments/plot.py", line 234, in summarize_experiments
    experiment_statistics['objectives_median'][-1],
IndexError: index -1 is out of bounds for axis 0 with size 0

And furthermore get the following warning:

Not all metrics are available!

Steps to reproduce

  • Run the given example:
python3 experiments/train.py --log_dir results/point_goal2/314 --environment sgym_Safexp-PointGoal2-v0 --total_training_steps 1000000 --safety
  • Run plot.py as provided:
python3 experiments/plot.py --data_path results/

Running Error!

Hi,
I'm trying to reproduce the code. However, I got the following error when I run the exactly same command as yours.


Traceback (most recent call last):
File "experiments/train.py", line 7, in
train_utils.train(config, LAMBDA)
File "/root/lambda/la-mbda/experiments/train_utils.py", line 139, in train
on_episode_end=lambda episode_summary, steps_count: on_episode_end(
File "/root/lambda/la-mbda/la_mbda/utils.py", line 101, in interact
pbar, len(episodes) < config.render_episodes and not training)
File "/root/lambda/la-mbda/la_mbda/utils.py", line 68, in do_episode
action = agent(observation, training)
File "/root/lambda/la-mbda/la_mbda/la_mbda.py", line 39, in call
self.pretrain_model()
File "/root/lambda/la-mbda/la_mbda/la_mbda.py", line 85, in pretrain_model
self.model.train(batch)
File "/root/lambda/la-mbda/la_mbda/bayesian_world_model.py", line 26, in train
posterior_beliefs = self._training_step(batch)
File "/opt/conda/envs/lambda/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 828, in call
result = self._call(*args, **kwds)
File "/opt/conda/envs/lambda/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
return self._stateless_fn(*args, **kwds)
File "/opt/conda/envs/lambda/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2943, in call
filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
File "/opt/conda/envs/lambda/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/opt/conda/envs/lambda/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 560, in call
ctx=ctx)
File "/opt/conda/envs/lambda/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.NotFoundError: 2 root error(s) found.
(0) Not found: No algorithm worked!
[[node sequential_1/conv2d/Conv2D (defined at /root/lambda/la-mbda/la_mbda/building_blocks.py:62) ]]
[[SWAG/SWAG/update_7/StatefulPartitionedCall/cond_1/pivot_t/_2025/_295]]
(1) Not found: No algorithm worked!
[[node sequential_1/conv2d/Conv2D (defined at /root/lambda/la-mbda/la_mbda/building_blocks.py:62) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__training_step_35978]

Errors may have originated from an input operation.
Input Source operations connected to node sequential_1/conv2d/Conv2D:
Reshape (defined at /root/lambda/la-mbda/la_mbda/building_blocks.py:61)

Input Source operations connected to node sequential_1/conv2d/Conv2D:
Reshape (defined at /root/lambda/la-mbda/la_mbda/building_blocks.py:61)

Function call stack:
_training_step -> _training_step

20%|################################## | 5000/25000 [01:25<05:42, 58.36it/s]


Any ideas? Thanks a lot!!!

About the observation.

Congratulations! I think LAMBDA is a pretty good work!
I wonder if you have tested LAMBDA in tasks with sensor-based observations, instead of the pixel-based ones used in the paper.

An error about GLFW

Hello,dear authors of lambda!
We have meet a strange problem when we first finished my pip3 install . ,and run the commend on the server as you suggested:
python3 experiments/train.py --log_dir results/la_mbda/point_goal2/314 --environment sgym_Safexp-PointGoal2-v0 --total_training_steps 1000000 --safety

There exists a 'GLFWError' as the follows:
GLFWError: (65544) b'X11: Failed to open display :1'

Then ,the tqdm began and suddenly shut down at 0%. After that ,there come another problem about GLFW:
GLFWError: (65537) b'The GLFW library is not initialized'.The whole process is then shut down.

We run the code on the server without any display device. So i wonder if the problem comes from the lack of the display as the GLFW is strongly related to 3D graph. Also,we'd appreciate if you can tell us how to run the code on a server without a display device!

Issue with mujoco210 build

Hi! Since MuJoCo200 is no longer available, I'm trying to build the environment by using MuJoCo210 and dm-control 1.0.1. However, I met the following problem. Is there any solution to get pass this? I have attached my environment packages and error report below.

Traceback (most recent call last): File "experiments/train.py", line 7, in <module> train_utils.train(config, LAMBDA) File "/home/simonzhan/Desktop/Projects/la-mbda/experiments/train_utils.py", line 137, in train training_steps, training_episodes_summaries = utils.interact( File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/la_mbda/utils.py", line 101, in interact do_episode(agent, training, File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/la_mbda/utils.py", line 63, in do_episode observation = environment.reset() if reset_function is None else reset_function() File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/gym/core.py", line 311, in reset return self.observation(self.env.reset(**kwargs)) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/gym/core.py", line 337, in reset return self.env.reset(**kwargs) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/gym/core.py", line 283, in reset return self.env.reset(**kwargs) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/gym/wrappers/time_limit.py", line 26, in reset return self.env.reset(**kwargs) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/gym/wrappers/order_enforcing.py", line 18, in reset return self.env.reset(**kwargs) File "/home/simonzhan/Desktop/Projects/la-mbda/safety-gym-dm_hack/safety_gym/envs/engine.py", line 897, in reset self.build() File "/home/simonzhan/Desktop/Projects/la-mbda/safety-gym-dm_hack/safety_gym/envs/engine.py", line 871, in build self.world.reset() File "/home/simonzhan/Desktop/Projects/la-mbda/safety-gym-dm_hack/safety_gym/envs/world.py", line 317, in reset self.build() File "/home/simonzhan/Desktop/Projects/la-mbda/safety-gym-dm_hack/safety_gym/envs/world.py", line 292, in build self.sim = dm_control.mujoco.Physics.from_xml_string(self.xml_string) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/dm_control/mujoco/engine.py", line 424, in from_xml_string return cls.from_model(model) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/dm_control/mujoco/engine.py", line 407, in from_model return cls(data) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/dm_control/mujoco/engine.py", line 122, in __init__ self._reload_from_data(data) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/dm_control/mujoco/engine.py", line 389, in _reload_from_data data=index.struct_indexer(self.data, 'mjdata', axis_indexers),) File "/home/simonzhan/anaconda3/envs/lambda-advanced/lib/python3.8/site-packages/dm_control/mujoco/index.py", line 623, in struct_indexer attr = getattr(struct, field_name) AttributeError: 'MjData' object has no attribute 'qacc_unc'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.