import spaces import gradio as gr import torch import os import traceback from diffusers import ZImagePipeline from huggingface_hub import list_repo_files from PIL import Image from huggingface_hub import HfApi import tempfile, shutil, uuid from datetime import datetime HF_MODEL = os.environ.get("HF_UPLOAD_REPO", "rahul7star/Zimg-Feb2026") def upload_image_and_prompt_cpu(input_image, prompt_text) -> str: api = HfApi() today_str = datetime.now().strftime("%Y-%m-%d") unique_subfolder = f"Upload-Image-{uuid.uuid4().hex[:8]}" hf_folder = f"{today_str}/{unique_subfolder}" # ---- save temp image ---- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img: input_image.save(tmp_img.name, format="PNG") tmp_img_path = tmp_img.name api.upload_file( path_or_fileobj=tmp_img_path, path_in_repo=f"{hf_folder}/final_image.png", repo_id=HF_MODEL, repo_type="model", token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), ) # ---- save prompt ---- summary_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name with open(summary_file, "w", encoding="utf-8") as f: f.write(prompt_text) api.upload_file( path_or_fileobj=summary_file, path_in_repo=f"{hf_folder}/summary.txt", repo_id=HF_MODEL, repo_type="model", token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), ) os.remove(tmp_img_path) os.remove(summary_file) return hf_folder # ============================================================ # CONFIG # ============================================================ MODEL_ID = "Tongyi-MAI/Z-Image-Turbo" DEFAULT_LORA_REPO = "rahul7star/ZImageLora" DTYPE = torch.bfloat16 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ============================================================ # GLOBAL STATE # ============================================================ pipe = None CURRENT_LORA_REPO = None CURRENT_LORA_FILE = None # ============================================================ # LOGGING # ============================================================ def log(msg): print(msg) return msg # ============================================================ # PIPELINE BUILD (ONCE) # ============================================================ try: pipe = ZImagePipeline.from_pretrained( MODEL_ID, torch_dtype=DTYPE, ) pipe.to(DEVICE) log("βœ… Pipeline built successfully") except Exception as e: log("❌ Pipeline build failed") log(traceback.format_exc()) pipe = None # ============================================================ # HELPERS # ============================================================ def list_loras_from_repo(repo_id: str): try: files = list_repo_files(repo_id) return [f for f in files if f.endswith(".safetensors")] except Exception as e: log(f"❌ Failed to list LoRAs: {e}") return [] # ============================================================ # IMAGE GENERATION (SAFE LORA LOGIC) # ============================================================ @spaces.GPU() def generate_image(prompt, height, width, steps, seed, guidance_scale): LOGS = [] print(prompt) if pipe is None: return None, [], "❌ Pipeline not initialized" generator = torch.Generator().manual_seed(int(seed)) placeholder = Image.new("RGB", (width, height), (255, 255, 255)) previews = [] # ---- Always start clean ---- try: pipe.unload_lora_weights() except Exception: pass # ---- Load LoRA for this run only ---- if CURRENT_LORA_FILE: try: pipe.load_lora_weights( CURRENT_LORA_REPO, weight_name=CURRENT_LORA_FILE ) LOGS.append(f"🧩 LoRA loaded: {CURRENT_LORA_FILE}") except Exception as e: LOGS.append(f"❌ LoRA load failed: {e}") # ---- Preview steps (lightweight) ---- try: num_previews = min(5, steps) for i in range(num_previews): out = pipe( prompt=prompt, height=height // 4, width=width // 4, num_inference_steps=i + 1, guidance_scale=guidance_scale, generator=generator, ) img = out.images[0].resize((width, height)) previews.append(img) yield None, previews, "\n".join(LOGS) except Exception as e: LOGS.append(f"⚠️ Preview failed: {e}") # ---- Final image ---- try: out = pipe( prompt=prompt, height=height, width=width, num_inference_steps=steps, guidance_scale=guidance_scale, generator=generator, ) final_img = out.images[0] previews.append(final_img) LOGS.append("βœ… Image generated") # ============================================ # HF UPLOAD (CPU SAFE) # ============================================ try: folder = upload_image_and_prompt_cpu(final_img, prompt) LOGS.append(f"☁️ Uab") except Exception as upload_error: LOGS.append(f"⚠️ Uabc error") yield final_img, previews, "\n".join(LOGS) except Exception as e: LOGS.append(f"❌ Generation failed: {e}") yield placeholder, previews, "\n".join(LOGS) finally: # ---- CRITICAL: unload after run ---- try: pipe.unload_lora_weights() LOGS.append("🧹 LoRA unloaded") except Exception: pass # ============================================================ # GRADIO UI # ============================================================ css = """ .gradio-container { max-width: 100% !important; padding: 16px 32px !important; } .section { margin-bottom: 12px; } .generate-btn { background: linear-gradient(90deg, #4b6cb7, #182848) !important; color: white !important; font-weight: 600; height: 46px; border-radius: 10px; } .secondary-btn { height: 42px; border-radius: 10px; } textarea, input { border-radius: 10px !important; } """ with gr.Blocks( title="Z-Image-Turbo (Runtime LoRA)", css=css, ) as demo: gr.Markdown( """

πŸ’¬ Join Our Discord Community

Get support β€’ Share results β€’ Discuss LoRAs β€’ Report bugs

πŸš€ Join Discord
""", elem_id="discord-banner" ) # ====================================================== # MAIN LAYOUT # ====================================================== with gr.Row(): # ================= LEFT PANEL ================= with gr.Column(scale=5): # -------- Prompt -------- prompt = gr.Textbox( label="Prompt", value="boat in ocean", lines=4, placeholder="Describe the image you want to generate…", ) # -------- LoRA Controls (NEXT TO PROMPT) -------- gr.Markdown("### 🧩 LoRA Controls") lora_repo = gr.Textbox( label="LoRA Repository", value=DEFAULT_LORA_REPO, lines=2, placeholder="username/repo (e.g. rahul7star/ZImageLora)", ) lora_dropdown = gr.Dropdown( label="LoRA File", choices=[], interactive=True, ) with gr.Row(): refresh_btn = gr.Button("πŸ”„ Refresh LoRA List", elem_classes="secondary-btn") clear_lora_btn = gr.Button("❌ Clear LoRA", elem_classes="secondary-btn") # -------- Generation Controls -------- gr.Markdown("### βš™οΈ Generation Settings") with gr.Row(): width = gr.Slider(256, 2048, value=1024, step=8, label="Width") height = gr.Slider(256, 2048, value=1024, step=8, label="Height") with gr.Row(): steps = gr.Slider(1, 50, value=20, step=1, label="Steps") guidance = gr.Slider(0, 10, value=0.0, step=0.5, label="Guidance") seed = gr.Number(value=42, label="Seed", precision=0) run_btn = gr.Button("πŸš€ Generate Image", elem_classes="generate-btn") logs_box = gr.Textbox( label="Logs", lines=10, interactive=False, ) # ================= RIGHT PANEL ================= with gr.Column(scale=7): final_image = gr.Image( label="Final Image", height=520, ) gallery = gr.Gallery( label="Generation Steps", columns=4, height=260, ) # ====================================================== # CALLBACKS # ====================================================== def refresh_loras(repo): files = list_loras_from_repo(repo) return gr.update( choices=files, value=files[0] if files else None, ) refresh_btn.click( refresh_loras, inputs=[lora_repo], outputs=[lora_dropdown], ) def select_lora(lora_file, repo): global CURRENT_LORA_FILE, CURRENT_LORA_REPO CURRENT_LORA_FILE = lora_file CURRENT_LORA_REPO = repo return f"🧩 Selected LoRA: {lora_file}" lora_dropdown.change( select_lora, inputs=[lora_dropdown, lora_repo], outputs=[logs_box], ) def clear_lora(): global CURRENT_LORA_FILE, CURRENT_LORA_REPO CURRENT_LORA_FILE = None CURRENT_LORA_REPO = None try: pipe.unload_lora_weights() except Exception: pass return ( gr.update(value=None), "🧹 LoRA cleared β€” base model will be used." ) clear_lora_btn.click( clear_lora, outputs=[lora_dropdown, logs_box], ) run_btn.click( generate_image, inputs=[prompt, height, width, steps, seed, guidance], outputs=[final_image, gallery, logs_box], ) demo.launch()