ui_gradio_extensions.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import gradio as gr
  3. from modules import localization, shared, scripts
  4. from modules.paths import script_path, data_path
  5. def webpath(fn):
  6. if fn.startswith(script_path):
  7. web_path = os.path.relpath(fn, script_path).replace('\\', '/')
  8. else:
  9. web_path = os.path.abspath(fn)
  10. return f'file={web_path}?{os.path.getmtime(fn)}'
  11. def javascript_html():
  12. # Ensure localization is in `window` before scripts
  13. head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
  14. script_js = os.path.join(script_path, "script.js")
  15. head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
  16. for script in scripts.list_scripts("javascript", ".js"):
  17. head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
  18. for script in scripts.list_scripts("javascript", ".mjs"):
  19. head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
  20. if shared.cmd_opts.theme:
  21. head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
  22. return head
  23. def css_html():
  24. head = ""
  25. def stylesheet(fn):
  26. return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
  27. for cssfile in scripts.list_files_with_name("style.css"):
  28. if not os.path.isfile(cssfile):
  29. continue
  30. head += stylesheet(cssfile)
  31. if os.path.exists(os.path.join(data_path, "user.css")):
  32. head += stylesheet(os.path.join(data_path, "user.css"))
  33. return head
  34. def reload_javascript():
  35. js = javascript_html()
  36. css = css_html()
  37. def template_response(*args, **kwargs):
  38. res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
  39. res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
  40. res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
  41. res.init_headers()
  42. return res
  43. gr.routes.templates.TemplateResponse = template_response
  44. if not hasattr(shared, 'GradioTemplateResponseOriginal'):
  45. shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse