Skip to content

[Tests] Enable more general testing for torch.compile() with LoRA hotswapping #11322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
150 changes: 80 additions & 70 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
floats_tensor,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
Expand Down Expand Up @@ -1720,7 +1719,7 @@ def test_push_to_hub_library_name(self):
@require_peft_backend
@require_peft_version_greater("0.14.0")
@is_torch_compile
class TestLoraHotSwappingForModel(unittest.TestCase):
class LoraHotSwappingForModelTesterMixin:
"""Test that hotswapping does not result in recompilation on the model directly.

We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
Expand All @@ -1741,48 +1740,24 @@ def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)

def get_small_unet(self):
# from diffusers UNet2DConditionModelTests
torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 8,
"attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 16,
}
model = UNet2DConditionModel(**init_dict)
return model.to(torch_device)

def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
def get_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
from peft import LoraConfig

unet_lora_config = LoraConfig(
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config

def get_dummy_input(self):
# from UNet2DConditionModelTests
batch_size = 4
num_channels = 4
sizes = (16, 16)

noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return lora_config

return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
def get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]

def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
"""
Expand All @@ -1800,23 +1775,27 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
fine.
"""
# create 2 adapters with different ranks and alphas
dummy_input = self.get_dummy_input()
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)

unet = self.get_small_unet()
unet.add_adapter(lora_config0, adapter_name="adapter0")
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
output0_before = unet(**dummy_input)["sample"]
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]

unet.add_adapter(lora_config1, adapter_name="adapter1")
unet.set_adapter("adapter1")
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
output1_before = unet(**dummy_input)["sample"]
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]

# sanity checks:
tol = 5e-3
Expand All @@ -1826,40 +1805,43 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_

with tempfile.TemporaryDirectory() as tmp_dirname:
# save the adapter checkpoints
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del unet
model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del model

# load the first adapter
unet = self.get_small_unet()
torch.manual_seed(0)
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
unet.enable_lora_hotswap(target_rank=max_rank)
model.enable_lora_hotswap(target_rank=max_rank)

file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)

if do_compile:
unet = torch.compile(unet, mode="reduce-overhead")
model = torch.compile(model, mode="reduce-overhead")

with torch.inference_mode():
output0_after = unet(**dummy_input)["sample"]
output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)

# hotswap the 2nd adapter
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)

# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
output1_after = unet(**dummy_input)["sample"]
output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)

# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with self.assertRaisesRegex(ValueError, msg):
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_model(self, rank0, rank1):
Expand All @@ -1877,58 +1859,86 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
if "unet" not in self.model_class.__name__.lower():
return

target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
if "unet" not in self.model_class.__name__.lower():
return

target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q"]
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)

target_modules.append(self.get_linear_module_name_other_than_attn(model))
del model

with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)

msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with self.assertRaisesRegex(RuntimeError, msg):
unet.enable_lora_hotswap(target_rank=32)
model.enable_lora_hotswap(target_rank=32)

def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
from diffusers.loaders.peft import logger

lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with self.assertLogs(logger=logger, level="WARNING") as cm:
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in log for log in cm.output)

def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")

def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with self.assertRaisesRegex(ValueError, msg):
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")

def test_hotswap_second_adapter_targets_more_layers_raises(self):
# check the error and log
Expand Down
4 changes: 2 additions & 2 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device

from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin


enable_full_determinism()
Expand Down Expand Up @@ -78,7 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict


class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
class FluxTransformerTests(ModelTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
Expand Down
6 changes: 4 additions & 2 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
torch_device,
)

from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin


if is_peft_available():
Expand Down Expand Up @@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs


class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class UNet2DConditionModelTests(
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
Expand Down