api.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. import base64
  2. import io
  3. import os
  4. import time
  5. import datetime
  6. import uvicorn
  7. import gradio as gr
  8. from threading import Lock
  9. from io import BytesIO
  10. from fastapi import APIRouter, Depends, FastAPI, Request, Response
  11. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  12. from fastapi.exceptions import HTTPException
  13. from fastapi.responses import JSONResponse
  14. from fastapi.encoders import jsonable_encoder
  15. from secrets import compare_digest
  16. import modules.shared as shared
  17. from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
  18. from modules.api import models
  19. from modules.shared import opts
  20. from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
  21. from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
  22. from modules.textual_inversion.preprocess import preprocess
  23. from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
  24. from PIL import PngImagePlugin, Image
  25. from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
  26. from modules.sd_vae import vae_dict
  27. from modules.sd_models_config import find_checkpoint_config_near_filename
  28. from modules.realesrgan_model import get_realesrgan_models
  29. from modules import devices
  30. from typing import Dict, List, Any
  31. import piexif
  32. import piexif.helper
  33. from contextlib import closing
  34. def script_name_to_index(name, scripts):
  35. try:
  36. return [script.title().lower() for script in scripts].index(name.lower())
  37. except Exception as e:
  38. raise HTTPException(
  39. status_code=422, detail=f"Script '{name}' not found") from e
  40. def validate_sampler_name(name):
  41. config = sd_samplers.all_samplers_map.get(name, None)
  42. if config is None:
  43. raise HTTPException(status_code=404, detail="Sampler not found")
  44. return name
  45. def setUpscalers(req: dict):
  46. reqDict = vars(req)
  47. reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
  48. reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
  49. return reqDict
  50. def decode_base64_to_image(encoding):
  51. if encoding.startswith("data:image/"):
  52. encoding = encoding.split(";")[1].split(",")[1]
  53. try:
  54. image = Image.open(BytesIO(base64.b64decode(encoding)))
  55. return image
  56. except Exception as e:
  57. raise HTTPException(
  58. status_code=500, detail="Invalid encoded image") from e
  59. def encode_pil_to_base64(image):
  60. with io.BytesIO() as output_bytes:
  61. if opts.samples_format.lower() == 'png':
  62. use_metadata = False
  63. metadata = PngImagePlugin.PngInfo()
  64. for key, value in image.info.items():
  65. if isinstance(key, str) and isinstance(value, str):
  66. metadata.add_text(key, value)
  67. use_metadata = True
  68. image.save(output_bytes, format="PNG", pnginfo=(
  69. metadata if use_metadata else None), quality=opts.jpeg_quality)
  70. elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
  71. if image.mode == "RGBA":
  72. image = image.convert("RGB")
  73. parameters = image.info.get('parameters', None)
  74. exif_bytes = piexif.dump({
  75. "Exif": {piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode")}
  76. })
  77. if opts.samples_format.lower() in ("jpg", "jpeg"):
  78. image.save(output_bytes, format="JPEG",
  79. exif=exif_bytes, quality=opts.jpeg_quality)
  80. else:
  81. image.save(output_bytes, format="WEBP",
  82. exif=exif_bytes, quality=opts.jpeg_quality)
  83. else:
  84. raise HTTPException(status_code=500, detail="Invalid image format")
  85. bytes_data = output_bytes.getvalue()
  86. return base64.b64encode(bytes_data)
  87. def api_middleware(app: FastAPI):
  88. rich_available = False
  89. try:
  90. if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
  91. import anyio # importing just so it can be placed on silent list
  92. import starlette # importing just so it can be placed on silent list
  93. from rich.console import Console
  94. console = Console()
  95. rich_available = True
  96. except Exception:
  97. pass
  98. @app.middleware("http")
  99. async def log_and_time(req: Request, call_next):
  100. ts = time.time()
  101. res: Response = await call_next(req)
  102. duration = str(round(time.time() - ts, 4))
  103. res.headers["X-Process-Time"] = duration
  104. endpoint = req.scope.get('path', 'err')
  105. if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
  106. print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
  107. t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
  108. code=res.status_code,
  109. ver=req.scope.get('http_version', '0.0'),
  110. cli=req.scope.get('client', ('0:0.0.0', 0))[0],
  111. prot=req.scope.get('scheme', 'err'),
  112. method=req.scope.get('method', 'err'),
  113. endpoint=endpoint,
  114. duration=duration,
  115. ))
  116. return res
  117. def handle_exception(request: Request, e: Exception):
  118. err = {
  119. "error": type(e).__name__,
  120. "detail": vars(e).get('detail', ''),
  121. "body": vars(e).get('body', ''),
  122. "errors": str(e),
  123. }
  124. # do not print backtrace on known httpexceptions
  125. if not isinstance(e, HTTPException):
  126. message = f"API error: {request.method}: {request.url} {err}"
  127. if rich_available:
  128. print(message)
  129. console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[
  130. anyio, starlette], word_wrap=False, width=min([console.width, 200]))
  131. else:
  132. errors.report(message, exc_info=True)
  133. return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
  134. @app.middleware("http")
  135. async def exception_handling(request: Request, call_next):
  136. try:
  137. return await call_next(request)
  138. except Exception as e:
  139. return handle_exception(request, e)
  140. @app.exception_handler(Exception)
  141. async def fastapi_exception_handler(request: Request, e: Exception):
  142. return handle_exception(request, e)
  143. @app.exception_handler(HTTPException)
  144. async def http_exception_handler(request: Request, e: HTTPException):
  145. return handle_exception(request, e)
  146. class Api:
  147. def __init__(self, app: FastAPI, queue_lock: Lock):
  148. if shared.cmd_opts.api_auth:
  149. self.credentials = {}
  150. for auth in shared.cmd_opts.api_auth.split(","):
  151. user, password = auth.split(":")
  152. self.credentials[user] = password
  153. self.router = APIRouter()
  154. self.app = app
  155. self.queue_lock = queue_lock
  156. api_middleware(self.app)
  157. self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi,
  158. methods=["POST"], response_model=models.TextToImageResponse)
  159. self.add_api_route("/sdapi/v1/img2img", self.img2imgapi,
  160. methods=["POST"], response_model=models.ImageToImageResponse)
  161. self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api,
  162. methods=["POST"], response_model=models.ExtrasSingleImageResponse)
  163. self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api,
  164. methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
  165. self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi,
  166. methods=["POST"], response_model=models.PNGInfoResponse)
  167. self.add_api_route("/sdapi/v1/progress", self.progressapi,
  168. methods=["GET", "POST"], response_model=models.ProgressResponse)
  169. self.add_api_route("/sdapi/v1/interrogate",
  170. self.interrogateapi, methods=["POST"])
  171. self.add_api_route("/sdapi/v1/interrupt",
  172. self.interruptapi, methods=["POST"])
  173. self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
  174. self.add_api_route("/sdapi/v1/options", self.get_config,
  175. methods=["GET"], response_model=models.OptionsModel)
  176. self.add_api_route("/sdapi/v1/options",
  177. self.set_config, methods=["POST"])
  178. self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags,
  179. methods=["GET"], response_model=models.FlagsModel)
  180. self.add_api_route("/sdapi/v1/samplers", self.get_samplers,
  181. methods=["GET"], response_model=List[models.SamplerItem])
  182. self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers,
  183. methods=["GET"], response_model=List[models.UpscalerItem])
  184. self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes,
  185. methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
  186. self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models,
  187. methods=["GET"], response_model=List[models.SDModelItem])
  188. self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes,
  189. methods=["GET"], response_model=List[models.SDVaeItem])
  190. self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks,
  191. methods=["GET"], response_model=List[models.HypernetworkItem])
  192. self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers,
  193. methods=["GET"], response_model=List[models.FaceRestorerItem])
  194. self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models,
  195. methods=["GET"], response_model=List[models.RealesrganItem])
  196. self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles,
  197. methods=["GET"], response_model=List[models.PromptStyleItem])
  198. self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings,
  199. methods=["GET"], response_model=models.EmbeddingsResponse)
  200. self.add_api_route("/sdapi/v1/refresh-checkpoints",
  201. self.refresh_checkpoints, methods=["POST"])
  202. self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding,
  203. methods=["POST"], response_model=models.CreateResponse)
  204. self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork,
  205. methods=["POST"], response_model=models.CreateResponse)
  206. self.add_api_route("/sdapi/v1/preprocess", self.preprocess,
  207. methods=["POST"], response_model=models.PreprocessResponse)
  208. self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding,
  209. methods=["POST"], response_model=models.TrainResponse)
  210. self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork,
  211. methods=["POST"], response_model=models.TrainResponse)
  212. self.add_api_route("/sdapi/v1/memory", self.get_memory,
  213. methods=["GET"], response_model=models.MemoryResponse)
  214. self.add_api_route("/sdapi/v1/unload-checkpoint",
  215. self.unloadapi, methods=["POST"])
  216. self.add_api_route("/sdapi/v1/reload-checkpoint",
  217. self.reloadapi, methods=["POST"])
  218. self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list,
  219. methods=["GET"], response_model=models.ScriptsList)
  220. self.add_api_route("/sdapi/v1/script-info", self.get_script_info,
  221. methods=["GET"], response_model=List[models.ScriptInfo])
  222. if shared.cmd_opts.api_server_stop:
  223. self.add_api_route("/sdapi/v1/server-kill",
  224. self.kill_webui, methods=["POST"])
  225. self.add_api_route("/sdapi/v1/server-restart",
  226. self.restart_webui, methods=["POST"])
  227. self.add_api_route("/sdapi/v1/server-stop",
  228. self.stop_webui, methods=["POST"])
  229. self.default_script_arg_txt2img = []
  230. self.default_script_arg_img2img = []
  231. def add_api_route(self, path: str, endpoint, **kwargs):
  232. if shared.cmd_opts.api_auth:
  233. return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
  234. return self.app.add_api_route(path, endpoint, **kwargs)
  235. def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
  236. if credentials.username in self.credentials:
  237. if compare_digest(credentials.password, self.credentials[credentials.username]):
  238. return True
  239. raise HTTPException(status_code=401, detail="Incorrect username or password", headers={
  240. "WWW-Authenticate": "Basic"})
  241. def get_selectable_script(self, script_name, script_runner):
  242. if script_name is None or script_name == "":
  243. return None, None
  244. script_idx = script_name_to_index(
  245. script_name, script_runner.selectable_scripts)
  246. script = script_runner.selectable_scripts[script_idx]
  247. return script, script_idx
  248. def get_scripts_list(self):
  249. t2ilist = [
  250. script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
  251. i2ilist = [
  252. script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
  253. return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
  254. def get_script_info(self):
  255. res = []
  256. for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
  257. res += [script.api_info for script in script_list if script.api_info is not None]
  258. return res
  259. def get_script(self, script_name, script_runner):
  260. if script_name is None or script_name == "":
  261. return None, None
  262. script_idx = script_name_to_index(script_name, script_runner.scripts)
  263. return script_runner.scripts[script_idx]
  264. def init_default_script_args(self, script_runner):
  265. # find max idx from the scripts in runner and generate a none array to init script_args
  266. last_arg_index = 1
  267. for script in script_runner.scripts:
  268. if last_arg_index < script.args_to:
  269. last_arg_index = script.args_to
  270. # None everywhere except position 0 to initialize script args
  271. script_args = [None]*last_arg_index
  272. script_args[0] = 0
  273. # get default values
  274. with gr.Blocks(): # will throw errors calling ui function without this
  275. for script in script_runner.scripts:
  276. if script.ui(script.is_img2img):
  277. ui_default_values = []
  278. for elem in script.ui(script.is_img2img):
  279. ui_default_values.append(elem.value)
  280. script_args[script.args_from:script.args_to] = ui_default_values
  281. return script_args
  282. def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
  283. script_args = default_script_args.copy()
  284. # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
  285. if selectable_scripts:
  286. script_args[selectable_scripts.args_from:
  287. selectable_scripts.args_to] = request.script_args
  288. script_args[0] = selectable_idx + 1
  289. # Now check for always on scripts
  290. if request.alwayson_scripts:
  291. for alwayson_script_name in request.alwayson_scripts.keys():
  292. alwayson_script = self.get_script(
  293. alwayson_script_name, script_runner)
  294. if alwayson_script is None:
  295. raise HTTPException(
  296. status_code=422, detail=f"always on script {alwayson_script_name} not found")
  297. # Selectable script in always on script param check
  298. if alwayson_script.alwayson is False:
  299. raise HTTPException(
  300. status_code=422, detail="Cannot have a selectable script in the always on scripts params")
  301. # always on script with no arg should always run so you don't really need to add them to the requests
  302. if "args" in request.alwayson_scripts[alwayson_script_name]:
  303. # min between arg length in scriptrunner and arg length in the request
  304. for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
  305. script_args[alwayson_script.args_from +
  306. idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
  307. return script_args
  308. def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
  309. script_runner = scripts.scripts_txt2img
  310. task_id = txt2imgreq.task_id
  311. if task_id is None:
  312. raise HTTPException(status_code=404, detail="task_id not found")
  313. if not script_runner.scripts:
  314. script_runner.initialize_scripts(False)
  315. ui.create_ui()
  316. if not self.default_script_arg_txt2img:
  317. self.default_script_arg_txt2img = self.init_default_script_args(
  318. script_runner)
  319. selectable_scripts, selectable_script_idx = self.get_selectable_script(
  320. txt2imgreq.script_name, script_runner)
  321. populate = txt2imgreq.copy(update={ # Override __init__ params
  322. "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
  323. "do_not_save_samples": not txt2imgreq.save_images,
  324. "do_not_save_grid": not txt2imgreq.save_images,
  325. })
  326. if populate.sampler_name:
  327. populate.sampler_index = None # prevent a warning later on
  328. args = vars(populate)
  329. args.pop('script_name', None)
  330. # will refeed them to the pipeline directly after initializing them
  331. args.pop('script_args', None)
  332. args.pop('alwayson_scripts', None)
  333. script_args = self.init_script_args(
  334. txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
  335. send_images = args.pop('send_images', True)
  336. args.pop('save_images', None)
  337. with self.queue_lock:
  338. with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
  339. p.scripts = script_runner
  340. p.outpath_grids = opts.outdir_txt2img_grids
  341. p.outpath_samples = opts.outdir_txt2img_samples
  342. try:
  343. shared.state.begin(job="scripts_txt2img")
  344. shared.state.task_id = task_id
  345. if selectable_scripts is not None:
  346. p.script_args = script_args
  347. processed = scripts.scripts_txt2img.run(
  348. p, *p.script_args) # Need to pass args as list here
  349. else:
  350. # Need to pass args as tuple here
  351. p.script_args = tuple(script_args)
  352. processed = process_images(p)
  353. finally:
  354. shared.state.end()
  355. b64images = list(
  356. map(encode_pil_to_base64, processed.images)) if send_images else []
  357. return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
  358. def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
  359. init_images = img2imgreq.init_images
  360. if init_images is None:
  361. raise HTTPException(status_code=404, detail="Init image not found")
  362. task_id = img2imgreq.task_id
  363. if task_id is None:
  364. raise HTTPException(status_code=404, detail="task_id not found")
  365. mask = img2imgreq.mask
  366. if mask:
  367. mask = decode_base64_to_image(mask)
  368. script_runner = scripts.scripts_img2img
  369. if not script_runner.scripts:
  370. script_runner.initialize_scripts(True)
  371. ui.create_ui()
  372. if not self.default_script_arg_img2img:
  373. self.default_script_arg_img2img = self.init_default_script_args(
  374. script_runner)
  375. selectable_scripts, selectable_script_idx = self.get_selectable_script(
  376. img2imgreq.script_name, script_runner)
  377. populate = img2imgreq.copy(update={ # Override __init__ params
  378. "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
  379. "do_not_save_samples": not img2imgreq.save_images,
  380. "do_not_save_grid": not img2imgreq.save_images,
  381. "mask": mask,
  382. })
  383. if populate.sampler_name:
  384. populate.sampler_index = None # prevent a warning later on
  385. args = vars(populate)
  386. # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
  387. args.pop('include_init_images', None)
  388. args.pop('script_name', None)
  389. # will refeed them to the pipeline directly after initializing them
  390. args.pop('script_args', None)
  391. args.pop('alwayson_scripts', None)
  392. script_args = self.init_script_args(
  393. img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
  394. send_images = args.pop('send_images', True)
  395. args.pop('save_images', None)
  396. with self.queue_lock:
  397. with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
  398. p.init_images = [decode_base64_to_image(
  399. x) for x in init_images]
  400. p.scripts = script_runner
  401. p.outpath_grids = opts.outdir_img2img_grids
  402. p.outpath_samples = opts.outdir_img2img_samples
  403. try:
  404. shared.state.begin(job="scripts_img2img")
  405. shared.state.task_id = task_id
  406. if selectable_scripts is not None:
  407. p.script_args = script_args
  408. processed = scripts.scripts_img2img.run(
  409. p, *p.script_args) # Need to pass args as list here
  410. else:
  411. # Need to pass args as tuple here
  412. p.script_args = tuple(script_args)
  413. processed = process_images(p)
  414. finally:
  415. shared.state.end()
  416. b64images = list(
  417. map(encode_pil_to_base64, processed.images)) if send_images else []
  418. if not img2imgreq.include_init_images:
  419. img2imgreq.init_images = None
  420. img2imgreq.mask = None
  421. return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
  422. def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
  423. reqDict = setUpscalers(req)
  424. reqDict['image'] = decode_base64_to_image(reqDict['image'])
  425. with self.queue_lock:
  426. result = postprocessing.run_extras(
  427. extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
  428. return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
  429. def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
  430. reqDict = setUpscalers(req)
  431. image_list = reqDict.pop('imageList', [])
  432. image_folder = [decode_base64_to_image(x.data) for x in image_list]
  433. with self.queue_lock:
  434. result = postprocessing.run_extras(
  435. extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
  436. return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
  437. def pnginfoapi(self, req: models.PNGInfoRequest):
  438. if (not req.image.strip()):
  439. return models.PNGInfoResponse(info="")
  440. image = decode_base64_to_image(req.image.strip())
  441. if image is None:
  442. return models.PNGInfoResponse(info="")
  443. geninfo, items = images.read_info_from_image(image)
  444. if geninfo is None:
  445. geninfo = ""
  446. items = {**{'parameters': geninfo}, **items}
  447. return models.PNGInfoResponse(info=geninfo, items=items)
  448. def progressapi(self, req: models.ProgressRequest = Depends()):
  449. # copy from check_progress_call of ui.py
  450. task_id = req.task_id
  451. if len(task_id) <= 0:
  452. raise HTTPException(status_code=404, detail="task_id not found")
  453. if shared.state.job_count == 0:
  454. return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
  455. # avoid dividing zero
  456. progress = 0.01
  457. if shared.state.job_count > 0:
  458. progress += shared.state.job_no / shared.state.job_count
  459. if shared.state.sampling_steps > 0:
  460. progress += 1 / shared.state.job_count * \
  461. shared.state.sampling_step / shared.state.sampling_steps
  462. time_since_start = time.time() - shared.state.time_start
  463. eta = (time_since_start/progress)
  464. eta_relative = eta-time_since_start
  465. progress = min(progress, 1)
  466. shared.state.set_current_image()
  467. current_image = None
  468. if shared.state.current_image and not req.skip_current_image:
  469. current_image = encode_pil_to_base64(shared.state.current_image)
  470. return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
  471. def interrogateapi(self, interrogatereq: models.InterrogateRequest):
  472. image_b64 = interrogatereq.image
  473. if image_b64 is None:
  474. raise HTTPException(status_code=404, detail="Image not found")
  475. img = decode_base64_to_image(image_b64)
  476. img = img.convert('RGB')
  477. # Override object param
  478. with self.queue_lock:
  479. if interrogatereq.model == "clip":
  480. processed = shared.interrogator.interrogate(img)
  481. elif interrogatereq.model == "deepdanbooru":
  482. processed = deepbooru.model.tag(img)
  483. else:
  484. raise HTTPException(status_code=404, detail="Model not found")
  485. return models.InterrogateResponse(caption=processed)
  486. def interruptapi(self, interruptreq: models.InterruptRequest):
  487. task_id = interruptreq.task_id
  488. if len(task_id) <= 0:
  489. raise HTTPException(status_code=404, detail="invalid task")
  490. if shared.state.task_id != task_id:
  491. raise HTTPException(status_code=404, detail="no match task")
  492. shared.state.interrupt()
  493. return {}
  494. def unloadapi(self):
  495. unload_model_weights()
  496. return {}
  497. def reloadapi(self):
  498. reload_model_weights()
  499. return {}
  500. def skip(self):
  501. shared.state.skip()
  502. def get_config(self):
  503. options = {}
  504. for key in shared.opts.data.keys():
  505. metadata = shared.opts.data_labels.get(key)
  506. if (metadata is not None):
  507. options.update({key: shared.opts.data.get(
  508. key, shared.opts.data_labels.get(key).default)})
  509. else:
  510. options.update({key: shared.opts.data.get(key, None)})
  511. return options
  512. def set_config(self, req: Dict[str, Any]):
  513. checkpoint_name = req.get("sd_model_checkpoint", None)
  514. if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
  515. raise RuntimeError(f"model {checkpoint_name!r} not found")
  516. for k, v in req.items():
  517. shared.opts.set(k, v)
  518. shared.opts.save(shared.config_filename)
  519. return
  520. def get_cmd_flags(self):
  521. return vars(shared.cmd_opts)
  522. def get_samplers(self):
  523. return [{"name": sampler[0], "aliases": sampler[2], "options": sampler[3]} for sampler in sd_samplers.all_samplers]
  524. def get_upscalers(self):
  525. return [
  526. {
  527. "name": upscaler.name,
  528. "model_name": upscaler.scaler.model_name,
  529. "model_path": upscaler.data_path,
  530. "model_url": None,
  531. "scale": upscaler.scale,
  532. }
  533. for upscaler in shared.sd_upscalers
  534. ]
  535. def get_latent_upscale_modes(self):
  536. return [
  537. {
  538. "name": upscale_mode,
  539. }
  540. for upscale_mode in [*(shared.latent_upscale_modes or {})]
  541. ]
  542. def get_sd_models(self):
  543. return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
  544. def get_sd_vaes(self):
  545. return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
  546. def get_hypernetworks(self):
  547. return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
  548. def get_face_restorers(self):
  549. return [{"name": x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
  550. def get_realesrgan_models(self):
  551. return [{"name": x.name, "path": x.data_path, "scale": x.scale} for x in get_realesrgan_models(None)]
  552. def get_prompt_styles(self):
  553. styleList = []
  554. for k in shared.prompt_styles.styles:
  555. style = shared.prompt_styles.styles[k]
  556. styleList.append(
  557. {"name": style[0], "prompt": style[1], "negative_prompt": style[2]})
  558. return styleList
  559. def get_embeddings(self):
  560. db = sd_hijack.model_hijack.embedding_db
  561. def convert_embedding(embedding):
  562. return {
  563. "step": embedding.step,
  564. "sd_checkpoint": embedding.sd_checkpoint,
  565. "sd_checkpoint_name": embedding.sd_checkpoint_name,
  566. "shape": embedding.shape,
  567. "vectors": embedding.vectors,
  568. }
  569. def convert_embeddings(embeddings):
  570. return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
  571. return {
  572. "loaded": convert_embeddings(db.word_embeddings),
  573. "skipped": convert_embeddings(db.skipped_embeddings),
  574. }
  575. def refresh_checkpoints(self):
  576. with self.queue_lock:
  577. shared.refresh_checkpoints()
  578. def create_embedding(self, args: dict):
  579. try:
  580. shared.state.begin(job="create_embedding")
  581. filename = create_embedding(**args) # create empty embedding
  582. # reload embeddings so new one can be immediately used
  583. sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
  584. return models.CreateResponse(info=f"create embedding filename: {filename}")
  585. except AssertionError as e:
  586. return models.TrainResponse(info=f"create embedding error: {e}")
  587. finally:
  588. shared.state.end()
  589. def create_hypernetwork(self, args: dict):
  590. try:
  591. shared.state.begin(job="create_hypernetwork")
  592. filename = create_hypernetwork(**args) # create empty embedding
  593. return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
  594. except AssertionError as e:
  595. return models.TrainResponse(info=f"create hypernetwork error: {e}")
  596. finally:
  597. shared.state.end()
  598. def preprocess(self, args: dict):
  599. try:
  600. shared.state.begin(job="preprocess")
  601. # quick operation unless blip/booru interrogation is enabled
  602. preprocess(**args)
  603. shared.state.end()
  604. return models.PreprocessResponse(info='preprocess complete')
  605. except KeyError as e:
  606. return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
  607. except Exception as e:
  608. return models.PreprocessResponse(info=f"preprocess error: {e}")
  609. finally:
  610. shared.state.end()
  611. def train_embedding(self, args: dict):
  612. try:
  613. shared.state.begin(job="train_embedding")
  614. apply_optimizations = shared.opts.training_xattention_optimizations
  615. error = None
  616. filename = ''
  617. if not apply_optimizations:
  618. sd_hijack.undo_optimizations()
  619. try:
  620. embedding, filename = train_embedding(
  621. **args) # can take a long time to complete
  622. except Exception as e:
  623. error = e
  624. finally:
  625. if not apply_optimizations:
  626. sd_hijack.apply_optimizations()
  627. return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
  628. except Exception as msg:
  629. return models.TrainResponse(info=f"train embedding error: {msg}")
  630. finally:
  631. shared.state.end()
  632. def train_hypernetwork(self, args: dict):
  633. try:
  634. shared.state.begin(job="train_hypernetwork")
  635. shared.loaded_hypernetworks = []
  636. apply_optimizations = shared.opts.training_xattention_optimizations
  637. error = None
  638. filename = ''
  639. if not apply_optimizations:
  640. sd_hijack.undo_optimizations()
  641. try:
  642. hypernetwork, filename = train_hypernetwork(**args)
  643. except Exception as e:
  644. error = e
  645. finally:
  646. shared.sd_model.cond_stage_model.to(devices.device)
  647. shared.sd_model.first_stage_model.to(devices.device)
  648. if not apply_optimizations:
  649. sd_hijack.apply_optimizations()
  650. shared.state.end()
  651. return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
  652. except Exception as exc:
  653. return models.TrainResponse(info=f"train embedding error: {exc}")
  654. finally:
  655. shared.state.end()
  656. def get_memory(self):
  657. try:
  658. import os
  659. import psutil
  660. process = psutil.Process(os.getpid())
  661. # only rss is cross-platform guaranteed so we dont rely on other values
  662. res = process.memory_info()
  663. # and total memory is calculated as actual value is not cross-platform safe
  664. ram_total = 100 * res.rss / process.memory_percent()
  665. ram = {'free': ram_total - res.rss,
  666. 'used': res.rss, 'total': ram_total}
  667. except Exception as err:
  668. ram = {'error': f'{err}'}
  669. try:
  670. import torch
  671. if torch.cuda.is_available():
  672. s = torch.cuda.mem_get_info()
  673. system = {'free': s[0], 'used': s[1] - s[0], 'total': s[1]}
  674. s = dict(torch.cuda.memory_stats(shared.device))
  675. allocated = {
  676. 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak']}
  677. reserved = {
  678. 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak']}
  679. active = {
  680. 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak']}
  681. inactive = {'current': s['inactive_split_bytes.all.current'],
  682. 'peak': s['inactive_split_bytes.all.peak']}
  683. warnings = {
  684. 'retries': s['num_alloc_retries'], 'oom': s['num_ooms']}
  685. cuda = {
  686. 'system': system,
  687. 'active': active,
  688. 'allocated': allocated,
  689. 'reserved': reserved,
  690. 'inactive': inactive,
  691. 'events': warnings,
  692. }
  693. else:
  694. cuda = {'error': 'unavailable'}
  695. except Exception as err:
  696. cuda = {'error': f'{err}'}
  697. return models.MemoryResponse(ram=ram, cuda=cuda)
  698. def launch(self, server_name, port, root_path):
  699. self.app.include_router(self.router)
  700. uvicorn.run(self.app, host=server_name, port=port,
  701. timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
  702. def kill_webui(self):
  703. restart.stop_program()
  704. def restart_webui(self):
  705. if restart.is_restartable():
  706. restart.restart_program()
  707. return Response(status_code=501)
  708. def stop_webui(request):
  709. shared.state.server_command = "stop"
  710. return Response("Stopping.")