123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- """
- Supports saving and restoring webui and extensions from a known working set of commits
- """
- import os
- import json
- import time
- import tqdm
- from datetime import datetime
- from collections import OrderedDict
- import git
- from modules import shared, extensions, errors
- from modules.paths_internal import script_path, config_states_dir
- all_config_states = OrderedDict()
- def list_config_states():
- global all_config_states
- all_config_states.clear()
- os.makedirs(config_states_dir, exist_ok=True)
- config_states = []
- for filename in os.listdir(config_states_dir):
- if filename.endswith(".json"):
- path = os.path.join(config_states_dir, filename)
- with open(path, "r", encoding="utf-8") as f:
- j = json.load(f)
- j["filepath"] = path
- config_states.append(j)
- config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
- for cs in config_states:
- timestamp = time.asctime(time.gmtime(cs["created_at"]))
- name = cs.get("name", "Config")
- full_name = f"{name}: {timestamp}"
- all_config_states[full_name] = cs
- return all_config_states
- def get_webui_config():
- webui_repo = None
- try:
- if os.path.exists(os.path.join(script_path, ".git")):
- webui_repo = git.Repo(script_path)
- except Exception:
- errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
- webui_remote = None
- webui_commit_hash = None
- webui_commit_date = None
- webui_branch = None
- if webui_repo and not webui_repo.bare:
- try:
- webui_remote = next(webui_repo.remote().urls, None)
- head = webui_repo.head.commit
- webui_commit_date = webui_repo.head.commit.committed_date
- webui_commit_hash = head.hexsha
- webui_branch = webui_repo.active_branch.name
- except Exception:
- webui_remote = None
- return {
- "remote": webui_remote,
- "commit_hash": webui_commit_hash,
- "commit_date": webui_commit_date,
- "branch": webui_branch,
- }
- def get_extension_config():
- ext_config = {}
- for ext in extensions.extensions:
- ext.read_info_from_repo()
- entry = {
- "name": ext.name,
- "path": ext.path,
- "enabled": ext.enabled,
- "is_builtin": ext.is_builtin,
- "remote": ext.remote,
- "commit_hash": ext.commit_hash,
- "commit_date": ext.commit_date,
- "branch": ext.branch,
- "have_info_from_repo": ext.have_info_from_repo
- }
- ext_config[ext.name] = entry
- return ext_config
- def get_config():
- creation_time = datetime.now().timestamp()
- webui_config = get_webui_config()
- ext_config = get_extension_config()
- return {
- "created_at": creation_time,
- "webui": webui_config,
- "extensions": ext_config
- }
- def restore_webui_config(config):
- print("* Restoring webui state...")
- if "webui" not in config:
- print("Error: No webui data saved to config")
- return
- webui_config = config["webui"]
- if "commit_hash" not in webui_config:
- print("Error: No commit saved to webui config")
- return
- webui_commit_hash = webui_config.get("commit_hash", None)
- webui_repo = None
- try:
- if os.path.exists(os.path.join(script_path, ".git")):
- webui_repo = git.Repo(script_path)
- except Exception:
- errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
- return
- try:
- webui_repo.git.fetch(all=True)
- webui_repo.git.reset(webui_commit_hash, hard=True)
- print(f"* Restored webui to commit {webui_commit_hash}.")
- except Exception:
- errors.report(f"Error restoring webui to commit{webui_commit_hash}")
- def restore_extension_config(config):
- print("* Restoring extension state...")
- if "extensions" not in config:
- print("Error: No extension data saved to config")
- return
- ext_config = config["extensions"]
- results = []
- disabled = []
- for ext in tqdm.tqdm(extensions.extensions):
- if ext.is_builtin:
- continue
- ext.read_info_from_repo()
- current_commit = ext.commit_hash
- if ext.name not in ext_config:
- ext.disabled = True
- disabled.append(ext.name)
- results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
- continue
- entry = ext_config[ext.name]
- if "commit_hash" in entry and entry["commit_hash"]:
- try:
- ext.fetch_and_reset_hard(entry["commit_hash"])
- ext.read_info_from_repo()
- if current_commit != entry["commit_hash"]:
- results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
- except Exception as ex:
- results.append((ext, current_commit[:8], False, ex))
- else:
- results.append((ext, current_commit[:8], False, "No commit hash found in config"))
- if not entry.get("enabled", False):
- ext.disabled = True
- disabled.append(ext.name)
- else:
- ext.disabled = False
- shared.opts.disabled_extensions = disabled
- shared.opts.save(shared.config_filename)
- print("* Finished restoring extensions. Results:")
- for ext, prev_commit, success, result in results:
- if success:
- print(f" + {ext.name}: {prev_commit} -> {result}")
- else:
- print(f" ! {ext.name}: FAILURE ({result})")
|