-
Notifications
You must be signed in to change notification settings - Fork 207
/
Copy pathmodeling_qwen2_vl_network.py
1965 lines (1686 loc) · 87.2 KB
/
modeling_qwen2_vl_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Paddle Qwen2-VL model."""
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import paddle
import paddle.distributed as dist
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.distributed.auto_parallel.local_layer import LocalLayer
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute
from paddlenlp.transformers.configuration_utils import PretrainedConfig
from paddlenlp.transformers.linear_utils import Linear
from paddlenlp.transformers.model_outputs import BaseModelOutputWithPast, ModelOutput
from paddlenlp.transformers.model_utils import PretrainedModel
from paddlemix.models.flash_attn_utils import (
create_attention_module,
has_flash_attn_func,
)
from ppdiffusers.utils import logging
from ...activations import ACT2FN
from .bert_padding import ( # index_first_axis,; pad_input,
IndexFirstAxis,
IndexPutFirstAxis,
unpad_input,
)
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig
logger = logging.get_logger(__name__)
flash_attn_func, flash_attn_varlen_func = has_flash_attn_func()
_IS_NPU = "npu" in paddle.get_device()
def get_triangle_upper_mask(x, mask=None):
if mask is not None:
return mask
# [bsz, n_head, q_len, kv_seq_len]
shape = x.shape
# [bsz, 1, q_len, kv_seq_len]
shape[1] = 1
mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
mask = paddle.triu(mask, diagonal=1)
mask.stop_gradient = True
return mask
def _compute_default_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["paddle.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["paddle.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`paddle.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`paddle.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (paddle.arange(0, dim, 2, dtype="int64").astype("float32") / dim))
return inv_freq, attention_factor
ROPE_INIT_FUNCTIONS = {
"default": _compute_default_rope_parameters,
}
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(axis=-1, dtype="int32")
indices = paddle.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() # [2, 1, 1323]
cu_seqlens = F.pad(paddle.cumsum(seqlens_in_batch, axis=0), (1, 0), data_format="NCL")
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def is_casual_mask(attention_mask):
"""
Upper triangular of attention_mask equals to attention_mask is casual
"""
return (paddle.triu(attention_mask) == attention_mask).all().item()
def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
Make causal mask used for self-attention
"""
batch_size, target_length = input_ids_shape # target_length: seq_len
mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
if past_key_values_length > 0:
# [tgt_len, tgt_len + past_len]
mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1)
# [bs, 1, tgt_len, tgt_len + past_len]
return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length])
def _expand_2d_mask(mask, dtype, tgt_length):
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape[0], mask.shape[-1]
tgt_length = tgt_length if tgt_length is not None else src_length
mask = mask[:, None, None, :].astype("bool")
mask.stop_gradient = True
expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
return expanded_mask
@dataclass
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
"""
Base class for Qwen2VL causal language model (or autoregressive) outputs.
Args:
loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
rope_deltas (`paddle.Tensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
loss: Optional[paddle.Tensor] = None
logits: paddle.Tensor = None
past_key_values: Optional[List[paddle.Tensor]] = None
hidden_states: Optional[Tuple[paddle.Tensor]] = None
attentions: Optional[Tuple[paddle.Tensor]] = None
rope_deltas: Optional[paddle.Tensor] = None
class Qwen2VLRotaryEmbedding(nn.Layer):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Qwen2VLConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.original_inv_freq = self.inv_freq
self._set_cos_sin_cache(seq_len=max_position_embeddings)
def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
# [seq_len]
t = paddle.arange(seq_len, dtype="float32")
# [seq_len, dim/2]
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = paddle.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.inv_freq = inv_freq
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.inv_freq = self.original_inv_freq
self.max_seq_len_cached = self.original_max_seq_len
@paddle.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
# So we expand the inv_freq to shape (3, ...)
inv_freq_expanded = (
self.inv_freq[None, None, :, None].astype("float32").expand([3, position_ids.shape[1], -1, 1])
)
position_ids_expanded = position_ids[:, :, None, :].astype("float32") # shape (3, bs, 1, positions)
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = paddle.get_device()
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with paddle.amp.auto_cast():
# Compute frequencies by matrix multiplication and transpose
# inv_freq_expanded shape: [3, bs, dim/2, 1]
# position_ids_expanded shape: [3, bs, 1, positions]
# Result shape after matmul: [3, bs, dim/2, positions]
# After transpose: [3, bs, positions, dim/2]
freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded)
freqs = freqs.transpose([0, 1, 3, 2])
emb = paddle.concat((freqs, freqs), axis=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.astype(x.dtype), sin.astype(x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
Explanation:
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
difference with modern LLMs.
Args:
q (`paddle.Tensor`): The query tensor.
k (`paddle.Tensor`): The key tensor.
cos (`paddle.Tensor`): The cosine part of the rotary embedding.
sin (`paddle.Tensor`): The sine part of the rotary embedding.
position_ids (`paddle.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
mrope_section(`List(int)`):
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(paddle.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
# cos = cos[position_ids]
# sin = sin[position_ids]
mrope_section = mrope_section * 2
cos = paddle.concat(x=[m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1).unsqueeze(
axis=unsqueeze_dim
)
sin = paddle.concat(x=[m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1).unsqueeze(
axis=unsqueeze_dim
)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) -> paddle.Tensor:
orig_dtype = tensor.dtype
with paddle.amp.auto_cast(False):
tensor = tensor.astype(dtype="float32")
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32")
sin = sin.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32")
output = tensor * cos + rotate_half(tensor) * sin
output = paddle.cast(output, orig_dtype)
return output
class VisionRotaryEmbedding(nn.Layer):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.inv_freq = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2, dtype="float32") / dim)
def forward(self, seqlen: int) -> paddle.Tensor:
seq = paddle.arange(seqlen).cast(self.inv_freq.dtype)
freqs = paddle.outer(x=seq, y=self.inv_freq)
return freqs
class PatchEmbed(nn.Layer):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3D(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias_attr=False)
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.reshape(
[-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size]
)
# NOTE(changwenbin): AttributeError: 'Variable' object has no attribute 'to'.
# hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).reshape([-1, self.embed_dim])
hidden_states = self.proj(paddle.cast(hidden_states, dtype=target_dtype)).reshape([-1, self.embed_dim])
return hidden_states
class PatchMerger(nn.Layer):
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = nn.LayerNorm(context_dim, epsilon=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, dim),
)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size]))
return x
class VisionMlp(nn.Layer):
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.act = ACT2FN[hidden_act]
self.fc2 = nn.Linear(hidden_dim, dim)
def forward(self, x) -> paddle.Tensor:
return self.fc2(self.act(self.fc1(x)))
class VisionAttention(nn.Layer):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim)
self.head_dim = dim // num_heads # must added
def forward(
self, hidden_states: paddle.Tensor, cu_seqlens: paddle.Tensor, rotary_pos_emb: paddle.Tensor = None
) -> paddle.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = (
self.qkv(hidden_states).reshape([seq_length, 3, self.num_heads, -1]).transpose([1, 0, 2, 3]).unbind(0)
)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = paddle.zeros([1, seq_length, seq_length], dtype="bool")
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
zero = paddle.zeros(attention_mask.shape, dtype=hidden_states.dtype)
neg_inf = paddle.full_like(attention_mask, paddle.finfo(hidden_states.dtype).min, dtype=hidden_states.dtype)
attention_mask = paddle.where(attention_mask, zero, neg_inf)
q = q.transpose([1, 0, 2])
k = k.transpose([1, 0, 2])
v = v.transpose([1, 0, 2])
attn_weights = paddle.matmul(q, k.transpose([0, 2, 1])) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype="float32")
attn_output = paddle.matmul(attn_weights, v)
attn_output = attn_output.transpose([1, 0, 2])
attn_output = attn_output.reshape([seq_length, -1])
attn_output = self.proj(attn_output)
return attn_output
class VisionFlashAttention2(nn.Layer):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim)
self.head_dim = dim // num_heads # must added
def forward(
self, hidden_states: paddle.Tensor, cu_seqlens: paddle.Tensor, rotary_pos_emb: paddle.Tensor = None
) -> paddle.Tensor:
seq_length = tuple(hidden_states.shape)[0]
qkv = self.qkv(hidden_states).reshape([seq_length, 3, self.num_heads, -1]).transpose(perm=[1, 0, 2, 3])
q, k, v = qkv.unbind(axis=0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
if _IS_NPU:
attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
q.astype("bfloat16"), # 不支持float32
k.astype("bfloat16"),
v.astype("bfloat16"),
is_varlen=True,
batch_size=1,
seq_length=seq_length,
).reshape([seq_length, -1])
else:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
softmax_scale = self.head_dim**-0.5 # TODO: 需要手动加上
attn_output = (
flash_attn_varlen_func( # flash_attn_unpadded
q.astype("bfloat16"), # 不支持float32
k.astype("bfloat16"),
v.astype("bfloat16"),
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
scale=softmax_scale, # TODO: 需要手动加上
)[0]
.squeeze(0)
.reshape([seq_length, -1])
)
# attn_output = attn_output.astype(paddle.float32)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2VLVisionBlock(nn.Layer):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
self.norm1 = nn.LayerNorm(config.embed_dim, epsilon=1e-6)
self.norm2 = nn.LayerNorm(config.embed_dim, epsilon=1e-6)
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
self.attn = create_attention_module(config, "vision", auto=True)
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: paddle.Tensor,
sequence_length: int,
target_length: int,
dtype: paddle.dtype,
min_dtype: float,
cache_position: paddle.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`paddle.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`paddle.dtype`):
The dtype to use for the 4D attention mask.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`paddle.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`paddle.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = paddle.full([sequence_length, target_length], fill_value=min_dtype, dtype=dtype)
if sequence_length != 1:
causal_mask = paddle.triu(x=causal_mask, diagonal=1)
causal_mask *= paddle.arange(target_length) > cache_position.reshape([-1, 1])
causal_mask = causal_mask[None, None, :, :].expand(shape=[batch_size, 1, -1, -1])
if attention_mask is not None:
causal_mask = causal_mask.clone()
mask_length = tuple(attention_mask.shape)[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
mask=padding_mask, value=min_dtype
)
return causal_mask
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
class Qwen2RMSNorm(nn.Layer):
def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = paddle.create_parameter(
shape=[hidden_size],
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0),
)
self.variance_epsilon = eps
def forward(self, hidden_states):
if paddle.in_dynamic_mode():
with paddle.amp.auto_cast(False):
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
else:
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
hidden_states = paddle.cast(hidden_states, self.weight.dtype)
return hidden_states * self.weight
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
class Qwen2MLP(nn.Layer):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.fuse_attention_ffn = config.fuse_attention_ffn
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w1
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w3
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) # w2
self.act_fn = ACT2FN[config.hidden_act]
self.fuse_swiglu = False
def forward(self, x):
x, y = self.gate_proj(x), self.up_proj(x)
if self.fuse_swiglu:
x = self.act_fn(x, y)
else:
x = self.act_fn(x) * y
return self.down_proj(x)
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
"""
This is the equivalent of paddle.repeat_interleave(x, axis=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand([batch, num_key_value_heads, n_rep, slen, head_dim])
return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim])
class Qwen2VLAttention(nn.Layer):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout
self.rope_scaling = config.rope_scaling
# self.sequence_parallel = config.sequence_parallel
self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
self.k_proj = Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True)
self.v_proj = Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True)
self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
self.rotary_emb = Qwen2VLRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
position_ids: Optional[paddle.Tensor] = None,
past_key_value: Optional[Tuple[paddle.Tensor]] = None, # Cache
output_attentions: bool = False,
use_cache: bool = False, # default true
cache_position: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
bsz, q_len, _ = hidden_states.shape
try:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
except:
hidden_states = hidden_states.astype(self.config.dtype)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
target_query_shape = [0, 0, self.num_heads, self.head_dim]
target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
query_states = query_states.reshape(shape=target_query_shape)
key_states = key_states.reshape(shape=target_key_value_shape)
value_states = value_states.reshape(shape=target_key_value_shape)
new_perm = [0, 2, 1, 3]
query_states = query_states.transpose(new_perm)
key_states = key_states.transpose(new_perm)
value_states = value_states.transpose(new_perm)
kv_seq_len = key_states.shape[-2] # q_len ######## [bs, num_head, seq_len, head_dim] # qwen2是 [-3]
if past_key_value is not None:
kv_seq_len += cache_position[0] + 1
# kv_seq_len += past_key_value[0].shape[-2] # qwen2是 [-3]
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
# [bs, num_head, seq_len, head_dim]
if past_key_value is not None:
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = paddle.concat([past_key_value[0], key_states], axis=2) # qwen2是 axis=1, qwen2_vl是 axis=2
value_states = paddle.concat([past_key_value[1], value_states], axis=2) # qwen2是 axis=1
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.astype("float32")
key_states = key_states.astype("float32")
value_states = value_states.astype("float32")
attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype="float32")
attn_weights = nn.functional.dropout(x=attn_weights, p=self.attention_dropout, training=self.training)
attn_output = paddle.matmul(attn_weights.cast(self.config.dtype), value_states.cast(self.config.dtype))
if attn_output.shape != [bsz, self.num_heads, q_len, self.head_dim]:
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.shape}"
)
attn_output = attn_output.transpose([0, 2, 1, 3])
attn_output = attn_output.reshape([bsz, q_len, -1])
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Qwen2VLFlashAttention2(Qwen2VLAttention):
"""
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
as the weights of the module stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal with padding tokens
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
config.max_window_layers layers.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
position_ids: Optional[paddle.Tensor] = None,
past_key_value: Optional[Tuple[paddle.Tensor]] = None, # Cache
output_attentions: bool = False,
use_cache: bool = False, # default true
cache_position: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
bsz, q_len, _ = tuple(hidden_states.shape)
try:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
except:
hidden_states = hidden_states.astype("bfloat16")
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
target_query_shape = [0, 0, self.num_heads, self.head_dim]
target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
query_states = query_states.reshape(shape=target_query_shape)
key_states = key_states.reshape(shape=target_key_value_shape)
value_states = value_states.reshape(shape=target_key_value_shape)
new_perm = [0, 2, 1, 3]
# [1, 3599, 1536] [bsz, q_len, self.num_heads * self.head_dim]
query_states = query_states.transpose(new_perm)
key_states = key_states.transpose(new_perm)
value_states = value_states.transpose(new_perm)
kv_seq_len = key_states.shape[-2] # q_len ######## [bs, num_head, seq_len, head_dim] # qwen2是 [-3]
if past_key_value is not None:
kv_seq_len += cache_position[0] + 1
# Because the input can be padded, the absolute sequence length depends on the max position id.
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_value is not None:
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = paddle.concat([past_key_value[0], key_states], axis=2) # qwen2是 axis=1, qwen2_vl是 axis=2
value_states = paddle.concat([past_key_value[1], value_states], axis=2) # qwen2是 axis=1
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Reashape to the expected shape for Flash Attention
# [1, 3599, 12, 128]
query_states = query_states.transpose(perm=[0, 2, 1, 3])
key_states = key_states.transpose(perm=[0, 2, 1, 3])
value_states = value_states.transpose(perm=[0, 2, 1, 3])
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len
# dropout=0.0 if not self.training else self.attention_dropout,
# causal=self.is_causal,
)
attn_output = attn_output.reshape([bsz, q_len, -1])
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`paddle.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`paddle.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`paddle.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`paddle.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
causal = self.is_causal and query_length != 1
if _IS_NPU:
if attention_mask is not None:
attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
query_states, # [5998, 16, 128]
key_states, # [5998, 8, 128]
value_states, # [5998, 8, 128]
attn_mask=attention_mask,
dropout=dropout,
causal=causal,
is_varlen=True,
)
else:
dtype = query_states.dtype
attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
query_states.astype("bfloat16"), # [5998, 16, 128]
key_states.astype("bfloat16"), # [5998, 8, 128]
value_states.astype("bfloat16"), # [5998, 8, 128]
attn_mask=attention_mask,
dropout=dropout,
causal=causal,
)
attn_output = attn_output.astype(dtype)
else:
head_dim = query_states.shape[-1]
softmax_scale = head_dim**-0.5 # TODO: 需要手动加上
if attention_mask is not None: # attention_mask.shape # [2, 1, 1323, 1323]
batch_size = query_states.shape[0] # [2, 1323, 12, 128]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func( # TODO: flash_attn_unpadded
query_states, # [5998, 16, 128]
key_states, # [5998, 8, 128]
value_states, # [5998, 8, 128]
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
scale=softmax_scale, # not softmax_scale=
dropout=dropout,
causal=causal,
)[0]
# attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
attn_output = IndexPutFirstAxis.apply(attn_output_unpad, indices_q, batch_size * query_length)
attn_output = attn_output.reshape([batch_size, query_length, -1])
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
causal=causal, # no softmax_scale=
)[0]
# # 修改这里的维度转换,考虑并行策略下的维度
# batch_size = query_states.shape[0]
# hidden_size = self.num_heads * self.head_dim # 计算实际的 hidden_size
# attn_output = attn_output.reshape([batch_size, query_length, hidden_size])
return attn_output
def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
# Note: This function was named _upad_input() in torch transformers/modeling_flash_attention_utils.py
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
# TODO:cuda error
key_layer = IndexFirstAxis.apply(
key_layer.reshape([batch_size * kv_seq_len, num_key_value_heads, head_dim]), indices_k
)
value_layer = IndexFirstAxis.apply(
value_layer.reshape([batch_size * kv_seq_len, num_key_value_heads, head_dim]), indices_k
)
if query_length == kv_seq_len:
query_layer = IndexFirstAxis.apply(
query_layer.reshape([batch_size * kv_seq_len, self.num_heads, head_dim]), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = paddle.arange(