Giter VIP home page Giter VIP logo

hsai-predictor's Introduction

How Safe Am I Given What I See?

Calibrated Prediction of Safety Chances for Image-Controlled Autonomy

This is the code for paper: "How Safe Am I Given What I See? Calibrated Prediction of Safety Chances for Image-Controlled Autonomy". This repository contains the following:

  • The DQN controller training code
  • Trained controllers for the racing car and the cart pole
  • Data generation code
  • Training code for all of our evaluators, autoencoders, forecasters, and predictors
  • Conformal calibration code

Prerequisites

pip install -r requirements.txt

The version of gym is 0.21.0 and box2d-py is 2.3.8.

Train the controllers

The code to train the racing car controller is in the directory RacingCarController(https://github.com/andywu0913/OpenAI-GYM-CarRacing-DQN)

The code to train the cart pole in the directory CartPoleController(https://github.com/fedebotu/vision-cartpole-dqn)

Data generation

Before collecting data, we need to edit the car_racing.py files in PYTHON_PATH/site-packages/gym/envs/box2d/car_racing.py to get the position of the car at each moment.

In def step(self,cation), just change the return value to:

return self.state, step_reward, done, [self.car.hull.position,self.road_poly,self.observation_space]

or directly substitute the file with EditedGym/car_racing.py

In the experiments on the racing car, we used 6 different models and collected 80K training samples, 20K test samples, 20K calibration samples, and 20K validation samples for each controller.

The example is here:

python RacingCar/gene_data.py --n=80000 -d="RacingCar/data/train/controller_6/" -c="RacingCar/models/trial_600.h5" -s=0

where n is the number of samples, d is the dir of output path, c is the path of DQN models, and s is the random seed.

For the cart pole experiments, we used 3 different models and collect 30K training samples, 30K test samples, 30K calibration samples, and 30K validation samples for each controller

    return np.array(self.state, dtype=np.float32), reward, done, theta
python CartPole/gene_data.py --n=30000 -d="CartPole/data/train/controller_1/" -c="CartPole/models/policy_net_best1.pt" -s=0

Evaluator training

python RacingCar/evaluator/train.py --log="RacingCar/evaluator/logs/" --train="RacingCar/data/train/" --test="RacingCar/data/test/"
python CartPole/evaluator/train.py  --log="CartPole/evaluator/logs/" --train="CartPole/data/train/" --test="CartPole/data/train/"

VAE training

For training a VAE without the safety loss:

python RacingCar/vae/train_unsafe_vae.py --log="RacingCar/vae/unsafe/" --train="RacingCar/data/train/" --test="RacingCar/data/test/" --eva="RacingCar/models/eva.tar"

For training a VAE with the safety loss:

python RacingCar/vae/train_safe_vae.py --log="RacingCar/vae/safe/" --train="RacingCar/data/train/" --test="RacingCar/data/test/" --eva="RacingCar/models/eva.tar"

For training VAEs for the cart pole, just replace the path of training and test data

Monolithic predictor training

For training the predictor using a CNN architecture(controller independent):

python MonoCnn/monoInd.py  --train="RacingCar/data/train/" --test="RacingCar/data/test/" --save="RacingCar/models/" --epochs=10 --steps=9 --task=1

epochs means the maximum training epoch, steps means the horizon range [0,steps] and task 1 is racing car and task 2 is cart pole (for racing cars, step is ten times than the real value)

To train a controller-specific ones:

python MonoCnn/monoCsp.py  --train="RacingCar/data/train/controller_1/" --test="RacingCar/data/test/controller_1/" --save="RacingCar/models/" --epochs=10 --steps=9 --task=1

To train an LSTM predictor:

python MonoLstm/monoInd.py  --train="RacingCar/data/train/" --test="RacingCar/data/test/" --save="RacingCar/models/" --epochs=10 --steps=9 --task=1 --vae="MonoLstm/safe_vae_best.tar"
python MonoLstm/monoCsp.py  --train="RacingCar/data/train/controller_1/" --test="RacingCar/data/test/controller_1/" --save="RacingCar/models/" --epochs=10 --steps=9 --task=1 --vae="MonoLstm/safe_vae_best.tar"

Composite predictors

To train an image (conv-lstm) predictor:

python CompImg/train.py  --train=TRAIN_PATH --test=TEST_PATH

To test an image (conv-lstm) predictor:

python CompImg/test.py  --test=TEST_PATH --eva=EVALUATOR_PATH

To train a latent predictor (controller-independent):

python CompLat/trainInd.py  --train=TRAIN_PATH --test=TEST_PATH --vae=VAE_PATH 

To train a latent predictor (controller-specific):

python CompLat/trainCsp.py  --train=TRAIN_PATH --test=TEST_PATH  --vae=VAE_PATH 

To test a latent predictor (controller-independent)

python CompLat/testInd.py  --test=TEST_PATH --eva=EVALUATOR_PATH --vae=VAE_PATH --rnn_SAVED_MODEL_PATH

To test a latent predictor (controller-specific)

python CompLat/testCsp.py  --test=TEST_PATH --eva=EVALUATOR_PATH --vae=VAE_PATH --rnn_SAVED_MODEL_PATH

Conformal calibration

After the test, the file will save the prediction results, especially the softmax scores and the safety labels into a npz file.

python ConformalCali/Cali-with-brier-score.py  --m=200 --n=1000 --data="sft.npz" --save="save.npz"

hsai-predictor's People

Contributors

maozj6 avatar bisc avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.