Giter VIP home page Giter VIP logo

trajax's Introduction


A Python library for differentiable optimal control on accelerators.

Jump to: installation | background | API | limitations

Trajax builds on JAX and hence code written with Trajax supports JAX's transformations. In particular, Trajax's solvers:

  1. Are automatically efficiently differentiable, via jax.grad.
  2. Scale up to parallel instances via jax.vmap and jax.pmap.
  3. Can run on CPUs, GPUs, and TPUs without code changes, and support end-to-end compilation with jax.jit.
  4. Are made available from Python, written with NumPy.

In Trajax, differentiation through the solution of a trajectory optimization problem is done more efficiently than by differentiating the solver implementation directly. Specifically, Trajax defines custom differentiation routines for its solvers. It registers these with JAX so that they are picked up whenever using JAX's autodiff features (e.g. jax.grad) to differentiate functions that call a Trajax solver.

This is a research project, not an official Google product.

Trajax is currently a work in progress, maintained by a few individuals at Google Research. While we are actively using Trajax in our own research projects, expect there to be bugs and rough edges compared to commercially available solvers.


To install directly from github using pip:

$ pip install git+

Alternatively, to install from source:

$ python install

Trajectory optimization and optimal control

We consider classical optimal control tasks concerning optimizing trajectories of a given discrete time dynamical system by solving the following problem. Given a cost function c, dynamics function f, and initial state x0, the goal is to compute:

argmin(lambda X, U: sum(c(X[t], U[t], t) for t in range(T)) + c_final(X[T]))

subject to the constraint that X[0] == x0 and that:

all(X[t + 1] == f(X[t], U[t], t) for t in range(T))

There are many resources for more on trajectory optimization, including Dynamic Programming and Optimal Control by Dimitri Bertsekas and Underactuated Robotics by Russ Tedrake.


In describing the API, it will be useful to abbreviate a JAX/NumPy floating point ndarray of shape (a, b, …) as a type denoted F[a, b, …]. Assume n is the state dimension, d is the control dimension, and T is the time horizon.

Problem setup convention/signature

Setting up a problem requires writing two functions, cost and dynamics, with type signatures:

cost(state: F[n], action: F[d], time_step: int) : float
dynamics(state: F[n], action: F[d], time_step: int) : F[n]

Note that even if a dimension n or d is 1, the corresponding state or action representation is still a rank-1 ndarray (i.e. a vector, of length 1).

Because Trajax uses JAX, the cost and dynamics functions must be written in a functional programming style as required by JAX. See the JAX readme for details on writing JAX-friendly functional code. By and large, functions that have no side effects and that use jax.numpy in place of numpy are likely to work.


If we abbreviate the type of the above two functions as CostFn and DynamicsFn, then our solvers have the following type signature prefix in common:

solver(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], initial_actions: F[T, d], *solver_args, **solver_kwargs): SolverOutput

SolverOutput is a tuple of (F[T + 1, n], F[T, d], float, *solver_outputs). The first three tuple components represent the optimal state trajectory, optimal control sequence, and the optimal objective value achieved, respectively. The remaining *solver_outputs are specific to the particular solver (such as number of iterations, norm of the final gradient, etc.).

There are currently four solvers provided: ilqr, scipy_minimize, cem, and random_shooting. Each extends the signatures above with solver-specific arguments and output values. Details are provided in each solver function's docstring.

Underlying the ilqr implementation is a time-varying LQR routine, which solves a special case of the above problem, where costs are convex quadratic and dynamics are affine. To capture this, both are represented as matrices. This routine is also made available as tvlqr.


One might want to write a custom solver, or work with an objective function for any other reason. To that end, Trajax offers the optimal control objective in the form of an API function:

objective(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], actions: F[T, d]): float

Combining this function with JAX's autodiff capabilities offers, for example, a starting point for writing a first-order custom solver. For example:

def improve_controls(cost, dynamics, U, x0, eta, num_iters):
  grad_fn = jax.grad(trajax.objective, argnums=(2,))
  for i in range(num_iters):
    U = U - eta * grad_fn(cost, dynamics, U, x0)
  return U

The solvers provided by Trajax are actually built around this objective function. For instance, the scipy_minimize solver simply calls scipy.minimize.minimize with the gradient and Hessian-vector product functions derived from objective using jax.grad and jax.hessian.


​​Just as Trajax inherits the autodiff, compilation, and parallelism features of JAX, it also inherits its corresponding limitations. Functions such as the cost and dynamics given to a solver must be written using jax.numpy in place of standard numpy, and must conform to a functional style; see the JAX readme. Due to the complexity of trajectory optimizer implementations, initial compilation times can be long.

trajax's People


froystig avatar ssingh19 avatar stephentu avatar vikas-sindhwani avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

trajax's Issues

Does not work w/ BRAX

Has anyone tried the solvers on BRAX environments? Here's what I have:

import trajax
import jax
from jax import numpy as jnp
from jax.flatten_util import ravel_pytree
import brax
from brax import envs

def get_f_and_c(env):
    key = jax.random.PRNGKey(0)
    state = env.reset(key)
    _, x2qp = ravel_pytree(state.qp)
    def f(x, u, t):
        qp = x2qp(x)
        nqp, _ = env.sys.step(qp, u)
        return ravel_pytree(nqp)[0]
    def c(x, u, t):
        qp = x2qp(x)
        dstate = state.replace(qp=qp)
        nstate = env.step(dstate, u)
        return -nstate.reward
    return f, c

env = envs.create('inverted_pendulum')
key = jax.random.PRNGKey(0)
state = env.reset(key)
x_init, x2qp = ravel_pytree(state.qp)

f, c = get_f_and_c(env)

x, u, cost, *outputs = trajax.optimizers.ilqr(c, f, x_init, jnp.zeros([1, env.action_size]))

which gives:

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Hi, I'd like to contribute to iCEM extensions of current CEM implementation. I'd like to ask:

  • if it's not on your list?
  • if is it ok, if I extend current cem() function with additional (optional) features and hyperparameters?


Use of deprecated Jax APIs/behavior

Trajax uses deprecated Jax APIs/behavior which result in warnings being emitted in two locations.

  1. The first instance is at
    @partial(jit, static_argnums=(0, 1, 9))

    where argnum 9 is specified although the function only has 8 arguments. This results in the warning:
jax/_src/ SyntaxWarning: Jitted function has static_argnums=(0, 1, 9), but only accepts 8 positional arguments. 
This warning will be replaced by an error after 2022-08-20 at the earliest.
  1. The second warning is at
    K = -sp.linalg.solve(G_, H, sym_pos=True)

    which raises the following warning
trajax/ FutureWarning: The sym_pos argument to solve() is deprecated and will be removed in a future JAX release. Use assume_a='pos' instead.

Bugfixes and improvements suggestions

I recently implemented a multiple-shooting variant of iLQR here:

Much of my implementation (apart from core algorithmic changes) is inspired on trajax. While implementing this, I found a few issues/possible improvements related to trajax itself, which I'll list below.

  1. In

    Q = lax.cond(make_psd, Q, psd, Q, lambda x: x)
    , you are projecting the Q matrices to become positive semi-definite. When the M matrices (i.e. cross-state-and-control quadratic terms) are non-zero, this is not sufficient. You should do this instead:

  2. Perhaps because of the issue above, you resort to doing least-square solves in

    K_k, *_ = np.linalg.lstsq(
    , when you should be able to use Cholesky solves instead (didn't benchmark this part in JAX, but generally should be faster); see

  3. In

    def lqr_step(P, p, Q, q, R, r, M, A, B, c, delta=1e-8):
    , some of the terms you have in your code are mathematically guaranteed to be zero and can be removed; this is solved in

  4. In places like

    return lax.fori_loop(0, T, body, (X, U))
    , you use fori_loop, when using a scan results in significant speed-ups. You can find almost-drop-in replacements here:

  5. It would be interesting to add support for a GPU-accelerated implementation of LQR; see

I'd be happy to try to merge my code ( into this repository (either just these improvements, or actually adding the new algorithm itself), if you find that interesting. Let me know!


Hi, do we have in total performance benefit from injecting jax args to scipy.optimize.minimize()? Are there any plans to extend jax.scipy.optimize.minimize() to constrained problems? Regards

ILQR optimizer doesn't support 1D scalar dynamical systems

When trying to run a 1D quadratic control affine nonlinear system of the form as shown below, the ILQR implementation is unable to handle scalar valued systems and results in a dimensionality mismatch error. Please find error and code below.


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import os

import jax
from jax import device_put
from jax import vmap
from jax.config import config
import jax.numpy as np
import numpy as onp

from trajax import optimizers
from trajax.integrators import euler
from trajax.integrators import rk4
import matplotlib.pyplot as plt

def quadratic_nonlinear(x, u, t, params=(5, 10)):
    Simple quadratic nonlinear system where we introduce reference trajectory as input
    :param x: 1D scalar state
    :param u: 1D input
    :param params: Kp, Kd gains for PD control law
    :return xdot: 1D array of shape 1
    del t
    Kp, Kd = params
    r = np.squeeze(u)
    # xdot = (x ** 2 + Kp * (x - r) - Kd * rdot)/(1 - Kd)
    xdot = x ** 2 + Kp * (x - r)
    return np.array([xdot])

class ILQR_test():
    Testing ILQR implementation in trajax for simple nonlinear systems
    def __init__(self):

    def discretize(self, type='euler', dynamics=None):
        if dynamics is not None:
            self.dynamics = dynamics

        self.dynamics = euler(self.dynamics, dt=0.01)
        if type != 'euler':
            self.dynamics = rk4(self.dynamics, dt=0.01)

    def testQuadNonLinear(self, maxiter):
        Calling ilqr on quadratic nonlinear system with input as reference trajectory
        :param maxiter: maximum number of iterations to take in ilqr
        :return: list of ilqr fn output
        horizon = 100
        dynamics = rk4(quadratic_nonlinear, dt=0.01)

        true_params = (100.0, 10.0, 1.0)

        def cost(params, state, action, t):
            final_weight, stage_weight, action_weight = params

            state_err = state - action
            state_cost = stage_weight * (state_err ** 2 + action ** 2)
            # action_cost = action_weight * np.squeeze(action) ** 2
            return np.where(t == horizon, final_weight * state_cost,

        x0 = np.array([-0.9])
        U0 = np.zeros((horizon, 1))
        X, U, obj, grad, adj, lqr_val, total_iter = optimizers.ilqr(
            functools.partial(cost, true_params), dynamics, x0, U0, maxiter)
        return [X, U, obj, grad, adj, lqr_val, total_iter]

test = ILQR_test()
traj_cost = []
num_iter = [2, 30, 40, 50]

for i in num_iter:
    # traj = test.apply_ilqr(x0=onp.random.randn(2), U=onp.random.randn(2), maxiter=i, dynamics=rk4(quadratic_nonlinear, dt=0.01))
    # traj = test.testPendulumReadmeExample(maxiter=i)
    traj = test.testQuadNonLinear(maxiter=i)

X = traj[0]
U = traj[1]


Solving a batch of trajectory optimization problems on a GPU


I am looking for a library to solve batches of trajectory optimization problems (same problem, different initial states) on accelerators, and I've found this library, which looks great! I have some experience with JAX, so this library would be perfect for me, but I am struggling to understand how much it was designed for my use case. As far as I know, GPUs are not suitable for matrix decompositions, and I saw that your iLQR solver relies on solving linear systems. Would you agree that this could be a bottleneck for my use case? And out of curiosity: what is the use case for which this library was designed?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.