123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- import re
- from collections import defaultdict
- from modules import errors
- extra_network_registry = {}
- extra_network_aliases = {}
- def initialize():
- extra_network_registry.clear()
- extra_network_aliases.clear()
- def register_extra_network(extra_network):
- extra_network_registry[extra_network.name] = extra_network
- def register_extra_network_alias(extra_network, alias):
- extra_network_aliases[alias] = extra_network
- def register_default_extra_networks():
- from modules.extra_networks_hypernet import ExtraNetworkHypernet
- register_extra_network(ExtraNetworkHypernet())
- class ExtraNetworkParams:
- def __init__(self, items=None):
- self.items = items or []
- self.positional = []
- self.named = {}
- for item in self.items:
- parts = item.split('=', 2) if isinstance(item, str) else [item]
- if len(parts) == 2:
- self.named[parts[0]] = parts[1]
- else:
- self.positional.append(item)
- def __eq__(self, other):
- return self.items == other.items
- class ExtraNetwork:
- def __init__(self, name):
- self.name = name
- def activate(self, p, params_list):
- """
- Called by processing on every run. Whatever the extra network is meant to do should be activated here.
- Passes arguments related to this extra network in params_list.
- User passes arguments by specifying this in his prompt:
- <name:arg1:arg2:arg3>
- Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
- separated by colon.
- Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
- in this case, all effects of this extra networks should be disabled.
- Can be called multiple times before deactivate() - each new call should override the previous call completely.
- For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
- > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
- params_list will be:
- [
- ExtraNetworkParams(items=["agm", "1.1"]),
- ExtraNetworkParams(items=["ray"])
- ]
- """
- raise NotImplementedError
- def deactivate(self, p):
- """
- Called at the end of processing for housekeeping. No need to do anything here.
- """
- raise NotImplementedError
- def activate(p, extra_network_data):
- """call activate for extra networks in extra_network_data in specified order, then call
- activate for all remaining registered networks with an empty argument list"""
- activated = []
- for extra_network_name, extra_network_args in extra_network_data.items():
- extra_network = extra_network_registry.get(extra_network_name, None)
- if extra_network is None:
- extra_network = extra_network_aliases.get(extra_network_name, None)
- if extra_network is None:
- print(f"Skipping unknown extra network: {extra_network_name}")
- continue
- try:
- extra_network.activate(p, extra_network_args)
- activated.append(extra_network)
- except Exception as e:
- errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
- for extra_network_name, extra_network in extra_network_registry.items():
- if extra_network in activated:
- continue
- try:
- extra_network.activate(p, [])
- except Exception as e:
- errors.display(e, f"activating extra network {extra_network_name}")
- if p.scripts is not None:
- p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
- def deactivate(p, extra_network_data):
- """call deactivate for extra networks in extra_network_data in specified order, then call
- deactivate for all remaining registered networks"""
- for extra_network_name in extra_network_data:
- extra_network = extra_network_registry.get(extra_network_name, None)
- if extra_network is None:
- continue
- try:
- extra_network.deactivate(p)
- except Exception as e:
- errors.display(e, f"deactivating extra network {extra_network_name}")
- for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
- continue
- try:
- extra_network.deactivate(p)
- except Exception as e:
- errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
- re_extra_net = re.compile(r"<(\w+):([^>]+)>")
- def parse_prompt(prompt):
- res = defaultdict(list)
- def found(m):
- name = m.group(1)
- args = m.group(2)
- res[name].append(ExtraNetworkParams(items=args.split(":")))
- return ""
- prompt = re.sub(re_extra_net, found, prompt)
- return prompt, res
- def parse_prompts(prompts):
- res = []
- extra_data = None
- for prompt in prompts:
- updated_prompt, parsed_extra_data = parse_prompt(prompt)
- if extra_data is None:
- extra_data = parsed_extra_data
- res.append(updated_prompt)
- return res, extra_data
|