Comments (6)
You might want to take a look at the newest suggest hparams for humanoid--we've fixed a couple of lingering bugs that were hurting humanoid performance AND ppo works for humanoid now (quite well).
from brax.
Hello!
There are a couple ways you could accomplish this. We've recently introduced the DefaultState
message to our config, which you could use to orient your system at initialization (see here: https://github.com/google/brax/blob/main/brax/physics/config.proto#L137)
If that doesn't work for you, you can always manually rotate the entire body, but you do need to be a little bit careful to make sure everything is rotated correctly. Here's a little snippet function that should do it:
@functools.partial(jax.vmap,in_axes=[0,0,None,None,None])
def transform_qp(qp, mask, rot, ref_vec, offset_vec):
"""Rotates a qp by some rot around some ref_vec, and translates by offset_vec.
Args:
qp: QPs to be rotated
mask: whether to transform this qp or not
rot: Quaternion to rotate by
ref_vec: point around which to rotate.
offset_vec: a vector to translate everything by"""
relative_pos = qp.pos - ref_vec
cur_rot = qp.rot
new_pos = brax.physics.math.rotate(relative_pos, rot) + ref_vec + offset_vec
new_rot = brax.physics.math.qmult(rot, qp.rot)
return brax.physics.base.QP(pos=jnp.where(mask,new_pos,qp.pos),
vel=qp.vel,
ang=qp.ang,
rot=jnp.where(mask,new_rot,qp.rot))
Take care with the batch indices on this function--it operates in parallel over an array of QPs as well as a mask array with the same leading batch dimension as those QPs. So say your system has a robot + the ground. You probably don't want to apply this transformation to the ground, so you can construct a mask like:
mask = jnp.array([1. if b.name != 'ground' else 0 for b in my_config.config.bodies])
You'll also need to pick a point around which you want to rotate. Maybe the robot's head?
head_index = 0
for i,b in enumerate(my_config.config.bodies):
if "head" in b.name :
head_index = i
break
Now you're ready to transform:
new_qp = transform_qp(old_qp,
mask,
jnp.array(transforms3d.euler.euler2quat(0., 0., jnp.pi/2.,axes='rxyz')), # some rotation quaternion--I like the transforms3d library for generating these
old_qp.pos[head_index], # the point to rotate around
jnp.array([0, 0, 0]), # some additional translational offset.
)
from brax.
Thank you very much for your kind reply, I tried the DefaultState
method and it works well:)
I also found a strange bug, I tried to run the example provided in training.ipynb
for training humanoid with sac
and it performed poorly. But when I rolled back the code to an earlier version (1c237e8), everything was fine, so I think some recent updates may have had some effect on the training of humanoid.
By the way, why only humanoid is trained using sac
? Is it possible for ppo
to achieve good performance in humanoid? Can you share some parameters for training humanoid using ppo
?
Thanks again for your prompt reply, it saves me a lot of time. Have a nice day:)
from brax.
Yep there was a sneaky bug we introduced that we actually fixed yesterday which hasn't quite yet percolated to github. Performance is still a little bit unstable on humanoid, so we're trying to hunt down if there's something else going on.
We haven't been able to get PPO to work for humanoid, which may or may not be related to the high performance variance for humanoid in SAC.
from brax.
Got it:)
I used transform_qp
to rotate the humanoid 180 degrees around the z-axis in an earlier bug-free version, like this:
humanoid_body_name = ["torso", "lwaist", "pelvis", "right_shin", "right_thigh",
"left_thigh", "left_shin", "right_upper_arm", "left_upper_arm",
"right_lower_arm", "left_lower_arm"]
mask = jnp.array([1. if b_body.name in humanoid_body_name else 0 for b_body in self.sys.config.bodies])
ref_vec = jnp.array([0.0, 0.0, 0.0])
qp = transform_qp(qp,
mask,
jnp.array(transforms3d.euler.euler2quat(0., 0., jnp.pi, axes='rxyz')),
ref_vec, # the point to rotate around
jnp.array([0, 0, 0]), # some additional translational offset.
)
qp, info = self.sys.step(qp,
jax.random.uniform(rng, (self.action_size,)) * .5)
The visualization shows that it works. To get the rotated humanoid to move forward, I added a negative sign to the original move reward:
lin_vel_cost = 1.25 * -(com_after[0] - com_before[0]) / self.sys.config.dt
But the training results were very poor. I'm not sure if it's the rotation or the reward design that's causing the problem, can you give me some insight?
Thanks again for your reply~
from brax.
What good news!!
Thank you for still remembering this issue, and a heartfelt thank you for all your hard work!
from brax.
Related Issues (20)
- Failure to write videos upon training completion HOT 1
- Are control actions scaled in BRAX environments? HOT 6
- Nan encounted in pipeline_step() HOT 1
- Cannot run simple MJX example on standard v4-8 Cloud TPU VM HOT 2
- Pusher environment with Spring pipeline HOT 3
- Optimizer with MultiTransform throws ValueError HOT 2
- Documentation bug: wrong observation and action spaces ordering HOT 3
- `mjx.ncon` removed as of MuJoCo 3.1.5 HOT 2
- policy callback 'policy_params_fn' for other algorithms? HOT 1
- ptxas version missmatch HOT 1
- CUDA OOM with jax/pytorch notebook HOT 1
- very very slow on local computer (even with GPU) HOT 2
- What replaced the old "pmap.is_synchornized" called in new brax versions? HOT 1
- Jacobian of State Dynamics HOT 3
- how to do n controlled physic steps per every control step
- NaNs at Inference HOT 7
- Setting Initial Camera Position in Brax Visualizer
- Rendered Plane Texture Issue in Visualizer
- Brax's Simulator Engine Swap
- TypeError: RandomNumberGenerator._generator_ctor() takes from 0 to 1 positional arguments but 2 were given
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from brax.