esrgan_model_arch.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. # this file is adapted from https://github.com/victorca25/iNNfer
  2. from collections import OrderedDict
  3. import math
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. ####################
  8. # RRDBNet Generator
  9. ####################
  10. class RRDBNet(nn.Module):
  11. def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
  12. act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
  13. finalact=None, gaussian_noise=False, plus=False):
  14. super(RRDBNet, self).__init__()
  15. n_upscale = int(math.log(upscale, 2))
  16. if upscale == 3:
  17. n_upscale = 1
  18. self.resrgan_scale = 0
  19. if in_nc % 16 == 0:
  20. self.resrgan_scale = 1
  21. elif in_nc != 4 and in_nc % 4 == 0:
  22. self.resrgan_scale = 2
  23. fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
  24. rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
  25. norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
  26. gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
  27. LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
  28. if upsample_mode == 'upconv':
  29. upsample_block = upconv_block
  30. elif upsample_mode == 'pixelshuffle':
  31. upsample_block = pixelshuffle_block
  32. else:
  33. raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
  34. if upscale == 3:
  35. upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
  36. else:
  37. upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
  38. HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
  39. HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
  40. outact = act(finalact) if finalact else None
  41. self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
  42. *upsampler, HR_conv0, HR_conv1, outact)
  43. def forward(self, x, outm=None):
  44. if self.resrgan_scale == 1:
  45. feat = pixel_unshuffle(x, scale=4)
  46. elif self.resrgan_scale == 2:
  47. feat = pixel_unshuffle(x, scale=2)
  48. else:
  49. feat = x
  50. return self.model(feat)
  51. class RRDB(nn.Module):
  52. """
  53. Residual in Residual Dense Block
  54. (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
  55. """
  56. def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
  57. norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
  58. spectral_norm=False, gaussian_noise=False, plus=False):
  59. super(RRDB, self).__init__()
  60. # This is for backwards compatibility with existing models
  61. if nr == 3:
  62. self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
  63. norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
  64. gaussian_noise=gaussian_noise, plus=plus)
  65. self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
  66. norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
  67. gaussian_noise=gaussian_noise, plus=plus)
  68. self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
  69. norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
  70. gaussian_noise=gaussian_noise, plus=plus)
  71. else:
  72. RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
  73. norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
  74. gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
  75. self.RDBs = nn.Sequential(*RDB_list)
  76. def forward(self, x):
  77. if hasattr(self, 'RDB1'):
  78. out = self.RDB1(x)
  79. out = self.RDB2(out)
  80. out = self.RDB3(out)
  81. else:
  82. out = self.RDBs(x)
  83. return out * 0.2 + x
  84. class ResidualDenseBlock_5C(nn.Module):
  85. """
  86. Residual Dense Block
  87. The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
  88. Modified options that can be used:
  89. - "Partial Convolution based Padding" arXiv:1811.11718
  90. - "Spectral normalization" arXiv:1802.05957
  91. - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
  92. {Rakotonirina} and A. {Rasoanaivo}
  93. """
  94. def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
  95. norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
  96. spectral_norm=False, gaussian_noise=False, plus=False):
  97. super(ResidualDenseBlock_5C, self).__init__()
  98. self.noise = GaussianNoise() if gaussian_noise else None
  99. self.conv1x1 = conv1x1(nf, gc) if plus else None
  100. self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
  101. norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
  102. spectral_norm=spectral_norm)
  103. self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
  104. norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
  105. spectral_norm=spectral_norm)
  106. self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
  107. norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
  108. spectral_norm=spectral_norm)
  109. self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
  110. norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
  111. spectral_norm=spectral_norm)
  112. if mode == 'CNA':
  113. last_act = None
  114. else:
  115. last_act = act_type
  116. self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
  117. norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
  118. spectral_norm=spectral_norm)
  119. def forward(self, x):
  120. x1 = self.conv1(x)
  121. x2 = self.conv2(torch.cat((x, x1), 1))
  122. if self.conv1x1:
  123. x2 = x2 + self.conv1x1(x)
  124. x3 = self.conv3(torch.cat((x, x1, x2), 1))
  125. x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
  126. if self.conv1x1:
  127. x4 = x4 + x2
  128. x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
  129. if self.noise:
  130. return self.noise(x5.mul(0.2) + x)
  131. else:
  132. return x5 * 0.2 + x
  133. ####################
  134. # ESRGANplus
  135. ####################
  136. class GaussianNoise(nn.Module):
  137. def __init__(self, sigma=0.1, is_relative_detach=False):
  138. super().__init__()
  139. self.sigma = sigma
  140. self.is_relative_detach = is_relative_detach
  141. self.noise = torch.tensor(0, dtype=torch.float)
  142. def forward(self, x):
  143. if self.training and self.sigma != 0:
  144. self.noise = self.noise.to(x.device)
  145. scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
  146. sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
  147. x = x + sampled_noise
  148. return x
  149. def conv1x1(in_planes, out_planes, stride=1):
  150. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  151. ####################
  152. # SRVGGNetCompact
  153. ####################
  154. class SRVGGNetCompact(nn.Module):
  155. """A compact VGG-style network structure for super-resolution.
  156. This class is copied from https://github.com/xinntao/Real-ESRGAN
  157. """
  158. def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
  159. super(SRVGGNetCompact, self).__init__()
  160. self.num_in_ch = num_in_ch
  161. self.num_out_ch = num_out_ch
  162. self.num_feat = num_feat
  163. self.num_conv = num_conv
  164. self.upscale = upscale
  165. self.act_type = act_type
  166. self.body = nn.ModuleList()
  167. # the first conv
  168. self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
  169. # the first activation
  170. if act_type == 'relu':
  171. activation = nn.ReLU(inplace=True)
  172. elif act_type == 'prelu':
  173. activation = nn.PReLU(num_parameters=num_feat)
  174. elif act_type == 'leakyrelu':
  175. activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  176. self.body.append(activation)
  177. # the body structure
  178. for _ in range(num_conv):
  179. self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
  180. # activation
  181. if act_type == 'relu':
  182. activation = nn.ReLU(inplace=True)
  183. elif act_type == 'prelu':
  184. activation = nn.PReLU(num_parameters=num_feat)
  185. elif act_type == 'leakyrelu':
  186. activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  187. self.body.append(activation)
  188. # the last conv
  189. self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
  190. # upsample
  191. self.upsampler = nn.PixelShuffle(upscale)
  192. def forward(self, x):
  193. out = x
  194. for i in range(0, len(self.body)):
  195. out = self.body[i](out)
  196. out = self.upsampler(out)
  197. # add the nearest upsampled image, so that the network learns the residual
  198. base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
  199. out += base
  200. return out
  201. ####################
  202. # Upsampler
  203. ####################
  204. class Upsample(nn.Module):
  205. r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
  206. The input data is assumed to be of the form
  207. `minibatch x channels x [optional depth] x [optional height] x width`.
  208. """
  209. def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
  210. super(Upsample, self).__init__()
  211. if isinstance(scale_factor, tuple):
  212. self.scale_factor = tuple(float(factor) for factor in scale_factor)
  213. else:
  214. self.scale_factor = float(scale_factor) if scale_factor else None
  215. self.mode = mode
  216. self.size = size
  217. self.align_corners = align_corners
  218. def forward(self, x):
  219. return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
  220. def extra_repr(self):
  221. if self.scale_factor is not None:
  222. info = f'scale_factor={self.scale_factor}'
  223. else:
  224. info = f'size={self.size}'
  225. info += f', mode={self.mode}'
  226. return info
  227. def pixel_unshuffle(x, scale):
  228. """ Pixel unshuffle.
  229. Args:
  230. x (Tensor): Input feature with shape (b, c, hh, hw).
  231. scale (int): Downsample ratio.
  232. Returns:
  233. Tensor: the pixel unshuffled feature.
  234. """
  235. b, c, hh, hw = x.size()
  236. out_channel = c * (scale**2)
  237. assert hh % scale == 0 and hw % scale == 0
  238. h = hh // scale
  239. w = hw // scale
  240. x_view = x.view(b, c, h, scale, w, scale)
  241. return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
  242. def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
  243. pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
  244. """
  245. Pixel shuffle layer
  246. (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
  247. Neural Network, CVPR17)
  248. """
  249. conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
  250. pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
  251. pixel_shuffle = nn.PixelShuffle(upscale_factor)
  252. n = norm(norm_type, out_nc) if norm_type else None
  253. a = act(act_type) if act_type else None
  254. return sequential(conv, pixel_shuffle, n, a)
  255. def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
  256. pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
  257. """ Upconv layer """
  258. upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
  259. upsample = Upsample(scale_factor=upscale_factor, mode=mode)
  260. conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
  261. pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
  262. return sequential(upsample, conv)
  263. ####################
  264. # Basic blocks
  265. ####################
  266. def make_layer(basic_block, num_basic_block, **kwarg):
  267. """Make layers by stacking the same blocks.
  268. Args:
  269. basic_block (nn.module): nn.module class for basic block. (block)
  270. num_basic_block (int): number of blocks. (n_layers)
  271. Returns:
  272. nn.Sequential: Stacked blocks in nn.Sequential.
  273. """
  274. layers = []
  275. for _ in range(num_basic_block):
  276. layers.append(basic_block(**kwarg))
  277. return nn.Sequential(*layers)
  278. def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
  279. """ activation helper """
  280. act_type = act_type.lower()
  281. if act_type == 'relu':
  282. layer = nn.ReLU(inplace)
  283. elif act_type in ('leakyrelu', 'lrelu'):
  284. layer = nn.LeakyReLU(neg_slope, inplace)
  285. elif act_type == 'prelu':
  286. layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
  287. elif act_type == 'tanh': # [-1, 1] range output
  288. layer = nn.Tanh()
  289. elif act_type == 'sigmoid': # [0, 1] range output
  290. layer = nn.Sigmoid()
  291. else:
  292. raise NotImplementedError(f'activation layer [{act_type}] is not found')
  293. return layer
  294. class Identity(nn.Module):
  295. def __init__(self, *kwargs):
  296. super(Identity, self).__init__()
  297. def forward(self, x, *kwargs):
  298. return x
  299. def norm(norm_type, nc):
  300. """ Return a normalization layer """
  301. norm_type = norm_type.lower()
  302. if norm_type == 'batch':
  303. layer = nn.BatchNorm2d(nc, affine=True)
  304. elif norm_type == 'instance':
  305. layer = nn.InstanceNorm2d(nc, affine=False)
  306. elif norm_type == 'none':
  307. def norm_layer(x): return Identity()
  308. else:
  309. raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
  310. return layer
  311. def pad(pad_type, padding):
  312. """ padding layer helper """
  313. pad_type = pad_type.lower()
  314. if padding == 0:
  315. return None
  316. if pad_type == 'reflect':
  317. layer = nn.ReflectionPad2d(padding)
  318. elif pad_type == 'replicate':
  319. layer = nn.ReplicationPad2d(padding)
  320. elif pad_type == 'zero':
  321. layer = nn.ZeroPad2d(padding)
  322. else:
  323. raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
  324. return layer
  325. def get_valid_padding(kernel_size, dilation):
  326. kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
  327. padding = (kernel_size - 1) // 2
  328. return padding
  329. class ShortcutBlock(nn.Module):
  330. """ Elementwise sum the output of a submodule to its input """
  331. def __init__(self, submodule):
  332. super(ShortcutBlock, self).__init__()
  333. self.sub = submodule
  334. def forward(self, x):
  335. output = x + self.sub(x)
  336. return output
  337. def __repr__(self):
  338. return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
  339. def sequential(*args):
  340. """ Flatten Sequential. It unwraps nn.Sequential. """
  341. if len(args) == 1:
  342. if isinstance(args[0], OrderedDict):
  343. raise NotImplementedError('sequential does not support OrderedDict input.')
  344. return args[0] # No sequential is needed.
  345. modules = []
  346. for module in args:
  347. if isinstance(module, nn.Sequential):
  348. for submodule in module.children():
  349. modules.append(submodule)
  350. elif isinstance(module, nn.Module):
  351. modules.append(module)
  352. return nn.Sequential(*modules)
  353. def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
  354. pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
  355. spectral_norm=False):
  356. """ Conv layer with padding, normalization, activation """
  357. assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
  358. padding = get_valid_padding(kernel_size, dilation)
  359. p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
  360. padding = padding if pad_type == 'zero' else 0
  361. if convtype=='PartialConv2D':
  362. from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
  363. c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
  364. dilation=dilation, bias=bias, groups=groups)
  365. elif convtype=='DeformConv2D':
  366. from torchvision.ops import DeformConv2d # not tested
  367. c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
  368. dilation=dilation, bias=bias, groups=groups)
  369. elif convtype=='Conv3D':
  370. c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
  371. dilation=dilation, bias=bias, groups=groups)
  372. else:
  373. c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
  374. dilation=dilation, bias=bias, groups=groups)
  375. if spectral_norm:
  376. c = nn.utils.spectral_norm(c)
  377. a = act(act_type) if act_type else None
  378. if 'CNA' in mode:
  379. n = norm(norm_type, out_nc) if norm_type else None
  380. return sequential(p, c, n, a)
  381. elif mode == 'NAC':
  382. if norm_type is None and act_type is not None:
  383. a = act(act_type, inplace=False)
  384. n = norm(norm_type, in_nc) if norm_type else None
  385. return sequential(n, a, p, c)