climate
YuanGao-YG commited on
Commit
dc726ea
·
verified ·
1 Parent(s): 484df85

Update model/vision.py

Browse files
Files changed (1) hide show
  1. model/vision.py +542 -559
model/vision.py CHANGED
@@ -1,559 +1,542 @@
1
- from __future__ import absolute_import
2
- from __future__ import division
3
- from __future__ import print_function
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- from thop import profile
10
-
11
-
12
- class VISION(nn.Module):
13
- def __init__(self,channel = 16):
14
- super(VISION,self).__init__()
15
- self.aoe = AOE(channel)
16
- self.gsao = GSAO(channel)
17
-
18
- def forward(self,x):
19
- x_aoe = self.aoe(x)
20
- out = self.gsao(x_aoe)
21
-
22
- return out
23
-
24
- class GSAO(nn.Module):
25
- def __init__(self,channel = 16):
26
- super(GSAO,self).__init__()
27
-
28
- self.gsao_left = GSAO_Left(channel)
29
-
30
- self.ssdc = SSDC(channel)
31
-
32
- self.gsao_right = GSAO_Right(channel)
33
-
34
- self.gsao_out = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False)
35
-
36
- def forward(self,x):
37
-
38
- L,M,S,SS = self.gsao_left(x)
39
- ssdc = self.ssdc(SS)
40
- x_out = self.gsao_right(ssdc,SS,S,M,L)
41
- out = self.gsao_out(x_out)
42
-
43
- return out
44
-
45
-
46
- class AOE(nn.Module):
47
- def __init__(self,channel = 16):
48
- super(AOE,self).__init__()
49
-
50
- self.uoa = UOA(channel)
51
- self.scp = SCP(channel)
52
-
53
- def forward(self,x):
54
- x_in = self.uoa(x)
55
- x_out = self.scp(x_in)#3 16
56
-
57
- return x_out
58
-
59
- class UOA(nn.Module):
60
- def __init__(self,channel = 16):
61
- super(UOA,self).__init__()
62
-
63
- self.Haze_in1 = nn.Conv2d(1,channel,kernel_size=1,stride=1,padding=0,bias=False)
64
- self.Haze_in3 = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False)
65
- self.Haze_in4 = nn.Conv2d(4,channel,kernel_size=1,stride=1,padding=0,bias=False)
66
-
67
- def forward(self,x):
68
- if x.shape[1] == 1:
69
- x_in = self.Haze_in1(x)#3 16
70
- elif x.shape[1] == 3:
71
- x_in = self.Haze_in3(x)#3 16
72
- elif x.shape[1] == 4:
73
- x_in = self.Haze_in4(x)#3 16
74
-
75
- return x_in
76
-
77
- class SCP(nn.Module):
78
- def __init__(self, channel):
79
- super(SCP, self).__init__()
80
- self.cgm = CGM(channel)
81
- self.cim = CIM(channel)
82
-
83
- def forward(self, x):
84
- x_cgm = self.cgm(x)
85
- x_cim = self.cim(x_cgm, x)
86
-
87
- return x_cim
88
-
89
- class GSAO_Left(nn.Module):
90
- def __init__(self,channel):
91
- super(GSAO_Left,self).__init__()
92
-
93
- self.el = GARO(channel)#16
94
- self.em = GARO(channel*2)#32
95
- self.es = GARO(channel*4)#64
96
- self.ess = GARO(channel*8)#128
97
- self.esss = GARO(channel*16)#256
98
-
99
- self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
100
- self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32
101
- self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64
102
- self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128
103
-
104
- def forward(self,x):
105
-
106
- elout = self.el(x)#16
107
- x_emin = self.conv_eltem(self.maxpool(elout))#32
108
- emout = self.em(x_emin)
109
- x_esin = self.conv_emtes(self.maxpool(emout))
110
- esout = self.es(x_esin)
111
- x_esin = self.conv_estess(self.maxpool(esout))
112
- essout = self.ess(x_esin)#128
113
-
114
- return elout,emout,esout,essout
115
-
116
- class SSDC(nn.Module):
117
- def __init__(self,channel):
118
- super(SSDC,self).__init__()
119
-
120
- self.s1 = SKO(channel*8)#128
121
- self.s2 = SKO(channel*8)#128
122
-
123
- def forward(self,x):
124
- ssdc1 = self.s1(x) + x
125
- ssdc2 = self.s2(ssdc1) + ssdc1
126
-
127
- return ssdc2
128
-
129
- class GSAO_Right(nn.Module):
130
- def __init__(self,channel):
131
- super(GSAO_Right,self).__init__()
132
-
133
- self.dss = GARO(channel*8)#128
134
- self.ds = GARO(channel*4)#64
135
- self.dm = GARO(channel*2)#32
136
- self.dl = GARO(channel)#16
137
-
138
- self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128
139
- self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64
140
- self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32
141
- self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16
142
-
143
- def _upsample(self,x):
144
- _,_,H,W = x.size()
145
- return F.upsample(x,size=(2*H,2*W),mode='bilinear')
146
-
147
- def forward(self,x,ss,s,m,l):
148
-
149
- dssout = self.dss(x+ss)
150
- x_dsin = self.conv_dsstds(self._upsample(dssout))
151
- dsout = self.ds(x_dsin+s)
152
- x_dmin = self.conv_dstdm(self._upsample(dsout))
153
- dmout = self.dm(x_dmin+m)
154
- x_dlin = self.conv_dmtdl(self._upsample(dmout))
155
- dlout = self.dl(x_dlin+l)
156
-
157
- return dlout
158
-
159
-
160
- class SKO(nn.Module):
161
- def __init__(self, in_ch, M=3, G=1, r=4, stride=1, L=32) -> None:
162
- super().__init__()
163
-
164
- d = max(int(in_ch/r), L)
165
- self.M = M
166
- self.in_ch = in_ch
167
- self.convs = nn.ModuleList([])
168
- for i in range(M):
169
- self.convs.append(
170
- nn.Sequential(
171
- nn.Conv2d(in_ch, in_ch, kernel_size=3+i*2, stride=stride, padding = 1+i, groups=G),
172
- nn.BatchNorm2d(in_ch),
173
- nn.ReLU(inplace=True)
174
- )
175
- )
176
- # print("D:", d)
177
- self.fc = nn.Linear(in_ch, d)
178
- self.fcs = nn.ModuleList([])
179
- for i in range(M):
180
- self.fcs.append(nn.Linear(d, in_ch))
181
- self.softmax = nn.Softmax(dim=1)
182
-
183
- def forward(self, x):
184
- for i, conv in enumerate(self.convs):
185
- fea = conv(x).clone().unsqueeze_(dim=1).clone()
186
- if i == 0:
187
- feas = fea
188
- else:
189
- feas = torch.cat([feas.clone(), fea], dim=1)
190
- fea_U = torch.sum(feas.clone(), dim=1)
191
- fea_s = fea_U.clone().mean(-1).mean(-1)
192
- fea_z = self.fc(fea_s)
193
- for i, fc in enumerate(self.fcs):
194
- vector = fc(fea_z).clone().unsqueeze_(dim=1)
195
- if i == 0:
196
- attention_vectors = vector
197
- else:
198
- attention_vectors = torch.cat([attention_vectors.clone(), vector], dim=1)
199
- attention_vectors = self.softmax(attention_vectors.clone())
200
- attention_vectors = attention_vectors.clone().unsqueeze(-1).unsqueeze(-1)
201
- fea_v = (feas * attention_vectors).clone().sum(dim=1)
202
- return fea_v
203
-
204
-
205
- class GARO(nn.Module):
206
- def __init__(self, channel, norm=False):
207
- super(GARO, self).__init__()
208
-
209
- self.conv_1_1 = DeformConv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
210
- self.conv_2_1 = DeformConv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
211
- self.act = nn.PReLU(channel)
212
- self.norm = nn.GroupNorm(num_channels=channel, num_groups=1)
213
-
214
- def _upsample(self, x, y):
215
- _, _, H, W = y.size()
216
- return F.upsample(x, size=(H, W), mode='bilinear')
217
-
218
- def forward(self, x):
219
- x_1 = self.act(self.norm(self.conv_1_1(x)))
220
- x_2 = self.act(self.norm(self.conv_2_1(x_1))) + x
221
-
222
- return x_2
223
-
224
- class CGM(nn.Module):
225
- def __init__(self, channel, prompt_len=3, prompt_size=96, lin_dim=16):
226
- super(CGM, self).__init__()
227
- self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, channel, prompt_size, prompt_size))
228
- self.linear_layer = nn.Linear(lin_dim, prompt_len)
229
- self.conv3x3 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
230
-
231
- def forward(self, x):
232
- B, C, H, W = x.shape
233
- emb = x.mean(dim=(-2, -1))
234
- prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
235
- prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
236
- 1, 1,
237
- 1,
238
- 1).squeeze(
239
- 1)
240
- prompt = torch.sum(prompt, dim=1)
241
- prompt = F.interpolate(prompt, (H, W), mode="bilinear")
242
- prompt = self.conv3x3(prompt)
243
-
244
- return prompt
245
-
246
- class CIM(nn.Module):
247
- def __init__(self, channel):
248
- super(CIM, self).__init__()
249
- self.res = ResBlock(2*channel, 2*channel)
250
- self.conv3x3 = nn.Conv2d(2*channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
251
-
252
- def forward(self, prompt, x):
253
-
254
- x = torch.cat((prompt, x), dim=1)
255
- x = self.res(x)
256
- out = self.conv3x3(x)
257
-
258
- return out
259
-
260
-
261
- class DeformConv2d(nn.Module):
262
- def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
263
- super(DeformConv2d, self).__init__()
264
- self.kernel_size = kernel_size
265
- self.padding = padding
266
- self.stride = stride
267
- self.zero_padding = nn.ZeroPad2d(padding)
268
- self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
269
-
270
- self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
271
- nn.init.constant_(self.p_conv.weight, 0)
272
- self.p_conv.register_backward_hook(self._set_lr)
273
-
274
- self.modulation = modulation
275
- if modulation:
276
- self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
277
- nn.init.constant_(self.m_conv.weight, 0)
278
- self.m_conv.register_backward_hook(self._set_lr)
279
-
280
- @staticmethod
281
- def _set_lr(module, grad_input, grad_output):
282
- grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
283
- grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
284
-
285
- def forward(self, x):
286
- offset = self.p_conv(x)
287
- if self.modulation:
288
- m = torch.sigmoid(self.m_conv(x))
289
-
290
- dtype = offset.data.type()
291
- ks = self.kernel_size
292
- N = offset.size(1) // 2
293
-
294
- if self.padding:
295
- x = self.zero_padding(x)
296
-
297
- p = self._get_p(offset, dtype)
298
-
299
- p = p.contiguous().permute(0, 2, 3, 1)
300
- q_lt = p.detach().floor()
301
- q_rb = q_lt + 1
302
-
303
- q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
304
- q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
305
- q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
306
- q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
307
-
308
- p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
309
-
310
- g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
311
- g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
312
- g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
313
- g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
314
-
315
- x_q_lt = self._get_x_q(x, q_lt, N)
316
- x_q_rb = self._get_x_q(x, q_rb, N)
317
- x_q_lb = self._get_x_q(x, q_lb, N)
318
- x_q_rt = self._get_x_q(x, q_rt, N)
319
-
320
- x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
321
- g_rb.unsqueeze(dim=1) * x_q_rb + \
322
- g_lb.unsqueeze(dim=1) * x_q_lb + \
323
- g_rt.unsqueeze(dim=1) * x_q_rt
324
-
325
- if self.modulation:
326
- m = m.contiguous().permute(0, 2, 3, 1)
327
- m = m.unsqueeze(dim=1)
328
- m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
329
- x_offset *= m
330
-
331
- x_offset = self._reshape_x_offset(x_offset, ks)
332
- out = self.conv(x_offset)
333
-
334
- return out
335
-
336
- def _get_p_n(self, N, dtype):
337
- p_n_x, p_n_y = torch.meshgrid(
338
- torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
339
- torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
340
- p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
341
- p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
342
-
343
- return p_n
344
-
345
- def _get_p_0(self, h, w, N, dtype):
346
- p_0_x, p_0_y = torch.meshgrid(
347
- torch.arange(1, h*self.stride+1, self.stride),
348
- torch.arange(1, w*self.stride+1, self.stride))
349
- p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
350
- p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
351
- p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
352
-
353
- return p_0
354
-
355
- def _get_p(self, offset, dtype):
356
- N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
357
-
358
- p_n = self._get_p_n(N, dtype)
359
- p_0 = self._get_p_0(h, w, N, dtype)
360
- p = p_0 + p_n + offset
361
- return p
362
-
363
- def _get_x_q(self, x, q, N):
364
- b, h, w, _ = q.size()
365
- padded_w = x.size(3)
366
- c = x.size(1)
367
- x = x.contiguous().view(b, c, -1)
368
-
369
- index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
370
- index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
371
-
372
- x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
373
-
374
- return x_offset
375
-
376
- @staticmethod
377
- def _reshape_x_offset(x_offset, ks):
378
- b, c, h, w, N = x_offset.size()
379
- x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
380
- x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
381
-
382
- return x_offset
383
-
384
- class DeformConv2d(nn.Module):
385
- def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
386
- super(DeformConv2d, self).__init__()
387
- self.kernel_size = kernel_size
388
- self.padding = padding
389
- self.stride = stride
390
- self.zero_padding = nn.ZeroPad2d(padding)
391
- self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
392
-
393
- self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
394
- nn.init.constant_(self.p_conv.weight, 0)
395
- self.p_conv.register_backward_hook(self._set_lr)
396
-
397
- self.modulation = modulation
398
- if modulation:
399
- self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
400
- nn.init.constant_(self.m_conv.weight, 0)
401
- self.m_conv.register_backward_hook(self._set_lr)
402
-
403
- @staticmethod
404
- def _set_lr(module, grad_input, grad_output):
405
- grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
406
- grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
407
-
408
- def forward(self, x):
409
- offset = self.p_conv(x)
410
- if self.modulation:
411
- m = torch.sigmoid(self.m_conv(x))
412
-
413
- dtype = offset.data.type()
414
- ks = self.kernel_size
415
- N = offset.size(1) // 2
416
-
417
- if self.padding:
418
- x = self.zero_padding(x)
419
-
420
- p = self._get_p(offset, dtype)
421
-
422
- p = p.contiguous().permute(0, 2, 3, 1)
423
- q_lt = p.detach().floor()
424
- q_rb = q_lt + 1
425
-
426
- q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
427
- q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
428
- q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
429
- q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
430
-
431
- p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
432
-
433
- g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
434
- g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
435
- g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
436
- g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
437
-
438
- x_q_lt = self._get_x_q(x, q_lt, N)
439
- x_q_rb = self._get_x_q(x, q_rb, N)
440
- x_q_lb = self._get_x_q(x, q_lb, N)
441
- x_q_rt = self._get_x_q(x, q_rt, N)
442
-
443
- x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
444
- g_rb.unsqueeze(dim=1) * x_q_rb + \
445
- g_lb.unsqueeze(dim=1) * x_q_lb + \
446
- g_rt.unsqueeze(dim=1) * x_q_rt
447
-
448
- if self.modulation:
449
- m = m.contiguous().permute(0, 2, 3, 1)
450
- m = m.unsqueeze(dim=1)
451
- m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
452
- x_offset *= m
453
-
454
- x_offset = self._reshape_x_offset(x_offset, ks)
455
- out = self.conv(x_offset)
456
-
457
- return out
458
-
459
- def _get_p_n(self, N, dtype):
460
- p_n_x, p_n_y = torch.meshgrid(
461
- torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
462
- torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
463
- p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
464
- p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
465
-
466
- return p_n
467
-
468
- def _get_p_0(self, h, w, N, dtype):
469
- p_0_x, p_0_y = torch.meshgrid(
470
- torch.arange(1, h*self.stride+1, self.stride),
471
- torch.arange(1, w*self.stride+1, self.stride))
472
- p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
473
- p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
474
- p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
475
-
476
- return p_0
477
-
478
- def _get_p(self, offset, dtype):
479
- N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
480
-
481
- p_n = self._get_p_n(N, dtype)
482
- p_0 = self._get_p_0(h, w, N, dtype)
483
- p = p_0 + p_n + offset
484
- return p
485
-
486
- def _get_x_q(self, x, q, N):
487
- b, h, w, _ = q.size()
488
- padded_w = x.size(3)
489
- c = x.size(1)
490
- x = x.contiguous().view(b, c, -1)
491
-
492
- index = q[..., :N]*padded_w + q[..., N:]
493
- index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
494
-
495
- x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
496
-
497
- return x_offset
498
-
499
- @staticmethod
500
- def _reshape_x_offset(x_offset, ks):
501
- b, c, h, w, N = x_offset.size()
502
- x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
503
- x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
504
-
505
- return x_offset
506
-
507
-
508
- class BasicConv(nn.Module):
509
- def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
510
- super(BasicConv, self).__init__()
511
- if bias and norm:
512
- bias = False
513
-
514
- padding = kernel_size // 2
515
- layers = list()
516
- if transpose:
517
- padding = kernel_size // 2 -1
518
- layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
519
- else:
520
- layers.append(
521
- nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
522
- if norm:
523
- layers.append(nn.BatchNorm2d(out_channel))
524
- if relu:
525
- layers.append(nn.GELU())
526
- self.main = nn.Sequential(*layers)
527
-
528
- def forward(self, x):
529
- return self.main(x)
530
-
531
-
532
- class ResBlock(nn.Module):
533
- def __init__(self, in_channel, out_channel):
534
- super(ResBlock, self).__init__()
535
- self.main = nn.Sequential(
536
- BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
537
- BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
538
- )
539
-
540
- def forward(self, x):
541
- return self.main(x) + x
542
-
543
-
544
- from thop import profile
545
-
546
- if __name__ == '__main__':
547
-
548
- device = "cuda" if torch.cuda.is_available() else "cpu"
549
-
550
- net = VISION().to(device)
551
-
552
- input = torch.randn(1, 4, 512, 512).to(device)
553
- output = net(input)
554
-
555
- macs, params = profile(net, inputs=(input, ))
556
-
557
- print('macs: ', macs, 'params: ', params)
558
- print('macs: %.2f G, params: %.2f M' % (macs / 1000000000.0, params / 1000000.0))
559
- print(output.shape)
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from thop import profile
10
+
11
+
12
+ class VISION(nn.Module):
13
+ def __init__(self,channel = 16):
14
+ super(VISION,self).__init__()
15
+ self.aoe = AOE(channel)
16
+ self.gsao = GSAO(channel)
17
+
18
+ def forward(self,x):
19
+ x_aoe = self.aoe(x)
20
+ out = self.gsao(x_aoe)
21
+
22
+ return out
23
+
24
+ class GSAO(nn.Module):
25
+ def __init__(self,channel = 16):
26
+ super(GSAO,self).__init__()
27
+
28
+ self.gsao_left = GSAO_Left(channel)
29
+
30
+ self.ssdc = SSDC(channel)
31
+
32
+ self.gsao_right = GSAO_Right(channel)
33
+
34
+ self.gsao_out = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False)
35
+
36
+ def forward(self,x):
37
+
38
+ L,M,S,SS = self.gsao_left(x)
39
+ ssdc = self.ssdc(SS)
40
+ x_out = self.gsao_right(ssdc,SS,S,M,L)
41
+ out = self.gsao_out(x_out)
42
+
43
+ return out
44
+
45
+
46
+ class AOE(nn.Module):
47
+ def __init__(self,channel = 16):
48
+ super(AOE,self).__init__()
49
+
50
+ self.uoa = UOA(channel)
51
+ self.scp = SCP(channel)
52
+
53
+ def forward(self,x):
54
+ x_in = self.uoa(x)
55
+ x_out = self.scp(x_in)#3 16
56
+
57
+ return x_out
58
+
59
+ class UOA(nn.Module):
60
+ def __init__(self,channel = 16):
61
+ super(UOA,self).__init__()
62
+
63
+ self.Haze_in1 = nn.Conv2d(1,channel,kernel_size=1,stride=1,padding=0,bias=False)
64
+ self.Haze_in3 = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False)
65
+ self.Haze_in4 = nn.Conv2d(4,channel,kernel_size=1,stride=1,padding=0,bias=False)
66
+
67
+ def forward(self,x):
68
+ if x.shape[1] == 1:
69
+ x_in = self.Haze_in1(x)#3 16
70
+ elif x.shape[1] == 3:
71
+ x_in = self.Haze_in3(x)#3 16
72
+ elif x.shape[1] == 4:
73
+ x_in = self.Haze_in4(x)#3 16
74
+
75
+ return x_in
76
+
77
+ class SCP(nn.Module):
78
+ def __init__(self, channel):
79
+ super(SCP, self).__init__()
80
+ self.cgm = CGM(channel)
81
+ self.cim = CIM(channel)
82
+
83
+ def forward(self, x):
84
+ x_cgm = self.cgm(x)
85
+ x_cim = self.cim(x_cgm, x)
86
+
87
+ return x_cim
88
+
89
+ class GSAO_Left(nn.Module):
90
+ def __init__(self,channel):
91
+ super(GSAO_Left,self).__init__()
92
+
93
+ self.el = GARO(channel)#16
94
+ self.em = GARO(channel*2)#32
95
+ self.es = GARO(channel*4)#64
96
+ self.ess = GARO(channel*8)#128
97
+ self.esss = GARO(channel*16)#256
98
+
99
+ self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
100
+ self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32
101
+ self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64
102
+ self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128
103
+
104
+ def forward(self,x):
105
+
106
+ elout = self.el(x)#16
107
+ x_emin = self.conv_eltem(self.maxpool(elout))#32
108
+ emout = self.em(x_emin)
109
+ x_esin = self.conv_emtes(self.maxpool(emout))
110
+ esout = self.es(x_esin)
111
+ x_esin = self.conv_estess(self.maxpool(esout))
112
+ essout = self.ess(x_esin)#128
113
+
114
+ return elout,emout,esout,essout
115
+
116
+ class SSDC(nn.Module):
117
+ def __init__(self,channel):
118
+ super(SSDC,self).__init__()
119
+
120
+ self.s1 = SKO(channel*8)#128
121
+ self.s2 = SKO(channel*8)#128
122
+
123
+ def forward(self,x):
124
+ ssdc1 = self.s1(x) + x
125
+ ssdc2 = self.s2(ssdc1) + ssdc1
126
+
127
+ return ssdc2
128
+
129
+ class GSAO_Right(nn.Module):
130
+ def __init__(self,channel):
131
+ super(GSAO_Right,self).__init__()
132
+
133
+ self.dss = GARO(channel*8)#128
134
+ self.ds = GARO(channel*4)#64
135
+ self.dm = GARO(channel*2)#32
136
+ self.dl = GARO(channel)#16
137
+
138
+ self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128
139
+ self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64
140
+ self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32
141
+ self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16
142
+
143
+ def _upsample(self,x):
144
+ _,_,H,W = x.size()
145
+ return F.upsample(x,size=(2*H,2*W),mode='bilinear')
146
+
147
+ def forward(self,x,ss,s,m,l):
148
+
149
+ dssout = self.dss(x+ss)
150
+ x_dsin = self.conv_dsstds(self._upsample(dssout))
151
+ dsout = self.ds(x_dsin+s)
152
+ x_dmin = self.conv_dstdm(self._upsample(dsout))
153
+ dmout = self.dm(x_dmin+m)
154
+ x_dlin = self.conv_dmtdl(self._upsample(dmout))
155
+ dlout = self.dl(x_dlin+l)
156
+
157
+ return dlout
158
+
159
+
160
+ class SKO(nn.Module):
161
+ def __init__(self, in_ch, M=3, G=1, r=4, stride=1, L=32) -> None:
162
+ super().__init__()
163
+
164
+ d = max(int(in_ch/r), L)
165
+ self.M = M
166
+ self.in_ch = in_ch
167
+ self.convs = nn.ModuleList([])
168
+ for i in range(M):
169
+ self.convs.append(
170
+ nn.Sequential(
171
+ nn.Conv2d(in_ch, in_ch, kernel_size=3+i*2, stride=stride, padding = 1+i, groups=G),
172
+ nn.BatchNorm2d(in_ch),
173
+ nn.ReLU(inplace=True)
174
+ )
175
+ )
176
+ # print("D:", d)
177
+ self.fc = nn.Linear(in_ch, d)
178
+ self.fcs = nn.ModuleList([])
179
+ for i in range(M):
180
+ self.fcs.append(nn.Linear(d, in_ch))
181
+ self.softmax = nn.Softmax(dim=1)
182
+
183
+ def forward(self, x):
184
+ for i, conv in enumerate(self.convs):
185
+ fea = conv(x).clone().unsqueeze_(dim=1).clone()
186
+ if i == 0:
187
+ feas = fea
188
+ else:
189
+ feas = torch.cat([feas.clone(), fea], dim=1)
190
+ fea_U = torch.sum(feas.clone(), dim=1)
191
+ fea_s = fea_U.clone().mean(-1).mean(-1)
192
+ fea_z = self.fc(fea_s)
193
+ for i, fc in enumerate(self.fcs):
194
+ vector = fc(fea_z).clone().unsqueeze_(dim=1)
195
+ if i == 0:
196
+ attention_vectors = vector
197
+ else:
198
+ attention_vectors = torch.cat([attention_vectors.clone(), vector], dim=1)
199
+ attention_vectors = self.softmax(attention_vectors.clone())
200
+ attention_vectors = attention_vectors.clone().unsqueeze(-1).unsqueeze(-1)
201
+ fea_v = (feas * attention_vectors).clone().sum(dim=1)
202
+ return fea_v
203
+
204
+
205
+ class GARO(nn.Module):
206
+ def __init__(self, channel, norm=False):
207
+ super(GARO, self).__init__()
208
+
209
+ self.conv_1_1 = DeformConv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
210
+ self.conv_2_1 = DeformConv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
211
+ self.act = nn.PReLU(channel)
212
+ self.norm = nn.GroupNorm(num_channels=channel, num_groups=1)
213
+
214
+ def _upsample(self, x, y):
215
+ _, _, H, W = y.size()
216
+ return F.upsample(x, size=(H, W), mode='bilinear')
217
+
218
+ def forward(self, x):
219
+ x_1 = self.act(self.norm(self.conv_1_1(x)))
220
+ x_2 = self.act(self.norm(self.conv_2_1(x_1))) + x
221
+
222
+ return x_2
223
+
224
+ class CGM(nn.Module):
225
+ def __init__(self, channel, prompt_len=3, prompt_size=96, lin_dim=16):
226
+ super(CGM, self).__init__()
227
+ self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, channel, prompt_size, prompt_size))
228
+ self.linear_layer = nn.Linear(lin_dim, prompt_len)
229
+ self.conv3x3 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
230
+
231
+ def forward(self, x):
232
+ B, C, H, W = x.shape
233
+ emb = x.mean(dim=(-2, -1))
234
+ prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
235
+ prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
236
+ 1, 1,
237
+ 1,
238
+ 1).squeeze(
239
+ 1)
240
+ prompt = torch.sum(prompt, dim=1)
241
+ prompt = F.interpolate(prompt, (H, W), mode="bilinear")
242
+ prompt = self.conv3x3(prompt)
243
+
244
+ return prompt
245
+
246
+ class CIM(nn.Module):
247
+ def __init__(self, channel):
248
+ super(CIM, self).__init__()
249
+ self.res = ResBlock(2*channel, 2*channel)
250
+ self.conv3x3 = nn.Conv2d(2*channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
251
+
252
+ def forward(self, prompt, x):
253
+
254
+ x = torch.cat((prompt, x), dim=1)
255
+ x = self.res(x)
256
+ out = self.conv3x3(x)
257
+
258
+ return out
259
+
260
+
261
+ class DeformConv2d(nn.Module):
262
+ def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
263
+ super(DeformConv2d, self).__init__()
264
+ self.kernel_size = kernel_size
265
+ self.padding = padding
266
+ self.stride = stride
267
+ self.zero_padding = nn.ZeroPad2d(padding)
268
+ self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
269
+
270
+ self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
271
+ nn.init.constant_(self.p_conv.weight, 0)
272
+ self.p_conv.register_backward_hook(self._set_lr)
273
+
274
+ self.modulation = modulation
275
+ if modulation:
276
+ self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
277
+ nn.init.constant_(self.m_conv.weight, 0)
278
+ self.m_conv.register_backward_hook(self._set_lr)
279
+
280
+ @staticmethod
281
+ def _set_lr(module, grad_input, grad_output):
282
+ grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
283
+ grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
284
+
285
+ def forward(self, x):
286
+ offset = self.p_conv(x)
287
+ if self.modulation:
288
+ m = torch.sigmoid(self.m_conv(x))
289
+
290
+ dtype = offset.data.type()
291
+ ks = self.kernel_size
292
+ N = offset.size(1) // 2
293
+
294
+ if self.padding:
295
+ x = self.zero_padding(x)
296
+
297
+ p = self._get_p(offset, dtype)
298
+
299
+ p = p.contiguous().permute(0, 2, 3, 1)
300
+ q_lt = p.detach().floor()
301
+ q_rb = q_lt + 1
302
+
303
+ q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
304
+ q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
305
+ q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
306
+ q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
307
+
308
+ p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
309
+
310
+ g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
311
+ g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
312
+ g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
313
+ g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
314
+
315
+ x_q_lt = self._get_x_q(x, q_lt, N)
316
+ x_q_rb = self._get_x_q(x, q_rb, N)
317
+ x_q_lb = self._get_x_q(x, q_lb, N)
318
+ x_q_rt = self._get_x_q(x, q_rt, N)
319
+
320
+ x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
321
+ g_rb.unsqueeze(dim=1) * x_q_rb + \
322
+ g_lb.unsqueeze(dim=1) * x_q_lb + \
323
+ g_rt.unsqueeze(dim=1) * x_q_rt
324
+
325
+ if self.modulation:
326
+ m = m.contiguous().permute(0, 2, 3, 1)
327
+ m = m.unsqueeze(dim=1)
328
+ m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
329
+ x_offset *= m
330
+
331
+ x_offset = self._reshape_x_offset(x_offset, ks)
332
+ out = self.conv(x_offset)
333
+
334
+ return out
335
+
336
+ def _get_p_n(self, N, dtype):
337
+ p_n_x, p_n_y = torch.meshgrid(
338
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
339
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
340
+ p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
341
+ p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
342
+
343
+ return p_n
344
+
345
+ def _get_p_0(self, h, w, N, dtype):
346
+ p_0_x, p_0_y = torch.meshgrid(
347
+ torch.arange(1, h*self.stride+1, self.stride),
348
+ torch.arange(1, w*self.stride+1, self.stride))
349
+ p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
350
+ p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
351
+ p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
352
+
353
+ return p_0
354
+
355
+ def _get_p(self, offset, dtype):
356
+ N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
357
+
358
+ p_n = self._get_p_n(N, dtype)
359
+ p_0 = self._get_p_0(h, w, N, dtype)
360
+ p = p_0 + p_n + offset
361
+ return p
362
+
363
+ def _get_x_q(self, x, q, N):
364
+ b, h, w, _ = q.size()
365
+ padded_w = x.size(3)
366
+ c = x.size(1)
367
+ x = x.contiguous().view(b, c, -1)
368
+
369
+ index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
370
+ index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
371
+
372
+ x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
373
+
374
+ return x_offset
375
+
376
+ @staticmethod
377
+ def _reshape_x_offset(x_offset, ks):
378
+ b, c, h, w, N = x_offset.size()
379
+ x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
380
+ x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
381
+
382
+ return x_offset
383
+
384
+ class DeformConv2d(nn.Module):
385
+ def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
386
+ super(DeformConv2d, self).__init__()
387
+ self.kernel_size = kernel_size
388
+ self.padding = padding
389
+ self.stride = stride
390
+ self.zero_padding = nn.ZeroPad2d(padding)
391
+ self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
392
+
393
+ self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
394
+ nn.init.constant_(self.p_conv.weight, 0)
395
+ self.p_conv.register_backward_hook(self._set_lr)
396
+
397
+ self.modulation = modulation
398
+ if modulation:
399
+ self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
400
+ nn.init.constant_(self.m_conv.weight, 0)
401
+ self.m_conv.register_backward_hook(self._set_lr)
402
+
403
+ @staticmethod
404
+ def _set_lr(module, grad_input, grad_output):
405
+ grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
406
+ grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
407
+
408
+ def forward(self, x):
409
+ offset = self.p_conv(x)
410
+ if self.modulation:
411
+ m = torch.sigmoid(self.m_conv(x))
412
+
413
+ dtype = offset.data.type()
414
+ ks = self.kernel_size
415
+ N = offset.size(1) // 2
416
+
417
+ if self.padding:
418
+ x = self.zero_padding(x)
419
+
420
+ p = self._get_p(offset, dtype)
421
+
422
+ p = p.contiguous().permute(0, 2, 3, 1)
423
+ q_lt = p.detach().floor()
424
+ q_rb = q_lt + 1
425
+
426
+ q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
427
+ q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
428
+ q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
429
+ q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
430
+
431
+ p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
432
+
433
+ g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
434
+ g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
435
+ g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
436
+ g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
437
+
438
+ x_q_lt = self._get_x_q(x, q_lt, N)
439
+ x_q_rb = self._get_x_q(x, q_rb, N)
440
+ x_q_lb = self._get_x_q(x, q_lb, N)
441
+ x_q_rt = self._get_x_q(x, q_rt, N)
442
+
443
+ x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
444
+ g_rb.unsqueeze(dim=1) * x_q_rb + \
445
+ g_lb.unsqueeze(dim=1) * x_q_lb + \
446
+ g_rt.unsqueeze(dim=1) * x_q_rt
447
+
448
+ if self.modulation:
449
+ m = m.contiguous().permute(0, 2, 3, 1)
450
+ m = m.unsqueeze(dim=1)
451
+ m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
452
+ x_offset *= m
453
+
454
+ x_offset = self._reshape_x_offset(x_offset, ks)
455
+ out = self.conv(x_offset)
456
+
457
+ return out
458
+
459
+ def _get_p_n(self, N, dtype):
460
+ p_n_x, p_n_y = torch.meshgrid(
461
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
462
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
463
+ p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
464
+ p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
465
+
466
+ return p_n
467
+
468
+ def _get_p_0(self, h, w, N, dtype):
469
+ p_0_x, p_0_y = torch.meshgrid(
470
+ torch.arange(1, h*self.stride+1, self.stride),
471
+ torch.arange(1, w*self.stride+1, self.stride))
472
+ p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
473
+ p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
474
+ p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
475
+
476
+ return p_0
477
+
478
+ def _get_p(self, offset, dtype):
479
+ N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
480
+
481
+ p_n = self._get_p_n(N, dtype)
482
+ p_0 = self._get_p_0(h, w, N, dtype)
483
+ p = p_0 + p_n + offset
484
+ return p
485
+
486
+ def _get_x_q(self, x, q, N):
487
+ b, h, w, _ = q.size()
488
+ padded_w = x.size(3)
489
+ c = x.size(1)
490
+ x = x.contiguous().view(b, c, -1)
491
+
492
+ index = q[..., :N]*padded_w + q[..., N:]
493
+ index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
494
+
495
+ x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
496
+
497
+ return x_offset
498
+
499
+ @staticmethod
500
+ def _reshape_x_offset(x_offset, ks):
501
+ b, c, h, w, N = x_offset.size()
502
+ x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
503
+ x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
504
+
505
+ return x_offset
506
+
507
+
508
+ class BasicConv(nn.Module):
509
+ def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
510
+ super(BasicConv, self).__init__()
511
+ if bias and norm:
512
+ bias = False
513
+
514
+ padding = kernel_size // 2
515
+ layers = list()
516
+ if transpose:
517
+ padding = kernel_size // 2 -1
518
+ layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
519
+ else:
520
+ layers.append(
521
+ nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
522
+ if norm:
523
+ layers.append(nn.BatchNorm2d(out_channel))
524
+ if relu:
525
+ layers.append(nn.GELU())
526
+ self.main = nn.Sequential(*layers)
527
+
528
+ def forward(self, x):
529
+ return self.main(x)
530
+
531
+
532
+ class ResBlock(nn.Module):
533
+ def __init__(self, in_channel, out_channel):
534
+ super(ResBlock, self).__init__()
535
+ self.main = nn.Sequential(
536
+ BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
537
+ BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
538
+ )
539
+
540
+ def forward(self, x):
541
+ return self.main(x) + x
542
+