WIP in https://microsoft.sharepoint.com/:w:/t/ONNX2/EcUaSHDlDiBFvGGX5BC49Z0B0mHhO7s_6uLeVrOoDE4n2w?e=ZXXIgy. Will move here when stable.
This design doc reflects thinking until Jan 2023. The design may have evolved since then, especially around what we call a "GraphBuilder" in this doc; some assumptions may no longer be accurate; but it should capture the gist for the torch_lib.
Created: December 2022
Updated: January 2023
Authors: @justinchuby , @fatcat-z , @xiaowuhu
Objective
This document aims to provide a design of the ATen Op Library in ONNX
Script and its integration with the new torch.fx exporter.
Motivation and context
https://github.com/microsoft/onnx-converters-private/issues/124
Goals and non-Goals
Goals
- Create a library of aten ops expressed in onnx-script independent of
the PyTorch -> ONNX exporter, with focus on creating a delightful
dev experience. The op library will be used in the latest version of
torch.onnx exporter.
- Design mechanisms the torch.onnx exporter can leverage to produce an
ONNX model using the function library.
- Representation (in ONNX proto), data structure and interfaces (as
part of onnx-script library) for preserving metadata around how the
graph is manipulated to facilitate downstream debugging uses.
- Extension
- Designed to provide multi exporter support: the same techniques
should also support Sklearn model's converter.
- Design a consistent and delightful experience for users who want
to write plain onnxscript to create ML models.
Non-goals
- Reach high coverage for aten ops
- The focus should be laying the foundation for the dev
experience, identifying gaps in onnx-script to enable rapid
parallel work.
- Some coverage should still be achieved, prioritizing ops are
representative of the authoring experience, commonly used and
can help us identify gaps in onnx-script.
- We need to identify ops needed for the dynamo demo and
prioritize those.
- Custom ONNX ops
- The ATen lib aims to use ONNX native operators. Support for
custom operators is currently out of scope.
- Multi ONNX opset version support
- We design opset conversion to be the responsibility of
downstream optimization in order to reduce complexity in both
the exporter and the function lib. (See assumptions and Open
questions below)
- Only the latest opset version is supported, and the function lib
will continue to track the latest opset version
- Shape inference and Graph optimization
- Shape inference is handled by onnx and/or exporters.
Design principles
- Robustness, accuracy, dev experience.
- Functions created should have a comprehensive test for
robustness and accuracy.
- The development experience should be speedy and easy.
- Maintain invariants across components
- Flip and switch migration
- The new exporter should only need to create bridges at a scale
much smaller than the existing symbolic functions for it to
leverage the ATen library.
- Delightful user experience
- Failures are clear and actionable.
- Test first development process
- Parallelizable work
- Meta information preserving for downstream tasks
- Function specs stay with the implementation in a single place.
- Fail early and fast
- Raise errors as early as possible
Architecture
We design the onnxscript ATen library and its integration with exporters
into three major components: the function library itself, the
GraphBuilder, and the exporter side logic.
The function library
The function library is a collection of ONNX functions written in
onnxscript. As such, it only contains logic representable by an ONNX
function proto. Each function matches, as closely as possible, the
signature of a torch.ops.aten operator. The function decomposes the ATen
operator with one or more ONNX operators.
A function has two roles, (1) decompose the ATen op to ONNX IR, and
(2) specify the requirements for inputs (expected input types etc.).
(2) is necessary because ONNX function cannot handle dtype dependent
logic. We will need additional code around the functions to bridge the
inputs and/or dispatch to different function overloads on the exporter
side. Requirements for inputs serve as meta information that can be
leveraged by the exporter for it to decide which function to use.
Based on the constraints of ONNX functions and principles of separation
of responsibilities, A function does not verify inputs or
handle/emit any errors. It should be regarded as a description of
the decomposition logic (which gets translated into a proto). Components
that leverage the functions are responsible for verifying inputs.
All functions will be tested by PyTorch's OpInfo database, using
onnxscript's eager mode, to make sure the numerical output matches that
of the ATen operator.
The function library additionally provides a default op-name-to-function
mapping for lookup in the exporter.
Graph Builder <or just Graph, with good APIs>
Having all the functions built, we still need a mechanism to synthesize
them into an onnx model graph. In the current version of torch.onnx
exporter, we use torchscript as the IR for the graph. As we envision the
new exporter, it makes sense for all onnx-related logic to be handled by
onnx. The proposed Graph Builder component will replace torchscript to
store the graph being built, and provide limited but necessary graph
manipulation (TBD, e.g. graph stitching) capabilities. Its main
responsibilities are:
-
Maintain one, or potentially more, representations of ONNX (sub)
graphs and (TBD, the relationship between framework native
input/output/operator representation and its internal
representation)
-
Provide an API to acquire the operator information from the
exporter, as well as inputs/outputs of the whole graph.
-
Define the operator information needed and a protocol for the
exporter to supply the information. The protocol should define
traits that will be implemented by the exporter.
-
Provide a protocol for the exporter to supply stack traces and
diagnostics information and preserve them through the exported model
in a sensible representation (most likely SARIF).
-
Provide the capability of executing any (sub)graph being built for
debuggability, such that users can e.g. examine any intermediate
results or run dissect algorithms on the partial graphs and
programmatically compare results.
-
Serialize graphs into ONNX model proto.
-
Build in guards for catching graph errors during graph
construction and provide useful error messages, leveraging
onnxscript/onnx capabilities and input requirements specified by
individual ONNX functions.
-
Provide a general mechanism to insert glue nodes to bridge dtype
discrepancy between input tensors and what functions expect, based
on the function input requirements.
-
Ensure the produced onnx model is valid, and produce diagnosable
errors as early as it can.
-
(TBD) Maintain a Python source representation of the model when a
model cannot be represented solely by onnx
Graph optimization and transformations are out of scope in this design.
Graph building experience
To provide a consistent experience for both exporter developers and
onnxscript authors, we propose to use the same interface as onnxscript
ops to produce graph by implementing a graph capturing Evaluator (which
internally talks to the GraphBuilder).
The following example builds an ONNX graph:
graphbuilder = GraphBuilder()
a = graphbuilder.input(dtype=..., shape=[2, 3])
b = graphbuilder.input(dtype=..., shape=...)
c = op.Mul(a, b)
d = aten_lib.ops.elu(c) # A function from the function library
print(a.rank()) # Python logic that is not possible in onnx functions
graphbuilder.set_output(c, d)
This way exporter developers and other users do not have to use the
lower-level graph manipulation APIs provided the GraphBuilder like
make_node, etc.
Communication protocols
TODO
Correctness validation
TODO
Exporter
The exporter is responsible for capturing a computation graph and
transforming it into a composition of ATen ops. It should also
- Capture and transform the computation graph to ATen level ops.
- Provide the necessary tensor/op information for graph construction
by storing them according to the protocol GraphBuilder provides.
- Leverage the op to function mapping and decide which function to use
to build the graph. In particular, the dtype dependent logic is
controlled by the exporter so that any changes made by PyTorch can
be directly reflected in the exporter, without the need of updating
a dependency (GraphBuilder)
- Provide the stack traces based on the protocol.
Exporters should not need to create/manipulate onnx nodes in most cases
except when dealing with complex conditional logic.
Interaction among components
We envision the export process roughly as the following.
- Exporter captures and transforms the computation graph to
eliminate inplace and "out" ops and inline/convert most ops to
aten, as well as adding explicit type promotion logic.
- Exporter walks through the framework native graph (e.g. fx
graph)
- For each node, Exporter converts the input/output information
into a format defined by the protocol provided by GraphBuilder.
- Exporter dispatches to Logic to select functions. This is
usually programmatic and in rare cases symbolic function style. The
graph is created by calling onnxscript ops in the Logic, which
internally communicates with GraphBuilder to record the graph.
GraphBuilder also records the output name of the ATen op to
associate it with the framework native output (from the exporter) so
that it can be uniquely identified by GraphBuilder in the ONNX
graph.
- GraphBuilder creates and validates ONNX nodes based on the
supplied info, which includes the code stacks information.
- When the graph is translated, Exporter calls GraphBuilder to
make a model proto.
- GraphBuilder makes a model and appends all the function protos
used, deduplicated, onto the model.
- Exporter saves the model on disk.
Design considerations
User experience
A core goal of this design is to provide a delightful experience. We
consider the experience from perspectives of (1) onnx-script Aten
library authors (2) Exporter developers (3) Exporter users.
Authoring experience
A delightful experience should include
- A speedy authoring experience enabled by onnxscript and its
typing/intellisense support
- Easy tests enabled by PyTorch OpInfo
- Debuggability in functions using eager evaluation
- A robust input validation mechanism
- Easy to navigate code
a. Naming: easy to locate
Contribution workflow
TODO
Exporter developers
- Intuitive APIs
- Clear contracts with protocols when interacting with the graph.
- Debug and diagnosability
- Debugger support. Users should be able to use pdb and set
breakpoints.
- Errors should point to to where they happen and provide
actionable information.
Exporter users
- Node history traces (be able to answer: how did this node come
about?)
Correctness: Code gen signatures
We code gen the function signatures from native_functions.yaml
from
pytorch to ensure correctness of the function signatures. This will
serve as a onetime tool to get us started. Having the tool to keep the
implementation updated can be more work than desired and is not planned.
Example generated signatures: #223
Correctness > potential model performance
To support graph conversion in the exporter, we focus on ensuring that
the model is correct and complete.
- For example, we prefer retaining the
if
nodes in a function that
has different modes based on an attribute.
- As another example, we retain
cast
nodes when needed to ensure
the functions work on any data type torch supports.
- We favor general ops like MatMul than specialized ops like Gemm in
the function lib.
Graph optimization is out of scope and is assumed to be done by
downstream optimization passes and runtimes.
Quantization support
The ATen function library assumes all inputs are non-quantized.
Quantized tensors and operators should be managed by the exporter.
GraphBuilder can provide a procedure to represent a dequant-reqant
block.
Training support
TBD. Needs input.
Context data (code traces) preservation
TODO: This needs a separate design
Input Validation protocol
See GraphBuilder
Code gen dispatching functions based on policy
To ensure correctness and scale coverage. We can also set break points
and trace code execution easily. TODO
Code reuse
Function calls function. We need delayed script compilation. TODO
Error handling principles and debug experience
Maintain invariants across components. Avoid raising errors in the
middle of computation logic. TODO
Testing
Most of the tests are to make sure the aten functions in onnxscript are
robust to different inputs and match the torch eager mode output. We
will leverage the Pytorch OpInfo
for generating sample test inputs and the onnxscript eager mode
evaluation for getting the function output.
Quantized ops tend to not have OpInfos. We can (1) Create simple tests
(2) work with the torch team to define OpInfos for those ops when high
coverage is needed.
Alternatives considered
aten torch decomp
PyTorch has a decomp (_decomp,_prim) library that decomposes aten ops
with more primitive aten ops. It is currently used by torch dynamo for
easing backend support. While we should leverage this feature in the
dynamo implementation, we should still aim to implement most of the
decompositions in this library so that code can be reused for other
frameworks and provide enough information for the downstream compiler
optimization.
Risks, concerns; how can it fail
Risk: Not all operators can be expressed in onnxscript
Response: This includes dtype dependent logic. While they cannot be
expressed by ONNX functions, we can use the graph building experience to
capture the logic statically into the ONNX model.
Risk: onnxscript may not have APIs for building control flow ops
Response: We need to take examples and design the graphbuilder
capability to support this kind of subgraphs.
Risk: There are a lot of ATen operators
...
Risk: Performance: Functions may contain unnecessary logic from ATen that is not involved in particular computation
Example
We define the ATen add operator in onnxscript as
def aten_add(self, other, alpha: float = 1) -> TensorType:
# add.Tensor(Tensor self, Tensor other, \*, Scalar alpha=1) -> Tensor
if alpha != 1:
other = op.Mul(other, alpha)
return op.Add(self, other)
Note that there is an attribute that needed a conditional.
Response: The library and the exporter concern the correctness of the
model and preserves as much logic as possible. Optimization is the
responsibility of downstream passes and runtimes.
Open questions
- Object model for the sources and how they should be represented in
onnx
- What should we do about prim torch ops? Do they resemble ATen ops?
We assume they do.
a. Wei-Sheng: eventually we want to support them for model
coverage. However, ONNXRuntime won't care since if we decompose
big ops into prims, ORT needs to fuse them back.
- How should we handle custom ONNX ops?
- Do we consider different ONNX opset versions?
a. No. It is the responsibility of downstream version conversion
passes/tools
Assumptions made
- Onnxscript/onnx with handle opset conversion so implementation
only needs to be done on the latest opset.
- The function library is only used to handle ATen operators, for both
dynamo and fx traced graphs. Cases for prim:: ops (if, etc.) are
rare and can be handled case by case.
- The decomposition logic to onnx cannot be easily generated, so we
still need to hand write them as onnxscript functions.
Glossary
"Functions"
, "onnx functions"
, "onnx functions written in onnxscript"
are used interchangeably.
"Function lib"
The ATen op library being described in this document.
"Proto"
An onnx protobuf sterilization.
"Input/output"
When not clearly stated, they refer to the input and
output of an operator, instead of a graph.
"Op"
, "operator"
An operator in a computation graph.
"ATen Op".
A computational operator defined by PyTorch, e.g. one
specified in native_functions.yaml. It also broadly refers to the
"torch IR"
defined in
https://docs-preview.pytorch.org/90644/ir.html , or the "prims"
operators.
"Exporter"
, "Converter"
Model converters. Usually the torch.onnx
PyTorch->ONNX model converter is considered.