Giter VIP home page Giter VIP logo

sbx's Introduction

CI codestyle

Stable Baselines Jax (SB3 + Jax = SBX)

Proof of concept version of Stable-Baselines3 in Jax.

Implemented algorithms:

Install using pip

For the latest master version:

pip install git+https://github.com/araffin/sbx

or:

pip install sbx-rl

Example

import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

env = gym.make("Pendulum-v1", render_mode="human")

model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)

vec_env = model.get_env()
obs = vec_env.reset()
for _ in range(1000):
    vec_env.render()
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)

vec_env.close()

Using SBX with the RL Zoo

Since SBX shares the SB3 API, it is compatible with the RL Zoo, you just need to override the algorithm mapping:

import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

if __name__ == "__main__":
    train()

Then you can run this script as you would with the RL Zoo:

python train.py --algo sac --env HalfCheetah-v4 -params train_freq:4 gradient_steps:4 -P

The same goes for the enjoy script:

import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

if __name__ == "__main__":
    enjoy()

Note about DroQ

DroQ is a special configuration of SAC.

To have the algorithm with the hyperparameters from the paper, you should use (using RL Zoo config format):

HalfCheetah-v4:
  n_timesteps: !!float 1e6
  policy: 'MlpPolicy'
  learning_starts: 10000
  gradient_steps: 20
  policy_delay: 20
  policy_kwargs: "dict(dropout_rate=0.01, layer_norm=True)"

and then using the RL Zoo script defined above: python train.py --algo sac --env HalfCheetah-v4 -c droq.yml -P.

We recommend playing with the policy_delay and gradient_steps parameters for better speed/efficiency. Having a higher learning rate for the q-value function is also helpful: qf_learning_rate: !!float 1e-3.

Note: when using the DroQ configuration with CrossQ, you should set layer_norm=False as there is already batch normalization.

Benchmark

A partial benchmark can be found on OpenRL Benchmark where you can also find several reports.

Citing the Project

To cite this repository in publications:

@article{stable-baselines3,
  author  = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
  title   = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
  journal = {Journal of Machine Learning Research},
  year    = {2021},
  volume  = {22},
  number  = {268},
  pages   = {1-8},
  url     = {http://jmlr.org/papers/v22/20-1364.html}
}

Maintainers

Stable-Baselines3 is currently maintained by Ashley Hill (aka @hill-a), Antonin Raffin (aka @araffin), Maximilian Ernestus (aka @ernestum), Adam Gleave (@AdamGleave), Anssi Kanervisto (@Miffyli) and Quentin Gallouédec (@qgallouedec).

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

How To Contribute

To any interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read CONTRIBUTING.md guide first.

Contributors

We would like to thank our contributors: @jan1854.

sbx's People

Contributors

araffin avatar jan1854 avatar paolodelia99 avatar theovincent avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sbx's Issues

Mujoco XLA - MJX Integration

As the biggest bottleneck of the training performance of SB3 is the environment, I am considering integrating SB3 with Mujoco XLA which is Mujoco written in Jax. Would this integration increase the performance? Currently, Mujoco XLA is released with huge performance improvement with Brax, including RL algorithms in JAX. Is SBX fully written in JAX?

[Bug] TQC Hyperparameter optimization: Results do not match the reference. This is likely a bug/unexpected loss of precision.

🐛 Bug

Hi,

When I try to run TQC hyperparameter optimization with multiple jobs (n-jobs>1) with a GPU (this also happens with multiple CPU cores and n-jobs=1), it gives me this error:

2024-04-07 14:35:59.992779: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 0: -inf, expected -0.000287323
2024-04-07 14:35:59.992804: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 1: -inf, expected -0.000267224
2024-04-07 14:35:59.992808: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 2: -inf, expected -0.000226477
2024-04-07 14:35:59.992811: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 3: -inf, expected -0.000281823
2024-04-07 14:35:59.992813: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 4: -inf, expected -0.000262532
2024-04-07 14:35:59.992815: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 5: -inf, expected -0.000252724
2024-04-07 14:35:59.992818: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 6: -inf, expected -0.000250007
2024-04-07 14:35:59.992820: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 7: -inf, expected -0.000265674
2024-04-07 14:35:59.992823: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 8: -inf, expected -0.00021464
2024-04-07 14:35:59.992825: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 9: -inf, expected -0.000204733
E0407 14:35:59.992828  798907 triton_autotuner.cc:766] Results do not match the reference. This is likely a bug/unexpected loss of precision.

To Reproduce

python rl-baselines3-zoo/train_sbx.py --algo tqc --env Pendulum-v1 -n 5000 --n-trials 50 --num-threads 1 --n-jobs 4 --log-interval 4900 --eval-episodes 16 --n-eval-envs 8 --seed 8 --vec-env "dummy" -optimize --sampler tpe --pruner median --n-startup-trials 10
[W 2024-04-07 14:36:00,208] Trial 16 failed with parameters: {'gamma': 0.995, 'learning_rate': 0.23149128592335125, 'batch_size': 1024, 'buffer_size': 10000, 'learning_starts': 1000, 'train_freq': 16, 'tau': 0.08, 'log_std_init': -0.3684256821552643, 'net_arch': 'medium', 'n_quantiles': 32, 'top_quantiles_to_drop_per_net': 30} because of the following error: XlaRuntimeError('INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.  Per-algorithm errors:\n  Results do not match the reference. This is likely a bug/unexpected loss of precision.

Traceback (most recent call last):
  File "/home/.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective
    model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)  # type: ignore[arg-type]
  File "/home/.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn
    return super().learn(
  File "/home/.conda/envs/.../lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn
    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
  File "/home/.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train
    ) = self._train(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.


Traceback (most recent call last):
File "/scratch/network/.../.../rl-baselines3-zoo/train_sbx.py", line 19, in <module>
train()
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/train.py", line 275, in train
exp_manager.hyperparameters_optimization()
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 874, in hyperparameters_optimization
study.optimize(self.objective, n_jobs=self.n_jobs, n_trials=self.n_trials)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/study.py", line 451, in optimize
_optimize(
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 99, in _optimize
f.result()
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 159, in _optimize_sequential
frozen_trial = _run_trial(study, func, catch)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 247, in _run_trial
raise func_err
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
value_or_values = func(trial)
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type]
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn
return super().learn(
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train
) = self._train(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.

### System Info

Describe the characteristic of your environment:

  • Library installed through pip

  • GPU models and configuration
    +---------------------------------------------------------------------------------------+
    | NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
    |-----------------------------------------+----------------------+----------------------+
    | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
    | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
    | | | MIG M. |
    |=========================================+======================+======================|
    | 0 NVIDIA A100 80GB PCIe On | 00000000:0D:00.0 Off | 0 |
    | N/A 40C P0 67W / 300W | 3508MiB / 81920MiB | 0% Default |
    | | | Disabled |
    +-----------------------------------------+----------------------+----------------------+
    | 1 NVIDIA A100 80GB PCIe On | 00000000:B5:00.0 Off | 0 |
    | N/A 38C P0 49W / 300W | 5MiB / 81920MiB | 0% Default |
    | | | Disabled |
    +-----------------------------------------+----------------------+----------------------+

  • Python 3.10.14

  • pytorch 2.2.2 py3.10_cuda12.1_cudnn8.9.2_0
    pytorch-cuda 12.1 ha16c6d3_5 pytorch
    pytorch-mutex 1.0 cuda pytorch
    torchtriton 2.2.0 py310 pytorch

  • Gym version
    gymnasium 0.29.1

  • Versions of any other relevant libraries
    jax 0.4.25 pyhd8ed1ab_0 conda-forge
    jax-jumpy 1.0.0 pyhd8ed1ab_0 conda-forge
    jaxlib 0.4.23 cuda118py310h8c47008_200 conda-forge

Additional context

I've noticed there's no bug when n-jobs=1, only when running multiple jobs. Maybe something with the way Optuna runs multiple jobs?

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

[Question] TypeError when exporting a model to PyTorch in SBX

🐛 Bug

When using PyTorch JIT to trace and save a trained model with SBX an exception occurs.

To Reproduce

The following code works fine for a model trained with TD3 with SB3. However, a TypeError occurs when trying to save a model trained with SBX.

import torch as th
from stable_baselines3.common.policies import BasePolicy
from sbx import TD3
from typing import Tuple
import torch as th

class OnnxableSB3Policy(th.nn.Module):
    def __init__(self, policy: BasePolicy):
        super().__init__()
        self.policy = policy

    def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
        return self.policy(observation, deterministic=True)
    
jit_path = "model.pt"

cuda_id = th.cuda.current_device()
model = TD3.load("model", device=cuda_id)
onnxable_model = OnnxableSB3Policy(model.policy)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size).to(device=cuda_id)

# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
Traceback (most recent call last):
  File "/home/.venv/lib/python3.12/site-packages/jax/_src/api_util.py", line 584, in shaped_abstractify
    return _shaped_abstractify_handlers[type(x)](x)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^
KeyError: <class 'torch.Tensor'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/export_to_pt.py", line 33, in <module>
    traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/jit/_trace.py", line 806, in trace
    return trace_module(
           ^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/jit/_trace.py", line 1074, in trace_module
    module._c._create_method_from_trace(
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/export_to_pt.py", line 21, in forward
    return self.policy(observation, deterministic=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/sbx/td3/policies.py", line 178, in forward
    return self._predict(obs, deterministic=deterministic)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/sbx/td3/policies.py", line 187, in _predict
    return TD3Policy.select_action(self.actor_state, observation)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot interpret 'torch.float32' as a data type
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

### System Info

- OS: Linux-4.18.0-513.11.1.el8_9.0.1.x86_64-x86_64-with-glibc2.28 # 1 SMP Sun Feb 11 10:42:18 UTC 2024
- Python: 3.12.1
- Stable-Baselines3: 2.3.0a2
- PyTorch: 2.2.1+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
- OpenAI Gym: 0.26.2

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

self.key is never updated

Thank you for your work on this cool repo! It is really useful for my research :)

🐛 Bug

Why is self.key always the same after each self._train call? More precisely, why is this part of the code, coded like this

sbx/sbx/sac/sac.py

Lines 446 to 449 in fcd647e

update_carry["actor_state"],
update_carry["ent_coef_state"],
key,
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),

and not like this

update_carry["actor_state"], 
update_carry["ent_coef_state"], 
update_carry["key"],  # Return the new updated key
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]), 

?

To Reproduce

Edit the method train in the file sbx/sac/sac.py to add the following line of code after the function _train has been called:

print("self.key", self.key)

Example: https://github.com/theovincent/sbx/blob/8327b98463c89b68f17ec0431d0cf3069cb7d7a7/sbx/sac/sac.py#L236

Create the following file, called train.py at the top level of the project:

import gymnasium as gym

from sbx import SAC

env = gym.make("Pendulum-v1")

model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=110, progress_bar=True)

Running python train.py in the terminal yields

>>> python train.py
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]
self.key [2110677572 2465855137]

Expected behavior

Changing the code, as suggested earlier, fixes the problem. Here are the logs when the change is implemented:

>>> python train.py
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
self.key [3440514203 2996688322]
self.key [ 507603733 1743734701]
self.key [1106737823 3095002064]
self.key [ 372788615 2111558586]
self.key [1808065049 3808616220]
self.key [1837019053 2754803453]
self.key [1740029140 3438719296]
self.key [1088489055 1273990256]
self.key [3718340890 2050508589]
self.key [1872112782 1422931421]

### System Info

  • Describe how the library was installed (pip, docker, source, ...)
    Fork the repo, clone it, create a python virtual env, install the dependencies
python3 -m venv env
source env/bin/activate
pip install -e .
pip install gymnasium[classic-control]
  • GPU models and configuration
    The GPU is not used
  • pip version
    23.2.1
>>> import stable_baselines3 as sb3
>>> sb3.get_system_info()
- OS: Linux-6.5.0-27-generic-x86_64-with-glibc2.35 # 28~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 15 10:51:06 UTC 2
- Python: 3.11.5
- Stable-Baselines3: 2.3.0
- PyTorch: 2.2.2+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1

({'OS': 'Linux-6.5.0-27-generic-x86_64-with-glibc2.35 # 28~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 15 10:51:06 UTC 2', 'Python': '3.11.5', 'Stable-Baselines3': '2.3.0', 'PyTorch': '2.2.2+cu121', 'GPU Enabled': 'True', 'Numpy': '1.26.4', 'Cloudpickle': '3.0.0', 'Gymnasium': '0.29.1'}, '- OS: Linux-6.5.0-27-generic-x86_64-with-glibc2.35 # 28~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 15 10:51:06 UTC 2\n- Python: 3.11.5\n- Stable-Baselines3: 2.3.0\n- PyTorch: 2.2.2+cu121\n- GPU Enabled: True\n- Numpy: 1.26.4\n- Cloudpickle: 3.0.0\n- Gymnasium: 0.29.1\n')

Additional context

Before commit e564074, the key was updated each time the function self._train was called as you can see here:

sbx/sbx/sac/sac.py

Lines 389 to 392 in 0f9163d

actor_state,
ent_coef_state,
key,
(actor_loss_value, qf_loss_value, ent_coef_value),

This bug seems to be present for:

  • CrossQ
  • SAC
  • TD3
  • TQC

Checklist

  • [ X] I have checked that there is no similar issue in the repo (required)
  • [ X] I have read the documentation (required)
  • [ X] I have provided a minimal working example to reproduce the bug (required)

crash when using a custom network architecture

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

🤖 Custom Gym Environment

Please check your environment first using:

from stable_baselines3.common.env_checker import check_env

env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)

it passes the check_env

### Describe the bug

when using a custom network architecture: dict(net_arch=[1000, 500]) is fails in SBX code

A clear and concise description of what the bug is.

### Code example

net = dict(net_arch=[1000, 500])
PPOmodel = PPO('MlpPolicy', env, policy_kwargs=net)

Please try to provide a minimal example to reproduce the bug.
For a custom environment, you need to give at least the observation space, action space, reset() and step() methods
(see working example below).
Error messages and stack traces are also helpful.

Traceback (most recent call last):
File "/Users/eric/Documents/development/deepLearning/deepMind/sparky/train.py", line 12, in
t.train()
File "/Users/eric/Documents/development/deepLearning/deepMind/sparky/trainEnviornment.py", line 280, in train
PPOmodel = PPO('MlpPolicy', env,
File "/Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/ppo/ppo.py", line 165, in init
self._setup_model()
File "/Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/ppo/ppo.py", line 171, in _setup_model
self.policy = self.policy_class( # pytype:disable=not-instantiable
File "/Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/ppo/policies.py", line 102, in init
self.n_units = net_arch[0]["pi"][0]
TypeError: 'int' object is not subscriptable

Please use the markdown code blocks
for both code and stack traces.

import gym
import numpy as np

from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env


class CustomEnv(gym.Env):

  def __init__(self):
    super(CustomEnv, self).__init__()
    self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
    self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))

  def reset(self):
    return self.observation_space.sample()

  def step(self, action):
    obs = self.observation_space.sample()
    reward = 1.0
    done = False
    info = {}
    return obs, reward, done, info

env = CustomEnv()
check_env(env)

model = A2C("MlpPolicy", env, verbose=1).learn(1000)
Traceback (most recent call last): File ...

### System Info
Describe the characteristic of your environment:

  • Describe how the library was installed (pip, docker, source, ...)
  • GPU models and configuration
  • Python version
  • PyTorch version
  • Gym version
  • Versions of any other relevant libraries

You can use sb3.get_system_info() to print relevant packages info:

import stable_baselines3 as sb3
sb3.get_system_info()

Additional context

Add any other context about the problem here.

### Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)
  • I have checked my env using the env checker (required)
  • I have provided a minimal working example to reproduce the bug (required)

[Enhancement] Support for large gradient_steps in SAC

Description:
Using the Jax implementation of SAC with larger values of gradient_steps, e.g. 1000, is very slow to compile. Consider

sbx/sbx/sac/sac.py

Lines 333 to 352 in b8dbac1

@classmethod
@partial(jax.jit, static_argnames=["cls", "gradient_steps"])
def _train(
cls,
gamma: float,
tau: float,
target_entropy: np.ndarray,
gradient_steps: int,
data: ReplayBufferSamplesNp,
policy_delay_indices: flax.core.FrozenDict,
qf_state: RLTrainState,
actor_state: TrainState,
ent_coef_state: TrainState,
key,
):
actor_loss_value = jnp.array(0)
for i in range(gradient_steps):
def slice(x, step=i):

I think the problem lies in unrolling the loop over too many gradient steps. Removing line 334 for not jiting avoids the problem.

To Reproduce

from sbx import SAC
import gymnasium as gym

env = gym.make('Pendulum-v1')
model = SAC('MlpPolicy', env, verbose=1, gradient_steps=1000)

model.learn(100000)

Expected behavior

It should compile fast.

Potential Fix

I adjusted the implementation by moving all computations in the loop body of SAC._train to a new jit'd function gradient_step. Using this function in a JAX fori_loop solves the issue and almost instantly compiles. If you agree with this I would propose a PR with my solution.

### System Info

  • Describe how the library was installed (pip, docker, source, ...): pip
  • sbx-rl version: 0.7.0
  • Python version: 3.11
  • Jax version: 0.4.14
  • Gymnasium version: 0.29

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

SBX becomes super slow when number of cpus are limited

🐛 Bug

SBX becomes much slower than SB3 when the number of cpus are limited

To Reproduce

Steps to reproduce the behavior.

'''
For installation please do -
pip install gym
pip install sbx
pip install mujoco
pip install shimmy
'''
import gym
import psutil
import random
import os, subprocess as sp


def train():
    pid = os.getpid()
    num_of_cpus = 4
    process = psutil.Process(pid)
    print("Process = ", pid)
    affinity = process.cpu_affinity()
    cpus_selected = random.sample(affinity, num_of_cpus)
    print("cpus_selected = ", cpus_selected)
    # print("iteration = ", iteration)
    process.cpu_affinity(cpus_selected)
    env = gym.make("Humanoid-v4")

    model = SAC("MlpPolicy", env, verbose=1)

    model.learn(total_timesteps=7e3, progress_bar=True)

# from stable_baselines3 import SAC

from sbx import SAC


train()

Expected behavior

If you want to compare sb3 vs sbx, you can uncomment from stable_baselines3 import SAC and comment out from sbx import SAC. I am noticing that sb3 is much faster than sbx in such situations

### System Info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, docker, source, ...)
  • GPU models and configuration
  • Python version
  • PyTorch version
  • Gym version
  • Versions of any other relevant libraries

You can use sb3.get_system_info() to print relevant packages info:

import stable_baselines3 as sb3
sb3.get_system_info()
{'OS': 'Linux-6.5.0-15-generic-x86_64-with-glibc2.17 # 15~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Jan 12 18:54:30 UTC 2', 'Python': '3.8.18', 'Stable-Baselines3': '2.3.0a1', 'PyTorch': '2.1.2+cu121', 'GPU Enabled': 'True', 'Numpy': '1.24.3', 'Cloudpickle': '3.0.0', 'Gymnasium': '0.29.1', 'OpenAI Gym': '0.26.2'}, '- OS: Linux-6.5.0-15-generic-x86_64-with-glibc2.17 # 15~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Jan 12 18:54:30 UTC 2\n- Python: 3.8.18\n- Stable-Baselines3: 2.3.0a1\n- PyTorch: 2.1.2+cu121\n- GPU Enabled: True\n- Numpy: 1.24.3\n- Cloudpickle: 3.0.0\n- Gymnasium: 0.29.1\n- OpenAI Gym: 0.26.2\n')

{'Cloudpickle': '3.0.0',
 'GPU Enabled': 'True',
 'Gymnasium': '0.29.1',
 'Numpy': '1.24.3',
 'OS': 'Linux-6.5.0-15-generic-x86_64-with-glibc2.17 # 15~22.04.1-Ubuntu SMP '
       'PREEMPT_DYNAMIC Fri Jan 12 18:54:30 UTC 2',
 'OpenAI Gym': '0.26.2',
 'PyTorch': '2.1.2+cu121',
 'Python': '3.8.18',
 'Stable-Baselines3': '2.3.0a1'}

Checklist

  • [ X] I have checked that there is no similar issue in the repo (required)
  • [ X] I have read the documentation (required)
  • [ X] I have provided a minimal working example to reproduce the bug (required)

[Feature Request] Passing custom activation functon in policy_kwargs

🚀 Feature

Possibility to pass a flax (from the flax.linen.activation module) activation function when creating a sbx model, through the policy_kwargs argument.

Motivation

In the current implementation of sbx, users are unable to pass custom activation functions when creating a model. This limitation restricts flexibility and may not suit all users' needs.

Pitch

Example:

policy_kwargs = dict(activation_fn=my_custom_activation_fn, net_arch=dict(pi=[64, 64], qf=[64, 64]))

model = TD3("MlpPolicy",
                       env,
                      policy_kwargs=policy_kwargs,
                      verbose=1)

Idea on how to implement it

Add attribute activation_fn to the underlying classes that are composing the policy (like in Critic and Actor in t3d/policy.py)

[Bug] TQC Entropy Coefficient

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

If your issue is related to a custom gym environment, please use the custom gym env template.

🐛 Bug

When running train with hyperparameter optimization with TQC (python train.py --algo tqc --optimize), it gives TypeError: TQC.__init__() got an unexpected keyword argument 'target_entropy'.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks
for both code and stack traces.

import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

if __name__ == "__main__":
    train()
Traceback (most recent call last): File ...
[W 2024-04-04 15:33:50,250] Trial 2 failed with parameters: {'gamma': 0.99, 'learning_rate': 0.0003250530792956964, 'batch_size': 256, 'buffer_size': 100000, 'learning_starts': 0, 'train_freq': 32, 'tau': 0.005, 'log_std_init': -1.275286479641816, 'net_arch': 'medium', 'n_quantiles': 32, 'top_quantiles_to_drop_per_net': 24} because of the following error: TypeError("TQC.__init__() got an unexpected keyword argument 'target_entropy'").
Traceback (most recent call last):
  File ".../lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File ".../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 753, in objective
    model = ALGOS[self.algo](
TypeError: TQC.__init__() got an unexpected keyword argument 'target_entropy'

Expected behavior

A clear and concise description of what you expected to happen.
It should run hyperparameter optimization without the error

### System Info

Describe the characteristic of your environment:

  • installed from source

You can use sb3.get_system_info() to print relevant packages info:

import stable_baselines3 as sb3
sb3.get_system_info()

Additional context

Add any other context about the problem here.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

[Feature Request] Update type annotations

Many type annotations are deprecated

/home/antonin/miniconda3/lib/python3.11/site-packages/chex/_src/pytypes.py:54: DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of typed prng keys (i.e. keys created with jax.random.key).
/home/antonin/miniconda3/lib/python3.11/site-packages/flax/linen/activation.py:37: DeprecationWarning: jax.nn.normalize is deprecated. Use jax.nn.standardize instead.
  from jax.nn import normalize
/home/antonin/miniconda3/lib/python3.11/site-packages/flax/linen/activation.py:37: DeprecationWarning: jax.nn.normalize is deprecated. Use jax.nn.standardize instead.
  from jax.nn import normalize
/home/antonin/miniconda3/lib/python3.11/site-packages/chex/_src/pytypes.py:53: DeprecationWarning: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].
/home/antonin/miniconda3/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.
  return jax_config.define_bool_state('flax_' + name, default, help)

[Bug] example supplied in readme crashing

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

If your issue is related to a custom gym environment, please use the custom gym env template.

🐛 Bug

A clear and concise description of what the bug is.

If I run the example in your readme it crashes

To Reproduce

Steps to reproduce the behavior.

import gym

from sbx import TQC, DroQ, SAC, PPO, DQN

env = gym.make("Pendulum-v1")

model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)

vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()

vec_env.close()

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks
for both code and stack traces.

from stable_baselines3 import ...
Traceback (most recent call last): File ...

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [34], in <cell line: 7>()
      3 from sbx import TQC, DroQ, SAC, PPO, DQN
      5 env = gym.make("Pendulum-v1")
----> 7 model = TQC("MlpPolicy", env, verbose=1)
      8 model.learn(total_timesteps=10_000, progress_bar=True)
     10 vec_env = model.get_env()

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/tqc.py:102, in TQC.__init__(self, policy, env, learning_rate, qf_learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, policy_delay, top_quantiles_to_drop_per_net, action_noise, ent_coef, use_sde, sde_sample_freq, use_sde_at_warmup, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)
     99 self.policy_kwargs["top_quantiles_to_drop_per_net"] = top_quantiles_to_drop_per_net
    101 if _init_setup_model:
--> 102     self._setup_model()

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/tqc.py:115, in TQC._setup_model(self)
    107 if self.policy is None:
    108     self.policy = self.policy_class(  # pytype:disable=not-instantiable
    109         self.observation_space,
    110         self.action_space,
    111         self.lr_schedule,
    112         **self.policy_kwargs,  # pytype:disable=not-instantiable
    113     )
--> 115     self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
    117     self.key, ent_key = jax.random.split(self.key, 2)
    119     self.actor = self.policy.actor

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/policies.py:142, in TQCPolicy.build(self, key, lr_schedule, qf_learning_rate)
    137 # Hack to make gSDE work without modifying internal SB3 code
    138 self.actor.reset_noise = self.reset_noise
    140 self.actor_state = TrainState.create(
    141     apply_fn=self.actor.apply,
--> 142     params=self.actor.init(actor_key, obs),
    143     tx=self.optimizer_class(learning_rate=lr_schedule(1), **self.optimizer_kwargs),
    144 )
    146 self.qf = Critic(
    147     dropout_rate=self.dropout_rate,
    148     use_layer_norm=self.layer_norm,
    149     n_units=self.n_units,
    150     n_quantiles=self.n_quantiles,
    151 )
    153 self.qf1_state = RLTrainState.create(
    154     apply_fn=self.qf.apply,
    155     params=self.qf.init(
   (...)
    165     tx=optax.adam(learning_rate=qf_learning_rate),
    166 )

    [... skipping hidden 9 frame]

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/policies.py:66, in Actor.__call__(self, x)
     63 log_std = nn.Dense(self.action_dim)(x)
     64 log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
     65 dist = TanhTransformedDistribution(
---> 66     tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)),
     67 )
     68 return dist

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py:235, in MultivariateNormalDiag.__init__(self, loc, scale_diag, scale_identity_multiplier, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    232 if scale_diag is not None:
    233   diag_cls = (KahanLogDetLinOpDiag if experimental_use_kahan_sum else
    234               tf.linalg.LinearOperatorDiag)
--> 235   scale = diag_cls(
    236       diag=scale_diag,
    237       is_non_singular=True,
    238       is_self_adjoint=True,
    239       is_positive_definite=False)
    240 else:
    241   # Deprecated behavior; breaks variable-safety rules by calling
    242   # `tf.shape(loc)`.
    243   num_rows = tf.compat.dimension_value(loc.shape[-1])

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:191, in LinearOperatorDiag.__init__(self, diag, is_non_singular, is_self_adjoint, is_positive_definite, is_square, name)
    182 super(LinearOperatorDiag, self).__init__(
    183     dtype=self._diag.dtype,
    184     is_non_singular=is_non_singular,
   (...)
    188     parameters=parameters,
    189     name=name)
    190 # TODO(b/143910018) Remove graph_parents in V3.
--> 191 self._set_graph_parents([self._diag])

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator.py:1177, in LinearOperator._set_graph_parents(self, graph_parents)
   1174 for i, t in enumerate(graph_parents):
   1175   if t is None or not (linear_operator_util.is_ref(t) or
   1176                        ops.is_tensor(t)):
-> 1177     raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
   1178 self._graph_parents = graph_parents

ValueError: Graph parent item 0 is not a Tensor; [[0.48654944]].
### Expected behavior

A clear and concise description of what you expected to happen.

if I import stable baselines PPO or others they train the example perfectly. Expecting SBX to do the same.


### System Info

Describe the characteristic of your environment:

Operating system is MacOS Monterey

 * Describe how the library was installed (pip, docker, source, ...)

pip

 * GPU models and configuration
 * Python version
 Python 3.9.15

 * PyTorch version
 * torch                      1.13.1
 * Gym version
  gym                        0.21.0
 * Versions of any other relevant libraries

You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()

Additional context

Add any other context about the problem here.

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)
  • [x ] I have read the documentation (required)
  • [x ] I have provided a minimal working example to reproduce the bug (required)

[Bug] AttributeError: module 'tensorflow.python.util.tf_inspect' has no attribute 'Parameter'

🐛 Bug

To Reproduce

import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ

env = gym.make("Pendulum-v1")

model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)

vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)
    vec_env.render()

vec_env.close()

error:

Traceback (most recent call last):
  File "check_sbx.py", line 3, in <module>
    from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/sbx/__init__.py", line 5, in <module>
    from sbx.droq import DroQ
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/sbx/droq/__init__.py", line 1, in <module>
    from sbx.droq.droq import DroQ
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/sbx/droq/droq.py", line 7, in <module>
    from sbx.tqc.policies import TQCPolicy
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/sbx/tqc/__init__.py", line 1, in <module>
    from sbx.tqc.tqc import TQC
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/sbx/tqc/tqc.py", line 19, in <module>
    from sbx.tqc.policies import TQCPolicy
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/sbx/tqc/policies.py", line 13, in <module>
    from sbx.common.distributions import TanhTransformedDistribution
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/sbx/common/distributions.py", line 7, in <module>
    tfd = tfp.distributions
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 53, in __getattr__
    module = self._load()
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 40, in _load
    module = importlib.import_module(self.__name__)
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 41, in <module>
    from tensorflow_probability.substrates.jax import bijectors
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 19, in <module>
    from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 17, in <module>
    from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py", line 19, in <module>
    from tensorflow_probability.python.internal.backend.jax import compat
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py", line 18, in <module>
    from tensorflow_probability.python.internal.backend.jax import v2
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/v2.py", line 27, in <module>
    from tensorflow_probability.python.internal.backend.jax import linalg
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py", line 28, in <module>
    from tensorflow_probability.python.internal.backend.jax.gen import adjoint_registrations as _adjoint_registrations
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/adjoint_registrations.py", line 37, in <module>
    from tensorflow_probability.python.internal.backend.jax.gen import linear_operator
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator.py", line 58, in <module>
    from tensorflow_probability.python.internal.backend.jax.gen import linear_operator_algebra
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_algebra.py", line 40, in <module>
    from tensorflow_probability.python.internal.backend.jax import tf_inspect
  File "/home/user/miniconda3/envs/decision-transformer-gym/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/tf_inspect.py", line 26, in <module>
    Parameter = inspect.Parameter
AttributeError: module 'tensorflow.python.util.tf_inspect' has no attribute 'Parameter'

Expected behavior

Should have started training.

### System Info

- OS: Linux-6.5.0-17-generic-x86_64-with-glibc2.10 # 17~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jan 16 14:32:32 UTC 2
- Python: 3.8.5
- Stable-Baselines3: 2.3.0a2
- PyTorch: 2.2.0+cu121
- GPU Enabled: True
- Numpy: 1.23.0
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
- OpenAI Gym: 0.18.3

Additional context

Add any other context about the problem here.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

[Question] MaskablePPO support

Hello,

I have a question about SBX. I am using maskablePPO from SB3 contrib because in my RL problem action masking is really important. I have found out about SBX library which seems to be very promising for speeding up computation time. My question is: does SBX support maskablePPO from SB3 contrib?

Thank you very much in advance for your help,
Best,
G.

[Feature Request] Dict Obs Spaces Support

🚀 Feature

As far as I understood by some preliminar tests, SBX currently does not support Dict observation spaces, do you have this feature in your roadmap? If yes when is it expected to be added?

Motivation

It would allow to extend the list of supported use cases featuring more complex/structured observation spaces

### Checklist

  • I have checked that there is no similar issue in the repo (required)

[Question] Why is fps much lower than CPU if using GPU

Question

I encountered a problem that a RL training can run at 5000 fps if I'm using sbx+cpu, but after the jaxlib-cuda11 was installed, it can only run at about 2000 fps.

Platform: Ubuntu 20.04, x86_64
Python version: 3.9.12
GPU: NVIDIA RTX 4090
CPU: i9-13900KS

nvidia-smi:

Thu Mar 21 14:41:29 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  Off |
|  0%   48C    P2    68W / 500W |  20952MiB / 24564MiB |     25%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1281      G   /usr/lib/xorg/Xorg                941MiB |
|    0   N/A  N/A      2137      G   /usr/bin/gnome-shell              183MiB |
|    0   N/A  N/A      4972      G   ...RendererForSitePerProcess        7MiB |
|    0   N/A  N/A      5435      G   ...2gtk-4.0/WebKitWebProcess        6MiB |
|    0   N/A  N/A      6950      G   ...b2020b/bin/glnxa64/MATLAB        6MiB |
|    0   N/A  N/A      8317      G   ...17D222A6D1FB8847155E9F895       19MiB |
|    0   N/A  N/A     89295      G   ...on=20240315-130113.878000      249MiB |
|    0   N/A  N/A     96608      C   ...3/envs/sb3-jax/bin/python    19530MiB |
+-----------------------------------------------------------------------------+

pip list | grep nvidia

nvidia-cublas-cu11            11.11.3.6
nvidia-cublas-cu12            12.4.2.65
nvidia-cuda-cupti-cu11        11.8.87
nvidia-cuda-cupti-cu12        12.4.99
nvidia-cuda-nvcc-cu11         11.8.89
nvidia-cuda-nvcc-cu12         12.4.99
nvidia-cuda-nvrtc-cu11        11.8.89
nvidia-cuda-nvrtc-cu12        12.1.105
nvidia-cuda-runtime-cu11      11.8.89
nvidia-cuda-runtime-cu12      12.4.99
nvidia-cudnn-cu11             8.9.6.50
nvidia-cudnn-cu12             8.9.2.26
nvidia-cufft-cu11             10.9.0.58
nvidia-cufft-cu12             11.2.0.44
nvidia-curand-cu12            10.3.2.106
nvidia-cusolver-cu11          11.4.1.48
nvidia-cusolver-cu12          11.6.0.99
nvidia-cusparse-cu11          11.7.5.86
nvidia-cusparse-cu12          12.3.0.142
nvidia-nccl-cu11              2.20.5
nvidia-nccl-cu12              2.19.3
nvidia-nvjitlink-cu12         12.4.99
nvidia-nvtx-cu12              12.1.105

and training script is:

from stable_baselines3.common.envs import MyCustomEnv
from sbx import PPO

env = MyCustomEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=int(1e8), progress_bar=True)

then I got:

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

...

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 11.7     |
|    ep_rew_mean        | -34.5    |
| time/                 |          |
|    fps                | 2027     |
|    iterations         | 927      |
|    time_elapsed       | 936      |
|    total_timesteps    | 1898496  |
| train/                |          |
|    clip_range         | 0.2      |
|    explained_variance | 0.34     |
|    n_updates          | 9260     |
|    pg_loss            | -0.0721  |
|    value_loss         | 5.89     |
------------------------------------

However, if using cpu,

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

...


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 11.6     |
|    ep_rew_mean        | -35.2    |
| time/                 |          |
|    fps                | 4773     |
|    iterations         | 139      |
|    time_elapsed       | 59       |
|    total_timesteps    | 284672   |
| train/                |          |
|    clip_range         | 0.2      |
|    explained_variance | 0.281    |
|    n_updates          | 1380     |
|    pg_loss            | -0.0598  |
|    value_loss         | 4.52     |

jax(gpu) is installed by Installing JAX, the cpu version is installed by pip install jax.

Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)

[Question] Extending sbx algorithms (e.g via a callback)

Hi there,

I'm trying to experiment with "RL while learning Minmax penalty" (paper, code), and I thought I'd try adding it to a sbx Droq setup. From the paper, the implementation looks quite straightforward, essentially:

for each step:
    penalty = minmaxpenalty.update(reward, Q[state])
    if info["unsafe"]:
        reward = penalty

hence I need to obtain the Q-value. I've been looking into the Droq code and I believe the Q-value is computed at (?)

next_target_quantiles = next_quantiles[:, :n_target_quantiles]
I've also been looking into trying to implement this via a StableBaselines callback, but can't seem to get it to work (not sure if this is a suitable use-case?)

Many thanks for any help, and for this fantastic lib! :)

Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)

[Feature Request] Support Optax Optimizer Schedules

🚀 Feature

Support optax optimizer schedules (as argument for learning_rate).

Motivation

Learning rate scheduling can be essential to achieving good results when training agents.

Pitch

Stable Baselines 3 has support for learning rate scheduling and it appears as though doing so in sbx could be achieved by allowing users to pass in optax optimizer schedules when specifying a model.

Alternatives

Unsure (None?)

Additional context

An example of what happens when trying to pass an optax optimizer schedule during construction of TQC.

Running

import optax
import gymnasium as gym

from sbx import TQC, DroQ, SAC, TD3, DDPG

env = gym.make("Pendulum-v1")

lr_schedule = optax.piecewise_constant_schedule(1e-3, boundaries_and_scales={5000: 0.1})

model = TQC("MlpPolicy", env, learning_rate=lr_schedule, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)

Yields the following assertion error:

Traceback (most recent call last):
  File "/mnt/sb3/sbx_reprex.py", line 10, in <module>
    model = TQC("MlpPolicy", env, learning_rate=lr_schedule, verbose=1)
  File "/opt/conda/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 111, in __init__
    self._setup_model()
  File "/opt/conda/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 125, in _setup_model
    assert isinstance(self.qf_learning_rate, float)
AssertionError

As far as I can tell, this same behaviour is present in TD3, SAC, DroQ, DDPG.

Can the assertion that the qf_learning_rate is a float (here for TQC) be removed/relaxed to support the use of optax schedules?

### Checklist

  • [✓] I have checked that there is no similar issue in the repo (required)

Custom env with FrameStack wrapper causes invalid actions to be passed to `env.step`

🤖 Custom Gym Environment

Describe the bug

When using gymnasium.wrappers.frame_stack.FrameStack with a simple custom env, I get an exception when an action is being chosen in step.

Code example

import itertools
from typing import Any, List, Tuple

import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers.frame_stack import FrameStack
from sbx import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv


class MyEnv(gym.Env):
    def __init__(self) -> None:
        self.actions, self.action_space = self.actionSpace()
        self.observation_space = Box(0, 1, shape=(1,))

        super().__init__()

    def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
        chosenAction = self.actions[action]

        return self.obs(), 0.0, False, False, {}

    def reset(
        self, *, seed: int | None = None, options: dict | None = None
    ) -> Tuple[Any, dict]:
        super().reset(seed=seed, options=options)
        return self.obs(), {}

    def obs(self):
        return np.array([0.5], dtype=np.float32)

    def render(self) -> Any | List[Any] | None:
        pass

    def actionSpace(self):
        baseActions = [0, 1, 2, 3, 4]

        totalActionsWithRepeats = list(itertools.permutations(baseActions, 2))
        withoutRepeats = []

        for combination in totalActionsWithRepeats:
            reversedCombination = combination[::-1]
            if reversedCombination not in withoutRepeats:
                withoutRepeats.append(combination)

        filteredActions = [[action] for action in baseActions] + withoutRepeats

        return filteredActions, Discrete(len(filteredActions))


if __name__ == "__main__":
    env = MyEnv()
    check_env(env)

    env = FrameStack(env, 4)
    env = DummyVecEnv([lambda: env])

    algo = PPO("MlpPolicy", env)
    algo.learn(total_timesteps=1000)
Traceback (most recent call last):
  File "/home/user/sbx_ppo_repro.py", line 61, in <module>
    algo.learn(total_timesteps=1000)
  File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/ppo/ppo.py", line 315, in learn
    return super().learn(
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 259, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/common/on_policy_algorithm.py", line 152, in collect_rollouts
    new_obs, rewards, dones, infos = env.step(clipped_actions)
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 197, in step
    return self.step_wait()
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 58, in step_wait
    obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
  File "/home/user/jax-venv/lib/python3.10/site-packages/gymnasium/wrappers/frame_stack.py", line 179, in step
    observation, reward, terminated, truncated, info = self.env.step(action)
  File "/home/user/sbx_ppo_repro.py", line 21, in step
    chosenAction = self.actions[action]
TypeError: only integer scalar arrays can be converted to a scalar index

### System Info

  • OS: Linux-6.5.6-76060506-generic-x86_64-with-glibc2.35 # 202310061235169739694522.04~9283e32 SMP PREEMPT_DYNAMIC Sun O
  • Python: 3.10.12
  • Stable-Baselines3: 2.1.0
  • PyTorch: 2.1.0+cu121
  • GPU Enabled: True
  • GPU Model: Nvida RTX 3080ti
  • Numpy: 1.26.1
  • Cloudpickle: 3.0.0
  • Gymnasium: 0.29.1

sbx at the latest commit was installed using pip: pip install git+https://github.com/araffin/sbx

### Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)
  • I have checked my env using the env checker (required)
  • I have provided a minimal working example to reproduce the bug (required)

[Feature Request] Multi-Discrete action spaces for PPO

🚀 Feature

Currently, PPO only supports (<class 'gymnasium.spaces.box.Box'>, <class 'gymnasium.spaces.discrete.Discrete'>) as action spaces. It would be awesome if it also supported MultiDiscrete action spaces.

Motivation

For many applications (Atari), one has to choose multiple discrete actions at each time step. StableBaselines3 supports MultiDiscrete action spaces already and it would be great if sbx supported it as well.

### Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)

[Question] Speedup compared to SB3

Question

One of the main feature of JAX compared to torch or TF is the speed. Would it be possible to showcase the speedup obtained using SBX compared to SB3 on environment that are fully jitted/not fully jitted ? It would give insights about the speedup increased and if it is worth switching from one library to the other.

Thanks !

Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.