Giter VIP home page Giter VIP logo

java-reinforcement-learning's Introduction

java-reinforcement-learning

Package provides java implementation of reinforcement learning algorithms as described in the book "Reinforcement Learning: An Introduction" by Sutton

Build Status Coverage Status

Features

The following reinforcement learning are implemented:

  • R-Learn
  • Q-Learn
  • Q-Learn with eligibility trace
  • SARSA
  • SARSA with eligibility trace
  • Actor-Critic
  • Actor-Critic with eligibility trace

The package also support a number of action-selection strategy:

  • soft-max
  • epsilon-greedy
  • greedy
  • Gibbs-soft-max

Reinforcement Learning

Install

Add the following dependency to your POM file:

<dependency>
  <groupId>com.github.chen0040</groupId>
  <artifactId>java-reinforcement-learning</artifactId>
  <version>1.0.5</version>
</dependency>

Application Samples

The application sample of this library can be found in the following repositories:

Usage

Create Agent

An reinforcement agent, say, Q-Learn agent, can be created by the following java code:

import com.github.chen0040.rl.learning.qlearn.QAgent;

int stateCount = 100;
int actionCount = 10;
QAgent agent = new QAgent(stateCount, actionCount);

The agent created has a state map of 100 states, and 10 different actions for its selection.

For Q-Learn and SARSA, the eligibility trace lambda can be enabled by calling:

agent.enableEligibilityTrace(lambda)

Select Action

At each time step, a action can be selected by the agent, by calling:

int actionId = agent.selectAction().getIndex();

If you want to limits the number of possible action at each states (say the problem restrict the actions avaliable at different state), then call:

Set<Integer> actionsAvailableAtCurrentState = world.getActionsAvailable(agent);
int actionTaken = agent.selectAction(actionsAvailableAtCurrentState).getIndex();

The agent can also change to a different action-selection policy available in com.github.chen0040.rl.actionselection package, for example, the following code switch the action selection policy to soft-max:

agent.getLearner().setActionSelection(SoftMaxActionSelectionStrategy.class.getCanonicalName());

State-Action Update

Once the world state has been updated due to the agent's selected action, its internal state-action Q matrix will be updated by calling:

int newStateId = world.update(agent, actionTaken);
double reward = world.reward(agent);

agent.update(actionTaken, newStateId, reward);

Sample code

Sample code for R-Learn

import com.github.chen0040.rl.learning.rlearn.RAgent;

int stateCount = 100;
int actionCount = 10;
RAgent agent = new RAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use RLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
RLearner agent = new RLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move move = moves.get(i);
    agent.update(move.oldState, move.action, move.newState, world.getActionsAvailableAtState(nextStateId), move.reward);
}

Sample code for Q-Learn

import com.github.chen0040.rl.learning.qlearn.QAgent;

int stateCount = 100;
int actionCount = 10;
QAgent agent = new QAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use QLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
QLearner agent = new QLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move move = moves.get(i);
    agent.update(move.oldState, move.action, move.newState, move.reward);
}

Sample code for SARSA

import com.github.chen0040.rl.learning.sarsa.SarsaAgent;

int stateCount = 100;
int actionCount = 10;
SarsaAgent agent = new SarsaAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use SarsaLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
SarsaLearner agent = new SarsaLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move next_move = moves.get(i);
    if(i != moves.size()-1) {
        next_move = moves.get(i+1);
    }
    Move current_move = moves.get(i);
    agent.update(current_move.oldState, current_move.action, current_move.newState, next_move.action, current_move.reward);
}

Sample code for Actor Critic Model

import com.github.chen0040.rl.learning.actorcritic.ActorCriticAgent;
import com.github.chen0040.rl.utils.Vec;

int stateCount = 100;
int actionCount = 10;
ActorCriticAgent agent = new ActorCriticAgent(stateCount, actionCount);
Vec stateValues = new Vec(stateCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 
 System.out.println("World state values changed ...");
 for(int stateId = 0; stateId < stateCount; ++stateId){
    stateValues.set(stateId, random.nextDouble());
 }
    
 agent.update(actionId, newStateId, reward, stateValues);
}

Alternatively, you can use ActorCriticLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
SarsaLearner agent = new SarsaLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move next_move = moves.get(i);
    if(i != moves.size()-1) {
        next_move = moves.get(i+1);
    }
    Move current_move = moves.get(i);
    agent.update(current_move.oldState, current_move.action, current_move.newState, next_move.action, current_move.reward);
}

Save and Load RL models

To save the trained RL model (say QLeanrer):

QLearner learner = new QLearner(stateCount, actionCount);
train(learner);
String json = learner.toJson();

To load the trained RL model from json:

QLearner learner = QLearn.fromJson(json);

java-reinforcement-learning's People

Contributors

chen0040 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

java-reinforcement-learning's Issues

Multi Agents ? With various properties

Il love your library
I'm searching to implement it for use for example in a car racing.
Goal is to learn on each turn, and taking acount of other racer
So input must be X racers with property (speed on each turn, etc...) and predict after 100 turns the winners probability
Hope this is clear

Can anybody tell me "world" in code and Readme?

Can anybody tell me "world" in code and Readme?
I cannot find where is "world"

` int actionId = agent.selectAction().getIndex();
System.out.println("Agent does action-" + actionId);

        int newStateId = world.update(agent, actionId);
        double reward = world.reward(agent);`

Thanks a lot!

Does R-Learn equals to Roth-Erev Algorithom?

Sorry to bother you, but I can't find what R-Learn stands for? In fact, I found Roth-Erev Algorithom in reinforcement learning.Can anynone give me some guidence?Thanks very much!

Question abot Android support

Your library is really interesting, congrats and thanks for sharing!

I am considering to use it to train over an existing environment that runs on a regular (desktop) JVM, then use that model for predictions on an Android app. I wonder if there is any previous experience or obstacle for of running the library on an Android app (going to try it, but decided to ask beforehand just in case).

An alternative would be to convert the saved trained models to a TensorFlow-compatible format (so it could be used with TensorFlow Lite). Is there any way to do that conversion?

Any suggestions are appreciated! ๐Ÿ™‡

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.