Browse Source

进度查询

rambo 1 year ago
parent
commit
909151359a
4 changed files with 651 additions and 317 deletions
  1. 198 102
      api/api.py
  2. 161 64
      api/models.py
  3. 255 130
      processing.py
  4. 37 21
      shared.py

+ 198 - 102
api/api.py

@@ -22,7 +22,7 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
 from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
 from modules.textual_inversion.preprocess import preprocess
 from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
-from PIL import PngImagePlugin,Image
+from PIL import PngImagePlugin, Image
 from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
 from modules.sd_vae import vae_dict
 from modules.sd_models_config import find_checkpoint_config_near_filename
@@ -38,7 +38,8 @@ def script_name_to_index(name, scripts):
     try:
         return [script.title().lower() for script in scripts].index(name.lower())
     except Exception as e:
-        raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
+        raise HTTPException(
+            status_code=422, detail=f"Script '{name}' not found") from e
 
 
 def validate_sampler_name(name):
@@ -63,7 +64,8 @@ def decode_base64_to_image(encoding):
         image = Image.open(BytesIO(base64.b64decode(encoding)))
         return image
     except Exception as e:
-        raise HTTPException(status_code=500, detail="Invalid encoded image") from e
+        raise HTTPException(
+            status_code=500, detail="Invalid encoded image") from e
 
 
 def encode_pil_to_base64(image):
@@ -76,19 +78,22 @@ def encode_pil_to_base64(image):
                 if isinstance(key, str) and isinstance(value, str):
                     metadata.add_text(key, value)
                     use_metadata = True
-            image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
+            image.save(output_bytes, format="PNG", pnginfo=(
+                metadata if use_metadata else None), quality=opts.jpeg_quality)
 
         elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
             if image.mode == "RGBA":
                 image = image.convert("RGB")
             parameters = image.info.get('parameters', None)
             exif_bytes = piexif.dump({
-                "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
+                "Exif": {piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode")}
             })
             if opts.samples_format.lower() in ("jpg", "jpeg"):
-                image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
+                image.save(output_bytes, format="JPEG",
+                           exif=exif_bytes, quality=opts.jpeg_quality)
             else:
-                image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
+                image.save(output_bytes, format="WEBP",
+                           exif=exif_bytes, quality=opts.jpeg_quality)
 
         else:
             raise HTTPException(status_code=500, detail="Invalid image format")
@@ -137,11 +142,13 @@ def api_middleware(app: FastAPI):
             "body": vars(e).get('body', ''),
             "errors": str(e),
         }
-        if not isinstance(e, HTTPException):  # do not print backtrace on known httpexceptions
+        # do not print backtrace on known httpexceptions
+        if not isinstance(e, HTTPException):
             message = f"API error: {request.method}: {request.url} {err}"
             if rich_available:
                 print(message)
-                console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
+                console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[
+                                        anyio, starlette], word_wrap=False, width=min([console.width, 200]))
             else:
                 errors.report(message, exc_info=True)
         return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@@ -174,44 +181,79 @@ class Api:
         self.app = app
         self.queue_lock = queue_lock
         api_middleware(self.app)
-        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
-        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
-        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
-        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
-        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
-        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
-        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
-        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi,
+                           methods=["POST"], response_model=models.TextToImageResponse)
+        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi,
+                           methods=["POST"], response_model=models.ImageToImageResponse)
+        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api,
+                           methods=["POST"], response_model=models.ExtrasSingleImageResponse)
+        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api,
+                           methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
+        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi,
+                           methods=["POST"], response_model=models.PNGInfoResponse)
+        self.add_api_route("/sdapi/v1/progress", self.progressapi,
+                           methods=["GET", "POST"], response_model=models.ProgressResponse)
+        self.add_api_route("/sdapi/v1/interrogate",
+                           self.interrogateapi, methods=["POST"])
+        self.add_api_route("/sdapi/v1/interrupt",
+                           self.interruptapi, methods=["POST"])
         self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
-        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
-        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
-        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
-        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
-        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
-        self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
-        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
-        self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
-        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
-        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
-        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
-        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
-        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
-        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
-        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
-        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
-        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
-        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
-        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
-        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
-        self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
-        self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
-        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
-        self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
+        self.add_api_route("/sdapi/v1/options", self.get_config,
+                           methods=["GET"], response_model=models.OptionsModel)
+        self.add_api_route("/sdapi/v1/options",
+                           self.set_config, methods=["POST"])
+        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags,
+                           methods=["GET"], response_model=models.FlagsModel)
+        self.add_api_route("/sdapi/v1/samplers", self.get_samplers,
+                           methods=["GET"], response_model=List[models.SamplerItem])
+        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers,
+                           methods=["GET"], response_model=List[models.UpscalerItem])
+        self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes,
+                           methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
+        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models,
+                           methods=["GET"], response_model=List[models.SDModelItem])
+        self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes,
+                           methods=["GET"], response_model=List[models.SDVaeItem])
+        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks,
+                           methods=["GET"], response_model=List[models.HypernetworkItem])
+        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers,
+                           methods=["GET"], response_model=List[models.FaceRestorerItem])
+        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models,
+                           methods=["GET"], response_model=List[models.RealesrganItem])
+        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles,
+                           methods=["GET"], response_model=List[models.PromptStyleItem])
+        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings,
+                           methods=["GET"], response_model=models.EmbeddingsResponse)
+        self.add_api_route("/sdapi/v1/refresh-checkpoints",
+                           self.refresh_checkpoints, methods=["POST"])
+        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding,
+                           methods=["POST"], response_model=models.CreateResponse)
+        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork,
+                           methods=["POST"], response_model=models.CreateResponse)
+        self.add_api_route("/sdapi/v1/preprocess", self.preprocess,
+                           methods=["POST"], response_model=models.PreprocessResponse)
+        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding,
+                           methods=["POST"], response_model=models.TrainResponse)
+        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork,
+                           methods=["POST"], response_model=models.TrainResponse)
+        self.add_api_route("/sdapi/v1/memory", self.get_memory,
+                           methods=["GET"], response_model=models.MemoryResponse)
+        self.add_api_route("/sdapi/v1/unload-checkpoint",
+                           self.unloadapi, methods=["POST"])
+        self.add_api_route("/sdapi/v1/reload-checkpoint",
+                           self.reloadapi, methods=["POST"])
+        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list,
+                           methods=["GET"], response_model=models.ScriptsList)
+        self.add_api_route("/sdapi/v1/script-info", self.get_script_info,
+                           methods=["GET"], response_model=List[models.ScriptInfo])
 
         if shared.cmd_opts.api_server_stop:
-            self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
-            self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
-            self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
+            self.add_api_route("/sdapi/v1/server-kill",
+                               self.kill_webui, methods=["POST"])
+            self.add_api_route("/sdapi/v1/server-restart",
+                               self.restart_webui, methods=["POST"])
+            self.add_api_route("/sdapi/v1/server-stop",
+                               self.stop_webui, methods=["POST"])
 
         self.default_script_arg_txt2img = []
         self.default_script_arg_img2img = []
@@ -226,19 +268,23 @@ class Api:
             if compare_digest(credentials.password, self.credentials[credentials.username]):
                 return True
 
-        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
+        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={
+                            "WWW-Authenticate": "Basic"})
 
     def get_selectable_script(self, script_name, script_runner):
         if script_name is None or script_name == "":
             return None, None
 
-        script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
+        script_idx = script_name_to_index(
+            script_name, script_runner.selectable_scripts)
         script = script_runner.selectable_scripts[script_idx]
         return script, script_idx
 
     def get_scripts_list(self):
-        t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
-        i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
+        t2ilist = [
+            script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
+        i2ilist = [
+            script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
 
         return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
 
@@ -258,7 +304,7 @@ class Api:
         return script_runner.scripts[script_idx]
 
     def init_default_script_args(self, script_runner):
-        #find max idx from the scripts in runner and generate a none array to init script_args
+        # find max idx from the scripts in runner and generate a none array to init script_args
         last_arg_index = 1
         for script in script_runner.scripts:
             if last_arg_index < script.args_to:
@@ -268,7 +314,7 @@ class Api:
         script_args[0] = 0
 
         # get default values
-        with gr.Blocks(): # will throw errors calling ui function without this
+        with gr.Blocks():  # will throw errors calling ui function without this
             for script in script_runner.scripts:
                 if script.ui(script.is_img2img):
                     ui_default_values = []
@@ -281,33 +327,43 @@ class Api:
         script_args = default_script_args.copy()
         # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
         if selectable_scripts:
-            script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
+            script_args[selectable_scripts.args_from:
+                        selectable_scripts.args_to] = request.script_args
             script_args[0] = selectable_idx + 1
 
         # Now check for always on scripts
         if request.alwayson_scripts:
             for alwayson_script_name in request.alwayson_scripts.keys():
-                alwayson_script = self.get_script(alwayson_script_name, script_runner)
+                alwayson_script = self.get_script(
+                    alwayson_script_name, script_runner)
                 if alwayson_script is None:
-                    raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
+                    raise HTTPException(
+                        status_code=422, detail=f"always on script {alwayson_script_name} not found")
                 # Selectable script in always on script param check
                 if alwayson_script.alwayson is False:
-                    raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
+                    raise HTTPException(
+                        status_code=422, detail="Cannot have a selectable script in the always on scripts params")
                 # always on script with no arg should always run so you don't really need to add them to the requests
                 if "args" in request.alwayson_scripts[alwayson_script_name]:
                     # min between arg length in scriptrunner and arg length in the request
                     for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
-                        script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
+                        script_args[alwayson_script.args_from +
+                                    idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
         return script_args
 
     def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
         script_runner = scripts.scripts_txt2img
+        task_id = txt2imgreq.task_id
+        if task_id is None:
+            raise HTTPException(status_code=404, detail="task_id not found")
         if not script_runner.scripts:
             script_runner.initialize_scripts(False)
             ui.create_ui()
         if not self.default_script_arg_txt2img:
-            self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
-        selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
+            self.default_script_arg_txt2img = self.init_default_script_args(
+                script_runner)
+        selectable_scripts, selectable_script_idx = self.get_selectable_script(
+            txt2imgreq.script_name, script_runner)
 
         populate = txt2imgreq.copy(update={  # Override __init__ params
             "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
@@ -319,10 +375,12 @@ class Api:
 
         args = vars(populate)
         args.pop('script_name', None)
-        args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+        # will refeed them to the pipeline directly after initializing them
+        args.pop('script_args', None)
         args.pop('alwayson_scripts', None)
 
-        script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
+        script_args = self.init_script_args(
+            txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
 
         send_images = args.pop('send_images', True)
         args.pop('save_images', None)
@@ -335,16 +393,20 @@ class Api:
 
                 try:
                     shared.state.begin(job="scripts_txt2img")
+                    shared.state.task_id = task_id
                     if selectable_scripts is not None:
                         p.script_args = script_args
-                        processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+                        processed = scripts.scripts_txt2img.run(
+                            p, *p.script_args)  # Need to pass args as list here
                     else:
-                        p.script_args = tuple(script_args) # Need to pass args as tuple here
+                        # Need to pass args as tuple here
+                        p.script_args = tuple(script_args)
                         processed = process_images(p)
                 finally:
                     shared.state.end()
 
-        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
+        b64images = list(
+            map(encode_pil_to_base64, processed.images)) if send_images else []
 
         return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
 
@@ -352,7 +414,9 @@ class Api:
         init_images = img2imgreq.init_images
         if init_images is None:
             raise HTTPException(status_code=404, detail="Init image not found")
-
+        task_id = img2imgreq.task_id
+        if task_id is None:
+            raise HTTPException(status_code=404, detail="task_id not found")
         mask = img2imgreq.mask
         if mask:
             mask = decode_base64_to_image(mask)
@@ -362,8 +426,10 @@ class Api:
             script_runner.initialize_scripts(True)
             ui.create_ui()
         if not self.default_script_arg_img2img:
-            self.default_script_arg_img2img = self.init_default_script_args(script_runner)
-        selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
+            self.default_script_arg_img2img = self.init_default_script_args(
+                script_runner)
+        selectable_scripts, selectable_script_idx = self.get_selectable_script(
+            img2imgreq.script_name, script_runner)
 
         populate = img2imgreq.copy(update={  # Override __init__ params
             "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
@@ -375,35 +441,43 @@ class Api:
             populate.sampler_index = None  # prevent a warning later on
 
         args = vars(populate)
-        args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+        # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+        args.pop('include_init_images', None)
         args.pop('script_name', None)
-        args.pop('script_args', None)  # will refeed them to the pipeline directly after initializing them
+        # will refeed them to the pipeline directly after initializing them
+        args.pop('script_args', None)
         args.pop('alwayson_scripts', None)
 
-        script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
+        script_args = self.init_script_args(
+            img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
 
         send_images = args.pop('send_images', True)
         args.pop('save_images', None)
 
         with self.queue_lock:
             with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
-                p.init_images = [decode_base64_to_image(x) for x in init_images]
+                p.init_images = [decode_base64_to_image(
+                    x) for x in init_images]
                 p.scripts = script_runner
                 p.outpath_grids = opts.outdir_img2img_grids
                 p.outpath_samples = opts.outdir_img2img_samples
 
                 try:
                     shared.state.begin(job="scripts_img2img")
+                    shared.state.task_id = task_id
                     if selectable_scripts is not None:
                         p.script_args = script_args
-                        processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+                        processed = scripts.scripts_img2img.run(
+                            p, *p.script_args)  # Need to pass args as list here
                     else:
-                        p.script_args = tuple(script_args) # Need to pass args as tuple here
+                        # Need to pass args as tuple here
+                        p.script_args = tuple(script_args)
                         processed = process_images(p)
                 finally:
                     shared.state.end()
 
-        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
+        b64images = list(
+            map(encode_pil_to_base64, processed.images)) if send_images else []
 
         if not img2imgreq.include_init_images:
             img2imgreq.init_images = None
@@ -417,7 +491,8 @@ class Api:
         reqDict['image'] = decode_base64_to_image(reqDict['image'])
 
         with self.queue_lock:
-            result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
+            result = postprocessing.run_extras(
+                extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
 
         return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
 
@@ -428,12 +503,13 @@ class Api:
         image_folder = [decode_base64_to_image(x.data) for x in image_list]
 
         with self.queue_lock:
-            result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+            result = postprocessing.run_extras(
+                extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
 
         return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
 
     def pnginfoapi(self, req: models.PNGInfoRequest):
-        if(not req.image.strip()):
+        if (not req.image.strip()):
             return models.PNGInfoResponse(info="")
 
         image = decode_base64_to_image(req.image.strip())
@@ -450,7 +526,9 @@ class Api:
 
     def progressapi(self, req: models.ProgressRequest = Depends()):
         # copy from check_progress_call of ui.py
-
+        task_id = req.task_id
+        if len(task_id) <= 0:
+            raise HTTPException(status_code=404, detail="task_id not found")
         if shared.state.job_count == 0:
             return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
 
@@ -460,7 +538,8 @@ class Api:
         if shared.state.job_count > 0:
             progress += shared.state.job_no / shared.state.job_count
         if shared.state.sampling_steps > 0:
-            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+            progress += 1 / shared.state.job_count * \
+                shared.state.sampling_step / shared.state.sampling_steps
 
         time_since_start = time.time() - shared.state.time_start
         eta = (time_since_start/progress)
@@ -495,7 +574,12 @@ class Api:
 
         return models.InterrogateResponse(caption=processed)
 
-    def interruptapi(self):
+    def interruptapi(self, interruptreq: models.InterruptRequest):
+        task_id = interruptreq.task_id
+        if len(task_id) <= 0:
+            raise HTTPException(status_code=404, detail="invalid task")
+        if shared.state.task_id != task_id:
+            raise HTTPException(status_code=404, detail="no match task")
         shared.state.interrupt()
 
         return {}
@@ -517,8 +601,9 @@ class Api:
         options = {}
         for key in shared.opts.data.keys():
             metadata = shared.opts.data_labels.get(key)
-            if(metadata is not None):
-                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
+            if (metadata is not None):
+                options.update({key: shared.opts.data.get(
+                    key, shared.opts.data_labels.get(key).default)})
             else:
                 options.update({key: shared.opts.data.get(key, None)})
 
@@ -539,7 +624,7 @@ class Api:
         return vars(shared.cmd_opts)
 
     def get_samplers(self):
-        return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
+        return [{"name": sampler[0], "aliases": sampler[2], "options": sampler[3]} for sampler in sd_samplers.all_samplers]
 
     def get_upscalers(self):
         return [
@@ -571,16 +656,17 @@ class Api:
         return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
 
     def get_face_restorers(self):
-        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
+        return [{"name": x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
 
     def get_realesrgan_models(self):
-        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
+        return [{"name": x.name, "path": x.data_path, "scale": x.scale} for x in get_realesrgan_models(None)]
 
     def get_prompt_styles(self):
         styleList = []
         for k in shared.prompt_styles.styles:
             style = shared.prompt_styles.styles[k]
-            styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
+            styleList.append(
+                {"name": style[0], "prompt": style[1], "negative_prompt": style[2]})
 
         return styleList
 
@@ -611,19 +697,19 @@ class Api:
     def create_embedding(self, args: dict):
         try:
             shared.state.begin(job="create_embedding")
-            filename = create_embedding(**args) # create empty embedding
-            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
+            filename = create_embedding(**args)  # create empty embedding
+            # reload embeddings so new one can be immediately used
+            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
             return models.CreateResponse(info=f"create embedding filename: {filename}")
         except AssertionError as e:
             return models.TrainResponse(info=f"create embedding error: {e}")
         finally:
             shared.state.end()
 
-
     def create_hypernetwork(self, args: dict):
         try:
             shared.state.begin(job="create_hypernetwork")
-            filename = create_hypernetwork(**args) # create empty embedding
+            filename = create_hypernetwork(**args)  # create empty embedding
             return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
         except AssertionError as e:
             return models.TrainResponse(info=f"create hypernetwork error: {e}")
@@ -633,7 +719,8 @@ class Api:
     def preprocess(self, args: dict):
         try:
             shared.state.begin(job="preprocess")
-            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
+            # quick operation unless blip/booru interrogation is enabled
+            preprocess(**args)
             shared.state.end()
             return models.PreprocessResponse(info='preprocess complete')
         except KeyError as e:
@@ -652,7 +739,8 @@ class Api:
             if not apply_optimizations:
                 sd_hijack.undo_optimizations()
             try:
-                embedding, filename = train_embedding(**args) # can take a long time to complete
+                embedding, filename = train_embedding(
+                    **args)  # can take a long time to complete
             except Exception as e:
                 error = e
             finally:
@@ -694,22 +782,30 @@ class Api:
             import os
             import psutil
             process = psutil.Process(os.getpid())
-            res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
-            ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
-            ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
+            # only rss is cross-platform guaranteed so we dont rely on other values
+            res = process.memory_info()
+            # and total memory is calculated as actual value is not cross-platform safe
+            ram_total = 100 * res.rss / process.memory_percent()
+            ram = {'free': ram_total - res.rss,
+                   'used': res.rss, 'total': ram_total}
         except Exception as err:
-            ram = { 'error': f'{err}' }
+            ram = {'error': f'{err}'}
         try:
             import torch
             if torch.cuda.is_available():
                 s = torch.cuda.mem_get_info()
-                system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
+                system = {'free': s[0], 'used': s[1] - s[0], 'total': s[1]}
                 s = dict(torch.cuda.memory_stats(shared.device))
-                allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
-                reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
-                active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
-                inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
-                warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
+                allocated = {
+                    'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak']}
+                reserved = {
+                    'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak']}
+                active = {
+                    'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak']}
+                inactive = {'current': s['inactive_split_bytes.all.current'],
+                            'peak': s['inactive_split_bytes.all.peak']}
+                warnings = {
+                    'retries': s['num_alloc_retries'], 'oom': s['num_ooms']}
                 cuda = {
                     'system': system,
                     'active': active,
@@ -726,7 +822,8 @@ class Api:
 
     def launch(self, server_name, port, root_path):
         self.app.include_router(self.router)
-        uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
+        uvicorn.run(self.app, host=server_name, port=port,
+                    timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
 
     def kill_webui(self):
         restart.stop_program()
@@ -739,4 +836,3 @@ class Api:
     def stop_webui(request):
         shared.state.server_command = "stop"
         return Response("Stopping.")
-

+ 161 - 64
api/models.py

@@ -26,6 +26,7 @@ API_NOT_ALLOWED = [
     "ddim_discretize"
 ]
 
+
 class ModelDef(BaseModel):
     """Assistance Class for Pydantic Dynamic Model Generation"""
 
@@ -46,8 +47,8 @@ class PydanticModelGenerator:
     def __init__(
         self,
         model_name: str = None,
-        class_instance = None,
-        additional_fields = None,
+        class_instance=None,
+        additional_fields=None,
     ):
         def field_type_generator(k, v):
             # field_type = str if not overrides.get(k) else overrides[k]["type"]
@@ -57,13 +58,14 @@ class PydanticModelGenerator:
             return Optional[field_type]
 
         def merge_class_params(class_):
-            all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+            all_classes = list(
+                filter(lambda x: x is not object, inspect.getmro(class_)))
             parameters = {}
             for classes in all_classes:
-                parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+                parameters = {**parameters, **
+                              inspect.signature(classes.__init__).parameters}
             return parameters
 
-
         self._model_name = model_name
         self._class_data = merge_class_params(class_instance)
 
@@ -74,7 +76,7 @@ class PydanticModelGenerator:
                 field_type=field_type_generator(k, v),
                 field_value=v.default
             )
-            for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
+            for (k, v) in self._class_data.items() if k not in API_NOT_ALLOWED
         ]
 
         for fields in additional_fields:
@@ -98,6 +100,7 @@ class PydanticModelGenerator:
         DynamicModel.__config__.allow_mutation = True
         return DynamicModel
 
+
 StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
     "StableDiffusionProcessingTxt2Img",
     StableDiffusionProcessingTxt2Img,
@@ -108,6 +111,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
         {"key": "send_images", "type": bool, "default": True},
         {"key": "save_images", "type": bool, "default": False},
         {"key": "alwayson_scripts", "type": dict, "default": {}},
+        {"key": "task_id", "type": str, "default": None},
     ]
 ).generate_model()
 
@@ -119,99 +123,162 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
         {"key": "init_images", "type": list, "default": None},
         {"key": "denoising_strength", "type": float, "default": 0.75},
         {"key": "mask", "type": str, "default": None},
-        {"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
+        {"key": "include_init_images", "type": bool,
+            "default": False, "exclude": True},
         {"key": "script_name", "type": str, "default": None},
         {"key": "script_args", "type": list, "default": []},
         {"key": "send_images", "type": bool, "default": True},
         {"key": "save_images", "type": bool, "default": False},
         {"key": "alwayson_scripts", "type": dict, "default": {}},
+        {"key": "task_id", "type": str, "default": None},
     ]
 ).generate_model()
 
+
 class TextToImageResponse(BaseModel):
-    images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+    images: List[str] = Field(default=None, title="Image",
+                              description="The generated image in base64 format.")
     parameters: dict
     info: str
 
+
 class ImageToImageResponse(BaseModel):
-    images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+    images: List[str] = Field(default=None, title="Image",
+                              description="The generated image in base64 format.")
     parameters: dict
     info: str
 
+
 class ExtrasBaseRequest(BaseModel):
-    resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
-    show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
-    gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
-    codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
-    codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
-    upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
-    upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
-    upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
-    upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
-    upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
-    upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
-    extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
-    upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
+    resize_mode: Literal[0, 1] = Field(
+        default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
+    show_extras_results: bool = Field(
+        default=True, title="Show results", description="Should the backend return the generated image?")
+    gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False,
+                                     description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
+    codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False,
+                                         description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
+    codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False,
+                                     description="Sets the weight of CodeFormer, values should be between 0 and 1.")
+    upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8,
+                                    description="By how much to upscale the image, only used when resize_mode=0.")
+    upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1,
+                                    description="Target width for the upscaler to hit. Only used when resize_mode=1.")
+    upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1,
+                                    description="Target height for the upscaler to hit. Only used when resize_mode=1.")
+    upscaling_crop: bool = Field(default=True, title="Crop to fit",
+                                 description="Should the upscaler crop the image to fit in the chosen size?")
+    upscaler_1: str = Field(default="None", title="Main upscaler",
+                            description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+    upscaler_2: str = Field(default="None", title="Secondary upscaler",
+                            description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+    extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1,
+                                                allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
+    upscale_first: bool = Field(default=False, title="Upscale first",
+                                description="Should the upscaler run before restoring faces?")
+
 
 class ExtraBaseResponse(BaseModel):
-    html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
+    html_info: str = Field(
+        title="HTML info", description="A series of HTML tags containing the process info.")
+
 
 class ExtrasSingleImageRequest(ExtrasBaseRequest):
-    image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+    image: str = Field(default="", title="Image",
+                       description="Image to work on, must be a Base64 string containing the image's data.")
+
 
 class ExtrasSingleImageResponse(ExtraBaseResponse):
-    image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
+    image: str = Field(default=None, title="Image",
+                       description="The generated image in base64 format.")
+
 
 class FileData(BaseModel):
-    data: str = Field(title="File data", description="Base64 representation of the file")
+    data: str = Field(title="File data",
+                      description="Base64 representation of the file")
     name: str = Field(title="File name")
 
+
 class ExtrasBatchImagesRequest(ExtrasBaseRequest):
-    imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+    imageList: List[FileData] = Field(
+        title="Images", description="List of images to work on. Must be Base64 strings")
+
 
 class ExtrasBatchImagesResponse(ExtraBaseResponse):
-    images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+    images: List[str] = Field(
+        title="Images", description="The generated images in base64 format.")
+
 
 class PNGInfoRequest(BaseModel):
-    image: str = Field(title="Image", description="The base64 encoded PNG image")
+    image: str = Field(
+        title="Image", description="The base64 encoded PNG image")
+
 
 class PNGInfoResponse(BaseModel):
-    info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
-    items: dict = Field(title="Items", description="An object containing all the info the image had")
+    info: str = Field(
+        title="Image info", description="A string with the parameters used to generate the image")
+    items: dict = Field(
+        title="Items", description="An object containing all the info the image had")
+
+
+class InterruptRequest(BaseModel):
+    task_id: str = Field(default="", title="task id", description="任务编号")
+
 
 class ProgressRequest(BaseModel):
-    skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
+    task_id: str = Field(default="", title="task id", description="任务编号")
+    skip_current_image: bool = Field(
+        default=False, title="Skip current image", description="Skip current image serialization")
+
 
 class ProgressResponse(BaseModel):
-    progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
+    progress: float = Field(
+        title="Progress", description="The progress with a range of 0 to 1")
     eta_relative: float = Field(title="ETA in secs")
-    state: dict = Field(title="State", description="The current state snapshot")
-    current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
-    textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
+    state: dict = Field(
+        title="State", description="The current state snapshot")
+    current_image: str = Field(default=None, title="Current image",
+                               description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+    textinfo: str = Field(default=None, title="Info text",
+                          description="Info text used by WebUI.")
+
 
 class InterrogateRequest(BaseModel):
-    image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
-    model: str = Field(default="clip", title="Model", description="The interrogate model used.")
+    image: str = Field(default="", title="Image",
+                       description="Image to work on, must be a Base64 string containing the image's data.")
+    model: str = Field(default="clip", title="Model",
+                       description="The interrogate model used.")
+
 
 class InterrogateResponse(BaseModel):
-    caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
+    caption: str = Field(default=None, title="Caption",
+                         description="The generated caption for the image.")
+
 
 class TrainResponse(BaseModel):
-    info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
+    info: str = Field(
+        title="Train info", description="Response string from train embedding or hypernetwork task.")
+
 
 class CreateResponse(BaseModel):
-    info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
+    info: str = Field(title="Create info",
+                      description="Response string from create embedding or hypernetwork task.")
+
 
 class PreprocessResponse(BaseModel):
-    info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
+    info: str = Field(title="Preprocess info",
+                      description="Response string from preprocessing task.")
+
 
 fields = {}
 for key, metadata in opts.data_labels.items():
     value = opts.data.get(key)
-    optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
+    optType = opts.typemap.get(type(metadata.default), type(
+        metadata.default)) if metadata.default else Any
 
     if metadata is not None:
-        fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
+        fields.update({key: (Optional[optType], Field(
+            default=metadata.default, description=metadata.label))})
     else:
         fields.update({key: (Optional[optType], Field())})
 
@@ -220,20 +287,23 @@ OptionsModel = create_model("Options", **fields)
 flags = {}
 _options = vars(parser)['_option_string_actions']
 for key in _options:
-    if(_options[key].dest != 'help'):
+    if (_options[key].dest != 'help'):
         flag = _options[key]
         _type = str
         if _options[key].default is not None:
             _type = type(_options[key].default)
-        flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
+        flags.update(
+            {flag.dest: (_type, Field(default=flag.default, description=flag.help))})
 
 FlagsModel = create_model("Flags", **flags)
 
+
 class SamplerItem(BaseModel):
     name: str = Field(title="Name")
     aliases: List[str] = Field(title="Aliases")
     options: Dict[str, str] = Field(title="Options")
 
+
 class UpscalerItem(BaseModel):
     name: str = Field(title="Name")
     model_name: Optional[str] = Field(title="Model Name")
@@ -241,9 +311,11 @@ class UpscalerItem(BaseModel):
     model_url: Optional[str] = Field(title="URL")
     scale: Optional[float] = Field(title="Scale")
 
+
 class LatentUpscalerModeItem(BaseModel):
     name: str = Field(title="Name")
 
+
 class SDModelItem(BaseModel):
     title: str = Field(title="Title")
     model_name: str = Field(title="Model Name")
@@ -252,23 +324,28 @@ class SDModelItem(BaseModel):
     filename: str = Field(title="Filename")
     config: Optional[str] = Field(title="Config file")
 
+
 class SDVaeItem(BaseModel):
     model_name: str = Field(title="Model Name")
     filename: str = Field(title="Filename")
 
+
 class HypernetworkItem(BaseModel):
     name: str = Field(title="Name")
     path: Optional[str] = Field(title="Path")
 
+
 class FaceRestorerItem(BaseModel):
     name: str = Field(title="Name")
     cmd_dir: Optional[str] = Field(title="Path")
 
+
 class RealesrganItem(BaseModel):
     name: str = Field(title="Name")
     path: Optional[str] = Field(title="Path")
     scale: Optional[int] = Field(title="Scale")
 
+
 class PromptStyleItem(BaseModel):
     name: str = Field(title="Name")
     prompt: Optional[str] = Field(title="Prompt")
@@ -276,15 +353,24 @@ class PromptStyleItem(BaseModel):
 
 
 class EmbeddingItem(BaseModel):
-    step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
-    sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
-    sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
-    shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
-    vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+    step: Optional[int] = Field(
+        title="Step", description="The number of steps that were used to train this embedding, if available")
+    sd_checkpoint: Optional[str] = Field(
+        title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+    sd_checkpoint_name: Optional[str] = Field(
+        title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+    shape: int = Field(
+        title="Shape", description="The length of each individual vector in the embedding")
+    vectors: int = Field(
+        title="Vectors", description="The number of vectors in the embedding")
+
 
 class EmbeddingsResponse(BaseModel):
-    loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
-    skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
+    loaded: Dict[str, EmbeddingItem] = Field(
+        title="Loaded", description="Embeddings loaded for the current model")
+    skipped: Dict[str, EmbeddingItem] = Field(
+        title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
+
 
 class MemoryResponse(BaseModel):
     ram: dict = Field(title="RAM", description="System memory stats")
@@ -292,21 +378,32 @@ class MemoryResponse(BaseModel):
 
 
 class ScriptsList(BaseModel):
-    txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
-    img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
+    txt2img: list = Field(default=None, title="Txt2img",
+                          description="Titles of scripts (txt2img)")
+    img2img: list = Field(default=None, title="Img2img",
+                          description="Titles of scripts (img2img)")
 
 
 class ScriptArg(BaseModel):
-    label: str = Field(default=None, title="Label", description="Name of the argument in UI")
-    value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
-    minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
-    maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
-    step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
-    choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+    label: str = Field(default=None, title="Label",
+                       description="Name of the argument in UI")
+    value: Optional[Any] = Field(
+        default=None, title="Value", description="Default value of the argument")
+    minimum: Optional[Any] = Field(
+        default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
+    maximum: Optional[Any] = Field(
+        default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
+    step: Optional[Any] = Field(default=None, title="Minimum",
+                                description="Step for changing value of the argumentin UI")
+    choices: Optional[List[str]] = Field(
+        default=None, title="Choices", description="Possible values for the argument")
 
 
 class ScriptInfo(BaseModel):
     name: str = Field(default=None, title="Name", description="Script name")
-    is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
-    is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
-    args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+    is_alwayson: bool = Field(default=None, title="IsAlwayson",
+                              description="Flag specifying whether this script is an alwayson script")
+    is_img2img: bool = Field(default=None, title="IsImg2img",
+                             description="Flag specifying whether this script is an img2img script")
+    args: List[ScriptArg] = Field(
+        title="Arguments", description="List of script's arguments")

+ 255 - 130
processing.py

@@ -38,7 +38,8 @@ opt_f = 8
 
 def setup_color_correction(image):
     logging.info("Calibrating color correction.")
-    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
+    correction_target = cv2.cvtColor(
+        np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
     return correction_target
 
 
@@ -79,19 +80,23 @@ def apply_overlay(image, paste_loc, index, overlays):
 
 
 def txt2img_image_conditioning(sd_model, x, width, height):
-    if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
+    # Inpainting models
+    if sd_model.model.conditioning_key in {'hybrid', 'concat'}:
 
         # The "masked-image" in this case will just be all zeros since the entire image is masked.
-        image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
-        image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+        image_conditioning = torch.zeros(
+            x.shape[0], 3, height, width, device=x.device)
+        image_conditioning = sd_model.get_first_stage_encoding(
+            sd_model.encode_first_stage(image_conditioning))
 
         # Add the fake full 1s mask to the first dimension.
-        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+        image_conditioning = torch.nn.functional.pad(
+            image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
         image_conditioning = image_conditioning.to(x.dtype)
 
         return image_conditioning
 
-    elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
+    elif sd_model.model.conditioning_key == "crossattn-adm":  # UnCLIP models
 
         return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
 
@@ -147,9 +152,11 @@ class StableDiffusionProcessing:
         self.s_min_uncond = s_min_uncond or opts.s_min_uncond
         self.s_churn = s_churn or opts.s_churn
         self.s_tmin = s_tmin or opts.s_tmin
-        self.s_tmax = s_tmax or float('inf')  # not representable as a standard ui option
+        # not representable as a standard ui option
+        self.s_tmax = s_tmax or float('inf')
         self.s_noise = s_noise or opts.s_noise
-        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
+        self.override_settings = {k: v for k, v in (
+            override_settings or {}).items() if k not in shared.restricted_opts}
         self.override_settings_restore_afterwards = override_settings_restore_afterwards
         self.is_using_inpainting_conditioning = False
         self.disable_extra_networks = False
@@ -191,18 +198,22 @@ class StableDiffusionProcessing:
         return shared.sd_model
 
     def txt2img_image_conditioning(self, x, width=None, height=None):
-        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {
+            'hybrid', 'concat'}
 
         return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
 
     def depth2img_image_conditioning(self, source_image):
         # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
         transformer = AddMiDaS(model_type="dpt_hybrid")
-        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
-        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
+        transformed = transformer(
+            {"jpg": rearrange(source_image[0], "c h w -> h w c")})
+        midas_in = torch.from_numpy(
+            transformed["midas_in"][None, ...]).to(device=shared.device)
         midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
 
-        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+        conditioning_image = self.sd_model.get_first_stage_encoding(
+            self.sd_model.encode_first_stage(source_image))
         conditioning = torch.nn.functional.interpolate(
             self.sd_model.depth_model(midas_in),
             size=conditioning_image.shape[2:],
@@ -211,19 +222,22 @@ class StableDiffusionProcessing:
         )
 
         (depth_min, depth_max) = torch.aminmax(conditioning)
-        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
+        conditioning = 2. * (conditioning - depth_min) / \
+            (depth_max - depth_min) - 1.
         return conditioning
 
     def edit_image_conditioning(self, source_image):
-        conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
+        conditioning_image = self.sd_model.encode_first_stage(
+            source_image).mode()
 
         return conditioning_image
 
     def unclip_image_conditioning(self, source_image):
         c_adm = self.sd_model.embedder(source_image)
         if self.sd_model.noise_augmentor is not None:
-            noise_level = 0 # TODO: Allow other noise levels?
-            c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
+            noise_level = 0  # TODO: Allow other noise levels?
+            c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(
+                torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
             c_adm = torch.cat((c_adm, noise_level_emb), 1)
         return c_adm
 
@@ -236,31 +250,41 @@ class StableDiffusionProcessing:
                 conditioning_mask = image_mask
             else:
                 conditioning_mask = np.array(image_mask.convert("L"))
-                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
-                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+                conditioning_mask = conditioning_mask.astype(
+                    np.float32) / 255.0
+                conditioning_mask = torch.from_numpy(
+                    conditioning_mask[None, None])
 
                 # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                 conditioning_mask = torch.round(conditioning_mask)
         else:
-            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
+            conditioning_mask = source_image.new_ones(
+                1, 1, *source_image.shape[-2:])
 
         # Create another latent image, this time with a masked version of the original input.
         # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
-        conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
+        conditioning_mask = conditioning_mask.to(
+            device=source_image.device, dtype=source_image.dtype)
         conditioning_image = torch.lerp(
             source_image,
             source_image * (1.0 - conditioning_mask),
-            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+            getattr(self, "inpainting_mask_weight",
+                    shared.opts.inpainting_mask_weight)
         )
 
         # Encode the new masked image using first stage of network.
-        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+        conditioning_image = self.sd_model.get_first_stage_encoding(
+            self.sd_model.encode_first_stage(conditioning_image))
 
         # Create the concatenated conditioning tensor to be fed to `c_concat`
-        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
-        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
-        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
-        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
+        conditioning_mask = torch.nn.functional.interpolate(
+            conditioning_mask, size=latent_image.shape[-2:])
+        conditioning_mask = conditioning_mask.expand(
+            conditioning_image.shape[0], -1, -1, -1)
+        image_conditioning = torch.cat(
+            [conditioning_mask, conditioning_image], dim=1)
+        image_conditioning = image_conditioning.to(
+            shared.device).type(self.sd_model.dtype)
 
         return image_conditioning
 
@@ -313,10 +337,13 @@ class StableDiffusionProcessing:
         if type(self.negative_prompt) == list:
             self.all_negative_prompts = self.negative_prompt
         else:
-            self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
+            self.all_negative_prompts = self.batch_size * \
+                self.n_iter * [self.negative_prompt]
 
-        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
-        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
+        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(
+            x, self.styles) for x in self.all_prompts]
+        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(
+            x, self.styles) for x in self.all_negative_prompts]
 
     def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
         """
@@ -356,16 +383,22 @@ class StableDiffusionProcessing:
         return cache[1]
 
     def setup_conds(self):
-        prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
-        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
+        prompts = prompt_parser.SdConditioning(
+            self.prompts, width=self.width, height=self.height)
+        negative_prompts = prompt_parser.SdConditioning(
+            self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
 
         sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
-        self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
-        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
-        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+        self.step_multiplier = 2 if sampler_config and sampler_config.options.get(
+            "second_order", False) else 1
+        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts,
+                                              self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning,
+                                             prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
 
     def parse_extra_network_prompts(self):
-        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
+        self.prompts, self.extra_network_data = extra_networks.parse_prompts(
+            self.prompts)
 
 
 class Processed:
@@ -407,14 +440,19 @@ class Processed:
         self.s_noise = p.s_noise
         self.s_min_uncond = p.s_min_uncond
         self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
-        self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
-        self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
-        self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
-        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+        self.prompt = self.prompt if type(
+            self.prompt) != list else self.prompt[0]
+        self.negative_prompt = self.negative_prompt if type(
+            self.negative_prompt) != list else self.negative_prompt[0]
+        self.seed = int(self.seed if type(self.seed) !=
+                        list else self.seed[0]) if self.seed is not None else -1
+        self.subseed = int(self.subseed if type(
+            self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
         self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
 
         self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
-        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
+        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [
+            self.negative_prompt]
         self.all_seeds = all_seeds or p.all_seeds or [self.seed]
         self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
         self.infotexts = infotexts or [info]
@@ -471,7 +509,8 @@ def slerp(val, low, high):
 
     omega = torch.acos(dot)
     so = torch.sin(omega)
-    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1) * \
+        low + (torch.sin(val*omega)/so).unsqueeze(1) * high
     return res
 
 
@@ -484,12 +523,14 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
     # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
     # produce the same images as with two batches [100], [101].
     if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
-        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
+        sampler_noises = [[]
+                          for _ in range(p.sampler.number_of_needed_noises(p))]
     else:
         sampler_noises = None
 
     for i, seed in enumerate(seeds):
-        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
+        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (
+            shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
 
         subnoise = None
         if subseeds is not None:
@@ -527,12 +568,14 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
                 torch.manual_seed(seed + eta_noise_seed_delta)
 
             for j in range(cnt):
-                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
+                sampler_noises[j].append(
+                    devices.randn_without_seed(tuple(noise_shape)))
 
         xs.append(noise)
 
     if sampler_noises is not None:
-        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
+        p.sampler.sampler_noises = [torch.stack(n).to(
+            shared.device) for n in sampler_noises]
 
     x = torch.stack(xs).to(shared.device)
     return x
@@ -643,7 +686,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "User": p.user if opts.add_user_name_to_info else None,
     }
 
-    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+    generation_params_text = ", ".join(
+        [k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
 
     prompt_text = p.prompt if use_main_prompt else all_prompts[index]
     negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
@@ -694,7 +738,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
 
     if type(p.prompt) == list:
-        assert(len(p.prompt) > 0)
+        assert (len(p.prompt) > 0)
     else:
         assert p.prompt is not None
 
@@ -713,7 +757,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     if type(seed) == list:
         p.all_seeds = seed
     else:
-        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
+        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0)
+                       for x in range(len(p.all_prompts))]
 
     if type(subseed) == list:
         p.all_subseeds = subseed
@@ -752,12 +797,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 break
 
             p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
-            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+            p.negative_prompts = p.all_negative_prompts[n *
+                                                        p.batch_size:(n + 1) * p.batch_size]
             p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
-            p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+            p.subseeds = p.all_subseeds[n *
+                                        p.batch_size:(n + 1) * p.batch_size]
 
             if p.scripts is not None:
-                p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
+                p.scripts.before_process_batch(
+                    p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
 
             if len(p.prompts) == 0:
                 break
@@ -769,7 +817,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                     extra_networks.activate(p, p.extra_network_data)
 
             if p.scripts is not None:
-                p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
+                p.scripts.process_batch(
+                    p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
 
             # params.txt should be saved after scripts.process_batch, since the
             # infotext could be modified by that callback
@@ -785,17 +834,21 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             for comment in model_hijack.comments:
                 comments[comment] = 1
 
-            p.extra_generation_params.update(model_hijack.extra_generation_params)
+            p.extra_generation_params.update(
+                model_hijack.extra_generation_params)
 
             if p.n_iter > 1:
                 shared.state.job = f"Batch {n+1} out of {p.n_iter}"
 
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
-                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
+                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds,
+                                        subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
 
-            x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+            x_samples_ddim = decode_latent_batch(
+                p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
             x_samples_ddim = torch.stack(x_samples_ddim).float()
-            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+            x_samples_ddim = torch.clamp(
+                (x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
             del samples_ddim
 
@@ -807,11 +860,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             if p.scripts is not None:
                 p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
 
-                p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
-                p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+                p.prompts = p.all_prompts[n *
+                                          p.batch_size:(n + 1) * p.batch_size]
+                p.negative_prompts = p.all_negative_prompts[n *
+                                                            p.batch_size:(n + 1) * p.batch_size]
 
-                batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
-                p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
+                batch_params = scripts.PostprocessBatchListArgs(
+                    list(x_samples_ddim))
+                p.scripts.postprocess_batch_list(
+                    p, batch_params, batch_number=n)
                 x_samples_ddim = batch_params.images
 
             def infotext(index=0, use_main_prompt=False):
@@ -825,7 +882,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
                 if p.restore_faces:
                     if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
-                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
+                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(
+                            i), p=p, suffix="-before-face-restoration")
 
                     devices.torch_gc()
 
@@ -841,14 +899,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
                 if p.color_corrections is not None and i < len(p.color_corrections):
                     if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
-                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
-                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
-                    image = apply_color_correction(p.color_corrections[i], image)
+                        image_without_cc = apply_overlay(
+                            image, p.paste_to, i, p.overlay_images)
+                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(
+                            i), p=p, suffix="-before-color-correction")
+                    image = apply_color_correction(
+                        p.color_corrections[i], image)
 
                 image = apply_overlay(image, p.paste_to, i, p.overlay_images)
 
                 if opts.samples_save and not p.do_not_save_samples:
-                    images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
+                    images.save_image(
+                        image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
 
                 text = infotext(i)
                 infotexts.append(text)
@@ -858,13 +920,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
                 if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
                     image_mask = p.mask_for_overlay.convert('RGB')
-                    image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+                    image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new(
+                        'RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
 
                     if opts.save_mask:
-                        images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
+                        images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(
+                            i), p=p, suffix="-mask")
 
                     if opts.save_mask_composite:
-                        images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
+                        images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(
+                            i), p=p, suffix="-mask-composite")
 
                     if opts.return_mask:
                         output_images.append(image_mask)
@@ -881,7 +946,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
         p.color_corrections = None
 
         index_of_first_image = 0
-        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
+        unwanted_grid_because_of_img_count = len(
+            output_images) < 2 and opts.grid_only_if_multiple
         if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
             grid = images.image_grid(output_images, p.batch_size)
 
@@ -894,7 +960,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 index_of_first_image = 1
 
             if opts.grid_save:
-                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(
+                    use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
 
     if not p.disable_extra_networks and p.extra_network_data:
         extra_networks.deactivate(p, p.extra_network_data)
@@ -935,7 +1002,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     cached_hr_uc = [None, None]
     cached_hr_c = [None, None]
 
-    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', task_id: str = None, **kwargs):
         super().__init__(**kwargs)
         self.enable_hr = enable_hr
         self.denoising_strength = denoising_strength
@@ -988,7 +1055,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                 self.hr_upscale_to_x = self.width
                 self.hr_upscale_to_y = self.height
 
-                self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
+                self.width, self.height = old_hires_fix_first_pass_dimensions(
+                    self.width, self.height)
                 self.applied_old_hires_behavior_to = (self.width, self.height)
 
             if self.hr_resize_x == 0 and self.hr_resize_y == 0:
@@ -1017,8 +1085,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                         self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                         self.hr_upscale_to_y = self.hr_resize_y
 
-                    self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
-                    self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
+                    self.truncate_x = (
+                        self.hr_upscale_to_x - target_w) // opt_f
+                    self.truncate_y = (
+                        self.hr_upscale_to_y - target_h) // opt_f
 
             # special case: the user has chosen to do nothing
             if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
@@ -1032,7 +1102,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                 if state.job_count == -1:
                     state.job_count = self.n_iter
 
-                shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
+                shared.total_tqdm.updateTotal(
+                    (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
                 state.job_count = state.job_count * 2
                 state.processing_has_refined_job_count = True
 
@@ -1043,15 +1114,20 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                 self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
 
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
-        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+        self.sampler = sd_samplers.create_sampler(
+            self.sampler_name, self.sd_model)
 
-        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+        latent_scale_mode = shared.latent_upscale_modes.get(
+            self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
         if self.enable_hr and latent_scale_mode is None:
             if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
-                raise Exception(f"could not find upscaler named {self.hr_upscaler}")
+                raise Exception(
+                    f"could not find upscaler named {self.hr_upscaler}")
 
-        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
-        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds,
+                                  subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning,
+                                      image_conditioning=self.txt2img_image_conditioning(x))
 
         if not self.enable_hr:
             return samples
@@ -1068,26 +1144,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                 return
 
             if not isinstance(image, Image.Image):
-                image = sd_samplers.sample_to_image(image, index, approximation=0)
+                image = sd_samplers.sample_to_image(
+                    image, index, approximation=0)
 
-            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
-            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
+            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [
+            ], iteration=self.iteration, position_in_batch=index)
+            images.save_image(image, self.outpath_samples, "",
+                              seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
 
         if latent_scale_mode is not None:
             for i in range(samples.shape[0]):
                 save_intermediate(samples, i)
 
-            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+            samples = torch.nn.functional.interpolate(samples, size=(
+                target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
 
             # Avoid making the inpainting conditioning unless necessary as
             # this does need some extra compute to decode / encode the image again.
             if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
-                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+                image_conditioning = self.img2img_image_conditioning(
+                    decode_first_stage(self.sd_model, samples), samples)
             else:
                 image_conditioning = self.txt2img_image_conditioning(samples)
         else:
             decoded_samples = decode_first_stage(self.sd_model, samples)
-            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
+            lowres_samples = torch.clamp(
+                (decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
 
             batch_images = []
             for i, x_sample in enumerate(lowres_samples):
@@ -1097,7 +1179,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
                 save_intermediate(image, i)
 
-                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
+                image = images.resize_image(
+                    0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
                 image = np.array(image).astype(np.float32) / 255.0
                 image = np.moveaxis(image, 2, 0)
                 batch_images.append(image)
@@ -1106,22 +1189,28 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             decoded_samples = decoded_samples.to(shared.device)
             decoded_samples = 2. * decoded_samples - 1.
 
-            samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+            samples = self.sd_model.get_first_stage_encoding(
+                self.sd_model.encode_first_stage(decoded_samples))
 
-            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
+            image_conditioning = self.img2img_image_conditioning(
+                decoded_samples, samples)
 
         shared.state.nextjob()
 
         img2img_sampler_name = self.hr_sampler_name or self.sampler_name
 
-        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+        # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+        if self.sampler_name in ['PLMS', 'UniPC']:
             img2img_sampler_name = 'DDIM'
 
-        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
+        self.sampler = sd_samplers.create_sampler(
+            img2img_sampler_name, self.sd_model)
 
-        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
+        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(
+            self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
 
-        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
+        noise = create_random_tensors(
+            samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
 
         # GC now before running the next img2img to prevent running out of memory
         x = None
@@ -1134,14 +1223,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         with devices.autocast():
             self.calculate_hr_conds()
 
-        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+        sd_models.apply_token_merging(
+            self.sd_model, self.get_token_merging_ratio(for_hr=True))
 
         if self.scripts is not None:
             self.scripts.before_hr(self)
 
-        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc,
+                                              steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
 
-        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+        sd_models.apply_token_merging(
+            self.sd_model, self.get_token_merging_ratio())
 
         self.is_hr_pass = False
 
@@ -1170,22 +1262,28 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         if type(self.hr_prompt) == list:
             self.all_hr_prompts = self.hr_prompt
         else:
-            self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
+            self.all_hr_prompts = self.batch_size * \
+                self.n_iter * [self.hr_prompt]
 
         if type(self.hr_negative_prompt) == list:
             self.all_hr_negative_prompts = self.hr_negative_prompt
         else:
-            self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
+            self.all_hr_negative_prompts = self.batch_size * \
+                self.n_iter * [self.hr_negative_prompt]
 
-        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
-        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(
+            x, self.styles) for x in self.all_hr_prompts]
+        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(
+            x, self.styles) for x in self.all_hr_negative_prompts]
 
     def calculate_hr_conds(self):
         if self.hr_c is not None:
             return
 
-        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
-        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts,
+                                                 self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts,
+                                                self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
 
     def setup_conds(self):
         super().setup_conds()
@@ -1197,7 +1295,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             if shared.opts.hires_fix_use_firstpass_conds:
                 self.calculate_hr_conds()
 
-            elif lowvram.is_enabled(shared.sd_model):  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+            # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+            elif lowvram.is_enabled(shared.sd_model):
                 with devices.autocast():
                     extra_networks.activate(self, self.hr_extra_network_data)
 
@@ -1210,10 +1309,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         res = super().parse_extra_network_prompts()
 
         if self.enable_hr:
-            self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
-            self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+            self.hr_prompts = self.all_hr_prompts[self.iteration *
+                                                  self.batch_size:(self.iteration + 1) * self.batch_size]
+            self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(
+                self.iteration + 1) * self.batch_size]
 
-            self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
+            self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(
+                self.hr_prompts)
 
         return res
 
@@ -1221,7 +1323,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
     sampler = None
 
-    def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
+    def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, task_id: str = None, **kwargs):
         super().__init__(**kwargs)
 
         self.init_images = init_images
@@ -1247,7 +1349,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
         self.image_conditioning = None
 
     def init(self, all_prompts, all_seeds, all_subseeds):
-        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+        self.sampler = sd_samplers.create_sampler(
+            self.sampler_name, self.sd_model)
         crop_region = None
 
         image_mask = self.image_mask
@@ -1261,29 +1364,36 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             if self.mask_blur_x > 0:
                 np_mask = np.array(image_mask)
                 kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
-                np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
+                np_mask = cv2.GaussianBlur(
+                    np_mask, (kernel_size, 1), self.mask_blur_x)
                 image_mask = Image.fromarray(np_mask)
 
             if self.mask_blur_y > 0:
                 np_mask = np.array(image_mask)
                 kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
-                np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
+                np_mask = cv2.GaussianBlur(
+                    np_mask, (1, kernel_size), self.mask_blur_y)
                 image_mask = Image.fromarray(np_mask)
 
             if self.inpaint_full_res:
                 self.mask_for_overlay = image_mask
                 mask = image_mask.convert('L')
-                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
-                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
+                crop_region = masking.get_crop_region(
+                    np.array(mask), self.inpaint_full_res_padding)
+                crop_region = masking.expand_crop_region(
+                    crop_region, self.width, self.height, mask.width, mask.height)
                 x1, y1, x2, y2 = crop_region
 
                 mask = mask.crop(crop_region)
-                image_mask = images.resize_image(2, mask, self.width, self.height)
+                image_mask = images.resize_image(
+                    2, mask, self.width, self.height)
                 self.paste_to = (x1, y1, x2-x1, y2-y1)
             else:
-                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
+                image_mask = images.resize_image(
+                    self.resize_mode, image_mask, self.width, self.height)
                 np_mask = np.array(image_mask)
-                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
+                np_mask = np.clip((np_mask.astype(np.float32))
+                                  * 2, 0, 255).astype(np.uint8)
                 self.mask_for_overlay = Image.fromarray(np_mask)
 
             self.overlay_images = []
@@ -1299,16 +1409,19 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             # Save init image
             if opts.save_init_img:
                 self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
-                images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+                images.save_image(img, path=opts.outdir_init_images, basename=None,
+                                  forced_filename=self.init_img_hash, save_to_dirs=False)
 
             image = images.flatten(img, opts.img2img_background_color)
 
             if crop_region is None and self.resize_mode != 3:
-                image = images.resize_image(self.resize_mode, image, self.width, self.height)
+                image = images.resize_image(
+                    self.resize_mode, image, self.width, self.height)
 
             if image_mask is not None:
                 image_masked = Image.new('RGBa', (image.width, image.height))
-                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
+                image_masked.paste(image.convert("RGBA").convert(
+                    "RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
 
                 self.overlay_images.append(image_masked.convert('RGBA'))
 
@@ -1330,7 +1443,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             imgs.append(image)
 
         if len(imgs) == 1:
-            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
+            batch_images = np.expand_dims(
+                imgs[0], axis=0).repeat(self.batch_size, axis=0)
             if self.overlay_images is not None:
                 self.overlay_images = self.overlay_images * self.batch_size
 
@@ -1341,44 +1455,55 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             self.batch_size = len(imgs)
             batch_images = np.array(imgs)
         else:
-            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
+            raise RuntimeError(
+                f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
 
         image = torch.from_numpy(batch_images)
         image = 2. * image - 1.
         image = image.to(shared.device, dtype=devices.dtype_vae)
 
-        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+        self.init_latent = self.sd_model.get_first_stage_encoding(
+            self.sd_model.encode_first_stage(image))
 
         if self.resize_mode == 3:
-            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(
+                self.height // opt_f, self.width // opt_f), mode="bilinear")
 
         if image_mask is not None:
             init_mask = latent_mask
-            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
-            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
+            latmask = init_mask.convert('RGB').resize(
+                (self.init_latent.shape[3], self.init_latent.shape[2]))
+            latmask = np.moveaxis(
+                np.array(latmask, dtype=np.float32), 2, 0) / 255
             latmask = latmask[0]
             latmask = np.around(latmask)
             latmask = np.tile(latmask[None], (4, 1, 1))
 
-            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
-            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
+            self.mask = torch.asarray(
+                1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
+            self.nmask = torch.asarray(latmask).to(
+                shared.device).type(self.sd_model.dtype)
 
             # this needs to be fixed to be done in sample() using actual seeds for batches
             if self.inpainting_fill == 2:
-                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
+                self.init_latent = self.init_latent * self.mask + create_random_tensors(
+                    self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
             elif self.inpainting_fill == 3:
                 self.init_latent = self.init_latent * self.mask
 
-        self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
+        self.image_conditioning = self.img2img_image_conditioning(
+            image, self.init_latent, image_mask)
 
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
-        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds,
+                                  subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
 
         if self.initial_noise_multiplier != 1.0:
             self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
             x *= self.initial_noise_multiplier
 
-        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
+        samples = self.sampler.sample_img2img(
+            self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
 
         if self.mask is not None:
             samples = samples * self.nmask + self.init_latent * self.mask

+ 37 - 21
shared.py

@@ -27,7 +27,8 @@ demo = None
 
 parser = cmd_args.parser
 
-script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
+script_loading.preload_extensions(
+    extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
 script_loading.preload_extensions(extensions_builtin_dir, parser)
 
 if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
@@ -64,10 +65,12 @@ gradio_hf_hub_themes = [
 ]
 
 
-cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
+cmd_opts.disable_extension_access = (
+    cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
 
 devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
-    (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
+    (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device(
+    ) for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
 
 devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
 devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
@@ -75,7 +78,8 @@ devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae el
 device = devices.device
 weight_load_location = None if cmd_opts.lowram else "cpu"
 
-batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
+batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (
+    cmd_opts.lowvram or cmd_opts.medvram)
 parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
 xformers_available = False
 config_filename = cmd_opts.ui_settings_file
@@ -96,6 +100,7 @@ class State:
     skipped = False
     interrupted = False
     job = ""
+    task_id = ""
     job_no = 0
     job_count = 0
     processing_has_refined_job_count = False
@@ -221,9 +226,11 @@ class State:
 
         import modules.sd_samplers
         if opts.show_progress_grid:
-            self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
+            self.assign_current_image(
+                modules.sd_samplers.samples_to_image_grid(self.current_latent))
         else:
-            self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
+            self.assign_current_image(
+                modules.sd_samplers.sample_to_image(self.current_latent))
 
         self.current_image_sampling_step = self.sampling_step
 
@@ -280,8 +287,6 @@ class OptionInfo:
         return self
 
 
-
-
 def options_section(section_identifier, options_dict):
     for v in options_dict.values():
         v.section = section_identifier
@@ -596,10 +601,12 @@ class Options:
                 info = opts.data_labels.get(key, None)
                 comp_args = info.component_args if info else None
                 if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
-                    raise RuntimeError(f"not possible to set {key} because it is restricted")
+                    raise RuntimeError(
+                        f"not possible to set {key} because it is restricted")
 
                 if cmd_opts.hide_ui_dir_config and key in restricted_opts:
-                    raise RuntimeError(f"not possible to set {key} because it is restricted")
+                    raise RuntimeError(
+                        f"not possible to set {key} because it is restricted")
 
                 self.data[key] = value
                 return
@@ -668,21 +675,25 @@ class Options:
 
         # 1.1.1 quicksettings list migration
         if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
-            self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
+            self.data['quicksettings_list'] = [i.strip()
+                                               for i in self.data.get('quicksettings').split(',')]
 
         # 1.4.0 ui_reorder
         if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
-            self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
+            self.data['ui_reorder_list'] = [i.strip()
+                                            for i in self.data.get('ui_reorder').split(',')]
 
         bad_settings = 0
         for k, v in self.data.items():
             info = self.data_labels.get(k, None)
             if info is not None and not self.same_type(info.default, v):
-                print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
+                print(
+                    f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
                 bad_settings += 1
 
         if bad_settings > 0:
-            print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
+            print(
+                f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
 
     def onchange(self, key, func, call=True):
         item = self.data_labels.get(key)
@@ -692,9 +703,12 @@ class Options:
             func()
 
     def dumpjson(self):
-        d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
-        d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
-        d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
+        d = {k: self.data.get(k, v.default)
+             for k, v in self.data_labels.items()}
+        d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items(
+        ) if v.comment_before is not None}
+        d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items(
+        ) if v.comment_after is not None}
         return json.dumps(d)
 
     def add_option(self, key, info):
@@ -709,7 +723,8 @@ class Options:
             if item.section not in section_ids:
                 section_ids[item.section] = len(section_ids)
 
-        self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
+        self.data_labels = dict(
+            sorted(settings_items, key=lambda x: section_ids[x[1].section]))
 
     def cast_value(self, key, value):
         """casts an arbitrary to the same type as this setting's value with key
@@ -760,7 +775,8 @@ class Shared(sys.modules[__name__].__class__):
         modules.sd_models.model_data.set_sd_model(value)
 
 
-sd_model: LatentDiffusion = None  # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
+# this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
+sd_model: LatentDiffusion = None
 sys.modules[__name__].__class__ = Shared
 
 settings_components = None
@@ -805,7 +821,6 @@ def reload_gradio_theme(theme_name=None):
             gradio_theme = gr.themes.Default(**default_theme_args)
 
 
-
 class TotalTQDM:
     def __init__(self):
         self._tqdm = None
@@ -850,7 +865,8 @@ def natural_sort_key(s, regex=re.compile('([0-9]+)')):
 
 
 def listfiles(dirname):
-    filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
+    filenames = [os.path.join(dirname, x) for x in sorted(
+        os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
     return [file for file in filenames if os.path.isfile(file)]