codeformer_arch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
  2. import math
  3. import torch
  4. from torch import nn, Tensor
  5. import torch.nn.functional as F
  6. from typing import Optional
  7. from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
  8. from basicsr.utils.registry import ARCH_REGISTRY
  9. def calc_mean_std(feat, eps=1e-5):
  10. """Calculate mean and std for adaptive_instance_normalization.
  11. Args:
  12. feat (Tensor): 4D tensor.
  13. eps (float): A small value added to the variance to avoid
  14. divide-by-zero. Default: 1e-5.
  15. """
  16. size = feat.size()
  17. assert len(size) == 4, 'The input feature should be 4D tensor.'
  18. b, c = size[:2]
  19. feat_var = feat.view(b, c, -1).var(dim=2) + eps
  20. feat_std = feat_var.sqrt().view(b, c, 1, 1)
  21. feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
  22. return feat_mean, feat_std
  23. def adaptive_instance_normalization(content_feat, style_feat):
  24. """Adaptive instance normalization.
  25. Adjust the reference features to have the similar color and illuminations
  26. as those in the degradate features.
  27. Args:
  28. content_feat (Tensor): The reference feature.
  29. style_feat (Tensor): The degradate features.
  30. """
  31. size = content_feat.size()
  32. style_mean, style_std = calc_mean_std(style_feat)
  33. content_mean, content_std = calc_mean_std(content_feat)
  34. normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
  35. return normalized_feat * style_std.expand(size) + style_mean.expand(size)
  36. class PositionEmbeddingSine(nn.Module):
  37. """
  38. This is a more standard version of the position embedding, very similar to the one
  39. used by the Attention is all you need paper, generalized to work on images.
  40. """
  41. def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
  42. super().__init__()
  43. self.num_pos_feats = num_pos_feats
  44. self.temperature = temperature
  45. self.normalize = normalize
  46. if scale is not None and normalize is False:
  47. raise ValueError("normalize should be True if scale is passed")
  48. if scale is None:
  49. scale = 2 * math.pi
  50. self.scale = scale
  51. def forward(self, x, mask=None):
  52. if mask is None:
  53. mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
  54. not_mask = ~mask
  55. y_embed = not_mask.cumsum(1, dtype=torch.float32)
  56. x_embed = not_mask.cumsum(2, dtype=torch.float32)
  57. if self.normalize:
  58. eps = 1e-6
  59. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  60. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  61. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
  62. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  63. pos_x = x_embed[:, :, :, None] / dim_t
  64. pos_y = y_embed[:, :, :, None] / dim_t
  65. pos_x = torch.stack(
  66. (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
  67. ).flatten(3)
  68. pos_y = torch.stack(
  69. (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
  70. ).flatten(3)
  71. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  72. return pos
  73. def _get_activation_fn(activation):
  74. """Return an activation function given a string"""
  75. if activation == "relu":
  76. return F.relu
  77. if activation == "gelu":
  78. return F.gelu
  79. if activation == "glu":
  80. return F.glu
  81. raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
  82. class TransformerSALayer(nn.Module):
  83. def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
  84. super().__init__()
  85. self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
  86. # Implementation of Feedforward model - MLP
  87. self.linear1 = nn.Linear(embed_dim, dim_mlp)
  88. self.dropout = nn.Dropout(dropout)
  89. self.linear2 = nn.Linear(dim_mlp, embed_dim)
  90. self.norm1 = nn.LayerNorm(embed_dim)
  91. self.norm2 = nn.LayerNorm(embed_dim)
  92. self.dropout1 = nn.Dropout(dropout)
  93. self.dropout2 = nn.Dropout(dropout)
  94. self.activation = _get_activation_fn(activation)
  95. def with_pos_embed(self, tensor, pos: Optional[Tensor]):
  96. return tensor if pos is None else tensor + pos
  97. def forward(self, tgt,
  98. tgt_mask: Optional[Tensor] = None,
  99. tgt_key_padding_mask: Optional[Tensor] = None,
  100. query_pos: Optional[Tensor] = None):
  101. # self attention
  102. tgt2 = self.norm1(tgt)
  103. q = k = self.with_pos_embed(tgt2, query_pos)
  104. tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
  105. key_padding_mask=tgt_key_padding_mask)[0]
  106. tgt = tgt + self.dropout1(tgt2)
  107. # ffn
  108. tgt2 = self.norm2(tgt)
  109. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  110. tgt = tgt + self.dropout2(tgt2)
  111. return tgt
  112. class Fuse_sft_block(nn.Module):
  113. def __init__(self, in_ch, out_ch):
  114. super().__init__()
  115. self.encode_enc = ResBlock(2*in_ch, out_ch)
  116. self.scale = nn.Sequential(
  117. nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
  118. nn.LeakyReLU(0.2, True),
  119. nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
  120. self.shift = nn.Sequential(
  121. nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
  122. nn.LeakyReLU(0.2, True),
  123. nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
  124. def forward(self, enc_feat, dec_feat, w=1):
  125. enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
  126. scale = self.scale(enc_feat)
  127. shift = self.shift(enc_feat)
  128. residual = w * (dec_feat * scale + shift)
  129. out = dec_feat + residual
  130. return out
  131. @ARCH_REGISTRY.register()
  132. class CodeFormer(VQAutoEncoder):
  133. def __init__(self, dim_embd=512, n_head=8, n_layers=9,
  134. codebook_size=1024, latent_size=256,
  135. connect_list=('32', '64', '128', '256'),
  136. fix_modules=('quantize', 'generator')):
  137. super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
  138. if fix_modules is not None:
  139. for module in fix_modules:
  140. for param in getattr(self, module).parameters():
  141. param.requires_grad = False
  142. self.connect_list = connect_list
  143. self.n_layers = n_layers
  144. self.dim_embd = dim_embd
  145. self.dim_mlp = dim_embd*2
  146. self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
  147. self.feat_emb = nn.Linear(256, self.dim_embd)
  148. # transformer
  149. self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
  150. for _ in range(self.n_layers)])
  151. # logits_predict head
  152. self.idx_pred_layer = nn.Sequential(
  153. nn.LayerNorm(dim_embd),
  154. nn.Linear(dim_embd, codebook_size, bias=False))
  155. self.channels = {
  156. '16': 512,
  157. '32': 256,
  158. '64': 256,
  159. '128': 128,
  160. '256': 128,
  161. '512': 64,
  162. }
  163. # after second residual block for > 16, before attn layer for ==16
  164. self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
  165. # after first residual block for > 16, before attn layer for ==16
  166. self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
  167. # fuse_convs_dict
  168. self.fuse_convs_dict = nn.ModuleDict()
  169. for f_size in self.connect_list:
  170. in_ch = self.channels[f_size]
  171. self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
  172. def _init_weights(self, module):
  173. if isinstance(module, (nn.Linear, nn.Embedding)):
  174. module.weight.data.normal_(mean=0.0, std=0.02)
  175. if isinstance(module, nn.Linear) and module.bias is not None:
  176. module.bias.data.zero_()
  177. elif isinstance(module, nn.LayerNorm):
  178. module.bias.data.zero_()
  179. module.weight.data.fill_(1.0)
  180. def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
  181. # ################### Encoder #####################
  182. enc_feat_dict = {}
  183. out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
  184. for i, block in enumerate(self.encoder.blocks):
  185. x = block(x)
  186. if i in out_list:
  187. enc_feat_dict[str(x.shape[-1])] = x.clone()
  188. lq_feat = x
  189. # ################# Transformer ###################
  190. # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
  191. pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
  192. # BCHW -> BC(HW) -> (HW)BC
  193. feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
  194. query_emb = feat_emb
  195. # Transformer encoder
  196. for layer in self.ft_layers:
  197. query_emb = layer(query_emb, query_pos=pos_emb)
  198. # output logits
  199. logits = self.idx_pred_layer(query_emb) # (hw)bn
  200. logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
  201. if code_only: # for training stage II
  202. # logits doesn't need softmax before cross_entropy loss
  203. return logits, lq_feat
  204. # ################# Quantization ###################
  205. # if self.training:
  206. # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
  207. # # b(hw)c -> bc(hw) -> bchw
  208. # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
  209. # ------------
  210. soft_one_hot = F.softmax(logits, dim=2)
  211. _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
  212. quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
  213. # preserve gradients
  214. # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
  215. if detach_16:
  216. quant_feat = quant_feat.detach() # for training stage III
  217. if adain:
  218. quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
  219. # ################## Generator ####################
  220. x = quant_feat
  221. fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
  222. for i, block in enumerate(self.generator.blocks):
  223. x = block(x)
  224. if i in fuse_list: # fuse after i-th block
  225. f_size = str(x.shape[-1])
  226. if w>0:
  227. x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
  228. out = x
  229. # logits doesn't need softmax before cross_entropy loss
  230. return out, logits, lq_feat