llvictorll commited on
Commit
4272d88
·
verified ·
1 Parent(s): 9c4679e

change path and name of app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py → app.py +77 -77
gradio_app.py → app.py RENAMED
@@ -1,77 +1,77 @@
1
- import argparse
2
- import torch
3
- import gradio as gr
4
- from torchvision import transforms
5
- from runner import MaskGIT
6
- import numpy as np
7
- import random
8
- import torchvision.utils as vutils
9
-
10
-
11
- class Args(argparse.Namespace):
12
- data_folder = ""
13
- vqgan_folder = r"C:\Users\vbesnier\Experiment\VQGAN"
14
- writer_log = ""
15
- data = ""
16
- mask_value = 1024
17
- seed = 1
18
- channel = 3
19
- num_workers = 0
20
- iter = 0
21
- global_epoch = 0
22
- lr = 1e-4
23
- drop_label = 0.1
24
- resume = True
25
- device = "cpu"
26
- print(device)
27
- debug = True
28
- test_only = False
29
- is_master = True
30
- is_multi_gpus = False
31
- vit_size = "base"
32
- vit_folder = r"C:\Users\vbesnier\Experiment\MaskGIT\current.pth"
33
- img_size = 256
34
- patch_size = 256 // 16
35
-
36
-
37
- def set_seed(seed):
38
- if seed > 0:
39
- torch.manual_seed(seed)
40
- torch.cuda.manual_seed(seed)
41
- np.random.seed(seed)
42
- random.seed(seed)
43
- torch.backends.cudnn.enable = False
44
- torch.backends.cudnn.deterministic = True
45
-
46
- args = Args()
47
- maskgit = MaskGIT(args)
48
-
49
-
50
- # Function to perform image synthesis
51
- def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
52
- # Perform image synthesis using your model
53
- set_seed(seed)
54
- with torch.no_grad():
55
- labels = [cls] * nb_img
56
- labels = torch.LongTensor(labels).to(args.device)
57
- gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
58
- randomize="linear", r_temp=r_temp, sched_mode="arccos",
59
- step=step)[0]
60
-
61
- # Post-process the output image (adjust based on your needs)
62
- output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
63
-
64
- return output_image
65
-
66
-
67
- # Gradio Interface
68
- app = gr.Interface(
69
- fn=synthesize_image,
70
- inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
71
- gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
72
- outputs=gr.Image(),
73
- title="Image Synthesis using MaskGIT",
74
- )
75
-
76
- # Launch the Gradio app
77
- app.launch(share=True)
 
1
+ import argparse
2
+ import torch
3
+ import gradio as gr
4
+ from torchvision import transforms
5
+ from runner import MaskGIT
6
+ import numpy as np
7
+ import random
8
+ import torchvision.utils as vutils
9
+
10
+
11
+ class Args(argparse.Namespace):
12
+ data_folder = ""
13
+ vqgan_folder = "pretrained_maskgit/VQGAN"
14
+ writer_log = ""
15
+ data = ""
16
+ mask_value = 1024
17
+ seed = 1
18
+ channel = 3
19
+ num_workers = 0
20
+ iter = 0
21
+ global_epoch = 0
22
+ lr = 1e-4
23
+ drop_label = 0.1
24
+ resume = True
25
+ device = "cpu"
26
+ print(device)
27
+ debug = True
28
+ test_only = False
29
+ is_master = True
30
+ is_multi_gpus = False
31
+ vit_size = "base"
32
+ vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
33
+ img_size = 256
34
+ patch_size = 256 // 16
35
+
36
+
37
+ def set_seed(seed):
38
+ if seed > 0:
39
+ torch.manual_seed(seed)
40
+ torch.cuda.manual_seed(seed)
41
+ np.random.seed(seed)
42
+ random.seed(seed)
43
+ torch.backends.cudnn.enable = False
44
+ torch.backends.cudnn.deterministic = True
45
+
46
+ args = Args()
47
+ maskgit = MaskGIT(args)
48
+
49
+
50
+ # Function to perform image synthesis
51
+ def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
52
+ # Perform image synthesis using your model
53
+ set_seed(seed)
54
+ with torch.no_grad():
55
+ labels = [cls] * nb_img
56
+ labels = torch.LongTensor(labels).to(args.device)
57
+ gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
58
+ randomize="linear", r_temp=r_temp, sched_mode="arccos",
59
+ step=step)[0]
60
+
61
+ # Post-process the output image (adjust based on your needs)
62
+ output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
63
+
64
+ return output_image
65
+
66
+
67
+ # Gradio Interface
68
+ app = gr.Interface(
69
+ fn=synthesize_image,
70
+ inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
71
+ gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
72
+ outputs=gr.Image(),
73
+ title="Image Synthesis using MaskGIT",
74
+ )
75
+
76
+ # Launch the Gradio app
77
+ app.launch()