upscaler.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import os
  2. from abc import abstractmethod
  3. import PIL
  4. from PIL import Image
  5. import modules.shared
  6. from modules import modelloader, shared
  7. LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
  8. NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
  9. class Upscaler:
  10. name = None
  11. model_path = None
  12. model_name = None
  13. model_url = None
  14. enable = True
  15. filter = None
  16. model = None
  17. user_path = None
  18. scalers: []
  19. tile = True
  20. def __init__(self, create_dirs=False):
  21. self.mod_pad_h = None
  22. self.tile_size = modules.shared.opts.ESRGAN_tile
  23. self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
  24. self.device = modules.shared.device
  25. self.img = None
  26. self.output = None
  27. self.scale = 1
  28. self.half = not modules.shared.cmd_opts.no_half
  29. self.pre_pad = 0
  30. self.mod_scale = None
  31. self.model_download_path = None
  32. if self.model_path is None and self.name:
  33. self.model_path = os.path.join(shared.models_path, self.name)
  34. if self.model_path and create_dirs:
  35. os.makedirs(self.model_path, exist_ok=True)
  36. try:
  37. import cv2 # noqa: F401
  38. self.can_tile = True
  39. except Exception:
  40. pass
  41. @abstractmethod
  42. def do_upscale(self, img: PIL.Image, selected_model: str):
  43. return img
  44. def upscale(self, img: PIL.Image, scale, selected_model: str = None):
  45. self.scale = scale
  46. dest_w = int((img.width * scale) // 8 * 8)
  47. dest_h = int((img.height * scale) // 8 * 8)
  48. for _ in range(3):
  49. shape = (img.width, img.height)
  50. img = self.do_upscale(img, selected_model)
  51. if shape == (img.width, img.height):
  52. break
  53. if img.width >= dest_w and img.height >= dest_h:
  54. break
  55. if img.width != dest_w or img.height != dest_h:
  56. img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
  57. return img
  58. @abstractmethod
  59. def load_model(self, path: str):
  60. pass
  61. def find_models(self, ext_filter=None) -> list:
  62. return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
  63. def update_status(self, prompt):
  64. print(f"\nextras: {prompt}", file=shared.progress_print_out)
  65. class UpscalerData:
  66. name = None
  67. data_path = None
  68. scale: int = 4
  69. scaler: Upscaler = None
  70. model: None
  71. def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
  72. self.name = name
  73. self.data_path = path
  74. self.local_data_path = path
  75. self.scaler = upscaler
  76. self.scale = scale
  77. self.model = model
  78. class UpscalerNone(Upscaler):
  79. name = "None"
  80. scalers = []
  81. def load_model(self, path):
  82. pass
  83. def do_upscale(self, img, selected_model=None):
  84. return img
  85. def __init__(self, dirname=None):
  86. super().__init__(False)
  87. self.scalers = [UpscalerData("None", None, self)]
  88. class UpscalerLanczos(Upscaler):
  89. scalers = []
  90. def do_upscale(self, img, selected_model=None):
  91. return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
  92. def load_model(self, _):
  93. pass
  94. def __init__(self, dirname=None):
  95. super().__init__(False)
  96. self.name = "Lanczos"
  97. self.scalers = [UpscalerData("Lanczos", None, self)]
  98. class UpscalerNearest(Upscaler):
  99. scalers = []
  100. def do_upscale(self, img, selected_model=None):
  101. return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
  102. def load_model(self, _):
  103. pass
  104. def __init__(self, dirname=None):
  105. super().__init__(False)
  106. self.name = "Nearest"
  107. self.scalers = [UpscalerData("Nearest", None, self)]