sd_hijack.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import torch
  2. from torch.nn.functional import silu
  3. from types import MethodType
  4. import modules.textual_inversion.textual_inversion
  5. from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
  6. from modules.hypernetworks import hypernetwork
  7. from modules.shared import cmd_opts
  8. from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
  9. import ldm.modules.attention
  10. import ldm.modules.diffusionmodules.model
  11. import ldm.modules.diffusionmodules.openaimodel
  12. import ldm.models.diffusion.ddim
  13. import ldm.models.diffusion.plms
  14. import ldm.modules.encoders.modules
  15. import sgm.modules.attention
  16. import sgm.modules.diffusionmodules.model
  17. import sgm.modules.diffusionmodules.openaimodel
  18. import sgm.modules.encoders.modules
  19. attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
  20. diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
  21. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  22. # new memory efficient cross attention blocks do not support hypernets and we already
  23. # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
  24. ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
  25. ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
  26. # silence new console spam from SD2
  27. ldm.modules.attention.print = lambda *args: None
  28. ldm.modules.diffusionmodules.model.print = lambda *args: None
  29. optimizers = []
  30. current_optimizer: sd_hijack_optimizations.SdOptimization = None
  31. def list_optimizers():
  32. new_optimizers = script_callbacks.list_optimizers_callback()
  33. new_optimizers = [x for x in new_optimizers if x.is_available()]
  34. new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
  35. optimizers.clear()
  36. optimizers.extend(new_optimizers)
  37. def apply_optimizations(option=None):
  38. global current_optimizer
  39. undo_optimizations()
  40. if len(optimizers) == 0:
  41. # a script can access the model very early, and optimizations would not be filled by then
  42. current_optimizer = None
  43. return ''
  44. ldm.modules.diffusionmodules.model.nonlinearity = silu
  45. ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  46. sgm.modules.diffusionmodules.model.nonlinearity = silu
  47. sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  48. if current_optimizer is not None:
  49. current_optimizer.undo()
  50. current_optimizer = None
  51. selection = option or shared.opts.cross_attention_optimization
  52. if selection == "Automatic" and len(optimizers) > 0:
  53. matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
  54. else:
  55. matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
  56. if selection == "None":
  57. matching_optimizer = None
  58. elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
  59. matching_optimizer = None
  60. elif matching_optimizer is None:
  61. matching_optimizer = optimizers[0]
  62. if matching_optimizer is not None:
  63. print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
  64. matching_optimizer.apply()
  65. print("done.")
  66. current_optimizer = matching_optimizer
  67. return current_optimizer.name
  68. else:
  69. print("Disabling attention optimization")
  70. return ''
  71. def undo_optimizations():
  72. ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  73. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  74. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  75. sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  76. sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  77. sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  78. def fix_checkpoint():
  79. """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
  80. checkpoints to be added when not training (there's a warning)"""
  81. pass
  82. def weighted_loss(sd_model, pred, target, mean=True):
  83. #Calculate the weight normally, but ignore the mean
  84. loss = sd_model._old_get_loss(pred, target, mean=False)
  85. #Check if we have weights available
  86. weight = getattr(sd_model, '_custom_loss_weight', None)
  87. if weight is not None:
  88. loss *= weight
  89. #Return the loss, as mean if specified
  90. return loss.mean() if mean else loss
  91. def weighted_forward(sd_model, x, c, w, *args, **kwargs):
  92. try:
  93. #Temporarily append weights to a place accessible during loss calc
  94. sd_model._custom_loss_weight = w
  95. #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
  96. #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
  97. if not hasattr(sd_model, '_old_get_loss'):
  98. sd_model._old_get_loss = sd_model.get_loss
  99. sd_model.get_loss = MethodType(weighted_loss, sd_model)
  100. #Run the standard forward function, but with the patched 'get_loss'
  101. return sd_model.forward(x, c, *args, **kwargs)
  102. finally:
  103. try:
  104. #Delete temporary weights if appended
  105. del sd_model._custom_loss_weight
  106. except AttributeError:
  107. pass
  108. #If we have an old loss function, reset the loss function to the original one
  109. if hasattr(sd_model, '_old_get_loss'):
  110. sd_model.get_loss = sd_model._old_get_loss
  111. del sd_model._old_get_loss
  112. def apply_weighted_forward(sd_model):
  113. #Add new function 'weighted_forward' that can be called to calc weighted loss
  114. sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
  115. def undo_weighted_forward(sd_model):
  116. try:
  117. del sd_model.weighted_forward
  118. except AttributeError:
  119. pass
  120. class StableDiffusionModelHijack:
  121. fixes = None
  122. layers = None
  123. circular_enabled = False
  124. clip = None
  125. optimization_method = None
  126. embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
  127. def __init__(self):
  128. self.extra_generation_params = {}
  129. self.comments = []
  130. self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
  131. def apply_optimizations(self, option=None):
  132. try:
  133. self.optimization_method = apply_optimizations(option)
  134. except Exception as e:
  135. errors.display(e, "applying cross attention optimization")
  136. undo_optimizations()
  137. def hijack(self, m):
  138. conditioner = getattr(m, 'conditioner', None)
  139. if conditioner:
  140. text_cond_models = []
  141. for i in range(len(conditioner.embedders)):
  142. embedder = conditioner.embedders[i]
  143. typename = type(embedder).__name__
  144. if typename == 'FrozenOpenCLIPEmbedder':
  145. embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
  146. conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
  147. text_cond_models.append(conditioner.embedders[i])
  148. if typename == 'FrozenCLIPEmbedder':
  149. model_embeddings = embedder.transformer.text_model.embeddings
  150. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  151. conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
  152. text_cond_models.append(conditioner.embedders[i])
  153. if typename == 'FrozenOpenCLIPEmbedder2':
  154. embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
  155. conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
  156. text_cond_models.append(conditioner.embedders[i])
  157. if len(text_cond_models) == 1:
  158. m.cond_stage_model = text_cond_models[0]
  159. else:
  160. m.cond_stage_model = conditioner
  161. if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
  162. model_embeddings = m.cond_stage_model.roberta.embeddings
  163. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
  164. m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
  165. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
  166. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  167. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  168. m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  169. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
  170. m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
  171. m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  172. apply_weighted_forward(m)
  173. if m.cond_stage_key == "edit":
  174. sd_hijack_unet.hijack_ddpm_edit()
  175. self.apply_optimizations()
  176. self.clip = m.cond_stage_model
  177. def flatten(el):
  178. flattened = [flatten(children) for children in el.children()]
  179. res = [el]
  180. for c in flattened:
  181. res += c
  182. return res
  183. self.layers = flatten(m)
  184. if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
  185. ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
  186. ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
  187. def undo_hijack(self, m):
  188. if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
  189. m.cond_stage_model = m.cond_stage_model.wrapped
  190. elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
  191. m.cond_stage_model = m.cond_stage_model.wrapped
  192. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  193. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  194. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  195. elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
  196. m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
  197. m.cond_stage_model = m.cond_stage_model.wrapped
  198. undo_optimizations()
  199. undo_weighted_forward(m)
  200. self.apply_circular(False)
  201. self.layers = None
  202. self.clip = None
  203. ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
  204. def apply_circular(self, enable):
  205. if self.circular_enabled == enable:
  206. return
  207. self.circular_enabled = enable
  208. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  209. layer.padding_mode = 'circular' if enable else 'zeros'
  210. def clear_comments(self):
  211. self.comments = []
  212. self.extra_generation_params = {}
  213. def get_prompt_lengths(self, text):
  214. if self.clip is None:
  215. return "-", "-"
  216. _, token_count = self.clip.process_texts([text])
  217. return token_count, self.clip.get_target_prompt_token_count(token_count)
  218. def redo_hijack(self, m):
  219. self.undo_hijack(m)
  220. self.hijack(m)
  221. class EmbeddingsWithFixes(torch.nn.Module):
  222. def __init__(self, wrapped, embeddings):
  223. super().__init__()
  224. self.wrapped = wrapped
  225. self.embeddings = embeddings
  226. def forward(self, input_ids):
  227. batch_fixes = self.embeddings.fixes
  228. self.embeddings.fixes = None
  229. inputs_embeds = self.wrapped(input_ids)
  230. if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
  231. return inputs_embeds
  232. vecs = []
  233. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  234. for offset, embedding in fixes:
  235. emb = devices.cond_cast_unet(embedding.vec)
  236. emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
  237. tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
  238. vecs.append(tensor)
  239. return torch.stack(vecs)
  240. def add_circular_option_to_conv_2d():
  241. conv2d_constructor = torch.nn.Conv2d.__init__
  242. def conv2d_constructor_circular(self, *args, **kwargs):
  243. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  244. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  245. model_hijack = StableDiffusionModelHijack()
  246. def register_buffer(self, name, attr):
  247. """
  248. Fix register buffer bug for Mac OS.
  249. """
  250. if type(attr) == torch.Tensor:
  251. if attr.device != devices.device:
  252. attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
  253. setattr(self, name, attr)
  254. ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
  255. ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer