-
Notifications
You must be signed in to change notification settings - Fork 64
TorchLib function authoring guide
Updated: July 2023
Authors: @justinchuby @titaiwangms
TorchLib functions are pure data. This means we avoid defining runtime behavior as code in the functions.
The main goal of torchlib is to convert PyTorch model to onnx model. Therefore, we need to understand the function signature in Pytorch first. Namely, ATen operators. native_functions.yaml defines all the native function in PyTorch.
- func: func_name(ArgType arg0[=default], ArgType arg1[=default], ...) -> Return
variants: function, method
dispatch:
CPU: func_cpu
CUDA: func_cuda
The developer should be careful to the ArgType
. Different ArgType
matches to different TypeVar
in torchlib.
The decorator torch_op
is used to officially register the function into torchlib.
def torch_op(
name: str | tuple[str, ...],
*,
registry: Optional[Registry] = None,
trace_only: bool = False,
private: bool = False,
complex: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
"""Register a torch op.
Args:
name: Qualified ATen name of the function. E.g. "aten::relu", "aten::add.Tensor".
Or a tuple of names e.g. ("aten::add.Scalar", "aten::add.Tensor").
Default overloads should be specified by omitting the overload part,
i.e. "aten::relu" instead of "aten::relu.default".
registry: Registry to register the function to. If None, the default registry is used.
trace_only: Whether the function should only be traced and not compiled.
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function supports complex.
"""
...
trace_only
extends the script()
to include complicated control-flow with the class TracedOnnxFunction, which instead of compiles the whole function with control flow into OnnxFunction, only traces it as a normal Python function to accommodate the unsupported control-flow.
- Name a function starting with the namespace it's from. For example,
aten_abs
orprims_abs
. - Correctly annotate the inputs and attributes with
native_function.yaml
.
Use one or create one TypeVar
in tensor_typing to match the ArgType
from native_functions.yaml. In most of the cases, inputs should all be tensor types, and attributes should be primitive types. However, depends on the implementation of the OnnxFunction, the scenario changes case by case to align with the requirements of used onnx operators in the function.