| import json |
| import os |
| import types |
| from urllib.parse import urlparse |
|
|
| import cv2 |
| import diffusers |
| import gradio as gr |
| import numpy as np |
| import spaces |
| import torch |
| from einops import rearrange |
| from huggingface_hub import hf_hub_download |
| from omegaconf import OmegaConf |
| from PIL import Image, ImageOps |
| from safetensors.torch import load_file |
| from torch.nn import functional as F |
| from torchdiffeq import odeint_adjoint as odeint |
|
|
| from echoflow.common import instantiate_class_from_config, unscale_latents |
| from echoflow.common.models import ( |
| ContrastiveModel, |
| DiffuserSTDiT, |
| ResNet18, |
| SegDiTTransformer2DModel, |
| ) |
|
|
| torch.set_grad_enabled(False) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| dtype = torch.float32 |
|
|
| print(f"Using device: {device}") |
|
|
| |
| B, T, C, H, W = 1, 64, 4, 28, 28 |
|
|
| VIEWS = ["A4C", "PSAX", "PLAX"] |
|
|
|
|
| def load_model(path): |
| if path.startswith("http"): |
| parsed_url = urlparse(path) |
| if "huggingface.co" in parsed_url.netloc: |
| parts = parsed_url.path.strip("/").split("/") |
| repo_id = "/".join(parts[:2]) |
|
|
| subfolder = None |
| if len(parts) > 3: |
| subfolder = "/".join(parts[4:]) |
|
|
| local_root = "./tmp" |
| local_dir = os.path.join(local_root, repo_id.replace("/", "_")) |
| if subfolder: |
| local_dir = os.path.join(local_root, subfolder) |
| os.makedirs(local_root, exist_ok=True) |
|
|
| config_file = hf_hub_download( |
| repo_id=repo_id, |
| subfolder=subfolder, |
| filename="config.json", |
| local_dir=local_root, |
| repo_type="model", |
| token=os.getenv("READ_HF_TOKEN"), |
| local_dir_use_symlinks=False, |
| ) |
|
|
| assert os.path.exists(config_file) |
|
|
| hf_hub_download( |
| repo_id=repo_id, |
| filename="diffusion_pytorch_model.safetensors", |
| subfolder=subfolder, |
| local_dir=local_root, |
| local_dir_use_symlinks=False, |
| token=os.getenv("READ_HF_TOKEN"), |
| ) |
|
|
| path = local_dir |
|
|
| model_root = os.path.join(config_file.split("config.json")[0]) |
| json_path = os.path.join(model_root, "config.json") |
| assert os.path.exists(json_path) |
|
|
| with open(json_path, "r") as f: |
| config = json.load(f) |
|
|
| klass_name = config["_class_name"] |
| klass = getattr(diffusers, klass_name, None) or globals().get(klass_name, None) |
| assert ( |
| klass is not None |
| ), f"Could not find class {klass_name} in diffusers or global scope." |
| assert hasattr( |
| klass, "from_pretrained" |
| ), f"Class {klass_name} does not support 'from_pretrained'." |
|
|
| return klass.from_pretrained(path) |
|
|
|
|
| def load_reid(path): |
| parsed_url = urlparse(path) |
| parts = parsed_url.path.strip("/").split("/") |
| repo_id = "/".join(parts[:2]) |
| subfolder = "/".join(parts[4:]) |
|
|
| local_root = "./tmp" |
|
|
| config_file = hf_hub_download( |
| repo_id=repo_id, |
| subfolder=subfolder, |
| filename="config.yaml", |
| local_dir=local_root, |
| repo_type="model", |
| token=os.getenv("READ_HF_TOKEN"), |
| local_dir_use_symlinks=False, |
| ) |
|
|
| weights_file = hf_hub_download( |
| repo_id=repo_id, |
| subfolder=subfolder, |
| filename="backbone.safetensors", |
| local_dir=local_root, |
| repo_type="model", |
| token=os.getenv("READ_HF_TOKEN"), |
| local_dir_use_symlinks=False, |
| ) |
|
|
| config = OmegaConf.load(config_file) |
| backbone = instantiate_class_from_config(config.backbone) |
| backbone = ContrastiveModel.patch_backbone( |
| backbone, config.model.args.in_channels, config.model.args.out_channels |
| ) |
| state_dict = load_file(weights_file) |
| backbone.load_state_dict(state_dict) |
| backbone = backbone.to(device, dtype=dtype) |
| backbone.eval() |
| return backbone |
|
|
|
|
| def get_vae_scaler(path): |
| scaler = torch.load(path) |
| scaler = {k: v.to(device) for k, v in scaler.items()} |
| return scaler |
|
|
|
|
| |
|
|
| lifm = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/lifm/FMiT-S2-4f4") |
| lifm = lifm.to(device, dtype=dtype) |
| lifm.eval() |
|
|
| vae = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/vae/avae-4f4") |
| vae = vae.to(device, dtype=dtype) |
| vae.eval() |
| vae_scaler = get_vae_scaler("assets/scaling.pt") |
|
|
| reid = { |
| "anatomies": { |
| "A4C": torch.cat( |
| [ |
| torch.load("assets/anatomies_dynamic.pt"), |
| torch.load("assets/anatomies_ped_a4c.pt"), |
| ], |
| dim=0, |
| ), |
| "PSAX": torch.load("assets/anatomies_ped_psax.pt"), |
| "PLAX": torch.load("assets/anatomies_lvh.pt"), |
| }, |
| "models": { |
| "A4C": load_reid( |
| "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/dynamic-4f4" |
| ), |
| "PSAX": load_reid( |
| "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/ped_psax-4f4" |
| ), |
| "PLAX": load_reid( |
| "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/lvh-4f4" |
| ), |
| }, |
| "tau": { |
| "A4C": 0.9997, |
| "PSAX": 0.9997, |
| "PLAX": 0.9997, |
| }, |
| } |
|
|
| lvfm = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/lvfm/FMvT-S2-4f4") |
| lvfm = lvfm.to(device, dtype=dtype) |
| lvfm.eval() |
|
|
|
|
| def load_default_mask(): |
| """Load the default mask from disk. If not found, return a blank black mask.""" |
| default_mask_path = os.path.join("assets", "default_mask.png") |
| try: |
| if os.path.exists(default_mask_path): |
| mask = Image.open(default_mask_path).convert("L") |
| |
| mask = mask.resize((400, 400), Image.Resampling.LANCZOS) |
| |
| mask = ImageOps.autocontrast(mask, cutoff=0) |
| return np.array(mask) |
| except Exception as e: |
| print(f"Error loading default mask: {e}") |
|
|
| |
| return np.zeros((400, 400), dtype=np.uint8) |
|
|
|
|
| def preprocess_mask(mask): |
| """Ensure mask is properly formatted for the model.""" |
| if mask is None: |
| return np.zeros((112, 112), dtype=np.uint8) |
|
|
| |
| if isinstance(mask, dict) and "composite" in mask: |
| |
| mask = mask["composite"] |
|
|
| |
| if isinstance(mask, np.ndarray): |
| mask_pil = Image.fromarray(mask) |
| else: |
| mask_pil = mask |
|
|
| |
| mask_pil = mask_pil.convert("L") |
|
|
| |
| mask_pil = ImageOps.autocontrast(mask_pil, cutoff=0) |
|
|
| |
| mask_pil = mask_pil.point(lambda p: 255 if p > 127 else 0) |
|
|
| |
| |
|
|
| |
| mask_pil = mask_pil.resize((112, 112), Image.Resampling.LANCZOS) |
|
|
| |
| return np.array(mask_pil) |
|
|
|
|
| @spaces.GPU(duration=3) |
| @torch.no_grad() |
| def generate_latent_image(mask, class_selection, sampling_steps=50): |
| """Generate a latent image based on mask, class selection, and sampling steps""" |
|
|
| |
| mask = preprocess_mask(mask) |
| mask = torch.from_numpy(mask).to(device, dtype=dtype) |
| mask = mask.unsqueeze(0).unsqueeze(0) |
| mask = F.interpolate(mask, size=(H, W), mode="bilinear", align_corners=False) |
| mask = 1.0 * (mask > 0) |
|
|
| |
|
|
| |
| class_idx = VIEWS.index(class_selection) |
| class_idx = torch.tensor([class_idx], device=device, dtype=torch.long) |
|
|
| |
| timesteps = torch.linspace( |
| 1.0, 0.0, steps=sampling_steps + 1, device=device, dtype=dtype |
| ) |
|
|
| forward_kwargs = { |
| "class_labels": class_idx, |
| "segmentation": mask, |
| } |
|
|
| z_1 = torch.randn( |
| (B, C, H, W), |
| device=device, |
| dtype=dtype, |
| |
| ) |
|
|
| lifm.forward_original = lifm.forward |
|
|
| def new_forward(self, t, y, *args, **kwargs): |
| kwargs = {**kwargs, **forward_kwargs} |
| return self.forward_original(y, t.view(1), *args, **kwargs).sample |
|
|
| lifm.forward = types.MethodType(new_forward, lifm) |
|
|
| |
| with torch.autocast("cuda"): |
| latent_image = odeint( |
| lifm, |
| z_1, |
| timesteps, |
| atol=1e-5, |
| rtol=1e-5, |
| adjoint_params=lifm.parameters(), |
| method="euler", |
| )[-1] |
|
|
| lifm.forward = lifm.forward_original |
|
|
| latent_image = latent_image.detach().cpu().numpy() |
|
|
| |
|
|
| return latent_image |
|
|
|
|
| @spaces.GPU(duration=3) |
| @torch.no_grad() |
| def decode_images(latents): |
| """Decode latent representations to pixel space using a VAE. |
| |
| Args: |
| latents: A numpy array of shape [B, C, H, W] for single image |
| or [B, C, T, H, W] for sequences/animations |
| |
| Returns: |
| numpy array of decoded images in [B, H, W, 3] format for single image |
| or [B, C, T, H, W] for sequences |
| """ |
| global vae |
| if latents is None: |
| return None |
|
|
| vae = vae.to(device, dtype=dtype) |
| vae.eval() |
|
|
| |
| if not isinstance(latents, torch.Tensor): |
| latents = torch.from_numpy(latents).to(device, dtype=dtype) |
|
|
| |
| latents = unscale_latents(latents, vae_scaler) |
|
|
| |
| is_sequence = len(latents.shape) == 5 |
|
|
| |
|
|
| if is_sequence: |
| B, C, T, H, W = latents.shape |
| latents = rearrange(latents[0], "c t h w -> t c h w") |
| else: |
| B, C, H, W = latents.shape |
|
|
| |
|
|
| with torch.no_grad(): |
| |
| |
| decoded = [] |
| for i in range(latents.shape[0]): |
| decoded.append(vae.decode(latents[i : i + 1].float()).sample) |
| decoded = torch.cat(decoded, dim=0) |
|
|
| decoded = (decoded + 1) * 128 |
| decoded = decoded.clamp(0, 255).to(torch.uint8).cpu() |
|
|
| if is_sequence: |
| |
| decoded = rearrange(decoded, "t c h w -> c t h w").unsqueeze(0) |
| else: |
| decoded = decoded.squeeze() |
| decoded = decoded.permute(1, 2, 0) |
|
|
| |
| return decoded.numpy() |
|
|
|
|
| def decode_latent_to_pixel(latent_image): |
| """Decode a single latent image to pixel space""" |
| if latent_image is None: |
| return None |
|
|
| |
| if len(latent_image.shape) == 3: |
| latent_image = latent_image[None, ...] |
|
|
| decoded_image = decode_images(latent_image) |
| decoded_image = cv2.resize( |
| decoded_image, (400, 400), interpolation=cv2.INTER_NEAREST |
| ) |
|
|
| return decoded_image |
|
|
|
|
| @spaces.GPU(duration=3) |
| @torch.no_grad() |
| def check_privacy(latent_image_numpy, class_selection): |
| """Check if the latent image is too similar to database images""" |
| latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype) |
| reid_model = reid["models"][class_selection].to(device, dtype=dtype) |
| real_anatomies = reid["anatomies"][class_selection] |
| tau = reid["tau"][class_selection] |
|
|
| with torch.no_grad(): |
| features = reid_model(latent_image).sigmoid().cpu() |
|
|
| corr = torch.corrcoef(torch.cat([real_anatomies, features], dim=0))[0, 1:] |
| corr = corr.max() |
|
|
| if corr > tau: |
| return ( |
| None, |
| f"⚠️ **Warning:** Generated image is too similar to training data. Privacy check failed.", |
| ) |
| else: |
| return ( |
| latent_image_numpy, |
| f"✅ **Success:** Generated image passed privacy check.", |
| ) |
|
|
|
|
| @spaces.GPU(duration=3) |
| @torch.no_grad() |
| def generate_animation( |
| latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0 |
| ): |
| """Generate an animated sequence of latent images based on EF""" |
| |
| |
| |
| |
| print("Generating animation...") |
|
|
| if latent_image is None: |
| return None |
|
|
| lvefs = torch.tensor([ejection_fraction / 100.0], device=device, dtype=dtype) |
| lvefs = lvefs[:, None, None].to(device, dtype) |
| uncond_lvefs = -1 * torch.ones_like(lvefs) |
|
|
| ref_images = torch.from_numpy(latent_image).to(device, dtype) |
| ref_images = ref_images[:, :, None, :, :] |
| ref_images = ref_images.repeat(1, 1, T, 1, 1) |
| uncond_images = torch.zeros_like(ref_images) |
|
|
| timesteps = torch.linspace( |
| 1.0, 0.0, steps=sampling_steps + 1, device=device, dtype=dtype |
| ) |
|
|
| forward_kwargs = { |
| "encoder_hidden_states": lvefs, |
| "cond_image": ref_images, |
| } |
|
|
| z_1 = torch.randn( |
| (B, C, T, H, W), |
| device=device, |
| dtype=dtype, |
| |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| lvfm.forward_original = lvfm.forward |
|
|
| def new_forward(self, t, y, *args, **kwargs): |
| kwargs = {**kwargs, **forward_kwargs} |
| |
|
|
| pred = self.forward_original(y, t.repeat(y.size(0)), *args, **kwargs).sample |
|
|
| if cfg_scale != 1.0: |
| uncond_kwargs = { |
| "encoder_hidden_states": uncond_lvefs, |
| "cond_image": uncond_images, |
| } |
| uncond_pred = self.forward_original( |
| y, t.repeat(y.size(0)), *args, **uncond_kwargs |
| ).sample |
|
|
| pred = uncond_pred + cfg_scale * (pred - uncond_pred) |
|
|
| return pred |
|
|
| lvfm.forward = types.MethodType(new_forward, lvfm) |
|
|
| with torch.autocast("cuda"): |
| synthetic_video = odeint( |
| lvfm, |
| z_1, |
| timesteps, |
| atol=1e-5, |
| rtol=1e-5, |
| adjoint_params=lvfm.parameters(), |
| method="euler", |
| )[-1] |
|
|
| lvfm.forward = lvfm.forward_original |
|
|
| |
|
|
| print("Animation generated") |
|
|
| return synthetic_video.detach().cpu() |
|
|
|
|
| @spaces.GPU(duration=3) |
| @torch.no_grad() |
| def decode_animation(latent_animation): |
| """Decode a latent animation to pixel space""" |
| if latent_animation is None: |
| return None |
|
|
| |
| if not isinstance(latent_animation, torch.Tensor): |
| latent_animation = torch.from_numpy(latent_animation) |
| latent_animation = latent_animation.to(device, dtype=dtype) |
|
|
| |
| if len(latent_animation.shape) == 4: |
| latent_animation = latent_animation[None, ...] |
|
|
| |
| decoded = decode_images(latent_animation) |
|
|
| |
| decoded = np.transpose(decoded[0], (1, 2, 3, 0)) |
|
|
| |
| decoded = np.stack( |
| [ |
| cv2.resize(frame, (400, 400), interpolation=cv2.INTER_NEAREST) |
| for frame in decoded |
| ] |
| ) |
|
|
| |
| temp_file = "temp_video_2.mp4" |
| fps = 32 |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| out = cv2.VideoWriter(temp_file, fourcc, fps, (400, 400)) |
|
|
| |
| for frame in decoded: |
| out.write(frame) |
| out.release() |
|
|
| return temp_file |
|
|
|
|
| def convert_latent_to_display(latent_image): |
| """Convert multi-channel latent image to grayscale for display""" |
| if latent_image is None: |
| return None |
|
|
| |
| if len(latent_image.shape) == 4: |
| |
| display_image = np.squeeze(latent_image, axis=0) |
| display_image = np.mean(display_image, axis=0) |
| elif len(latent_image.shape) == 3: |
| |
| display_image = np.mean(latent_image, axis=0) |
| else: |
| display_image = latent_image |
|
|
| |
| display_image = (display_image - display_image.min()) / ( |
| display_image.max() - display_image.min() + 1e-8 |
| ) |
|
|
| |
| display_image = (display_image * 255).astype(np.uint8) |
|
|
| |
| display_image = cv2.resize( |
| display_image, (400, 400), interpolation=cv2.INTER_NEAREST |
| ) |
|
|
| return display_image |
|
|
|
|
| @spaces.GPU(duration=3) |
| @torch.no_grad() |
| def latent_animation_to_grayscale(latent_animation): |
| """Convert multi-channel latent animation to grayscale for display""" |
| if latent_animation is None: |
| return None |
|
|
| |
|
|
| |
| if torch.is_tensor(latent_animation): |
| latent_animation = latent_animation.detach().cpu().numpy() |
|
|
| |
| if len(latent_animation.shape) == 5: |
| latent_animation = np.squeeze(latent_animation, axis=0) |
| latent_animation = np.transpose(latent_animation, (1, 0, 2, 3)) |
|
|
| |
|
|
| |
| latent_animation = np.mean(latent_animation, axis=1) |
|
|
| |
|
|
| |
| min_vals = latent_animation.min(axis=(1, 2), keepdims=True) |
| max_vals = latent_animation.max(axis=(1, 2), keepdims=True) |
| latent_animation = (latent_animation - min_vals) / (max_vals - min_vals + 1e-8) |
|
|
| |
| latent_animation = (latent_animation * 255).astype(np.uint8) |
|
|
| |
|
|
| |
| resized_frames = [] |
| for frame in latent_animation: |
| resized = cv2.resize(frame, (400, 400), interpolation=cv2.INTER_NEAREST) |
| resized_frames.append(resized) |
|
|
| |
| grayscale_video = np.stack(resized_frames) |
|
|
| |
|
|
| |
| grayscale_video = grayscale_video[..., None].repeat(3, axis=-1) |
|
|
| |
|
|
| |
| temp_file = "temp_video.mp4" |
| fps = 32 |
|
|
| |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| out = cv2.VideoWriter(temp_file, fourcc, fps, (400, 400)) |
|
|
| |
| for frame in grayscale_video: |
| out.write(frame) |
|
|
| out.release() |
|
|
| return temp_file |
|
|
|
|
| |
| def load_view_mask(view): |
| mask_path = f"assets/{view.lower()}_seg.png" |
| try: |
| mask_image = Image.open(mask_path).convert("L") |
| mask_image = mask_image.resize((400, 400), Image.Resampling.LANCZOS) |
| |
| mask_image = ImageOps.autocontrast(mask_image, cutoff=0) |
| mask_array = np.array(mask_image) |
|
|
| |
| editor_value = { |
| "background": np.zeros((400, 400), dtype=np.uint8), |
| "layers": [mask_array], |
| "composite": mask_array, |
| } |
| return editor_value |
| except Exception as e: |
| print(f"Error loading mask for view {view}: {e}") |
| return None |
|
|
|
|
| custom_js = """ |
| <script> |
| console.log("Hello, world!"); |
| (function() { |
| // Poll every 100ms for the existence of the header row |
| const intervalId = setInterval(() => { |
| console.log("Polling for header row"); |
| const headerRow = document.querySelector("tr.tr-head"); |
| if (headerRow) { |
| const headers = headerRow.querySelectorAll("th"); |
| headers.forEach(cell => { |
| const text = cell.innerText.trim(); |
| if (text === "Binary Mask") { |
| cell.innerText = "Mask"; |
| } else if (text === "View Class") { |
| cell.innerText = "View"; |
| } else if (text === "Number of Sampling Steps") { |
| cell.innerText = "Img Samp. Steps"; |
| } else if (text === "Ejection Fraction (%)") { |
| cell.innerText = "EF %"; |
| } else if (text === "Number of Sampling Steps.") { |
| cell.innerText = "Video Samp. Steps"; |
| } else if (text === "Classifier-Free Guidance Scale") { |
| cell.innerText = "CFG"; |
| } else if (text === "Filtered Latent Image") { |
| cell.innerText = "Filtered Image"; |
| } |
| }); |
| clearInterval(intervalId); |
| console.log("Headers updated."); |
| } |
| }, 500); |
| })(); |
| </script> |
| """ |
|
|
|
|
| def create_demo(): |
|
|
| black_background = np.zeros((400, 400), dtype=np.uint8) |
|
|
| |
| try: |
| mask_image = Image.open("assets/a4c_seg.png").convert("L") |
| mask_image = mask_image.resize((400, 400), Image.Resampling.LANCZOS) |
| |
| mask_image = ImageOps.autocontrast(mask_image, cutoff=0) |
| mask_image = mask_image.point(lambda p: 255 if p > 127 else 0) |
| mask_array = np.array(mask_image) |
|
|
| |
| editor_value = { |
| "background": black_background, |
| "layers": [mask_array], |
| "composite": mask_array, |
| } |
| except Exception as e: |
| print(f"Error loading mask image: {e}") |
| |
| editor_value = black_background |
|
|
| |
| mask_input = gr.ImageEditor( |
| label="Binary Mask", |
| height=400, |
| width=400, |
| image_mode="L", |
| value=editor_value, |
| type="numpy", |
| brush=gr.Brush( |
| colors=["#ffffff"], |
| color_mode="fixed", |
| default_size=20, |
| default_color="#ffffff", |
| ), |
| eraser=gr.Eraser(default_size=20), |
| show_download_button=True, |
| sources=[], |
| canvas_size=(400, 400), |
| fixed_canvas=True, |
| layers=False, |
| render=False, |
| ) |
|
|
| class_selection = gr.Radio( |
| choices=["A4C", "PSAX", "PLAX"], |
| label="View Class", |
| value="A4C", |
| render=False, |
| ) |
|
|
| sampling_steps = gr.Slider( |
| minimum=1, |
| maximum=200, |
| value=100, |
| step=1, |
| label="Number of Sampling Steps", |
| render=False, |
| ) |
|
|
| ef_slider = gr.Slider( |
| minimum=0, |
| maximum=100, |
| value=65, |
| label="Ejection Fraction (%)", |
| render=False, |
| ) |
|
|
| animation_steps = gr.Slider( |
| minimum=1, |
| maximum=200, |
| value=100, |
| step=1, |
| label="Number of Sampling Steps.", |
| render=False, |
| ) |
|
|
| cfg_slider = gr.Slider( |
| minimum=0, |
| maximum=10, |
| value=1, |
| step=1, |
| label="Classifier-Free Guidance Scale", |
| render=False, |
| ) |
|
|
| latent_image_display = gr.Image( |
| label="Latent Image", |
| type="numpy", |
| height=400, |
| width=400, |
| render=False, |
| ) |
|
|
| decoded_image_display = gr.Image( |
| label="Decoded Image", |
| type="numpy", |
| height=400, |
| width=400, |
| render=False, |
| ) |
|
|
| privacy_status = gr.Markdown(render=False) |
|
|
| filtered_latent_display = gr.Image( |
| label="Filtered Latent Image", |
| type="numpy", |
| height=400, |
| width=400, |
| render=False, |
| ) |
|
|
| latent_animation_display = gr.Video( |
| label="Latent Video", |
| format="mp4", |
| render=False, |
| autoplay=True, |
| loop=True, |
| ) |
|
|
| decoded_animation_display = gr.Video( |
| label="Decoded Video", |
| format="mp4", |
| render=False, |
| autoplay=True, |
| loop=True, |
| ) |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), head=custom_js) as demo: |
| gr.Markdown( |
| "# EchoFlow: A Foundation Model for Cardiac Ultrasound Image and Video Generation" |
| ) |
| gr.Markdown("## Preprint: https://arxiv.org/abs/2503.22357") |
| gr.Markdown("## Dataset Generation Pipeline") |
|
|
| gr.Markdown( |
| """ |
| This demo showcases EchoFlow's ability to generate synthetic echocardiogram images and videos while preserving patient privacy. The pipeline consists of four main steps: |
| |
| 1. **Latent Image Generation**: Draw a mask to indicate the region where the Left Ventricle should appear. Select the desired cardiac view, and click "Generate Latent Image". This outputs a latent image, which can be decoded into a pixel space image by clicking "Decode to Pixel Space". |
| 2. **Privacy Filter**: When clicking "Run Privacy Check", the generated image will be checked against a database of all training anatomies to ensure it is sufficiently different from real patient data. |
| 3. **Latent Video Generation**: If the privacy check passes, the latent image can be animated into a video with the desired Ejection Fraction. |
| 4. **Video Decoding**: The video can be decoded back to pixel space by clicking "Decode Video". |
| |
| ### ⚙️ Parameters |
| - **Sampling Steps**: Higher values produce better quality but take longer |
| - **Ejection Fraction**: Controls the strength of heart contraction in the animation |
| - **CFG Scale**: Controls how closely the animation follows the specified conditions |
| """ |
| ) |
|
|
| def load_example( |
| mask, |
| view, |
| steps, |
| ef, |
| anim_steps, |
| cfg, |
| latent, |
| decoded, |
| status, |
| filtered, |
| latent_vid, |
| decoded_vid, |
| ): |
| |
| |
| return [ |
| mask, |
| view, |
| steps, |
| ef, |
| anim_steps, |
| cfg, |
| latent, |
| decoded, |
| status, |
| filtered, |
| latent_vid, |
| decoded_vid, |
| ] |
|
|
| |
| examples = gr.Examples( |
| examples=[ |
| |
| [ |
| |
| { |
| "background": np.zeros((400, 400), dtype=np.uint8), |
| "layers": [ |
| np.array( |
| Image.open("assets/a4c_seg.png") |
| .convert("L") |
| .resize((400, 400)) |
| ) |
| ], |
| "composite": np.array( |
| Image.open("assets/a4c_seg.png") |
| .convert("L") |
| .resize((400, 400)) |
| ), |
| }, |
| "A4C", |
| 100, |
| 65, |
| 100, |
| 1.0, |
| |
| Image.open("assets/examples/a4c_latent.png"), |
| Image.open("assets/examples/a4c_decoded.png"), |
| "✅ **Success:** Generated image passed privacy check.", |
| Image.open("assets/examples/a4c_filtered.png"), |
| "assets/examples/a4c_latent.mp4", |
| "assets/examples/a4c_decoded.mp4", |
| ], |
| |
| [ |
| |
| { |
| "background": np.zeros((400, 400), dtype=np.uint8), |
| "layers": [ |
| np.array( |
| Image.open("assets/psax_seg.png") |
| .convert("L") |
| .resize((400, 400)) |
| ) |
| ], |
| "composite": np.array( |
| Image.open("assets/psax_seg.png") |
| .convert("L") |
| .resize((400, 400)) |
| ), |
| }, |
| "PSAX", |
| 100, |
| 65, |
| 100, |
| 1.0, |
| |
| Image.open("assets/examples/psax_latent.png"), |
| Image.open("assets/examples/psax_decoded.png"), |
| "✅ **Success:** Generated image passed privacy check.", |
| Image.open("assets/examples/psax_filtered.png"), |
| "assets/examples/psax_latent.mp4", |
| "assets/examples/psax_decoded.mp4", |
| ], |
| |
| [ |
| |
| { |
| "background": np.zeros((400, 400), dtype=np.uint8), |
| "layers": [ |
| np.array( |
| Image.open("assets/plax_seg.png") |
| .convert("L") |
| .resize((400, 400)) |
| ) |
| ], |
| "composite": np.array( |
| Image.open("assets/plax_seg.png") |
| .convert("L") |
| .resize((400, 400)) |
| ), |
| }, |
| "PLAX", |
| 100, |
| 65, |
| 100, |
| 1.0, |
| |
| Image.open("assets/examples/plax_latent.png"), |
| Image.open("assets/examples/plax_decoded.png"), |
| "✅ **Success:** Generated image passed privacy check.", |
| Image.open("assets/examples/plax_filtered.png"), |
| "assets/examples/plax_latent.mp4", |
| "assets/examples/plax_decoded.mp4", |
| ], |
| ], |
| inputs=[ |
| mask_input, |
| class_selection, |
| sampling_steps, |
| ef_slider, |
| animation_steps, |
| cfg_slider, |
| latent_image_display, |
| decoded_image_display, |
| privacy_status, |
| filtered_latent_display, |
| latent_animation_display, |
| decoded_animation_display, |
| ], |
| fn=load_example, |
| label="Click on an example to see the results immediately.", |
| examples_per_page=3, |
| ) |
|
|
| |
| with gr.Row(): |
| |
| with gr.Column(): |
| gr.Markdown( |
| '<img src="https://huggingface.co/spaces/HReynaud/EchoFlow/resolve/main/assets/h1.png" style="width: 100%; height: 75px; object-fit: contain;">' |
| ) |
| gr.Markdown("### Latent Image Generation") |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("Draw the LV mask (white = region of interest)") |
| |
| black_background = np.zeros((400, 400), dtype=np.uint8) |
|
|
| |
| try: |
| mask_image = Image.open("assets/a4c_seg.png").convert("L") |
| mask_image = mask_image.resize( |
| (400, 400), Image.Resampling.LANCZOS |
| ) |
| |
| mask_image = ImageOps.autocontrast(mask_image, cutoff=0) |
| mask_image = mask_image.point( |
| lambda p: 255 if p > 127 else 0 |
| ) |
| mask_array = np.array(mask_image) |
|
|
| |
| editor_value = { |
| "background": black_background, |
| "layers": [mask_array], |
| "composite": mask_array, |
| } |
| except Exception as e: |
| print(f"Error loading mask image: {e}") |
| |
| editor_value = black_background |
|
|
| |
| mask_input.render() |
| class_selection.render() |
| sampling_steps.render() |
|
|
| |
| generate_btn = gr.Button("Generate Latent Image", variant="primary") |
|
|
| |
| latent_image_display.render() |
|
|
| |
| decode_btn = gr.Button( |
| "Decode to Pixel Space (Optional)", |
| interactive=False, |
| variant="primary", |
| ) |
|
|
| |
| decoded_image_display.render() |
|
|
| |
| with gr.Column(): |
| gr.Markdown( |
| '<img src="https://huggingface.co/spaces/HReynaud/EchoFlow/resolve/main/assets/h2.png" style="width: 100%; height: 75px; object-fit: contain;">' |
| ) |
| gr.Markdown("### Privacy Filter") |
| gr.Markdown( |
| "Checks if the generated image is too similar to training data" |
| ) |
|
|
| |
| privacy_btn = gr.Button( |
| "Run Privacy Check", interactive=False, variant="primary" |
| ) |
|
|
| |
| privacy_status.render() |
|
|
| |
| filtered_latent_display.render() |
|
|
| |
| with gr.Column(): |
| gr.Markdown( |
| '<img src="https://huggingface.co/spaces/HReynaud/EchoFlow/resolve/main/assets/h3.png" style="width: 100%; height: 75px; object-fit: contain;">' |
| ) |
| gr.Markdown("### Latent Video Generation") |
|
|
| |
| ef_slider.render() |
| animation_steps.render() |
| cfg_slider.render() |
|
|
| |
| animate_btn = gr.Button( |
| "Generate Video", interactive=False, variant="primary" |
| ) |
|
|
| |
| latent_animation_display.render() |
|
|
| |
| with gr.Column(): |
| gr.Markdown( |
| '<img src="https://huggingface.co/spaces/HReynaud/EchoFlow/resolve/main/assets/h4.png" style="width: 100%; height: 75px; object-fit: contain;">' |
| ) |
| gr.Markdown("### Video Decoding") |
|
|
| |
| decode_animation_btn = gr.Button( |
| "Decode Video", interactive=False, variant="primary" |
| ) |
|
|
| |
| decoded_animation_display.render() |
|
|
| |
| latent_image_state = gr.State(None) |
| filtered_latent_state = gr.State(None) |
| latent_animation_state = gr.State(None) |
|
|
| |
| class_selection.change( |
| fn=load_view_mask, |
| inputs=[class_selection], |
| outputs=[mask_input], |
| queue=False, |
| ) |
|
|
| generate_btn.click( |
| fn=generate_latent_image, |
| inputs=[mask_input, class_selection, sampling_steps], |
| outputs=[latent_image_state], |
| queue=True, |
| ).then( |
| fn=convert_latent_to_display, |
| inputs=[latent_image_state], |
| outputs=[latent_image_display], |
| queue=False, |
| ).then( |
| fn=lambda x: gr.Button( |
| interactive=x is not None |
| ), |
| inputs=[latent_image_state], |
| outputs=[decode_btn], |
| queue=False, |
| ).then( |
| fn=lambda x: gr.Button( |
| interactive=x is not None |
| ), |
| inputs=[latent_image_state], |
| outputs=[privacy_btn], |
| queue=False, |
| ) |
|
|
| decode_btn.click( |
| fn=decode_latent_to_pixel, |
| inputs=[latent_image_state], |
| outputs=[decoded_image_display], |
| queue=True, |
| ).then( |
| fn=lambda x: gr.Button( |
| interactive=x is not None |
| ), |
| inputs=[decoded_image_display], |
| outputs=[privacy_btn], |
| queue=False, |
| ) |
|
|
| privacy_btn.click( |
| fn=check_privacy, |
| inputs=[latent_image_state, class_selection], |
| outputs=[filtered_latent_state, privacy_status], |
| queue=True, |
| ).then( |
| fn=convert_latent_to_display, |
| inputs=[filtered_latent_state], |
| outputs=[filtered_latent_display], |
| queue=False, |
| ).then( |
| fn=lambda x: gr.Button( |
| interactive=x is not None |
| ), |
| inputs=[filtered_latent_state], |
| outputs=[animate_btn], |
| queue=False, |
| ) |
|
|
| animate_btn.click( |
| fn=generate_animation, |
| inputs=[filtered_latent_state, ef_slider, animation_steps, cfg_slider], |
| outputs=[latent_animation_state], |
| queue=True, |
| ).then( |
| fn=latent_animation_to_grayscale, |
| inputs=[latent_animation_state], |
| outputs=[latent_animation_display], |
| queue=False, |
| ).then( |
| fn=lambda x: gr.Button( |
| interactive=x is not None |
| ), |
| inputs=[latent_animation_state], |
| outputs=[decode_animation_btn], |
| queue=False, |
| ) |
|
|
| decode_animation_btn.click( |
| fn=decode_animation, |
| inputs=[latent_animation_state], |
| outputs=[decoded_animation_display], |
| queue=True, |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = create_demo() |
| demo.launch() |
|
|