Giter VIP home page Giter VIP logo

js-torch's Introduction

js-torch

PyTorch in JavaScript

  • JS-Torch is a Deep Learning JavaScript library built from scratch, to closely follow PyTorch's syntax.
  • It contains a fully functional Tensor object, which can track gradients, Deep Learning Layers and functions, and an Automatic Differentiation engine.
  • Feel free to try out a Web Demo!

Implemented Tensor Operations:
Implemented Deep Learning Layers:

1. Project Structure

  • assets/ : Folder to store images and the Demo.
  • src/ : Framework with JavaScript files.
    • src/tensor.js: File with the Tensor class and all of the tensor Operations.
    • src/utils.js: File with operations and helper functions.
    • src/layers.js: Submodule of the framework. Contains full layers.
    • src/optim.js: Submodule of the framework. Contains Adam Optimizer.
  • tests/: Folder with unit tests. Contains test.js.

2. Running it Yourself

Simple Autograd Example:

const torch = require("js-pytorch");

// Instantiate Tensors:
x = torch.randn([8,4,5]);
w = torch.randn([8,5,4], requires_grad = true);
b = torch.tensor([0.2, 0.5, 0.1, 0.0], requires_grad = true);

// Make calculations:
out = torch.matmul(x, w);
out = torch.add(out, b);

// Compute gradients on whole graph:
out.backward();

// Get gradients from specific Tensors:
console.log(w.grad);
console.log(b.grad);

Complex Autograd Example (Transformer):

const torch = require("js-pytorch");
const nn = torch.nn;

class Transformer extends nn.Module {
    constructor(vocab_size, hidden_size, n_timesteps, n_heads, p) {
        super();
        // Instantiate Transformer's Layers:
        this.embed = new nn.Embedding(vocab_size, hidden_size);
        this.pos_embed = new nn.PositionalEmbedding(n_timesteps, hidden_size);
        this.b1 = new nn.Block(hidden_size, hidden_size, n_heads, n_timesteps, dropout_p=p);
        this.b2 = new nn.Block(hidden_size, hidden_size, n_heads, n_timesteps, dropout_p=p);
        this.ln = new nn.LayerNorm(hidden_size);
        this.linear = new nn.Linear(hidden_size, vocab_size);
    };

    forward(x) {
        let z;
        z = torch.add(this.embed.forward(x), this.pos_embed.forward(x));
        z = this.b1.forward(z);
        z = this.b2.forward(z);
        z = this.ln.forward(z);
        z = this.linear.forward(z);
        return z;
    };
};

// Instantiate your custom nn.Module:
const model = new Transformer(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p);

// Define loss function and optimizer:
const loss_func = new nn.CrossEntropyLoss();
const optimizer = new optim.Adam(model.parameters(), lr=5e-3, reg=0);

// Instantiate sample input and output:
let x = torch.randint(0,vocab_size,[batch_size,n_timesteps,1]);
let y = torch.randint(0,vocab_size,[batch_size,n_timesteps]);
let loss;

// Training Loop:
for(let i=0 ; i < 40 ; i++) {
    // Forward pass through the Transformer:
    let z = model.forward(x);

    // Get loss:
    loss = loss_func.forward(z, y);

    // Backpropagate the loss using torch.tensor's backward() method:
    loss.backward();

    // Update the weights:
    optimizer.step();
    
    // Reset the gradients to zero after each training step:
    optimizer.zero_grad();
    
};

Note: You can install the package locally with: npm install js-pytorch


3. Results

  • The models implemented in the unit tests all converged to near-zero losses.
  • This package is not as optimized as PyTorch yet, but I tried making it more interpretable. Efficiency improvements are incoming!
  • Hope you enjoy!

js-torch's People

Contributors

eduardoleao052 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.