diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 88343a128bb1..9471449561b9 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -142,6 +142,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms CUBLAS_WORKSPACE_CONFIG: :16:8 + RUN_COMPILE: yes run: | python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "not Flax and not Onnx" \ diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index f82a2407f333..83fad5ee4c71 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -59,7 +59,6 @@ from diffusers.utils.testing_utils import ( CaptureLogger, backend_empty_cache, - floats_tensor, get_python_version, is_torch_compile, numpy_cosine_similarity_distance, @@ -1720,7 +1719,7 @@ def test_push_to_hub_library_name(self): @require_peft_backend @require_peft_version_greater("0.14.0") @is_torch_compile -class TestLoraHotSwappingForModel(unittest.TestCase): +class LoraHotSwappingForModelTesterMixin: """Test that hotswapping does not result in recompilation on the model directly. We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively @@ -1741,48 +1740,24 @@ def tearDown(self): gc.collect() backend_empty_cache(torch_device) - def get_small_unet(self): - # from diffusers UNet2DConditionModelTests - torch.manual_seed(0) - init_dict = { - "block_out_channels": (4, 8), - "norm_num_groups": 4, - "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), - "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), - "cross_attention_dim": 8, - "attention_head_dim": 2, - "out_channels": 4, - "in_channels": 4, - "layers_per_block": 1, - "sample_size": 16, - } - model = UNet2DConditionModel(**init_dict) - return model.to(torch_device) - - def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules): + def get_lora_config(self, lora_rank, lora_alpha, target_modules): # from diffusers test_models_unet_2d_condition.py from peft import LoraConfig - unet_lora_config = LoraConfig( + lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules, init_lora_weights=False, use_dora=False, ) - return unet_lora_config - - def get_dummy_input(self): - # from UNet2DConditionModelTests - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) + return lora_config - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + def get_linear_module_name_other_than_attn(self, model): + linear_names = [ + name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name + ] + return linear_names[0] def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None): """ @@ -1800,23 +1775,27 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ fine. """ # create 2 adapters with different ranks and alphas - dummy_input = self.get_dummy_input() + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + alpha0, alpha1 = rank0, rank1 max_rank = max([rank0, rank1]) if target_modules1 is None: target_modules1 = target_modules0[:] - lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0) - lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1) + lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0) + lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1) - unet = self.get_small_unet() - unet.add_adapter(lora_config0, adapter_name="adapter0") + model.add_adapter(lora_config0, adapter_name="adapter0") with torch.inference_mode(): - output0_before = unet(**dummy_input)["sample"] + torch.manual_seed(0) + output0_before = model(**inputs_dict)["sample"] - unet.add_adapter(lora_config1, adapter_name="adapter1") - unet.set_adapter("adapter1") + model.add_adapter(lora_config1, adapter_name="adapter1") + model.set_adapter("adapter1") with torch.inference_mode(): - output1_before = unet(**dummy_input)["sample"] + torch.manual_seed(0) + output1_before = model(**inputs_dict)["sample"] # sanity checks: tol = 5e-3 @@ -1826,40 +1805,43 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ with tempfile.TemporaryDirectory() as tmp_dirname: # save the adapter checkpoints - unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") - unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") - del unet + model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") + model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") + del model # load the first adapter - unet = self.get_small_unet() + torch.manual_seed(0) + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + if do_compile or (rank0 != rank1): # no need to prepare if the model is not compiled or if the ranks are identical - unet.enable_lora_hotswap(target_rank=max_rank) + model.enable_lora_hotswap(target_rank=max_rank) file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") - unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) + model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) if do_compile: - unet = torch.compile(unet, mode="reduce-overhead") + model = torch.compile(model, mode="reduce-overhead") with torch.inference_mode(): - output0_after = unet(**dummy_input)["sample"] + output0_after = model(**inputs_dict)["sample"] assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) # hotswap the 2nd adapter - unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) + model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) # we need to call forward to potentially trigger recompilation with torch.inference_mode(): - output1_after = unet(**dummy_input)["sample"] + output1_after = model(**inputs_dict)["sample"] assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) # check error when not passing valid adapter name name = "does-not-exist" msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" with self.assertRaisesRegex(ValueError, msg): - unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) + model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_model(self, rank0, rank1): @@ -1876,6 +1858,9 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1): @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + return + # It's important to add this context to raise an error on recompilation target_modules = ["conv", "conv1", "conv2"] with torch._dynamo.config.patch(error_on_recompile=True): @@ -1883,52 +1868,77 @@ def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + return + # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "conv"] with torch._dynamo.config.patch(error_on_recompile=True): self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1): + # In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping + # with `torch.compile()` for models that have both linear and conv layers. In this test, we check + # if we can target a linear layer from the transformer blocks and another linear layer from non-attention + # block. + target_modules = ["to_q"] + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + target_modules.append(self.get_linear_module_name_other_than_attn(model)) + del model + + # It's important to add this context to raise an error on recompilation + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) + def test_enable_lora_hotswap_called_after_adapter_added_raises(self): # ensure that enable_lora_hotswap is called before loading the first adapter - lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) - unet = self.get_small_unet() - unet.add_adapter(lora_config) + lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") with self.assertRaisesRegex(RuntimeError, msg): - unet.enable_lora_hotswap(target_rank=32) + model.enable_lora_hotswap(target_rank=32) def test_enable_lora_hotswap_called_after_adapter_added_warning(self): # ensure that enable_lora_hotswap is called before loading the first adapter from diffusers.loaders.peft import logger - lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) - unet = self.get_small_unet() - unet.add_adapter(lora_config) + lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) msg = ( "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." ) with self.assertLogs(logger=logger, level="WARNING") as cm: - unet.enable_lora_hotswap(target_rank=32, check_compiled="warn") + model.enable_lora_hotswap(target_rank=32, check_compiled="warn") assert any(msg in log for log in cm.output) def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): # check possibility to ignore the error/warning - lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) - unet = self.get_small_unet() - unet.add_adapter(lora_config) + lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Capture all warnings - unet.enable_lora_hotswap(target_rank=32, check_compiled="warn") + model.enable_lora_hotswap(target_rank=32, check_compiled="warn") self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}") def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): # check that wrong argument value raises an error - lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) - unet = self.get_small_unet() - unet.add_adapter(lora_config) + lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") with self.assertRaisesRegex(ValueError, msg): - unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") + model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") def test_hotswap_second_adapter_targets_more_layers_raises(self): # check the error and log diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index c88b3dac8216..4238ca844d1c 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -22,7 +22,7 @@ from diffusers.models.embeddings import ImageProjection from diffusers.utils.testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin enable_full_determinism() @@ -78,7 +78,7 @@ def create_flux_ip_adapter_state_dict(model): return ip_state_dict -class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): +class FluxTransformerTests(ModelTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index d01a0b493520..94a5d641a717 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -53,7 +53,7 @@ torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin if is_peft_available(): @@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): return custom_diffusion_attn_procs -class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class UNet2DConditionModelTests( + ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase +): model_class = UNet2DConditionModel main_input_name = "sample" # We override the items here because the unet under consideration is small.