123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- import os
- import collections
- from modules import paths, shared, devices, script_callbacks, sd_models
- import glob
- from copy import deepcopy
- vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
- vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
- vae_dict = {}
- base_vae = None
- loaded_vae_file = None
- checkpoint_info = None
- checkpoints_loaded = collections.OrderedDict()
- def get_base_vae(model):
- if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
- return base_vae
- return None
- def store_base_vae(model):
- global base_vae, checkpoint_info
- if checkpoint_info != model.sd_checkpoint_info:
- assert not loaded_vae_file, "Trying to store non-base VAE!"
- base_vae = deepcopy(model.first_stage_model.state_dict())
- checkpoint_info = model.sd_checkpoint_info
- def delete_base_vae():
- global base_vae, checkpoint_info
- base_vae = None
- checkpoint_info = None
- def restore_base_vae(model):
- global loaded_vae_file
- if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
- print("Restoring base VAE")
- _load_vae_dict(model, base_vae)
- loaded_vae_file = None
- delete_base_vae()
- def get_filename(filepath):
- return os.path.basename(filepath)
- def refresh_vae_list():
- vae_dict.clear()
- paths = [
- os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
- os.path.join(sd_models.model_path, '**/*.vae.pt'),
- os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
- os.path.join(vae_path, '**/*.ckpt'),
- os.path.join(vae_path, '**/*.pt'),
- os.path.join(vae_path, '**/*.safetensors'),
- ]
- if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
- paths += [
- os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
- os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
- os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
- ]
- if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
- paths += [
- os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
- os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
- os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
- ]
- candidates = []
- for path in paths:
- candidates += glob.iglob(path, recursive=True)
- for filepath in candidates:
- name = get_filename(filepath)
- vae_dict[name] = filepath
- def find_vae_near_checkpoint(checkpoint_file):
- checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
- for vae_file in vae_dict.values():
- if os.path.basename(vae_file).startswith(checkpoint_path):
- return vae_file
- return None
- def resolve_vae(checkpoint_file):
- if shared.cmd_opts.vae_path is not None:
- return shared.cmd_opts.vae_path, 'from commandline argument'
- is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
- vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
- if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
- return vae_near_checkpoint, 'found near the checkpoint'
- if shared.opts.sd_vae == "None":
- return None, None
- vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
- if vae_from_options is not None:
- return vae_from_options, 'specified in settings'
- if not is_automatic:
- print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
- return None, None
- def load_vae_dict(filename, map_location):
- vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
- return vae_dict_1
- def load_vae(model, vae_file=None, vae_source="from unknown source"):
- global vae_dict, loaded_vae_file
- # save_settings = False
- cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
- if vae_file:
- if cache_enabled and vae_file in checkpoints_loaded:
- # use vae checkpoint cache
- print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
- store_base_vae(model)
- _load_vae_dict(model, checkpoints_loaded[vae_file])
- else:
- assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
- print(f"Loading VAE weights {vae_source}: {vae_file}")
- store_base_vae(model)
- vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
- _load_vae_dict(model, vae_dict_1)
- if cache_enabled:
- # cache newly loaded vae
- checkpoints_loaded[vae_file] = vae_dict_1.copy()
- # clean up cache if limit is reached
- if cache_enabled:
- while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
- checkpoints_loaded.popitem(last=False) # LRU
- # If vae used is not in dict, update it
- # It will be removed on refresh though
- vae_opt = get_filename(vae_file)
- if vae_opt not in vae_dict:
- vae_dict[vae_opt] = vae_file
- elif loaded_vae_file:
- restore_base_vae(model)
- loaded_vae_file = vae_file
- # don't call this from outside
- def _load_vae_dict(model, vae_dict_1):
- model.first_stage_model.load_state_dict(vae_dict_1)
- model.first_stage_model.to(devices.dtype_vae)
- def clear_loaded_vae():
- global loaded_vae_file
- loaded_vae_file = None
- unspecified = object()
- def reload_vae_weights(sd_model=None, vae_file=unspecified):
- from modules import lowvram, devices, sd_hijack
- if not sd_model:
- sd_model = shared.sd_model
- checkpoint_info = sd_model.sd_checkpoint_info
- checkpoint_file = checkpoint_info.filename
- if vae_file == unspecified:
- vae_file, vae_source = resolve_vae(checkpoint_file)
- else:
- vae_source = "from function argument"
- if loaded_vae_file == vae_file:
- return
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- else:
- sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(sd_model)
- load_vae(sd_model, vae_file, vae_source)
- sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
- print("VAE weights loaded.")
- return sd_model
|