123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- import inspect
- import os
- from collections import namedtuple
- from typing import Optional, Dict, Any
- from fastapi import FastAPI
- from gradio import Blocks
- from modules import errors, timer
- def report_exception(c, job):
- errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
- class ImageSaveParams:
- def __init__(self, image, p, filename, pnginfo):
- self.image = image
- """the PIL image itself"""
- self.p = p
- """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
- self.filename = filename
- """name of file that the image would be saved to"""
- self.pnginfo = pnginfo
- """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
- class CFGDenoiserParams:
- def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
- self.x = x
- """Latent image representation in the process of being denoised"""
- self.image_cond = image_cond
- """Conditioning image"""
- self.sigma = sigma
- """Current sigma noise step value"""
- self.sampling_step = sampling_step
- """Current Sampling step number"""
- self.total_sampling_steps = total_sampling_steps
- """Total number of sampling steps planned"""
- self.text_cond = text_cond
- """ Encoder hidden states of text conditioning from prompt"""
- self.text_uncond = text_uncond
- """ Encoder hidden states of text conditioning from negative prompt"""
- class CFGDenoisedParams:
- def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
- self.x = x
- """Latent image representation in the process of being denoised"""
- self.sampling_step = sampling_step
- """Current Sampling step number"""
- self.total_sampling_steps = total_sampling_steps
- """Total number of sampling steps planned"""
- self.inner_model = inner_model
- """Inner model reference used for denoising"""
- class AfterCFGCallbackParams:
- def __init__(self, x, sampling_step, total_sampling_steps):
- self.x = x
- """Latent image representation in the process of being denoised"""
- self.sampling_step = sampling_step
- """Current Sampling step number"""
- self.total_sampling_steps = total_sampling_steps
- """Total number of sampling steps planned"""
- class UiTrainTabParams:
- def __init__(self, txt2img_preview_params):
- self.txt2img_preview_params = txt2img_preview_params
- class ImageGridLoopParams:
- def __init__(self, imgs, cols, rows):
- self.imgs = imgs
- self.cols = cols
- self.rows = rows
- ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
- callback_map = dict(
- callbacks_app_started=[],
- callbacks_model_loaded=[],
- callbacks_ui_tabs=[],
- callbacks_ui_train_tabs=[],
- callbacks_ui_settings=[],
- callbacks_before_image_saved=[],
- callbacks_image_saved=[],
- callbacks_cfg_denoiser=[],
- callbacks_cfg_denoised=[],
- callbacks_cfg_after_cfg=[],
- callbacks_before_component=[],
- callbacks_after_component=[],
- callbacks_image_grid=[],
- callbacks_infotext_pasted=[],
- callbacks_script_unloaded=[],
- callbacks_before_ui=[],
- callbacks_on_reload=[],
- callbacks_list_optimizers=[],
- callbacks_list_unets=[],
- )
- def clear_callbacks():
- for callback_list in callback_map.values():
- callback_list.clear()
- def app_started_callback(demo: Optional[Blocks], app: FastAPI):
- for c in callback_map['callbacks_app_started']:
- try:
- c.callback(demo, app)
- timer.startup_timer.record(os.path.basename(c.script))
- except Exception:
- report_exception(c, 'app_started_callback')
- def app_reload_callback():
- for c in callback_map['callbacks_on_reload']:
- try:
- c.callback()
- except Exception:
- report_exception(c, 'callbacks_on_reload')
- def model_loaded_callback(sd_model):
- for c in callback_map['callbacks_model_loaded']:
- try:
- c.callback(sd_model)
- except Exception:
- report_exception(c, 'model_loaded_callback')
- def ui_tabs_callback():
- res = []
- for c in callback_map['callbacks_ui_tabs']:
- try:
- res += c.callback() or []
- except Exception:
- report_exception(c, 'ui_tabs_callback')
- return res
- def ui_train_tabs_callback(params: UiTrainTabParams):
- for c in callback_map['callbacks_ui_train_tabs']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'callbacks_ui_train_tabs')
- def ui_settings_callback():
- for c in callback_map['callbacks_ui_settings']:
- try:
- c.callback()
- except Exception:
- report_exception(c, 'ui_settings_callback')
- def before_image_saved_callback(params: ImageSaveParams):
- for c in callback_map['callbacks_before_image_saved']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'before_image_saved_callback')
- def image_saved_callback(params: ImageSaveParams):
- for c in callback_map['callbacks_image_saved']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'image_saved_callback')
- def cfg_denoiser_callback(params: CFGDenoiserParams):
- for c in callback_map['callbacks_cfg_denoiser']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'cfg_denoiser_callback')
- def cfg_denoised_callback(params: CFGDenoisedParams):
- for c in callback_map['callbacks_cfg_denoised']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'cfg_denoised_callback')
- def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
- for c in callback_map['callbacks_cfg_after_cfg']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'cfg_after_cfg_callback')
- def before_component_callback(component, **kwargs):
- for c in callback_map['callbacks_before_component']:
- try:
- c.callback(component, **kwargs)
- except Exception:
- report_exception(c, 'before_component_callback')
- def after_component_callback(component, **kwargs):
- for c in callback_map['callbacks_after_component']:
- try:
- c.callback(component, **kwargs)
- except Exception:
- report_exception(c, 'after_component_callback')
- def image_grid_callback(params: ImageGridLoopParams):
- for c in callback_map['callbacks_image_grid']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'image_grid')
- def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
- for c in callback_map['callbacks_infotext_pasted']:
- try:
- c.callback(infotext, params)
- except Exception:
- report_exception(c, 'infotext_pasted')
- def script_unloaded_callback():
- for c in reversed(callback_map['callbacks_script_unloaded']):
- try:
- c.callback()
- except Exception:
- report_exception(c, 'script_unloaded')
- def before_ui_callback():
- for c in reversed(callback_map['callbacks_before_ui']):
- try:
- c.callback()
- except Exception:
- report_exception(c, 'before_ui')
- def list_optimizers_callback():
- res = []
- for c in callback_map['callbacks_list_optimizers']:
- try:
- c.callback(res)
- except Exception:
- report_exception(c, 'list_optimizers')
- return res
- def list_unets_callback():
- res = []
- for c in callback_map['callbacks_list_unets']:
- try:
- c.callback(res)
- except Exception:
- report_exception(c, 'list_unets')
- return res
- def add_callback(callbacks, fun):
- stack = [x for x in inspect.stack() if x.filename != __file__]
- filename = stack[0].filename if stack else 'unknown file'
- callbacks.append(ScriptCallback(filename, fun))
- def remove_current_script_callbacks():
- stack = [x for x in inspect.stack() if x.filename != __file__]
- filename = stack[0].filename if stack else 'unknown file'
- if filename == 'unknown file':
- return
- for callback_list in callback_map.values():
- for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
- callback_list.remove(callback_to_remove)
- def remove_callbacks_for_function(callback_func):
- for callback_list in callback_map.values():
- for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
- callback_list.remove(callback_to_remove)
- def on_app_started(callback):
- """register a function to be called when the webui started, the gradio `Block` component and
- fastapi `FastAPI` object are passed as the arguments"""
- add_callback(callback_map['callbacks_app_started'], callback)
- def on_before_reload(callback):
- """register a function to be called just before the server reloads."""
- add_callback(callback_map['callbacks_on_reload'], callback)
- def on_model_loaded(callback):
- """register a function to be called when the stable diffusion model is created; the model is
- passed as an argument; this function is also called when the script is reloaded. """
- add_callback(callback_map['callbacks_model_loaded'], callback)
- def on_ui_tabs(callback):
- """register a function to be called when the UI is creating new tabs.
- The function must either return a None, which means no new tabs to be added, or a list, where
- each element is a tuple:
- (gradio_component, title, elem_id)
- gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
- title is tab text displayed to user in the UI
- elem_id is HTML id for the tab
- """
- add_callback(callback_map['callbacks_ui_tabs'], callback)
- def on_ui_train_tabs(callback):
- """register a function to be called when the UI is creating new tabs for the train tab.
- Create your new tabs with gr.Tab.
- """
- add_callback(callback_map['callbacks_ui_train_tabs'], callback)
- def on_ui_settings(callback):
- """register a function to be called before UI settings are populated; add your settings
- by using shared.opts.add_option(shared.OptionInfo(...)) """
- add_callback(callback_map['callbacks_ui_settings'], callback)
- def on_before_image_saved(callback):
- """register a function to be called before an image is saved to a file.
- The callback is called with one argument:
- - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
- """
- add_callback(callback_map['callbacks_before_image_saved'], callback)
- def on_image_saved(callback):
- """register a function to be called after an image is saved to a file.
- The callback is called with one argument:
- - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
- """
- add_callback(callback_map['callbacks_image_saved'], callback)
- def on_cfg_denoiser(callback):
- """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
- The callback is called with one argument:
- - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
- """
- add_callback(callback_map['callbacks_cfg_denoiser'], callback)
- def on_cfg_denoised(callback):
- """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
- The callback is called with one argument:
- - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
- """
- add_callback(callback_map['callbacks_cfg_denoised'], callback)
- def on_cfg_after_cfg(callback):
- """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
- The callback is called with one argument:
- - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
- """
- add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
- def on_before_component(callback):
- """register a function to be called before a component is created.
- The callback is called with arguments:
- - component - gradio component that is about to be created.
- - **kwargs - args to gradio.components.IOComponent.__init__ function
- Use elem_id/label fields of kwargs to figure out which component it is.
- This can be useful to inject your own components somewhere in the middle of vanilla UI.
- """
- add_callback(callback_map['callbacks_before_component'], callback)
- def on_after_component(callback):
- """register a function to be called after a component is created. See on_before_component for more."""
- add_callback(callback_map['callbacks_after_component'], callback)
- def on_image_grid(callback):
- """register a function to be called before making an image grid.
- The callback is called with one argument:
- - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
- """
- add_callback(callback_map['callbacks_image_grid'], callback)
- def on_infotext_pasted(callback):
- """register a function to be called before applying an infotext.
- The callback is called with two arguments:
- - infotext: str - raw infotext.
- - result: Dict[str, any] - parsed infotext parameters.
- """
- add_callback(callback_map['callbacks_infotext_pasted'], callback)
- def on_script_unloaded(callback):
- """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
- the script did should be reverted here"""
- add_callback(callback_map['callbacks_script_unloaded'], callback)
- def on_before_ui(callback):
- """register a function to be called before the UI is created."""
- add_callback(callback_map['callbacks_before_ui'], callback)
- def on_list_optimizers(callback):
- """register a function to be called when UI is making a list of cross attention optimization options.
- The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
- to it."""
- add_callback(callback_map['callbacks_list_optimizers'], callback)
- def on_list_unets(callback):
- """register a function to be called when UI is making a list of alternative options for unet.
- The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
- add_callback(callback_map['callbacks_list_unets'], callback)
|