Skip to content

Commit f6b32b4

Browse files
titaiwangmsZhao-Xu Luo
authored and
Zhao-Xu Luo
committed
Support Gemma3 with Clip fused attention (microsoft#24280)
### Description <!-- Describe your changes. --> Essentially, the vision model is traced differently (this time it's without mask.), and the input indices of op.Add and op.MatMul can be different. Also, fp16 and fp32 need different tracing patterns (op.Cast). 1. Add another traced pattern to CLIP attention to cover no attention_mask case 2. Accept different index of input on op.Add and op.MatMul (be more general) 3. fp16 and fp32 shows different pattern (op.Cast after op.Softmax) 4. Refactor test_fastgelu.py to cover torch.onnx.export(..., dynamo=True) 5. Add gemma3 vision attention (SigLip) test to cover both fp16 and fp32 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> To optimize Gemma3 multi-modal model, the changes are needed. https://huggingface.co/google/gemma-3-4b-it NOTE: some related follow-ups (upstream optimizations to onnxscript-optimizer): microsoft/onnxscript#2158 microsoft/onnxscript#2156
1 parent bb1376c commit f6b32b4

File tree

9 files changed

+336
-56
lines changed

9 files changed

+336
-56
lines changed

onnxruntime/python/tools/transformers/fusion_attention_clip.py

+62-21
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
126126
if node_before_layer_norm is None:
127127
continue
128128
child = self.model.find_first_child_by_type(
129-
node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False
129+
node_before_layer_norm,
130+
"LayerNormalization",
131+
input_name_to_nodes,
132+
False,
130133
)
131134
if child is None:
132135
continue
@@ -146,19 +149,26 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
146149
qkv_nodes = self.model.match_parent_path(
147150
normalize_node,
148151
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
149-
[1, 1, 0, 0, 0],
152+
[1, None, 0, 0, 0],
150153
)
151154
if qkv_nodes is None:
152155
logger.debug("fuse_attention: failed to match qkv path")
153156
return
154-
155-
reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes[2], qkv_nodes[3], qkv_nodes[-1]
157+
reshape_qkv, transpose_qkv, matmul_qkv = (
158+
qkv_nodes[2],
159+
qkv_nodes[3],
160+
qkv_nodes[-1],
161+
)
156162

157163
v_nodes = self.model.match_parent_path(
158-
matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
164+
matmul_qkv,
165+
["Reshape", "Transpose", "Reshape", "Add", "MatMul"],
166+
[1, 0, 0, 0, None],
159167
)
160168
if v_nodes is None:
161-
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 1])
169+
v_nodes = self.model.match_parent_path(
170+
matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]
171+
)
162172
if v_nodes is None:
163173
logger.debug("fuse_attention: failed to match v path")
164174
return
@@ -182,17 +192,30 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
182192
)
183193
if qk_nodes is None:
184194
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
185-
if qk_nodes is None:
186-
qk_nodes = self.model.match_parent_path(
187-
matmul_qkv, ["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0, 0, 0]
188-
)
189-
if qk_nodes is None:
190-
logger.debug("fuse_attention: failed to match qk path")
191-
return
192-
else:
193-
add_mask = qk_nodes[3]
194-
else:
195+
if qk_nodes is not None:
195196
add_mask = qk_nodes[1]
197+
else:
198+
# If attention mask is not used, we can still match the qk path.
199+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
200+
if qk_nodes is None:
201+
# Cast nodes are added in the model for fp16.
202+
qk_nodes = self.model.match_parent_path(
203+
matmul_qkv,
204+
["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"],
205+
[0, 0, 0, 0, 0, 0],
206+
)
207+
if qk_nodes is not None:
208+
add_mask = qk_nodes[3]
209+
else:
210+
# If attention mask is not used, we can still match the qk path.
211+
qk_nodes = self.model.match_parent_path(
212+
matmul_qkv,
213+
["Cast", "Cast", "Softmax", "Mul", "MatMul"],
214+
[0, 0, 0, 0, 0],
215+
)
216+
if qk_nodes is None:
217+
logger.debug("fuse_attention: failed to match qk path")
218+
return
196219
else:
197220
assert len(add_mask_indices) == 1
198221
causal_mask_input_index = 1 - add_mask_indices[0]
@@ -201,10 +224,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
201224
matmul_qk = qk_nodes[-1]
202225

203226
q_nodes = self.model.match_parent_path(
204-
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
227+
matmul_qk,
228+
["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"],
229+
[0, 0, 0, 0, None, None],
205230
)
206231
if q_nodes is None:
207-
q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 1])
232+
q_nodes = self.model.match_parent_path(
233+
matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None]
234+
)
208235
if q_nodes is None:
209236
logger.debug("fuse_attention: failed to match q path")
210237
return
@@ -216,10 +243,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
216243
add_q, matmul_q = q_nodes[-2], q_nodes[-1]
217244

218245
k_nodes = self.model.match_parent_path(
219-
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
246+
matmul_qk,
247+
["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
248+
[1, 0, 0, 0, 0, None],
220249
)
221250
if k_nodes is None:
222-
k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 1])
251+
k_nodes = self.model.match_parent_path(
252+
matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]
253+
)
223254
if k_nodes is None:
224255
logger.debug("fuse_attention: failed to match k path")
225256
return
@@ -242,7 +273,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
242273
# 4D Add after Q x K'
243274
add_qk_nodes = self.model.match_parent_path(
244275
add_mask,
245-
["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze", "Reshape", "Reshape", "Cast"],
276+
[
277+
"Where",
278+
"Sub",
279+
"Cast",
280+
"Expand",
281+
"Unsqueeze",
282+
"Unsqueeze",
283+
"Reshape",
284+
"Reshape",
285+
"Cast",
286+
],
246287
[1, 2, 1, 0, 0, 0, 0, 0, 0],
247288
)
248289
if add_qk_nodes is not None:

onnxruntime/python/tools/transformers/fusion_fastgelu.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,12 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict
177177
return
178178
mul_after_mul_half = children[0]
179179

180+
# root_node could be None when root_input is graph input
180181
root_node = self.model.get_parent(
181182
mul_after_mul_half,
182183
0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1,
183184
output_name_to_node,
184185
)
185-
if root_node is None:
186-
return
187186

188187
mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
189188
if mul_before_tanh is None:
@@ -197,7 +196,13 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict
197196
if add_before_tanh is None:
198197
return
199198

200-
mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node])
199+
mul_after_pow = self.model.match_parent(
200+
add_before_tanh,
201+
"Mul",
202+
None,
203+
output_name_to_node,
204+
exclude=[root_node] if root_node else [],
205+
)
201206
if mul_after_pow is None:
202207
return
203208

@@ -212,7 +217,9 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict
212217
if not self.model.has_constant_input(pow, 3.0):
213218
return
214219

215-
if pow.input[0] != root_node.output[0]:
220+
root_input = mul_after_mul_half.input[0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1]
221+
222+
if pow.input[0] != root_input:
216223
return
217224

218225
subgraph_nodes = [
@@ -236,7 +243,7 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict
236243
self.nodes_to_remove.extend(subgraph_nodes)
237244
fused_node = helper.make_node(
238245
"FastGelu",
239-
inputs=[root_node.output[0]],
246+
inputs=[root_input],
240247
outputs=mul_after_mul_half.output,
241248
name=self.model.create_node_name("FastGelu"),
242249
)

onnxruntime/test/python/transformers/test_gelu_fusions.py

+42-29
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest
44

55
import torch
6+
from parameterized import parameterized
67
from parity_utilities import find_transformers_source
78

89
if find_transformers_source():
@@ -43,16 +44,6 @@ def forward(self, x):
4344
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
4445

4546

46-
test_cases = [
47-
("huggingface", "Gelu", HuggingfaceGelu),
48-
("huggingface", "FastGelu", HuggingfaceFastGelu),
49-
("huggingface", "QuickGelu", HuggingfaceQuickGelu),
50-
("huggingface", "FastGelu", HuggingfaceTorchGeluTanh),
51-
("megatron", "Gelu", MegatronGelu),
52-
("megatron", "FastGelu", MegatronFastGelu),
53-
]
54-
55-
5647
class TestGeluFusions(unittest.TestCase):
5748
def verify_node_count(self, bert_model, expected_node_count, test_name):
5849
for op_type, count in expected_node_count.items():
@@ -62,25 +53,47 @@ def verify_node_count(self, bert_model, expected_node_count, test_name):
6253
print(f"{op}: {len(bert_model.get_nodes_by_op_type(op))} expected={counter}")
6354
self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count)
6455

65-
def test_fusions(self):
66-
for test_case in test_cases:
67-
source, operator, model_class = test_case
68-
model = model_class()
69-
dummy_input = torch.ones(3, dtype=torch.float32)
70-
test_name = f"{operator}_{source}"
71-
onnx_path = f"{test_name}.onnx"
72-
torch.onnx.export(
73-
model,
74-
(dummy_input),
75-
onnx_path,
76-
input_names=["input"],
77-
output_names=["output"],
78-
)
79-
optimizer = optimize_model(onnx_path, "bert")
80-
# optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx")
81-
os.remove(onnx_path)
82-
expected_node_count = {operator: 1}
83-
self.verify_node_count(optimizer, expected_node_count, test_name)
56+
@parameterized.expand(
57+
[
58+
(("huggingface", "Gelu", HuggingfaceGelu), True),
59+
(("huggingface", "FastGelu", HuggingfaceFastGelu), True),
60+
(("huggingface", "QuickGelu", HuggingfaceQuickGelu), True),
61+
(("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), True),
62+
(("megatron", "Gelu", MegatronGelu), True),
63+
(("megatron", "FastGelu", MegatronFastGelu), True),
64+
(("huggingface", "Gelu", HuggingfaceGelu), False),
65+
(("huggingface", "FastGelu", HuggingfaceFastGelu), False),
66+
(("huggingface", "QuickGelu", HuggingfaceQuickGelu), False),
67+
(("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), False),
68+
(("megatron", "Gelu", MegatronGelu), False),
69+
(("megatron", "FastGelu", MegatronFastGelu), False),
70+
]
71+
)
72+
def test_fusions(self, test_case, dynamo):
73+
source, operator, model_class = test_case
74+
model = model_class()
75+
dummy_input = torch.ones(3, dtype=torch.float32)
76+
test_name = f"{operator}_{source}"
77+
onnx_path = f"{test_name}.onnx"
78+
torch.onnx.export(
79+
model,
80+
(dummy_input,),
81+
onnx_path,
82+
input_names=["input"],
83+
output_names=["output"],
84+
dynamo=dynamo,
85+
optimize=True, # Only meaningful when dynamo is True
86+
)
87+
optimizer = optimize_model(onnx_path, "bert")
88+
# optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx")
89+
os.remove(onnx_path)
90+
# Remove the associated .data file (dynamo)
91+
data_path = onnx_path + ".data"
92+
if os.path.exists(data_path):
93+
os.remove(data_path)
94+
expected_node_count = {operator: 1}
95+
96+
self.verify_node_count(optimizer, expected_node_count, test_name)
8497

8598

8699
if __name__ == "__main__":

0 commit comments

Comments
 (0)