esrgan_model.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import sys
  2. import numpy as np
  3. import torch
  4. from PIL import Image
  5. import modules.esrgan_model_arch as arch
  6. from modules import modelloader, images, devices
  7. from modules.shared import opts
  8. from modules.upscaler import Upscaler, UpscalerData
  9. def mod2normal(state_dict):
  10. # this code is copied from https://github.com/victorca25/iNNfer
  11. if 'conv_first.weight' in state_dict:
  12. crt_net = {}
  13. items = list(state_dict)
  14. crt_net['model.0.weight'] = state_dict['conv_first.weight']
  15. crt_net['model.0.bias'] = state_dict['conv_first.bias']
  16. for k in items.copy():
  17. if 'RDB' in k:
  18. ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
  19. if '.weight' in k:
  20. ori_k = ori_k.replace('.weight', '.0.weight')
  21. elif '.bias' in k:
  22. ori_k = ori_k.replace('.bias', '.0.bias')
  23. crt_net[ori_k] = state_dict[k]
  24. items.remove(k)
  25. crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
  26. crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
  27. crt_net['model.3.weight'] = state_dict['upconv1.weight']
  28. crt_net['model.3.bias'] = state_dict['upconv1.bias']
  29. crt_net['model.6.weight'] = state_dict['upconv2.weight']
  30. crt_net['model.6.bias'] = state_dict['upconv2.bias']
  31. crt_net['model.8.weight'] = state_dict['HRconv.weight']
  32. crt_net['model.8.bias'] = state_dict['HRconv.bias']
  33. crt_net['model.10.weight'] = state_dict['conv_last.weight']
  34. crt_net['model.10.bias'] = state_dict['conv_last.bias']
  35. state_dict = crt_net
  36. return state_dict
  37. def resrgan2normal(state_dict, nb=23):
  38. # this code is copied from https://github.com/victorca25/iNNfer
  39. if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
  40. re8x = 0
  41. crt_net = {}
  42. items = list(state_dict)
  43. crt_net['model.0.weight'] = state_dict['conv_first.weight']
  44. crt_net['model.0.bias'] = state_dict['conv_first.bias']
  45. for k in items.copy():
  46. if "rdb" in k:
  47. ori_k = k.replace('body.', 'model.1.sub.')
  48. ori_k = ori_k.replace('.rdb', '.RDB')
  49. if '.weight' in k:
  50. ori_k = ori_k.replace('.weight', '.0.weight')
  51. elif '.bias' in k:
  52. ori_k = ori_k.replace('.bias', '.0.bias')
  53. crt_net[ori_k] = state_dict[k]
  54. items.remove(k)
  55. crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
  56. crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
  57. crt_net['model.3.weight'] = state_dict['conv_up1.weight']
  58. crt_net['model.3.bias'] = state_dict['conv_up1.bias']
  59. crt_net['model.6.weight'] = state_dict['conv_up2.weight']
  60. crt_net['model.6.bias'] = state_dict['conv_up2.bias']
  61. if 'conv_up3.weight' in state_dict:
  62. # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
  63. re8x = 3
  64. crt_net['model.9.weight'] = state_dict['conv_up3.weight']
  65. crt_net['model.9.bias'] = state_dict['conv_up3.bias']
  66. crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
  67. crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
  68. crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
  69. crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
  70. state_dict = crt_net
  71. return state_dict
  72. def infer_params(state_dict):
  73. # this code is copied from https://github.com/victorca25/iNNfer
  74. scale2x = 0
  75. scalemin = 6
  76. n_uplayer = 0
  77. plus = False
  78. for block in list(state_dict):
  79. parts = block.split(".")
  80. n_parts = len(parts)
  81. if n_parts == 5 and parts[2] == "sub":
  82. nb = int(parts[3])
  83. elif n_parts == 3:
  84. part_num = int(parts[1])
  85. if (part_num > scalemin
  86. and parts[0] == "model"
  87. and parts[2] == "weight"):
  88. scale2x += 1
  89. if part_num > n_uplayer:
  90. n_uplayer = part_num
  91. out_nc = state_dict[block].shape[0]
  92. if not plus and "conv1x1" in block:
  93. plus = True
  94. nf = state_dict["model.0.weight"].shape[0]
  95. in_nc = state_dict["model.0.weight"].shape[1]
  96. out_nc = out_nc
  97. scale = 2 ** scale2x
  98. return in_nc, out_nc, nf, nb, plus, scale
  99. class UpscalerESRGAN(Upscaler):
  100. def __init__(self, dirname):
  101. self.name = "ESRGAN"
  102. self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
  103. self.model_name = "ESRGAN_4x"
  104. self.scalers = []
  105. self.user_path = dirname
  106. super().__init__()
  107. model_paths = self.find_models(ext_filter=[".pt", ".pth"])
  108. scalers = []
  109. if len(model_paths) == 0:
  110. scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
  111. scalers.append(scaler_data)
  112. for file in model_paths:
  113. if file.startswith("http"):
  114. name = self.model_name
  115. else:
  116. name = modelloader.friendly_name(file)
  117. scaler_data = UpscalerData(name, file, self, 4)
  118. self.scalers.append(scaler_data)
  119. def do_upscale(self, img, selected_model):
  120. try:
  121. model = self.load_model(selected_model)
  122. except Exception as e:
  123. print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
  124. return img
  125. model.to(devices.device_esrgan)
  126. img = esrgan_upscale(model, img)
  127. return img
  128. def load_model(self, path: str):
  129. if path.startswith("http"):
  130. # TODO: this doesn't use `path` at all?
  131. filename = modelloader.load_file_from_url(
  132. url=self.model_url,
  133. model_dir=self.model_download_path,
  134. file_name=f"{self.model_name}.pth",
  135. )
  136. else:
  137. filename = path
  138. state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
  139. if "params_ema" in state_dict:
  140. state_dict = state_dict["params_ema"]
  141. elif "params" in state_dict:
  142. state_dict = state_dict["params"]
  143. num_conv = 16 if "realesr-animevideov3" in filename else 32
  144. model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
  145. model.load_state_dict(state_dict)
  146. model.eval()
  147. return model
  148. if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
  149. nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
  150. state_dict = resrgan2normal(state_dict, nb)
  151. elif "conv_first.weight" in state_dict:
  152. state_dict = mod2normal(state_dict)
  153. elif "model.0.weight" not in state_dict:
  154. raise Exception("The file is not a recognized ESRGAN model.")
  155. in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
  156. model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
  157. model.load_state_dict(state_dict)
  158. model.eval()
  159. return model
  160. def upscale_without_tiling(model, img):
  161. img = np.array(img)
  162. img = img[:, :, ::-1]
  163. img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
  164. img = torch.from_numpy(img).float()
  165. img = img.unsqueeze(0).to(devices.device_esrgan)
  166. with torch.no_grad():
  167. output = model(img)
  168. output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
  169. output = 255. * np.moveaxis(output, 0, 2)
  170. output = output.astype(np.uint8)
  171. output = output[:, :, ::-1]
  172. return Image.fromarray(output, 'RGB')
  173. def esrgan_upscale(model, img):
  174. if opts.ESRGAN_tile == 0:
  175. return upscale_without_tiling(model, img)
  176. grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
  177. newtiles = []
  178. scale_factor = 1
  179. for y, h, row in grid.tiles:
  180. newrow = []
  181. for tiledata in row:
  182. x, w, tile = tiledata
  183. output = upscale_without_tiling(model, tile)
  184. scale_factor = output.width // tile.width
  185. newrow.append([x * scale_factor, w * scale_factor, output])
  186. newtiles.append([y * scale_factor, h * scale_factor, newrow])
  187. newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
  188. output = images.combine_grid(newgrid)
  189. return output