Giter VIP home page Giter VIP logo

policy_search_bb-alpha's Introduction

PolicySearch-bb-alpha

Python implementation of policy search and model training using $\alpha$-divergence minimization for Bayesian neural networks with latent variables. See:

Depeweg, Stefan, et al. "Learning and policy search in stochastic dynamical systems with bayesian neural networks." arXiv preprint arXiv:1605.07127 (2016).

Prerequisites

Requires the standard libraries for theano-based models and Lasagne (I use 0.2.dev)

Usage

  1. Insert industrialbenchmark_python in environment/:

    • Download python version of industrialbenchmark.

    • Move to environment/industrialbenchmark

  2. Generate batch of state transitions:

     cd  environment/  
     python make_data.py 
    

    will generate a training and test set stored in environment/out

    X: Setpoint,A(t-14),..,A(t+1),R(t-15)..,R(t)

    Y: R(t+1)

  3. Model Training:

    cd experiments/  
    python train_model.py 0.5 
    

    Will train a BNN using bb-alpha with alpha=0.5

    After training model will be stored in experiments/models as pickle file

    Code will run on GPU/CPU. Parameters are chosen conservatively for GPU use. Consider decreasing sample size to 25 for CPU use.

    Expected training time (i5-6600K CPU @ 4.0GHz, GTX 1060): CPU:
    50 samples: 21.5 hours
    25 samples: 10.5 hours

    GPU:
    50 samples: 3.5 hours
    25 samples: 2.0 hours

  4. Policy Training

    cd experiments/  
    python train_controller.py 0.5 
    

    Will train a policy using model from step 2 (required a model exists in models/)
    After training the policy will be stored in experiments/controller as pickle file.

    Code will run on CPU. For GPU use one should pass only indexes to train_func using givens. In our experiments no speedup was obtaiend from GPU use.

  5. Policy Evaluation:
    An example policy evaluation script is given in environment/eval_pol.py

    cd environment/  
    ipython  
    from eval_pol import evaluate   
    results = evalute('../experiments/controller/AD_1.0.p')  
    

Some helpful tips:

model.bb_alpha.network.update_randomness(n_samples)   

will sample n_samples from q(W) and resample the input noise.

For prediction use:

m,v = model.predict(np.tile(X,[n_samples,1,1]))  

where X is n x d
m is n_samples x n x d
v is n_samples x n x d (constant output noise variance)

policy_search_bb-alpha's People

Contributors

stde avatar

Watchers

Slice avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.