Giter VIP home page Giter VIP logo

tractjs's Introduction

tractjs

npm version Test Deploy to Github Pages

Run ONNX and TensorFlow inference in the browser. A thin wrapper on top of tract.

The Open Neural Network Exchange is a format which many popular libraries like PyTorch, TensorFlow and MXNet can export to which allows tractjs to run neural networks from (almost) any library.

Website | API Docs

Why tractjs instead of ONNX.js?

There is currently one other usable ONNX runner for the browser, ONNX.js. There are a couple of things tractjs does better:

  • tractjs supports more operators:
    • LSTMs (even bidirectional) are supported while ONNX.js does not support any recurrent networks.
    • Some ONNX-ML models like decision tree classifiers are also supported.
  • tractjs is more convenient to use. It can build to a single file tractjs.min.js which contains the inlined WASM and WebWorker. The WASM backend of ONNX.js can not as easily be used without a build system.

There are however also some downsides to tractjs. See the FAQ.

Getting started

Without a bundler

<html>
  <head>
    <meta charset="utf-8" />
    <script src="https://unpkg.com/tractjs/dist/tractjs.min.js"></script>
    <script>
      tractjs.load("path/to/your/model").then((model) => {
        model
          .predict([new tractjs.Tensor(new Float32Array([1, 2, 3, 4]), [2, 2])])
          .then((preds) => {
            console.log(preds);
          });
      });
    </script>
  </head>
</html>

With a bundler

npm install tractjs
import * as tractjs from "tractjs";

tractjs.load("path/to/your/model").then((model) => {
  model
    .predict([new tractjs.Tensor(new Float32Array([1, 2, 3, 4]), [2, 2])])
    .then((preds) => {
      console.log(preds);
    });
});

With Node.js

tractjs now runs in Node.js! Models are fetched from the file system.

const tractjs = require("tractjs");

tractjs.load("./path/to/your/model").then((model) => {
  model
    .predict([new tractjs.Tensor(new Float32Array([1, 2, 3, 4]), [2, 2])])
    .then((preds) => {
      console.log(preds);
    });
});

FAQ

Why does my model with dynamic input dimensions not work?

Currently, tract requires has some restrictions on dynamic dimensions. If your model has a dynamic dimension, there's multiple solutions:

  1. Declare a dynamic dimension via an input fact. Input facts are a way to provide additional information about input type and shape that can not be inferred via the model data:
const model = await tractjs.load("path/to/your/model", {
  inputFacts: {
    0: ["float32", [1, "s", 224, 224]],
  },
});
  1. Set fixed input dimensions via input facts. This is of course not ideal because subsequently the model can only be passed inputs with this exact shape:
const model = await tractjs.load("path/to/your/model", {
  inputFacts: {
    // be careful with image model input facts! here I use ONNX's NCHW format
    // if you are using TF you will probably need to use NHWC (`[1, 224, 224, 3]`).
    0: ["float32", [1, 3, 224, 224]],
  },
});
  1. Turn optimize off. This is the nuclear option. It will turn off all optimizations relying on information about input shape. This will make sure your model work (even with multiple dynamic dimensions) but significantly impact performance:
const model = await tractjs.load("path/to/your/model", {
  optimize: false,
});

What about size?

At the time of writing, tractjs is very large for web standards (6.2MB raw, 2.1MB gzipped). This is due to tract being quite large, and due to some overhead from inlining the WASM. But it's not as bad as it sounds. You can load tractjs lazily along your demo, where you will likely have to load significantly large weights too.

If you are working on a very size-sensitive application, get in touch and we can work on decreasing the size. There are some more optimizations to be done (e. g. an option not to inline WASM, and removing panics from the build). There is also ongoing work in tract to decrease size.

What about WebGL / WebNN support?

tractjs are bindings to the tract Rust library which was originally not intended to be run on the web. WebGL / WebNN support would be great, but would require lots of web-specific changes in tract so it is currently not under consideration.

License

Apache 2.0/MIT

All original work licensed under either of

Contribution

Contributions are very welcome! See CONTRIBUTING.md.

tractjs's People

Contributors

bminixhofer avatar danielbank avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

tractjs's Issues

Unit tests

As mentioned in #7 unit tests would be very nice. At the moment this is blocked by #8 but I'll see if I can fix that soon.

I'm open to the choice of framework, I've used Mocha in the past but Jest also looks good to me.

cc @danielbank if you're still interested.

Multithreading across batch dimension

Multithreading across the batch dimension should be quite easy to implement (ref wasm-bindgen rayon example) without any need for changes in tract (like this). Would only need a new option batch_dimension which is zero by default. This would act as a good filler until (if) internal multithreading (sonos/tract#342) is implemented in tract.

On a benchmark I did in the past tractjs vastly outperformed ONNX.js when pinned to a single core so this feature would likely make tractjs consistently faster than ONNX.js (at least on input with batch sizes > 1).

Should the input facts in the example be [1, 3, 224, 224]?

I don't know much about the input facts but it seems like the standard referenced ones for the TensorFlow models is [1, 224, 224, 3] (see #5). However the example in the README is [1, 3, 224, 224]. More a question but if it is a typo, this issue could be for updating that. In my own tinkering, I needed to use [1, 224, 224, 3]

Can't resolve 'worker_threads'

Bundling tools emit a warning:

WARNING in /Users/bminixhofer/Documents/Projects/tractjs/wrapper/dist/tractjs.esm.js
Module not found: Error: Can't resolve 'worker_threads' in '/Users/bminixhofer/Documents/Projects/tractjs/wrapper/dist'
 @ /Users/bminixhofer/Documents/Projects/tractjs/wrapper/dist/tractjs.esm.js
 @ ./bench.js

Doesn't seem to break anything but should be investigated.

Loading models outside the project

Hi Benjamin,

I'm a big fan of your tractjs project and am excited to use it to do ML on the web. For my particular use case, I want to load an arbitrary model which I don't want bundled into my code. In my experimenting, I have tried the following two approaches:

Load the model from the file system doing something like:

const model = new tractjs.Model('file:///Users/Bankster/Downloads/mobilenet_v2_1.4_224_frozen.pb');

This fails with an error:

Fetch API cannot load file:///Users/Bankster/Downloads/mobilenet_v2_1.4_224_frozen.pb. URL scheme "file" is not supported.

It makes sense that I wouldn't be able to use the fetch API to access local files, that would be a security issue for Chrome.

Serve the model as a static file from a web server and try to provide the endpoint to the constructor:

const model = new tractjs.Model('http://localhost:8080/models/mobilenet_v2_1.4_224_frozen.pb');

If I do this, I get the following error:

Error: TractError(
  Msg(
      "Translating node #0 \"input\" Source ToTypedTranslator",
  ),
  ...
)

My guess is that the model is being chunked and Tract doesn't expect that. I will keep exploring this second approach but was curious if you had any ideas on this matter.

Edit: I checked the response header and see Content-Length: 24508794 so it isn't a chunking issue. Maybe a MIME type or something...

What is the proper flow control for loading a model?

When trying to detect when a model is loaded (or failed to load), I have written my code as follows:

const model: Model = new Model('./tests/plus3.pb');
await model.modelId;

In the course of writing the unit test PR, I am seeing that this raises a TypeScript error:

Property 'modelId' is private and only accessible within class 'Model'.

This raises the question for me, is there a more appropriate way for detecting when the model loads? On the surface of it, this feels like a layer of indirection and it would be easier to do:

const model: Model = await new Model('./tests/plus3.pb');

If it is the most appropriate way, the private should be dropped from the modelId property.

CPU vs WASM

In tfjs and ONNX.js, you can set the backend to webgl, wasm or cpu. I saw that gpu is not supported with tractjs, so webgl is not an option, but I don't see a place where we can choose between wasm or cpu. Is there such a thing or the backend is always wasm?
Thanks

NNEF integration

tract now supports NNEF in sonos/tract#340.

  • Check size reduction by replacing ONNX and TF with NNEF and building the package.
  • If the size reduction is significant (@kali sees 25%!) we should create a separate package @tractjs/nnef for NNEF support. It follows that ONNX and TF should also be split this way into @tractjs/onnx and @tractjs/tf. See the way TF.js does splitting into packages.

This and #12 should get us below 1MB!

Move cypress to devDependencies

On cursory look it seems to be used for tests only, but it's getting pulled into production builds (yarn install --production) because it's in package.json's dependencies instead of devDependencies.

(The package by itself isn't particularly heavy, but upon installation it downloads a zip that expands to 700+ MB in ~/.cache/Cypress and inflates Docker images.)

Loading the plus3.pb from the Tract Repo

I'm trying to load the plus3 model from the Tract repo for my tests. Whereas it doesn't fail with an error, the model never successfully loads. Are you able to load this model in your example code?

https://github.com/snipsco/tract/blob/04b03d75801e221400bc6249b306a200517e35d8/tensorflow/tests/models/plus3.pb

I've tried:

optimize false

    const model: Model = new Model('./tests/plus3.pb', {
      optimize: false,
    });

input facts that I'm pretty sure are wrong:

    const model: Model = new Model('./tests/plus3.pb', {
      inputFacts: {
        0: ['float32', [3]],
      }
    });

and no input facts:

    const model: Model = new Model('./tests/plus3.pb');

Unable to catch model loading errors?

I am trying to handle the possibility for my model failing to load (for instance, if I supply the wrong input facts like last time ๐Ÿ˜). I noticed that the tractjs.Model() constructor returns an object like Uย {modelId: Promise}, so I figured I could catch this Promise failing like so:

const model = await new tractjs.Model('http://localhost:8080/models/mobilenet_v2_1.4_224_frozen.pb', {
  inputFacts: {
    0: ['float32', ['bogus', 'input', 'facts'],
  },
});
model.modelId
  .then(() => {
    console.log('we are good, hurray!')
  })
  .catch((e) => {
    console.log('boo... model didnt load');
  });

However I still get an uncaught exception and my catch block is not executed:

Uncaught (in promise) Error: TractError(
    Msg(
        "Failed analyse for node #177 \"MobilenetV2/Conv/Conv2D\" Conv",
    ),
    ...
)

Node.js support

There are currently 2 reasons tractjs doesn't work in Node.js:

  1. fetch doesn't work in Node, need something like cross-fetch to fix that (but I've had some problems setting it up with Rollup)
  2. The WebWorker API is not consistent between Node.js and browsers, apparently https://github.com/darionco/rollup-plugin-web-worker-loader has an option for Node.js as target platform but that doesn't seem to be working (?) for me.

Don't inline Worker and WASM when building ESM

As mentioned in https://github.com/bminixhofer/tractjs/blob/master/README.md#what-about-size not inlining the Worker and WASM would reduce size. With the recent changes in the build system this should be much easier than before.

At the moment the WASM is actually doubly inlined, once in the worker.js, then the worker.js is again inlined in tractjs.js. The WASM itself is only 4.7MB (1.0MB gzipped) so fixing this would reduce the size by roughly 50%! But it is of course only possible for the ESM build, not for tractjs.min.js.

The CJS build for Node.js will also still be inlined (for now) since size is not as big of an issue for Node applications and I don't want to get into any file system issues.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.