Spaces:
Runtime error
Runtime error
change path and name of app.py
Browse files- gradio_app.py → app.py +77 -77
gradio_app.py → app.py
RENAMED
|
@@ -1,77 +1,77 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import torch
|
| 3 |
-
import gradio as gr
|
| 4 |
-
from torchvision import transforms
|
| 5 |
-
from runner import MaskGIT
|
| 6 |
-
import numpy as np
|
| 7 |
-
import random
|
| 8 |
-
import torchvision.utils as vutils
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class Args(argparse.Namespace):
|
| 12 |
-
data_folder = ""
|
| 13 |
-
vqgan_folder =
|
| 14 |
-
writer_log = ""
|
| 15 |
-
data = ""
|
| 16 |
-
mask_value = 1024
|
| 17 |
-
seed = 1
|
| 18 |
-
channel = 3
|
| 19 |
-
num_workers = 0
|
| 20 |
-
iter = 0
|
| 21 |
-
global_epoch = 0
|
| 22 |
-
lr = 1e-4
|
| 23 |
-
drop_label = 0.1
|
| 24 |
-
resume = True
|
| 25 |
-
device = "cpu"
|
| 26 |
-
print(device)
|
| 27 |
-
debug = True
|
| 28 |
-
test_only = False
|
| 29 |
-
is_master = True
|
| 30 |
-
is_multi_gpus = False
|
| 31 |
-
vit_size = "base"
|
| 32 |
-
vit_folder =
|
| 33 |
-
img_size = 256
|
| 34 |
-
patch_size = 256 // 16
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def set_seed(seed):
|
| 38 |
-
if seed > 0:
|
| 39 |
-
torch.manual_seed(seed)
|
| 40 |
-
torch.cuda.manual_seed(seed)
|
| 41 |
-
np.random.seed(seed)
|
| 42 |
-
random.seed(seed)
|
| 43 |
-
torch.backends.cudnn.enable = False
|
| 44 |
-
torch.backends.cudnn.deterministic = True
|
| 45 |
-
|
| 46 |
-
args = Args()
|
| 47 |
-
maskgit = MaskGIT(args)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# Function to perform image synthesis
|
| 51 |
-
def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
|
| 52 |
-
# Perform image synthesis using your model
|
| 53 |
-
set_seed(seed)
|
| 54 |
-
with torch.no_grad():
|
| 55 |
-
labels = [cls] * nb_img
|
| 56 |
-
labels = torch.LongTensor(labels).to(args.device)
|
| 57 |
-
gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
|
| 58 |
-
randomize="linear", r_temp=r_temp, sched_mode="arccos",
|
| 59 |
-
step=step)[0]
|
| 60 |
-
|
| 61 |
-
# Post-process the output image (adjust based on your needs)
|
| 62 |
-
output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
|
| 63 |
-
|
| 64 |
-
return output_image
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# Gradio Interface
|
| 68 |
-
app = gr.Interface(
|
| 69 |
-
fn=synthesize_image,
|
| 70 |
-
inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
|
| 71 |
-
gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
|
| 72 |
-
outputs=gr.Image(),
|
| 73 |
-
title="Image Synthesis using MaskGIT",
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
# Launch the Gradio app
|
| 77 |
-
app.launch(
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from runner import MaskGIT
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import torchvision.utils as vutils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Args(argparse.Namespace):
|
| 12 |
+
data_folder = ""
|
| 13 |
+
vqgan_folder = "pretrained_maskgit/VQGAN"
|
| 14 |
+
writer_log = ""
|
| 15 |
+
data = ""
|
| 16 |
+
mask_value = 1024
|
| 17 |
+
seed = 1
|
| 18 |
+
channel = 3
|
| 19 |
+
num_workers = 0
|
| 20 |
+
iter = 0
|
| 21 |
+
global_epoch = 0
|
| 22 |
+
lr = 1e-4
|
| 23 |
+
drop_label = 0.1
|
| 24 |
+
resume = True
|
| 25 |
+
device = "cpu"
|
| 26 |
+
print(device)
|
| 27 |
+
debug = True
|
| 28 |
+
test_only = False
|
| 29 |
+
is_master = True
|
| 30 |
+
is_multi_gpus = False
|
| 31 |
+
vit_size = "base"
|
| 32 |
+
vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
|
| 33 |
+
img_size = 256
|
| 34 |
+
patch_size = 256 // 16
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def set_seed(seed):
|
| 38 |
+
if seed > 0:
|
| 39 |
+
torch.manual_seed(seed)
|
| 40 |
+
torch.cuda.manual_seed(seed)
|
| 41 |
+
np.random.seed(seed)
|
| 42 |
+
random.seed(seed)
|
| 43 |
+
torch.backends.cudnn.enable = False
|
| 44 |
+
torch.backends.cudnn.deterministic = True
|
| 45 |
+
|
| 46 |
+
args = Args()
|
| 47 |
+
maskgit = MaskGIT(args)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Function to perform image synthesis
|
| 51 |
+
def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
|
| 52 |
+
# Perform image synthesis using your model
|
| 53 |
+
set_seed(seed)
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
labels = [cls] * nb_img
|
| 56 |
+
labels = torch.LongTensor(labels).to(args.device)
|
| 57 |
+
gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
|
| 58 |
+
randomize="linear", r_temp=r_temp, sched_mode="arccos",
|
| 59 |
+
step=step)[0]
|
| 60 |
+
|
| 61 |
+
# Post-process the output image (adjust based on your needs)
|
| 62 |
+
output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
|
| 63 |
+
|
| 64 |
+
return output_image
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Gradio Interface
|
| 68 |
+
app = gr.Interface(
|
| 69 |
+
fn=synthesize_image,
|
| 70 |
+
inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
|
| 71 |
+
gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
|
| 72 |
+
outputs=gr.Image(),
|
| 73 |
+
title="Image Synthesis using MaskGIT",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Launch the Gradio app
|
| 77 |
+
app.launch()
|