Skip to content

Commit 580e7ae

Browse files
authored
Merge branch 'main' into enable-hotswap-testing-ci
2 parents d2e6c9c + 0dec414 commit 580e7ae

File tree

11 files changed

+304
-44
lines changed

11 files changed

+304
-44
lines changed

Diff for: docs/source/en/api/pipelines/wan.md

+54
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,60 @@ output = pipe(
133133
export_to_video(output, "wan-i2v.mp4", fps=16)
134134
```
135135

136+
### First and Last Frame Interpolation
137+
138+
```python
139+
import numpy as np
140+
import torch
141+
import torchvision.transforms.functional as TF
142+
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
143+
from diffusers.utils import export_to_video, load_image
144+
from transformers import CLIPVisionModel
145+
146+
147+
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
148+
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
149+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
150+
pipe = WanImageToVideoPipeline.from_pretrained(
151+
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
152+
)
153+
pipe.to("cuda")
154+
155+
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
156+
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
157+
158+
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
159+
aspect_ratio = image.height / image.width
160+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
161+
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
162+
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
163+
image = image.resize((width, height))
164+
return image, height, width
165+
166+
def center_crop_resize(image, height, width):
167+
# Calculate resize ratio to match first frame dimensions
168+
resize_ratio = max(width / image.width, height / image.height)
169+
170+
# Resize the image
171+
width = round(image.width * resize_ratio)
172+
height = round(image.height * resize_ratio)
173+
size = [width, height]
174+
image = TF.center_crop(image, size)
175+
176+
return image, height, width
177+
178+
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
179+
if last_frame.size != first_frame.size:
180+
last_frame, _, _ = center_crop_resize(last_frame, height, width)
181+
182+
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
183+
184+
output = pipe(
185+
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
186+
).frames[0]
187+
export_to_video(output, "output.mp4", fps=16)
188+
```
189+
136190
### Video to Video Generation
137191

138192
```python

Diff for: examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -1915,17 +1915,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19151915
free_memory()
19161916

19171917
# Scheduler and math around the number of training steps.
1918-
overrode_max_train_steps = False
1919-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1918+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1919+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
19201920
if args.max_train_steps is None:
1921-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1922-
overrode_max_train_steps = True
1921+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1922+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1923+
num_training_steps_for_scheduler = (
1924+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1925+
)
1926+
else:
1927+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
19231928

19241929
lr_scheduler = get_scheduler(
19251930
args.lr_scheduler,
19261931
optimizer=optimizer,
1927-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1928-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1932+
num_warmup_steps=num_warmup_steps_for_scheduler,
1933+
num_training_steps=num_training_steps_for_scheduler,
19291934
num_cycles=args.lr_num_cycles,
19301935
power=args.lr_power,
19311936
)
@@ -1949,7 +1954,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19491954
lr_scheduler,
19501955
)
19511956
else:
1952-
print("I SHOULD BE HERE")
19531957
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
19541958
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
19551959
)
@@ -1961,8 +1965,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19611965

19621966
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
19631967
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1964-
if overrode_max_train_steps:
1968+
if args.max_train_steps is None:
19651969
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1970+
if num_training_steps_for_scheduler != args.max_train_steps:
1971+
logger.warning(
1972+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1973+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1974+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1975+
)
19661976
# Afterwards we recalculate our number of training epochs
19671977
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
19681978

Diff for: examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
3434
from diffusers.loaders import (
3535
FromSingleFileMixin,
36-
StableDiffusionLoraLoaderMixin,
3736
StableDiffusionXLLoraLoaderMixin,
3837
TextualInversionLoaderMixin,
3938
)
@@ -300,7 +299,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
300299

301300

302301
class StableDiffusionXLControlNetAdapterInpaintPipeline(
303-
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionLoraLoaderMixin
302+
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
304303
):
305304
r"""
306305
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter

Diff for: examples/dreambooth/train_dreambooth_flux.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -1407,17 +1407,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14071407
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
14081408

14091409
# Scheduler and math around the number of training steps.
1410-
overrode_max_train_steps = False
1411-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1410+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1411+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
14121412
if args.max_train_steps is None:
1413-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1414-
overrode_max_train_steps = True
1413+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1414+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1415+
num_training_steps_for_scheduler = (
1416+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1417+
)
1418+
else:
1419+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
14151420

14161421
lr_scheduler = get_scheduler(
14171422
args.lr_scheduler,
14181423
optimizer=optimizer,
1419-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1420-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1424+
num_warmup_steps=num_warmup_steps_for_scheduler,
1425+
num_training_steps=num_training_steps_for_scheduler,
14211426
num_cycles=args.lr_num_cycles,
14221427
power=args.lr_power,
14231428
)
@@ -1444,8 +1449,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14441449

14451450
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
14461451
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1447-
if overrode_max_train_steps:
1452+
if args.max_train_steps is None:
14481453
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1454+
if num_training_steps_for_scheduler != args.max_train_steps:
1455+
logger.warning(
1456+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1457+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1458+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1459+
)
14491460
# Afterwards we recalculate our number of training epochs
14501461
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
14511462

Diff for: examples/dreambooth/train_dreambooth_lora_flux.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -1524,17 +1524,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15241524
free_memory()
15251525

15261526
# Scheduler and math around the number of training steps.
1527-
overrode_max_train_steps = False
1528-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1527+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1528+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
15291529
if args.max_train_steps is None:
1530-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1531-
overrode_max_train_steps = True
1530+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1531+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1532+
num_training_steps_for_scheduler = (
1533+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1534+
)
1535+
else:
1536+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
15321537

15331538
lr_scheduler = get_scheduler(
15341539
args.lr_scheduler,
15351540
optimizer=optimizer,
1536-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1537-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1541+
num_warmup_steps=num_warmup_steps_for_scheduler,
1542+
num_training_steps=num_training_steps_for_scheduler,
15381543
num_cycles=args.lr_num_cycles,
15391544
power=args.lr_power,
15401545
)
@@ -1561,8 +1566,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15611566

15621567
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
15631568
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1564-
if overrode_max_train_steps:
1569+
if args.max_train_steps is None:
15651570
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1571+
if num_training_steps_for_scheduler != args.max_train_steps:
1572+
logger.warning(
1573+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1574+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1575+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1576+
)
15661577
# Afterwards we recalculate our number of training epochs
15671578
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
15681579

Diff for: examples/dreambooth/train_dreambooth_lora_sdxl.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -1523,17 +1523,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15231523
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
15241524

15251525
# Scheduler and math around the number of training steps.
1526-
overrode_max_train_steps = False
1527-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1526+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1527+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
15281528
if args.max_train_steps is None:
1529-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1530-
overrode_max_train_steps = True
1529+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1530+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1531+
num_training_steps_for_scheduler = (
1532+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1533+
)
1534+
else:
1535+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
15311536

15321537
lr_scheduler = get_scheduler(
15331538
args.lr_scheduler,
15341539
optimizer=optimizer,
1535-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1536-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1540+
num_warmup_steps=num_warmup_steps_for_scheduler,
1541+
num_training_steps=num_training_steps_for_scheduler,
15371542
num_cycles=args.lr_num_cycles,
15381543
power=args.lr_power,
15391544
)
@@ -1550,7 +1555,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15501555

15511556
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
15521557
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1553-
if overrode_max_train_steps:
1558+
if args.max_train_steps is None:
1559+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1560+
if num_training_steps_for_scheduler != args.max_train_steps:
1561+
logger.warning(
1562+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1563+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1564+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1565+
)
15541566
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
15551567
# Afterwards we recalculate our number of training epochs
15561568
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

Diff for: scripts/convert_wan_to_diffusers.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@
3939
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
4040
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
4141
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
42+
# for the FLF2V model
43+
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
44+
# Add attention component mappings
45+
"self_attn.q": "attn1.to_q",
46+
"self_attn.k": "attn1.to_k",
47+
"self_attn.v": "attn1.to_v",
48+
"self_attn.o": "attn1.to_out.0",
49+
"self_attn.norm_q": "attn1.norm_q",
50+
"self_attn.norm_k": "attn1.norm_k",
51+
"cross_attn.q": "attn2.to_q",
52+
"cross_attn.k": "attn2.to_k",
53+
"cross_attn.v": "attn2.to_v",
54+
"cross_attn.o": "attn2.to_out.0",
55+
"cross_attn.norm_q": "attn2.norm_q",
56+
"cross_attn.norm_k": "attn2.norm_k",
57+
"attn2.to_k_img": "attn2.add_k_proj",
58+
"attn2.to_v_img": "attn2.add_v_proj",
59+
"attn2.norm_k_img": "attn2.norm_added_k",
4260
}
4361

4462
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
@@ -135,6 +153,28 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
135153
"text_dim": 4096,
136154
},
137155
}
156+
elif model_type == "Wan-FLF2V-14B-720P":
157+
config = {
158+
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
159+
"diffusers_config": {
160+
"image_dim": 1280,
161+
"added_kv_proj_dim": 5120,
162+
"attention_head_dim": 128,
163+
"cross_attn_norm": True,
164+
"eps": 1e-06,
165+
"ffn_dim": 13824,
166+
"freq_dim": 256,
167+
"in_channels": 36,
168+
"num_attention_heads": 40,
169+
"num_layers": 40,
170+
"out_channels": 16,
171+
"patch_size": [1, 2, 2],
172+
"qk_norm": "rms_norm_across_heads",
173+
"text_dim": 4096,
174+
"rope_max_seq_len": 1024,
175+
"pos_embed_seq_len": 257 * 2,
176+
},
177+
}
138178
return config
139179

140180

@@ -393,11 +433,12 @@ def get_args():
393433
vae = convert_vae()
394434
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
395435
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
436+
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
396437
scheduler = UniPCMultistepScheduler(
397-
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
438+
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
398439
)
399440

400-
if "I2V" in args.model_type:
441+
if "I2V" in args.model_type or "FLF2V" in args.model_type:
401442
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
402443
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
403444
)

Diff for: src/diffusers/models/transformers/transformer_hidream_image.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -918,5 +918,5 @@ def forward(
918918
unscale_lora_layers(self, lora_scale)
919919

920920
if not return_dict:
921-
return (output, hidden_states_masks)
922-
return Transformer2DModelOutput(sample=output, mask=hidden_states_masks)
921+
return (output,)
922+
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)