deepbooru.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. import re
  3. import torch
  4. import numpy as np
  5. from modules import modelloader, paths, deepbooru_model, devices, images, shared
  6. re_special = re.compile(r'([\\()])')
  7. class DeepDanbooru:
  8. def __init__(self):
  9. self.model = None
  10. def load(self):
  11. if self.model is not None:
  12. return
  13. files = modelloader.load_models(
  14. model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
  15. model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
  16. ext_filter=[".pt"],
  17. download_name='model-resnet_custom_v3.pt',
  18. )
  19. self.model = deepbooru_model.DeepDanbooruModel()
  20. self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
  21. self.model.eval()
  22. self.model.to(devices.cpu, devices.dtype)
  23. def start(self):
  24. self.load()
  25. self.model.to(devices.device)
  26. def stop(self):
  27. if not shared.opts.interrogate_keep_models_in_memory:
  28. self.model.to(devices.cpu)
  29. devices.torch_gc()
  30. def tag(self, pil_image):
  31. self.start()
  32. res = self.tag_multi(pil_image)
  33. self.stop()
  34. return res
  35. def tag_multi(self, pil_image, force_disable_ranks=False):
  36. threshold = shared.opts.interrogate_deepbooru_score_threshold
  37. use_spaces = shared.opts.deepbooru_use_spaces
  38. use_escape = shared.opts.deepbooru_escape
  39. alpha_sort = shared.opts.deepbooru_sort_alpha
  40. include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
  41. pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
  42. a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
  43. with torch.no_grad(), devices.autocast():
  44. x = torch.from_numpy(a).to(devices.device)
  45. y = self.model(x)[0].detach().cpu().numpy()
  46. probability_dict = {}
  47. for tag, probability in zip(self.model.tags, y):
  48. if probability < threshold:
  49. continue
  50. if tag.startswith("rating:"):
  51. continue
  52. probability_dict[tag] = probability
  53. if alpha_sort:
  54. tags = sorted(probability_dict)
  55. else:
  56. tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
  57. res = []
  58. filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
  59. for tag in [x for x in tags if x not in filtertags]:
  60. probability = probability_dict[tag]
  61. tag_outformat = tag
  62. if use_spaces:
  63. tag_outformat = tag_outformat.replace('_', ' ')
  64. if use_escape:
  65. tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
  66. if include_ranks:
  67. tag_outformat = f"({tag_outformat}:{probability:.3f})"
  68. res.append(tag_outformat)
  69. return ", ".join(res)
  70. model = DeepDanbooru()