Giter VIP home page Giter VIP logo

burn's Introduction

Current Crates.io Version Test Status Documentation Rust Version license

This library aims to be a complete deep learning framework with extreme flexibility written in Rust. The goal would be to satisfy researchers as well as practitioners making it easier to experiment, train and deploy your models.

Sections

Features

  • Flexible and intuitive custom neural network module ๐Ÿค–
  • Stateless and thread safe forward pass ๐Ÿš€
  • Fast training with full support for metric, logging and checkpointing ๐ŸŒŸ
  • Burn-Tensor: Tensor library with autodiff, CPU and GPU support ๐Ÿ”ฅ
  • Burn-Dataset: Dataset library with multiple utilities and sources ๐Ÿ“š

Get Started

The best way to get started with burn is the look at the examples. Also, this may be a good idea to checkout the main components to get a quick overview of how to use burn.

Examples

For now there is only one example, but more to come ๐Ÿ’ช..

MNIST

The MNIST example is not just of small script that shows you how to train a basic model, but it's a quick one showing you how to:

  • Define your own custom module (MLP).
  • Create the data pipeline from a raw dataset to a batched multi-threaded fast DataLoader.
  • Configure a learner to display and log metrics as well as to keep training checkpoints.

The example can be run like so:

$ git clone https://github.com/burn-rs/burn.git
$ cd burn
$ export TORCH_CUDA_VERSION=cu113               # Set the cuda version
$ # Use the --release flag to really speed up training.
$ cargo run --example mnist --release           # CPU NdArray Backend
$ cargo run --example mnist_cuda_gpu --release  # GPU Tch Backend

Components

Knowing the main components will be of great help when starting playing with burn.

Backend

Almost everything is based on the Backend trait, which allows to run tensor operations with different implementations without having to change your code. A backend does not necessary have autodiff capabilities, therefore you can use ADBackend when you require it.

Tensor

The Tensor struct is at the core of the burn framework. It takes two generic parameters, the Backend and the number of dimensions D,

use burn::tensor::{Tensor, Shape, Data};
use burn::tensor::backend::{Backend, NdArrayBackend, TchBackend};

fn my_func<B: Backend>() {
    let _my_tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
}

fn main() {
    my_func<NdArrayBackend<f32>>();
    my_func<TchBackend<f32>>();
}

Module

The Module derive let your create your own neural network module similar to PyTorch.

use burn::nn;
use burn::module::{Param, Module};
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
struct MyModule<B: Backend> {
  my_param: Param<nn::Linear<B>>,
  repeat: usize,
}

Note that only the fields wrapped inside Param are updated during training, and the other ones should implement Clone.

Forward

The Forward trait can also be implemented by your module.

use burn::module::Forward;
use burn::tensor::Tensor;

impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for MyModule<B> {
   fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
       let mut x = input;

       for _ in 0..self.repeat {
           x = self.my_param.forward(x);
       }

       x
   }
}

Note that you can implement multiple time the Forward trait with different inputs and outputs.

Config

The Config derive lets you define serializable and deserializable configurations or hyper-parameters for your modules or any components.

use burn::config::Config;

#[derive(Config)]
struct MyConfig {
    #[config(default = 1.0e-6)]
    pub epsilon: usize,
    pub dim: usize,
}

The derive also adds useful methods to your config.

fn my_func() {
    let config = MyConfig::new(100);
    println!("{}", config.epsilon); // 1.0.e-6
    println!("{}", config.dim); // 100
    let config =  MyConfig::new(100).with_epsilon(1.0e-8);
    println!("{}", config.epsilon); // 1.0.e-8
}

Learner

The Learner is the main struct that let you train a neural network with support for logging, metric, checkpointing and more. In order to create a learner, you must use the LearnerBuilder.

use burn::train::LearnerBuilder;

let learner = LearnerBuilder::new("/tmp/artifact_dir")
    .metric_train_plot(AccuracyMetric::new())
    .metric_valid_plot(AccuracyMetric::new())
    .metric_train(LossMetric::new())
    .metric_valid(LossMetric::new())
    .with_file_checkpointer::<f32>(2)
    .num_epochs(config.num_epochs)
    .build(model, optim);

See this example for a real usage.

License

Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See LICENSE-APACHE and LICENSE-MIT for details. Opening a pull request is assumed to signal agreement with these licensing terms.

burn's People

Contributors

nathanielsimard avatar n8henrie avatar olgam4 avatar kepae avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.