Skip to content

TorchLib function authoring guide

Ti-Tai Wang edited this page Aug 21, 2023 · 10 revisions

Updated: July 2023
Authors: @justinchuby @titaiwangms

Concepts

TorchLib functions are pure data. This means we avoid defining runtime behavior as code in the functions.

Check native_functions.yaml

The main goal of torchlib is to convert PyTorch model to onnx model. Therefore, we need to understand the function signature in Pytorch first. Namely, ATen operators. native_functions.yaml defines all the native function in PyTorch.

- func: func_name(ArgType arg0[=default], ArgType arg1[=default], ...) -> Return
  variants: function, method
  dispatch:
    CPU: func_cpu
    CUDA: func_cuda

The developer should be careful to the ArgType. Different ArgType matches to different TypeVar in torchlib.

Implement OnnxFunction/TracedOnnxFunction

The decorator: torch_op

The decorator torch_op is used to officially register the function into torchlib.

def torch_op(
    name: str | tuple[str, ...],
    *,
    registry: Optional[Registry] = None,
    trace_only: bool = False,
    private: bool = False,
    complex: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
    """Register a torch op.

    Args:
        name: Qualified ATen name of the function. E.g. "aten::relu", "aten::add.Tensor".
            Or a tuple of names e.g. ("aten::add.Scalar", "aten::add.Tensor").
            Default overloads should be specified by omitting the overload part,
            i.e. "aten::relu" instead of "aten::relu.default".
        registry: Registry to register the function to. If None, the default registry is used.
        trace_only: Whether the function should only be traced and not compiled.
        private: Whether the function is private (not directly exposed). It should
            be true for all functions with names starting with "_".
        complex: Whether the function supports complex.
    """
    ...

trace_only extends the script() to include complicated control-flow with the class TracedOnnxFunction, which instead of compiles the whole function with control flow into OnnxFunction, only traces it as a normal Python function to accommodate the unsupported control-flow.

Function signature

  • Name a function starting with the namespace it's from. For example, aten_abs or prims_abs.
  • Correctly annotate the inputs and attributes with native_function.yaml.

Use one or create one TypeVar in tensor_typing to match the ArgType from native_functions.yaml. In most of the cases, inputs should all be tensor types, and attributes should be primitive types. However, depends on the implementation of the OnnxFunction, the scenario changes case by case to align with the requirements of used onnx operators in the function.

Function body

Test