Giter VIP home page Giter VIP logo

pcc-pytorch's Introduction

Prediction, Consistency and Curvature

This is a pytorch implementation of the paper "Prediction, Consistency, Curvature: Representation Learning for Locally-Linear Control". The work was done during the residency at VinAI Research, Hanoi, Vietnam.

Installing

First, clone the repository:

https://github.com/VinAIResearch/PCC-pytorch.git

Then install the dependencies as listed in pcc.yml and activate the environment:

conda env create -f pcc.yml

conda activate pcc

Training

The code currently supports training for planar, pendulum, cartpole and threepole environment. Run the train_pcc.py with your own settings. For example:

python train_pcc.py \
    --env=planar \
    --armotized=False \
    --log_dir=planar_1 \
    --seed=1 \
    --data_size=5000 \
    --noise=0 \
    --batch_size=128 \
    --lam_p=1.0 \
    --lam_c=8.0 \
    --lam_cur=8.0 \
    --vae_coeff=0.01 \
    --determ_coeff=0.3 \
    --lr=0.0005 \
    --decay=0.001 \
    --num_iter=5000 \
    --iter_save=1000 \
    --save_map=True

First, data is sampled according to the given data size and noise level, then PCC model will be trained using the specified settings.

If the argument save_map is set to True, the latent map will be drawn every 10 epoches (for planar only), then the gif file will be saved at the same directory as the trained model.

You can also visualize the training process by running tensorboard --logdir={path_to_log_file}, where path_to_log_file has the form logs/{env}/{log_dir}. The trained model will be saved at result/{env}/{log_dir}.

Visualizing latent maps

You can visualize the latent map for both planar and pendulum, to do that simply run:

python latent_map_planar.py --log_path={log_to_trained_model} --epoch={epoch}
or 
python latent_map_pendulum.py --log_path={log_to_trained_model} --epoch={epoch}

Sampling data

You can generate the training images for visualization by simply running:

cd data

python sample_{env_name}_data.py --sample_size={sample_size} --noise={noise}

Currently the code supports simulating 3 environments: planar, pendulum and cartpole.

The raw data (images) is saved in data/{env_name}/raw_{noise}_noise

Running iLQR on latent space

The configuration file for running iLQR for each task is in ilqr_config folder, you can modify with your own settings. Run python ilqr.py --task={task}, where task is in {plane, swing, balance, cartpole}.

The code will run iLQR for all models trained for that specific task and compute some statistics. The result is saved in iLQR/result.

Result

We evaluate the PCC model in 2 ways: quality of the latent map and the percentage of time the agent spent in the goal region.

Planar system

Latent map

Below is a random latent map PCC produces. You can watch a video clip comparing how latent maps produced by E2C and PCC evolve at this link: https://www.youtube.com/watch?v=pBmzFvvE2bo.

Latent space learned by PCC

Control result

We got around 48% on average and around 76% for the best model. Below are 2 sample trajectories of the agent.

Sample planar trajectory 1

Sample planar trajectory 2

Inverted pendulum

Latent map

Below is a random latent map PCC produces.

Latent space learned by PCC

Control result

We got around 60.7% on average and around 80.65% for the best model. Below are 2 sample trajectories of the inverted pendulum.

Sample inverted pendulum trajectory 1

Sample inverted pendulum trajectory 2

Cartpole

Sample cartpole trajectory 1

Sample cartpole trajectory 2

Acknowledgment

Many thanks to Nir Levine and Yinlam Chow for their help in answering the questions related to the PCC paper.

Citation

If you find this implementation useful for your work, please consider starring this repository.

pcc-pytorch's People

Contributors

tung-nd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pcc-pytorch's Issues

Curvature Loss is Dimensionally Incorrect

Hey I first wanted to say that this is great work. I also wanted to point out that in the non-amortized curvature loss if you look at the shapes:

grad_z is [batchsize x latent_dim]
grad_u is [batchsize x action_dim]

when I think in the paper grad_u should be [batchsize x latent_dim].

So, I think right now the way the curvature loss is formatted it is non-sensical dimensionally since if action_dim โ‰  latent_dim or action_dim โ‰  1 then you'd get a dimension mismatch. Let me know if you have any questions.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.