HZSDU commited on
Commit
aadbb8e
·
verified ·
1 Parent(s): d858b89

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. ad.ipynb +517 -0
  2. checkpoints/ostracoda_cyclegan/test_opt.txt +42 -0
  3. configs/config.yaml +52 -0
  4. cyclegan_model/data/__init__.py +89 -0
  5. cyclegan_model/data/__pycache__/__init__.cpython-310.pyc +0 -0
  6. cyclegan_model/data/__pycache__/base_dataset.cpython-310.pyc +0 -0
  7. cyclegan_model/data/__pycache__/image_folder.cpython-310.pyc +0 -0
  8. cyclegan_model/data/__pycache__/single_dataset.cpython-310.pyc +0 -0
  9. cyclegan_model/data/__pycache__/unaligned_dataset.cpython-310.pyc +0 -0
  10. cyclegan_model/data/base_dataset.py +167 -0
  11. cyclegan_model/data/image_folder.py +65 -0
  12. cyclegan_model/data/single_dataset.py +40 -0
  13. cyclegan_model/data/unaligned_dataset.py +71 -0
  14. cyclegan_model/model/__init__.py +63 -0
  15. cyclegan_model/model/__pycache__/__init__.cpython-310.pyc +0 -0
  16. cyclegan_model/model/__pycache__/base_model.cpython-310.pyc +0 -0
  17. cyclegan_model/model/__pycache__/cycle_gan_model.cpython-310.pyc +0 -0
  18. cyclegan_model/model/__pycache__/networks.cpython-310.pyc +0 -0
  19. cyclegan_model/model/__pycache__/test_model.cpython-310.pyc +0 -0
  20. cyclegan_model/model/base_model.py +230 -0
  21. cyclegan_model/model/cycle_gan_model.py +229 -0
  22. cyclegan_model/model/networks.py +1091 -0
  23. cyclegan_model/model/test_model.py +70 -0
  24. cyclegan_model/options/__init__.py +0 -0
  25. cyclegan_model/options/__pycache__/__init__.cpython-310.pyc +0 -0
  26. cyclegan_model/options/__pycache__/base_options.cpython-310.pyc +0 -0
  27. cyclegan_model/options/__pycache__/test_options.cpython-310.pyc +0 -0
  28. cyclegan_model/options/base_options.py +138 -0
  29. cyclegan_model/options/test_options.py +24 -0
  30. cyclegan_model/util/__init__.py +0 -0
  31. cyclegan_model/util/__pycache__/__init__.cpython-310.pyc +0 -0
  32. cyclegan_model/util/__pycache__/image_pool.cpython-310.pyc +0 -0
  33. cyclegan_model/util/__pycache__/util.cpython-310.pyc +0 -0
  34. cyclegan_model/util/image_pool.py +54 -0
  35. cyclegan_model/util/util.py +103 -0
  36. data/content/12.jpg +0 -0
  37. data/content/15.jpg +0 -0
  38. data/content/27032.jpg +0 -0
  39. data/style/13.png +0 -0
  40. data/style/2.jpg +0 -0
  41. data/style/9.jpg +0 -0
  42. flux_ad/__pycache__/utils.cpython-310.pyc +0 -0
  43. flux_ad/main.ipynb +96 -0
  44. flux_ad/mypipeline.py +381 -0
  45. flux_ad/utils.py +232 -0
  46. inpaint_model/model/__init__.py +0 -0
  47. inpaint_model/model/__pycache__/__init__.cpython-310.pyc +0 -0
  48. inpaint_model/model/__pycache__/networks.cpython-310.pyc +0 -0
  49. inpaint_model/model/networks.py +562 -0
  50. matting/image_matting.py +274 -0
ad.ipynb ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "c7dbfe5c",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Texture Synthesis"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "8a06ff4b",
14
+ "metadata": {},
15
+ "source": [
16
+ "### Texture Synthesis via Optimization"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "6349220d",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "import os\n",
27
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" \n",
28
+ "from accelerate.utils import set_seed\n",
29
+ "from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel\n",
30
+ "from pipeline_sd import ADPipeline\n",
31
+ "from utils import *\n",
32
+ "\n",
33
+ "model_name = \"/root/models/stable-diffusion-v1-5\"\n",
34
+ "vae = \"\"\n",
35
+ "lr = 0.05\n",
36
+ "iters = 1\n",
37
+ "seed = 42\n",
38
+ "width, height = 512, 512\n",
39
+ "weight = 0\n",
40
+ "batch_size = 3\n",
41
+ "mixed_precision = \"bf16\"\n",
42
+ "num_inference_steps = 300\n",
43
+ "enable_gradient_checkpoint = False\n",
44
+ "start_layer, end_layer = 10, 16\n",
45
+ "\n",
46
+ "\n",
47
+ "style_image = [\"./data/texture/4.jpg\"]\n",
48
+ "content_image = \"\"\n",
49
+ "\n",
50
+ "\n",
51
+ "scheduler = DDIMScheduler.from_pretrained(model_name, subfolder=\"scheduler\")\n",
52
+ "pipe = ADPipeline.from_pretrained(\n",
53
+ " model_name, scheduler=scheduler, safety_checker=None\n",
54
+ ")\n",
55
+ "if vae != \"\":\n",
56
+ " vae = AutoencoderKL.from_pretrained(vae)\n",
57
+ " pipe.vae = vae\n",
58
+ "pipe.classifier = pipe.unet\n",
59
+ "set_seed(seed)\n",
60
+ "\n",
61
+ "style_image = torch.cat([load_image(path, size=(512, 512)) for path in style_image])\n",
62
+ "rec_style_image = pipe.latent2image(pipe.image2latent(style_image))\n",
63
+ "if content_image == \"\":\n",
64
+ " content_image = None\n",
65
+ "else:\n",
66
+ " content_image = load_image(content_image, size=(width, height))\n",
67
+ "controller = Controller(self_layers=(start_layer, end_layer))\n",
68
+ "result = pipe.optimize(\n",
69
+ " lr=lr,\n",
70
+ " batch_size=batch_size,\n",
71
+ " iters=iters,\n",
72
+ " width=width,\n",
73
+ " height=height,\n",
74
+ " weight=weight,\n",
75
+ " controller=controller,\n",
76
+ " style_image=style_image,\n",
77
+ " content_image=content_image,\n",
78
+ " mixed_precision=mixed_precision,\n",
79
+ " num_inference_steps=num_inference_steps,\n",
80
+ " enable_gradient_checkpoint=enable_gradient_checkpoint,\n",
81
+ ")\n",
82
+ "\n",
83
+ "save_image(style_image, \"style.png\")\n",
84
+ "save_image(result, \"output.png\")\n",
85
+ "show_image(\"style.png\", title=\"style image\")\n",
86
+ "show_image(\"output.png\", title=\"generated\")"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "id": "c4ba9b1d",
92
+ "metadata": {},
93
+ "source": [
94
+ "### Texture Synthesis via Sample"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "id": "67a535c1",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "import os\n",
105
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" \n",
106
+ "from accelerate.utils import set_seed\n",
107
+ "from diffusers import AutoencoderKL, DDIMScheduler\n",
108
+ "from pipeline_sd import ADPipeline\n",
109
+ "from utils import *\n",
110
+ "\n",
111
+ "model_name = \"/root/models/stable-diffusion-v1-5\"\n",
112
+ "vae = \"\"\n",
113
+ "lr = 0.05\n",
114
+ "iters =3\n",
115
+ "seed = 42\n",
116
+ "width, height = 512, 512\n",
117
+ "weight = 0.\n",
118
+ "mixed_precision = \"bf16\"\n",
119
+ "num_inference_steps = 50\n",
120
+ "guidance_scale = 1\n",
121
+ "num_images_per_prompt = 3\n",
122
+ "enable_gradient_checkpoint = False\n",
123
+ "start_layer, end_layer = 10, 16\n",
124
+ "\n",
125
+ "\n",
126
+ "style_image = [\"./data/texture/8.jpg\"]\n",
127
+ "content_image = \"\"\n",
128
+ "\n",
129
+ "scheduler = DDIMScheduler.from_pretrained(model_name, subfolder=\"scheduler\")\n",
130
+ "pipe = ADPipeline.from_pretrained(model_name, scheduler=scheduler, safety_checker=None)\n",
131
+ "if vae != \"\":\n",
132
+ " vae = AutoencoderKL.from_pretrained(vae)\n",
133
+ " pipe.vae = vae\n",
134
+ "pipe.classifier = pipe.unet\n",
135
+ "set_seed(seed)\n",
136
+ "\n",
137
+ "style_image = torch.cat([load_image(path, size=(512, 512)) for path in style_image])\n",
138
+ "rec_style_image = pipe.latent2image(pipe.image2latent(style_image))\n",
139
+ "if content_image == \"\":\n",
140
+ " content_image = None\n",
141
+ "else:\n",
142
+ " content_image = load_image(content_image, size=(width, height))\n",
143
+ "controller = Controller(self_layers=(start_layer, end_layer))\n",
144
+ "\n",
145
+ "result = pipe.sample(\n",
146
+ " lr=lr,\n",
147
+ " adain=False,\n",
148
+ " iters=iters,\n",
149
+ " width=width,\n",
150
+ " height=height,\n",
151
+ " weight=weight,\n",
152
+ " controller=controller,\n",
153
+ " style_image=style_image,\n",
154
+ " content_image=content_image,\n",
155
+ " prompt=\"\",\n",
156
+ " negative_prompt=\"\",\n",
157
+ " mixed_precision=mixed_precision,\n",
158
+ " num_inference_steps=num_inference_steps,\n",
159
+ " guidance_scale=guidance_scale,\n",
160
+ " num_images_per_prompt=num_images_per_prompt,\n",
161
+ " enable_gradient_checkpoint=enable_gradient_checkpoint,\n",
162
+ ")\n",
163
+ "\n",
164
+ "save_image(style_image, \"style.png\")\n",
165
+ "save_image(result, \"output.png\")\n",
166
+ "show_image(\"style.png\", title=\"style image\")\n",
167
+ "show_image(\"output.png\", title=\"generated\")"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "markdown",
172
+ "id": "c69ae623",
173
+ "metadata": {},
174
+ "source": [
175
+ "### Texture Synthesis via MultiDiffusion"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "id": "6d059173",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "import os\n",
186
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" \n",
187
+ "from accelerate.utils import set_seed\n",
188
+ "from diffusers import AutoencoderKL, DDIMScheduler\n",
189
+ "from pipeline_sd import ADPipeline\n",
190
+ "from utils import *\n",
191
+ "\n",
192
+ "model_name = \"/root/models/stable-diffusion-v1-5\"\n",
193
+ "vae = \"\"\n",
194
+ "lr = 0.05\n",
195
+ "iters = 2\n",
196
+ "seed = 42\n",
197
+ "width, height = 512*2, 512\n",
198
+ "weight = 0.0\n",
199
+ "mixed_precision = \"bf16\"\n",
200
+ "num_inference_steps = 50\n",
201
+ "guidance_scale = 1\n",
202
+ "num_images_per_prompt = 1\n",
203
+ "enable_gradient_checkpoint = False\n",
204
+ "start_layer, end_layer = 10, 16\n",
205
+ "\n",
206
+ "\n",
207
+ "style_image = [\"./data/texture/17.jpg\"]\n",
208
+ "content_image = \"\"\n",
209
+ "\n",
210
+ "scheduler = DDIMScheduler.from_pretrained(model_name, subfolder=\"scheduler\")\n",
211
+ "pipe = ADPipeline.from_pretrained(model_name, scheduler=scheduler, safety_checker=None)\n",
212
+ "if vae != \"\":\n",
213
+ " vae = AutoencoderKL.from_pretrained(vae)\n",
214
+ " pipe.vae = vae\n",
215
+ "\n",
216
+ "pipe.classifier = pipe.unet\n",
217
+ "set_seed(seed)\n",
218
+ "\n",
219
+ "style_image = torch.cat([load_image(path, size=(512, 512)) for path in style_image])\n",
220
+ "if content_image == \"\":\n",
221
+ " content_image = None\n",
222
+ "else:\n",
223
+ " content_image = load_image(content_image, size=(width, height))\n",
224
+ "controller = Controller(self_layers=(start_layer, end_layer))\n",
225
+ "\n",
226
+ "result = pipe.panorama(\n",
227
+ " lr=lr,\n",
228
+ " iters=iters,\n",
229
+ " width=width,\n",
230
+ " height=height,\n",
231
+ " weight=weight,\n",
232
+ " controller=controller,\n",
233
+ " style_image=style_image,\n",
234
+ " content_image=content_image,\n",
235
+ " prompt=\"\",\n",
236
+ " negative_prompt=\"\",\n",
237
+ " stride=8,\n",
238
+ " view_batch_size=8,\n",
239
+ " mixed_precision=mixed_precision,\n",
240
+ " num_inference_steps=num_inference_steps,\n",
241
+ " guidance_scale=guidance_scale,\n",
242
+ " num_images_per_prompt=num_images_per_prompt,\n",
243
+ " enable_gradient_checkpoint=enable_gradient_checkpoint,\n",
244
+ ")\n",
245
+ "\n",
246
+ "save_image(style_image, \"style.png\")\n",
247
+ "save_image(result, \"output.png\")\n",
248
+ "show_image(\"style.png\", title=\"style image\")\n",
249
+ "show_image(\"output.png\", title=\"generated\")"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "id": "25e4e702",
255
+ "metadata": {},
256
+ "source": [
257
+ "## Style/Appearance Transfer"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "4badee8f",
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "import os\n",
268
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" \n",
269
+ "from accelerate.utils import set_seed\n",
270
+ "from diffusers import AutoencoderKL, DDIMScheduler\n",
271
+ "from pipeline_sd import ADPipeline\n",
272
+ "from utils import *\n",
273
+ "\n",
274
+ "model_name = \"/root/models/stable-diffusion-v1-5\"\n",
275
+ "vae = \"\"\n",
276
+ "lr = 0.05\n",
277
+ "iters = 1\n",
278
+ "seed = 42\n",
279
+ "width = 512\n",
280
+ "height = 512\n",
281
+ "weight = 0.25\n",
282
+ "batch_size = 1\n",
283
+ "mixed_precision = \"bf16\"\n",
284
+ "num_inference_steps = 200\n",
285
+ "guidance_scale = 1\n",
286
+ "num_images_per_prompt = 1\n",
287
+ "enable_gradient_checkpoint = False\n",
288
+ "start_layer, end_layer = 10, 16\n",
289
+ "\n",
290
+ "\n",
291
+ "style_image = [\"./data/style/12.jpg\"]\n",
292
+ "content_image = \"./data/content/deer.jpg\"\n",
293
+ "\n",
294
+ "\n",
295
+ "scheduler = DDIMScheduler.from_pretrained(model_name, subfolder=\"scheduler\")\n",
296
+ "pipe = ADPipeline.from_pretrained(\n",
297
+ " model_name, scheduler=scheduler, safety_checker=None\n",
298
+ ")\n",
299
+ "if vae != \"\":\n",
300
+ " vae = AutoencoderKL.from_pretrained(vae)\n",
301
+ " pipe.vae = vae\n",
302
+ "\n",
303
+ "pipe.classifier = pipe.unet\n",
304
+ "set_seed(seed)\n",
305
+ "\n",
306
+ "style_image = torch.cat([load_image(path, size=(512, 512)) for path in style_image])\n",
307
+ "if content_image == \"\":\n",
308
+ " content_image = None\n",
309
+ "else:\n",
310
+ " content_image = load_image(content_image, size=(width, height))\n",
311
+ "controller = Controller(self_layers=(start_layer, end_layer))\n",
312
+ "result = pipe.optimize(\n",
313
+ " lr=lr,\n",
314
+ " batch_size=batch_size,\n",
315
+ " iters=iters,\n",
316
+ " width=width,\n",
317
+ " height=height,\n",
318
+ " weight=weight,\n",
319
+ " controller=controller,\n",
320
+ " style_image=style_image,\n",
321
+ " content_image=content_image,\n",
322
+ " mixed_precision=mixed_precision,\n",
323
+ " num_inference_steps=num_inference_steps,\n",
324
+ " enable_gradient_checkpoint=enable_gradient_checkpoint,\n",
325
+ ")\n",
326
+ "\n",
327
+ "save_image(style_image, \"style.png\")\n",
328
+ "save_image(content_image, \"content.png\")\n",
329
+ "save_image(result, \"output.png\")\n",
330
+ "show_image(\"style.png\", title=\"style image\")\n",
331
+ "show_image(\"content.png\", title=\"content image\")\n",
332
+ "show_image(\"output.png\", title=\"generated\")"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "markdown",
337
+ "id": "088ca839",
338
+ "metadata": {},
339
+ "source": [
340
+ "## Style-specific T2I Generation "
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "markdown",
345
+ "id": "02efdafd",
346
+ "metadata": {},
347
+ "source": [
348
+ "### Style-specific T2I Generation with SD1.5"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": null,
354
+ "id": "15c9fb96",
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "import os\n",
359
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" \n",
360
+ "from pipeline_sd import ADPipeline\n",
361
+ "from diffusers import DDIMScheduler, AutoencoderKL\n",
362
+ "import torch\n",
363
+ "from utils import *\n",
364
+ "from accelerate.utils import set_seed\n",
365
+ "\n",
366
+ "\n",
367
+ "model_name = \"/root/models/stable-diffusion-v1-5\"\n",
368
+ "vae = \"\"\n",
369
+ "lr = 0.015\n",
370
+ "iters = 3\n",
371
+ "seed = 42\n",
372
+ "mixed_precision = \"bf16\"\n",
373
+ "num_inference_steps = 50\n",
374
+ "guidance_scale = 7.5\n",
375
+ "num_images_per_prompt = 3\n",
376
+ "enable_gradient_checkpoint = False\n",
377
+ "start_layer, end_layer = 10, 16\n",
378
+ "\n",
379
+ "prompt = \"A deer\"\n",
380
+ "style_image = [\"./data/style/1.jpg\"]\n",
381
+ "\n",
382
+ "scheduler = DDIMScheduler.from_pretrained(model_name, subfolder=\"scheduler\")\n",
383
+ "pipe = ADPipeline.from_pretrained(\n",
384
+ " model_name, scheduler=scheduler, safety_checker=None\n",
385
+ ")\n",
386
+ "if vae != \"\":\n",
387
+ " vae = AutoencoderKL.from_pretrained(vae)\n",
388
+ " pipe.vae = vae\n",
389
+ "\n",
390
+ "pipe.classifier = pipe.unet\n",
391
+ "set_seed(seed)\n",
392
+ "\n",
393
+ "style_image = torch.cat([load_image(path, size=(512, 512)) for path in style_image])\n",
394
+ "controller = Controller(self_layers=(start_layer, end_layer))\n",
395
+ "\n",
396
+ "result = pipe.sample(\n",
397
+ " controller=controller,\n",
398
+ " iters=iters,\n",
399
+ " lr=lr,\n",
400
+ " adain=True,\n",
401
+ " height=512,\n",
402
+ " width=512,\n",
403
+ " mixed_precision=\"bf16\",\n",
404
+ " style_image=style_image,\n",
405
+ " prompt=prompt,\n",
406
+ " negative_prompt=\"\",\n",
407
+ " guidance_scale=guidance_scale,\n",
408
+ " num_inference_steps=num_inference_steps,\n",
409
+ " num_images_per_prompt=num_images_per_prompt,\n",
410
+ " enable_gradient_checkpoint=enable_gradient_checkpoint\n",
411
+ ")\n",
412
+ "\n",
413
+ "save_image(style_image, \"style.png\")\n",
414
+ "save_image(result, \"output.png\")\n",
415
+ "show_image(\"style.png\", title=\"style image\")\n",
416
+ "show_image(\"output.png\", title=prompt)\n"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "markdown",
421
+ "id": "dd75eac7",
422
+ "metadata": {},
423
+ "source": [
424
+ "### Style-specific T2I Generation with SDXL"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "1541fd6b",
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "import os\n",
435
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" \n",
436
+ "from pipeline_sdxl import ADPipeline\n",
437
+ "from diffusers import DDIMScheduler, AutoencoderKL\n",
438
+ "import torch\n",
439
+ "from utils import *\n",
440
+ "from accelerate.utils import set_seed\n",
441
+ "\n",
442
+ "\n",
443
+ "model_name = \"/root/models/stable-diffusion-xl-base-1.0/\"\n",
444
+ "vae = \"\"\n",
445
+ "lr = 0.015\n",
446
+ "iters = 2\n",
447
+ "seed = 42\n",
448
+ "mixed_precision = \"bf16\"\n",
449
+ "num_inference_steps = 50\n",
450
+ "guidance_scale = 7\n",
451
+ "num_images_per_prompt = 5\n",
452
+ "enable_gradient_checkpoint = True\n",
453
+ "start_layer, end_layer = 64, 70\n",
454
+ "\n",
455
+ "prompt = \"A rocket\"\n",
456
+ "style_image = [\"./data/style/1.png\"]\n",
457
+ "\n",
458
+ "scheduler = DDIMScheduler.from_pretrained(model_name, subfolder=\"scheduler\")\n",
459
+ "pipe = ADPipeline.from_pretrained(\n",
460
+ " model_name, scheduler=scheduler, safety_checker=None\n",
461
+ ")\n",
462
+ "if vae != \"\":\n",
463
+ " vae = AutoencoderKL.from_pretrained(vae)\n",
464
+ " pipe.vae = vae\n",
465
+ "\n",
466
+ "pipe.classifier = pipe.unet\n",
467
+ "set_seed(seed)\n",
468
+ "\n",
469
+ "style_image = torch.cat([load_image(path, size=(1024, 1024)) for path in style_image])\n",
470
+ "controller = Controller(self_layers=(start_layer, end_layer))\n",
471
+ "\n",
472
+ "result = pipe.sample(\n",
473
+ " controller=controller,\n",
474
+ " iters=iters,\n",
475
+ " lr=lr,\n",
476
+ " adain=True,\n",
477
+ " height=1024,\n",
478
+ " width=1024,\n",
479
+ " mixed_precision=\"bf16\",\n",
480
+ " style_image=style_image,\n",
481
+ " prompt=prompt,\n",
482
+ " negative_prompt=\"\",\n",
483
+ " guidance_scale=guidance_scale,\n",
484
+ " num_inference_steps=num_inference_steps,\n",
485
+ " num_images_per_prompt=num_images_per_prompt,\n",
486
+ " enable_gradient_checkpoint=enable_gradient_checkpoint\n",
487
+ ")\n",
488
+ "\n",
489
+ "save_image(style_image, \"style.png\")\n",
490
+ "save_image(result, \"output.png\")\n",
491
+ "show_image(\"style.png\", title=\"style image\")\n",
492
+ "show_image(\"output.png\", title=prompt)\n"
493
+ ]
494
+ }
495
+ ],
496
+ "metadata": {
497
+ "kernelspec": {
498
+ "display_name": "ad",
499
+ "language": "python",
500
+ "name": "python3"
501
+ },
502
+ "language_info": {
503
+ "codemirror_mode": {
504
+ "name": "ipython",
505
+ "version": 3
506
+ },
507
+ "file_extension": ".py",
508
+ "mimetype": "text/x-python",
509
+ "name": "python",
510
+ "nbconvert_exporter": "python",
511
+ "pygments_lexer": "ipython3",
512
+ "version": "3.10.16"
513
+ }
514
+ },
515
+ "nbformat": 4,
516
+ "nbformat_minor": 5
517
+ }
checkpoints/ostracoda_cyclegan/test_opt.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ----------------- Options ---------------
2
+ aspect_ratio: 1.0
3
+ batch_size: 1
4
+ checkpoints_dir: ./checkpoints
5
+ crop_size: 256
6
+ dataroot: None
7
+ dataset_mode: single
8
+ direction: AtoB
9
+ display_winsize: 256
10
+ epoch: latest
11
+ eval: False
12
+ gpu_ids: 0
13
+ init_gain: 0.02
14
+ init_type: normal
15
+ input_nc: 3
16
+ isTrain: False [default: None]
17
+ load_iter: 0 [default: 0]
18
+ load_size: 256
19
+ max_dataset_size: inf
20
+ model: test
21
+ model_suffix:
22
+ n_layers_D: 3
23
+ name: ostracoda_cyclegan
24
+ ndf: 64
25
+ netD: basic
26
+ netG: resnet_9blocks
27
+ ngf: 64
28
+ no_dropout: False
29
+ no_flip: False
30
+ norm: instance
31
+ num_test: 100
32
+ num_threads: 4
33
+ output_nc: 3
34
+ phase: test
35
+ preprocess: resize_and_crop
36
+ results_dir: ./results/
37
+ serial_batches: False
38
+ suffix:
39
+ use_wandb: False
40
+ verbose: False
41
+ wandb_project_name: CycleGAN-and-pix2pix
42
+ ----------------- End -------------------
configs/config.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data parameters
2
+ dataset_name: imagenet
3
+ data_with_subfolder: True
4
+
5
+ train_data_path: traindata/train
6
+ val_data_path: traindata/val
7
+ resume: checkpoints/imagenet/hole_benchmark
8
+
9
+
10
+ batch_size: 4
11
+ image_shape: [256, 256, 3]
12
+ mask_shape: [128, 128]
13
+ mask_batch_same: True
14
+ max_delta_shape: [32, 32]
15
+ margin: [0, 0]
16
+ discounted_mask: True
17
+ spatial_discounting_gamma: 0.9
18
+ random_crop: True
19
+ mask_type: hole # hole | mosaic
20
+ mosaic_unit_size: 12
21
+
22
+ # training parameters
23
+ expname: benchmark
24
+ cuda: Ture
25
+ gpu_ids: [0] # set the GPU ids to use, e.g. [0] or [1, 2]
26
+ num_workers: 4
27
+ lr: 0.0001
28
+ beta1: 0.5
29
+ beta2: 0.9
30
+ n_critic: 5
31
+ niter: 480000
32
+ print_iter: 100
33
+ viz_iter: 1000
34
+ viz_max_out: 16
35
+ snapshot_save_iter: 5000
36
+
37
+ # loss weight
38
+ coarse_l1_alpha: 1.2
39
+ l1_loss_alpha: 1.2
40
+ ae_loss_alpha: 1.2
41
+ global_wgan_loss_alpha: 1.
42
+ gan_loss_alpha: 0.001
43
+ wgan_gp_lambda: 10
44
+
45
+ # network parameters
46
+ netG:
47
+ input_dim: 3
48
+ ngf: 32
49
+
50
+ netD:
51
+ input_dim: 3
52
+ ndf: 64
cyclegan_model/data/__init__.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from cyclegan_model.data.base_dataset import BaseDataset
16
+
17
+
18
+ def find_dataset_using_name(dataset_name):
19
+ """Import the module "data/[dataset_name]_dataset.py".
20
+
21
+ In the file, the class called DatasetNameDataset() will
22
+ be instantiated. It has to be a subclass of BaseDataset,
23
+ and it is case-insensitive.
24
+ """
25
+ dataset_filename = "cyclegan_model.data." + dataset_name + "_dataset"
26
+ datasetlib = importlib.import_module(dataset_filename)
27
+
28
+ dataset = None
29
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30
+ for name, cls in datasetlib.__dict__.items():
31
+ if name.lower() == target_dataset_name.lower() \
32
+ and issubclass(cls, BaseDataset):
33
+ dataset = cls
34
+
35
+ if dataset is None:
36
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37
+
38
+ return dataset
39
+
40
+
41
+ def get_option_setter(dataset_name):
42
+ """Return the static method <modify_commandline_options> of the dataset class."""
43
+ dataset_class = find_dataset_using_name(dataset_name)
44
+ return dataset_class.modify_commandline_options
45
+
46
+
47
+ def create_dataset(opt):
48
+ """Create a dataset given the option.
49
+
50
+ This function wraps the class CustomDatasetDataLoader.
51
+ This is the main interface between this package and 'train.py'/'test.py'
52
+ """
53
+ data_loader = CustomDatasetDataLoader(opt)
54
+ dataset = data_loader.load_data()
55
+ return dataset
56
+
57
+
58
+ class CustomDatasetDataLoader():
59
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
60
+
61
+ def __init__(self, opt):
62
+ """Initialize this class
63
+
64
+ Step 1: create a dataset instance given the name [dataset_mode]
65
+ Step 2: create a multi-threaded data loader.
66
+ """
67
+ self.opt = opt
68
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
69
+ self.dataset = dataset_class(opt)
70
+ print("dataset [%s] was created" % type(self.dataset).__name__)
71
+ self.dataloader = torch.utils.data.DataLoader(
72
+ self.dataset,
73
+ batch_size=opt.batch_size,
74
+ shuffle=not opt.serial_batches,
75
+ num_workers=int(opt.num_threads))
76
+
77
+ def load_data(self):
78
+ return self
79
+
80
+ def __len__(self):
81
+ """Return the number of data in the dataset"""
82
+ return min(len(self.dataset), self.opt.max_dataset_size)
83
+
84
+ def __iter__(self):
85
+ """Return a batch of data"""
86
+ for i, data in enumerate(self.dataloader):
87
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
88
+ break
89
+ yield data
cyclegan_model/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.92 kB). View file
 
cyclegan_model/data/__pycache__/base_dataset.cpython-310.pyc ADDED
Binary file (6.15 kB). View file
 
cyclegan_model/data/__pycache__/image_folder.cpython-310.pyc ADDED
Binary file (2.43 kB). View file
 
cyclegan_model/data/__pycache__/single_dataset.cpython-310.pyc ADDED
Binary file (2 kB). View file
 
cyclegan_model/data/__pycache__/unaligned_dataset.cpython-310.pyc ADDED
Binary file (2.99 kB). View file
 
cyclegan_model/data/base_dataset.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ self.root = opt.dataroot
31
+
32
+ @staticmethod
33
+ def modify_commandline_options(parser, is_train):
34
+ """Add new dataset-specific options, and rewrite default values for existing options.
35
+
36
+ Parameters:
37
+ parser -- original option parser
38
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39
+
40
+ Returns:
41
+ the modified parser.
42
+ """
43
+ return parser
44
+
45
+ @abstractmethod
46
+ def __len__(self):
47
+ """Return the total number of images in the dataset."""
48
+ return 0
49
+
50
+ @abstractmethod
51
+ def __getitem__(self, index):
52
+ """Return a data point and its metadata information.
53
+
54
+ Parameters:
55
+ index - - a random integer for data indexing
56
+
57
+ Returns:
58
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59
+ """
60
+ pass
61
+
62
+
63
+ def get_params(opt, size):
64
+ w, h = size
65
+ new_h = h
66
+ new_w = w
67
+ if opt.preprocess == 'resize_and_crop':
68
+ new_h = new_w = opt.load_size
69
+ elif opt.preprocess == 'scale_width_and_crop':
70
+ new_w = opt.load_size
71
+ new_h = opt.load_size * h // w
72
+
73
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
74
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
75
+
76
+ flip = random.random() > 0.5
77
+
78
+ return {'crop_pos': (x, y), 'flip': flip}
79
+
80
+
81
+ def get_transform(opt, params=None, grayscale=False, method=transforms.InterpolationMode.BICUBIC, convert=True):
82
+ transform_list = []
83
+ if grayscale:
84
+ transform_list.append(transforms.Grayscale(1))
85
+ if 'resize' in opt.preprocess:
86
+ osize = [opt.load_size, opt.load_size]
87
+ transform_list.append(transforms.Resize(osize, method))
88
+ elif 'scale_width' in opt.preprocess:
89
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
90
+
91
+ if 'crop' in opt.preprocess:
92
+ if params is None:
93
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
94
+ else:
95
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
96
+
97
+ if opt.preprocess == 'none':
98
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
99
+
100
+ if not opt.no_flip:
101
+ if params is None:
102
+ transform_list.append(transforms.RandomHorizontalFlip())
103
+ elif params['flip']:
104
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
105
+
106
+ if convert:
107
+ transform_list += [transforms.ToTensor()]
108
+ if grayscale:
109
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
110
+ else:
111
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
112
+ return transforms.Compose(transform_list)
113
+
114
+
115
+ def __transforms2pil_resize(method):
116
+ mapper = {transforms.InterpolationMode.BILINEAR: Image.BILINEAR,
117
+ transforms.InterpolationMode.BICUBIC: Image.BICUBIC,
118
+ transforms.InterpolationMode.NEAREST: Image.NEAREST,
119
+ transforms.InterpolationMode.LANCZOS: Image.LANCZOS,}
120
+ return mapper[method]
121
+
122
+
123
+ def __make_power_2(img, base, method=transforms.InterpolationMode.BICUBIC):
124
+ method = __transforms2pil_resize(method)
125
+ ow, oh = img.size
126
+ h = int(round(oh / base) * base)
127
+ w = int(round(ow / base) * base)
128
+ if h == oh and w == ow:
129
+ return img
130
+
131
+ __print_size_warning(ow, oh, w, h)
132
+ return img.resize((w, h), method)
133
+
134
+
135
+ def __scale_width(img, target_size, crop_size, method=transforms.InterpolationMode.BICUBIC):
136
+ method = __transforms2pil_resize(method)
137
+ ow, oh = img.size
138
+ if ow == target_size and oh >= crop_size:
139
+ return img
140
+ w = target_size
141
+ h = int(max(target_size * oh / ow, crop_size))
142
+ return img.resize((w, h), method)
143
+
144
+
145
+ def __crop(img, pos, size):
146
+ ow, oh = img.size
147
+ x1, y1 = pos
148
+ tw = th = size
149
+ if (ow > tw or oh > th):
150
+ return img.crop((x1, y1, x1 + tw, y1 + th))
151
+ return img
152
+
153
+
154
+ def __flip(img, flip):
155
+ if flip:
156
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
157
+ return img
158
+
159
+
160
+ def __print_size_warning(ow, oh, w, h):
161
+ """Print warning information about image size(only print once)"""
162
+ if not hasattr(__print_size_warning, 'has_printed'):
163
+ print("The image size needs to be a multiple of 4. "
164
+ "The loaded image size was (%d, %d), so it was adjusted to "
165
+ "(%d, %d). This adjustment will be done to all images "
166
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
167
+ __print_size_warning.has_printed = True
cyclegan_model/data/image_folder.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+
12
+ IMG_EXTENSIONS = [
13
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
14
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
15
+ '.tif', '.TIF', '.tiff', '.TIFF',
16
+ ]
17
+
18
+
19
+ def is_image_file(filename):
20
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21
+
22
+
23
+ def make_dataset(dir, max_dataset_size=float("inf")):
24
+ images = []
25
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
26
+
27
+ for root, _, fnames in sorted(os.walk(dir)):
28
+ for fname in fnames:
29
+ if is_image_file(fname):
30
+ path = os.path.join(root, fname)
31
+ images.append(path)
32
+ return images[:min(max_dataset_size, len(images))]
33
+
34
+
35
+ def default_loader(path):
36
+ return Image.open(path).convert('RGB')
37
+
38
+
39
+ class ImageFolder(data.Dataset):
40
+
41
+ def __init__(self, root, transform=None, return_paths=False,
42
+ loader=default_loader):
43
+ imgs = make_dataset(root)
44
+ if len(imgs) == 0:
45
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
46
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
47
+
48
+ self.root = root
49
+ self.imgs = imgs
50
+ self.transform = transform
51
+ self.return_paths = return_paths
52
+ self.loader = loader
53
+
54
+ def __getitem__(self, index):
55
+ path = self.imgs[index]
56
+ img = self.loader(path)
57
+ if self.transform is not None:
58
+ img = self.transform(img)
59
+ if self.return_paths:
60
+ return img, path
61
+ else:
62
+ return img
63
+
64
+ def __len__(self):
65
+ return len(self.imgs)
cyclegan_model/data/single_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cyclegan_model.data.base_dataset import BaseDataset, get_transform
2
+ from cyclegan_model.data.image_folder import make_dataset
3
+ from PIL import Image
4
+
5
+
6
+ class SingleDataset(BaseDataset):
7
+ """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
8
+
9
+ It can be used for generating CycleGAN results only for one side with the model option '-model test'.
10
+ """
11
+
12
+ def __init__(self, opt):
13
+ """Initialize this dataset class.
14
+
15
+ Parameters:
16
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
17
+ """
18
+ BaseDataset.__init__(self, opt)
19
+ self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
20
+ input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
21
+ self.transform = get_transform(opt, grayscale=(input_nc == 1))
22
+
23
+ def __getitem__(self, index):
24
+ """Return a data point and its metadata information.
25
+
26
+ Parameters:
27
+ index - - a random integer for data indexing
28
+
29
+ Returns a dictionary that contains A and A_paths
30
+ A(tensor) - - an image in one domain
31
+ A_paths(str) - - the path of the image
32
+ """
33
+ A_path = self.A_paths[index]
34
+ A_img = Image.open(A_path).convert('RGB')
35
+ A = self.transform(A_img)
36
+ return {'A': A, 'A_paths': A_path}
37
+
38
+ def __len__(self):
39
+ """Return the total number of images in the dataset."""
40
+ return len(self.A_paths)
cyclegan_model/data/unaligned_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from cyclegan_model.data.base_dataset import BaseDataset, get_transform
3
+ from cyclegan_model.data.image_folder import make_dataset
4
+ from PIL import Image
5
+ import random
6
+
7
+
8
+ class UnalignedDataset(BaseDataset):
9
+ """
10
+ This dataset class can load unaligned/unpaired datasets.
11
+
12
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
13
+ and from domain B '/path/to/data/trainB' respectively.
14
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
15
+ Similarly, you need to prepare two directories:
16
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
17
+ """
18
+
19
+ def __init__(self, opt):
20
+ """Initialize this dataset class.
21
+
22
+ Parameters:
23
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
24
+ """
25
+ BaseDataset.__init__(self, opt)
26
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
27
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
28
+
29
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
30
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
31
+ self.A_size = len(self.A_paths) # get the size of dataset A
32
+ self.B_size = len(self.B_paths) # get the size of dataset B
33
+ btoA = self.opt.direction == 'BtoA'
34
+ input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
35
+ output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
36
+ self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
37
+ self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))
38
+
39
+ def __getitem__(self, index):
40
+ """Return a data point and its metadata information.
41
+
42
+ Parameters:
43
+ index (int) -- a random integer for data indexing
44
+
45
+ Returns a dictionary that contains A, B, A_paths and B_paths
46
+ A (tensor) -- an image in the input domain
47
+ B (tensor) -- its corresponding image in the target domain
48
+ A_paths (str) -- image paths
49
+ B_paths (str) -- image paths
50
+ """
51
+ A_path = self.A_paths[index % self.A_size] # make sure index is within then range
52
+ if self.opt.serial_batches: # make sure index is within then range
53
+ index_B = index % self.B_size
54
+ else: # randomize the index for domain B to avoid fixed pairs.
55
+ index_B = random.randint(0, self.B_size - 1)
56
+ B_path = self.B_paths[index_B]
57
+ A_img = Image.open(A_path).convert('RGB')
58
+ B_img = Image.open(B_path).convert('RGB')
59
+ # apply image transformation
60
+ A = self.transform_A(A_img)
61
+ B = self.transform_B(B_img)
62
+
63
+ return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
64
+
65
+ def __len__(self):
66
+ """Return the total number of images in the dataset.
67
+
68
+ As we have two datasets with potentially different number of images,
69
+ we take a maximum of
70
+ """
71
+ return max(self.A_size, self.B_size)
cyclegan_model/model/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from cyclegan_model.model.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "cyclegan_model.model." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+ """
60
+ model = find_model_using_name(opt.model)
61
+ instance = model(opt)
62
+ print("model [%s] was created" % type(instance).__name__)
63
+ return instance
cyclegan_model/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
cyclegan_model/model/__pycache__/base_model.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
cyclegan_model/model/__pycache__/cycle_gan_model.cpython-310.pyc ADDED
Binary file (7.94 kB). View file
 
cyclegan_model/model/__pycache__/networks.cpython-310.pyc ADDED
Binary file (31.9 kB). View file
 
cyclegan_model/model/__pycache__/test_model.cpython-310.pyc ADDED
Binary file (3.13 kB). View file
 
cyclegan_model/model/base_model.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this function, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): define networks used in our training.
29
+ -- self.visual_names (str list): specify the images that you want to display and save.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
+ torch.backends.cudnn.benchmark = True
39
+ self.loss_names = []
40
+ self.model_names = []
41
+ self.visual_names = []
42
+ self.optimizers = []
43
+ self.image_paths = []
44
+ self.metric = 0 # used for learning rate policy 'plateau'
45
+
46
+ @staticmethod
47
+ def modify_commandline_options(parser, is_train):
48
+ """Add new model-specific options, and rewrite default values for existing options.
49
+
50
+ Parameters:
51
+ parser -- original option parser
52
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
53
+
54
+ Returns:
55
+ the modified parser.
56
+ """
57
+ return parser
58
+
59
+ @abstractmethod
60
+ def set_input(self, input):
61
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
62
+
63
+ Parameters:
64
+ input (dict): includes the data itself and its metadata information.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def forward(self):
70
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
71
+ pass
72
+
73
+ @abstractmethod
74
+ def optimize_parameters(self):
75
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
76
+ pass
77
+
78
+ def setup(self, opt):
79
+ """Load and print networks; create schedulers
80
+
81
+ Parameters:
82
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
83
+ """
84
+ if self.isTrain:
85
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86
+ if not self.isTrain or opt.continue_train:
87
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
88
+ self.load_networks(load_suffix)
89
+ self.print_networks(opt.verbose)
90
+
91
+ def eval(self):
92
+ """Make models eval mode during test time"""
93
+ for name in self.model_names:
94
+ if isinstance(name, str):
95
+ net = getattr(self, 'net' + name)
96
+ net.eval()
97
+
98
+ def test(self):
99
+ """Forward function used in test time.
100
+
101
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
102
+ It also calls <compute_visuals> to produce additional visualization results
103
+ """
104
+ with torch.no_grad():
105
+ self.forward()
106
+ self.compute_visuals()
107
+
108
+ def compute_visuals(self):
109
+ """Calculate additional output images for visdom and HTML visualization"""
110
+ pass
111
+
112
+ def get_image_paths(self):
113
+ """ Return image paths that are used to load current data"""
114
+ return self.image_paths
115
+
116
+ def update_learning_rate(self):
117
+ """Update learning rates for all the networks; called at the end of every epoch"""
118
+ old_lr = self.optimizers[0].param_groups[0]['lr']
119
+ for scheduler in self.schedulers:
120
+ if self.opt.lr_policy == 'plateau':
121
+ scheduler.step(self.metric)
122
+ else:
123
+ scheduler.step()
124
+
125
+ lr = self.optimizers[0].param_groups[0]['lr']
126
+ print('learning rate %.7f -> %.7f' % (old_lr, lr))
127
+
128
+ def get_current_visuals(self):
129
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
130
+ visual_ret = OrderedDict()
131
+ for name in self.visual_names:
132
+ if isinstance(name, str):
133
+ visual_ret[name] = getattr(self, name)
134
+ return visual_ret
135
+
136
+ def get_current_losses(self):
137
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
138
+ errors_ret = OrderedDict()
139
+ for name in self.loss_names:
140
+ if isinstance(name, str):
141
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
142
+ return errors_ret
143
+
144
+ def save_networks(self, epoch):
145
+ """Save all the networks to the disk.
146
+
147
+ Parameters:
148
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
149
+ """
150
+ for name in self.model_names:
151
+ if isinstance(name, str):
152
+ save_filename = '%s_net_%s.pth' % (epoch, name)
153
+ save_path = os.path.join(self.save_dir, save_filename)
154
+ net = getattr(self, 'net' + name)
155
+
156
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
157
+ torch.save(net.module.cpu().state_dict(), save_path)
158
+ net.cuda(self.gpu_ids[0])
159
+ else:
160
+ torch.save(net.cpu().state_dict(), save_path)
161
+
162
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
163
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
164
+ key = keys[i]
165
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
166
+ if module.__class__.__name__.startswith('InstanceNorm') and \
167
+ (key == 'running_mean' or key == 'running_var'):
168
+ if getattr(module, key) is None:
169
+ state_dict.pop('.'.join(keys))
170
+ if module.__class__.__name__.startswith('InstanceNorm') and \
171
+ (key == 'num_batches_tracked'):
172
+ state_dict.pop('.'.join(keys))
173
+ else:
174
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
175
+
176
+ def load_networks(self, epoch):
177
+ """Load all the networks from the disk.
178
+
179
+ Parameters:
180
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
181
+ """
182
+ for name in self.model_names:
183
+ if isinstance(name, str):
184
+ load_filename = '%s_net_%s.pth' % (epoch, name)
185
+ load_path = os.path.join(self.save_dir, load_filename)
186
+ net = getattr(self, 'net' + name)
187
+ if isinstance(net, torch.nn.DataParallel):
188
+ net = net.module
189
+ print('loading the model from %s' % load_path)
190
+ # if you are using PyTorch newer than 0.4 (e.g., built from
191
+ # GitHub source), you can remove str() on self.device
192
+ state_dict = torch.load(load_path, map_location=str(self.device))
193
+ if hasattr(state_dict, '_metadata'):
194
+ del state_dict._metadata
195
+
196
+ # patch InstanceNorm checkpoints prior to 0.4
197
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
198
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
199
+ net.load_state_dict(state_dict)
200
+
201
+ def print_networks(self, verbose):
202
+ """Print the total number of parameters in the network and (if verbose) network architecture
203
+
204
+ Parameters:
205
+ verbose (bool) -- if verbose: print the network architecture
206
+ """
207
+ print('---------- Networks initialized -------------')
208
+ for name in self.model_names:
209
+ if isinstance(name, str):
210
+ net = getattr(self, 'net' + name)
211
+ num_params = 0
212
+ for param in net.parameters():
213
+ num_params += param.numel()
214
+ if verbose:
215
+ print(net)
216
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
217
+ print('-----------------------------------------------')
218
+
219
+ def set_requires_grad(self, nets, requires_grad=False):
220
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
221
+ Parameters:
222
+ nets (network list) -- a list of networks
223
+ requires_grad (bool) -- whether the networks require gradients or not
224
+ """
225
+ if not isinstance(nets, list):
226
+ nets = [nets]
227
+ for net in nets:
228
+ if net is not None:
229
+ for param in net.parameters():
230
+ param.requires_grad = requires_grad
cyclegan_model/model/cycle_gan_model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import itertools
3
+ from cyclegan_model.util.image_pool import ImagePool
4
+ from .base_model import BaseModel
5
+ from . import networks
6
+ from .networks import cal_gradient_penalty
7
+
8
+ class CycleGANModel(BaseModel):
9
+ """
10
+ This class implements the CycleGAN model, for learning image-to-image translation without paired data.
11
+
12
+ The model training requires '--dataset_mode unaligned' dataset.
13
+ By default, it uses a '--netG resnet_9blocks' ResNet generator,
14
+ a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
15
+ and a least-square GANs objective ('--gan_mode lsgan').
16
+
17
+ CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
18
+ """
19
+ @staticmethod
20
+ def modify_commandline_options(parser, is_train=True):
21
+ """Add new dataset-specific options, and rewrite default values for existing options.
22
+
23
+ Parameters:
24
+ parser -- original option parser
25
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
26
+
27
+ Returns:
28
+ the modified parser.
29
+
30
+ For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
31
+ A (source domain), B (target domain).
32
+ Generators: G_A: A -> B; G_B: B -> A.
33
+ Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
34
+ Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
35
+ Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
36
+ Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
37
+ Dropout is not used in the original CycleGAN paper.
38
+ """
39
+ parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout
40
+ if is_train:
41
+ parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
42
+ parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
43
+ parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
44
+
45
+ return parser
46
+
47
+ def __init__(self, opt):
48
+ """Initialize the CycleGAN class.
49
+
50
+ Parameters:
51
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
52
+ """
53
+ BaseModel.__init__(self, opt)
54
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
55
+ self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
56
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
57
+ visual_names_A = ['real_A', 'fake_B', 'rec_A']
58
+ visual_names_B = ['real_B', 'fake_A', 'rec_B']
59
+ if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
60
+ visual_names_A.append('idt_B')
61
+ visual_names_B.append('idt_A')
62
+
63
+ self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
64
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
65
+ if self.isTrain:
66
+ self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
67
+ else: # during test time, only load Gs
68
+ self.model_names = ['G_A', 'G_B']
69
+
70
+ # define networks (both Generators and discriminators)
71
+ # The naming is different from those used in the paper.
72
+ # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
73
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
74
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
75
+ self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
76
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
77
+
78
+ if self.isTrain: # define discriminators
79
+ self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
80
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
81
+ self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
82
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
83
+
84
+ if self.isTrain:
85
+ if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
86
+ assert(opt.input_nc == opt.output_nc)
87
+ self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
88
+ self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
89
+ # define loss functions
90
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
91
+ self.criterionCycle = torch.nn.L1Loss()
92
+ self.criterionIdt = torch.nn.L1Loss()
93
+
94
+
95
+
96
+ #self.criterionCycle = lpips.LPIPS(net='alex').to(self.device)
97
+ #self.criterionIdt = lpips.LPIPS(net='alex').to(self.device)
98
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
99
+ self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=0.0001, betas=(opt.beta1, 0.999))
100
+ self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=0.0003, betas=(opt.beta1, 0.999))
101
+ self.optimizers.append(self.optimizer_G)
102
+ self.optimizers.append(self.optimizer_D)
103
+
104
+ def set_input(self, input):
105
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
106
+
107
+ Parameters:
108
+ input (dict): include the data itself and its metadata information.
109
+
110
+ The option 'direction' can be used to swap domain A and domain B.
111
+ """
112
+ AtoB = self.opt.direction == 'AtoB'
113
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
114
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
115
+
116
+ def forward(self):
117
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
118
+ self.fake_B = self.netG_A(self.real_A) # G_A(A)
119
+ self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
120
+ self.fake_A = self.netG_B(self.real_B) # G_B(B)
121
+ self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
122
+
123
+ def backward_D_basic(self, netD, real, fake):
124
+ """Calculate GAN loss for the discriminator
125
+
126
+ Parameters:
127
+ netD (network) -- the discriminator D
128
+ real (tensor array) -- real images
129
+ fake (tensor array) -- images generated by a generator
130
+
131
+ Return the discriminator loss.
132
+ We also call loss_D.backward() to calculate the gradients.
133
+ """
134
+ # Real
135
+ pred_real = netD(real)
136
+ pradient_penalty,pradients=cal_gradient_penalty(netD,real,fake,self.device)
137
+
138
+
139
+ loss_D_real = self.criterionGAN(pred_real, True)
140
+ # Fake
141
+ pred_fake = netD(fake.detach())
142
+ loss_D_fake = self.criterionGAN(pred_fake, False)
143
+ # Combined loss and calculate gradients
144
+ loss_D = (loss_D_real + loss_D_fake) * 0.5
145
+ loss_D.backward()
146
+ return loss_D
147
+
148
+ def backward_D_A(self):
149
+ """Calculate GAN loss for discriminator D_A"""
150
+ fake_B = self.fake_B_pool.query(self.fake_B)
151
+ self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
152
+
153
+ def backward_D_B(self):
154
+ """Calculate GAN loss for discriminator D_B"""
155
+ fake_A = self.fake_A_pool.query(self.fake_A)
156
+ self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
157
+
158
+ def backward_G(self,retain_graph=False):
159
+ """Calculate the loss for generators G_A and G_B"""
160
+ lambda_idt = self.opt.lambda_identity
161
+ lambda_A = self.opt.lambda_A
162
+ lambda_B = self.opt.lambda_B
163
+ # Identity loss
164
+ if lambda_idt > 0:
165
+ # G_A should be identity if real_B is fed: ||G_A(B) - B||
166
+ self.idt_A = self.netG_A(self.real_B)
167
+ self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
168
+ # G_B should be identity if real_A is fed: ||G_B(A) - A||
169
+ self.idt_B = self.netG_B(self.real_A)
170
+ self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
171
+ else:
172
+ self.loss_idt_A = 0
173
+ self.loss_idt_B = 0
174
+
175
+ # GAN loss D_A(G_A(A))
176
+ self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
177
+ # GAN loss D_B(G_B(B))
178
+ self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
179
+ # Forward cycle loss || G_B(G_A(A)) - A||
180
+ self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
181
+ # Backward cycle loss || G_A(G_B(B)) - B||
182
+ self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
183
+ # combined loss and calculate gradients
184
+ self.loss_G = self.loss_G_A + self.loss_G_B + 0.7*self.loss_cycle_A +0.7*self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
185
+
186
+ self.loss_G.backward(retain_graph=retain_graph)
187
+ #self.loss_G.backward(retain_graph=True)
188
+
189
+ def optimize_parameters(self):
190
+ torch.autograd.set_detect_anomaly(True)
191
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
192
+ # forward
193
+ self.forward() # compute fake images and reconstruction images.
194
+ # G_A and G_B
195
+ self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
196
+ '''
197
+ self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
198
+ self.backward_G() # calculate gradients for G_A and G_B
199
+ self.optimizer_G.step() # update G_A and G_B's weights
200
+ '''
201
+ '''
202
+ self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
203
+ for i in range(2): #生成器训练两次
204
+ self.backward_G(retain_graph=(i < 1)) # calculate gradients for G_A
205
+ self.optimizer_G.step() # update G_A and G_B's weights
206
+ # D_A and D_B
207
+ '''
208
+
209
+ self.optimizer_G.zero_grad()
210
+ self.backward_G(retain_graph=True) # 保留计算图
211
+ grad_cache_G = [p.grad.clone() for p in itertools.chain(self.netG_A.parameters(), self.netG_B.parameters())]
212
+
213
+ # 第二次生成器训练
214
+ self.optimizer_G.zero_grad()
215
+ self.backward_G(retain_graph=False)
216
+
217
+ # 梯度融合:将两次训练的梯度相加
218
+ for p, cache_g in zip(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), grad_cache_G):
219
+ if p.grad is not None:
220
+ p.grad += cache_g
221
+
222
+ # 执行参数更新
223
+ self.optimizer_G.step()
224
+
225
+ self.set_requires_grad([self.netD_A, self.netD_B], True)
226
+ self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
227
+ self.backward_D_A() # calculate gradients for D_A
228
+ self.backward_D_B() # calculate graidents for D_B
229
+ self.optimizer_D.step() # update D_A and D_B's weights
cyclegan_model/model/networks.py ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+ import random
7
+ import torch.nn.functional as F
8
+
9
+ ###############################################################################
10
+ # Helper Functions
11
+ ###############################################################################
12
+
13
+
14
+ class Identity(nn.Module):
15
+ def forward(self, x):
16
+ return x
17
+
18
+
19
+ def get_norm_layer(norm_type='instance'):
20
+ """Return a normalization layer
21
+
22
+ Parameters:
23
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
24
+
25
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
26
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
27
+ """
28
+ if norm_type == 'batch':
29
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
30
+ elif norm_type == 'instance':
31
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
32
+ elif norm_type == 'none':
33
+ def norm_layer(x):
34
+ return Identity()
35
+ else:
36
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
37
+ return norm_layer
38
+
39
+
40
+ def get_scheduler(optimizer, opt):
41
+ """Return a learning rate scheduler
42
+
43
+ Parameters:
44
+ optimizer -- the optimizer of the network
45
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
46
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
47
+
48
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
49
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
50
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
51
+ See https://pytorch.org/docs/stable/optim.html for more details.
52
+ """
53
+ if opt.lr_policy == 'linear':
54
+ def lambda_rule(epoch):
55
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
56
+ return lr_l
57
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
58
+ elif opt.lr_policy == 'step':
59
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
60
+ elif opt.lr_policy == 'plateau':
61
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
62
+ elif opt.lr_policy == 'cosine':
63
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
64
+ else:
65
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
66
+ return scheduler
67
+
68
+
69
+ def init_weights(net, init_type='normal', init_gain=0.02):
70
+ """Initialize network weights.
71
+
72
+ Parameters:
73
+ net (network) -- network to be initialized
74
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
75
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
76
+
77
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
78
+ work better for some applications. Feel free to try yourself.
79
+ """
80
+ def init_func(m): # define the initialization function
81
+ classname = m.__class__.__name__
82
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
83
+ if init_type == 'normal':
84
+ init.normal_(m.weight.data, 0.0, init_gain)
85
+ elif init_type == 'xavier':
86
+ init.xavier_normal_(m.weight.data, gain=init_gain)
87
+ elif init_type == 'kaiming':
88
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
89
+ elif init_type == 'orthogonal':
90
+ init.orthogonal_(m.weight.data, gain=init_gain)
91
+ else:
92
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
93
+ if hasattr(m, 'bias') and m.bias is not None:
94
+ init.constant_(m.bias.data, 0.0)
95
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
96
+ init.normal_(m.weight.data, 1.0, init_gain)
97
+ init.constant_(m.bias.data, 0.0)
98
+
99
+ print('initialize network with %s' % init_type)
100
+ net.apply(init_func) # apply the initialization function <init_func>
101
+
102
+
103
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
104
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
105
+ Parameters:
106
+ net (network) -- the network to be initialized
107
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
108
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
109
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
110
+
111
+ Return an initialized network.
112
+ """
113
+ if len(gpu_ids) > 0:
114
+ assert(torch.cuda.is_available())
115
+ net.to(gpu_ids[0])
116
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
117
+ init_weights(net, init_type, init_gain=init_gain)
118
+ return net
119
+
120
+
121
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
122
+ """Create a generator
123
+
124
+ Parameters:
125
+ input_nc (int) -- the number of channels in input images
126
+ output_nc (int) -- the number of channels in output images
127
+ ngf (int) -- the number of filters in the last conv layer
128
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
129
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
130
+ use_dropout (bool) -- if use dropout layers.
131
+ init_type (str) -- the name of our initialization method.
132
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
133
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
134
+
135
+ Returns a generator
136
+
137
+ Our current implementation provides two types of generators:
138
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
139
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
140
+
141
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
142
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
143
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
144
+
145
+
146
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
147
+ """
148
+ net = None
149
+ norm_layer = get_norm_layer(norm_type=norm)
150
+
151
+ if netG == 'resnet_9blocks':
152
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
153
+ elif netG == 'resnet_6blocks':
154
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
155
+ elif netG == 'unet_128':
156
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
157
+ elif netG == 'unet_256':
158
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
159
+ elif netG == 'resnet_attention':
160
+ net = ResnetGeneratorWithAttention(input_nc,output_nc,ngf,norm_layer=norm_layer,use_dropout=use_dropout,n_blocks=9)
161
+ elif netG == 'resnet_unet_attention':
162
+ net = Unet_SEA_ResnetGenerator(input_nc,output_nc, ngf, norm_layer, use_dropout)
163
+ elif netG == 'resnet_skip_attention':
164
+ net = ResnetGeneratorWithAttentionAndSkipConnection(input_nc,output_nc,ngf,norm_layer=norm_layer,use_dropout=use_dropout,n_blocks=12)
165
+ else:
166
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
167
+ return init_net(net, init_type, init_gain, gpu_ids)
168
+
169
+
170
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
171
+ """Create a discriminator
172
+
173
+ Parameters:
174
+ input_nc (int) -- the number of channels in input images
175
+ ndf (int) -- the number of filters in the first conv layer
176
+ netD (str) -- the architecture's name: basic | n_layers | pixel
177
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
178
+ norm (str) -- the type of normalization layers used in the network.
179
+ init_type (str) -- the name of the initialization method.
180
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
181
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
182
+
183
+ Returns a discriminator
184
+
185
+ Our current implementation provides three types of discriminators:
186
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
187
+ It can classify whether 70×70 overlapping patches are real or fake.
188
+ Such a patch-level discriminator architecture has fewer parameters
189
+ than a full-image discriminator and can work on arbitrarily-sized images
190
+ in a fully convolutional fashion.
191
+
192
+ [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
193
+ with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
194
+
195
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
196
+ It encourages greater color diversity but has no effect on spatial statistics.
197
+
198
+ The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
199
+ """
200
+ net = None
201
+ norm_layer = get_norm_layer(norm_type=norm)
202
+
203
+ if netD == 'basic': # default PatchGAN classifier
204
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
205
+ elif netD == 'n_layers': # more options
206
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
207
+ elif netD == 'pixel': # classify if each pixel is real or fake
208
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
209
+ else:
210
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
211
+ return init_net(net, init_type, init_gain, gpu_ids)
212
+
213
+
214
+ ##############################################################################
215
+ # Classes
216
+ ##############################################################################
217
+ class GANLoss(nn.Module):
218
+ """Define different GAN objectives.
219
+
220
+ The GANLoss class abstracts away the need to create the target label tensor
221
+ that has the same size as the input.
222
+ """
223
+
224
+ def __init__(self, gan_mode):
225
+ super(GANLoss, self).__init__()
226
+ target_real_label = random.randint(7, 12) * 0.1
227
+ target_fake_label = random.randint(0, 3) * 0.1
228
+ self.register_buffer('real_label', torch.tensor(target_real_label))
229
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
230
+ self.gan_mode = gan_mode
231
+ if gan_mode == 'lsgan':
232
+ self.loss = nn.MSELoss()
233
+ elif gan_mode == 'vanilla':
234
+ self.loss = nn.BCEWithLogitsLoss()
235
+ elif gan_mode in ['wgangp']:
236
+ self.loss = None
237
+ else:
238
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
239
+ '''
240
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
241
+ """ Initialize the GANLoss class.
242
+
243
+ Parameters:
244
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
245
+ target_real_label (bool) - - label for a real image
246
+ target_fake_label (bool) - - label of a fake image
247
+
248
+ Note: Do not use sigmoid as the last layer of Discriminator.
249
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
250
+ """
251
+ super(GANLoss, self).__init__()
252
+ self.register_buffer('real_label', torch.tensor(target_real_label))
253
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
254
+ self.gan_mode = gan_mode
255
+ if gan_mode == 'lsgan':
256
+ self.loss = nn.MSELoss()
257
+ elif gan_mode == 'vanilla':
258
+ self.loss = nn.BCEWithLogitsLoss()
259
+ elif gan_mode in ['wgangp']:
260
+ self.loss = None
261
+ else:
262
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
263
+ '''
264
+ def get_target_tensor(self, prediction, target_is_real):
265
+ """Create label tensors with the same size as the input.
266
+
267
+ Parameters:
268
+ prediction (tensor) - - tpyically the prediction from a discriminator
269
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
270
+
271
+ Returns:
272
+ A label tensor filled with ground truth label, and with the size of the input
273
+ """
274
+
275
+ if target_is_real:
276
+ target_tensor = self.real_label
277
+ else:
278
+ target_tensor = self.fake_label
279
+ return target_tensor.expand_as(prediction)
280
+
281
+ def __call__(self, prediction, target_is_real):
282
+ """Calculate loss given Discriminator's output and grount truth labels.
283
+
284
+ Parameters:
285
+ prediction (tensor) - - tpyically the prediction output from a discriminator
286
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
287
+
288
+ Returns:
289
+ the calculated loss.
290
+ """
291
+ if self.gan_mode in ['lsgan', 'vanilla']:
292
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
293
+ loss = self.loss(prediction, target_tensor)
294
+ elif self.gan_mode == 'wgangp':
295
+ if target_is_real:
296
+ loss = -prediction.mean()
297
+ else:
298
+ loss = prediction.mean()
299
+ return loss
300
+
301
+
302
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
303
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
304
+
305
+ Arguments:
306
+ netD (network) -- discriminator network
307
+ real_data (tensor array) -- real images
308
+ fake_data (tensor array) -- generated images from the generator
309
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
310
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
311
+ constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
312
+ lambda_gp (float) -- weight for this loss
313
+
314
+ Returns the gradient penalty loss
315
+ """
316
+ if lambda_gp > 0.0:
317
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
318
+ interpolatesv = real_data
319
+ elif type == 'fake':
320
+ interpolatesv = fake_data
321
+ elif type == 'mixed':
322
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
323
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
324
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
325
+ else:
326
+ raise NotImplementedError('{} not implemented'.format(type))
327
+ interpolatesv.requires_grad_(True)
328
+ disc_interpolates = netD(interpolatesv)
329
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
330
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
331
+ create_graph=True, retain_graph=True, only_inputs=True)
332
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
333
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
334
+ return gradient_penalty, gradients
335
+ else:
336
+ return 0.0, None
337
+
338
+
339
+ class ResnetGenerator(nn.Module):
340
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
341
+
342
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
343
+ """
344
+
345
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
346
+ """Construct a Resnet-based generator
347
+
348
+ Parameters:
349
+ input_nc (int) -- the number of channels in input images
350
+ output_nc (int) -- the number of channels in output images
351
+ ngf (int) -- the number of filters in the last conv layer
352
+ norm_layer -- normalization layer
353
+ use_dropout (bool) -- if use dropout layers
354
+ n_blocks (int) -- the number of ResNet blocks
355
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
356
+ """
357
+ assert(n_blocks >= 0)
358
+ super(ResnetGenerator, self).__init__()
359
+ if type(norm_layer) == functools.partial:
360
+ use_bias = norm_layer.func == nn.InstanceNorm2d
361
+ else:
362
+ use_bias = norm_layer == nn.InstanceNorm2d
363
+
364
+ model = [nn.ReflectionPad2d(3),
365
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
366
+ norm_layer(ngf),
367
+ nn.ReLU(True)]
368
+
369
+ n_downsampling = 2
370
+ for i in range(n_downsampling): # add downsampling layers
371
+ mult = 2 ** i
372
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
373
+ norm_layer(ngf * mult * 2),
374
+ nn.ReLU(True)]
375
+
376
+ mult = 2 ** n_downsampling
377
+ for i in range(n_blocks): # add ResNet blocks
378
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
379
+
380
+ for i in range(n_downsampling): # add upsampling layers
381
+ mult = 2 ** (n_downsampling - i)
382
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
383
+ kernel_size=3, stride=2,
384
+ padding=1, output_padding=1,
385
+ bias=use_bias),
386
+ norm_layer(int(ngf * mult / 2)),
387
+ nn.ReLU(True)]
388
+ model += [nn.ReflectionPad2d(3)]
389
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
390
+ model += [nn.Tanh()]
391
+
392
+ self.model = nn.Sequential(*model)
393
+
394
+ def forward(self, input):
395
+ """Standard forward"""
396
+ return self.model(input)
397
+
398
+
399
+ class ResnetBlock(nn.Module):
400
+ """Define a Resnet block"""
401
+
402
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
403
+ """Initialize the Resnet block
404
+
405
+ A resnet block is a conv block with skip connections
406
+ We construct a conv block with build_conv_block function,
407
+ and implement skip connections in <forward> function.
408
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
409
+ """
410
+ super(ResnetBlock, self).__init__()
411
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
412
+
413
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
414
+ """Construct a convolutional block.
415
+
416
+ Parameters:
417
+ dim (int) -- the number of channels in the conv layer.
418
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
419
+ norm_layer -- normalization layer
420
+ use_dropout (bool) -- if use dropout layers.
421
+ use_bias (bool) -- if the conv layer uses bias or not
422
+
423
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
424
+ """
425
+ conv_block = []
426
+ p = 0
427
+ if padding_type == 'reflect':
428
+ conv_block += [nn.ReflectionPad2d(1)]
429
+ elif padding_type == 'replicate':
430
+ conv_block += [nn.ReplicationPad2d(1)]
431
+ elif padding_type == 'zero':
432
+ p = 1
433
+ else:
434
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
435
+
436
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
437
+ if use_dropout:
438
+ conv_block += [nn.Dropout(0.5)]
439
+
440
+ p = 0
441
+ if padding_type == 'reflect':
442
+ conv_block += [nn.ReflectionPad2d(1)]
443
+ elif padding_type == 'replicate':
444
+ conv_block += [nn.ReplicationPad2d(1)]
445
+ elif padding_type == 'zero':
446
+ p = 1
447
+ else:
448
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
449
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
450
+
451
+ return nn.Sequential(*conv_block)
452
+
453
+ def forward(self, x):
454
+ """Forward function (with skip connections)"""
455
+ out = x + self.conv_block(x) # add skip connections
456
+ return out
457
+
458
+
459
+ class UnetGenerator(nn.Module):
460
+ """Create a Unet-based generator"""
461
+
462
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
463
+ """Construct a Unet generator
464
+ Parameters:
465
+ input_nc (int) -- the number of channels in input images
466
+ output_nc (int) -- the number of channels in output images
467
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
468
+ image of size 128x128 will become of size 1x1 # at the bottleneck
469
+ ngf (int) -- the number of filters in the last conv layer
470
+ norm_layer -- normalization layer
471
+
472
+ We construct the U-Net from the innermost layer to the outermost layer.
473
+ It is a recursive process.
474
+ """
475
+ super(UnetGenerator, self).__init__()
476
+ # construct unet structure
477
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
478
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
479
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
480
+ # gradually reduce the number of filters from ngf * 8 to ngf
481
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
482
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
483
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
484
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
485
+
486
+ def forward(self, input):
487
+ """Standard forward"""
488
+ return self.model(input)
489
+
490
+
491
+ class UnetSkipConnectionBlock(nn.Module):
492
+ """Defines the Unet submodule with skip connection.
493
+ X -------------------identity----------------------
494
+ |-- downsampling -- |submodule| -- upsampling --|
495
+ """
496
+
497
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
498
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
499
+ """Construct a Unet submodule with skip connections.
500
+
501
+ Parameters:
502
+ outer_nc (int) -- the number of filters in the outer conv layer
503
+ inner_nc (int) -- the number of filters in the inner conv layer
504
+ input_nc (int) -- the number of channels in input images/features
505
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
506
+ outermost (bool) -- if this module is the outermost module
507
+ innermost (bool) -- if this module is the innermost module
508
+ norm_layer -- normalization layer
509
+ use_dropout (bool) -- if use dropout layers.
510
+ """
511
+ super(UnetSkipConnectionBlock, self).__init__()
512
+ self.outermost = outermost
513
+ if type(norm_layer) == functools.partial:
514
+ use_bias = norm_layer.func == nn.InstanceNorm2d
515
+ else:
516
+ use_bias = norm_layer == nn.InstanceNorm2d
517
+ if input_nc is None:
518
+ input_nc = outer_nc
519
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
520
+ stride=2, padding=1, bias=use_bias)
521
+ downrelu = nn.LeakyReLU(0.2, True)
522
+ downnorm = norm_layer(inner_nc)
523
+ uprelu = nn.ReLU(True)
524
+ upnorm = norm_layer(outer_nc)
525
+
526
+ if outermost:
527
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
528
+ kernel_size=4, stride=2,
529
+ padding=1)
530
+ down = [downconv]
531
+ up = [uprelu, upconv, nn.Tanh()]
532
+ model = down + [submodule] + up
533
+ elif innermost:
534
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
535
+ kernel_size=4, stride=2,
536
+ padding=1, bias=use_bias)
537
+ down = [downrelu, downconv]
538
+ up = [uprelu, upconv, upnorm]
539
+ model = down + up
540
+ else:
541
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
542
+ kernel_size=4, stride=2,
543
+ padding=1, bias=use_bias)
544
+ down = [downrelu, downconv, downnorm]
545
+ up = [uprelu, upconv, upnorm]
546
+
547
+ if use_dropout:
548
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
549
+ else:
550
+ model = down + [submodule] + up
551
+
552
+ self.model = nn.Sequential(*model)
553
+
554
+ def forward(self, x):
555
+ if self.outermost:
556
+ return self.model(x)
557
+ else: # add skip connections
558
+ return torch.cat([x, self.model(x)], 1)
559
+
560
+
561
+ class NLayerDiscriminator(nn.Module):
562
+ """Defines a PatchGAN discriminator"""
563
+
564
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
565
+ """Construct a PatchGAN discriminator
566
+
567
+ Parameters:
568
+ input_nc (int) -- the number of channels in input images
569
+ ndf (int) -- the number of filters in the last conv layer
570
+ n_layers (int) -- the number of conv layers in the discriminator
571
+ norm_layer -- normalization layer
572
+ """
573
+ super(NLayerDiscriminator, self).__init__()
574
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
575
+ use_bias = norm_layer.func == nn.InstanceNorm2d
576
+ else:
577
+ use_bias = norm_layer == nn.InstanceNorm2d
578
+
579
+ kw = 4
580
+ padw = 1
581
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
582
+ nf_mult = 1
583
+ nf_mult_prev = 1
584
+ for n in range(1, n_layers): # gradually increase the number of filters
585
+ nf_mult_prev = nf_mult
586
+ nf_mult = min(2 ** n, 8)
587
+ sequence += [
588
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
589
+ norm_layer(ndf * nf_mult),
590
+ nn.LeakyReLU(0.2, True)
591
+ ]
592
+
593
+ nf_mult_prev = nf_mult
594
+ nf_mult = min(2 ** n_layers, 8)
595
+ sequence += [
596
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
597
+ norm_layer(ndf * nf_mult),
598
+ nn.LeakyReLU(0.2, True)
599
+ ]
600
+
601
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
602
+ self.model = nn.Sequential(*sequence)
603
+
604
+ def forward(self, input):
605
+ """Standard forward."""
606
+ return self.model(input)
607
+
608
+
609
+ class PixelDiscriminator(nn.Module):
610
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
611
+
612
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
613
+ """Construct a 1x1 PatchGAN discriminator
614
+
615
+ Parameters:
616
+ input_nc (int) -- the number of channels in input images
617
+ ndf (int) -- the number of filters in the last conv layer
618
+ norm_layer -- normalization layer
619
+ """
620
+ super(PixelDiscriminator, self).__init__()
621
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
622
+ use_bias = norm_layer.func == nn.InstanceNorm2d
623
+ else:
624
+ use_bias = norm_layer == nn.InstanceNorm2d
625
+
626
+ self.net = [
627
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
628
+ nn.LeakyReLU(0.2, True),
629
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
630
+ norm_layer(ndf * 2),
631
+ nn.LeakyReLU(0.2, True),
632
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
633
+
634
+ self.net = nn.Sequential(*self.net)
635
+
636
+ def forward(self, input):
637
+ """Standard forward."""
638
+ return self.net(input)
639
+
640
+ class Self_Attention(nn.Module):
641
+
642
+ def __init__(self, in_dim, activation):
643
+ super(Self_Attention, self).__init__()
644
+ self.chanel_in = in_dim
645
+ self.activation = activation
646
+ ## 下面的query_conv,key_conv,value_conv即对应Wg,Wf,Wh
647
+ self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) # 即得到C^ X C
648
+ self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) # 即得到C^ X C
649
+ self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) # 即得到C X C
650
+ self.gamma = nn.Parameter(torch.zeros(1)) # 这里即是计算最终输出的时候的伽马值,初始化为0
651
+
652
+ self.softmax = nn.Softmax(dim=-1)
653
+
654
+ def forward(self, x):
655
+ m_batchsize, C, width, height = x.size()
656
+ ## 下面的proj_query,proj_key都是C^ X C X C X N= C^ X N
657
+ proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N),permute即为转置
658
+ proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)
659
+ energy = torch.bmm(proj_query, proj_key) # transpose check,进行点乘操作
660
+ attention = self.softmax(energy) # BX (N) X (N)
661
+ proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
662
+
663
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
664
+ out = out.view(m_batchsize, C, width, height)
665
+
666
+ out = self.gamma * out + x
667
+ return out
668
+
669
+ class ResnetBlockWithAttention(nn.Module):
670
+ """Define a Resnet block"""
671
+
672
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
673
+ """Initialize the Resnet block
674
+
675
+ A resnet block is a conv block with skip connections
676
+ We construct a conv block with build_conv_block function,
677
+ and implement skip connections in <forward> function.
678
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
679
+ """
680
+ super(ResnetBlockWithAttention, self).__init__()
681
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
682
+ self.attention=Self_Attention(dim,'relu')
683
+
684
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
685
+ """Construct a convolutional block.
686
+
687
+ Parameters:
688
+ dim (int) -- the number of channels in the conv layer.
689
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
690
+ norm_layer -- normalization layer
691
+ use_dropout (bool) -- if use dropout layers.
692
+ use_bias (bool) -- if the conv layer uses bias or not
693
+
694
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
695
+ """
696
+ conv_block = []
697
+ p = 0
698
+ if padding_type == 'reflect':
699
+ conv_block += [nn.ReflectionPad2d(1)]
700
+ elif padding_type == 'replicate':
701
+ conv_block += [nn.ReplicationPad2d(1)]
702
+ elif padding_type == 'zero':
703
+ p = 1
704
+ else:
705
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
706
+
707
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
708
+ if use_dropout:
709
+ conv_block += [nn.Dropout(0.5)]
710
+
711
+ p = 0
712
+ if padding_type == 'reflect':
713
+ conv_block += [nn.ReflectionPad2d(1)]
714
+ elif padding_type == 'replicate':
715
+ conv_block += [nn.ReplicationPad2d(1)]
716
+ elif padding_type == 'zero':
717
+ p = 1
718
+ else:
719
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
720
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
721
+
722
+ return nn.Sequential(*conv_block)
723
+
724
+ def forward(self, x):
725
+ """Forward function (with skip connections)"""
726
+ out = x + self.conv_block(x)+self.attention(x) # add skip connections
727
+ return out
728
+
729
+
730
+ class ResnetGeneratorWithAttentionAndSkipConnection(nn.Module):
731
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6,
732
+ padding_type='reflect'):
733
+ super(ResnetGeneratorWithAttentionAndSkipConnection, self).__init__()
734
+ assert n_blocks >= 0
735
+ if type(norm_layer) == functools.partial:
736
+ use_bias = norm_layer.func == nn.InstanceNorm2d
737
+ else:
738
+ use_bias = norm_layer == nn.InstanceNorm2d
739
+
740
+ self.n_downsampling = n_downsampling = 2
741
+
742
+ # Initial convolution block
743
+ self.encoder_initial = nn.Sequential(
744
+ nn.ReflectionPad2d(3),
745
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
746
+ norm_layer(ngf),
747
+ nn.ReLU(True)
748
+ )
749
+
750
+ # Downsampling layers
751
+ self.encoder_down = nn.ModuleList()
752
+ for i in range(n_downsampling):
753
+ mult = 2 ** i
754
+ self.encoder_down.append(nn.Sequential(
755
+ nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
756
+ norm_layer(ngf * mult * 2),
757
+ nn.ReLU(True)
758
+ ))
759
+
760
+ # Middle ResNet blocks with attention
761
+ mult = 2 ** n_downsampling
762
+ self.middle = nn.Sequential()
763
+ for i in range(n_blocks):
764
+ self.middle.add_module(f'resnet_block_{i}', ResnetBlockWithAttention(
765
+ ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
766
+ use_dropout=use_dropout, use_bias=use_bias))
767
+
768
+ # Upsampling layers and skip connections
769
+ self.decoder_up = nn.ModuleList()
770
+ self.skip_convs = nn.ModuleList()
771
+ self.fusion_convs = nn.ModuleList()
772
+ self.fusion_norms = nn.ModuleList()
773
+ self.fusion_act = nn.ModuleList()
774
+
775
+ for j in range(n_downsampling):
776
+ # Upsampling layer
777
+ mult = 2 ** (n_downsampling - j)
778
+ self.decoder_up.append(nn.Sequential(
779
+ nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
780
+ kernel_size=3, stride=2,
781
+ padding=1, output_padding=1,
782
+ bias=use_bias),
783
+ norm_layer(int(ngf * mult / 2)),
784
+ nn.ReLU(True))
785
+ )
786
+
787
+ # Skip connection processing
788
+ encoder_layer_index = n_downsampling - j - 1
789
+ encoder_mult = 2 ** (encoder_layer_index + 1)
790
+ decoder_mult_out = int(ngf * mult / 2)
791
+
792
+ self.skip_convs.append(nn.Conv2d(ngf * encoder_mult, decoder_mult_out, kernel_size=1, bias=use_bias))
793
+
794
+ self.fusion_convs.append(
795
+ nn.Conv2d(decoder_mult_out * 2, decoder_mult_out, kernel_size=3, padding=1, bias=use_bias))
796
+ self.fusion_norms.append(norm_layer(decoder_mult_out))
797
+ self.fusion_act.append(nn.ReLU(True))
798
+
799
+ # Final output layer
800
+ self.output = nn.Sequential(
801
+ nn.ReflectionPad2d(3),
802
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
803
+ nn.Tanh()
804
+ )
805
+
806
+ def forward(self, input):
807
+ # Encoder
808
+ x = self.encoder_initial(input)
809
+ skips = []
810
+ for down_layer in self.encoder_down:
811
+ x = down_layer(x)
812
+ skips.append(x)
813
+
814
+ # Middle blocks
815
+ x = self.middle(x)
816
+
817
+ # Decoder with skip connections
818
+ for j in range(self.n_downsampling):
819
+ x = self.decoder_up[j](x)
820
+
821
+ # Get corresponding skip connection
822
+ skip_index = self.n_downsampling - j - 1
823
+ skip = skips[skip_index]
824
+
825
+ # Process skip connection
826
+ skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=False)
827
+ skip = self.skip_convs[j](skip)
828
+
829
+ # Concatenate and fuse
830
+ x = torch.cat([x, skip], dim=1)
831
+ x = self.fusion_convs[j](x)
832
+ x = self.fusion_norms[j](x)
833
+ x = self.fusion_act[j](x)
834
+
835
+ # Final output
836
+ return self.output(x)
837
+
838
+
839
+ class ResnetGeneratorWithAttention(nn.Module):
840
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
841
+
842
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
843
+ """
844
+
845
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
846
+ """Construct a Resnet-based generator
847
+
848
+ Parameters:
849
+ input_nc (int) -- the number of channels in input images
850
+ output_nc (int) -- the number of channels in output images
851
+ ngf (int) -- the number of filters in the last conv layer
852
+ norm_layer -- normalization layer
853
+ use_dropout (bool) -- if use dropout layers
854
+ n_blocks (int) -- the number of ResNet blocks
855
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
856
+ """
857
+ assert(n_blocks >= 0)
858
+ super(ResnetGeneratorWithAttention, self).__init__()
859
+ if type(norm_layer) == functools.partial:
860
+ use_bias = norm_layer.func == nn.InstanceNorm2d
861
+ else:
862
+ use_bias = norm_layer == nn.InstanceNorm2d
863
+
864
+ model = [nn.ReflectionPad2d(3),
865
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
866
+ norm_layer(ngf),
867
+ nn.ReLU(True)]
868
+
869
+ n_downsampling = 2
870
+ for i in range(n_downsampling): # add downsampling layers
871
+ mult = 2 ** i
872
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
873
+ norm_layer(ngf * mult * 2),
874
+ nn.ReLU(True)]
875
+
876
+ mult = 2 ** n_downsampling
877
+ for i in range(n_blocks): # add ResNet blocks
878
+
879
+ model += [ResnetBlockWithAttention(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
880
+
881
+ for i in range(n_downsampling): # add upsampling layers
882
+ mult = 2 ** (n_downsampling - i)
883
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
884
+ kernel_size=3, stride=2,
885
+ padding=1, output_padding=1,
886
+ bias=use_bias),
887
+ norm_layer(int(ngf * mult / 2)),
888
+ nn.ReLU(True)]
889
+ model += [nn.ReflectionPad2d(3)]
890
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
891
+ model += [nn.Tanh()]
892
+
893
+ self.model = nn.Sequential(*model)
894
+
895
+ def forward(self, input):
896
+ """Standard forward"""
897
+ return self.model(input)
898
+
899
+
900
+ class Unet_SEA_ResnetGenerator(nn.Module):
901
+
902
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
903
+ padding_type='reflect'):
904
+ super(Unet_SEA_ResnetGenerator, self).__init__()
905
+ if type(norm_layer) == functools.partial:
906
+ use_bias = norm_layer.func == nn.InstanceNorm2d
907
+ else:
908
+ use_bias = norm_layer == nn.InstanceNorm2d
909
+ self.pad = nn.ReflectionPad2d(3)
910
+ self.Down_conv1 = nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias) # 下采样第一层
911
+ self.conv_norm = norm_layer(input_nc)
912
+ self.relu = nn.ReLU(True)
913
+ self.Down_conv2 = nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias) # 下采样第二层
914
+ self.SA = Self_Attention(ngf * 2, 'relu')
915
+ self.Down_conv3 = nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias) # 下采样第三层
916
+ self.Sa_block_1 = SEA_ResnetBlock_1(ngf * 4, padding_type=padding_type, norm_layer=norm_layer,
917
+ use_dropout=use_dropout, use_bias=use_bias)
918
+ self.Sa_block_2 = SEA_ResnetBlock_1(ngf * 4, padding_type=padding_type, norm_layer=norm_layer,
919
+ use_dropout=use_dropout, use_bias=use_bias)
920
+ self.Sa_block_3 = SEA_ResnetBlock_1(ngf * 4, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
921
+ use_bias=use_bias)
922
+ self.Up_conv1 = nn.ConvTranspose2d(ngf * 4 * 2, ngf * 2, kernel_size=3, stride=2, padding=1, output_padding=1,
923
+ bias=use_bias)
924
+ self.Up_conv2 = nn.ConvTranspose2d(ngf * 2 * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1,
925
+ bias=use_bias)
926
+ self.Up_conv3 = nn.Conv2d(ngf * 2, output_nc, kernel_size=7, padding=0)
927
+ self.tan = nn.Tanh()
928
+
929
+ def forward(self, x):
930
+ x1 = self.relu(self.conv_norm(self.Down_conv1(self.pad(x))))
931
+ x2 = self.relu(self.conv_norm(self.Down_conv2(x1)))
932
+ x3 = self.relu(self.conv_norm(self.Down_conv3(x2)))
933
+ x4 = self.Sa_block_3(self.Sa_block_2(self.Sa_block_1(x3)))
934
+ x = torch.cat([x4, x3], 1)
935
+ x = self.relu(self.conv_norm(self.Up_conv1(x)))
936
+ x = torch.cat([x, x2], 1)
937
+ x = self.relu(self.conv_norm(self.Up_conv2(x)))
938
+ x = torch.cat([x, x1], 1)
939
+ x = self.tan(self.Up_conv3(self.pad(x)))
940
+ return x
941
+
942
+ class SEA_ResnetBlock_1(nn.Module):
943
+
944
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
945
+ super(SEA_ResnetBlock_1, self).__init__()
946
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
947
+ self.self_attention = Self_Attention(dim, 'relu')
948
+
949
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
950
+ conv_block = []
951
+ p = 0
952
+ if padding_type == 'reflect':
953
+ conv_block += [nn.ReflectionPad2d(1)]
954
+ elif padding_type == 'replicate':
955
+ conv_block += [nn.ReplicationPad2d(1)]
956
+ elif padding_type == 'zero':
957
+ p = 1
958
+ else:
959
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
960
+
961
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
962
+ if use_dropout:
963
+ conv_block += [nn.Dropout(0.5)]
964
+
965
+ p = 0
966
+ if padding_type == 'reflect':
967
+ conv_block += [nn.ReflectionPad2d(1)]
968
+ elif padding_type == 'replicate':
969
+ conv_block += [nn.ReplicationPad2d(1)]
970
+ elif padding_type == 'zero':
971
+ p = 1
972
+ else:
973
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
974
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
975
+
976
+ return nn.Sequential(*conv_block)
977
+
978
+ def forward(self, x):
979
+ out = self.self_attention(x) + self.conv_block(x) + x # add skip connections
980
+ return out
981
+
982
+
983
+ class Discriminator(nn.Module):
984
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
985
+ super(Discriminator, self).__init__()
986
+ # 256 x 256
987
+ self.conv1 = nn.Sequential(nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1),
988
+ nn.ELU(True),
989
+ norm_layer(ndf),
990
+ nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
991
+ nn.ELU(True),
992
+ nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
993
+ norm_layer(ndf),
994
+ nn.AvgPool2d(2,2))
995
+
996
+
997
+ # 128 x 128
998
+ self.conv2 = nn.Sequential(nn.Conv2d(ndf,2*ndf,kernel_size=3,stride=1,padding=1),
999
+ nn.ELU(True),
1000
+ norm_layer(ndf*2),
1001
+ nn.Conv2d(2*ndf,2*ndf,kernel_size=3,stride=1,padding=1),
1002
+ nn.ELU(True),
1003
+ norm_layer(ndf*2),
1004
+ nn.AvgPool2d(2,2))
1005
+ # 64 x 64
1006
+ self.conv3 = nn.Sequential(nn.Conv2d(2*ndf,3*ndf,kernel_size=3,stride=1,padding=1),
1007
+ nn.ELU(True),
1008
+ norm_layer(ndf*3),
1009
+ nn.Conv2d(3*ndf,3*ndf,kernel_size=3,stride=1,padding=1),
1010
+ nn.ELU(True),
1011
+ norm_layer(ndf*3),
1012
+ nn.AvgPool2d(2,2))
1013
+ # 32 x 32
1014
+ self.conv4 = nn.Sequential(nn.Conv2d(3*ndf,4*ndf,kernel_size=3,stride=1,padding=1),
1015
+ nn.ELU(True),
1016
+ norm_layer(ndf*4),
1017
+ nn.Conv2d(4*ndf,4*ndf,kernel_size=3,stride=1,padding=1),
1018
+ nn.ELU(True),
1019
+ norm_layer(ndf*4),
1020
+ nn.AvgPool2d(2,2))
1021
+ # 16 x 16
1022
+ self.conv5 = nn.Sequential(nn.Conv2d(4*ndf,5*ndf,kernel_size=3,stride=1,padding=1),
1023
+ nn.ELU(True),
1024
+ norm_layer(ndf*5),
1025
+ nn.Conv2d(5*ndf,5*ndf,kernel_size=3,stride=1,padding=1),
1026
+ nn.ELU(True),
1027
+ norm_layer(ndf*5),
1028
+ nn.AvgPool2d(2,2))
1029
+ # 8 x 8
1030
+
1031
+
1032
+ self.embed1 = nn.Linear(ndf * 5 * 8 * 8, 64)
1033
+ self.embed2 = nn.Linear(64, ndf * 8 * 8)
1034
+
1035
+ # 8 x 8
1036
+ self.deconv1 = nn.Sequential(nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1037
+ nn.ELU(True),
1038
+ nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1039
+ nn.ELU(True),
1040
+ nn.Upsample(size=None,scale_factor=2,mode='bilinear',align_corners=False))
1041
+ # 16 x 16
1042
+ self.deconv2 = nn.Sequential(nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1043
+ nn.ELU(True),
1044
+ nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1045
+ nn.ELU(True),
1046
+ nn.Upsample(size=None,scale_factor=2,mode='bilinear',align_corners=False))
1047
+ # 32 x 32
1048
+ self.deconv3 = nn.Sequential(nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1049
+ nn.ELU(True),
1050
+ nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1051
+ nn.ELU(True),
1052
+ nn.Upsample(size=None,scale_factor=2,mode='bilinear',align_corners=False))
1053
+ # 64 x 64
1054
+ self.deconv4 = nn.Sequential(nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1055
+ nn.ELU(True),
1056
+ nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1057
+ nn.ELU(True),
1058
+ nn.Upsample(size=None,scale_factor=2,mode='bilinear',align_corners=False))
1059
+ # 128 x 128
1060
+ self.deconv5 = nn.Sequential(nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1061
+ nn.ELU(True),
1062
+ nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
1063
+ nn.ELU(True),
1064
+ nn.Upsample(size=None,scale_factor=2,mode='bilinear',align_corners=False))
1065
+ # 256 x 256
1066
+ self.deconv6 = nn.Sequential(nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1),
1067
+ nn.ELU(True),
1068
+ nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1),
1069
+ nn.ELU(True),
1070
+ nn.Conv2d(ndf, input_nc, kernel_size=3, stride=1, padding=1),
1071
+ nn.Tanh())
1072
+ self.ndf = ndf
1073
+
1074
+ def forward(self, x):
1075
+ out = self.conv1(x)
1076
+ out = self.conv2(out)
1077
+ out = self.conv3(out)
1078
+ out = self.conv4(out)
1079
+ out = self.conv5(out)
1080
+ out = out.view(out.size(0), self.ndf * 5 * 8 * 8)
1081
+
1082
+ out = self.embed1(out)
1083
+ out = self.embed2(out)
1084
+ out = out.view(out.size(0), self.ndf, 8, 8)
1085
+ out = self.deconv1(out)
1086
+ out = self.deconv2(out)
1087
+ out = self.deconv3(out)
1088
+ out = self.deconv4(out)
1089
+ out = self.deconv5(out)
1090
+ out = self.deconv6(out)
1091
+ return out
cyclegan_model/model/test_model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_model import BaseModel
2
+ from . import networks
3
+
4
+
5
+ class TestModel(BaseModel):
6
+ """ This TesteModel can be used to generate CycleGAN results for only one direction.
7
+ This model will automatically set '--dataset_mode single', which only loads the images from one collection.
8
+
9
+ See the test instruction for more details.
10
+ """
11
+ @staticmethod
12
+ def modify_commandline_options(parser, is_train=True):
13
+ """Add new dataset-specific options, and rewrite default values for existing options.
14
+
15
+ Parameters:
16
+ parser -- original option parser
17
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
18
+
19
+ Returns:
20
+ the modified parser.
21
+
22
+ The model can only be used during test time. It requires '--dataset_mode single'.
23
+ You need to specify the network using the option '--model_suffix'.
24
+ """
25
+ assert not is_train, 'TestModel cannot be used during training time'
26
+ parser.set_defaults(dataset_mode='single')
27
+ parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
28
+
29
+ return parser
30
+
31
+ def __init__(self, opt):
32
+ """Initialize the pix2pix class.
33
+
34
+ Parameters:
35
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
36
+ """
37
+ assert(not opt.isTrain)
38
+ BaseModel.__init__(self, opt)
39
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
40
+ self.loss_names = []
41
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
42
+ self.visual_names = ['real', 'fake']
43
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
44
+ self.model_names = ['G' + opt.model_suffix] # only generator is needed.
45
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
46
+ opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
47
+
48
+ # assigns the model to self.netG_[suffix] so that it can be loaded
49
+ # please see <BaseModel.load_networks>
50
+ setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
51
+
52
+ def set_input(self, input):
53
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
54
+
55
+ Parameters:
56
+ input: a dictionary that contains the data itself and its metadata information.
57
+
58
+ We need to use 'single_dataset' dataset mode. It only load images from one domain.
59
+ """
60
+ self.real = input['A'].to(self.device)
61
+ self.image_paths = input['A_paths']
62
+
63
+ def forward(self):
64
+ """Run forward pass."""
65
+ self.fake = self.netG(self.real) # G(real)
66
+
67
+ def optimize_parameters(self):
68
+ """No optimization for test model."""
69
+ pass
70
+
cyclegan_model/options/__init__.py ADDED
File without changes
cyclegan_model/options/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (169 Bytes). View file
 
cyclegan_model/options/__pycache__/base_options.cpython-310.pyc ADDED
Binary file (6.88 kB). View file
 
cyclegan_model/options/__pycache__/test_options.cpython-310.pyc ADDED
Binary file (1.15 kB). View file
 
cyclegan_model/options/base_options.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from cyclegan_model.util import util
4
+ import torch
5
+ import cyclegan_model.model as model
6
+ import cyclegan_model.data as data
7
+
8
+
9
+ class BaseOptions():
10
+ """This class defines options used during both training and test time.
11
+
12
+ It also implements several helper functions such as parsing, printing, and saving the options.
13
+ It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
14
+ """
15
+
16
+ def __init__(self):
17
+ """Reset the class; indicates the class hasn't been initailized"""
18
+ self.initialized = False
19
+
20
+ def initialize(self, parser):
21
+ """Define the common options that are used in both training and test."""
22
+ # basic parameters
23
+ parser.add_argument('--dataroot',type=str,help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
24
+ parser.add_argument('--name', type=str, default='ostracoda_cyclegan', help='name of the experiment. It decides where to store samples and models')
25
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
26
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
27
+ # model parameters
28
+ parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
29
+ parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
30
+ parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
31
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
32
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
33
+ parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
34
+ parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
35
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
36
+ parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
37
+ parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
38
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
39
+ parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
40
+ # dataset parameters
41
+ parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
42
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
43
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
44
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
45
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
46
+ parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
47
+ parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
48
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
49
+ parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
50
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
51
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
52
+ # additional parameters
53
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
54
+ parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
55
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
56
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
57
+ # wandb parameters
58
+ parser.add_argument('--use_wandb', action='store_true', help='if specified, then init wandb logging')
59
+ parser.add_argument('--wandb_project_name', type=str, default='CycleGAN-and-pix2pix', help='specify wandb project name')
60
+ self.initialized = True
61
+ return parser
62
+
63
+ def gather_options(self):
64
+ """Initialize our parser with basic options(only once).
65
+ Add additional model-specific and dataset-specific options.
66
+ These options are defined in the <modify_commandline_options> function
67
+ in model and dataset classes.
68
+ """
69
+ if not self.initialized: # check if it has been initialized
70
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
71
+ parser = self.initialize(parser)
72
+
73
+ # get the basic options
74
+ opt, _ = parser.parse_known_args()
75
+
76
+ # modify model-related parser options
77
+ model_name = opt.model
78
+ model_option_setter = model.get_option_setter(model_name)
79
+ parser = model_option_setter(parser, self.isTrain)
80
+ opt, _ = parser.parse_known_args() # parse again with new defaults
81
+
82
+ # modify dataset-related parser options
83
+ dataset_name = opt.dataset_mode
84
+ dataset_option_setter = data.get_option_setter(dataset_name)
85
+ parser = dataset_option_setter(parser, self.isTrain)
86
+
87
+ # save and return the parser
88
+ self.parser = parser
89
+ return parser.parse_args()
90
+
91
+ def print_options(self, opt):
92
+ """Print and save options
93
+
94
+ It will print both current options and default values(if different).
95
+ It will save options into a text file / [checkpoints_dir] / opt.txt
96
+ """
97
+ message = ''
98
+ message += '----------------- Options ---------------\n'
99
+ for k, v in sorted(vars(opt).items()):
100
+ comment = ''
101
+ default = self.parser.get_default(k)
102
+ if v != default:
103
+ comment = '\t[default: %s]' % str(default)
104
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
105
+ message += '----------------- End -------------------'
106
+ print(message)
107
+
108
+ # save to the disk
109
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
110
+ util.mkdirs(expr_dir)
111
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
112
+ with open(file_name, 'wt') as opt_file:
113
+ opt_file.write(message)
114
+ opt_file.write('\n')
115
+
116
+ def parse(self):
117
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
118
+ opt = self.gather_options()
119
+ opt.isTrain = self.isTrain # train or test
120
+
121
+ # process opt.suffix
122
+ if opt.suffix:
123
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
124
+ opt.name = opt.name + suffix
125
+
126
+
127
+ # set gpu ids
128
+ str_ids = opt.gpu_ids.split(',')
129
+ opt.gpu_ids = []
130
+ for str_id in str_ids:
131
+ id = int(str_id)
132
+ if id >= 0:
133
+ opt.gpu_ids.append(id)
134
+ if len(opt.gpu_ids) > 0:
135
+ torch.cuda.set_device(opt.gpu_ids[0])
136
+
137
+ self.opt = opt
138
+ return self.opt
cyclegan_model/options/test_options.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TestOptions(BaseOptions):
5
+ """This class includes test options.
6
+
7
+ It also includes shared options defined in BaseOptions.
8
+ """
9
+
10
+ def initialize(self, parser):
11
+ parser = BaseOptions.initialize(self, parser) # define shared options
12
+ parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
13
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
14
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
15
+ # Dropout and Batchnorm has different behavioir during training and test.
16
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
17
+ parser.add_argument('--num_test', type=int, default=100, help='how many test images to run')
18
+
19
+ # rewrite devalue values
20
+ # parser.set_defaults(model='test')
21
+ # To avoid cropping, the load_size should be the same as crop_size
22
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
23
+ self.isTrain = False
24
+ return parser
cyclegan_model/util/__init__.py ADDED
File without changes
cyclegan_model/util/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (166 Bytes). View file
 
cyclegan_model/util/__pycache__/image_pool.cpython-310.pyc ADDED
Binary file (1.81 kB). View file
 
cyclegan_model/util/__pycache__/util.cpython-310.pyc ADDED
Binary file (3.2 kB). View file
 
cyclegan_model/util/image_pool.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class ImagePool():
6
+ """This class implements an image buffer that stores previously generated images.
7
+
8
+ This buffer enables us to update discriminators using a history of generated images
9
+ rather than the ones produced by the latest generators.
10
+ """
11
+
12
+ def __init__(self, pool_size):
13
+ """Initialize the ImagePool class
14
+
15
+ Parameters:
16
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17
+ """
18
+ self.pool_size = pool_size
19
+ if self.pool_size > 0: # create an empty pool
20
+ self.num_imgs = 0
21
+ self.images = []
22
+
23
+ def query(self, images):
24
+ """Return an image from the pool.
25
+
26
+ Parameters:
27
+ images: the latest generated images from the generator
28
+
29
+ Returns images from the buffer.
30
+
31
+ By 50/100, the buffer will return input images.
32
+ By 50/100, the buffer will return images previously stored in the buffer,
33
+ and insert the current images to the buffer.
34
+ """
35
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
36
+ return images
37
+ return_images = []
38
+ for image in images:
39
+ image = torch.unsqueeze(image.data, 0)
40
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41
+ self.num_imgs = self.num_imgs + 1
42
+ self.images.append(image)
43
+ return_images.append(image)
44
+ else:
45
+ p = random.uniform(0, 1)
46
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48
+ tmp = self.images[random_id].clone()
49
+ self.images[random_id] = image
50
+ return_images.append(tmp)
51
+ else: # by another 50% chance, the buffer will return the current image
52
+ return_images.append(image)
53
+ return_images = torch.cat(return_images, 0) # collect all the images and return
54
+ return return_images
cyclegan_model/util/util.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains simple helper functions """
2
+ from __future__ import print_function
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+
8
+
9
+ def tensor2im(input_image, imtype=np.uint8):
10
+ """"Converts a Tensor array into a numpy image array.
11
+
12
+ Parameters:
13
+ input_image (tensor) -- the input image tensor array
14
+ imtype (type) -- the desired type of the converted numpy array
15
+ """
16
+ if not isinstance(input_image, np.ndarray):
17
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
18
+ image_tensor = input_image.data
19
+ else:
20
+ return input_image
21
+ image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
22
+ if image_numpy.shape[0] == 1: # grayscale to RGB
23
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
24
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
25
+ else: # if it is a numpy array, do nothing
26
+ image_numpy = input_image
27
+ return image_numpy.astype(imtype)
28
+
29
+
30
+ def diagnose_network(net, name='network'):
31
+ """Calculate and print the mean of average absolute(gradients)
32
+
33
+ Parameters:
34
+ net (torch network) -- Torch network
35
+ name (str) -- the name of the network
36
+ """
37
+ mean = 0.0
38
+ count = 0
39
+ for param in net.parameters():
40
+ if param.grad is not None:
41
+ mean += torch.mean(torch.abs(param.grad.data))
42
+ count += 1
43
+ if count > 0:
44
+ mean = mean / count
45
+ print(name)
46
+ print(mean)
47
+
48
+
49
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
50
+ """Save a numpy image to the disk
51
+
52
+ Parameters:
53
+ image_numpy (numpy array) -- input numpy array
54
+ image_path (str) -- the path of the image
55
+ """
56
+
57
+ image_pil = Image.fromarray(image_numpy)
58
+ h, w, _ = image_numpy.shape
59
+
60
+ if aspect_ratio > 1.0:
61
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
62
+ if aspect_ratio < 1.0:
63
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
64
+ image_pil.save(image_path)
65
+
66
+
67
+ def print_numpy(x, val=True, shp=False):
68
+ """Print the mean, min, max, median, std, and size of a numpy array
69
+
70
+ Parameters:
71
+ val (bool) -- if print the values of the numpy array
72
+ shp (bool) -- if print the shape of the numpy array
73
+ """
74
+ x = x.astype(np.float64)
75
+ if shp:
76
+ print('shape,', x.shape)
77
+ if val:
78
+ x = x.flatten()
79
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
80
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
81
+
82
+
83
+ def mkdirs(paths):
84
+ """create empty directories if they don't exist
85
+
86
+ Parameters:
87
+ paths (str list) -- a list of directory paths
88
+ """
89
+ if isinstance(paths, list) and not isinstance(paths, str):
90
+ for path in paths:
91
+ mkdir(path)
92
+ else:
93
+ mkdir(paths)
94
+
95
+
96
+ def mkdir(path):
97
+ """create a single empty directory if it didn't exist
98
+
99
+ Parameters:
100
+ path (str) -- a single directory path
101
+ """
102
+ if not os.path.exists(path):
103
+ os.makedirs(path)
data/content/12.jpg ADDED
data/content/15.jpg ADDED
data/content/27032.jpg ADDED
data/style/13.png ADDED
data/style/2.jpg ADDED
data/style/9.jpg ADDED
flux_ad/__pycache__/utils.cpython-310.pyc ADDED
Binary file (6.07 kB). View file
 
flux_ad/main.ipynb ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Style-specific T2I Generation with Flux.1-dev"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
18
+ "from accelerate.utils import set_seed\n",
19
+ "from mypipeline import AttDistPipeline\n",
20
+ "from utils import *\n",
21
+ "\n",
22
+ "\n",
23
+ "model_name = \"/root/models/FLUX.1-dev\"\n",
24
+ "lr = 0.01\n",
25
+ "iters = 2\n",
26
+ "seed = 42\n",
27
+ "width = 512\n",
28
+ "height = 512\n",
29
+ "mixed_precision = \"bf16\"\n",
30
+ "num_inference_steps = 50\n",
31
+ "guidance_scale = 3.5\n",
32
+ "enable_gradient_checkpoint = True\n",
33
+ "start_layer, end_layer = 50, 57\n",
34
+ "start_time = 9999\n",
35
+ "prompt=\"A panda\"\n",
36
+ "\n",
37
+ "\n",
38
+ "pipe = AttDistPipeline.from_pretrained(\n",
39
+ " model_name, torch_dtype=torch.float16)\n",
40
+ "\n",
41
+ "\n",
42
+ "memory_efficient(pipe)\n",
43
+ "set_seed(seed)\n",
44
+ "loss_fn = torch.nn.L1Loss()\n",
45
+ "\n",
46
+ "style_image = [\"..//data/style/1.jpg\"]\n",
47
+ "style_image = torch.cat([load_image(path, size=(512, 512)) for path in style_image])\n",
48
+ "\n",
49
+ "\n",
50
+ "controller = Controller(self_layers=(start_layer, end_layer))\n",
51
+ "\n",
52
+ "result = pipe.sample(\n",
53
+ " lr=lr,\n",
54
+ " prompt=prompt,\n",
55
+ " loss_fn=loss_fn,\n",
56
+ " iters=iters,\n",
57
+ " width=width,\n",
58
+ " height=height,\n",
59
+ " start_time=start_time,\n",
60
+ " controller=controller,\n",
61
+ " style_image=style_image,\n",
62
+ " guidance_scale=guidance_scale,\n",
63
+ " mixed_precision=mixed_precision,\n",
64
+ " num_inference_steps=num_inference_steps,\n",
65
+ " enable_gradient_checkpoint=enable_gradient_checkpoint,\n",
66
+ ")\n",
67
+ "\n",
68
+ "save_image(style_image, \"style.png\")\n",
69
+ "save_image(result, \"output.png\")\n",
70
+ "show_image(\"style.png\", title=\"style image\")\n",
71
+ "show_image(\"output.png\", title=prompt)"
72
+ ]
73
+ }
74
+ ],
75
+ "metadata": {
76
+ "kernelspec": {
77
+ "display_name": "ad",
78
+ "language": "python",
79
+ "name": "python3"
80
+ },
81
+ "language_info": {
82
+ "codemirror_mode": {
83
+ "name": "ipython",
84
+ "version": 3
85
+ },
86
+ "file_extension": ".py",
87
+ "mimetype": "text/x-python",
88
+ "name": "python",
89
+ "nbconvert_exporter": "python",
90
+ "pygments_lexer": "ipython3",
91
+ "version": "3.10.16"
92
+ }
93
+ },
94
+ "nbformat": 4,
95
+ "nbformat_minor": 2
96
+ }
flux_ad/mypipeline.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from accelerate import Accelerator
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import utils
9
+ from diffusers import FluxPipeline
10
+
11
+
12
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
13
+ def retrieve_latents(
14
+ encoder_output: torch.Tensor,
15
+ generator: Optional[torch.Generator] = None,
16
+ sample_mode: str = "sample",
17
+ ):
18
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
19
+ return encoder_output.latent_dist.sample(generator)
20
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
21
+ return encoder_output.latent_dist.mode()
22
+ elif hasattr(encoder_output, "latents"):
23
+ return encoder_output.latents
24
+ else:
25
+ raise AttributeError("Could not access latents of provided encoder_output")
26
+
27
+
28
+ class AttDistPipeline(FluxPipeline):
29
+
30
+ def freeze(self):
31
+ self.transformer.requires_grad_(False)
32
+ self.text_encoder.requires_grad_(False)
33
+ self.text_encoder_2.requires_grad_(False)
34
+ self.vae.requires_grad_(False)
35
+
36
+ @torch.no_grad()
37
+ def image2latent(self, image):
38
+ dtype = next(self.vae.parameters()).dtype
39
+ device = self._execution_device
40
+ image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
41
+ latent = retrieve_latents(self.vae.encode(image))
42
+ latent = (
43
+ latent - self.vae.config.shift_factor
44
+ ) * self.vae.config.scaling_factor
45
+ return latent
46
+
47
+ @torch.no_grad()
48
+ def latent2image(self, latent, height, width):
49
+ dtype = next(self.vae.parameters()).dtype
50
+ device = self._execution_device
51
+ latent = latent.to(device=device, dtype=dtype)
52
+ latents = self._unpack_latents(latent, height, width, self.vae_scale_factor)
53
+ latents = (
54
+ latents / self.vae.config.scaling_factor
55
+ ) + self.vae.config.shift_factor
56
+ image = self.vae.decode(latents, return_dict=False)[0]
57
+ return (image * 0.5 + 0.5).clamp(0, 1)
58
+
59
+ def sample(
60
+ self,
61
+ style_image=None,
62
+ controller=None,
63
+ loss_fn=None,
64
+ start_time=9999,
65
+ lr=0.05,
66
+ iters=2,
67
+ mixed_precision="no",
68
+ enable_gradient_checkpoint=False,
69
+ prompt: Union[str, List[str]] = None,
70
+ prompt_2: Optional[Union[str, List[str]]] = None,
71
+ height: Optional[int] = None,
72
+ width: Optional[int] = None,
73
+ num_inference_steps: int = 28,
74
+ # timesteps: List[int] = None,
75
+ guidance_scale: float = 3.5,
76
+ num_images_per_prompt: Optional[int] = 1,
77
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
78
+ latents: Optional[torch.FloatTensor] = None,
79
+ prompt_embeds: Optional[torch.FloatTensor] = None,
80
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
81
+ output_type: Optional[str] = "pil",
82
+ return_dict: bool = True,
83
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
84
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
85
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
86
+ max_sequence_length: int = 512,
87
+ ):
88
+ height = height or self.default_sample_size * self.vae_scale_factor
89
+ width = width or self.default_sample_size * self.vae_scale_factor
90
+ device = self._execution_device
91
+ accelerator = Accelerator(
92
+ mixed_precision=mixed_precision, gradient_accumulation_steps=1
93
+ )
94
+ weight_dtype = torch.float32
95
+ if accelerator.mixed_precision == "fp16":
96
+ weight_dtype = torch.float16
97
+ elif accelerator.mixed_precision == "bf16":
98
+ weight_dtype = torch.bfloat16
99
+
100
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
101
+ self.transformer.to(accelerator.device, dtype=weight_dtype)
102
+ self.vae.to(accelerator.device, dtype=weight_dtype)
103
+ self.text_encoder.to(accelerator.device, dtype=weight_dtype)
104
+ self.text_encoder_2.to(accelerator.device, dtype=weight_dtype)
105
+ self.transformer = accelerator.prepare(self.transformer)
106
+ if enable_gradient_checkpoint:
107
+ self.transformer.enable_gradient_checkpointing()
108
+ # self.transformer.train()
109
+
110
+ (null_embeds, null_pooled_embeds, null_text_ids) = self.encode_prompt(
111
+ prompt="",
112
+ prompt_2=prompt_2,
113
+ )
114
+ (
115
+ prompt_embeds,
116
+ pooled_prompt_embeds,
117
+ text_ids,
118
+ ) = self.encode_prompt(
119
+ prompt=prompt,
120
+ prompt_2=prompt_2,
121
+ prompt_embeds=prompt_embeds,
122
+ pooled_prompt_embeds=pooled_prompt_embeds,
123
+ device=device,
124
+ num_images_per_prompt=num_images_per_prompt,
125
+ max_sequence_length=max_sequence_length,
126
+ )
127
+ # 4. Prepare latent variables
128
+ num_channels_latents = self.transformer.config.in_channels // 4
129
+ latents, latent_image_ids = self.prepare_latents(
130
+ num_images_per_prompt,
131
+ num_channels_latents,
132
+ height,
133
+ width,
134
+ null_embeds.dtype,
135
+ device,
136
+ generator,
137
+ latents,
138
+ )
139
+
140
+ print(style_image.shape)
141
+ height_, width_ = style_image.shape[2], style_image.shape[3]
142
+ style_latent = self.image2latent(style_image)
143
+ print(style_latent.shape)
144
+ print(latents.shape)
145
+ style_latent = self._pack_latents(style_latent, 1, num_channels_latents, style_latent.shape[2], style_latent.shape[3])
146
+
147
+ _, null_image_id = self.prepare_latents(
148
+ num_images_per_prompt,
149
+ num_channels_latents,
150
+ height_,
151
+ width_,
152
+ null_embeds.dtype,
153
+ device,
154
+ generator,
155
+ style_latent,
156
+ )
157
+
158
+ # 5. Prepare timesteps
159
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
160
+ image_seq_len = latents.shape[1]
161
+ mu = calculate_shift(
162
+ image_seq_len,
163
+ self.scheduler.config.base_image_seq_len,
164
+ self.scheduler.config.max_image_seq_len,
165
+ self.scheduler.config.base_shift,
166
+ self.scheduler.config.max_shift,
167
+ )
168
+ timesteps, num_inference_steps = retrieve_timesteps(
169
+ self.scheduler,
170
+ num_inference_steps,
171
+ device,
172
+ None,
173
+ sigmas,
174
+ mu=mu,
175
+ )
176
+
177
+ timesteps = self.scheduler.timesteps
178
+ print(f"timesteps: {timesteps}")
179
+ self._num_timesteps = len(timesteps)
180
+
181
+ cache = utils.DataCache()
182
+
183
+ utils.register_attn_control(
184
+ self.transformer.transformer_blocks,
185
+ controller=controller,
186
+ cache=cache,
187
+ )
188
+ utils.register_attn_control(
189
+ self.transformer.single_transformer_blocks,
190
+ controller=controller,
191
+ cache=cache,
192
+ )
193
+ # handle guidance
194
+ if self.transformer.config.guidance_embeds:
195
+ guidance = torch.full(
196
+ [1], guidance_scale, device=device, dtype=torch.float32
197
+ )
198
+ guidance = guidance.expand(latents.shape[0])
199
+ else:
200
+ guidance = None
201
+
202
+ null_guidance = torch.full(
203
+ [1], 1, device=device, dtype=torch.float32
204
+ )
205
+
206
+ print(controller.num_self_layers)
207
+
208
+
209
+ pbar = tqdm(timesteps, desc="Sample")
210
+ for i, t in enumerate(pbar):
211
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
212
+ with torch.no_grad():
213
+ noise_pred = self.transformer(
214
+ hidden_states=latents,
215
+ timestep=timestep / 1000,
216
+ guidance=guidance,
217
+ pooled_projections=pooled_prompt_embeds,
218
+ encoder_hidden_states=prompt_embeds,
219
+ txt_ids=text_ids,
220
+ img_ids=latent_image_ids,
221
+ joint_attention_kwargs=None,
222
+ return_dict=False,
223
+ )[0]
224
+
225
+ # compute the previous noisy sample x_t -> x_t-1
226
+ latents = self.scheduler.step(
227
+ noise_pred, t, latents, return_dict=False
228
+ )[0]
229
+ if t < start_time:
230
+ if i < num_inference_steps - 1:
231
+ timestep = timesteps[i+1:i+2]
232
+ # print(timestep)
233
+ noise = torch.randn_like(style_latent)
234
+ # print(style_latent.shape)
235
+ style_latent_ = self.scheduler.scale_noise(style_latent, timestep, noise)
236
+ else:
237
+ timestep = torch.tensor([0], device=style_latent.device)
238
+ style_latent_ = style_latent
239
+
240
+ cache.clear()
241
+ controller.step()
242
+
243
+ _ = self.transformer(
244
+ hidden_states=style_latent_,
245
+ timestep=timestep / 1000,
246
+ guidance=null_guidance,
247
+ pooled_projections=null_pooled_embeds,
248
+ encoder_hidden_states=null_embeds,
249
+ txt_ids=null_text_ids,
250
+ img_ids=null_image_id,
251
+ joint_attention_kwargs=None,
252
+ return_dict=False,
253
+ )[0]
254
+ _, ref_k_list, ref_v_list, _ = cache.get()
255
+
256
+
257
+ latents = utils.adain(latents, style_latent_)
258
+ latents = latents.detach()
259
+ optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
260
+ optimizer = accelerator.prepare(optimizer)
261
+
262
+ for _ in range(iters):
263
+ cache.clear()
264
+ controller.step()
265
+ optimizer.zero_grad()
266
+ _ = self.transformer(
267
+ hidden_states=latents,
268
+ timestep=timestep / 1000,
269
+ guidance=null_guidance,
270
+ pooled_projections=null_pooled_embeds,
271
+ encoder_hidden_states=null_embeds,
272
+ txt_ids=null_text_ids,
273
+ img_ids=latent_image_ids,
274
+ joint_attention_kwargs=None,
275
+ return_dict=False,
276
+ )[0]
277
+ q_list, _, _, self_out_list = cache.get()
278
+ ref_self_out_list = [
279
+ F.scaled_dot_product_attention(
280
+ q,
281
+ ref_k,
282
+ ref_v,
283
+ )
284
+ for q, ref_k, ref_v in zip(q_list, ref_k_list, ref_v_list)
285
+ ]
286
+ style_loss = sum(
287
+ [
288
+ loss_fn(self_out, ref_self_out.detach())
289
+ for self_out, ref_self_out in zip(
290
+ self_out_list, ref_self_out_list
291
+ )
292
+ ]
293
+ )
294
+ loss = style_loss
295
+ accelerator.backward(loss)
296
+ # loss.backward()
297
+ optimizer.step()
298
+
299
+ pbar.set_postfix(loss=loss.item(), time=t.item())
300
+ torch.cuda.empty_cache()
301
+ latents = latents.detach()
302
+ return self.latent2image(latents, height, width)
303
+
304
+
305
+ def calculate_shift(
306
+ image_seq_len,
307
+ base_seq_len: int = 256,
308
+ max_seq_len: int = 4096,
309
+ base_shift: float = 0.5,
310
+ max_shift: float = 1.16,
311
+ ):
312
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
313
+ b = base_shift - m * base_seq_len
314
+ mu = image_seq_len * m + b
315
+ return mu
316
+
317
+
318
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
319
+ def retrieve_timesteps(
320
+ scheduler,
321
+ num_inference_steps: Optional[int] = None,
322
+ device: Optional[Union[str, torch.device]] = None,
323
+ timesteps: Optional[List[int]] = None,
324
+ sigmas: Optional[List[float]] = None,
325
+ **kwargs,
326
+ ):
327
+ r"""
328
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
329
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
330
+
331
+ Args:
332
+ scheduler (`SchedulerMixin`):
333
+ The scheduler to get timesteps from.
334
+ num_inference_steps (`int`):
335
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
336
+ must be `None`.
337
+ device (`str` or `torch.device`, *optional*):
338
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
339
+ timesteps (`List[int]`, *optional*):
340
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
341
+ `num_inference_steps` and `sigmas` must be `None`.
342
+ sigmas (`List[float]`, *optional*):
343
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
344
+ `num_inference_steps` and `timesteps` must be `None`.
345
+
346
+ Returns:
347
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
348
+ second element is the number of inference steps.
349
+ """
350
+ if timesteps is not None and sigmas is not None:
351
+ raise ValueError(
352
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
353
+ )
354
+ if timesteps is not None:
355
+ accepts_timesteps = "timesteps" in set(
356
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
357
+ )
358
+ if not accepts_timesteps:
359
+ raise ValueError(
360
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
361
+ f" timestep schedules. Please check whether you are using the correct scheduler."
362
+ )
363
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
364
+ timesteps = scheduler.timesteps
365
+ num_inference_steps = len(timesteps)
366
+ elif sigmas is not None:
367
+ accept_sigmas = "sigmas" in set(
368
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
369
+ )
370
+ if not accept_sigmas:
371
+ raise ValueError(
372
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
373
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
374
+ )
375
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
376
+ timesteps = scheduler.timesteps
377
+ num_inference_steps = len(timesteps)
378
+ else:
379
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
380
+ timesteps = scheduler.timesteps
381
+ return timesteps, num_inference_steps
flux_ad/utils.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from torchvision.transforms import ToTensor
6
+ from torchvision.utils import save_image
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ def register_attn_control(unet, controller, cache=None):
11
+ def attn_forward(self):
12
+
13
+ def forward(
14
+ hidden_states,
15
+ encoder_hidden_states=None,
16
+ attention_mask=None,
17
+ image_rotary_emb=None,
18
+ *args,
19
+ **kwargs,
20
+ ):
21
+ batch_size, _, _ = (
22
+ hidden_states.shape
23
+ if encoder_hidden_states is None
24
+ else encoder_hidden_states.shape
25
+ )
26
+
27
+ # `sample` projections.
28
+ query = self.to_q(hidden_states)
29
+ key = self.to_k(hidden_states)
30
+ value = self.to_v(hidden_states)
31
+
32
+ inner_dim = key.shape[-1]
33
+ head_dim = inner_dim // self.heads
34
+
35
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
36
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
37
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
38
+
39
+ if self.norm_q is not None:
40
+ query = self.norm_q(query)
41
+ if self.norm_k is not None:
42
+ key = self.norm_k(key)
43
+
44
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
45
+ if encoder_hidden_states is not None:
46
+ # `context` projections.
47
+ encoder_hidden_states_query_proj = self.add_q_proj(
48
+ encoder_hidden_states
49
+ )
50
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
51
+ encoder_hidden_states_value_proj = self.add_v_proj(
52
+ encoder_hidden_states
53
+ )
54
+
55
+ encoder_hidden_states_query_proj = (
56
+ encoder_hidden_states_query_proj.view(
57
+ batch_size, -1, self.heads, head_dim
58
+ ).transpose(1, 2)
59
+ )
60
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
61
+ batch_size, -1, self.heads, head_dim
62
+ ).transpose(1, 2)
63
+ encoder_hidden_states_value_proj = (
64
+ encoder_hidden_states_value_proj.view(
65
+ batch_size, -1, self.heads, head_dim
66
+ ).transpose(1, 2)
67
+ )
68
+
69
+ if self.norm_added_q is not None:
70
+ encoder_hidden_states_query_proj = self.norm_added_q(
71
+ encoder_hidden_states_query_proj
72
+ )
73
+ if self.norm_added_k is not None:
74
+ encoder_hidden_states_key_proj = self.norm_added_k(
75
+ encoder_hidden_states_key_proj
76
+ )
77
+
78
+ # attention
79
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
80
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
81
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
82
+
83
+ if image_rotary_emb is not None:
84
+ from diffusers.models.embeddings import apply_rotary_emb
85
+
86
+ query = apply_rotary_emb(query, image_rotary_emb)
87
+ key = apply_rotary_emb(key, image_rotary_emb)
88
+
89
+ hidden_states = F.scaled_dot_product_attention(
90
+ query, key, value, dropout_p=0.0, is_causal=False
91
+ )
92
+ if controller.cur_self_layer in controller.self_layers:
93
+ # print("cache added")
94
+ cache.add(query, key, value, hidden_states)
95
+ # if encoder_hidden_states is None:
96
+ controller.cur_self_layer += 1
97
+
98
+ hidden_states = hidden_states.transpose(1, 2).reshape(
99
+ batch_size, -1, self.heads * head_dim
100
+ )
101
+
102
+ hidden_states = hidden_states.to(query.dtype)
103
+
104
+ if encoder_hidden_states is not None:
105
+ encoder_hidden_states, hidden_states = (
106
+ hidden_states[:, : encoder_hidden_states.shape[1]],
107
+ hidden_states[:, encoder_hidden_states.shape[1] :],
108
+ )
109
+
110
+ # linear proj
111
+ hidden_states = self.to_out[0](hidden_states)
112
+ # dropout
113
+ hidden_states = self.to_out[1](hidden_states)
114
+ encoder_hidden_states = self.to_add_out(encoder_hidden_states)
115
+
116
+ return hidden_states, encoder_hidden_states
117
+ else:
118
+ return hidden_states
119
+
120
+ return forward
121
+
122
+ def modify_forward(net, count):
123
+ # print(net.named_children())
124
+ for name, subnet in net.named_children():
125
+ if net.__class__.__name__ == "Attention": # spatial Transformer layer
126
+ net.forward = attn_forward(net)
127
+ return count + 1
128
+ elif hasattr(net, "children"):
129
+ count = modify_forward(subnet, count)
130
+ return count
131
+
132
+ cross_att_count = 0
133
+ cross_att_count += modify_forward(unet, 0)
134
+ controller.num_self_layers += cross_att_count
135
+
136
+
137
+ def load_image(image_path, size=None, mode="RGB"):
138
+ img = Image.open(image_path).convert(mode)
139
+ if size is None:
140
+ width, height = img.size
141
+ new_width = (width // 64) * 64
142
+ new_height = (height // 64) * 64
143
+ size = (new_width, new_height)
144
+ img = img.resize(size, Image.BICUBIC)
145
+ return ToTensor()(img).unsqueeze(0)
146
+
147
+
148
+ def adain(source, target, eps=1e-6):
149
+ source_mean, source_std = torch.mean(source, dim=1, keepdim=True), torch.std(
150
+ source, dim=1, keepdim=True
151
+ )
152
+ target_mean, target_std = torch.mean(
153
+ target, dim=(0, 1), keepdim=True
154
+ ), torch.std(target, dim=(0, 1), keepdim=True)
155
+ normalized_source = (source - source_mean) / (source_std + eps)
156
+ transferred_source = normalized_source * target_std + target_mean
157
+
158
+ return transferred_source
159
+
160
+
161
+ def shuffle_tensor(tensor):
162
+ B, C, H, W = tensor.shape
163
+ flat_tensor = tensor.reshape(B, C, -1)
164
+ shuffled_tensor = flat_tensor[:, :, torch.randperm(H * W)].reshape(B, C, H, W)
165
+ return shuffled_tensor
166
+
167
+
168
+ class Controller:
169
+ def step(self):
170
+ self.cur_self_layer = 0
171
+
172
+ def __init__(self, self_layers=(0, 16)):
173
+ self.num_self_layers = 0
174
+ self.cur_self_layer = 0
175
+ self.self_layers = list(range(*self_layers))
176
+
177
+
178
+ class DataCache:
179
+ def __init__(self):
180
+ self.q = []
181
+ self.k = []
182
+ self.v = []
183
+ self.out = []
184
+
185
+ def clear(self):
186
+ self.q.clear()
187
+ self.k.clear()
188
+ self.v.clear()
189
+ self.out.clear()
190
+
191
+ def add(self, q, k, v, out):
192
+ self.q.append(q)
193
+ self.k.append(k)
194
+ self.v.append(v)
195
+ self.out.append(out)
196
+
197
+ def get(self):
198
+ return self.q.copy(), self.k.copy(), self.v.copy(), self.out.copy()
199
+
200
+
201
+ def memory_efficient(model):
202
+ model.freeze()
203
+ try:
204
+ model.enable_model_cpu_offload()
205
+ except AttributeError:
206
+ print("enable_model_cpu_offload is not supported.")
207
+ try:
208
+ model.enable_vae_slicing()
209
+ except AttributeError:
210
+ print("enable_vae_slicing is not supported.")
211
+
212
+ try:
213
+ model.enable_vae_tiling()
214
+ except AttributeError:
215
+ print("enable_vae_tiling is not supported.")
216
+
217
+ def show_image(path, title, display_height=3, title_fontsize=12):
218
+ img = Image.open(path)
219
+ img_width, img_height = img.size
220
+
221
+ aspect_ratio = img_width / img_height
222
+ display_width = display_height * aspect_ratio
223
+
224
+ plt.figure(figsize=(display_width, display_height))
225
+ plt.imshow(img)
226
+ plt.title(title,
227
+ fontsize=title_fontsize,
228
+ fontweight='bold',
229
+ pad=20)
230
+ plt.axis('off')
231
+ plt.tight_layout()
232
+ plt.show()
inpaint_model/model/__init__.py ADDED
File without changes
inpaint_model/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (166 Bytes). View file
 
inpaint_model/model/__pycache__/networks.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
inpaint_model/model/networks.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import spectral_norm as spectral_norm_fn
5
+ from torch.nn.utils import weight_norm as weight_norm_fn
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from torchvision import utils as vutils
9
+ from inpaint_model.utils.tools import extract_image_patches, flow_to_image, \
10
+ reduce_mean, reduce_sum, default_loader, same_padding
11
+
12
+
13
+ class Generator(nn.Module):
14
+ def __init__(self, config, use_cuda, device_ids):
15
+ super(Generator, self).__init__()
16
+ self.input_dim = config['input_dim']
17
+ self.cnum = config['ngf']
18
+ self.use_cuda = use_cuda
19
+ self.device_ids = device_ids
20
+
21
+ self.coarse_generator = CoarseGenerator(self.input_dim, self.cnum, self.use_cuda, self.device_ids)
22
+ self.fine_generator = FineGenerator(self.input_dim, self.cnum, self.use_cuda, self.device_ids)
23
+
24
+ def forward(self, x, mask):
25
+ x_stage1 = self.coarse_generator(x, mask)
26
+ x_stage2, offset_flow = self.fine_generator(x, x_stage1, mask)
27
+ return x_stage1, x_stage2, offset_flow
28
+
29
+
30
+ class CoarseGenerator(nn.Module):
31
+ def __init__(self, input_dim, cnum, use_cuda=True, device_ids=None):
32
+ super(CoarseGenerator, self).__init__()
33
+ self.use_cuda = use_cuda
34
+ self.device_ids = device_ids
35
+
36
+ self.conv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
37
+ self.conv2_downsample = gen_conv(cnum, cnum*2, 3, 2, 1)
38
+ self.conv3 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
39
+ self.conv4_downsample = gen_conv(cnum*2, cnum*4, 3, 2, 1)
40
+ self.conv5 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
41
+ self.conv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
42
+
43
+ self.conv7_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 2, rate=2)
44
+ self.conv8_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 4, rate=4)
45
+ self.conv9_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 8, rate=8)
46
+ self.conv10_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 16, rate=16)
47
+
48
+ self.conv11 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
49
+ self.conv12 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
50
+
51
+ self.conv13 = gen_conv(cnum*4, cnum*2, 3, 1, 1)
52
+ self.conv14 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
53
+ self.conv15 = gen_conv(cnum*2, cnum, 3, 1, 1)
54
+ self.conv16 = gen_conv(cnum, cnum//2, 3, 1, 1)
55
+ self.conv17 = gen_conv(cnum//2, input_dim, 3, 1, 1, activation='none')
56
+
57
+ def forward(self, x, mask):
58
+ # For indicating the boundaries of images
59
+ ones = torch.ones(x.size(0), 1, x.size(2), x.size(3))
60
+ if self.use_cuda:
61
+ ones = ones.cuda()
62
+ mask = mask.cuda()
63
+ # 5 x 256 x 256
64
+ x = self.conv1(torch.cat([x, ones, mask], dim=1))
65
+ x = self.conv2_downsample(x)
66
+ # cnum*2 x 128 x 128
67
+ x = self.conv3(x)
68
+ x = self.conv4_downsample(x)
69
+ # cnum*4 x 64 x 64
70
+ x = self.conv5(x)
71
+ x = self.conv6(x)
72
+ x = self.conv7_atrous(x)
73
+ x = self.conv8_atrous(x)
74
+ x = self.conv9_atrous(x)
75
+ x = self.conv10_atrous(x)
76
+ x = self.conv11(x)
77
+ x = self.conv12(x)
78
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
79
+ # cnum*2 x 128 x 128
80
+ x = self.conv13(x)
81
+ x = self.conv14(x)
82
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
83
+ # cnum x 256 x 256
84
+ x = self.conv15(x)
85
+ x = self.conv16(x)
86
+ x = self.conv17(x)
87
+ # 3 x 256 x 256
88
+ x_stage1 = torch.clamp(x, -1., 1.)
89
+
90
+ return x_stage1
91
+
92
+
93
+ class FineGenerator(nn.Module):
94
+ def __init__(self, input_dim, cnum, use_cuda=True, device_ids=None):
95
+ super(FineGenerator, self).__init__()
96
+ self.use_cuda = use_cuda
97
+ self.device_ids = device_ids
98
+
99
+ # 3 x 256 x 256
100
+ self.conv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
101
+ self.conv2_downsample = gen_conv(cnum, cnum, 3, 2, 1)
102
+ # cnum*2 x 128 x 128
103
+ self.conv3 = gen_conv(cnum, cnum*2, 3, 1, 1)
104
+ self.conv4_downsample = gen_conv(cnum*2, cnum*2, 3, 2, 1)
105
+ # cnum*4 x 64 x 64
106
+ self.conv5 = gen_conv(cnum*2, cnum*4, 3, 1, 1)
107
+ self.conv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
108
+
109
+ self.conv7_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 2, rate=2)
110
+ self.conv8_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 4, rate=4)
111
+ self.conv9_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 8, rate=8)
112
+ self.conv10_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 16, rate=16)
113
+
114
+ # attention branch
115
+ # 3 x 256 x 256
116
+ self.pmconv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
117
+ self.pmconv2_downsample = gen_conv(cnum, cnum, 3, 2, 1)
118
+ # cnum*2 x 128 x 128
119
+ self.pmconv3 = gen_conv(cnum, cnum*2, 3, 1, 1)
120
+ self.pmconv4_downsample = gen_conv(cnum*2, cnum*4, 3, 2, 1)
121
+ # cnum*4 x 64 x 64
122
+ self.pmconv5 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
123
+ self.pmconv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1, activation='relu')
124
+ self.contextul_attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10,
125
+ fuse=True, use_cuda=self.use_cuda, device_ids=self.device_ids)
126
+ self.pmconv9 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
127
+ self.pmconv10 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
128
+ self.allconv11 = gen_conv(cnum*8, cnum*4, 3, 1, 1)
129
+ self.allconv12 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
130
+ self.allconv13 = gen_conv(cnum*4, cnum*2, 3, 1, 1)
131
+ self.allconv14 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
132
+ self.allconv15 = gen_conv(cnum*2, cnum, 3, 1, 1)
133
+ self.allconv16 = gen_conv(cnum, cnum//2, 3, 1, 1)
134
+ self.allconv17 = gen_conv(cnum//2, input_dim, 3, 1, 1, activation='none')
135
+
136
+ def forward(self, xin, x_stage1, mask):
137
+ x1_inpaint = x_stage1 * mask + xin * (1. - mask)
138
+ # For indicating the boundaries of images
139
+ ones = torch.ones(xin.size(0), 1, xin.size(2), xin.size(3))
140
+ if self.use_cuda:
141
+ ones = ones.cuda()
142
+ mask = mask.cuda()
143
+ # conv branch
144
+ xnow = torch.cat([x1_inpaint, ones, mask], dim=1)
145
+ x = self.conv1(xnow)
146
+ x = self.conv2_downsample(x)
147
+ x = self.conv3(x)
148
+ x = self.conv4_downsample(x)
149
+ x = self.conv5(x)
150
+ x = self.conv6(x)
151
+ x = self.conv7_atrous(x)
152
+ x = self.conv8_atrous(x)
153
+ x = self.conv9_atrous(x)
154
+ x = self.conv10_atrous(x)
155
+ x_hallu = x
156
+ # attention branch
157
+ x = self.pmconv1(xnow)
158
+ x = self.pmconv2_downsample(x)
159
+ x = self.pmconv3(x)
160
+ x = self.pmconv4_downsample(x)
161
+ x = self.pmconv5(x)
162
+ x = self.pmconv6(x)
163
+ x, offset_flow = self.contextul_attention(x, x, mask)
164
+ x = self.pmconv9(x)
165
+ x = self.pmconv10(x)
166
+ pm = x
167
+ x = torch.cat([x_hallu, pm], dim=1)
168
+ # merge two branches
169
+ x = self.allconv11(x)
170
+ x = self.allconv12(x)
171
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
172
+ x = self.allconv13(x)
173
+ x = self.allconv14(x)
174
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
175
+ x = self.allconv15(x)
176
+ x = self.allconv16(x)
177
+ x = self.allconv17(x)
178
+ x_stage2 = torch.clamp(x, -1., 1.)
179
+
180
+ return x_stage2, offset_flow
181
+
182
+
183
+ class ContextualAttention(nn.Module):
184
+ def __init__(self, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10,
185
+ fuse=False, use_cuda=False, device_ids=None):
186
+ super(ContextualAttention, self).__init__()
187
+ self.ksize = ksize
188
+ self.stride = stride
189
+ self.rate = rate
190
+ self.fuse_k = fuse_k
191
+ self.softmax_scale = softmax_scale
192
+ self.fuse = fuse
193
+ self.use_cuda = use_cuda
194
+ self.device_ids = device_ids
195
+
196
+ def forward(self, f, b, mask=None):
197
+ """ Contextual attention layer implementation.
198
+ Contextual attention is first introduced in publication:
199
+ Generative Image Inpainting with Contextual Attention, Yu et al.
200
+ Args:
201
+ f: Input feature to match (foreground).
202
+ b: Input feature for match (background).
203
+ mask: Input mask for b, indicating patches not available.
204
+ ksize: Kernel size for contextual attention.
205
+ stride: Stride for extracting patches from b.
206
+ rate: Dilation for matching.
207
+ softmax_scale: Scaled softmax for attention.
208
+ Returns:
209
+ torch.tensor: output
210
+ """
211
+ # get shapes
212
+ raw_int_fs = list(f.size()) # b*c*h*w
213
+ raw_int_bs = list(b.size()) # b*c*h*w
214
+
215
+ # extract patches from background with stride and rate
216
+ kernel = 2 * self.rate
217
+ # raw_w is extracted for reconstruction
218
+ raw_w = extract_image_patches(b, ksizes=[kernel, kernel],
219
+ strides=[self.rate*self.stride,
220
+ self.rate*self.stride],
221
+ rates=[1, 1],
222
+ padding='same') # [N, C*k*k, L]
223
+ # raw_shape: [N, C, k, k, L]
224
+ raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
225
+ raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
226
+ raw_w_groups = torch.split(raw_w, 1, dim=0)
227
+
228
+ # downscaling foreground option: downscaling both foreground and
229
+ # background for matching and use original background for reconstruction.
230
+ f = F.interpolate(f, scale_factor=1./self.rate, mode='nearest')
231
+ b = F.interpolate(b, scale_factor=1./self.rate, mode='nearest')
232
+ int_fs = list(f.size()) # b*c*h*w
233
+ int_bs = list(b.size())
234
+ f_groups = torch.split(f, 1, dim=0) # split tensors along the batch dimension
235
+ # w shape: [N, C*k*k, L]
236
+ w = extract_image_patches(b, ksizes=[self.ksize, self.ksize],
237
+ strides=[self.stride, self.stride],
238
+ rates=[1, 1],
239
+ padding='same')
240
+ # w shape: [N, C, k, k, L]
241
+ w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
242
+ w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
243
+ w_groups = torch.split(w, 1, dim=0)
244
+
245
+ # process mask
246
+ if mask is None:
247
+ mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]])
248
+ if self.use_cuda:
249
+ mask = mask.cuda()
250
+ else:
251
+ mask = F.interpolate(mask, scale_factor=1./(4*self.rate), mode='nearest')
252
+ int_ms = list(mask.size())
253
+ # m shape: [N, C*k*k, L]
254
+ m = extract_image_patches(mask, ksizes=[self.ksize, self.ksize],
255
+ strides=[self.stride, self.stride],
256
+ rates=[1, 1],
257
+ padding='same')
258
+ # m shape: [N, C, k, k, L]
259
+ m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
260
+ m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
261
+ m = m[0] # m shape: [L, C, k, k]
262
+ # mm shape: [L, 1, 1, 1]
263
+ mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True)==0.).to(torch.float32)
264
+ mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
265
+
266
+ y = []
267
+ offsets = []
268
+ k = self.fuse_k
269
+ scale = self.softmax_scale # to fit the PyTorch tensor image value range
270
+ fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k
271
+ if self.use_cuda:
272
+ fuse_weight = fuse_weight.cuda()
273
+
274
+ for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
275
+ '''
276
+ O => output channel as a conv filter
277
+ I => input channel as a conv filter
278
+ xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
279
+ wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
280
+ raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
281
+ '''
282
+ # conv for compare
283
+ escape_NaN = torch.FloatTensor([1e-4])
284
+ if self.use_cuda:
285
+ escape_NaN = escape_NaN.cuda()
286
+ wi = wi[0] # [L, C, k, k]
287
+ max_wi = torch.sqrt(reduce_sum(torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True))
288
+ wi_normed = wi / max_wi
289
+ # xi shape: [1, C, H, W], yi shape: [1, L, H, W]
290
+ xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
291
+ yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W]
292
+ # conv implementation for fuse scores to encourage large patches
293
+ if self.fuse:
294
+ # make all of depth to spatial resolution
295
+ yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3]) # (B=1, I=1, H=32*32, W=32*32)
296
+ yi = same_padding(yi, [k, k], [1, 1], [1, 1])
297
+ yi = F.conv2d(yi, fuse_weight, stride=1) # (B=1, C=1, H=32*32, W=32*32)
298
+ yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]) # (B=1, 32, 32, 32, 32)
299
+ yi = yi.permute(0, 2, 1, 4, 3)
300
+ yi = yi.contiguous().view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])
301
+ yi = same_padding(yi, [k, k], [1, 1], [1, 1])
302
+ yi = F.conv2d(yi, fuse_weight, stride=1)
303
+ yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2])
304
+ yi = yi.permute(0, 2, 1, 4, 3).contiguous()
305
+ yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]) # (B=1, C=32*32, H=32, W=32)
306
+ # softmax to match
307
+ yi = yi * mm
308
+ yi = F.softmax(yi*scale, dim=1)
309
+ yi = yi * mm # [1, L, H, W]
310
+
311
+ offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
312
+
313
+ if int_bs != int_fs:
314
+ # Normalize the offset value to match foreground dimension
315
+ times = float(int_fs[2] * int_fs[3]) / float(int_bs[2] * int_bs[3])
316
+ offset = ((offset + 1).float() * times - 1).to(torch.int64)
317
+ offset = torch.cat([offset//int_fs[3], offset%int_fs[3]], dim=1) # 1*2*H*W
318
+
319
+ # deconv for patch pasting
320
+ wi_center = raw_wi[0]
321
+ # yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding
322
+ yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64)
323
+ y.append(yi)
324
+ offsets.append(offset)
325
+
326
+ y = torch.cat(y, dim=0) # back to the mini-batch
327
+ y.contiguous().view(raw_int_fs)
328
+
329
+ offsets = torch.cat(offsets, dim=0)
330
+ offsets = offsets.view(int_fs[0], 2, *int_fs[2:])
331
+
332
+ # case1: visualize optical flow: minus current position
333
+ h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(int_fs[0], -1, -1, int_fs[3])
334
+ w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(int_fs[0], -1, int_fs[2], -1)
335
+ ref_coordinate = torch.cat([h_add, w_add], dim=1)
336
+ if self.use_cuda:
337
+ ref_coordinate = ref_coordinate.cuda()
338
+
339
+ offsets = offsets - ref_coordinate
340
+ # flow = pt_flow_to_image(offsets)
341
+
342
+ flow = torch.from_numpy(flow_to_image(offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255.
343
+ flow = flow.permute(0, 3, 1, 2)
344
+ if self.use_cuda:
345
+ flow = flow.cuda()
346
+ # case2: visualize which pixels are attended
347
+ # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))
348
+
349
+ if self.rate != 1:
350
+ flow = F.interpolate(flow, scale_factor=self.rate*4, mode='nearest')
351
+
352
+ return y, flow
353
+
354
+
355
+ def test_contextual_attention(args):
356
+ import cv2
357
+ import os
358
+ # run on cpu
359
+ os.environ['CUDA_VISIBLE_DEVICES'] = '2'
360
+
361
+ def float_to_uint8(img):
362
+ img = img * 255
363
+ return img.astype('uint8')
364
+
365
+ rate = 2
366
+ stride = 1
367
+ grid = rate*stride
368
+
369
+ b = default_loader(args.imageA)
370
+ w, h = b.size
371
+ b = b.resize((w//grid*grid//2, h//grid*grid//2), Image.ANTIALIAS)
372
+ # b = b.resize((w//grid*grid, h//grid*grid), Image.ANTIALIAS)
373
+ print('Size of imageA: {}'.format(b.size))
374
+
375
+ f = default_loader(args.imageB)
376
+ w, h = f.size
377
+ f = f.resize((w//grid*grid, h//grid*grid), Image.ANTIALIAS)
378
+ print('Size of imageB: {}'.format(f.size))
379
+
380
+ f, b = transforms.ToTensor()(f), transforms.ToTensor()(b)
381
+ f, b = f.unsqueeze(0), b.unsqueeze(0)
382
+ if torch.cuda.is_available():
383
+ f, b = f.cuda(), b.cuda()
384
+
385
+ contextual_attention = ContextualAttention(ksize=3, stride=stride, rate=rate, fuse=True)
386
+
387
+ if torch.cuda.is_available():
388
+ contextual_attention = contextual_attention.cuda()
389
+
390
+ yt, flow_t = contextual_attention(f, b)
391
+ vutils.save_image(yt, 'vutils' + args.imageOut, normalize=True)
392
+ vutils.save_image(flow_t, 'flow' + args.imageOut, normalize=True)
393
+ # y = tensor_img_to_npimg(yt.cpu()[0])
394
+ # flow = tensor_img_to_npimg(flow_t.cpu()[0])
395
+ # cv2.imwrite('flow' + args.imageOut, flow_t)
396
+
397
+
398
+ class LocalDis(nn.Module):
399
+ def __init__(self, config, use_cuda=True, device_ids=None):
400
+ super(LocalDis, self).__init__()
401
+ self.input_dim = config['input_dim']
402
+ self.cnum = config['ndf']
403
+ self.use_cuda = use_cuda
404
+ self.device_ids = device_ids
405
+
406
+ self.dis_conv_module = DisConvModule(self.input_dim, self.cnum)
407
+ self.linear = nn.Linear(self.cnum*4*8*8, 1)
408
+
409
+ def forward(self, x):
410
+ x = self.dis_conv_module(x)
411
+ x = x.view(x.size()[0], -1)
412
+ x = self.linear(x)
413
+
414
+ return x
415
+
416
+
417
+ class GlobalDis(nn.Module):
418
+ def __init__(self, config, use_cuda=True, device_ids=None):
419
+ super(GlobalDis, self).__init__()
420
+ self.input_dim = config['input_dim']
421
+ self.cnum = config['ndf']
422
+ self.use_cuda = use_cuda
423
+ self.device_ids = device_ids
424
+
425
+ self.dis_conv_module = DisConvModule(self.input_dim, self.cnum)
426
+ self.linear = nn.Linear(self.cnum*4*16*16, 1)
427
+
428
+ def forward(self, x):
429
+ x = self.dis_conv_module(x)
430
+ x = x.view(x.size()[0], -1)
431
+ x = self.linear(x)
432
+
433
+ return x
434
+
435
+
436
+ class DisConvModule(nn.Module):
437
+ def __init__(self, input_dim, cnum, use_cuda=True, device_ids=None):
438
+ super(DisConvModule, self).__init__()
439
+ self.use_cuda = use_cuda
440
+ self.device_ids = device_ids
441
+
442
+ self.conv1 = dis_conv(input_dim, cnum, 5, 2, 2)
443
+ self.conv2 = dis_conv(cnum, cnum*2, 5, 2, 2)
444
+ self.conv3 = dis_conv(cnum*2, cnum*4, 5, 2, 2)
445
+ self.conv4 = dis_conv(cnum*4, cnum*4, 5, 2, 2)
446
+
447
+ def forward(self, x):
448
+ x = self.conv1(x)
449
+ x = self.conv2(x)
450
+ x = self.conv3(x)
451
+ x = self.conv4(x)
452
+
453
+ return x
454
+
455
+
456
+ def gen_conv(input_dim, output_dim, kernel_size=3, stride=1, padding=0, rate=1,
457
+ activation='elu'):
458
+ return Conv2dBlock(input_dim, output_dim, kernel_size, stride,
459
+ conv_padding=padding, dilation=rate,
460
+ activation=activation)
461
+
462
+
463
+ def dis_conv(input_dim, output_dim, kernel_size=5, stride=2, padding=0, rate=1,
464
+ activation='lrelu'):
465
+ return Conv2dBlock(input_dim, output_dim, kernel_size, stride,
466
+ conv_padding=padding, dilation=rate,
467
+ activation=activation)
468
+
469
+
470
+ class Conv2dBlock(nn.Module):
471
+ def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0,
472
+ conv_padding=0, dilation=1, weight_norm='none', norm='none',
473
+ activation='relu', pad_type='zero', transpose=False):
474
+ super(Conv2dBlock, self).__init__()
475
+ self.use_bias = True
476
+ # initialize padding
477
+ if pad_type == 'reflect':
478
+ self.pad = nn.ReflectionPad2d(padding)
479
+ elif pad_type == 'replicate':
480
+ self.pad = nn.ReplicationPad2d(padding)
481
+ elif pad_type == 'zero':
482
+ self.pad = nn.ZeroPad2d(padding)
483
+ elif pad_type == 'none':
484
+ self.pad = None
485
+ else:
486
+ assert 0, "Unsupported padding type: {}".format(pad_type)
487
+
488
+ # initialize normalization
489
+ norm_dim = output_dim
490
+ if norm == 'bn':
491
+ self.norm = nn.BatchNorm2d(norm_dim)
492
+ elif norm == 'in':
493
+ self.norm = nn.InstanceNorm2d(norm_dim)
494
+ elif norm == 'none':
495
+ self.norm = None
496
+ else:
497
+ assert 0, "Unsupported normalization: {}".format(norm)
498
+
499
+ if weight_norm == 'sn':
500
+ self.weight_norm = spectral_norm_fn
501
+ elif weight_norm == 'wn':
502
+ self.weight_norm = weight_norm_fn
503
+ elif weight_norm == 'none':
504
+ self.weight_norm = None
505
+ else:
506
+ assert 0, "Unsupported normalization: {}".format(weight_norm)
507
+
508
+ # initialize activation
509
+ if activation == 'relu':
510
+ self.activation = nn.ReLU(inplace=True)
511
+ elif activation == 'elu':
512
+ self.activation = nn.ELU(inplace=True)
513
+ elif activation == 'lrelu':
514
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
515
+ elif activation == 'prelu':
516
+ self.activation = nn.PReLU()
517
+ elif activation == 'selu':
518
+ self.activation = nn.SELU(inplace=True)
519
+ elif activation == 'tanh':
520
+ self.activation = nn.Tanh()
521
+ elif activation == 'none':
522
+ self.activation = None
523
+ else:
524
+ assert 0, "Unsupported activation: {}".format(activation)
525
+
526
+ # initialize convolution
527
+ if transpose:
528
+ self.conv = nn.ConvTranspose2d(input_dim, output_dim,
529
+ kernel_size, stride,
530
+ padding=conv_padding,
531
+ output_padding=conv_padding,
532
+ dilation=dilation,
533
+ bias=self.use_bias)
534
+ else:
535
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride,
536
+ padding=conv_padding, dilation=dilation,
537
+ bias=self.use_bias)
538
+
539
+ if self.weight_norm:
540
+ self.conv = self.weight_norm(self.conv)
541
+
542
+ def forward(self, x):
543
+ if self.pad:
544
+ x = self.conv(self.pad(x))
545
+ else:
546
+ x = self.conv(x)
547
+ if self.norm:
548
+ x = self.norm(x)
549
+ if self.activation:
550
+ x = self.activation(x)
551
+ return x
552
+
553
+
554
+
555
+ if __name__ == "__main__":
556
+ import argparse
557
+ parser = argparse.ArgumentParser()
558
+ parser.add_argument('--imageA', default='', type=str, help='Image A as background patches to reconstruct image B.')
559
+ parser.add_argument('--imageB', default='', type=str, help='Image B is reconstructed with image A.')
560
+ parser.add_argument('--imageOut', default='result.png', type=str, help='Image B is reconstructed with image A.')
561
+ args = parser.parse_args()
562
+ test_contextual_attention(args)
matting/image_matting.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.autograd import Variable
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+
11
+
12
+ # 定义U²-Net模型
13
+ class REBNCONV(nn.Module):
14
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
15
+ super(REBNCONV, self).__init__()
16
+
17
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
18
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
19
+ self.relu_s1 = nn.ReLU(inplace=True)
20
+
21
+ def forward(self, x):
22
+ hx = x
23
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
24
+ return xout
25
+
26
+
27
+ class RSU7(nn.Module):
28
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
29
+ super(RSU7, self).__init__()
30
+
31
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
32
+
33
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
34
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
35
+
36
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
37
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
38
+
39
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
40
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
41
+
42
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
43
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+
50
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
51
+
52
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
53
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
54
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
55
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
56
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
57
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
58
+
59
+ def forward(self, x):
60
+ hx = x
61
+ hxin = self.rebnconvin(hx)
62
+
63
+ hx1 = self.rebnconv1(hxin)
64
+ hx = self.pool1(hx1)
65
+
66
+ hx2 = self.rebnconv2(hx)
67
+ hx = self.pool2(hx2)
68
+
69
+ hx3 = self.rebnconv3(hx)
70
+ hx = self.pool3(hx3)
71
+
72
+ hx4 = self.rebnconv4(hx)
73
+ hx = self.pool4(hx4)
74
+
75
+ hx5 = self.rebnconv5(hx)
76
+ hx = self.pool5(hx5)
77
+
78
+ hx6 = self.rebnconv6(hx)
79
+ hx7 = self.rebnconv7(hx6)
80
+
81
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
82
+ hx6dup = F.interpolate(hx6d, scale_factor=2, mode='bilinear', align_corners=False)
83
+
84
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
85
+ hx5dup = F.interpolate(hx5d, scale_factor=2, mode='bilinear', align_corners=False)
86
+
87
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
88
+ hx4dup = F.interpolate(hx4d, scale_factor=2, mode='bilinear', align_corners=False)
89
+
90
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
91
+ hx3dup = F.interpolate(hx3d, scale_factor=2, mode='bilinear', align_corners=False)
92
+
93
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
94
+ hx2dup = F.interpolate(hx2d, scale_factor=2, mode='bilinear', align_corners=False)
95
+
96
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
97
+
98
+ return hx1d + hxin
99
+
100
+
101
+ class U2NET(nn.Module):
102
+ def __init__(self, in_ch=3, out_ch=1):
103
+ super(U2NET, self).__init__()
104
+
105
+ self.stage1 = RSU7(in_ch, 32, 64)
106
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
107
+
108
+ self.stage2 = RSU7(64, 32, 128)
109
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
110
+
111
+ self.stage3 = RSU7(128, 64, 256)
112
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
113
+
114
+ self.stage4 = RSU7(256, 128, 512)
115
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
116
+
117
+ self.stage5 = RSU7(512, 256, 512)
118
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
119
+
120
+ self.stage6 = RSU7(512, 512, 512)
121
+
122
+ self.stage5d = RSU7(1024, 256, 512)
123
+ self.stage4d = RSU7(1024, 128, 256)
124
+ self.stage3d = RSU7(512, 64, 128)
125
+ self.stage2d = RSU7(256, 32, 64)
126
+ self.stage1d = RSU7(128, 16, 64)
127
+
128
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
129
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
130
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
131
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
132
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
133
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
134
+
135
+ self.outconv = nn.Conv2d(6, out_ch, 1)
136
+
137
+ def forward(self, x):
138
+ hx = x
139
+
140
+ hx1 = self.stage1(hx)
141
+ hx = self.pool12(hx1)
142
+
143
+ hx2 = self.stage2(hx)
144
+ hx = self.pool23(hx2)
145
+
146
+ hx3 = self.stage3(hx)
147
+ hx = self.pool34(hx3)
148
+
149
+ hx4 = self.stage4(hx)
150
+ hx = self.pool45(hx4)
151
+
152
+ hx5 = self.stage5(hx)
153
+ hx = self.pool56(hx5)
154
+
155
+ hx6 = self.stage6(hx)
156
+ hx6up = F.interpolate(hx6, scale_factor=2, mode='bilinear', align_corners=False)
157
+
158
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
159
+ hx5dup = F.interpolate(hx5d, scale_factor=2, mode='bilinear', align_corners=False)
160
+
161
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
162
+ hx4dup = F.interpolate(hx4d, scale_factor=2, mode='bilinear', align_corners=False)
163
+
164
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
165
+ hx3dup = F.interpolate(hx3d, scale_factor=2, mode='bilinear', align_corners=False)
166
+
167
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
168
+ hx2dup = F.interpolate(hx2d, scale_factor=2, mode='bilinear', align_corners=False)
169
+
170
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
171
+
172
+ d1 = self.side1(hx1d)
173
+
174
+ d2 = self.side2(hx2d)
175
+ d2 = F.interpolate(d2, scale_factor=2, mode='bilinear', align_corners=False)
176
+
177
+ d3 = self.side3(hx3d)
178
+ d3 = F.interpolate(d3, scale_factor=4, mode='bilinear', align_corners=False)
179
+
180
+ d4 = self.side4(hx4d)
181
+ d4 = F.interpolate(d4, scale_factor=8, mode='bilinear', align_corners=False)
182
+
183
+ d5 = self.side5(hx5d)
184
+ d5 = F.interpolate(d5, scale_factor=16, mode='bilinear', align_corners=False)
185
+
186
+ d6 = self.side6(hx6)
187
+ d6 = F.interpolate(d6, scale_factor=32, mode='bilinear', align_corners=False)
188
+
189
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
190
+
191
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(
192
+ d4), torch.sigmoid(d5), torch.sigmoid(d6)
193
+
194
+
195
+ # 预处理图像函数
196
+ def preprocess_image(image_path):
197
+ img = cv2.imread(image_path)
198
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
199
+ img = cv2.resize(img, (320, 320))
200
+ img = img.astype(np.float32) / 255.0
201
+ img = transforms.ToTensor()(img)
202
+ img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
203
+ return img.unsqueeze(0)
204
+
205
+
206
+ # 后处理mask函数
207
+ def postprocess_mask(mask):
208
+ mask = mask.squeeze()
209
+ mask = (mask > 0.5).float()
210
+ mask = mask.cpu().numpy()
211
+ return mask
212
+
213
+
214
+ # 创建二值图片函数
215
+ def create_binary_image(image_path, mask, output_path):
216
+ original_image = cv2.imread(image_path)
217
+ original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
218
+
219
+ mask_resized = cv2.resize(mask, (original_image.shape[1], original_image.shape[0]))
220
+
221
+ result = np.zeros_like(original_image)
222
+ for c in range(3):
223
+ result[:, :, c] = np.where(mask_resized == 1, original_image[:, :, c], 0)
224
+
225
+ cv2.imwrite(output_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
226
+
227
+
228
+ # 加载模型函数
229
+ def load_model(model_path):
230
+ net = U2NET(3, 1)
231
+ if torch.cuda.is_available():
232
+ net.load_state_dict(torch.load(model_path))
233
+ net.cuda()
234
+ else:
235
+ net.load_state_dict(torch.load(model_path, map_location='cpu'))
236
+ net.eval()
237
+ return net
238
+
239
+
240
+ # 处理图像函数
241
+ def process_images(input_folder, output_folder, model_path):
242
+ if not os.path.exists(output_folder):
243
+ os.makedirs(output_folder)
244
+
245
+ model = load_model(model_path)
246
+
247
+ for filename in os.listdir(input_folder):
248
+ if filename.endswith(('.png', '.jpg', '.jpeg')):
249
+ image_path = os.path.join(input_folder, filename)
250
+ output_path = os.path.join(output_folder, filename)
251
+
252
+ img = preprocess_image(image_path)
253
+
254
+ if torch.cuda.is_available():
255
+ inputs = Variable(img.cuda())
256
+ else:
257
+ inputs = Variable(img)
258
+
259
+ d1, d2, d3, d4, d5, d6, d7 = model(inputs)
260
+
261
+ pred_mask = d1
262
+ mask = postprocess_mask(pred_mask)
263
+
264
+ create_binary_image(image_path, mask, output_path)
265
+
266
+ print(f"Processed {filename}")
267
+
268
+
269
+ if __name__ == "__main__":
270
+ model_path = "./u2net.pth" # 替换为正确的模型权重文件路径
271
+ input_folder = r"D:\learn_torch\pytorch-CycleGAN\datasets\ostracoda\trainA" # 输入图像文件夹
272
+ output_folder = r"D:\learn_torch\pytorch-CycleGAN\datasets\ostracoda\trainA_matting" # 输出图像文件夹
273
+
274
+ process_images(input_folder, output_folder, model_path)