Skip to content

Commit bb1376c

Browse files
fs-eireZhao-Xu Luo
authored and
Zhao-Xu Luo
committed
[WebGPU] fix cache key of AttentionProbs/VxAttentionScore (microsoft#24309)
### Description fix the cache inconsistency of program AttentionProbs/VxAttentionScore `n_reps` is already in uniforms so do not use it from hardcoded.
1 parent 61cde37 commit bb1376c

File tree

4 files changed

+13
-16
lines changed

4 files changed

+13
-16
lines changed

.github/actions/webgpu-validate-shader-key/action.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ runs:
2222
working-directory: ${{ github.action_path }}
2323

2424
- name: Validate shader keys (native log)
25-
if: ${{ !inputs.is_chromium_log != 'true' }}
25+
if: ${{ inputs.is_chromium_log != 'true' }}
2626
shell: cmd
2727
run: |
2828
node validate-shader-key.js < "${{ inputs.log_file_path }}"

.github/workflows/windows-web-ci-workflow.yml

-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ jobs:
200200

201201
- name: Validate shader keys - WebGPU EP
202202
if: ${{ inputs.run_webgpu_tests == true && inputs.build_config == 'Debug' }}
203-
continue-on-error: true
204203
uses: ./.github/actions/webgpu-validate-shader-key
205204
with:
206205
log_file_path: ${{ runner.temp }}\web\test\07\chrome_debug.log

onnxruntime/contrib_ops/webgpu/bert/attention.cc

+8-8
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
108108
std::ostringstream oss;
109109
InitVarStub(oss, seqlen_k_);
110110
shader.MainFunctionBody() << oss.str();
111-
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
111+
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
112112
if (has_present_key_) {
113-
shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n";
113+
shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n";
114114
}
115115

116116
shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
@@ -123,7 +123,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
123123

124124
if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
125125
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
126-
<< " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n"
126+
<< " let pastKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n"
127127
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
128128
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
129129
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
@@ -181,7 +181,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
181181
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
182182

183183
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
184-
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
184+
components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_};
185185
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
186186
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
187187
if (feed_past_key) {
@@ -331,9 +331,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
331331
std::ostringstream oss;
332332
InitVarStub(oss, seqlen_k_);
333333
shader.MainFunctionBody() << oss.str();
334-
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
334+
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n";
335335
if (has_present_value_) {
336-
shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n";
336+
shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n";
337337
}
338338

339339
shader.MainFunctionBody() << "var value = output_value_t(0);\n"
@@ -346,7 +346,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
346346

347347
if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
348348
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
349-
<< " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n"
349+
<< " let pastValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n"
350350
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
351351
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
352352
<< " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n"
@@ -400,7 +400,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
400400
const int components = parameters.v_head_size_ % 4 == 0 ? 4 : (parameters.v_head_size_ % 2 == 0 ? 2 : 1);
401401
constexpr int tile_size = 12;
402402
int tile_n_size = tile_size * components;
403-
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
403+
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_};
404404
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
405405
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
406406
if (feed_past_value) {

onnxruntime/contrib_ops/webgpu/bert/attention.h

+4-6
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
3434
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
3535
public:
3636
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
37-
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
38-
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
37+
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
38+
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
3939
}
4040

4141
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -60,7 +60,6 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
6060
bool has_attention_bias_;
6161
int tile_size_;
6262
int components_;
63-
int n_reps_;
6463
const Tensor* seqlen_k_;
6564
bool past_present_share_buffer_;
6665
bool is_first_prompt_;
@@ -90,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
9089

9190
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
9291
public:
93-
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
94-
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
92+
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
93+
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
9594
}
9695

9796
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -114,7 +113,6 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
114113
bool feed_past_value_;
115114
bool has_present_value_;
116115
int tile_size_;
117-
int n_reps_;
118116
const Tensor* seqlen_k_;
119117
bool past_present_share_buffer_;
120118
bool is_first_prompt_;

0 commit comments

Comments
 (0)