Skip to content

Pattern match on default attributes? #2012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
leshabirukov opened this issue Jan 16, 2025 · 11 comments
Open

Pattern match on default attributes? #2012

leshabirukov opened this issue Jan 16, 2025 · 11 comments
Assignees
Labels
category: question A question about something module: rewriter

Comments

@leshabirukov
Copy link
Contributor

Problem: rewrite Conv nodes, but only with dilations==[1,1]. Issue - default value. To check dilation, I need include it in pattern, but this cause pattern match to fail on nodes with default dilations.

Probably related: #1627

Possible fix: substitute defaults on deserialization, I suspect, such a function already exists.

rewr_c.py.txt

@justinchuby
Copy link
Collaborator

Thanks. So a node with default and without default is store differently in graph. They are semantically the same but syntactically different. We need to decide what the behavior of the matcher should be.

@justinchuby justinchuby added category: question A question about something module: rewriter labels Jan 16, 2025
@leshabirukov
Copy link
Contributor Author

Well, it can be outside of matcher, if one could force defaults for node.

But as I found, it may be not formalized at all!

from onnx import helper, reference

schema = reference.op_run.get_schema('Conv', 15)
if schema.attributes:
    for _, attr in sorted(schema.attributes.items()):
        default_value = helper.get_attribute_value(attr.default_value)
        print(f'{attr.name=:20}{default_value=}')

attr.name='auto_pad'
default_value=b'NOTSET'
attr.name='dilations'
default_value=None
attr.name='group'
default_value=1
attr.name='kernel_shape'
default_value=None
attr.name='pads'
default_value=None
attr.name='strides'
default_value=None
onnx.reference.op_run.get_schema('Conv', 11).attributes['dilations']

OpSchema.Attribute(name='dilations', type=<AttrType.INTS: 7>, description='dilation value along each spatial axis of the filter. If not present, the dilation defaults is 1 along each spatial axis.', default_value=, required=False)

I wonder, how it handled in onnxruntime, just hardcoded?

@leshabirukov
Copy link
Contributor Author

I would rather put defaults handler outside of matcher, with function like force_defaults(node) =>node, forced keys probably paired with strip_defaults

@Johansmm
Copy link
Contributor

Johansmm commented Mar 5, 2025

I had the same problem. For now I solved it through a condition function:

def target_pattern(op, x, w):
        return op.Conv(x, w, _outputs=["y"])

def replacement_pattern(op, x, w, **__):
        # Naïve pattern
       return op.Identity(op.Conv(x, w))

def condition_fn(*_, y, **__):
       ir_node = y.producer()
       return (dilatation := ir_node.attributes.get("dilatation", None)) and dilatation.value == [1,1]

rule = RewriteRule(target_pattern, replacement_pattern, condition_fn)

I tried the following target_pattern:

def target_pattern(op, x, w, dilatation=[1,1]):
        return op.Conv(x, w, dilatation=dilatation)

But unfortunately the condition was never triggered.

Let me know if there is a simpler solution.

@justinchuby justinchuby changed the title substitute defaults on deserialization Pattern match on attributes? Mar 5, 2025
@justinchuby justinchuby changed the title Pattern match on attributes? Pattern match on default attributes? Mar 5, 2025
@leshabirukov
Copy link
Contributor Author

@Johansmm hey, yes, it is better then my force/strip_defaults!

@Johansmm
Copy link
Contributor

Johansmm commented Mar 9, 2025

But @leshabirukov what do you think of my suggestion? Let the user define the default value in the target function:

def target_pattern(op, x, w, dilations = [1,1]):
        return op.Conv(x, w, dilations=dilations)

Then, replacement_target receives dilations = [1,1] if matcher finds a Conv node without this parameter. Otherwise (dilations is defined in the conv node), the value must be equal to [1,1] to match.

@leshabirukov
Copy link
Contributor Author

@Johansmm, I hesitate here.

I agree, it is neat, understandable semantics.

Questions to think about:
It have same issue, i don't like in my own solution, - it is not preserve <defaulted\not defaulted> state. I prefer 'defaulted' as leading to minimal data structure, but general user can have another preferences.

Was defaults supposed to be ancored to the onnx standart? I mean is it ok to set defaulted Pad values to Reflect mode (because your hardware prefer it), thow, standart says, default mode is 'constant'.
If we obey standart, it is better to pull defaults from scheme. But scheme defaults for Conv - dilations is funny.

Implementation can unveil some other issues, I can't say how to exactly parse following:

    def target_pattern(op, x, w, dilatation=[1, 1]):
        y = op.Conv(x, w, dilatation=dilatation)
        return op.Conv(y, w, dilatation=dilatation)

@Johansmm
Copy link
Contributor

@leshabirukov I agree with you that it might be better to infer the default values of the schema, but I don't think it is the purpose of onnxscript to hardcode all the parameters that are not defined in the ONNX schemas.

What I like about my solution is that in replacement_target I get the value that I define in target_pattern: If I define a dilations=[1,1] or dilations = "hello_word", it is just that value that I will get in 'replacement_target.dilations', because with that I can make decisions of what I want to do. Naive example:

def target_pattern(op, x, w, dilations = "hello_world"):
    return op.Conv(x, w, dilations=dilations)

def replacement_pattern(op, x, w, dilations):
    if dilations == "hello_world":
        # Remove dilations
        y = op.Conv(x, w)
    else:
        # Include dilations
        y = op.Conv(x, w, dilations=dilations)
    return y

I'm not sure it's a good solution, but at least I think the user has more control.

@leshabirukov
Copy link
Contributor Author

Meh. I hesitate to judge, because I am not one, who will implement it. Behind nice and convenient interface of python-function-as-pattern, I sence a lot of support work, making it consistent with Python rules and user's expectation.
As I see, function formal parameters are bound 'from inside out', in 'dilation=[1,1]' notation 'dilation' part "go backward" to bind parameters of restriction and patch functions, at the same time '[1,1]' "go forward" and affect patterns body. It makes me feel kind of uneasy.

And mechanics that already work sometimes fail:
#2036

@Johansmm
Copy link
Contributor

@leshabirukov and how about including your idea about force_defaults through an additional parameter? Something like what _allow_other_inputs and _allow_other_attributes does?

e.g.:

def target_pattern(op, x, w):
        return op.Conv(x, w, force_defaults=True)

def replacement_pattern(op, x, w, dilations, **__):
        # Naïve pattern
       return op.Identity(op.Conv(x, w, dilations=dilations))

rule = RewriteRule(target_pattern, replacement_pattern)

I could work on this, but I would like to define a workable solution.

@leshabirukov
Copy link
Contributor Author

@Johansmm, let my groans not discourage you! I have no question to proposed syntax, it is nice and clear, it is implementation, my concerns about. If you or somebody manage to made it smoothly, it would be powerful addition to rewriter.

Still, I would take a pause and gather more usecases, good thing condition function variant is workable. Clearifying about scheme defaults from devs would be especially great.

As about force_defaults I still can't replace it in case of optional input such as Conv's bias.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: question A question about something module: rewriter
Projects
None yet
Development

No branches or pull requests

4 participants