Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,16 +4,12 @@ import random
|
|
| 4 |
import spaces
|
| 5 |
import torch
|
| 6 |
import time
|
| 7 |
-
import logging
|
| 8 |
from diffusers import DiffusionPipeline, AutoencoderTiny
|
| 9 |
# Using AttnProcessor2_0 for potential speedup with PyTorch 2.x
|
| 10 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 11 |
# Assuming custom_pipeline defines FluxWithCFGPipeline correctly
|
| 12 |
from custom_pipeline import FluxWithCFGPipeline
|
| 13 |
|
| 14 |
-
# --- Setup Logging ---
|
| 15 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 16 |
-
|
| 17 |
# --- Torch Optimizations ---
|
| 18 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 19 |
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
|
|
@@ -34,50 +30,36 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 34 |
pipe = None # Initialize pipe to None
|
| 35 |
|
| 36 |
try:
|
| 37 |
-
logging.info("Loading diffusion pipeline...")
|
| 38 |
pipe = FluxWithCFGPipeline.from_pretrained(
|
| 39 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
|
| 40 |
)
|
| 41 |
-
logging.info("Loading VAE...")
|
| 42 |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
| 43 |
|
| 44 |
-
logging.info(f"Moving pipeline to {device}...")
|
| 45 |
pipe.to(device)
|
| 46 |
|
| 47 |
# Apply optimizations
|
| 48 |
-
logging.info("Setting attention processor...")
|
| 49 |
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
| 50 |
pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
|
| 51 |
|
| 52 |
-
logging.info("Loading and fusing LoRA...")
|
| 53 |
pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
|
| 54 |
pipe.set_adapters(["better"], adapter_weights=[1.0])
|
| 55 |
pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
|
| 56 |
pipe.unload_lora_weights() # Unload after fusing
|
| 57 |
-
logging.info("LoRA fused and unloaded.")
|
| 58 |
|
| 59 |
# --- Compilation (Major Speed Optimization) ---
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# logging.info("Compiling VAE Encoder...")
|
| 63 |
-
# pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
|
| 64 |
-
# logging.info("Model compilation finished.")
|
| 65 |
|
| 66 |
# Clear cache after setup
|
| 67 |
if torch.cuda.is_available():
|
| 68 |
torch.cuda.empty_cache()
|
| 69 |
-
logging.info("CUDA cache cleared after setup.")
|
| 70 |
|
| 71 |
except Exception as e:
|
| 72 |
-
|
| 73 |
-
# Display error in Gradio if UI is already built, otherwise just log and exit.
|
| 74 |
-
# For simplicity here, we'll rely on the Gradio UI showing an error if `pipe` is None later.
|
| 75 |
-
# If running script directly, consider `sys.exit()`
|
| 76 |
-
# raise gr.Error(f"Failed to load models. Check logs for details. Error: {e}")
|
| 77 |
|
| 78 |
|
| 79 |
# --- Inference Function ---
|
| 80 |
-
@spaces.GPU(
|
| 81 |
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
|
| 82 |
"""Generates an image using the FLUX pipeline with error handling."""
|
| 83 |
|
|
@@ -85,10 +67,7 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
| 85 |
raise gr.Error("Diffusion pipeline failed to load. Cannot generate images.")
|
| 86 |
|
| 87 |
if not prompt or prompt.strip() == "":
|
| 88 |
-
# Return a blank image or previous result if prompt is empty?
|
| 89 |
-
# For now, raise warning and return None.
|
| 90 |
gr.Warning("Prompt is empty. Please enter a description.")
|
| 91 |
-
# Returning None for image, original seed, and error message
|
| 92 |
return None, seed, "Error: Empty prompt"
|
| 93 |
|
| 94 |
start_time = time.time()
|
|
@@ -105,8 +84,6 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
| 105 |
# Clamp steps
|
| 106 |
steps_to_use = max(MIN_INFERENCE_STEPS, min(steps_to_use, MAX_INFERENCE_STEPS))
|
| 107 |
|
| 108 |
-
logging.info(f"Generating image with prompt: '{prompt}', seed: {seed}, size: {width}x{height}, steps: {steps_to_use}")
|
| 109 |
-
|
| 110 |
try:
|
| 111 |
# Ensure generator is on the correct device
|
| 112 |
generator = torch.Generator(device=device).manual_seed(int(float(seed)))
|
|
@@ -127,18 +104,15 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
| 127 |
|
| 128 |
latency = time.time() - start_time
|
| 129 |
latency_str = f"Latency: {latency:.2f} seconds (Steps: {steps_to_use})"
|
| 130 |
-
logging.info(f"Image generated successfully. {latency_str}")
|
| 131 |
return result_img, seed, latency_str
|
| 132 |
|
| 133 |
except torch.cuda.OutOfMemoryError as e:
|
| 134 |
-
logging.error(f"CUDA OutOfMemoryError: {e}", exc_info=True)
|
| 135 |
# Clear cache and suggest reducing size/steps
|
| 136 |
if torch.cuda.is_available():
|
| 137 |
torch.cuda.empty_cache()
|
| 138 |
raise gr.Error("GPU ran out of memory. Try reducing the image width/height or the number of inference steps.")
|
| 139 |
|
| 140 |
except Exception as e:
|
| 141 |
-
logging.error(f"Error during image generation: {e}", exc_info=True)
|
| 142 |
# Clear cache just in case
|
| 143 |
if torch.cuda.is_available():
|
| 144 |
torch.cuda.empty_cache()
|
|
@@ -150,14 +124,12 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
| 150 |
# It's triggered by changes in prompt or sliders when realtime is enabled.
|
| 151 |
def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
|
| 152 |
if realtime_enabled and pipe is not None:
|
| 153 |
-
logging.debug("Realtime update triggered.")
|
| 154 |
# Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
|
| 155 |
# We don't set is_enhance=True for realtime updates.
|
| 156 |
return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
|
| 157 |
else:
|
| 158 |
# If realtime is disabled or pipe failed, don't update the image, seed, or latency.
|
| 159 |
# Return gr.update() for each output component to indicate no change.
|
| 160 |
-
logging.debug("Realtime update skipped (disabled or pipe error).")
|
| 161 |
return gr.update(), gr.update(), gr.update()
|
| 162 |
|
| 163 |
|
|
@@ -225,7 +197,8 @@ with gr.Blocks() as demo:
|
|
| 225 |
outputs=[result, seed, latency],
|
| 226 |
show_progress="full",
|
| 227 |
queue=False,
|
| 228 |
-
concurrency_limit=None
|
|
|
|
| 229 |
)
|
| 230 |
|
| 231 |
generateBtn.click(
|
|
@@ -251,9 +224,8 @@ with gr.Blocks() as demo:
|
|
| 251 |
concurrency_limit=None
|
| 252 |
)
|
| 253 |
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
return next(generate_image(*args[1:]))
|
| 257 |
|
| 258 |
prompt.submit(
|
| 259 |
fn=generate_image,
|
|
@@ -266,7 +238,7 @@ with gr.Blocks() as demo:
|
|
| 266 |
|
| 267 |
for component in [prompt, width, height, num_inference_steps]:
|
| 268 |
component.input(
|
| 269 |
-
fn=
|
| 270 |
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
| 271 |
outputs=[result, seed, latency],
|
| 272 |
show_progress="hidden",
|
|
@@ -274,6 +246,17 @@ with gr.Blocks() as demo:
|
|
| 274 |
queue=False,
|
| 275 |
concurrency_limit=None
|
| 276 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
# Launch the app
|
| 279 |
-
demo.launch()
|
|
|
|
| 4 |
import spaces
|
| 5 |
import torch
|
| 6 |
import time
|
|
|
|
| 7 |
from diffusers import DiffusionPipeline, AutoencoderTiny
|
| 8 |
# Using AttnProcessor2_0 for potential speedup with PyTorch 2.x
|
| 9 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 10 |
# Assuming custom_pipeline defines FluxWithCFGPipeline correctly
|
| 11 |
from custom_pipeline import FluxWithCFGPipeline
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
# --- Torch Optimizations ---
|
| 14 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 15 |
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
|
|
|
|
| 30 |
pipe = None # Initialize pipe to None
|
| 31 |
|
| 32 |
try:
|
|
|
|
| 33 |
pipe = FluxWithCFGPipeline.from_pretrained(
|
| 34 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
|
| 35 |
)
|
|
|
|
| 36 |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
| 37 |
|
|
|
|
| 38 |
pipe.to(device)
|
| 39 |
|
| 40 |
# Apply optimizations
|
|
|
|
| 41 |
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
| 42 |
pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
|
| 43 |
|
|
|
|
| 44 |
pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
|
| 45 |
pipe.set_adapters(["better"], adapter_weights=[1.0])
|
| 46 |
pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
|
| 47 |
pipe.unload_lora_weights() # Unload after fusing
|
|
|
|
| 48 |
|
| 49 |
# --- Compilation (Major Speed Optimization) ---
|
| 50 |
+
pipe.vae.decoder = torch.compile(pipe.vae.decoder, mode="reduce-overhead", fullgraph=True)
|
| 51 |
+
pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Clear cache after setup
|
| 54 |
if torch.cuda.is_available():
|
| 55 |
torch.cuda.empty_cache()
|
|
|
|
| 56 |
|
| 57 |
except Exception as e:
|
| 58 |
+
print(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
# --- Inference Function ---
|
| 62 |
+
@spaces.GPU() # Slightly increased duration buffer
|
| 63 |
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
|
| 64 |
"""Generates an image using the FLUX pipeline with error handling."""
|
| 65 |
|
|
|
|
| 67 |
raise gr.Error("Diffusion pipeline failed to load. Cannot generate images.")
|
| 68 |
|
| 69 |
if not prompt or prompt.strip() == "":
|
|
|
|
|
|
|
| 70 |
gr.Warning("Prompt is empty. Please enter a description.")
|
|
|
|
| 71 |
return None, seed, "Error: Empty prompt"
|
| 72 |
|
| 73 |
start_time = time.time()
|
|
|
|
| 84 |
# Clamp steps
|
| 85 |
steps_to_use = max(MIN_INFERENCE_STEPS, min(steps_to_use, MAX_INFERENCE_STEPS))
|
| 86 |
|
|
|
|
|
|
|
| 87 |
try:
|
| 88 |
# Ensure generator is on the correct device
|
| 89 |
generator = torch.Generator(device=device).manual_seed(int(float(seed)))
|
|
|
|
| 104 |
|
| 105 |
latency = time.time() - start_time
|
| 106 |
latency_str = f"Latency: {latency:.2f} seconds (Steps: {steps_to_use})"
|
|
|
|
| 107 |
return result_img, seed, latency_str
|
| 108 |
|
| 109 |
except torch.cuda.OutOfMemoryError as e:
|
|
|
|
| 110 |
# Clear cache and suggest reducing size/steps
|
| 111 |
if torch.cuda.is_available():
|
| 112 |
torch.cuda.empty_cache()
|
| 113 |
raise gr.Error("GPU ran out of memory. Try reducing the image width/height or the number of inference steps.")
|
| 114 |
|
| 115 |
except Exception as e:
|
|
|
|
| 116 |
# Clear cache just in case
|
| 117 |
if torch.cuda.is_available():
|
| 118 |
torch.cuda.empty_cache()
|
|
|
|
| 124 |
# It's triggered by changes in prompt or sliders when realtime is enabled.
|
| 125 |
def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
|
| 126 |
if realtime_enabled and pipe is not None:
|
|
|
|
| 127 |
# Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
|
| 128 |
# We don't set is_enhance=True for realtime updates.
|
| 129 |
return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
|
| 130 |
else:
|
| 131 |
# If realtime is disabled or pipe failed, don't update the image, seed, or latency.
|
| 132 |
# Return gr.update() for each output component to indicate no change.
|
|
|
|
| 133 |
return gr.update(), gr.update(), gr.update()
|
| 134 |
|
| 135 |
|
|
|
|
| 197 |
outputs=[result, seed, latency],
|
| 198 |
show_progress="full",
|
| 199 |
queue=False,
|
| 200 |
+
concurrency_limit=None,
|
| 201 |
+
fn_kwargs={"is_enhance": True} # Pass the flag to indicate enhance
|
| 202 |
)
|
| 203 |
|
| 204 |
generateBtn.click(
|
|
|
|
| 224 |
concurrency_limit=None
|
| 225 |
)
|
| 226 |
|
| 227 |
+
# Removed the intermediate realtime_generation function.
|
| 228 |
+
# handle_realtime_update checks the realtime toggle internally.
|
|
|
|
| 229 |
|
| 230 |
prompt.submit(
|
| 231 |
fn=generate_image,
|
|
|
|
| 238 |
|
| 239 |
for component in [prompt, width, height, num_inference_steps]:
|
| 240 |
component.input(
|
| 241 |
+
fn=handle_realtime_update, # Call the wrapper that checks the toggle
|
| 242 |
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
| 243 |
outputs=[result, seed, latency],
|
| 244 |
show_progress="hidden",
|
|
|
|
| 246 |
queue=False,
|
| 247 |
concurrency_limit=None
|
| 248 |
)
|
| 249 |
+
|
| 250 |
+
# Also trigger realtime on seed change if randomize is off
|
| 251 |
+
seed.input(
|
| 252 |
+
fn=handle_realtime_update,
|
| 253 |
+
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
| 254 |
+
outputs=[result, seed, latency],
|
| 255 |
+
show_progress="hidden",
|
| 256 |
+
trigger_mode="always_last",
|
| 257 |
+
queue=False,
|
| 258 |
+
concurrency_limit=None
|
| 259 |
+
)
|
| 260 |
|
| 261 |
# Launch the app
|
| 262 |
+
demo.launch()
|