Skip to content

Non-nested graph definitions? #2154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dpotop opened this issue Apr 1, 2025 · 3 comments
Open

Non-nested graph definitions? #2154

dpotop opened this issue Apr 1, 2025 · 3 comments
Labels
category: question A question about something

Comments

@dpotop
Copy link

dpotop commented Apr 1, 2025

Hello,

ONNX and onnxscript already allow the definition of hierarchical graphs, for operators such as Scan. However, as far as I understand, the nested graphs must have a nested definition, local to the operator using the Scan. This precludes any form of parametric graph definition, which could come handy, for instance, in the definition of a general RNN operator (that takes a Cell graph as argument).

I'd like to know if such a parametric graph definition style is compatible with onnxscript principles (nothing seems to prohibit it in ONNX, but I may be missing something). If it's compatible, I'd like to help adding it to the formalism (but some hints on how to do it would help).

Regards,
Dumitru

@justinchuby justinchuby added the category: question A question about something label Apr 1, 2025
@justinchuby
Copy link
Collaborator

@gramalingam

@justinchuby
Copy link
Collaborator

@dpotop would you like to propose an example usage show show what you have in mind?

@dpotop
Copy link
Author

dpotop commented Apr 3, 2025

@dpotop would you like to propose an example usage show show what you have in mind?

Definitely! I will use the simplest example I can find in the distribution. Its original form is:

def cumulative_sum(X: INT64["N"]):
    """Test use of a nested-function as a graph-attribute, using the Scan operator."""
    @graph()
    def Body(sum_in,next):
        sum_out = sum_in + next
        scan_out = op.Identity(sum_out)
        return sum_out, scan_out
    
    zero = op.Constant(value_int=0)
    _, result = op.Scan(zero, X, body=Body, num_scan_inputs=1)
    return result

Ideally, I'd like to allow something like:

@graph()
def Body(sum_in,next):
    sum_out = sum_in + next
    scan_out = op.Identity(sum_out)
    return sum_out, scan_out
    
@script()
def cumulative_sum(body_graph,X: INT64["N"]):
    """Test use of a nested-function as a graph-attribute, using the Scan operator."""
    zero = op.Constant(value_int=0)
    _, result = op.Scan(zero, X, body=body_graph, num_scan_inputs=1)
    return result

Notice the body_graph passed in parameter to cumulative_sum.

If this is too difficult (because @graph() objects must be nested) , maybe we could use the ideas behind this inliner python package, and the code could be:

@inline
def body_fun(sum_in,next):
    sum_out = sum_in + next
    scan_out = op.Identity(sum_out)
    return sum_out, scan_out
    
@script()
def cumulative_sum(body_fun,X: INT64["N"]):
    """Test use of a nested-function as a graph-attribute, using the Scan operator."""
    @graph()
    def Body(sum_in,next):
        return body_fun(sum_in,next)

    zero = op.Constant(value_int=0)
    _, result = op.Scan(zero, X, body=Body, num_scan_inputs=1)
    return result

In this second case, the difficulty is that the inliner package is deprecated, requiring some update work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: question A question about something
Projects
None yet
Development

No branches or pull requests

2 participants