Giter VIP home page Giter VIP logo

onnxscript's Introduction

ONNX Script

CI Dev Release PyPI - Version PyPI - Python Version Ruff Black

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python. ONNX Script is:

  • Expressive: enables the authoring of all ONNX functions.
  • Simple and concise: function code is natural and simple.
  • Debuggable: allows for eager-mode evaluation that provides for a more delightful ONNX model debugging experience.

Note however that ONNX Script does not intend to support the entirety of the Python language.

Website: https://onnxscript.ai/

Design Overview

ONNX Script provides a few major capabilities for authoring and debugging ONNX models and functions:

  • A converter which translates a Python ONNX Script function into an ONNX graph, accomplished by traversing the Python Abstract Syntax Tree to build an ONNX graph equivalent of the function.

  • A converter that operates inversely, translating ONNX models and functions into ONNX Script. This capability can be used to fully round-trip ONNX Script โ†” ONNX graph.

  • A runtime shim that allows such functions to be evaluated (in an "eager mode"). This functionality currently relies on ONNX Runtime for executing every ONNX Operator, and there is a Python-only reference runtime for ONNX underway that will also be supported.

    Note that the runtime is intended to help understand and debug function definitions. Performance is not a goal here.

Installing ONNX Script

pip install --upgrade onnxscript

Install for Development

git clone https://github.com/microsoft/onnxscript
cd onnxscript
pip install -r requirements-dev.txt
pip install -e .

Run Unit Tests

pytest .

Example

import onnx

# We use ONNX opset 15 to define the function below.
from onnxscript import FLOAT, script
from onnxscript import opset15 as op


# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def onnx_hardmax(X, axis: int):
    """Hardmax is similar to ArgMax, with the result being encoded OneHot style."""

    # The type annotation on X indicates that it is a float tensor of
    # unknown rank. The type annotation on axis indicates that it will
    # be treated as an int attribute in ONNX.
    #
    # Invoke ONNX opset 15 op ArgMax.
    # Use unnamed arguments for ONNX input parameters, and named
    # arguments for ONNX attribute parameters.
    argmax = op.ArgMax(X, axis=axis, keepdims=False)
    xshape = op.Shape(X, start=axis)
    # use the Constant operator to create constant tensors
    zero = op.Constant(value_ints=[0])
    depth = op.GatherElements(xshape, zero)
    empty_shape = op.Constant(value_ints=[0])
    depth = op.Reshape(depth, empty_shape)
    values = op.Constant(value_ints=[0, 1])
    cast_values = op.CastLike(values, X)
    return op.OneHot(argmax, depth, cast_values, axis=axis)


# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def sample_model(X: FLOAT[64, 128], Wt: FLOAT[128, 10], Bias: FLOAT[10]) -> FLOAT[64, 10]:
    matmul = op.MatMul(X, Wt) + Bias
    return onnx_hardmax(matmul, axis=1)


# onnx_model is an in-memory ModelProto
onnx_model = sample_model.to_model_proto()

# Save the ONNX model at a given path
onnx.save(onnx_model, "sample_model.onnx")

# Check the model
try:
    onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
    print(f"The model is invalid: {e}")
else:
    print("The model is valid!")

The decorator parses the code of the function, converting it into an intermediate representation. If it fails, it produces an error message indicating the line where the error was detected. If it succeeds, the intermediate representation can be converted into an ONNX graph structure of type FunctionProto:

  • Hardmax.to_function_proto() returns a FunctionProto

Eager Mode Evaluation

Eager mode is mostly used to debug and validate that intermediate results are as expected. The function defined above can be called as below, executing in an eager-evaluation mode:

import numpy as np

v = np.array([[0, 1], [2, 3]], dtype=np.float32)
result = Hardmax(v)

More examples can be found in the docs/examples directory.

Development Guidelines

Every change impacting the converter or the eager evaluation must be unit tested with class OnnxScriptTestCase to ensure both systems do return the same results with the same inputs.

Coding Style

We use ruff, black, isort, and mypy etc. to check code formatting and use lintrunner to run all linters. You can install the dependencies and initialize with

pip install lintrunner lintrunner-adapters
lintrunner init

This will install lintrunner on your system and download all the necessary dependencies to run linters locally. If you want to see what lintrunner init will install, run lintrunner init --dry-run.

To lint local changes:

lintrunner

To format files:

lintrunner f

To lint all files:

lintrunner --all-files

Use --output oneline to produce a compact list of lint errors, useful when there are many errors to fix.

See all available options with lintrunner -h.

To read more about lintrunner, see wiki. To update an existing linting rule or create a new one, modify .lintrunner.toml or create a new adapter following examples in https://github.com/justinchuby/lintrunner-adapters.

Contributing

We're always looking for your help to improve the product (bug fixes, new features, documentation, etc). Currently ONNX Script is under early and heavy development, so we encourage proposing any major changes by filing an issue to discuss your idea with the team first.

Report a Security Issue

Please do not report security vulnerabilities through public GitHub issues.

Please refer to our guidance on filing Security Issues.

Licensing Guidelines

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com.

When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repositories using our CLA.

Code of Conduct

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos is subject to those third-party's policies.

onnxscript's People

Contributors

8bit-pixies avatar abhibyte avatar abock avatar bowenbao avatar dependabot[bot] avatar doloresgarcia avatar er3x3 avatar fatcat-z avatar gramalingam avatar jcwchen avatar justinchuby avatar liqunfu avatar luisfmnunes avatar maanavd avatar microsoftopensource avatar sdpython avatar shubhambhokare1 avatar take-cheeze avatar titaiwangms avatar wschin avatar xadupre avatar xiaowuhu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

onnxscript's Issues

Set up linters for the project

It helps if we set up linters early in the development process (less big PRs for fixes in the future). We may consider: mypy, pylint, black, isort, pydocstyle, flake8, bandit and xdoctest.

bug: tensor types are not recognized as types

Seen in #223

  1. "INT64" expects no type arguments, but 1 given [MYPY]

    def aten_ones(size: INT64[...]) -> TensorType:
  2. Unexpected "..." [MYPY]

    def aten_row_stack(tensors: TensorType[...]) -> TensorType:
  3. INT64 does not accept a default value #229

    image

TensorType is not a type

It looks like TensorType is not a type. When I did

def LeakyRelu(input, negative_slope: FLOAT | float = 0.01, inplace: BOOL | bool = False):
    ...

I got

TypeError: unsupported operand type(s) for |: 'TensorType' and 'type'

Type `script()`

Because: mypy: Untyped decorator makes function "Elu" untyped [misc]mypy(error)

image

Union return type is not value

FloatType = FLOAT16 | FLOAT | DOUBLE

def aten_relu6(self: FloatType) -> FloatType:
    zero = op.CastLike(op.Constant(value_float=0.0), self)
    return op.Max(self, zero)

raises

        if fn.returns:
            returntype = self.eval_constant_expr(fn.returns)
            if isinstance(returntype, tuple):
                assert all(ta.is_valid(t) for t in returntype)
                self.returntype = returntype
            else:
>               assert ta.is_valid(returntype)
E               AssertionError

cc @gramalingam

Tensor typing annotations

  1. Generics

    • torch
      • alpha_dropout(Tensor(a!) input, float p, bool train) -> Tensor(a!)
    • onnxscript
      • def alpha_dropout(input: TensorType[T, ...], p: float, train: bool) -> TensorType[T, ...]
  2. Dims

    • torch
      • reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor
    • onnxscript
      • def reflection_pad3d(self: TensorType[Any, ...], padding: TensorType[INT64, [6]]) -> TensorType[Any, ...]

Tracking: ATen lib exploration

@justinchuby is using this issue to track experiments

  • conditional logic involving strings
    • (What does the graph look like?)
  • tensor as kwargs (default for onnx-inputs?)
    • optional tensor inputs
  • return in if
  • generic tensor types: Aaron's PR
    • union type annotations
  • operation on lists and tuples
  • script vs non-script
  • is a tensor always backed by data, or can it be symbolic? Interoperability with numpy (__array__)
  • composite subgraphs, aten if
  • Function calls function. Are they both in the graph? How to reuse code?
    • Namespacing support for functions
    • Function calling non-script functions
    • Graph builder: consider evaluator
    • If I call to function proto in the caller, will I get callee as well?
  • input validation. What do we need?
  • Data type dependent logic
    • dtype reconciliation and promotion => need to be done before calling the function.
    • quantized ops
  • function inlining
  • IR builder
    • how does it handle shape type inference when building a model proto?

Maintain consistency on how args and kwargs are handled

Currently in eager mode, an int passed as a positional argument is treated as a tensor and one as a keyword argument is considered a raw scalar. This behavior can create confusion for users.

There is also a potential for making programmatic calls trickier because the tensors cannot by supplied by name.

[Parser] Support `del`

When an argument is not used, I would like to delete it as good python code does. The arguments cannot be removed if we wanted to keep a good correspondence between aten signatures and the lib.

I assume onnxscript can just do nothing for del?

@onnxscript.script()
def Elu(
    self,
    alpha: float = 1.0,
    scale: float = 1.0,
    input_scale: float = 1.0,
):
    del scale    # Right here, it is now ValueError: ERROR None:10    -- line:     del input_scale
    del input_scale
    return op.Elu(self, alpha=alpha)

Avoid relative imports; import only modules

  1. Even though relative imports are prevalent in ort, they are confusing, harder to manage and refactor. We should prefer absolute imports for clarity and robustness.
    https://softwareengineering.stackexchange.com/questions/159503/whats-wrong-with-relative-imports-in-python
    https://google.github.io/styleguide/pyguide.html#22-imports
  2. Import only modules (in most cases that is not __init__).
    Import only the module instead of classes and functions to (1) keep the namespace clean and provide readers with more context on where its members come from, and (2) prevent circular import errors.
    • Prevent circular import errors: Programming FAQ โ€” Python 3.10.4 documentation
      Circular imports are fine where both modules use the โ€œimport โ€ form of import. They fail when the 2nd module wants to grab a name out of the first (โ€œfrom module import nameโ€) and the import is at the top level. Thatโ€™s because names in the 1st are not yet available, because the first module is busy importing the 2nd.

    • Clean namespace: For example, readers donโ€™t need to backtrack to see sleep is a function from time, as opposed to a function defined in the file. https://google.github.io/styleguide/pyguide.html#22-imports

Create eager-mode evaluation of onnxscript functions

(a) We would like to be able to run onnxscript functions by calling ORT to execute each primitive op. This will allow us to debug function definitions more easily.

As an example, consider the following onnxscript code fragment, which defines a gemm followed by gelu in terms of more primitive operations.

import oxs

def gemmgelu(
        A: FLOAT[2048, 16],
        W: FLOAT[16, 4096],
        Bias: FLOAT[4096]
) -> FLOAT[2048, 4096]:

    a = oxs.Constant(value_float=0.5)
    b = oxs.Constant(value_float=0.797885)
    c = oxs.Constant(value_float=0.035677)
    one = oxs.Constant(value_float=1.0)
    P1 = oxs.MatMul(A, W)
    X = oxs.Add(P1, Bias)
    T1 = oxs.Mul(X, X)
    T2 = oxs.Mul(c, T1)
    T3 = oxs.Add(b, T2)
    T4 = oxs.Mul(X, T3)
    T5 = oxs.Tanh(T4)
    T6 = oxs.Add(one, T5)
    T7 = oxs.Mul(X, T6)
    Y = oxs.Mul(a, T7)
    return Y

print(gemmgelu(a, w, b))

We would like to execute this in a standard python debugger, by calling on ORT kernels to execute each op.

Finally, we would also like to be able to use test-cases defined for the function-op (eg., from onnx) to check correctness of function definition.

[Feature Request] Return in if

I created this toy example

@script()
def FuncWithStr(input, negative_slope: float = 0.01, mode: str = "foo"):
    if mode == "foo":
        return input
    else:
        zero = op.CastLike(0, input)
        negative_slope = op.CastLike(negative_slope, input)
        return op.Where(input < zero, negative_slope * input, input)

which I think is common for pytorch functions.

I get

ValueError: ERROR
thenGraph_3:4    -- line:         return input
    Return statements are not permitted inside control-flow statements.

Should I break the logic out to a non-script function and conditionally call two different onnx functions for the two cases instead?

Create a "latest_opset"

Looks like the default opset is opset 14. It would then be helpful to have a latest_opset

Model-local functions when creating a ModelProto from a python function

When we export/convert a Python function to a ModelProto, we need to identify the set of functions that will be included in the generated ModelProto as model-local functions. Furthermore, the users should be able to control this. For example, I might want to create model that calls Relu, with or without including the function-definition for Relu (even though we might have a function-definition for Relu available).

(See PR: #41 )

[Converter support] Support creating ONNX dialect in other IRs

onnx script is going to be great for the conversion process in pytorch. The torch exporter uses the onnx dialect in TorchScript, before going through a few additional passes and eventually generating an onnx proto. (So we cannot create onnx protos directly)

It would be very helpful to be able to delegate the graph building process to another object / entity. For example, we can create a wrapper around torch's graph.op method so that each graph building call is delegated to graph.op, allowing it to build a torch script graph.

One way of doing this can be exposing the graph building APIs so we don't need to rely on @script and the source code for constructing the graph.

cc @BowenBao

Make `op.Cast` a type narrowing function?

def aten_elu__int(
    self: IntType, alpha: float = 1.0,     scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
    input_scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
) -> TensorType:
    return op.Elu(op.Cast(self, to=onnxscript.FLOAT), alpha=alpha)

pylance will complain:

Argument of type "BFLOAT16 | BOOL | DOUBLE | FLOAT | FLOAT16 | INT16 | INT32 | INT64 | INT8 | STRING | UINT16 | UINT32 | UINT64 | UINT8" cannot be assigned to parameter "X" of type "DOUBLE | FLOAT | FLOAT16" in function "Elu"
  Type "BFLOAT16 | BOOL | DOUBLE | FLOAT | FLOAT16 | INT16 | INT32 | INT64 | INT8 | STRING | UINT16 | UINT32 | UINT64 | UINT8" cannot be assigned to type "DOUBLE | FLOAT | FLOAT16"
    Type "BFLOAT16" cannot be assigned to type "DOUBLE | FLOAT | FLOAT16"
      "BFLOAT16" is incompatible with "DOUBLE"
      "BFLOAT16" is incompatible with "FLOAT"
      "BFLOAT16" is incompatible with "FLOAT16"

Making to= to take a generic and making it to return a TypeGuard may help.

https://mypy.readthedocs.io/en/stable/type_narrowing.html#user-defined-type-guards

cc @abock

lintrunner tweaks

Creating a pile-on issue for problems with lintrunner.

  • The regex-based linters are wildly slow, and what's more, they result in reading and writing the same set of files many times over. Ideally all the operations should be performed in a single pass across the codebase.
    • Specifically the SPACES (#214) linter takes many minutes to complete when running against the statically generated code. This linter should just trim lines... no grep necessary.
    • We can use a checker to check on .editorconfig directly #227
  • It seems lintrunner does not work appropriately on Windows... this is a barrier to any dev working in Windows. WSL should not be a workaround.
    • the flake8 linter does not find the .flake8 config file when running in Windows, black-isort the same. We need to figure out if that is because lintrunner doesn't set cwd properly on Windows justinchuby/lintrunner-adapters#29
    • We need to port all the grep linters to Windows because there is no grep removed all grep linters
    • suo/lintrunner#22 suo/lintrunner#23

Improve error messages when a function does not return

I created a function that I forgot to add return

def aten_lt(self, other):
    # lt.Tensor(Tensor self, Tensor other) -> Tensor

    # TODO(justinchuby): Input spec: non bool tensor
    # Boolean inputs can be pre-casted by policy
    op.Less(self, other)

The message is

onnxscript/main.py:105: in transform
    result = script_check(ast, opset, env, src, default_opset=default_opset)
onnxscript/main.py:53: in script_check
    return convert.top_level_stmt(f)
onnxscript/converter.py:1337: in top_level_stmt
    analysis.do_liveness_analysis(stmt, self.message)
onnxscript/analysis.py:157: in do_liveness_analysis
    live = visit(s, live)
onnxscript/analysis.py:98: in visit
    live = do_visit(stmt, live_out)
onnxscript/analysis.py:152: in do_visit
    raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}."))
onnxscript/converter.py:235: in message
    return self.source_of(node).msg(error_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <onnxscript.converter.Converter object at 0x7f257b1f5e70>, node = <ast.Expr object at 0x7f257b1f5d80>

    def source_of(self, node: ast.AST) -> sourceinfo.SourceInfo:
>       return sourceinfo.SourceInfo(node, self.source, self.current_fn.name)
E       AttributeError: 'NoneType' object has no attribute 'name'

onnxscript/converter.py:231: AttributeError

Support Annotated types

We would like to annotate the types with typing.Annotated (https://docs.python.org/3/library/typing.html#typing.Annotated) to include run time value checking for the ATen lib. Annotate should only be used for attributes.

@onnxscript.script()
@atenop("aten::elu")
def Elu(
    self,
    alpha: float = 1.0,
    scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,   # Here, just need to take the float type out
    input_scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
):
    # del scale
    # del input_scale
    return op.Elu(self, alpha=alpha)

ATen/Torch Op Library in ONNX Script Design

WIP in https://microsoft.sharepoint.com/:w:/t/ONNX2/EcUaSHDlDiBFvGGX5BC49Z0B0mHhO7s_6uLeVrOoDE4n2w?e=ZXXIgy. Will move here when stable.

This design doc reflects thinking until Jan 2023. The design may have evolved since then, especially around what we call a "GraphBuilder" in this doc; some assumptions may no longer be accurate; but it should capture the gist for the torch_lib.

Created: December 2022
Updated: January 2023
Authors: @justinchuby , @fatcat-z , @xiaowuhu

Objective

This document aims to provide a design of the ATen Op Library in ONNX
Script and its integration with the new torch.fx exporter.

Motivation and context

https://github.com/microsoft/onnx-converters-private/issues/124

Goals and non-Goals

Goals

  • Create a library of aten ops expressed in onnx-script independent of
    the PyTorch -> ONNX exporter, with focus on creating a delightful
    dev experience. The op library will be used in the latest version of
    torch.onnx exporter.
  • Design mechanisms the torch.onnx exporter can leverage to produce an
    ONNX model using the function library.
  • Representation (in ONNX proto), data structure and interfaces (as
    part of onnx-script library) for preserving metadata around how the
    graph is manipulated to facilitate downstream debugging uses.
  • Extension
    • Designed to provide multi exporter support: the same techniques
      should also support Sklearn model's converter.
    • Design a consistent and delightful experience for users who want
      to write plain onnxscript to create ML models.

Non-goals

  • Reach high coverage for aten ops
    • The focus should be laying the foundation for the dev
      experience, identifying gaps in onnx-script to enable rapid
      parallel work.
    • Some coverage should still be achieved, prioritizing ops are
      representative of the authoring experience, commonly used and
      can help us identify gaps in onnx-script.
    • We need to identify ops needed for the dynamo demo and
      prioritize those.
  • Custom ONNX ops
    • The ATen lib aims to use ONNX native operators. Support for
      custom operators is currently out of scope.
  • Multi ONNX opset version support
    • We design opset conversion to be the responsibility of
      downstream optimization in order to reduce complexity in both
      the exporter and the function lib. (See assumptions and Open
      questions below)
    • Only the latest opset version is supported, and the function lib
      will continue to track the latest opset version
  • Shape inference and Graph optimization
    • Shape inference is handled by onnx and/or exporters.

Design principles

  • Robustness, accuracy, dev experience.
    • Functions created should have a comprehensive test for
      robustness and accuracy.
    • The development experience should be speedy and easy.
    • Maintain invariants across components
  • Flip and switch migration
    • The new exporter should only need to create bridges at a scale
      much smaller than the existing symbolic functions for it to
      leverage the ATen library.
  • Delightful user experience
    • Failures are clear and actionable.
  • Test first development process
  • Parallelizable work
  • Meta information preserving for downstream tasks
  • Function specs stay with the implementation in a single place.
  • Fail early and fast
    • Raise errors as early as possible

Architecture

We design the onnxscript ATen library and its integration with exporters
into three major components: the function library itself, the
GraphBuilder, and the exporter side logic.

The function library

The function library is a collection of ONNX functions written in
onnxscript. As such, it only contains logic representable by an ONNX
function proto. Each function matches, as closely as possible, the
signature of a torch.ops.aten operator. The function decomposes the ATen
operator with one or more ONNX operators.

A function has two roles, (1) decompose the ATen op to ONNX IR, and
(2) specify the requirements for inputs (expected input types etc.).
(2) is necessary because ONNX function cannot handle dtype dependent
logic. We will need additional code around the functions to bridge the
inputs and/or dispatch to different function overloads on the exporter
side. Requirements for inputs serve as meta information that can be
leveraged by the exporter for it to decide which function to use.

Based on the constraints of ONNX functions and principles of separation
of responsibilities, A function does not verify inputs or
handle/emit any errors. It should be regarded as a description of
the decomposition logic (which gets translated into a proto). Components
that leverage the functions are responsible for verifying inputs.

All functions will be tested by PyTorch's OpInfo database, using
onnxscript's eager mode, to make sure the numerical output matches that
of the ATen operator.

The function library additionally provides a default op-name-to-function
mapping for lookup in the exporter.

Graph Builder <or just Graph, with good APIs>

Having all the functions built, we still need a mechanism to synthesize
them into an onnx model graph. In the current version of torch.onnx
exporter, we use torchscript as the IR for the graph. As we envision the
new exporter, it makes sense for all onnx-related logic to be handled by
onnx. The proposed Graph Builder component will replace torchscript to
store the graph being built, and provide limited but necessary graph
manipulation (TBD, e.g. graph stitching) capabilities. Its main
responsibilities are:

  1. Maintain one, or potentially more, representations of ONNX (sub)
    graphs and (TBD, the relationship between framework native
    input/output/operator representation and its internal
    representation)

  2. Provide an API to acquire the operator information from the
    exporter, as well as inputs/outputs of the whole graph.

  3. Define the operator information needed and a protocol for the
    exporter to supply the information. The protocol should define
    traits that will be implemented by the exporter.

  4. Provide a protocol for the exporter to supply stack traces and
    diagnostics information and preserve them through the exported model
    in a sensible representation (most likely SARIF).

  5. Provide the capability of executing any (sub)graph being built for
    debuggability, such that users can e.g. examine any intermediate
    results or run dissect algorithms on the partial graphs and
    programmatically compare results.

  6. Serialize graphs into ONNX model proto.

  7. Build in guards for catching graph errors during graph
    construction and provide useful error messages, leveraging
    onnxscript/onnx capabilities and input requirements specified by
    individual ONNX functions.

  8. Provide a general mechanism to insert glue nodes to bridge dtype
    discrepancy between input tensors and what functions expect, based
    on the function input requirements.

  9. Ensure the produced onnx model is valid, and produce diagnosable
    errors as early as it can.

  10. (TBD) Maintain a Python source representation of the model when a
    model cannot be represented solely by onnx

Graph optimization and transformations are out of scope in this design.

Graph building experience

To provide a consistent experience for both exporter developers and
onnxscript authors, we propose to use the same interface as onnxscript
ops to produce graph by implementing a graph capturing Evaluator (which
internally talks to the GraphBuilder).

The following example builds an ONNX graph:

graphbuilder = GraphBuilder()
a = graphbuilder.input(dtype=..., shape=[2, 3])
b = graphbuilder.input(dtype=..., shape=...)
c = op.Mul(a, b)
d = aten_lib.ops.elu(c) # A function from the function library
print(a.rank()) # Python logic that is not possible in onnx functions
graphbuilder.set_output(c, d)

This way exporter developers and other users do not have to use the
lower-level graph manipulation APIs provided the GraphBuilder like
make_node, etc.

Communication protocols

TODO

Correctness validation

TODO

Exporter

The exporter is responsible for capturing a computation graph and
transforming it into a composition of ATen ops. It should also

  1. Capture and transform the computation graph to ATen level ops.
  2. Provide the necessary tensor/op information for graph construction
    by storing them according to the protocol GraphBuilder provides.
  3. Leverage the op to function mapping and decide which function to use
    to build the graph. In particular, the dtype dependent logic is
    controlled by the exporter so that any changes made by PyTorch can
    be directly reflected in the exporter, without the need of updating
    a dependency (GraphBuilder)
  4. Provide the stack traces based on the protocol.
    Exporters should not need to create/manipulate onnx nodes in most cases
    except when dealing with complex conditional logic.

Interaction among components

We envision the export process roughly as the following.

  1. Exporter captures and transforms the computation graph to
    eliminate inplace and "out" ops and inline/convert most ops to
    aten, as well as adding explicit type promotion logic.
  2. Exporter walks through the framework native graph (e.g. fx
    graph)
  3. For each node, Exporter converts the input/output information
    into a format defined by the protocol provided by GraphBuilder.
  4. Exporter dispatches to Logic to select functions. This is
    usually programmatic and in rare cases symbolic function style. The
    graph is created by calling onnxscript ops in the Logic, which
    internally communicates with GraphBuilder to record the graph.
    GraphBuilder also records the output name of the ATen op to
    associate it with the framework native output (from the exporter) so
    that it can be uniquely identified by GraphBuilder in the ONNX
    graph.
  5. GraphBuilder creates and validates ONNX nodes based on the
    supplied info, which includes the code stacks information.
  6. When the graph is translated, Exporter calls GraphBuilder to
    make a model proto.
  7. GraphBuilder makes a model and appends all the function protos
    used, deduplicated, onto the model.
  8. Exporter saves the model on disk.

Design considerations

User experience

A core goal of this design is to provide a delightful experience. We
consider the experience from perspectives of (1) onnx-script Aten
library authors (2) Exporter developers (3) Exporter users.

Authoring experience

A delightful experience should include

  1. A speedy authoring experience enabled by onnxscript and its
    typing/intellisense support
  2. Easy tests enabled by PyTorch OpInfo
  3. Debuggability in functions using eager evaluation
  4. A robust input validation mechanism
  5. Easy to navigate code
    a. Naming: easy to locate

Contribution workflow

TODO

Exporter developers

  • Intuitive APIs
    • Clear contracts with protocols when interacting with the graph.
  • Debug and diagnosability
    • Debugger support. Users should be able to use pdb and set
      breakpoints.
    • Errors should point to to where they happen and provide
      actionable information.

Exporter users

  • Node history traces (be able to answer: how did this node come
    about?)

Correctness: Code gen signatures

We code gen the function signatures from native_functions.yaml from
pytorch to ensure correctness of the function signatures. This will
serve as a onetime tool to get us started. Having the tool to keep the
implementation updated can be more work than desired and is not planned.

Example generated signatures: #223

Correctness > potential model performance

To support graph conversion in the exporter, we focus on ensuring that
the model is correct and complete.

  1. For example, we prefer retaining the if nodes in a function that
    has different modes based on an attribute.
  2. As another example, we retain cast nodes when needed to ensure
    the functions work on any data type torch supports.
  3. We favor general ops like MatMul than specialized ops like Gemm in
    the function lib.

Graph optimization is out of scope and is assumed to be done by
downstream optimization passes and runtimes.

Quantization support

The ATen function library assumes all inputs are non-quantized.
Quantized tensors and operators should be managed by the exporter.
GraphBuilder can provide a procedure to represent a dequant-reqant
block.

Training support

TBD. Needs input.

Context data (code traces) preservation

TODO: This needs a separate design

Input Validation protocol

See GraphBuilder

Code gen dispatching functions based on policy

To ensure correctness and scale coverage. We can also set break points
and trace code execution easily. TODO

Code reuse

Function calls function. We need delayed script compilation. TODO

Error handling principles and debug experience

Maintain invariants across components. Avoid raising errors in the
middle of computation logic. TODO

Testing

Most of the tests are to make sure the aten functions in onnxscript are
robust to different inputs and match the torch eager mode output. We
will leverage the Pytorch OpInfo
for generating sample test inputs and the onnxscript eager mode
evaluation for getting the function output.

Quantized ops tend to not have OpInfos. We can (1) Create simple tests
(2) work with the torch team to define OpInfos for those ops when high
coverage is needed.

Alternatives considered

aten torch decomp

PyTorch has a decomp (_decomp,_prim) library that decomposes aten ops
with more primitive aten ops. It is currently used by torch dynamo for
easing backend support. While we should leverage this feature in the
dynamo implementation, we should still aim to implement most of the
decompositions in this library so that code can be reused for other
frameworks and provide enough information for the downstream compiler
optimization.

Risks, concerns; how can it fail

Risk: Not all operators can be expressed in onnxscript

Response: This includes dtype dependent logic. While they cannot be
expressed by ONNX functions, we can use the graph building experience to
capture the logic statically into the ONNX model.

Risk: onnxscript may not have APIs for building control flow ops

Response: We need to take examples and design the graphbuilder
capability to support this kind of subgraphs.

Risk: There are a lot of ATen operators

...

Risk: Performance: Functions may contain unnecessary logic from ATen that is not involved in particular computation

Example

We define the ATen add operator in onnxscript as

def aten_add(self, other, alpha: float = 1) -> TensorType:
    # add.Tensor(Tensor self, Tensor other, \*, Scalar alpha=1) -> Tensor

    if alpha != 1:
        other = op.Mul(other, alpha)

    return op.Add(self, other)

Note that there is an attribute that needed a conditional.

Response: The library and the exporter concern the correctness of the
model and preserves as much logic as possible. Optimization is the
responsibility of downstream passes and runtimes.

Open questions

  1. Object model for the sources and how they should be represented in
    onnx
  2. What should we do about prim torch ops? Do they resemble ATen ops?
    We assume they do.
    a. Wei-Sheng: eventually we want to support them for model
    coverage. However, ONNXRuntime won't care since if we decompose
    big ops into prims, ORT needs to fuse them back.
  3. How should we handle custom ONNX ops?
  4. Do we consider different ONNX opset versions?
    a. No. It is the responsibility of downstream version conversion
    passes/tools

Assumptions made

  1. Onnxscript/onnx with handle opset conversion so implementation
    only needs to be done on the latest opset.
  2. The function library is only used to handle ATen operators, for both
    dynamo and fx traced graphs. Cases for prim:: ops (if, etc.) are
    rare and can be handled case by case.
  3. The decomposition logic to onnx cannot be easily generated, so we
    still need to hand write them as onnxscript functions.

Glossary

  • "Functions", "onnx functions", "onnx functions written in onnxscript" are used interchangeably.
  • "Function lib" The ATen op library being described in this document.
  • "Proto" An onnx protobuf sterilization.
  • "Input/output" When not clearly stated, they refer to the input and
    output of an operator, instead of a graph.
  • "Op", "operator" An operator in a computation graph.
  • "ATen Op". A computational operator defined by PyTorch, e.g. one
    specified in native_functions.yaml. It also broadly refers to the
    "torch IR" defined in
    https://docs-preview.pytorch.org/90644/ir.html , or the "prims"
    operators.
  • "Exporter", "Converter" Model converters. Usually the torch.onnx
    PyTorch->ONNX model converter is considered.

Show what types of inputs are allowed

It would be helpful to show what types of inputs are allowed/expected types in eager mode calls, the name of the argument, as well as how one can make it right.

The current message is

Traceback (most recent call last):
  File "/home/justinchu/dev/onnx-script/playground/test_func.py", line 13, in <module>
    result = LeakyRelu(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=np.float32), 0.1)
  File "/home/justinchu/dev/onnx-script/onnxscript/values.py", line 192, in __call__
    return self._usercall(*args, **kwargs)
  File "/home/justinchu/dev/onnx-script/onnxscript/values.py", line 203, in _usercall
    raise TypeError(f"Unexpected input type {type(a)} for an input {i}.")
TypeError: Unexpected input type <class 'float'> for an input 1.

https://github.com/microsoft/onnx-script/blob/de7b00a57894d5cbbfb8a739c5a49f80f73a34cd/onnxscript/values.py#L203

onnxscript produced graph fails onnx graph checker

Traceback (most recent call last):
  File "/home/justinchu/dev/onnx-script/onnxscript/poc/os_graph_builder.py", line 145, in <module>
    onnx_model = gb.make_model(model_name)
  File "/home/justinchu/dev/onnx-script/onnxscript/poc/os_graph_builder.py", line 77, in make_model
    checker.check_model(self.onnx_model)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/onnx/checker.py", line 106, in check_model
    C.check_model(protobuf_string)
onnx.onnx_cpp2py_export.checker.ValidationError: Graph must be in single static assignment (SSA) form, however 'other' has been used as output names multiple times.

From script

@script()
def aten_add(self, other, alpha: float = 1) -> TensorType:
    # add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
    if alpha != 1:
        other = op.Mul(other, alpha)  # type: ignore[arg-type]
    return op.Add(self, other)

cc @gramalingam

Delay compile functions

All functions decorated with script() will be compiled in import time. With a potential library of >~1000 functions, this may translate to a long import time (to measure). Potentially, we can compile the function just in time when used.

Create comparison utility to compare 2 FunctionProto or 2 ModelProto

For testing purposes, we would like to create a utility to compare 2 FunctionProtos. And, similarly, 2 ModelProtos.

For robust checking, ideally it should support the following features, but it is okay to start with something simpler first, and add these features.

  • Ability to handle differences in some tensor names (to handle temporary name generators)
  • Ability to handle some reordering of Nodes that is immaterial (to handle variation in order of evaluation)

Support `@script` in python interactive mode / dynamic

Currently @script does not work in python interactive mode because inspect cannot find code source in the command prompt. However, there are hacks we can use.

dill implements a hack that reads the input buff to retrieve the source code: https://github.com/uqfoundation/dill/blob/master/dill/source.py#L326-L415.

An example of its usage is in the taichi project. It wraps the dill code in the sourceinspect library that supports inspecting code in more environments: https://github.com/taichi-dev/sourceinspect

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.