@@ -126,7 +126,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
126
126
if node_before_layer_norm is None :
127
127
continue
128
128
child = self .model .find_first_child_by_type (
129
- node_before_layer_norm , "LayerNormalization" , input_name_to_nodes , False
129
+ node_before_layer_norm ,
130
+ "LayerNormalization" ,
131
+ input_name_to_nodes ,
132
+ False ,
130
133
)
131
134
if child is None :
132
135
continue
@@ -146,19 +149,26 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
146
149
qkv_nodes = self .model .match_parent_path (
147
150
normalize_node ,
148
151
["Add" , "MatMul" , "Reshape" , "Transpose" , "MatMul" ],
149
- [1 , 1 , 0 , 0 , 0 ],
152
+ [1 , None , 0 , 0 , 0 ],
150
153
)
151
154
if qkv_nodes is None :
152
155
logger .debug ("fuse_attention: failed to match qkv path" )
153
156
return
154
-
155
- reshape_qkv , transpose_qkv , matmul_qkv = qkv_nodes [2 ], qkv_nodes [3 ], qkv_nodes [- 1 ]
157
+ reshape_qkv , transpose_qkv , matmul_qkv = (
158
+ qkv_nodes [2 ],
159
+ qkv_nodes [3 ],
160
+ qkv_nodes [- 1 ],
161
+ )
156
162
157
163
v_nodes = self .model .match_parent_path (
158
- matmul_qkv , ["Reshape" , "Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , 0 , None ]
164
+ matmul_qkv ,
165
+ ["Reshape" , "Transpose" , "Reshape" , "Add" , "MatMul" ],
166
+ [1 , 0 , 0 , 0 , None ],
159
167
)
160
168
if v_nodes is None :
161
- v_nodes = self .model .match_parent_path (matmul_qkv , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , 1 ])
169
+ v_nodes = self .model .match_parent_path (
170
+ matmul_qkv , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , None ]
171
+ )
162
172
if v_nodes is None :
163
173
logger .debug ("fuse_attention: failed to match v path" )
164
174
return
@@ -182,17 +192,30 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
182
192
)
183
193
if qk_nodes is None :
184
194
qk_nodes = self .model .match_parent_path (matmul_qkv , ["Softmax" , "Add" , "Mul" , "MatMul" ], [0 , 0 , 0 , 0 ])
185
- if qk_nodes is None :
186
- qk_nodes = self .model .match_parent_path (
187
- matmul_qkv , ["Cast" , "Cast" , "Softmax" , "Add" , "Mul" , "MatMul" ], [0 , 0 , 0 , 0 , 0 , 0 ]
188
- )
189
- if qk_nodes is None :
190
- logger .debug ("fuse_attention: failed to match qk path" )
191
- return
192
- else :
193
- add_mask = qk_nodes [3 ]
194
- else :
195
+ if qk_nodes is not None :
195
196
add_mask = qk_nodes [1 ]
197
+ else :
198
+ # If attention mask is not used, we can still match the qk path.
199
+ qk_nodes = self .model .match_parent_path (matmul_qkv , ["Softmax" , "Mul" , "MatMul" ], [0 , 0 , 0 ])
200
+ if qk_nodes is None :
201
+ # Cast nodes are added in the model for fp16.
202
+ qk_nodes = self .model .match_parent_path (
203
+ matmul_qkv ,
204
+ ["Cast" , "Cast" , "Softmax" , "Add" , "Mul" , "MatMul" ],
205
+ [0 , 0 , 0 , 0 , 0 , 0 ],
206
+ )
207
+ if qk_nodes is not None :
208
+ add_mask = qk_nodes [3 ]
209
+ else :
210
+ # If attention mask is not used, we can still match the qk path.
211
+ qk_nodes = self .model .match_parent_path (
212
+ matmul_qkv ,
213
+ ["Cast" , "Cast" , "Softmax" , "Mul" , "MatMul" ],
214
+ [0 , 0 , 0 , 0 , 0 ],
215
+ )
216
+ if qk_nodes is None :
217
+ logger .debug ("fuse_attention: failed to match qk path" )
218
+ return
196
219
else :
197
220
assert len (add_mask_indices ) == 1
198
221
causal_mask_input_index = 1 - add_mask_indices [0 ]
@@ -201,10 +224,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
201
224
matmul_qk = qk_nodes [- 1 ]
202
225
203
226
q_nodes = self .model .match_parent_path (
204
- matmul_qk , ["Reshape" , "Transpose" , "Reshape" , "Mul" , "Add" , "MatMul" ], [0 , 0 , 0 , 0 , None , None ]
227
+ matmul_qk ,
228
+ ["Reshape" , "Transpose" , "Reshape" , "Mul" , "Add" , "MatMul" ],
229
+ [0 , 0 , 0 , 0 , None , None ],
205
230
)
206
231
if q_nodes is None :
207
- q_nodes = self .model .match_parent_path (matmul_qk , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [0 , 0 , 0 , 1 ])
232
+ q_nodes = self .model .match_parent_path (
233
+ matmul_qk , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [0 , 0 , 0 , None ]
234
+ )
208
235
if q_nodes is None :
209
236
logger .debug ("fuse_attention: failed to match q path" )
210
237
return
@@ -216,10 +243,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
216
243
add_q , matmul_q = q_nodes [- 2 ], q_nodes [- 1 ]
217
244
218
245
k_nodes = self .model .match_parent_path (
219
- matmul_qk , ["Transpose" , "Reshape" , "Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , 0 , 0 , None ]
246
+ matmul_qk ,
247
+ ["Transpose" , "Reshape" , "Transpose" , "Reshape" , "Add" , "MatMul" ],
248
+ [1 , 0 , 0 , 0 , 0 , None ],
220
249
)
221
250
if k_nodes is None :
222
- k_nodes = self .model .match_parent_path (matmul_qk , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , 1 ])
251
+ k_nodes = self .model .match_parent_path (
252
+ matmul_qk , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , None ]
253
+ )
223
254
if k_nodes is None :
224
255
logger .debug ("fuse_attention: failed to match k path" )
225
256
return
@@ -242,7 +273,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
242
273
# 4D Add after Q x K'
243
274
add_qk_nodes = self .model .match_parent_path (
244
275
add_mask ,
245
- ["Where" , "Sub" , "Cast" , "Expand" , "Unsqueeze" , "Unsqueeze" , "Reshape" , "Reshape" , "Cast" ],
276
+ [
277
+ "Where" ,
278
+ "Sub" ,
279
+ "Cast" ,
280
+ "Expand" ,
281
+ "Unsqueeze" ,
282
+ "Unsqueeze" ,
283
+ "Reshape" ,
284
+ "Reshape" ,
285
+ "Cast" ,
286
+ ],
246
287
[1 , 2 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ],
247
288
)
248
289
if add_qk_nodes is not None :
0 commit comments