modelloader.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from __future__ import annotations
  2. import os
  3. import shutil
  4. import importlib
  5. from urllib.parse import urlparse
  6. from modules import shared
  7. from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
  8. from modules.paths import script_path, models_path
  9. def load_file_from_url(
  10. url: str,
  11. *,
  12. model_dir: str,
  13. progress: bool = True,
  14. file_name: str | None = None,
  15. ) -> str:
  16. """Download a file from `url` into `model_dir`, using the file present if possible.
  17. Returns the path to the downloaded file.
  18. """
  19. os.makedirs(model_dir, exist_ok=True)
  20. if not file_name:
  21. parts = urlparse(url)
  22. file_name = os.path.basename(parts.path)
  23. cached_file = os.path.abspath(os.path.join(model_dir, file_name))
  24. if not os.path.exists(cached_file):
  25. print(f'Downloading: "{url}" to {cached_file}\n')
  26. from torch.hub import download_url_to_file
  27. download_url_to_file(url, cached_file, progress=progress)
  28. return cached_file
  29. def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
  30. """
  31. A one-and done loader to try finding the desired models in specified directories.
  32. @param download_name: Specify to download from model_url immediately.
  33. @param model_url: If no other models are found, this will be downloaded on upscale.
  34. @param model_path: The location to store/find models in.
  35. @param command_path: A command-line argument to search for models in first.
  36. @param ext_filter: An optional list of filename extensions to filter by
  37. @return: A list of paths containing the desired model(s)
  38. """
  39. output = []
  40. try:
  41. places = []
  42. if command_path is not None and command_path != model_path:
  43. pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
  44. if os.path.exists(pretrained_path):
  45. print(f"Appending path: {pretrained_path}")
  46. places.append(pretrained_path)
  47. elif os.path.exists(command_path):
  48. places.append(command_path)
  49. places.append(model_path)
  50. for place in places:
  51. for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
  52. if os.path.islink(full_path) and not os.path.exists(full_path):
  53. print(f"Skipping broken symlink: {full_path}")
  54. continue
  55. if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
  56. continue
  57. if full_path not in output:
  58. output.append(full_path)
  59. if model_url is not None and len(output) == 0:
  60. if download_name is not None:
  61. output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
  62. else:
  63. output.append(model_url)
  64. except Exception:
  65. pass
  66. return output
  67. def friendly_name(file: str):
  68. if file.startswith("http"):
  69. file = urlparse(file).path
  70. file = os.path.basename(file)
  71. model_name, extension = os.path.splitext(file)
  72. return model_name
  73. def cleanup_models():
  74. # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
  75. # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
  76. # somehow auto-register and just do these things...
  77. root_path = script_path
  78. src_path = models_path
  79. dest_path = os.path.join(models_path, "Stable-diffusion")
  80. move_files(src_path, dest_path, ".ckpt")
  81. move_files(src_path, dest_path, ".safetensors")
  82. src_path = os.path.join(root_path, "ESRGAN")
  83. dest_path = os.path.join(models_path, "ESRGAN")
  84. move_files(src_path, dest_path)
  85. src_path = os.path.join(models_path, "BSRGAN")
  86. dest_path = os.path.join(models_path, "ESRGAN")
  87. move_files(src_path, dest_path, ".pth")
  88. src_path = os.path.join(root_path, "gfpgan")
  89. dest_path = os.path.join(models_path, "GFPGAN")
  90. move_files(src_path, dest_path)
  91. src_path = os.path.join(root_path, "SwinIR")
  92. dest_path = os.path.join(models_path, "SwinIR")
  93. move_files(src_path, dest_path)
  94. src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
  95. dest_path = os.path.join(models_path, "LDSR")
  96. move_files(src_path, dest_path)
  97. def move_files(src_path: str, dest_path: str, ext_filter: str = None):
  98. try:
  99. os.makedirs(dest_path, exist_ok=True)
  100. if os.path.exists(src_path):
  101. for file in os.listdir(src_path):
  102. fullpath = os.path.join(src_path, file)
  103. if os.path.isfile(fullpath):
  104. if ext_filter is not None:
  105. if ext_filter not in file:
  106. continue
  107. print(f"Moving {file} from {src_path} to {dest_path}.")
  108. try:
  109. shutil.move(fullpath, dest_path)
  110. except Exception:
  111. pass
  112. if len(os.listdir(src_path)) == 0:
  113. print(f"Removing empty folder: {src_path}")
  114. shutil.rmtree(src_path, True)
  115. except Exception:
  116. pass
  117. def load_upscalers():
  118. # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
  119. # so we'll try to import any _model.py files before looking in __subclasses__
  120. modules_dir = os.path.join(shared.script_path, "modules")
  121. for file in os.listdir(modules_dir):
  122. if "_model.py" in file:
  123. model_name = file.replace("_model.py", "")
  124. full_model = f"modules.{model_name}_model"
  125. try:
  126. importlib.import_module(full_model)
  127. except Exception:
  128. pass
  129. datas = []
  130. commandline_options = vars(shared.cmd_opts)
  131. # some of upscaler classes will not go away after reloading their modules, and we'll end
  132. # up with two copies of those classes. The newest copy will always be the last in the list,
  133. # so we go from end to beginning and ignore duplicates
  134. used_classes = {}
  135. for cls in reversed(Upscaler.__subclasses__()):
  136. classname = str(cls)
  137. if classname not in used_classes:
  138. used_classes[classname] = cls
  139. for cls in reversed(used_classes.values()):
  140. name = cls.__name__
  141. cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
  142. commandline_model_path = commandline_options.get(cmd_name, None)
  143. scaler = cls(commandline_model_path)
  144. scaler.user_path = commandline_model_path
  145. scaler.model_download_path = commandline_model_path or scaler.model_path
  146. datas += scaler.scalers
  147. shared.sd_upscalers = sorted(
  148. datas,
  149. # Special case for UpscalerNone keeps it at the beginning of the list.
  150. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
  151. )