Spaces:
Runtime error
Runtime error
| import argparse | |
| import torch | |
| import gradio as gr | |
| from torchvision import transforms | |
| from runner import MaskGIT | |
| import numpy as np | |
| import random | |
| import torchvision.utils as vutils | |
| class Args(argparse.Namespace): | |
| data_folder = "" | |
| vqgan_folder = "pretrained_maskgit/VQGAN" | |
| writer_log = "" | |
| data = "" | |
| mask_value = 1024 | |
| seed = 1 | |
| channel = 3 | |
| num_workers = 0 | |
| iter = 0 | |
| global_epoch = 0 | |
| lr = 1e-4 | |
| drop_label = 0.1 | |
| resume = True | |
| device = "cpu" | |
| print(device) | |
| debug = True | |
| test_only = False | |
| is_master = True | |
| is_multi_gpus = False | |
| vit_size = "base" | |
| vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth" | |
| img_size = 256 | |
| patch_size = 256 // 16 | |
| def set_seed(seed): | |
| if seed > 0: | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.backends.cudnn.enable = False | |
| torch.backends.cudnn.deterministic = True | |
| args = Args() | |
| maskgit = MaskGIT(args) | |
| # Function to perform image synthesis | |
| def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1): | |
| # Perform image synthesis using your model | |
| set_seed(seed) | |
| with torch.no_grad(): | |
| labels = [cls] * nb_img | |
| labels = torch.LongTensor(labels).to(args.device) | |
| gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w, | |
| randomize="linear", r_temp=r_temp, sched_mode="arccos", | |
| step=step)[0] | |
| # Post-process the output image (adjust based on your needs) | |
| output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True)) | |
| return output_image | |
| # Gradio Interface | |
| app = gr.Interface( | |
| fn=synthesize_image, | |
| inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16), | |
| gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)], | |
| outputs=gr.Image(), | |
| title="Image Synthesis using MaskGIT", | |
| ) | |
| # Launch the Gradio app | |
| app.launch() | |