diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4e62f3ef6182..af2f10fc4da5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -568,6 +568,8 @@ title: UniDiffuser - local: api/pipelines/value_guided_sampling title: Value-guided sampling + - local: api/pipelines/visualcloze + title: VisualCloze - local: api/pipelines/wan title: Wan - local: api/pipelines/wuerstchen diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index 6a8e82a692e0..95b50ce60826 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -89,6 +89,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation | | [Value-guided planning](value_guided_sampling) | value guided sampling | | [Wuerstchen](wuerstchen) | text2image | +| [VisualCloze](visualcloze) | text2image, image2image, subject driven generation, inpainting, style transfer, image restoration, image editing, [depth,normal,edge,pose]2image, [depth,normal,edge,pose]-estimation, virtual try-on, image relighting | ## DiffusionPipeline diff --git a/docs/source/en/api/pipelines/visualcloze.md b/docs/source/en/api/pipelines/visualcloze.md new file mode 100644 index 000000000000..a7d7e1ba292e --- /dev/null +++ b/docs/source/en/api/pipelines/visualcloze.md @@ -0,0 +1,200 @@ + + +# VisualCloze + +[VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning](https://arxiv.org/abs/2504.07960) is an in-context learning based universal image generation framework that can 1) support various in-domain tasks, 2) generalize to unseen tasks through in-context learning, 3) unify multiple tasks into one step and generate both target image and intermediate results, and 4) support reverse-engineering a set of conditions from a target image. + +The abstract from the paper is: + +*Recent progress in diffusion models significantly advances various image generation tasks. However, the current mainstream approach remains focused on building task-specific models, which have limited efficiency when supporting a wide range of different needs. While universal models attempt to address this limitation, they face critical challenges, including generalizable task instruction, appropriate task distributions, and unified architectural design. To tackle these challenges, we propose VisualCloze, a universal image generation framework, which supports a wide range of in-domain tasks, generalization to unseen ones, unseen unification of multiple tasks, and reverse generation. Unlike existing methods that rely on language-based task instruction, leading to task ambiguity and weak generalization, we integrate visual in-context learning, allowing models to identify tasks from visual demonstrations. Meanwhile, the inherent sparsity of visual task distributions hampers the learning of transferable knowledge across tasks. To this end, we introduce Graph200K, a graph-structured dataset that establishes various interrelated tasks, enhancing task density and transferable knowledge. Furthermore, we uncover that our unified image generation formulation shared a consistent objective with image infilling, enabling us to leverage the strong generative priors of pre-trained infilling models without modifying the architectures. The codes, dataset, and models are available at https://visualcloze.github.io.* + +## Inference + +### Model loading + +VisualCloze releases two models suitable for diffusers, i.e., VisualClozePipeline-384 and VisualClozePipeline-512, which are trained with resolutions of 384 and 512, respectively. +The resolution means that each image is resized to the area of the square of it before concatenating images into a grid layout. +In this case, VisualCloze uses [SDEdit](https://arxiv.org/abs/2108.01073) to upsample the generated images. +```python +import torch +from diffusers import VisualClozePipeline + +pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16) +pipe.to("cuda") +``` + +### Input prompts +VisualCloze supports a wide variety of tasks. You need to pass a task prompt to describe the intention of the generation task, and optionally, a content prompt to describe the caption of the image to be generated. When the content prompt is not needed, None should also be passed. + +### Input images + +The input image should be a List[List[Image|None]]. Excluding the last row, each row represents an in-context example. The last row represents the current query, where the image to be generated is set to None. +When using batch inference, the input images should be a List[List[List[Image|None]]], and the input prompts should be a List[str|None]. + +### Resolution + +By default, the model first generates an image with a resolution of ${model.resolution}^2$, and then upsamples it by a factor of three. You can try setting the `upsampling_height` and `upsampling_width` parameters to generate images with different size. + + +### Examples + + +More examples covering a wide range of tasks can be found in the [Online Demo](https://huggingface.co/spaces/VisualCloze/VisualCloze) and [Github Repo](https://github.com/lzyhha/VisualCloze). +Here, the document shows simple examples for mask2image, edge-detection, and subject-driven generation. + +#### mask2image + +```python + +# Load in-context images (make sure the paths are correct and accessible) +image_paths = [ + # in-context examples + [ + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/2c4e256fa512cb7e7f433f4c7f9101de_sam2_mask.jpg'), + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/2c4e256fa512cb7e7f433f4c7f9101de.jpg'), + ], + # query with the target image + [ + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/9c565b1aad76b22f5bb836744a93561a_sam2_mask.jpg'), + None, # No image needed for the target image + ], +] + +# Task and content prompt +task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding." +content_prompt = """Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. +The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. +Its plumage is a mix of dark brown and golden hues, with intricate feather details. +The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. +The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, +soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, +tranquil, majestic, wildlife photography.""" + +# Run the pipeline +image_result = pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image_paths, + upsampling_width=1344, + upsampling_height=768, + upsampling_strength=0.4, + guidance_scale=30, + num_inference_steps=30, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0) +).images[0][0] + +# Save the resulting image +image_result.save("visualcloze.png") +``` + +#### Example for edge-detection + +```python +# Load in-context images (make sure the paths are correct and accessible) +image_paths = [ + # in-context examples + [ + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/de5a8b250bf407aa7e04913562dcba90.jpg'), + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/de5a8b250bf407aa7e04913562dcba90_hed_512.jpg'), + ], + [ + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/5bf755ed9dbb9b3e223e7ba35232b06e.jpg'), + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/5bf755ed9dbb9b3e223e7ba35232b06e_hed_512.jpg'), + ], + # query with the target image + [ + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/53b3f413257bee9e499b823b44623b1a.jpg'), + None, # No image needed for the target image + ], +] + +# Task and content prompt +task_prompt = "Each row illustrates a pathway from [IMAGE1] a sharp and beautifully composed photograph to [IMAGE2] edge map with natural well-connected outlines using a clear logical task." +content_prompt = "" + +# Run the pipeline +image_result = pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image_paths, + upsampling_width=864, + upsampling_height=1152, + upsampling_strength=0.4, + guidance_scale=30, + num_inference_steps=30, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0) +).images[0][0] + +# Save the resulting image +image_result.save("visualcloze.png") +``` + +#### Example for subject-driven generation + +```python +# Load in-context images (make sure the paths are correct and accessible) +image_paths = [ + # in-context examples + [ + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/data-00004-of-00022-7170_reference.jpg'), + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/data-00004-of-00022-7170_depth-anything-v2_Large.jpg'), + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/data-00004-of-00022-7170_target.jpg'), + ], + [ + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/data-00013-of-00022-4696_reference.jpg'), + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/data-00013-of-00022-4696_depth-anything-v2_Large.jpg'), + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/data-00013-of-00022-4696_target.jpg'), + ], + # query with the target image + [ + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/data-00005-of-00022-4396_reference.jpg'), + load_image('https://huggingface.co/VisualCloze/VisualCloze/resolve/main/examples/data-00005-of-00022-4396_depth-anything-v2_Large.jpg'), + None, # No image needed for the target image + ], +] + +# Task and content prompt +task_prompt = """Each row describes a process that begins with [IMAGE1] an image containing the key object, +[IMAGE2] depth map revealing gray-toned spatial layers and results in +[IMAGE3] an image with artistic qualitya high-quality image with exceptional detail.""" +content_prompt = """A vintage porcelain collector's item. Beneath a blossoming cherry tree in early spring, +this treasure is photographed up close, with soft pink petals drifting through the air and vibrant blossoms framing the scene.""" + +# Run the pipeline +image_result = pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image_paths, + upsampling_width=1024, + upsampling_height=1024, + upsampling_strength=0.2, + guidance_scale=30, + num_inference_steps=30, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0) +).images[0][0] + +# Save the resulting image +image_result.save("visualcloze.png") +``` + +## VisualClozePipeline + +[[autodoc]] VisualClozePipeline + - all + - __call__ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f51a4ef2b3f6..913facc524b8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -513,6 +513,7 @@ "VersatileDiffusionPipeline", "VersatileDiffusionTextToImagePipeline", "VideoToVideoSDPipeline", + "VisualClozePipeline", "VQDiffusionPipeline", "WanImageToVideoPipeline", "WanPipeline", @@ -1086,6 +1087,7 @@ VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, VideoToVideoSDPipeline, + VisualClozePipeline, VQDiffusionPipeline, WanImageToVideoPipeline, WanPipeline, diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index c7847d160ffe..f27738dfe2a8 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -14,7 +14,7 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image @@ -1317,3 +1317,234 @@ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: in samples = samples[:, :, start_y:end_y, start_x:end_x] return samples + + +class VisualClozeProcessor(VaeImageProcessor): + """ + Image processor for the VisualCloze pipeline. + + This processor handles the preprocessing of images for visual cloze tasks, including resizing, normalization, and + mask generation. + + Args: + resolution (int, optional): + Target resolution for processing images. Each image will be resized to this resolution before being + concatenated to avoid the out-of-memory error. Defaults to 384. + *args: Additional arguments passed to [~image_processor.VaeImageProcessor] + **kwargs: Additional keyword arguments passed to [~image_processor.VaeImageProcessor] + """ + + def __init__(self, *args, resolution: int = 384, **kwargs): + super().__init__(*args, **kwargs) + self.resolution = resolution + + def preprocess_image( + self, input_images: List[List[Optional[Image.Image]]], vae_scale_factor: int + ) -> Tuple[List[List[torch.Tensor]], List[List[List[int]]], List[int]]: + """ + Preprocesses input images for the VisualCloze pipeline. + + This function handles the preprocessing of input images by: + 1. Resizing and cropping images to maintain consistent dimensions + 2. Converting images to the Tensor format for the VAE + 3. Normalizing pixel values + 4. Tracking image sizes and positions of target images + + Args: + input_images (List[List[Optional[Image.Image]]]): + A nested list of PIL Images where: + - Outer list represents different samples, including in-context examples and the query + - Inner list contains images for the task + - In the last row, condition images are provided and the target images are placed as None + vae_scale_factor (int): + The scale factor used by the VAE for resizing images + + Returns: + Tuple containing: + - List[List[torch.Tensor]]: Preprocessed images in tensor format + - List[List[List[int]]]: Dimensions of each processed image [height, width] + - List[int]: Target positions indicating which images are to be generated + """ + n_samples, n_task_images = len(input_images), len(input_images[0]) + divisible = 2 * vae_scale_factor + + processed_images: List[List[Image.Image]] = [[] for _ in range(n_samples)] + resize_size: List[Optional[Tuple[int, int]]] = [None for _ in range(n_samples)] + target_position: List[int] = [] + + # Process each sample + for i in range(n_samples): + # Determine size from first non-None image + for j in range(n_task_images): + if input_images[i][j] is not None: + aspect_ratio = input_images[i][j].width / input_images[i][j].height + target_area = self.resolution * self.resolution + new_h = int((target_area / aspect_ratio) ** 0.5) + new_w = int(new_h * aspect_ratio) + + new_w = max(new_w // divisible, 1) * divisible + new_h = max(new_h // divisible, 1) * divisible + resize_size[i] = (new_w, new_h) + break + + # Process all images in the sample + for j in range(n_task_images): + if input_images[i][j] is not None: + target = self._resize_and_crop(input_images[i][j], resize_size[i][0], resize_size[i][1]) + processed_images[i].append(target) + if i == n_samples - 1: + target_position.append(0) + else: + blank = Image.new("RGB", resize_size[i] or (self.resolution, self.resolution), (0, 0, 0)) + processed_images[i].append(blank) + if i == n_samples - 1: + target_position.append(1) + + # Ensure consistent width for multiple target images when there are multiple target images + if len(target_position) > 1 and sum(target_position) > 1: + new_w = resize_size[n_samples - 1][0] or 384 + for i in range(len(processed_images)): + for j in range(len(processed_images[i])): + if processed_images[i][j] is not None: + new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width)) + new_w = int(new_w / 16) * 16 + new_h = int(new_h / 16) * 16 + processed_images[i][j] = self.height(processed_images[i][j], new_h, new_w) + + # Convert to tensors and normalize + image_sizes = [] + for i in range(len(processed_images)): + image_sizes.append([[img.height, img.width] for img in processed_images[i]]) + for j, image in enumerate(processed_images[i]): + image = self.pil_to_numpy(image) + image = self.numpy_to_pt(image) + image = self.normalize(image) + processed_images[i][j] = image + + return processed_images, image_sizes, target_position + + def preprocess_mask( + self, input_images: List[List[Image.Image]], target_position: List[int] + ) -> List[List[torch.Tensor]]: + """ + Generate masks for the VisualCloze pipeline. + + Args: + input_images (List[List[Image.Image]]): + Processed images from preprocess_image + target_position (List[int]): + Binary list marking the positions of target images (1 for target, 0 for condition) + + Returns: + List[List[torch.Tensor]]: + A nested list of mask tensors (1 for target positions, 0 for condition images) + """ + mask = [] + for i, row in enumerate(input_images): + if i == len(input_images) - 1: # Query row + row_masks = [ + torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=m) for m in target_position + ] + else: # In-context examples + row_masks = [ + torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=0) for _ in target_position + ] + mask.append(row_masks) + return mask + + def preprocess_image_upsampling( + self, + input_images: List[List[Image.Image]], + height: int, + width: int, + ) -> Tuple[List[List[Image.Image]], List[List[List[int]]]]: + """Process images for the upsampling stage in the VisualCloze pipeline. + + Args: + input_images: Input image to process + height: Target height + width: Target width + + Returns: + Tuple of processed image and its size + """ + image = self.resize(input_images[0][0], height, width) + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + image = self.normalize(image) + + input_images[0][0] = image + image_sizes = [[[height, width]]] + return input_images, image_sizes + + def preprocess_mask_upsampling(self, input_images: List[List[Image.Image]]) -> List[List[torch.Tensor]]: + return [[torch.ones((1, 1, input_images[0][0].shape[2], input_images[0][0].shape[3]))]] + + def get_layout_prompt(self, size: Tuple[int, int]) -> str: + layout_instruction = ( + f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.", + ) + return layout_instruction + + def preprocess( + self, + task_prompt: Union[str, List[str]], + content_prompt: Union[str, List[str]], + input_images: Optional[List[List[List[Optional[str]]]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + upsampling: bool = False, + vae_scale_factor: int = 16, + ) -> Dict: + """Process visual cloze inputs. + + Args: + task_prompt: Task description(s) + content_prompt: Content description(s) + input_images: List of images or None for the target images + height: Optional target height for upsampling stage + width: Optional target width for upsampling stage + upsampling: Whether this is in the upsampling processing stage + + Returns: + Dictionary containing processed images, masks, prompts and metadata + """ + if isinstance(task_prompt, str): + task_prompt = [task_prompt] + content_prompt = [content_prompt] + input_images = [input_images] + + output = { + "init_image": [], + "mask": [], + "task_prompt": task_prompt if not upsampling else [None for _ in range(len(task_prompt))], + "content_prompt": content_prompt, + "layout_prompt": [], + "target_position": [], + "image_size": [], + } + for i in range(len(task_prompt)): + if upsampling: + layout_prompt = None + else: + layout_prompt = self.get_layout_prompt((len(input_images[i]), len(input_images[i][0]))) + + if upsampling: + cur_processed_images, cur_image_size = self.preprocess_image_upsampling( + input_images[i], height=height, width=width + ) + cur_mask = self.preprocess_mask_upsampling(cur_processed_images) + else: + cur_processed_images, cur_image_size, cur_target_position = self.preprocess_image( + input_images[i], vae_scale_factor=vae_scale_factor + ) + cur_mask = self.preprocess_mask(cur_processed_images, cur_target_position) + + output["target_position"].append(cur_target_position) + + output["image_size"].append(cur_image_size) + output["init_image"].append(cur_processed_images) + output["mask"].append(cur_mask) + output["layout_prompt"].append(layout_prompt) + + return output diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 011f23ed371c..2b75393f69ed 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -278,6 +278,7 @@ _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["omnigen"] = ["OmniGenPipeline"] + _import_structure["visualcloze"] = ["VisualClozePipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] @@ -722,6 +723,7 @@ UniDiffuserPipeline, UniDiffuserTextDecoder, ) + from .visualcloze import VisualClozePipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wuerstchen import ( WuerstchenCombinedPipeline, diff --git a/src/diffusers/pipelines/visualcloze/__init__.py b/src/diffusers/pipelines/visualcloze/__init__.py new file mode 100644 index 000000000000..aba314868efa --- /dev/null +++ b/src/diffusers/pipelines/visualcloze/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_visualcloze"] = ["VisualClozePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_visualcloze import VisualClozePipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze.py b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze.py new file mode 100644 index 000000000000..b2fa9a9a5e24 --- /dev/null +++ b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze.py @@ -0,0 +1,1201 @@ +# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import VisualClozeProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..flux.pipeline_output import FluxPipelineOutput +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import VisualClozePipeline + >>> from diffusers.utils import load_image + + >>> image = [ + ... # in-context examples + ... [ + ... load_image( + ... "https://github.com/lzyhha/VisualCloze/tree/main/examples/examples/5bf755ed9dbb9b3e223e7ba35232b06e/5bf755ed9dbb9b3e223e7ba35232b06e_depth-anything-v2_Large.jpg" + ... ), + ... load_image( + ... "https://github.com/lzyhha/VisualCloze/tree/main/examples/examples/5bf755ed9dbb9b3e223e7ba35232b06e/5bf755ed9dbb9b3e223e7ba35232b06e.jpg" + ... ), + ... ], + ... # query with the target image + ... [ + ... load_image( + ... "https://github.com/lzyhha/VisualCloze/tree/main/examples/examples/2b74476568f7562a6aa832d423132ed3/2b74476568f7562a6aa832d423132ed3_depth-anything-v2_Large.jpg" + ... ), + ... None, + ... ], + ... ] + >>> task_prompt = "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity." + >>> content_prompt = "Group photo of five young adults enjoying a rooftop gathering at dusk. The group is positioned in the center, with three women and two men smiling and embracing. The woman on the far left wears a floral top and holds a drink, looking slightly to the right. Next to her, a woman in a denim jacket stands close to a woman in a white blouse, both smiling directly at the camera. The fourth woman, in an orange top, stands close to the man on the far right, who wears a red shirt and blue blazer, smiling broadly. The background features a cityscape with a tall building and string lights hanging overhead, creating a warm, festive atmosphere. Soft natural lighting, warm color palette, shallow depth of field, intimate and joyful mood, slightly blurred background, urban rooftop setting, evening ambiance." + + >>> pipe = VisualClozePipeline.from_pretrained( + ... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU + + >>> image = pipe( + ... task_prompt=task_prompt, + ... content_prompt=content_prompt, + ... image=image, + ... upampling_height=1024, + ... upampling_width=1024, + ... upsampling_strength=0.4, + ... guidance_scale=30, + ... num_inference_steps=30, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0][0] + >>> image.save("visualcloze.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class VisualClozePipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The VisualCloze pipeline for image generation with visual context. Reference: + https://github.com/lzyhha/VisualCloze/tree/main This pipeline is designed to generate images based on visual + in-context examples. + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + resolution (`int`, *optional*, defaults to 384): + The resolution of each image when concatenating images from the query and in-context examples. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + resolution: int = 384, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.resolution = resolution + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VisualClozeProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels, resolution=resolution + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + layout_prompt: Union[str, List[str]], + task_prompt: Union[str, List[str]], + content_prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + layout_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the number of in-context examples and the number of images involved in + the task. + task_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the task intention. + content_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the content or caption of the target image to be generated. + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + if isinstance(layout_prompt, str): + layout_prompt = [layout_prompt] + task_prompt = [task_prompt] + content_prompt = [content_prompt] + + def _preprocess(prompt, content=False): + if prompt is not None: + return f"The last image of the last row depicts: {prompt}" if content else prompt + else: + return "" + + prompt = [ + f"{_preprocess(layout_prompt[i])} {_preprocess(task_prompt[i])} {_preprocess(content_prompt[i], content=True)}".strip() + for i in range(len(layout_prompt)) + ] + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + image, + task_prompt, + content_prompt, + upsampling_height, + upsampling_width, + strength, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if upsampling_height is not None and upsampling_height % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`upsampling_height`has to be divisible by {self.vae_scale_factor * 2} but are {upsampling_height}. Dimensions will be resized accordingly" + ) + if upsampling_width is not None and upsampling_width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`upsampling_width` have to be divisible by {self.vae_scale_factor * 2} but are {upsampling_width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Validate prompt inputs + if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None: + raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ") + + if task_prompt is None and content_prompt is None and prompt_embeds is None: + raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ") + + # Validate prompt types and consistency + if task_prompt is None: + raise ValueError("`task_prompt` is missing.") + + if task_prompt is not None and not isinstance(task_prompt, (str, list)): + raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}") + + if content_prompt is not None and not isinstance(content_prompt, (str, list)): + raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}") + + if isinstance(task_prompt, list) or isinstance(content_prompt, list): + if not isinstance(task_prompt, list) or not isinstance(content_prompt, list): + raise ValueError( + f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, " + f"got {type(task_prompt)} and {type(content_prompt)}" + ) + if len(content_prompt) != len(task_prompt): + raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.") + + for sample in image: + if not isinstance(sample, list) or not isinstance(sample[0], list): + raise ValueError("Each sample in the batch must have a 2D list of images.") + if len({len(row) for row in sample}) != 1: + raise ValueError("Each in-context example and query should contain the same number of images.") + if not any(img is None for img in sample[-1]): + raise ValueError("There are no targets in the query, which should be represented as None.") + for row in sample[:-1]: + if any(img is None for img in row): + raise ValueError("Images are missing in in-context examples.") + + # Validate embeddings + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + # Validate sequence length + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(image, vae_scale_factor, device, dtype): + latent_image_ids = [] + + for idx, img in enumerate(image, start=1): + img = img.squeeze(0) + channels, height, width = img.shape + + num_patches_h = height // vae_scale_factor // 2 + num_patches_w = width // vae_scale_factor // 2 + + patch_ids = torch.zeros(num_patches_h, num_patches_w, 3, device=device, dtype=dtype) + patch_ids[..., 0] = idx + patch_ids[..., 1] = torch.arange(num_patches_h, device=device, dtype=dtype)[:, None] + patch_ids[..., 2] = torch.arange(num_patches_w, device=device, dtype=dtype)[None, :] + + patch_ids = patch_ids.reshape(-1, 3) + latent_image_ids.append(patch_ids) + + return torch.cat(latent_image_ids, dim=0) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents with _latents->_latents_upsampling + def _unpack_latents_upsampling(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + @staticmethod + def _unpack_latents(latents, sizes, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + start = 0 + unpacked_latents = [] + for i in range(len(sizes)): + cur_size = sizes[i] + height = cur_size[0][0] // vae_scale_factor + width = sum([size[1] for size in cur_size]) // vae_scale_factor + + end = start + (height * width) // 4 + + cur_latents = latents[:, start:end] + cur_latents = cur_latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + cur_latents = cur_latents.permute(0, 3, 1, 4, 2, 5) + cur_latents = cur_latents.reshape(batch_size, channels // (2 * 2), height, width) + + unpacked_latents.append(cur_latents) + + start = end + + return unpacked_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + input_image, + input_mask, + timestep, + batch_size, + dtype, + device, + generator, + vae_scale_factor, + upsampling=False, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + def _prepare_single_batch(image, mask, gen): + """Helper function to prepare latents for a single batch.""" + with torch.autocast("cuda", dtype): + # Concatenate images and masks along width dimension + image = [torch.cat(img, dim=3).to(device, non_blocking=True) for img in image] + mask = [torch.cat(m, dim=3).to(device, non_blocking=True) for m in mask] + + # Generate latent image IDs + latent_image_ids = self._prepare_latent_image_ids(image, vae_scale_factor, device, dtype) + + # Encode images to latent space + with torch.no_grad(): + if not upsampling: + # For initial encoding, use actual images + image_latent = [self._encode_vae_image(img, gen) for img in image] + masked_image_latent = [img.clone() for img in image_latent] + else: + # For post-upsampling, use zero images for masked latents + image_latent = [self._encode_vae_image(img, gen) for img in image] + masked_image_latent = [self._encode_vae_image(img * 0, gen) for img in image] + + for i in range(len(image_latent)): + # Rearrange latents and masks for patch processing + num_channels_latents, height, width = image_latent[i].shape[1:] + image_latent[i] = self._pack_latents(image_latent[i], 1, num_channels_latents, height, width) + masked_image_latent[i] = self._pack_latents( + masked_image_latent[i], 1, num_channels_latents, height, width + ) + + # Rearrange masks for patch processing + num_channels_latents, height, width = mask[i].shape[1:] + mask[i] = mask[i].view( + 1, + num_channels_latents, + height // vae_scale_factor, + vae_scale_factor, + width // vae_scale_factor, + vae_scale_factor, + ) + mask[i] = mask[i].permute(0, 1, 3, 5, 2, 4) + mask[i] = mask[i].reshape( + 1, + num_channels_latents * (vae_scale_factor**2), + height // vae_scale_factor, + width // vae_scale_factor, + ) + mask[i] = self._pack_latents( + mask[i], + 1, + num_channels_latents * (vae_scale_factor**2), + height // vae_scale_factor, + width // vae_scale_factor, + ) + + # Concatenate along batch dimension + image_latent = torch.cat(image_latent, dim=1) + masked_image_latent = torch.cat(masked_image_latent, dim=1) + mask = torch.cat(mask, dim=1) + + return image_latent, masked_image_latent, mask, latent_image_ids + + # Process each batch + masked_image_latents = [] + image_latents = [] + masks = [] + latent_image_ids = [] + + for i in range(len(input_image)): + _image_latent, _masked_image_latent, _mask, _latent_image_ids = _prepare_single_batch( + input_image[i], input_mask[i], generator if isinstance(generator, torch.Generator) else generator[i] + ) + masked_image_latents.append(_masked_image_latent) + image_latents.append(_image_latent) + masks.append(_mask) + latent_image_ids.append(_latent_image_ids) + + # Concatenate all batches + masked_image_latents = torch.cat(masked_image_latents, dim=0) + image_latents = torch.cat(image_latents, dim=0) + masks = torch.cat(masks, dim=0) + + # Handle batch size expansion + if batch_size > masked_image_latents.shape[0]: + if batch_size % masked_image_latents.shape[0] == 0: + # Expand batches by repeating + additional_image_per_prompt = batch_size // masked_image_latents.shape[0] + masked_image_latents = torch.cat([masked_image_latents] * additional_image_per_prompt, dim=0) + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + masks = torch.cat([masks] * additional_image_per_prompt, dim=0) + else: + raise ValueError( + f"Cannot expand batch size from {masked_image_latents.shape[0]} to {batch_size}. " + "Batch sizes must be multiples of each other." + ) + + # Add noise to latents + noises = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noises).to(dtype=dtype) + + # Combine masked latents with masks + masked_image_latents = torch.cat((masked_image_latents, masks), dim=-1).to(dtype=dtype) + + return latents, masked_image_latents, latent_image_ids[0] + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + task_prompt: Union[str, List[str]] = None, + content_prompt: Union[str, List[str]] = None, + image: Optional[torch.FloatTensor] = None, + upsampling_height: Optional[int] = None, + upsampling_width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 30.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + upsampling_strength: float = 1.0, + ): + r""" + Function invoked when calling the VisualCloze pipeline for generation. + + Args: + task_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the task intention. + content_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the content or caption of the target image to be generated. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + upsampling_height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By + default, the image is upsampled by a factor of three, and the base resolution is determined by the + resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is + specified, the other will be automatically set based on the aspect ratio. + upsampling_width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By + default, the image is upsampled by a factor of three, and the base resolution is determined by the + resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is + specified, the other will be automatically set based on the aspect ratio. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 30.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + upsampling_strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image` when upsampling the results. Must be between 0 and + 1. The generated image is used as a starting point and more noise is added the higher the + `upsampling_strength`. The number of denoising steps depends on the amount of noise initially added. + When `upsampling_strength` is 1, added noise is maximum and the denoising process runs for the full + number of iterations specified in `num_inference_steps`. A value of 0 skips the upsampling step and + output the results at the resolution of `self.resolution`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image, + task_prompt, + content_prompt, + upsampling_height, + upsampling_width, + upsampling_strength, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + processor_output = self.image_processor.preprocess( + task_prompt, content_prompt, image, vae_scale_factor=self.vae_scale_factor + ) + + # Define call parameters + if processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], str): + batch_size = 1 + elif processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], list): + batch_size = len(processor_output["task_prompt"]) + + # Generate the target image latents by denoising the initial noise + # using the provided prompts and guidance scale + cloze_latents = self.denoise( + processor_output, + batch_size=batch_size, + num_inference_steps=num_inference_steps, + sigmas=sigmas, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + strength=1, + ) + + # Crop the target image + # Since the generated image is a concatenation of the conditional and target regions, + # we need to extract only the target regions based on their positions + images = [] + for b in range(len(cloze_latents)): + cur_image_size = processor_output["image_size"][b % batch_size] + cur_target_position = processor_output["target_position"][b % batch_size] + cur_latent = self._unpack_latents(cloze_latents[b].unsqueeze(0), cur_image_size, self.vae_scale_factor)[-1] + cur_latent = (cur_latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor + cur_image = self.vae.decode(cur_latent, return_dict=False)[0] + cur_image = self.image_processor.postprocess(cur_image)[0] + + start = 0 + cropped = [] + for i, size in enumerate(cur_image_size[-1]): + if cur_target_position[i]: + cropped.append(cur_image.crop((start, 0, start + size[1], size[0]))) + start += size[1] + images.append(cropped) + + # Upsampling the generated images + n_target_per_sample = [] + upsampling_image = [] + upsampling_task_prompt = [] + upsampling_content_prompt = [] + upsampling_generator = generator if isinstance(generator, (torch.Generator,)) else [] + for i in range(len(images)): + n_target_per_sample.append(len(images[i])) + for image in images[i]: + upsampling_image.append([[image]]) + upsampling_task_prompt.append(None) + upsampling_content_prompt.append(processor_output["content_prompt"][i % batch_size]) + if not isinstance(generator, (torch.Generator,)): + upsampling_generator.append(generator[i % num_images_per_prompt]) + + base_width, base_height = upsampling_image[0][0][0].size + if upsampling_height is None and upsampling_width is None: + upsampling_height = int(base_height * 3 / self.vae_scale_factor) * self.vae_scale_factor + upsampling_width = int(base_width * 3 / self.vae_scale_factor) * self.vae_scale_factor + elif upsampling_height is None: + upsampling_height = base_height * (upsampling_width / base_width) + upsampling_height = int(upsampling_height / self.vae_scale_factor) * self.vae_scale_factor + elif upsampling_width is None: + upsampling_width = base_width * (upsampling_height / base_height) + upsampling_width = int(upsampling_width / self.vae_scale_factor) * self.vae_scale_factor + + divisible = 2 * self.vae_scale_factor + upsampling_height = int(upsampling_height // divisible) * divisible + upsampling_width = int(upsampling_width // divisible) * divisible + + processor_output = self.image_processor.preprocess( + upsampling_task_prompt, + upsampling_content_prompt, + upsampling_image, + upsampling=True, + height=upsampling_height, + width=upsampling_width, + vae_scale_factor=self.vae_scale_factor, + ) + + # Upsampling the generated images through SDEdit (https://arxiv.org/abs/2108.01073), + # which enhances details and fix flaws. + # The amount of noise added to the initial image in `processor_output["init_image"]` + # is determined by the `upsampling_strength` parameter. + upsampling_latents = self.denoise( + processor_output, + batch_size=len(upsampling_image), + num_inference_steps=num_inference_steps, + sigmas=sigmas, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + generator=generator, + latents=latents, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + strength=upsampling_strength, + upsampling=True, + ) + + if output_type == "latent": + image = upsampling_latents + else: + latents = self._unpack_latents_upsampling( + upsampling_latents, upsampling_height, upsampling_width, self.vae_scale_factor + ) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if output_type == "pil": + # Each sample in the batch may have multiple output images. When returning as PIL images, + # these images cannot be concatenated. Therefore, for each sample, + # a list is used to represent all the output images. + output = [] + start = 0 + for n in n_target_per_sample: + output.append(image[start : start + n]) + start += n + else: + output = image + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return FluxPipelineOutput(images=output) + + def denoise( + self, + processor_output: dict = None, + batch_size: int = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 30.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + strength: float = 1.0, + upsampling: bool = False, + ): + device = self._execution_device + + # 1. Prepare prompt embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + layout_prompt=processor_output["layout_prompt"], + task_prompt=processor_output["task_prompt"], + content_prompt=processor_output["content_prompt"], + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 2. Prepare timesteps + # Calculate sequence length and shift factor + image_seq_len = sum( + (size[0] // self.vae_scale_factor // 2) * (size[1] // self.vae_scale_factor // 2) + for sample in processor_output["image_size"][0] + for size in sample + ) + + # Calculate noise schedule parameters + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + # Get timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + + # 3. Prepare latent variables + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents, masked_image_latents, latent_image_ids = self.prepare_latents( + processor_output["init_image"], + processor_output["mask"], + latent_timestep, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + vae_scale_factor=self.vae_scale_factor, + upsampling=upsampling, + ) + + # Calculate warmup steps + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Prepare guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 4. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=torch.cat((latents, masked_image_latents), dim=2), + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # Compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # Some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # XLA optimization + if XLA_AVAILABLE: + xm.mark_step() + + return latents diff --git a/tests/pipelines/visualcloze/__init__.py b/tests/pipelines/visualcloze/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/visualcloze/test_pipeline_visualcloze.py b/tests/pipelines/visualcloze/test_pipeline_visualcloze.py new file mode 100644 index 000000000000..f2fb2fa99fcd --- /dev/null +++ b/tests/pipelines/visualcloze/test_pipeline_visualcloze.py @@ -0,0 +1,203 @@ +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, VisualClozePipeline +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class VisualClozePipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = VisualClozePipeline + params = frozenset( + [ + "task_prompt", + "content_prompt", + "upsampling_height", + "upsampling_width", + "guidance_scale", + "prompt_embeds", + "pooled_prompt_embeds", + "upsampling_strength", + ] + ) + batch_params = frozenset(["task_prompt", "content_prompt", "image"]) + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=12, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=6, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[2, 2, 2], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + "resolution": 32, + } + + def get_dummy_inputs(self, device, seed=0): + # Create example images to simulate the input format required by VisualCloze + context_image = [ + Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8)) + for _ in range(2) + ] + query_image = [ + Image.fromarray( + floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8) + ), + None, + ] + + # Create an image list that conforms to the VisualCloze input format + image = [ + context_image, # In-Context example + query_image, # Query image + ] + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.", + "content_prompt": "A beautiful landscape with mountains and a lake", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "upsampling_height": 32, + "upsampling_width": 32, + "max_sequence_length": 48, + "output_type": "np", + "upsampling_strength": 0.4, + } + return inputs + + def test_visualcloze_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["content_prompt"] = "A different landscape with forests and rivers" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different + assert max_diff > 1e-6 + + def test_visualcloze_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"upsampling_height": height, "upsampling_width": width}) + image = pipe(**inputs).images[0][0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=1e-3) + + def test_upsampling_strength(self, expected_min_diff=1e-1): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + # Test different upsampling strengths + inputs["upsampling_strength"] = 0.2 + output_no_upsampling = pipe(**inputs).images[0] + + inputs["upsampling_strength"] = 0.8 + output_full_upsampling = pipe(**inputs).images[0] + + # Different upsampling strengths should produce different outputs + max_diff = np.abs(output_no_upsampling - output_full_upsampling).max() + assert max_diff > expected_min_diff + + def test_different_task_prompts(self, expected_min_diff=1e-1): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_original = pipe(**inputs).images[0] + + inputs["task_prompt"] = "A different task description for image generation" + output_different_task = pipe(**inputs).images[0] + + # Different task prompts should produce different outputs + max_diff = np.abs(output_original - output_different_task).max() + assert max_diff > expected_min_diff