img2img.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import os
  2. from contextlib import closing
  3. from pathlib import Path
  4. import numpy as np
  5. from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
  6. import gradio as gr
  7. from modules import sd_samplers, images as imgutil
  8. from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
  9. from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
  10. from modules.shared import opts, state
  11. from modules.images import save_image
  12. import modules.shared as shared
  13. import modules.processing as processing
  14. from modules.ui import plaintext_to_html
  15. import modules.scripts
  16. def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
  17. processing.fix_seed(p)
  18. images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
  19. is_inpaint_batch = False
  20. if inpaint_mask_dir:
  21. inpaint_masks = shared.listfiles(inpaint_mask_dir)
  22. is_inpaint_batch = bool(inpaint_masks)
  23. if is_inpaint_batch:
  24. print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
  25. print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
  26. save_normally = output_dir == ''
  27. p.do_not_save_grid = True
  28. p.do_not_save_samples = not save_normally
  29. state.job_count = len(images) * p.n_iter
  30. # extract "default" params to use in case getting png info fails
  31. prompt = p.prompt
  32. negative_prompt = p.negative_prompt
  33. seed = p.seed
  34. cfg_scale = p.cfg_scale
  35. sampler_name = p.sampler_name
  36. steps = p.steps
  37. for i, image in enumerate(images):
  38. state.job = f"{i+1} out of {len(images)}"
  39. if state.skipped:
  40. state.skipped = False
  41. if state.interrupted:
  42. break
  43. try:
  44. img = Image.open(image)
  45. except UnidentifiedImageError as e:
  46. print(e)
  47. continue
  48. # Use the EXIF orientation of photos taken by smartphones.
  49. img = ImageOps.exif_transpose(img)
  50. if to_scale:
  51. p.width = int(img.width * scale_by)
  52. p.height = int(img.height * scale_by)
  53. p.init_images = [img] * p.batch_size
  54. image_path = Path(image)
  55. if is_inpaint_batch:
  56. # try to find corresponding mask for an image using simple filename matching
  57. if len(inpaint_masks) == 1:
  58. mask_image_path = inpaint_masks[0]
  59. else:
  60. # try to find corresponding mask for an image using simple filename matching
  61. mask_image_dir = Path(inpaint_mask_dir)
  62. masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
  63. if len(masks_found) == 0:
  64. print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
  65. continue
  66. # it should contain only 1 matching mask
  67. # otherwise user has many masks with the same name but different extensions
  68. mask_image_path = masks_found[0]
  69. mask_image = Image.open(mask_image_path)
  70. p.image_mask = mask_image
  71. if use_png_info:
  72. try:
  73. info_img = img
  74. if png_info_dir:
  75. info_img_path = os.path.join(png_info_dir, os.path.basename(image))
  76. info_img = Image.open(info_img_path)
  77. geninfo, _ = imgutil.read_info_from_image(info_img)
  78. parsed_parameters = parse_generation_parameters(geninfo)
  79. parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
  80. except Exception:
  81. parsed_parameters = {}
  82. p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
  83. p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
  84. p.seed = int(parsed_parameters.get("Seed", seed))
  85. p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
  86. p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
  87. p.steps = int(parsed_parameters.get("Steps", steps))
  88. proc = modules.scripts.scripts_img2img.run(p, *args)
  89. if proc is None:
  90. proc = process_images(p)
  91. for n, processed_image in enumerate(proc.images):
  92. filename = image_path.stem
  93. infotext = proc.infotext(p, n)
  94. relpath = os.path.dirname(os.path.relpath(image, input_dir))
  95. if n > 0:
  96. filename += f"-{n}"
  97. if not save_normally:
  98. os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
  99. if processed_image.mode == 'RGBA':
  100. processed_image = processed_image.convert("RGB")
  101. save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
  102. def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
  103. override_settings = create_override_settings_dict(override_settings_texts)
  104. is_batch = mode == 5
  105. if mode == 0: # img2img
  106. image = init_img.convert("RGB")
  107. mask = None
  108. elif mode == 1: # img2img sketch
  109. image = sketch.convert("RGB")
  110. mask = None
  111. elif mode == 2: # inpaint
  112. image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
  113. alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
  114. mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
  115. mask = ImageChops.lighter(alpha_mask, mask).convert('L')
  116. image = image.convert("RGB")
  117. elif mode == 3: # inpaint sketch
  118. image = inpaint_color_sketch
  119. orig = inpaint_color_sketch_orig or inpaint_color_sketch
  120. pred = np.any(np.array(image) != np.array(orig), axis=-1)
  121. mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
  122. mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
  123. blur = ImageFilter.GaussianBlur(mask_blur)
  124. image = Image.composite(image.filter(blur), orig, mask.filter(blur))
  125. image = image.convert("RGB")
  126. elif mode == 4: # inpaint upload mask
  127. image = init_img_inpaint
  128. mask = init_mask_inpaint
  129. else:
  130. image = None
  131. mask = None
  132. # Use the EXIF orientation of photos taken by smartphones.
  133. if image is not None:
  134. image = ImageOps.exif_transpose(image)
  135. if selected_scale_tab == 1 and not is_batch:
  136. assert image, "Can't scale by because no image is selected"
  137. width = int(image.width * scale_by)
  138. height = int(image.height * scale_by)
  139. assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
  140. p = StableDiffusionProcessingImg2Img(
  141. sd_model=shared.sd_model,
  142. outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
  143. outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
  144. prompt=prompt,
  145. negative_prompt=negative_prompt,
  146. styles=prompt_styles,
  147. seed=seed,
  148. subseed=subseed,
  149. subseed_strength=subseed_strength,
  150. seed_resize_from_h=seed_resize_from_h,
  151. seed_resize_from_w=seed_resize_from_w,
  152. seed_enable_extras=seed_enable_extras,
  153. sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
  154. batch_size=batch_size,
  155. n_iter=n_iter,
  156. steps=steps,
  157. cfg_scale=cfg_scale,
  158. width=width,
  159. height=height,
  160. restore_faces=restore_faces,
  161. tiling=tiling,
  162. init_images=[image],
  163. mask=mask,
  164. mask_blur=mask_blur,
  165. inpainting_fill=inpainting_fill,
  166. resize_mode=resize_mode,
  167. denoising_strength=denoising_strength,
  168. image_cfg_scale=image_cfg_scale,
  169. inpaint_full_res=inpaint_full_res,
  170. inpaint_full_res_padding=inpaint_full_res_padding,
  171. inpainting_mask_invert=inpainting_mask_invert,
  172. override_settings=override_settings,
  173. )
  174. p.scripts = modules.scripts.scripts_img2img
  175. p.script_args = args
  176. p.user = request.username
  177. if shared.cmd_opts.enable_console_prompts:
  178. print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
  179. if mask:
  180. p.extra_generation_params["Mask blur"] = mask_blur
  181. with closing(p):
  182. if is_batch:
  183. assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
  184. process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
  185. processed = Processed(p, [], p.seed, "")
  186. else:
  187. processed = modules.scripts.scripts_img2img.run(p, *args)
  188. if processed is None:
  189. processed = process_images(p)
  190. shared.total_tqdm.clear()
  191. generation_info_js = processed.js()
  192. if opts.samples_log_stdout:
  193. print(generation_info_js)
  194. if opts.do_not_show_images:
  195. processed.images = []
  196. return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")