123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346 |
- import torch
- from torch.nn.functional import silu
- from types import MethodType
- import modules.textual_inversion.textual_inversion
- from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
- from modules.hypernetworks import hypernetwork
- from modules.shared import cmd_opts
- from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
- import ldm.modules.attention
- import ldm.modules.diffusionmodules.model
- import ldm.modules.diffusionmodules.openaimodel
- import ldm.models.diffusion.ddim
- import ldm.models.diffusion.plms
- import ldm.modules.encoders.modules
- import sgm.modules.attention
- import sgm.modules.diffusionmodules.model
- import sgm.modules.diffusionmodules.openaimodel
- import sgm.modules.encoders.modules
- attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
- diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
- diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
- # new memory efficient cross attention blocks do not support hypernets and we already
- # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
- ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
- ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
- # silence new console spam from SD2
- ldm.modules.attention.print = lambda *args: None
- ldm.modules.diffusionmodules.model.print = lambda *args: None
- optimizers = []
- current_optimizer: sd_hijack_optimizations.SdOptimization = None
- def list_optimizers():
- new_optimizers = script_callbacks.list_optimizers_callback()
- new_optimizers = [x for x in new_optimizers if x.is_available()]
- new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
- optimizers.clear()
- optimizers.extend(new_optimizers)
- def apply_optimizations(option=None):
- global current_optimizer
- undo_optimizations()
- if len(optimizers) == 0:
- # a script can access the model very early, and optimizations would not be filled by then
- current_optimizer = None
- return ''
- ldm.modules.diffusionmodules.model.nonlinearity = silu
- ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
- sgm.modules.diffusionmodules.model.nonlinearity = silu
- sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
- if current_optimizer is not None:
- current_optimizer.undo()
- current_optimizer = None
- selection = option or shared.opts.cross_attention_optimization
- if selection == "Automatic" and len(optimizers) > 0:
- matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
- else:
- matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
- if selection == "None":
- matching_optimizer = None
- elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
- matching_optimizer = None
- elif matching_optimizer is None:
- matching_optimizer = optimizers[0]
- if matching_optimizer is not None:
- print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
- matching_optimizer.apply()
- print("done.")
- current_optimizer = matching_optimizer
- return current_optimizer.name
- else:
- print("Disabling attention optimization")
- return ''
- def undo_optimizations():
- ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
- ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
- sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
- sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
- sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
- def fix_checkpoint():
- """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
- checkpoints to be added when not training (there's a warning)"""
- pass
- def weighted_loss(sd_model, pred, target, mean=True):
- #Calculate the weight normally, but ignore the mean
- loss = sd_model._old_get_loss(pred, target, mean=False)
- #Check if we have weights available
- weight = getattr(sd_model, '_custom_loss_weight', None)
- if weight is not None:
- loss *= weight
- #Return the loss, as mean if specified
- return loss.mean() if mean else loss
- def weighted_forward(sd_model, x, c, w, *args, **kwargs):
- try:
- #Temporarily append weights to a place accessible during loss calc
- sd_model._custom_loss_weight = w
- #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
- #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
- if not hasattr(sd_model, '_old_get_loss'):
- sd_model._old_get_loss = sd_model.get_loss
- sd_model.get_loss = MethodType(weighted_loss, sd_model)
- #Run the standard forward function, but with the patched 'get_loss'
- return sd_model.forward(x, c, *args, **kwargs)
- finally:
- try:
- #Delete temporary weights if appended
- del sd_model._custom_loss_weight
- except AttributeError:
- pass
- #If we have an old loss function, reset the loss function to the original one
- if hasattr(sd_model, '_old_get_loss'):
- sd_model.get_loss = sd_model._old_get_loss
- del sd_model._old_get_loss
- def apply_weighted_forward(sd_model):
- #Add new function 'weighted_forward' that can be called to calc weighted loss
- sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
- def undo_weighted_forward(sd_model):
- try:
- del sd_model.weighted_forward
- except AttributeError:
- pass
- class StableDiffusionModelHijack:
- fixes = None
- layers = None
- circular_enabled = False
- clip = None
- optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
- def __init__(self):
- self.extra_generation_params = {}
- self.comments = []
- self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
- def apply_optimizations(self, option=None):
- try:
- self.optimization_method = apply_optimizations(option)
- except Exception as e:
- errors.display(e, "applying cross attention optimization")
- undo_optimizations()
- def hijack(self, m):
- conditioner = getattr(m, 'conditioner', None)
- if conditioner:
- text_cond_models = []
- for i in range(len(conditioner.embedders)):
- embedder = conditioner.embedders[i]
- typename = type(embedder).__name__
- if typename == 'FrozenOpenCLIPEmbedder':
- embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
- conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
- text_cond_models.append(conditioner.embedders[i])
- if typename == 'FrozenCLIPEmbedder':
- model_embeddings = embedder.transformer.text_model.embeddings
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
- conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
- text_cond_models.append(conditioner.embedders[i])
- if typename == 'FrozenOpenCLIPEmbedder2':
- embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
- conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
- text_cond_models.append(conditioner.embedders[i])
- if len(text_cond_models) == 1:
- m.cond_stage_model = text_cond_models[0]
- else:
- m.cond_stage_model = conditioner
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- model_embeddings = m.cond_stage_model.roberta.embeddings
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
- m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
- elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
- m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
- m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
- m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- apply_weighted_forward(m)
- if m.cond_stage_key == "edit":
- sd_hijack_unet.hijack_ddpm_edit()
- self.apply_optimizations()
- self.clip = m.cond_stage_model
- def flatten(el):
- flattened = [flatten(children) for children in el.children()]
- res = [el]
- for c in flattened:
- res += c
- return res
- self.layers = flatten(m)
- if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
- ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
- def undo_hijack(self, m):
- if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
- m.cond_stage_model = m.cond_stage_model.wrapped
- elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
- m.cond_stage_model = m.cond_stage_model.wrapped
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
- model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
- elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
- m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
- m.cond_stage_model = m.cond_stage_model.wrapped
- undo_optimizations()
- undo_weighted_forward(m)
- self.apply_circular(False)
- self.layers = None
- self.clip = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
- def apply_circular(self, enable):
- if self.circular_enabled == enable:
- return
- self.circular_enabled = enable
- for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
- layer.padding_mode = 'circular' if enable else 'zeros'
- def clear_comments(self):
- self.comments = []
- self.extra_generation_params = {}
- def get_prompt_lengths(self, text):
- if self.clip is None:
- return "-", "-"
- _, token_count = self.clip.process_texts([text])
- return token_count, self.clip.get_target_prompt_token_count(token_count)
- def redo_hijack(self, m):
- self.undo_hijack(m)
- self.hijack(m)
- class EmbeddingsWithFixes(torch.nn.Module):
- def __init__(self, wrapped, embeddings):
- super().__init__()
- self.wrapped = wrapped
- self.embeddings = embeddings
- def forward(self, input_ids):
- batch_fixes = self.embeddings.fixes
- self.embeddings.fixes = None
- inputs_embeds = self.wrapped(input_ids)
- if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
- return inputs_embeds
- vecs = []
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, embedding in fixes:
- emb = devices.cond_cast_unet(embedding.vec)
- emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
- vecs.append(tensor)
- return torch.stack(vecs)
- def add_circular_option_to_conv_2d():
- conv2d_constructor = torch.nn.Conv2d.__init__
- def conv2d_constructor_circular(self, *args, **kwargs):
- return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
- torch.nn.Conv2d.__init__ = conv2d_constructor_circular
- model_hijack = StableDiffusionModelHijack()
- def register_buffer(self, name, attr):
- """
- Fix register buffer bug for Mac OS.
- """
- if type(attr) == torch.Tensor:
- if attr.device != devices.device:
- attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
- setattr(self, name, attr)
- ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
- ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|