interrogate.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import os
  2. import sys
  3. from collections import namedtuple
  4. from pathlib import Path
  5. import re
  6. import torch
  7. import torch.hub
  8. from torchvision import transforms
  9. from torchvision.transforms.functional import InterpolationMode
  10. from modules import devices, paths, shared, lowvram, modelloader, errors
  11. blip_image_eval_size = 384
  12. clip_model_name = 'ViT-L/14'
  13. Category = namedtuple("Category", ["name", "topn", "items"])
  14. re_topn = re.compile(r"\.top(\d+)\.")
  15. def category_types():
  16. return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
  17. def download_default_clip_interrogate_categories(content_dir):
  18. print("Downloading CLIP categories...")
  19. tmpdir = f"{content_dir}_tmp"
  20. category_types = ["artists", "flavors", "mediums", "movements"]
  21. try:
  22. os.makedirs(tmpdir, exist_ok=True)
  23. for category_type in category_types:
  24. torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
  25. os.rename(tmpdir, content_dir)
  26. except Exception as e:
  27. errors.display(e, "downloading default CLIP interrogate categories")
  28. finally:
  29. if os.path.exists(tmpdir):
  30. os.removedirs(tmpdir)
  31. class InterrogateModels:
  32. blip_model = None
  33. clip_model = None
  34. clip_preprocess = None
  35. dtype = None
  36. running_on_cpu = None
  37. def __init__(self, content_dir):
  38. self.loaded_categories = None
  39. self.skip_categories = []
  40. self.content_dir = content_dir
  41. self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
  42. def categories(self):
  43. if not os.path.exists(self.content_dir):
  44. download_default_clip_interrogate_categories(self.content_dir)
  45. if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
  46. return self.loaded_categories
  47. self.loaded_categories = []
  48. if os.path.exists(self.content_dir):
  49. self.skip_categories = shared.opts.interrogate_clip_skip_categories
  50. category_types = []
  51. for filename in Path(self.content_dir).glob('*.txt'):
  52. category_types.append(filename.stem)
  53. if filename.stem in self.skip_categories:
  54. continue
  55. m = re_topn.search(filename.stem)
  56. topn = 1 if m is None else int(m.group(1))
  57. with open(filename, "r", encoding="utf8") as file:
  58. lines = [x.strip() for x in file.readlines()]
  59. self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
  60. return self.loaded_categories
  61. def create_fake_fairscale(self):
  62. class FakeFairscale:
  63. def checkpoint_wrapper(self):
  64. pass
  65. sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
  66. def load_blip_model(self):
  67. self.create_fake_fairscale()
  68. import models.blip
  69. files = modelloader.load_models(
  70. model_path=os.path.join(paths.models_path, "BLIP"),
  71. model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
  72. ext_filter=[".pth"],
  73. download_name='model_base_caption_capfilt_large.pth',
  74. )
  75. blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
  76. blip_model.eval()
  77. return blip_model
  78. def load_clip_model(self):
  79. import clip
  80. if self.running_on_cpu:
  81. model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
  82. else:
  83. model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
  84. model.eval()
  85. model = model.to(devices.device_interrogate)
  86. return model, preprocess
  87. def load(self):
  88. if self.blip_model is None:
  89. self.blip_model = self.load_blip_model()
  90. if not shared.cmd_opts.no_half and not self.running_on_cpu:
  91. self.blip_model = self.blip_model.half()
  92. self.blip_model = self.blip_model.to(devices.device_interrogate)
  93. if self.clip_model is None:
  94. self.clip_model, self.clip_preprocess = self.load_clip_model()
  95. if not shared.cmd_opts.no_half and not self.running_on_cpu:
  96. self.clip_model = self.clip_model.half()
  97. self.clip_model = self.clip_model.to(devices.device_interrogate)
  98. self.dtype = next(self.clip_model.parameters()).dtype
  99. def send_clip_to_ram(self):
  100. if not shared.opts.interrogate_keep_models_in_memory:
  101. if self.clip_model is not None:
  102. self.clip_model = self.clip_model.to(devices.cpu)
  103. def send_blip_to_ram(self):
  104. if not shared.opts.interrogate_keep_models_in_memory:
  105. if self.blip_model is not None:
  106. self.blip_model = self.blip_model.to(devices.cpu)
  107. def unload(self):
  108. self.send_clip_to_ram()
  109. self.send_blip_to_ram()
  110. devices.torch_gc()
  111. def rank(self, image_features, text_array, top_count=1):
  112. import clip
  113. devices.torch_gc()
  114. if shared.opts.interrogate_clip_dict_limit != 0:
  115. text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
  116. top_count = min(top_count, len(text_array))
  117. text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
  118. text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
  119. text_features /= text_features.norm(dim=-1, keepdim=True)
  120. similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
  121. for i in range(image_features.shape[0]):
  122. similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
  123. similarity /= image_features.shape[0]
  124. top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
  125. return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
  126. def generate_caption(self, pil_image):
  127. gpu_image = transforms.Compose([
  128. transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
  129. transforms.ToTensor(),
  130. transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
  131. ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
  132. with torch.no_grad():
  133. caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
  134. return caption[0]
  135. def interrogate(self, pil_image):
  136. res = ""
  137. shared.state.begin(job="interrogate")
  138. try:
  139. if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
  140. lowvram.send_everything_to_cpu()
  141. devices.torch_gc()
  142. self.load()
  143. caption = self.generate_caption(pil_image)
  144. self.send_blip_to_ram()
  145. devices.torch_gc()
  146. res = caption
  147. clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
  148. with torch.no_grad(), devices.autocast():
  149. image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
  150. image_features /= image_features.norm(dim=-1, keepdim=True)
  151. for cat in self.categories():
  152. matches = self.rank(image_features, cat.items, top_count=cat.topn)
  153. for match, score in matches:
  154. if shared.opts.interrogate_return_ranks:
  155. res += f", ({match}:{score/100:.3f})"
  156. else:
  157. res += f", {match}"
  158. except Exception:
  159. errors.report("Error interrogating", exc_info=True)
  160. res += "<error>"
  161. self.unload()
  162. shared.state.end()
  163. return res