image_embedding.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import base64
  2. import json
  3. import warnings
  4. import numpy as np
  5. import zlib
  6. from PIL import Image, ImageDraw
  7. import torch
  8. class EmbeddingEncoder(json.JSONEncoder):
  9. def default(self, obj):
  10. if isinstance(obj, torch.Tensor):
  11. return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
  12. return json.JSONEncoder.default(self, obj)
  13. class EmbeddingDecoder(json.JSONDecoder):
  14. def __init__(self, *args, **kwargs):
  15. json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
  16. def object_hook(self, d):
  17. if 'TORCHTENSOR' in d:
  18. return torch.from_numpy(np.array(d['TORCHTENSOR']))
  19. return d
  20. def embedding_to_b64(data):
  21. d = json.dumps(data, cls=EmbeddingEncoder)
  22. return base64.b64encode(d.encode())
  23. def embedding_from_b64(data):
  24. d = base64.b64decode(data)
  25. return json.loads(d, cls=EmbeddingDecoder)
  26. def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
  27. while True:
  28. seed = (a * seed + c) % m
  29. yield seed % 255
  30. def xor_block(block):
  31. g = lcg()
  32. randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
  33. return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
  34. def style_block(block, sequence):
  35. im = Image.new('RGB', (block.shape[1], block.shape[0]))
  36. draw = ImageDraw.Draw(im)
  37. i = 0
  38. for x in range(-6, im.size[0], 8):
  39. for yi, y in enumerate(range(-6, im.size[1], 8)):
  40. offset = 0
  41. if yi % 2 == 0:
  42. offset = 4
  43. shade = sequence[i % len(sequence)]
  44. i += 1
  45. draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
  46. fg = np.array(im).astype(np.uint8) & 0xF0
  47. return block ^ fg
  48. def insert_image_data_embed(image, data):
  49. d = 3
  50. data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
  51. data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
  52. data_np_high = data_np_ >> 4
  53. data_np_low = data_np_ & 0x0F
  54. h = image.size[1]
  55. next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
  56. next_size = next_size + ((h*d)-(next_size % (h*d)))
  57. data_np_low = np.resize(data_np_low, next_size)
  58. data_np_low = data_np_low.reshape((h, -1, d))
  59. data_np_high = np.resize(data_np_high, next_size)
  60. data_np_high = data_np_high.reshape((h, -1, d))
  61. edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
  62. edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
  63. data_np_low = style_block(data_np_low, sequence=edge_style)
  64. data_np_low = xor_block(data_np_low)
  65. data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
  66. data_np_high = xor_block(data_np_high)
  67. im_low = Image.fromarray(data_np_low, mode='RGB')
  68. im_high = Image.fromarray(data_np_high, mode='RGB')
  69. background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
  70. background.paste(im_low, (0, 0))
  71. background.paste(image, (im_low.size[0]+1, 0))
  72. background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
  73. return background
  74. def crop_black(img, tol=0):
  75. mask = (img > tol).all(2)
  76. mask0, mask1 = mask.any(0), mask.any(1)
  77. col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
  78. row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
  79. return img[row_start:row_end, col_start:col_end]
  80. def extract_image_data_embed(image):
  81. d = 3
  82. outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
  83. black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
  84. if black_cols[0].shape[0] < 2:
  85. print('No Image data blocks found.')
  86. return None
  87. data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
  88. data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
  89. data_block_lower = xor_block(data_block_lower)
  90. data_block_upper = xor_block(data_block_upper)
  91. data_block = (data_block_upper << 4) | (data_block_lower)
  92. data_block = data_block.flatten().tobytes()
  93. data = zlib.decompress(data_block)
  94. return json.loads(data, cls=EmbeddingDecoder)
  95. def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
  96. from modules.images import get_font
  97. if textfont:
  98. warnings.warn(
  99. 'passing in a textfont to caption_image_overlay is deprecated and does nothing',
  100. DeprecationWarning,
  101. stacklevel=2,
  102. )
  103. from math import cos
  104. image = srcimage.copy()
  105. fontsize = 32
  106. factor = 1.5
  107. gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
  108. for y in range(image.size[1]):
  109. mag = 1-cos(y/image.size[1]*factor)
  110. mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
  111. gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
  112. image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
  113. draw = ImageDraw.Draw(image)
  114. font = get_font(fontsize)
  115. padding = 10
  116. _, _, w, h = draw.textbbox((0, 0), title, font=font)
  117. fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
  118. font = get_font(fontsize)
  119. _, _, w, h = draw.textbbox((0, 0), title, font=font)
  120. draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
  121. _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
  122. fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  123. _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
  124. fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  125. _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
  126. fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  127. font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
  128. draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
  129. draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
  130. draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
  131. return image
  132. if __name__ == '__main__':
  133. testEmbed = Image.open('test_embedding.png')
  134. data = extract_image_data_embed(testEmbed)
  135. assert data is not None
  136. data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
  137. assert data is not None
  138. image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
  139. cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
  140. test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
  141. embedded_image = insert_image_data_embed(cap_image, test_embed)
  142. retrived_embed = extract_image_data_embed(embedded_image)
  143. assert str(retrived_embed) == str(test_embed)
  144. embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
  145. assert embedded_image == embedded_image2
  146. g = lcg()
  147. shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
  148. reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
  149. 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
  150. 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
  151. 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
  152. 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
  153. 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
  154. 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
  155. 204, 86, 73, 222, 44, 198, 118, 240, 97]
  156. assert shared_random == reference_random
  157. hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
  158. assert 12731374 == hunna_kay_random_sum