@@ -108,9 +108,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
108
108
std::ostringstream oss;
109
109
InitVarStub (oss, seqlen_k_);
110
110
shader.MainFunctionBody () << oss.str ();
111
- shader.MainFunctionBody () << " let kOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.kv_sequence_length * uniforms.K;\n " ;
111
+ shader.MainFunctionBody () << " let kOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.kv_sequence_length * uniforms.K;\n " ;
112
112
if (has_present_key_) {
113
- shader.MainFunctionBody () << " let presentKeyOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.present_sequence_length * uniforms.K;\n " ;
113
+ shader.MainFunctionBody () << " let presentKeyOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.present_sequence_length * uniforms.K;\n " ;
114
114
}
115
115
116
116
shader.MainFunctionBody () << " var value = f32_val_t(0);\n "
@@ -123,7 +123,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
123
123
124
124
if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
125
125
shader.MainFunctionBody () << " if (n + local_id.y < past_sequence_length) {\n "
126
- << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.past_sequence_length * uniforms.K;\n "
126
+ << " let pastKeyOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.past_sequence_length * uniforms.K;\n "
127
127
<< " tileK[idx] = " << (past_present_share_buffer_ ? " present_key" : " past_key" ) << " [pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n "
128
128
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n "
129
129
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n "
@@ -181,7 +181,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
181
181
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1 );
182
182
183
183
AttentionProbsProgram program{" AttentionProbs" , feed_past_key, has_present_key, has_attention_bias, tile_size,
184
- components, parameters.is_first_prompt_ , parameters. n_reps , seqlen_k, parameters.past_present_share_buffer_ };
184
+ components, parameters.is_first_prompt_ , seqlen_k, parameters.past_present_share_buffer_ };
185
185
program.AddInputs ({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
186
186
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
187
187
if (feed_past_key) {
@@ -331,9 +331,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
331
331
std::ostringstream oss;
332
332
InitVarStub (oss, seqlen_k_);
333
333
shader.MainFunctionBody () << oss.str ();
334
- shader.MainFunctionBody () << " let vOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.kv_sequence_length + n;\n " ;
334
+ shader.MainFunctionBody () << " let vOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.N * uniforms.kv_sequence_length + n;\n " ;
335
335
if (has_present_value_) {
336
- shader.MainFunctionBody () << " let presentValueOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.present_sequence_length + n;\n " ;
336
+ shader.MainFunctionBody () << " let presentValueOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.N * uniforms.present_sequence_length + n;\n " ;
337
337
}
338
338
339
339
shader.MainFunctionBody () << " var value = output_value_t(0);\n "
@@ -346,7 +346,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
346
346
347
347
if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
348
348
shader.MainFunctionBody () << " if (w + local_id.y < past_sequence_length) {\n "
349
- << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.past_sequence_length + n;\n "
349
+ << " let pastValueOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.N * uniforms.past_sequence_length + n;\n "
350
350
<< " tileK[idx] = " << (past_present_share_buffer_ ? " present_value" : " past_value" ) << " [pastValueOffset + (w + local_id.y) * uniforms.N];\n "
351
351
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n "
352
352
<< " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n "
@@ -400,7 +400,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
400
400
const int components = parameters.v_head_size_ % 4 == 0 ? 4 : (parameters.v_head_size_ % 2 == 0 ? 2 : 1 );
401
401
constexpr int tile_size = 12 ;
402
402
int tile_n_size = tile_size * components;
403
- VxAttentionScoreProgram program{" VxAttentionScore" , feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_ , parameters. n_reps , seqlen_k, parameters.past_present_share_buffer_ };
403
+ VxAttentionScoreProgram program{" VxAttentionScore" , feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_ , seqlen_k, parameters.past_present_share_buffer_ };
404
404
program.AddInputs ({{probs, ProgramTensorMetadataDependency::TypeAndRank},
405
405
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
406
406
if (feed_past_value) {
0 commit comments