scripts_postprocessing.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import os
  2. import gradio as gr
  3. from modules import errors, shared
  4. class PostprocessedImage:
  5. def __init__(self, image):
  6. self.image = image
  7. self.info = {}
  8. class ScriptPostprocessing:
  9. filename = None
  10. controls = None
  11. args_from = None
  12. args_to = None
  13. order = 1000
  14. """scripts will be ordred by this value in postprocessing UI"""
  15. name = None
  16. """this function should return the title of the script."""
  17. group = None
  18. """A gr.Group component that has all script's UI inside it"""
  19. def ui(self):
  20. """
  21. This function should create gradio UI elements. See https://gradio.app/docs/#components
  22. The return value should be a dictionary that maps parameter names to components used in processing.
  23. Values of those components will be passed to process() function.
  24. """
  25. pass
  26. def process(self, pp: PostprocessedImage, **args):
  27. """
  28. This function is called to postprocess the image.
  29. args contains a dictionary with all values returned by components from ui()
  30. """
  31. pass
  32. def image_changed(self):
  33. pass
  34. def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
  35. try:
  36. res = func(*args, **kwargs)
  37. return res
  38. except Exception as e:
  39. errors.display(e, f"calling {filename}/{funcname}")
  40. return default
  41. class ScriptPostprocessingRunner:
  42. def __init__(self):
  43. self.scripts = None
  44. self.ui_created = False
  45. def initialize_scripts(self, scripts_data):
  46. self.scripts = []
  47. for script_data in scripts_data:
  48. script: ScriptPostprocessing = script_data.script_class()
  49. script.filename = script_data.path
  50. if script.name == "Simple Upscale":
  51. continue
  52. self.scripts.append(script)
  53. def create_script_ui(self, script, inputs):
  54. script.args_from = len(inputs)
  55. script.args_to = len(inputs)
  56. script.controls = wrap_call(script.ui, script.filename, "ui")
  57. for control in script.controls.values():
  58. control.custom_script_source = os.path.basename(script.filename)
  59. inputs += list(script.controls.values())
  60. script.args_to = len(inputs)
  61. def scripts_in_preferred_order(self):
  62. if self.scripts is None:
  63. import modules.scripts
  64. self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
  65. scripts_order = shared.opts.postprocessing_operation_order
  66. def script_score(name):
  67. for i, possible_match in enumerate(scripts_order):
  68. if possible_match == name:
  69. return i
  70. return len(self.scripts)
  71. script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
  72. return sorted(self.scripts, key=lambda x: script_scores[x.name])
  73. def setup_ui(self):
  74. inputs = []
  75. for script in self.scripts_in_preferred_order():
  76. with gr.Row() as group:
  77. self.create_script_ui(script, inputs)
  78. script.group = group
  79. self.ui_created = True
  80. return inputs
  81. def run(self, pp: PostprocessedImage, args):
  82. for script in self.scripts_in_preferred_order():
  83. shared.state.job = script.name
  84. script_args = args[script.args_from:script.args_to]
  85. process_args = {}
  86. for (name, _component), value in zip(script.controls.items(), script_args):
  87. process_args[name] = value
  88. script.process(pp, **process_args)
  89. def create_args_for_run(self, scripts_args):
  90. if not self.ui_created:
  91. with gr.Blocks(analytics_enabled=False):
  92. self.setup_ui()
  93. scripts = self.scripts_in_preferred_order()
  94. args = [None] * max([x.args_to for x in scripts])
  95. for script in scripts:
  96. script_args_dict = scripts_args.get(script.name, None)
  97. if script_args_dict is not None:
  98. for i, name in enumerate(script.controls):
  99. args[script.args_from + i] = script_args_dict.get(name, None)
  100. return args
  101. def image_changed(self):
  102. for script in self.scripts_in_preferred_order():
  103. script.image_changed()