extra_networks.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import re
  2. from collections import defaultdict
  3. from modules import errors
  4. extra_network_registry = {}
  5. extra_network_aliases = {}
  6. def initialize():
  7. extra_network_registry.clear()
  8. extra_network_aliases.clear()
  9. def register_extra_network(extra_network):
  10. extra_network_registry[extra_network.name] = extra_network
  11. def register_extra_network_alias(extra_network, alias):
  12. extra_network_aliases[alias] = extra_network
  13. def register_default_extra_networks():
  14. from modules.extra_networks_hypernet import ExtraNetworkHypernet
  15. register_extra_network(ExtraNetworkHypernet())
  16. class ExtraNetworkParams:
  17. def __init__(self, items=None):
  18. self.items = items or []
  19. self.positional = []
  20. self.named = {}
  21. for item in self.items:
  22. parts = item.split('=', 2) if isinstance(item, str) else [item]
  23. if len(parts) == 2:
  24. self.named[parts[0]] = parts[1]
  25. else:
  26. self.positional.append(item)
  27. def __eq__(self, other):
  28. return self.items == other.items
  29. class ExtraNetwork:
  30. def __init__(self, name):
  31. self.name = name
  32. def activate(self, p, params_list):
  33. """
  34. Called by processing on every run. Whatever the extra network is meant to do should be activated here.
  35. Passes arguments related to this extra network in params_list.
  36. User passes arguments by specifying this in his prompt:
  37. <name:arg1:arg2:arg3>
  38. Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
  39. separated by colon.
  40. Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
  41. in this case, all effects of this extra networks should be disabled.
  42. Can be called multiple times before deactivate() - each new call should override the previous call completely.
  43. For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
  44. > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
  45. params_list will be:
  46. [
  47. ExtraNetworkParams(items=["agm", "1.1"]),
  48. ExtraNetworkParams(items=["ray"])
  49. ]
  50. """
  51. raise NotImplementedError
  52. def deactivate(self, p):
  53. """
  54. Called at the end of processing for housekeeping. No need to do anything here.
  55. """
  56. raise NotImplementedError
  57. def activate(p, extra_network_data):
  58. """call activate for extra networks in extra_network_data in specified order, then call
  59. activate for all remaining registered networks with an empty argument list"""
  60. activated = []
  61. for extra_network_name, extra_network_args in extra_network_data.items():
  62. extra_network = extra_network_registry.get(extra_network_name, None)
  63. if extra_network is None:
  64. extra_network = extra_network_aliases.get(extra_network_name, None)
  65. if extra_network is None:
  66. print(f"Skipping unknown extra network: {extra_network_name}")
  67. continue
  68. try:
  69. extra_network.activate(p, extra_network_args)
  70. activated.append(extra_network)
  71. except Exception as e:
  72. errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
  73. for extra_network_name, extra_network in extra_network_registry.items():
  74. if extra_network in activated:
  75. continue
  76. try:
  77. extra_network.activate(p, [])
  78. except Exception as e:
  79. errors.display(e, f"activating extra network {extra_network_name}")
  80. if p.scripts is not None:
  81. p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
  82. def deactivate(p, extra_network_data):
  83. """call deactivate for extra networks in extra_network_data in specified order, then call
  84. deactivate for all remaining registered networks"""
  85. for extra_network_name in extra_network_data:
  86. extra_network = extra_network_registry.get(extra_network_name, None)
  87. if extra_network is None:
  88. continue
  89. try:
  90. extra_network.deactivate(p)
  91. except Exception as e:
  92. errors.display(e, f"deactivating extra network {extra_network_name}")
  93. for extra_network_name, extra_network in extra_network_registry.items():
  94. args = extra_network_data.get(extra_network_name, None)
  95. if args is not None:
  96. continue
  97. try:
  98. extra_network.deactivate(p)
  99. except Exception as e:
  100. errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
  101. re_extra_net = re.compile(r"<(\w+):([^>]+)>")
  102. def parse_prompt(prompt):
  103. res = defaultdict(list)
  104. def found(m):
  105. name = m.group(1)
  106. args = m.group(2)
  107. res[name].append(ExtraNetworkParams(items=args.split(":")))
  108. return ""
  109. prompt = re.sub(re_extra_net, found, prompt)
  110. return prompt, res
  111. def parse_prompts(prompts):
  112. res = []
  113. extra_data = None
  114. for prompt in prompts:
  115. updated_prompt, parsed_extra_data = parse_prompt(prompt)
  116. if extra_data is None:
  117. extra_data = parsed_extra_data
  118. res.append(updated_prompt)
  119. return res, extra_data