ui_tempdir.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. import tempfile
  3. from collections import namedtuple
  4. from pathlib import Path
  5. import gradio.components
  6. from PIL import PngImagePlugin
  7. from modules import shared
  8. Savedfile = namedtuple("Savedfile", ["name"])
  9. def register_tmp_file(gradio, filename):
  10. if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
  11. gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
  12. if hasattr(gradio, 'temp_dirs'): # gradio 3.9
  13. gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
  14. def check_tmp_file(gradio, filename):
  15. if hasattr(gradio, 'temp_file_sets'):
  16. return any(filename in fileset for fileset in gradio.temp_file_sets)
  17. if hasattr(gradio, 'temp_dirs'):
  18. return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
  19. return False
  20. def save_pil_to_file(self, pil_image, dir=None, format="png"):
  21. already_saved_as = getattr(pil_image, 'already_saved_as', None)
  22. if already_saved_as and os.path.isfile(already_saved_as):
  23. register_tmp_file(shared.demo, already_saved_as)
  24. filename = already_saved_as
  25. if not shared.opts.save_images_add_number:
  26. filename += f'?{os.path.getmtime(already_saved_as)}'
  27. return filename
  28. if shared.opts.temp_dir != "":
  29. dir = shared.opts.temp_dir
  30. use_metadata = False
  31. metadata = PngImagePlugin.PngInfo()
  32. for key, value in pil_image.info.items():
  33. if isinstance(key, str) and isinstance(value, str):
  34. metadata.add_text(key, value)
  35. use_metadata = True
  36. file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
  37. pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
  38. return file_obj.name
  39. # override save to file function so that it also writes PNG info
  40. gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
  41. def on_tmpdir_changed():
  42. if shared.opts.temp_dir == "" or shared.demo is None:
  43. return
  44. os.makedirs(shared.opts.temp_dir, exist_ok=True)
  45. register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
  46. def cleanup_tmpdr():
  47. temp_dir = shared.opts.temp_dir
  48. if temp_dir == "" or not os.path.isdir(temp_dir):
  49. return
  50. for root, _, files in os.walk(temp_dir, topdown=False):
  51. for name in files:
  52. _, extension = os.path.splitext(name)
  53. if extension != ".png":
  54. continue
  55. filename = os.path.join(root, name)
  56. os.remove(filename)