styles.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import csv
  2. import os
  3. import os.path
  4. import re
  5. import typing
  6. import shutil
  7. class PromptStyle(typing.NamedTuple):
  8. name: str
  9. prompt: str
  10. negative_prompt: str
  11. def merge_prompts(style_prompt: str, prompt: str) -> str:
  12. if "{prompt}" in style_prompt:
  13. res = style_prompt.replace("{prompt}", prompt)
  14. else:
  15. parts = filter(None, (prompt.strip(), style_prompt.strip()))
  16. res = ", ".join(parts)
  17. return res
  18. def apply_styles_to_prompt(prompt, styles):
  19. for style in styles:
  20. prompt = merge_prompts(style, prompt)
  21. return prompt
  22. re_spaces = re.compile(" +")
  23. def extract_style_text_from_prompt(style_text, prompt):
  24. stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
  25. stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
  26. if "{prompt}" in stripped_style_text:
  27. left, right = stripped_style_text.split("{prompt}", 2)
  28. if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
  29. prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
  30. return True, prompt
  31. else:
  32. if stripped_prompt.endswith(stripped_style_text):
  33. prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
  34. if prompt.endswith(', '):
  35. prompt = prompt[:-2]
  36. return True, prompt
  37. return False, prompt
  38. def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
  39. if not style.prompt and not style.negative_prompt:
  40. return False, prompt, negative_prompt
  41. match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
  42. if not match_positive:
  43. return False, prompt, negative_prompt
  44. match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
  45. if not match_negative:
  46. return False, prompt, negative_prompt
  47. return True, extracted_positive, extracted_negative
  48. class StyleDatabase:
  49. def __init__(self, path: str):
  50. self.no_style = PromptStyle("None", "", "")
  51. self.styles = {}
  52. self.path = path
  53. self.reload()
  54. def reload(self):
  55. self.styles.clear()
  56. if not os.path.exists(self.path):
  57. return
  58. with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
  59. reader = csv.DictReader(file, skipinitialspace=True)
  60. for row in reader:
  61. # Support loading old CSV format with "name, text"-columns
  62. prompt = row["prompt"] if "prompt" in row else row["text"]
  63. negative_prompt = row.get("negative_prompt", "")
  64. self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
  65. def get_style_prompts(self, styles):
  66. return [self.styles.get(x, self.no_style).prompt for x in styles]
  67. def get_negative_style_prompts(self, styles):
  68. return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
  69. def apply_styles_to_prompt(self, prompt, styles):
  70. return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
  71. def apply_negative_styles_to_prompt(self, prompt, styles):
  72. return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
  73. def save_styles(self, path: str) -> None:
  74. # Always keep a backup file around
  75. if os.path.exists(path):
  76. shutil.copy(path, f"{path}.bak")
  77. fd = os.open(path, os.O_RDWR | os.O_CREAT)
  78. with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
  79. # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
  80. # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
  81. writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
  82. writer.writeheader()
  83. writer.writerows(style._asdict() for k, style in self.styles.items())
  84. def extract_styles_from_prompt(self, prompt, negative_prompt):
  85. extracted = []
  86. applicable_styles = list(self.styles.values())
  87. while True:
  88. found_style = None
  89. for style in applicable_styles:
  90. is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
  91. if is_match:
  92. found_style = style
  93. prompt = new_prompt
  94. negative_prompt = new_neg_prompt
  95. break
  96. if not found_style:
  97. break
  98. applicable_styles.remove(found_style)
  99. extracted.append(found_style.name)
  100. return list(reversed(extracted)), prompt, negative_prompt