Skip to content

Add basic implementation of AuraFlowImg2ImgPipeline #11340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

AstraliteHeart
Copy link
Contributor

@AstraliteHeart AstraliteHeart commented Apr 16, 2025

What does this PR do?

Adds a very mechanical conversion of other img2img pipelines (mostly SD3/Flux) to support AuraFlow. Seems to require a bit more strength (0.75+) compared to SDXL (my only point of reference that I've used a lot for I2I) but works fine and does not complain about GGUF (still need to check compilation).

Fixes # (issue)

Before submitting

Who can review?

@cloneofsimo @sayakpaul
@yiyixuxu @asomoza

@sayakpaul
Copy link
Member

Thanks for yet another contribution! Could you post a snippet and some results?

@AstraliteHeart
Copy link
Contributor Author

Apologies, had to clean up things and make the tests actually work.

The docstrings included in the CL should be a good snippet, i.e.

import torch
from diffusers import AuraFlowImg2ImgPipeline
import requests
from PIL import Image
from io import BytesIO

# download an initial image
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))

pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow-v0.3", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "A fantasy landscape, trending on artstation"
image = pipe(prompt=prompt, image=init_image, strength=0.75, num_inference_steps=50).images[0]
image.save("aura_flow_img2img.png")

Unfortunately seems that my math may be wrong somehow?

With strength 0.75

image

with strength 0.85

image

with strength 0.95

image

but with 0.9

image

@DN6, any ideas?

Comment on lines 422 to 426
# Compute latents
latents = mean + std * sample

# Scale latents
latents = latents * self.vae.config.scaling_factor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VAE config has ('latents_mean', None), ('latents_std', None) so I believe the code would be a noop but implemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay then init_latents *= self.vae.config.scaling_factor should be just fine. We can safely remove the other things.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good 👍🏽 I think some parts of the implementation have to be adjusted to look more similar to the existing Flux pipelines.


return timesteps, num_inference_steps - t_start

def prepare_img2img_latents(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be placed under prepare_latents like the other Img2Img pipelines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is what you intended, but done?

prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)

# 5. Prepare timesteps
timesteps, num_inference_steps = self.get_timesteps(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think timesteps need to be adjusted for strength and shift no?

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, probably :)

@AstraliteHeart
Copy link
Contributor Author

AstraliteHeart commented Apr 18, 2025

Unfortunately still seeing visual noise instead of the image at some values of strength so something in my math must be wrong.

@sayakpaul
Copy link
Member

Do those values generally tend to be higher?

@AstraliteHeart
Copy link
Contributor Author

image

@sayakpaul
Copy link
Member

A bit hard to see sorry

@AstraliteHeart
Copy link
Contributor Author

Weird, I can click on the image to get the full sized one (with an extra click)
It's the 0.88 - 0.94 range that is visual noise, and I think images above 0.94 not even using initial image.

@AstraliteHeart
Copy link
Contributor Author

I still have no idea what is going on but I think the code is correct yet AF may require some special work done. Here are my observations:

  1. When using certain input images, the VAE encoder, responsible for creating the initial latent representation (x0) of the input image, produces a latent distribution (latent_dist) where the standard deviation (std) component consistently collapses to effectively zero (e.g., std=0.0000, corresponding to a highly negative logvar).

  2. This variance collapse is observed even when ensuring the VAE was loaded and operated entirely in float32 precision. This confirms the issue is not merely an fp16 underflow problem during VAE computation but rather suggests the SDXL VAE predicts near-zero variance for this type of input (I was aware of different issues with SDXL VAE but not this specific behavior).

  3. Because the predicted std is zero, the subsequent step of sampling the initial latent variable x0 from this distribution (mean + std * noise) becomes deterministic, effectively yielding only the mean component (x0 = mean).

  4. To address the lack of variance in x0, I attempted an experiment in which the logvar output by the VAE was manually "clamped" to a minimum value (tested min_logvar = -10.0 resulting in std ≈ 0.0067, and min_logvar = -4.0 resulting in std ≈ 0.135) before sampling x0. This successfully introduced non-zero variance into the initial latent state. At this point my assumption was just - the issue is in VAE.

  5. Despite successfully injecting variance into x0 via clamping, the pipeline still produced noise/corrupted images when run at high strength values (e.g., strength=0.9).

  6. Crucially, the pipeline works reasonably well at lower strength values (e.g., strength=0.7), producing recognizable image outputs that incorporate the initial image structure.

The core issue no longer seems to be only the deterministic x0 caused by the initial VAE variance collapse (as fixing that didn't solve the high-strength problem). Instead, the failure at high strength (0.9) may stem from an instability in the denoising process itself when initiated from the very high noise levels corresponding to these high strengths. The process is stable when starting from the lower noise levels associated with moderate strength (0.7).

@sayakpaul
Copy link
Member

Thanks for the detailed analysis. Do these vary from AuraFlow and AuraFlow0.3?

@AstraliteHeart
Copy link
Contributor Author

Yes, I am seeing same behavior in 0.2.

I focused too much on VAE in the last comment - I don't think it's the root cause (after all SDXL works just fine) and perhaps the real issue is some kind numerical instability we are facing?

I've included 3 videos generated taking a snapshot of the model state each 5 frames - before, at the moment of the issue and after. At least looking at them I can't notice anything weird that can explain the problem

combined_steps_s0.85.mp4
combined_steps_s0.90.mp4
combined_steps_s0.95.mp4

I've also attempted to affect the problematic strength range by changing number of steps or guidance scale but it had no effect. Interestingly enabling use_karras_sigmas=True on the scheduler seems to "fix" the issue as I no longer can hit the noise output but I still experience an very sharp change from "low strength i2i" to "just t2i" at around 0.98 strength. Feels like I am missing something super obvious here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants