Giter VIP home page Giter VIP logo

nclustrl's Introduction

NclustRL

NclustRL is a toolkit that implements some functionalities to help train agents for n-clustering tasks. It works with Ray's RLlib to train DRL agents.

Ray is a general-purpose framework for distributed computing that implements a known library for hyperparameter tunning, Tune. Furthermore, it implements RLlib, a DRL framework that supports distributed computing and great customization.

NclustRL implements a trainer API for n-clustering that handles all training tasks for the user; a set of default models and metrics; and other helpful functions. Likewise, it provides a set of default configurations for n-clustering tasks available in "configs".

Diagram exemplifying NclustEnv's architecture

The trainer API aims to provide a simple way of training and testing DRL agents for n-clustering tasks. This class handles all of RLlib's logic and expose only user-friendly methods.

After initialized, the trainer exposes four primary methods:

  • Train: Exposes the primary training function. It receives the training parameters that should be passed on to Tune, initiates the training process, manages multiple samples of the same trial, and parses results returning the best performance obtained;
  • Load: Imports an agent from a checkpoint for testing;
  • Test: Evaluates the accuracy and mean reward and returns the mean and standard deviation for each of these metrics across n episodes.
  • Test Dataset: Evaluates the performance in the same way as Test but receives as input a specific dataset from where episodes should be sampled.

Installation

This tool can be installed from PyPI:

pip install nclustRL

Getting started

Here are the basics, for more information check the Experiments available on "Exp".

## Train basic agent

from nclustRL.trainer import Trainer
from nclustRL.configs.default_configs import PPO_PBT, DEFAULT_CONFIG
from ray.rllib.agents.ppo import PPOTrainer
from nclustenv.configs import biclustering

# Inicialize Trainer

config = DEFAULT_CONFIG.copy()
config['env_config'] = biclustering.binary.basic_v2

trainer = Trainer(
    trainer=PPOTrainer,
    env='BiclusterEnv-v0',
    save_dir='nclustRL/Exp/test',
    name='test',
    config=config
)

## Tune agent

best_checkpoint = trainer.train(
    num_samples=8, 
    scheduler=PPO_PBT,
    stop_iters=500,
)

Model

By default this tool implements a model for hybrid proximal policy optimization algorithm, available in "models". This model can be customized, or other models might be implemented and passed in the configs.

License

GPLv3

nclustrl's People

Contributors

pedrocotovio avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.