Giter VIP home page Giter VIP logo

corl's Introduction

CORL (Clean Offline Reinforcement Learning)

Twitter arXiv Ruff

⚠️ This is an active and supported fork of the original CORL library. The original repository is on freeze and will not be update further.

🧵 CORL is an Offline Reinforcement Learning library that provides high-quality and easy-to-follow single-file implementations of SOTA ORL algorithms. Each implementation is backed by a research-friendly codebase, allowing you to run or tune thousands of experiments. Heavily inspired by cleanrl for online RL, check them out too!

  • 📜 Single-file implementation
  • 📈 Benchmarked Implementation (11+ offline algorithms, 5+ offline-to-online algorithms, 30+ datasets with detailed logs)
  • 🖼 Weights and Biases integration

You can read more about CORL design and main results in our technical paper.


  • ⭐ If you're interested in discrete control, make sure to check out our new library — Katakomba. It provides both discrete control algorithms augmented with recurrence and an offline RL benchmark for the NetHack Learning environment.

⚠️ NOTE: CORL (similarily to CleanRL) is not a modular library and therefore it is not meant to be imported. At the cost of duplicate code, we make all implementation details of an ORL algorithm variant easy to understand. You should consider using CORL if you want to 1) understand and control all implementation details of an algorithm or 2) rapidly prototype advanced features that other modular ORL libraries do not support.

Getting started

Please refer to the documentation for more details. TLDR:

git clone https://github.com/corl-team/CORL.git && cd CORL
pip install -r requirements/requirements_dev.txt

# alternatively, you could use docker
docker build -t <image_name> .
docker run --gpus=all -it --rm --name <container_name> <image_name>

Algorithms Implemented

Algorithm Variants Implemented Wandb Report
Offline and Offline-to-Online
Conservative Q-Learning for Offline Reinforcement Learning
(CQL)
offline/cql.py
finetune/cql.py
Offline

Offline-to-online
Accelerating Online Reinforcement Learning with Offline Datasets
(AWAC)
offline/awac.py
finetune/awac.py
Offline

Offline-to-online
Offline Reinforcement Learning with Implicit Q-Learning
(IQL)
offline/iql.py
finetune/iql.py
Offline

Offline-to-online
Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)
offline/rebrac.py
finetune/rebrac.py
Offline

Offline-to-online
Offline-to-Online only
Supported Policy Optimization for Offline Reinforcement Learning
(SPOT)
finetune/spot.py Offline-to-online
Cal-QL: Calibrated Offline RL Pre-Training for Efficient Online Fine-Tuning
(Cal-QL)
finetune/cal_ql.py Offline-to-online
Offline only
✅ Behavioral Cloning
(BC)
offline/any_percent_bc.py Offline
✅ Behavioral Cloning-10%
(BC-10%)
offline/any_percent_bc.py Offline
A Minimalist Approach to Offline Reinforcement Learning
(TD3+BC)
offline/td3_bc.py Offline
Decision Transformer: Reinforcement Learning via Sequence Modeling
(DT)
offline/dt.py Offline
Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(SAC-N)
offline/sac_n.py Offline
Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(EDAC)
offline/edac.py Offline
Q-Ensemble for Offline RL: Don't Scale the Ensemble, Scale the Batch Size
(LB-SAC)
offline/lb_sac.py Offline Gym-MuJoCo

D4RL Benchmarks

You can check the links above for learning curves and details. Here, we report reproduced final and best scores. Note that they differ by a significant margin, and some papers may use different approaches, not making it always explicit which reporting methodology they chose. If you want to re-collect our results in a more structured/nuanced manner, see results.

Offline

Last Scores

Gym-MuJoCo
Task-Name BC 10% BC TD3+BC AWAC CQL IQL ReBRAC SAC-N EDAC DT
halfcheetah-medium-v2 42.40 ± 0.19 42.46 ± 0.70 48.10 ± 0.18 50.02 ± 0.27 47.04 ± 0.22 48.31 ± 0.22 64.04 ± 0.68 68.20 ± 1.28 67.70 ± 1.04 42.20 ± 0.26
halfcheetah-medium-replay-v2 35.66 ± 2.33 23.59 ± 6.95 44.84 ± 0.59 45.13 ± 0.88 45.04 ± 0.27 44.46 ± 0.22 51.18 ± 0.31 60.70 ± 1.01 62.06 ± 1.10 38.91 ± 0.50
halfcheetah-medium-expert-v2 55.95 ± 7.35 90.10 ± 2.45 90.78 ± 6.04 95.00 ± 0.61 95.63 ± 0.42 94.74 ± 0.52 103.80 ± 2.95 98.96 ± 9.31 104.76 ± 0.64 91.55 ± 0.95
hopper-medium-v2 53.51 ± 1.76 55.48 ± 7.30 60.37 ± 3.49 63.02 ± 4.56 59.08 ± 3.77 67.53 ± 3.78 102.29 ± 0.17 40.82 ± 9.91 101.70 ± 0.28 65.10 ± 1.61
hopper-medium-replay-v2 29.81 ± 2.07 70.42 ± 8.66 64.42 ± 21.52 98.88 ± 2.07 95.11 ± 5.27 97.43 ± 6.39 94.98 ± 6.53 100.33 ± 0.78 99.66 ± 0.81 81.77 ± 6.87
hopper-medium-expert-v2 52.30 ± 4.01 111.16 ± 1.03 101.17 ± 9.07 101.90 ± 6.22 99.26 ± 10.91 107.42 ± 7.80 109.45 ± 2.34 101.31 ± 11.63 105.19 ± 10.08 110.44 ± 0.33
walker2d-medium-v2 63.23 ± 16.24 67.34 ± 5.17 82.71 ± 4.78 68.52 ± 27.19 80.75 ± 3.28 80.91 ± 3.17 85.82 ± 0.77 87.47 ± 0.66 93.36 ± 1.38 67.63 ± 2.54
walker2d-medium-replay-v2 21.80 ± 10.15 54.35 ± 6.34 85.62 ± 4.01 80.62 ± 3.58 73.09 ± 13.22 82.15 ± 3.03 84.25 ± 2.25 78.99 ± 0.50 87.10 ± 2.78 59.86 ± 2.73
walker2d-medium-expert-v2 98.96 ± 15.98 108.70 ± 0.25 110.03 ± 0.36 111.44 ± 1.62 109.56 ± 0.39 111.72 ± 0.86 111.86 ± 0.43 114.93 ± 0.41 114.75 ± 0.74 107.11 ± 0.96
locomotion average 50.40 69.29 76.45 79.39 78.28 81.63 89.74 83.52 92.92 73.84
Maze2d
Task-Name BC 10% BC TD3+BC AWAC CQL IQL ReBRAC SAC-N EDAC DT
maze2d-umaze-v1 0.36 ± 8.69 12.18 ± 4.29 29.41 ± 12.31 65.65 ± 5.34 -8.90 ± 6.11 42.11 ± 0.58 106.87 ± 22.16 130.59 ± 16.52 95.26 ± 6.39 18.08 ± 25.42
maze2d-medium-v1 0.79 ± 3.25 14.25 ± 2.33 59.45 ± 36.25 84.63 ± 35.54 86.11 ± 9.68 34.85 ± 2.72 105.11 ± 31.67 88.61 ± 18.72 57.04 ± 3.45 31.71 ± 26.33
maze2d-large-v1 2.26 ± 4.39 11.32 ± 5.10 97.10 ± 25.41 215.50 ± 3.11 23.75 ± 36.70 61.72 ± 3.50 78.33 ± 61.77 204.76 ± 1.19 95.60 ± 22.92 35.66 ± 28.20
maze2d average 1.13 12.58 61.99 121.92 33.65 46.23 96.77 141.32 82.64 28.48
Antmaze
Task-Name BC 10% BC TD3+BC AWAC CQL IQL ReBRAC SAC-N EDAC DT
antmaze-umaze-v2 55.25 ± 4.15 65.75 ± 5.26 70.75 ± 39.18 56.75 ± 9.09 92.75 ± 1.92 77.00 ± 5.52 97.75 ± 1.48 0.00 ± 0.00 0.00 ± 0.00 57.00 ± 9.82
antmaze-umaze-diverse-v2 47.25 ± 4.09 44.00 ± 1.00 44.75 ± 11.61 54.75 ± 8.01 37.25 ± 3.70 54.25 ± 5.54 83.50 ± 7.02 0.00 ± 0.00 0.00 ± 0.00 51.75 ± 0.43
antmaze-medium-play-v2 0.00 ± 0.00 2.00 ± 0.71 0.25 ± 0.43 0.00 ± 0.00 65.75 ± 11.61 65.75 ± 11.71 89.50 ± 3.35 0.00 ± 0.00 0.00 ± 0.00 0.00 ± 0.00
antmaze-medium-diverse-v2 0.75 ± 0.83 5.75 ± 9.39 0.25 ± 0.43 0.00 ± 0.00 67.25 ± 3.56 73.75 ± 5.45 83.50 ± 8.20 0.00 ± 0.00 0.00 ± 0.00 0.00 ± 0.00
antmaze-large-play-v2 0.00 ± 0.00 0.00 ± 0.00 0.00 ± 0.00 0.00 ± 0.00 20.75 ± 7.26 42.00 ± 4.53 52.25 ± 29.01 0.00 ± 0.00 0.00 ± 0.00 0.00 ± 0.00
antmaze-large-diverse-v2 0.00 ± 0.00 0.75 ± 0.83 0.00 ± 0.00 0.00 ± 0.00 20.50 ± 13.24 30.25 ± 3.63 64.00 ± 5.43 0.00 ± 0.00 0.00 ± 0.00 0.00 ± 0.00
antmaze average 17.21 19.71 19.33 18.58 50.71 57.17 78.42 0.00 0.00 18.12
Adroit
Task-Name BC 10% BC TD3+BC AWAC CQL IQL ReBRAC SAC-N EDAC DT
pen-human-v1 71.03 ± 6.26 26.99 ± 9.60 -3.88 ± 0.21 76.65 ± 11.71 13.71 ± 16.98 78.49 ± 8.21 103.16 ± 8.49 6.86 ± 5.93 5.07 ± 6.16 67.68 ± 5.48
pen-cloned-v1 51.92 ± 15.15 46.67 ± 14.25 5.13 ± 5.28 85.72 ± 16.92 1.04 ± 6.62 83.42 ± 8.19 102.79 ± 7.84 31.35 ± 2.14 12.02 ± 1.75 64.43 ± 1.43
pen-expert-v1 109.65 ± 7.28 114.96 ± 2.96 122.53 ± 21.27 159.91 ± 1.87 -1.41 ± 2.34 128.05 ± 9.21 152.16 ± 6.33 87.11 ± 48.95 -1.55 ± 0.81 116.38 ± 1.27
door-human-v1 2.34 ± 4.00 -0.13 ± 0.07 -0.33 ± 0.01 2.39 ± 2.26 5.53 ± 1.31 3.26 ± 1.83 -0.10 ± 0.01 -0.38 ± 0.00 -0.12 ± 0.13 4.44 ± 0.87
door-cloned-v1 -0.09 ± 0.03 0.29 ± 0.59 -0.34 ± 0.01 -0.01 ± 0.01 -0.33 ± 0.01 3.07 ± 1.75 0.06 ± 0.05 -0.33 ± 0.00 2.66 ± 2.31 7.64 ± 3.26
door-expert-v1 105.35 ± 0.09 104.04 ± 1.46 -0.33 ± 0.01 104.57 ± 0.31 -0.32 ± 0.02 106.65 ± 0.25 106.37 ± 0.29 -0.33 ± 0.00 106.29 ± 1.73 104.87 ± 0.39
hammer-human-v1 3.03 ± 3.39 -0.19 ± 0.02 1.02 ± 0.24 1.01 ± 0.51 0.14 ± 0.11 1.79 ± 0.80 0.24 ± 0.24 0.24 ± 0.00 0.28 ± 0.18 1.28 ± 0.15
hammer-cloned-v1 0.55 ± 0.16 0.12 ± 0.08 0.25 ± 0.01 1.27 ± 2.11 0.30 ± 0.01 1.50 ± 0.69 5.00 ± 3.75 0.14 ± 0.09 0.19 ± 0.07 1.82 ± 0.55
hammer-expert-v1 126.78 ± 0.64 121.75 ± 7.67 3.11 ± 0.03 127.08 ± 0.13 0.26 ± 0.01 128.68 ± 0.33 133.62 ± 0.27 25.13 ± 43.25 28.52 ± 49.00 117.45 ± 6.65
relocate-human-v1 0.04 ± 0.03 -0.14 ± 0.08 -0.29 ± 0.01 0.45 ± 0.53 0.06 ± 0.03 0.12 ± 0.04 0.16 ± 0.30 -0.31 ± 0.01 -0.17 ± 0.17 0.05 ± 0.01
relocate-cloned-v1 -0.06 ± 0.01 -0.00 ± 0.02 -0.30 ± 0.01 -0.01 ± 0.03 -0.29 ± 0.01 0.04 ± 0.01 1.66 ± 2.59 -0.01 ± 0.10 0.17 ± 0.35 0.16 ± 0.09
relocate-expert-v1 107.58 ± 1.20 97.90 ± 5.21 -1.73 ± 0.96 109.52 ± 0.47 -0.30 ± 0.02 106.11 ± 4.02 107.52 ± 2.28 -0.36 ± 0.00 71.94 ± 18.37 104.28 ± 0.42
adroit average 48.18 42.69 10.40 55.71 1.53 53.43 59.39 12.43 18.78 49.21

Best Scores

Gym-MuJoCo
Task-Name BC 10% BC TD3+BC AWAC CQL IQL ReBRAC SAC-N EDAC DT
halfcheetah-medium-v2 43.60 ± 0.14 43.90 ± 0.13 48.93 ± 0.11 50.81 ± 0.15 47.62 ± 0.03 48.84 ± 0.07 65.62 ± 0.46 72.21 ± 0.31 69.72 ± 0.92 42.73 ± 0.10
halfcheetah-medium-replay-v2 40.52 ± 0.19 42.27 ± 0.46 45.84 ± 0.26 46.47 ± 0.26 46.43 ± 0.19 45.35 ± 0.08 52.22 ± 0.31 67.29 ± 0.34 66.55 ± 1.05 40.31 ± 0.28
halfcheetah-medium-expert-v2 79.69 ± 3.10 94.11 ± 0.22 96.59 ± 0.87 96.83 ± 0.23 97.04 ± 0.17 95.38 ± 0.17 108.89 ± 1.20 111.73 ± 0.47 110.62 ± 1.04 93.40 ± 0.21
hopper-medium-v2 69.04 ± 2.90 73.84 ± 0.37 70.44 ± 1.18 95.42 ± 3.67 70.80 ± 1.98 80.46 ± 3.09 103.19 ± 0.16 101.79 ± 0.20 103.26 ± 0.14 69.42 ± 3.64
hopper-medium-replay-v2 68.88 ± 10.33 90.57 ± 2.07 98.12 ± 1.16 101.47 ± 0.23 101.63 ± 0.55 102.69 ± 0.96 102.57 ± 0.45 103.83 ± 0.53 103.28 ± 0.49 88.74 ± 3.02
hopper-medium-expert-v2 90.63 ± 10.98 113.13 ± 0.16 113.22 ± 0.43 113.26 ± 0.49 112.84 ± 0.66 113.18 ± 0.38 113.16 ± 0.43 111.24 ± 0.15 111.80 ± 0.11 111.18 ± 0.21
walker2d-medium-v2 80.64 ± 0.91 82.05 ± 0.93 86.91 ± 0.28 85.86 ± 3.76 84.77 ± 0.20 87.58 ± 0.48 87.79 ± 0.19 90.17 ± 0.54 95.78 ± 1.07 74.70 ± 0.56
walker2d-medium-replay-v2 48.41 ± 7.61 76.09 ± 0.40 91.17 ± 0.72 86.70 ± 0.94 89.39 ± 0.88 89.94 ± 0.93 91.11 ± 0.63 85.18 ± 1.63 89.69 ± 1.39 68.22 ± 1.20
walker2d-medium-expert-v2 109.95 ± 0.62 109.90 ± 0.09 112.21 ± 0.06 113.40 ± 2.22 111.63 ± 0.38 113.06 ± 0.53 112.49 ± 0.18 116.93 ± 0.42 116.52 ± 0.75 108.71 ± 0.34
locomotion average 70.15 80.65 84.83 87.80 84.68 86.28 93.00 95.60 96.36 77.49
Maze2d
Task-Name BC 10% BC TD3+BC AWAC CQL IQL ReBRAC SAC-N EDAC DT
maze2d-umaze-v1 16.09 ± 0.87 22.49 ± 1.52 99.33 ± 16.16 136.96 ± 10.89 92.05 ± 13.66 50.92 ± 4.23 162.28 ± 1.79 153.12 ± 6.49 149.88 ± 1.97 63.83 ± 17.35
maze2d-medium-v1 19.16 ± 1.24 27.64 ± 1.87 150.93 ± 3.89 152.73 ± 20.78 128.66 ± 5.44 122.69 ± 30.00 150.12 ± 4.48 93.80 ± 14.66 154.41 ± 1.58 68.14 ± 12.25
maze2d-large-v1 20.75 ± 6.66 41.83 ± 3.64 197.64 ± 5.26 227.31 ± 1.47 157.51 ± 7.32 162.25 ± 44.18 197.55 ± 5.82 207.51 ± 0.96 182.52 ± 2.68 50.25 ± 19.34
maze2d average 18.67 30.65 149.30 172.33 126.07 111.95 169.98 151.48 162.27 60.74
Antmaze
Task-Name BC 10% BC TD3+BC AWAC CQL IQL ReBRAC SAC-N EDAC DT
antmaze-umaze-v2 68.50 ± 2.29 77.50 ± 1.50 98.50 ± 0.87 70.75 ± 8.84 94.75 ± 0.83 84.00 ± 4.06 100.00 ± 0.00 0.00 ± 0.00 42.50 ± 28.61 64.50 ± 2.06
antmaze-umaze-diverse-v2 64.75 ± 4.32 63.50 ± 2.18 71.25 ± 5.76 81.50 ± 4.27 53.75 ± 2.05 79.50 ± 3.35 96.75 ± 2.28 0.00 ± 0.00 0.00 ± 0.00 60.50 ± 2.29
antmaze-medium-play-v2 4.50 ± 1.12 6.25 ± 2.38 3.75 ± 1.30 25.00 ± 10.70 80.50 ± 3.35 78.50 ± 3.84 93.50 ± 2.60 0.00 ± 0.00 0.00 ± 0.00 0.75 ± 0.43
antmaze-medium-diverse-v2 4.75 ± 1.09 16.50 ± 5.59 5.50 ± 1.50 10.75 ± 5.31 71.00 ± 4.53 83.50 ± 1.80 91.75 ± 2.05 0.00 ± 0.00 0.00 ± 0.00 0.50 ± 0.50
antmaze-large-play-v2 0.50 ± 0.50 13.50 ± 9.76 1.25 ± 0.43 0.50 ± 0.50 34.75 ± 5.85 53.50 ± 2.50 68.75 ± 13.90 0.00 ± 0.00 0.00 ± 0.00 0.00 ± 0.00
antmaze-large-diverse-v2 0.75 ± 0.43 6.25 ± 1.79 0.25 ± 0.43 0.00 ± 0.00 36.25 ± 3.34 53.00 ± 3.00 69.50 ± 7.26 0.00 ± 0.00 0.00 ± 0.00 0.00 ± 0.00
antmaze average 23.96 30.58 30.08 31.42 61.83 72.00 86.71 0.00 7.08 21.04
Adroit
Task-Name BC 10% BC TD3+BC AWAC CQL IQL ReBRAC SAC-N EDAC DT
pen-human-v1 99.69 ± 7.45 59.89 ± 8.03 9.95 ± 8.19 119.03 ± 6.55 58.91 ± 1.81 106.15 ± 10.28 127.28 ± 3.22 56.48 ± 7.17 35.84 ± 10.57 77.83 ± 2.30
pen-cloned-v1 99.14 ± 12.27 83.62 ± 11.75 52.66 ± 6.33 125.78 ± 3.28 14.74 ± 2.31 114.05 ± 4.78 128.64 ± 7.15 52.69 ± 5.30 26.90 ± 7.85 71.17 ± 2.70
pen-expert-v1 128.77 ± 5.88 134.36 ± 3.16 142.83 ± 7.72 162.53 ± 0.30 14.86 ± 4.07 140.01 ± 6.36 157.62 ± 0.26 116.43 ± 40.26 36.04 ± 4.60 119.49 ± 2.31
door-human-v1 9.41 ± 4.55 7.00 ± 6.77 -0.11 ± 0.06 17.70 ± 2.55 13.28 ± 2.77 13.52 ± 1.22 0.27 ± 0.43 -0.10 ± 0.06 2.51 ± 2.26 7.36 ± 1.24
door-cloned-v1 3.40 ± 0.95 10.37 ± 4.09 -0.20 ± 0.11 10.53 ± 2.82 -0.08 ± 0.13 9.02 ± 1.47 7.73 ± 6.80 -0.21 ± 0.10 20.36 ± 1.11 11.18 ± 0.96
door-expert-v1 105.84 ± 0.23 105.92 ± 0.24 4.49 ± 7.39 106.60 ± 0.27 59.47 ± 25.04 107.29 ± 0.37 106.78 ± 0.04 0.05 ± 0.02 109.22 ± 0.24 105.49 ± 0.09
hammer-human-v1 12.61 ± 4.87 6.23 ± 4.79 2.38 ± 0.14 16.95 ± 3.61 0.30 ± 0.05 6.86 ± 2.38 1.18 ± 0.15 0.25 ± 0.00 3.49 ± 2.17 1.68 ± 0.11
hammer-cloned-v1 8.90 ± 4.04 8.72 ± 3.28 0.96 ± 0.30 10.74 ± 5.54 0.32 ± 0.03 11.63 ± 1.70 48.16 ± 6.20 12.67 ± 15.02 0.27 ± 0.01 2.74 ± 0.22
hammer-expert-v1 127.89 ± 0.57 128.15 ± 0.66 33.31 ± 47.65 129.08 ± 0.26 0.93 ± 1.12 129.76 ± 0.37 134.74 ± 0.30 91.74 ± 47.77 69.44 ± 47.00 127.39 ± 0.10
relocate-human-v1 0.59 ± 0.27 0.16 ± 0.14 -0.29 ± 0.01 1.77 ± 0.84 1.03 ± 0.20 1.22 ± 0.28 3.70 ± 2.34 -0.18 ± 0.14 0.05 ± 0.02 0.08 ± 0.02
relocate-cloned-v1 0.45 ± 0.31 0.74 ± 0.45 -0.02 ± 0.04 0.39 ± 0.13 -0.07 ± 0.02 1.78 ± 0.70 9.25 ± 2.56 0.10 ± 0.04 4.11 ± 1.39 0.34 ± 0.09
relocate-expert-v1 110.31 ± 0.36 109.77 ± 0.60 0.23 ± 0.27 111.21 ± 0.32 0.03 ± 0.10 110.12 ± 0.82 111.14 ± 0.23 -0.07 ± 0.08 98.32 ± 3.75 106.49 ± 0.30
adroit average 58.92 54.58 20.51 67.69 13.65 62.62 69.71 27.49 33.88 52.60

Offline-to-Online

Scores

Task-Name AWAC CQL IQL SPOT Cal-QL ReBRAC
antmaze-umaze-v2 52.75 ± 8.67 → 98.75 ± 1.09 94.00 ± 1.58 → 99.50 ± 0.87 77.00 ± 0.71 → 96.50 ± 1.12 91.00 ± 2.55 → 99.50 ± 0.50 76.75 ± 7.53 → 99.75 ± 0.43 98.00 ± 1.58 → 74.75 ± 42.59
antmaze-umaze-diverse-v2 56.00 ± 2.74 → 0.00 ± 0.00 9.50 ± 9.91 → 99.00 ± 1.22 59.50 ± 9.55 → 63.75 ± 25.02 36.25 ± 2.17 → 95.00 ± 3.67 32.00 ± 27.79 → 98.50 ± 1.12 73.75 ± 13.27 → 98.00 ± 2.92
antmaze-medium-play-v2 0.00 ± 0.00 → 0.00 ± 0.00 59.00 ± 11.18 → 97.75 ± 1.30 71.75 ± 2.95 → 89.75 ± 1.09 67.25 ± 10.47 → 97.25 ± 1.30 71.75 ± 3.27 → 98.75 ± 1.64 87.50 ± 3.77 → 98.00 ± 1.58
antmaze-medium-diverse-v2 0.00 ± 0.00 → 0.00 ± 0.00 63.50 ± 6.84 → 97.25 ± 1.92 64.25 ± 1.92 → 92.25 ± 2.86 73.75 ± 7.29 → 94.50 ± 1.66 62.00 ± 4.30 → 98.25 ± 1.48 85.25 ± 2.17 → 98.75 ± 0.43
antmaze-large-play-v2 0.00 ± 0.00 → 0.00 ± 0.00 28.75 ± 7.76 → 88.25 ± 2.28 38.50 ± 8.73 → 64.50 ± 17.04 31.50 ± 12.58 → 87.00 ± 3.24 31.75 ± 8.87 → 97.25 ± 1.79 68.50 ± 6.18 → 31.50 ± 33.56
antmaze-large-diverse-v2 0.00 ± 0.00 → 0.00 ± 0.00 35.50 ± 3.64 → 91.75 ± 3.96 26.75 ± 3.77 → 64.25 ± 4.15 17.50 ± 7.26 → 81.00 ± 14.14 44.00 ± 8.69 → 91.50 ± 3.91 67.00 ± 10.61 → 72.25 ± 41.73
antmaze average 18.12 → 16.46 48.38 → 95.58 56.29 → 78.50 52.88 → 92.38 53.04 → 97.33 80.00 → 78.88
pen-cloned-v1 88.66 ± 15.10 → 86.82 ± 11.12 -2.76 ± 0.08 → -1.28 ± 2.16 84.19 ± 3.96 → 102.02 ± 20.75 6.19 ± 5.21 → 43.63 ± 20.09 -2.66 ± 0.04 → -2.68 ± 0.12 74.04 ± 11.97 → 138.15 ± 3.22
door-cloned-v1 0.93 ± 1.66 → 0.01 ± 0.00 -0.33 ± 0.01 → -0.33 ± 0.01 1.19 ± 0.93 → 20.34 ± 9.32 -0.21 ± 0.14 → 0.02 ± 0.31 -0.33 ± 0.01 → -0.33 ± 0.01 0.07 ± 0.04 → 102.39 ± 8.27
hammer-cloned-v1 1.80 ± 3.01 → 0.24 ± 0.04 0.56 ± 0.55 → 2.85 ± 4.81 1.35 ± 0.32 → 57.27 ± 28.49 3.97 ± 6.39 → 3.73 ± 4.99 0.25 ± 0.04 → 0.17 ± 0.17 6.54 ± 3.35 → 124.65 ± 7.37
relocate-cloned-v1 -0.04 ± 0.04 → -0.04 ± 0.01 -0.33 ± 0.01 → -0.33 ± 0.01 0.04 ± 0.04 → 0.32 ± 0.38 -0.24 ± 0.01 → -0.15 ± 0.05 -0.31 ± 0.05 → -0.31 ± 0.04 0.70 ± 0.62 → 6.96 ± 4.59
adroit average 22.84 → 21.76 -0.72 → 0.22 21.69 → 44.99 2.43 → 11.81 -0.76 → -0.79 20.33 → 93.04

Regrets

Task-Name AWAC CQL IQL SPOT Cal-QL ReBRAC
antmaze-umaze-v2 0.04 ± 0.01 0.02 ± 0.00 0.07 ± 0.00 0.02 ± 0.00 0.01 ± 0.00 0.11 ± 0.18
antmaze-umaze-diverse-v2 0.88 ± 0.01 0.09 ± 0.01 0.43 ± 0.11 0.22 ± 0.07 0.05 ± 0.01 0.04 ± 0.02
antmaze-medium-play-v2 1.00 ± 0.00 0.08 ± 0.01 0.09 ± 0.01 0.06 ± 0.00 0.04 ± 0.01 0.03 ± 0.01
antmaze-medium-diverse-v2 1.00 ± 0.00 0.08 ± 0.00 0.10 ± 0.01 0.05 ± 0.01 0.04 ± 0.01 0.03 ± 0.00
antmaze-large-play-v2 1.00 ± 0.00 0.21 ± 0.02 0.34 ± 0.05 0.29 ± 0.07 0.13 ± 0.02 0.14 ± 0.05
antmaze-large-diverse-v2 1.00 ± 0.00 0.21 ± 0.03 0.41 ± 0.03 0.23 ± 0.08 0.13 ± 0.02 0.29 ± 0.39
antmaze average 0.82 0.11 0.24 0.15 0.07 0.11
pen-cloned-v1 0.46 ± 0.02 0.97 ± 0.00 0.37 ± 0.01 0.58 ± 0.02 0.98 ± 0.01 0.08 ± 0.01
door-cloned-v1 1.00 ± 0.00 1.00 ± 0.00 0.83 ± 0.03 0.99 ± 0.01 1.00 ± 0.00 0.19 ± 0.05
hammer-cloned-v1 1.00 ± 0.00 1.00 ± 0.00 0.65 ± 0.10 0.98 ± 0.01 1.00 ± 0.00 0.13 ± 0.03
relocate-cloned-v1 1.00 ± 0.00 1.00 ± 0.00 1.00 ± 0.00 1.00 ± 0.00 1.00 ± 0.00 0.90 ± 0.06
adroit average 0.86 0.99 0.71 0.89 0.99 0.33

Citing CORL

If you use CORL in your work, please use the following bibtex

@inproceedings{
tarasov2022corl,
  title={CORL: Research-oriented Deep Offline Reinforcement Learning Library},
  author={Denis Tarasov and Alexander Nikulin and Dmitry Akimov and Vladislav Kurenkov and Sergey Kolesnikov},
  booktitle={3rd Offline RL Workshop: Offline RL as a ''Launchpad''},
  year={2022},
  url={https://openreview.net/forum?id=SyAS49bBcv}
}

corl's People

Contributors

adamjelley avatar cherrypiesexy avatar dt6a avatar howuhh avatar levilovearch avatar modanesh avatar nakamotoo avatar scitator avatar suessmann avatar typoverflow avatar vkurenkov avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

corl's Issues

Not compatible with cython3

BUG
/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx:127:21: Cannot assign type 'void (const char ) except * nogil' to 'void ()(const char *) noexcept nogil'. Exception values are incompatible. Suggest adding 'noexcept' to type 'void (const char *) except * nogil'.
Traceback (most recent call last):
File "/home/fs01/mah4021/CORL/algorithms/offline/dt.py", line 11, in
import d4rl # noqa
File "/home/fs01/mah4021/d4rl/d4rl/init.py", line 14, in
import d4rl.locomotion
File "/home/fs01/mah4021/d4rl/d4rl/locomotion/init.py", line 2, in
from d4rl.locomotion import ant
File "/home/fs01/mah4021/d4rl/d4rl/locomotion/ant.py", line 20, in
import mujoco_py
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/init.py", line 2, in
from mujoco_py.builder import cymj, ignore_mujoco_warnings, functions, MujocoException
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 504, in
cymj = load_cython_ext(mujoco_path)
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 110, in load_cython_ext
cext_so_path = builder.build()
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 226, in build
built_so_file_path = self._build_impl()
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 296, in _build_impl
so_file_path = super()._build_impl()
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 239, in _build_impl
dist.ext_modules = cythonize([self.extension])
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/Cython/Build/Dependencies.py", line 1154, in cythonize
cythonize_one(*args)
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/Cython/Build/Dependencies.py", line 1321, in cythonize_one
raise CompileError(None, pyx_file)
Cython.Compiler.Errors.CompileError: /home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx
(CORL) bash-4.4$ pip instal --upgrade cython
ERROR: unknown command "instal" - maybe you meant "install"
(CORL) bash-4.4$ pip install --upgrade cython
Requirement already satisfied: cython in /home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages (3.0.5)
(CORL) bash-4.4$ python dt.py
Compiling /home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx because it changed.
[1/1] Cythonizing /home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx
performance hint: /home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx:67:5: Exception check on 'c_warning_callback' will always require the GIL to be acquired.
Possible solutions:
1. Declare the function as 'noexcept' if you control the definition and you're sure you don't want the function to raise exceptions.
2. Use an 'int' return type on the function to allow an error code to be returned.
performance hint: /home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx:104:5: Exception check on 'c_error_callback' will always require the GIL to be acquired.
Possible solutions:
1. Declare the function as 'noexcept' if you control the definition and you're sure you don't want the function to raise exceptions.
2. Use an 'int' return type on the function to allow an error code to be returned.

Error compiling Cython file:

...
See c_warning_callback, which is the C wrapper to the user defined function
'''
global py_warning_callback
global mju_user_warning
py_warning_callback = warn
mju_user_warning = c_warning_callback
^

/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx:92:23: Cannot assign type 'void (const char ) except * nogil' to 'void ()(const char *) noexcept nogil'. Exception values are incompatible. Suggest adding 'noexcept' to type 'void (const char *) except * nogil'.

Error compiling Cython file:

...
See c_warning_callback, which is the C wrapper to the user defined function
'''
global py_error_callback
global mju_user_error
py_error_callback = err_callback
mju_user_error = c_error_callback
^

/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx:127:21: Cannot assign type 'void (const char ) except * nogil' to 'void ()(const char *) noexcept nogil'. Exception values are incompatible. Suggest adding 'noexcept' to type 'void (const char *) except * nogil'.
Traceback (most recent call last):
File "/home/fs01/mah4021/CORL/algorithms/offline/dt.py", line 11, in
import d4rl # noqa
File "/home/fs01/mah4021/d4rl/d4rl/init.py", line 14, in
import d4rl.locomotion
File "/home/fs01/mah4021/d4rl/d4rl/locomotion/init.py", line 2, in
from d4rl.locomotion import ant
File "/home/fs01/mah4021/d4rl/d4rl/locomotion/ant.py", line 20, in
import mujoco_py
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/init.py", line 2, in
from mujoco_py.builder import cymj, ignore_mujoco_warnings, functions, MujocoException
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 504, in
cymj = load_cython_ext(mujoco_path)
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 110, in load_cython_ext
cext_so_path = builder.build()
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 226, in build
built_so_file_path = self._build_impl()
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 296, in _build_impl
so_file_path = super()._build_impl()
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/builder.py", line 239, in _build_impl
dist.ext_modules = cythonize([self.extension])
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/Cython/Build/Dependencies.py", line 1154, in cythonize
cythonize_one(*args)
File "/home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/Cython/Build/Dependencies.py", line 1321, in cythonize_one
raise CompileError(None, pyx_file)
Cython.Compiler.Errors.CompileError: /home/fs01/mah4021/.conda/envs/CORL/lib/python3.10/site-packages/mujoco_py/cymj.pyx

FIX
(CORL) bash-4.4$ pip install "cython<3"
Collecting cython<3
Downloading Cython-0.29.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (3.1 kB)
Downloading Cython-0.29.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (1.9 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 11.4 MB/s eta 0:00:00
Installing collected packages: cython
Attempting uninstall: cython
Found existing installation: Cython 3.0.5
Uninstalling Cython-3.0.5:
Successfully uninstalled Cython-3.0.5
Successfully installed cython-0.29.36

Migration to Gymnasium

How exactly do I do roll outs with the trained policy with rendering? Is that naively easy to do, or requires work?

Minari Integration with CORL

There we will track out progress for Minari integration with CORL. Minari is a standard format for offline RL datasets, with popular reference datasets and related utilities, which we believe will replace D4RL in the future by combining most of the existing benchmarks under unified interface and storage. And we want to be prepared! Eventually this will become a CORLv2.

The plan is to add an experimental separate directory with algorithms using Minari and to retraing all algorithms on the D4RL datasets currently (and in the future) re-created in the Minari.

Adapted algorithms:

  • BC
  • TD3 + BC
  • AWAC
  • CQL
  • IQL
  • SAC-N
  • EDAC
  • DT

Retrained algorithms (only Adroit for now):

  • BC
  • TD3 + BC
  • AWAC
  • CQL
  • IQL
  • SAC-N
  • EDAC
  • DT

CQL Hopper normalizing reward

In the config files for CQL and Hopper, for every dataset, the normalize_reward is set to false, while only in medium_v2.yaml it is set to true (medium_v2.yaml). I'm wondering if it's set to true by mistake or it's intentional. If it's intentional, why it's only normalizing rewards for medium_v2 and not other datasets?

Thanks.

CQL: Unable to maintain performance

I have trained a CQL agent on HalfCheetah with a good performance, getting to 5000 per episode. I modified the code to save the best model during training not every mode after evaluation:

            if config.checkpoints_path:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
                )

to:

            if config.checkpoints_path and eval_score > best_eval_score:
                best_eval_score = eval_score
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"best.pt"),
                )

The issue happens when I load this model and run it over a newly created environment. In that case, the performance drops to -700 per episode which is basically the performance of a random agent. Here is the code to test the agent:

@torch.no_grad()
def test_actor(
    config: argparse
) -> np.ndarray:
    env = gym.make(config.env)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    # Set seeds
    seed = config.seed
    set_seed(seed, env)

    critic_1 = FullyConnectedQFunction(
        state_dim,
        action_dim,
        config.orthogonal_init,
        config.q_n_hidden_layers,
    ).to(config.device)
    critic_2 = FullyConnectedQFunction(state_dim, action_dim, config.orthogonal_init).to(
        config.device
    )
    critic_1_optimizer = torch.optim.Adam(list(critic_1.parameters()), config.qf_lr)
    critic_2_optimizer = torch.optim.Adam(list(critic_2.parameters()), config.qf_lr)

    actor = TanhGaussianPolicy(
        state_dim,
        action_dim,
        max_action,
        log_std_multiplier=config.policy_log_std_multiplier,
        orthogonal_init=config.orthogonal_init,
    ).to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), config.policy_lr)

    kwargs = {
        "critic_1": critic_1,
        "critic_2": critic_2,
        "critic_1_optimizer": critic_1_optimizer,
        "critic_2_optimizer": critic_2_optimizer,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "discount": config.discount,
        "soft_target_update_rate": config.soft_target_update_rate,
        "device": config.device,
        # CQL
        "target_entropy": -np.prod(env.action_space.shape).item(),
        "alpha_multiplier": config.alpha_multiplier,
        "use_automatic_entropy_tuning": config.use_automatic_entropy_tuning,
        "backup_entropy": config.backup_entropy,
        "policy_lr": config.policy_lr,
        "qf_lr": config.qf_lr,
        "bc_steps": config.bc_steps,
        "target_update_period": config.target_update_period,
        "cql_n_actions": config.cql_n_actions,
        "cql_importance_sample": config.cql_importance_sample,
        "cql_lagrange": config.cql_lagrange,
        "cql_target_action_gap": config.cql_target_action_gap,
        "cql_temp": config.cql_temp,
        "cql_alpha": config.cql_alpha,
        "cql_max_target_backup": config.cql_max_target_backup,
        "cql_clip_diff_min": config.cql_clip_diff_min,
        "cql_clip_diff_max": config.cql_clip_diff_max,
    }

    print("---------------------------------------")
    print(f"TESTING CQL, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = ContinuousCQL(**kwargs)

    model_path = os.path.join(config.checkpoints_path, f"best.pt")
    trainer.load_state_dict(torch.load(model_path, map_location=config.device))
    actor = trainer.actor

    actor.eval()
    episode_rewards = []
    for _ in range(config.test_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        while not done:
            action = actor.act(state, config.device)
            state, reward, done, _ = env.step(action)
            episode_reward += reward
        episode_rewards.append(episode_reward)

    print(f"TEST Mean reward: {np.mean(episode_rewards)}")

Katakomba link

The link to Katakomba in the readme leads to an archived repo

[Docs] Mention request for JAX-CORL

Dear CORL Team,

Firstly, I would like to express my appreciation for your work on the CORL codebase. The clean, single-file implementation coupled with a robust performance report has greatly impressed me, and I truly enjoy using it.

As someone who predominantly uses JAX in my research due to its speed and efficiency, I realized the potential value of a JAX-based counterpart to CORL. To address this gap, I recently launched JAX-CORL, which provides single-file implementations of Offline RL algorithms specifically in JAX.

Currently, JAX-CORL includes three algorithms: AWAC, IQL, and TD3+BC, and I have plans to expand it further to include other algorithms such as CQL, DT, and TD7 in near future.

Understanding that my repository is still in the early stages compared to CORL, I would be immensely grateful if you could consider mentioning JAX-CORL as a JAX-based single-file codebase for offline RL in your README.md. This acknowledgment would not only help in reaching more researchers who could benefit from this resource but also potentially foster further collaborative enhancements.

Thank you for considering this request.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.