Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision.transforms as transforms | |
| from torchvision.utils import make_grid | |
| import gradio as gr | |
| from model import ( | |
| UNet, | |
| VQVAE, | |
| LinearNoiseScheduler, | |
| get_tokenizer_and_model, | |
| get_text_representation, | |
| dataset_params, | |
| diffusion_params, | |
| ldm_params, | |
| autoencoder_params, | |
| train_params, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| import json | |
| print("Gradio version:", gr.__version__) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Currently running on {device}") | |
| if device.type == "cuda": | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| # Download config and checkpoint files from HF Hub | |
| config_path = hf_hub_download( | |
| repo_id="RishabA/celeba-cond-ddpm", filename="config.json" | |
| ) | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| ldm_ckpt_path = hf_hub_download( | |
| repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth" | |
| ) | |
| vae_ckpt_path = hf_hub_download( | |
| repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth" | |
| ) | |
| # Instantiate and load the models | |
| unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device) | |
| vae = VQVAE( | |
| config["dataset_params"]["image_channels"], config["autoencoder_params"] | |
| ).to(device) | |
| unet_state = torch.load(ldm_ckpt_path, map_location=device) | |
| unet.load_state_dict(unet_state["model_state_dict"]) | |
| print(unet_state["epoch"]) | |
| vae_state = torch.load(vae_ckpt_path, map_location=device) | |
| vae.load_state_dict(vae_state["model_state_dict"]) | |
| unet.eval() | |
| vae.eval() | |
| print("Model and checkpoints loaded successfully!") | |
| def sample_ddpm_inference(text_prompt): | |
| """ | |
| Given a text prompt and (optionally) an image condition (as a PIL image), | |
| sample from the diffusion model and return a generated image (PIL image). | |
| """ | |
| mask_image_pil = None | |
| # Guidance scale controls the strength of classifier-free guidance | |
| guidance_scale = 1.0 | |
| image_display_rate = 10 | |
| # Create noise scheduler | |
| scheduler = LinearNoiseScheduler( | |
| num_timesteps=diffusion_params["num_timesteps"], | |
| beta_start=diffusion_params["beta_start"], | |
| beta_end=diffusion_params["beta_end"], | |
| ) | |
| # Get conditioning config from ldm_params | |
| condition_config = ldm_params.get("condition_config", None) | |
| condition_types = ( | |
| condition_config.get("condition_types", []) | |
| if condition_config is not None | |
| else [] | |
| ) | |
| # Load text tokenizer/model for conditioning | |
| text_model_type = condition_config["text_condition_config"]["text_embed_model"] | |
| text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device) | |
| # Get empty text representation for classifier-free guidance | |
| empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device) | |
| # Get text representation of the input prompt | |
| text_prompt_embed = get_text_representation( | |
| [text_prompt], text_tokenizer, text_model, device | |
| ) | |
| # Prepare image conditioning: | |
| # If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros. | |
| if "image" in condition_types: | |
| if mask_image_pil is not None: | |
| mask_transform = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| ( | |
| ldm_params["condition_config"]["image_condition_config"][ | |
| "image_condition_h" | |
| ], | |
| ldm_params["condition_config"]["image_condition_config"][ | |
| "image_condition_w" | |
| ], | |
| ) | |
| ), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| mask_tensor = ( | |
| mask_transform(mask_image_pil).unsqueeze(0).to(device) | |
| ) # (1, channels, H, W) | |
| else: | |
| # Create a zero mask with the required number of channels (e.g. 18) | |
| ic = ldm_params["condition_config"]["image_condition_config"][ | |
| "image_condition_input_channels" | |
| ] | |
| H = ldm_params["condition_config"]["image_condition_config"][ | |
| "image_condition_h" | |
| ] | |
| W = ldm_params["condition_config"]["image_condition_config"][ | |
| "image_condition_w" | |
| ] | |
| mask_tensor = torch.zeros((1, ic, H, W), device=device) | |
| else: | |
| mask_tensor = None | |
| # Build conditioning dictionaries for classifier-free guidance: | |
| # For unconditional, we use empty text and zero mask. | |
| uncond_input = {} | |
| cond_input = {} | |
| if "text" in condition_types: | |
| uncond_input["text"] = empty_text_embed | |
| cond_input["text"] = text_prompt_embed | |
| if "image" in condition_types: | |
| # Use zeros for unconditioning, and the provided mask for conditioning. | |
| uncond_input["image"] = torch.zeros_like(mask_tensor) | |
| cond_input["image"] = mask_tensor | |
| # Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat) | |
| # For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32. | |
| latent_size = dataset_params["image_size"] // ( | |
| 2 ** sum(autoencoder_params["down_sample"]) | |
| ) | |
| batch = train_params["num_samples"] | |
| z_channels = autoencoder_params["z_channels"] | |
| # Sample initial latent noise | |
| xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device) | |
| # Sampling loop (reverse diffusion) | |
| T = diffusion_params["num_timesteps"] | |
| for i in reversed(range(T)): | |
| t = torch.full((batch,), i, dtype=torch.long, device=device) | |
| with torch.no_grad(): | |
| # Get conditional noise prediction | |
| noise_pred_cond = unet(xt, t, cond_input) | |
| if guidance_scale > 1: | |
| noise_pred_uncond = unet(xt, t, uncond_input) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_cond - noise_pred_uncond | |
| ) | |
| else: | |
| noise_pred = noise_pred_cond | |
| xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t) | |
| if i % image_display_rate == 0 or i == 0: | |
| # Decode current latent into image | |
| generated = vae.decode(xt) | |
| generated = torch.clamp(generated, -1, 1) | |
| generated = (generated + 1) / 2 # scale to [0,1] | |
| grid = make_grid(generated, nrow=1) | |
| pil_img = transforms.ToPILImage()(grid.cpu()) | |
| yield pil_img | |
| css_str = """ | |
| body { | |
| background-color: #121212; | |
| color: #e0e0e0; | |
| font-family: Arial, sans-serif; | |
| } | |
| .container { | |
| max-width: 700px; | |
| margin: 15px auto; | |
| } | |
| h1 { | |
| font-size: 36px; | |
| font-weight: bold; | |
| text-align: center; | |
| color: #ffffff; | |
| } | |
| .description { | |
| font-size: 16px; | |
| text-align: center; | |
| color: #b0b0b0; | |
| } | |
| """ | |
| with gr.Blocks(css=css_str) as demo: | |
| gr.HTML("<div class='container'>") | |
| gr.Markdown("<h1>Conditioned Latent Diffusion of Faces</h1>") | |
| gr.Markdown( | |
| "<div class='description'>Enter a text prompt and watch the reverse diffusion process as the image is denoised and generates an image of a face.</div>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="E.g., 'He is a man with brown hair.'", | |
| ) | |
| generate_button = gr.Button("Generate Image") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image", type="pil") | |
| generate_button.click( | |
| fn=sample_ddpm_inference, | |
| inputs=[text_input], | |
| outputs=[output_image], | |
| ) | |
| gr.HTML("</div>") | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |