sd_hijack_optimizations.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. from __future__ import annotations
  2. import math
  3. import psutil
  4. import torch
  5. from torch import einsum
  6. from ldm.util import default
  7. from einops import rearrange
  8. from modules import shared, errors, devices, sub_quadratic_attention
  9. from modules.hypernetworks import hypernetwork
  10. import ldm.modules.attention
  11. import ldm.modules.diffusionmodules.model
  12. import sgm.modules.attention
  13. import sgm.modules.diffusionmodules.model
  14. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  15. sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
  16. class SdOptimization:
  17. name: str = None
  18. label: str | None = None
  19. cmd_opt: str | None = None
  20. priority: int = 0
  21. def title(self):
  22. if self.label is None:
  23. return self.name
  24. return f"{self.name} - {self.label}"
  25. def is_available(self):
  26. return True
  27. def apply(self):
  28. pass
  29. def undo(self):
  30. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  31. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  32. sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  33. sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
  34. class SdOptimizationXformers(SdOptimization):
  35. name = "xformers"
  36. cmd_opt = "xformers"
  37. priority = 100
  38. def is_available(self):
  39. return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
  40. def apply(self):
  41. ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
  42. ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
  43. sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
  44. sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
  45. class SdOptimizationSdpNoMem(SdOptimization):
  46. name = "sdp-no-mem"
  47. label = "scaled dot product without memory efficient attention"
  48. cmd_opt = "opt_sdp_no_mem_attention"
  49. priority = 80
  50. def is_available(self):
  51. return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
  52. def apply(self):
  53. ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
  54. ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
  55. sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
  56. sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
  57. class SdOptimizationSdp(SdOptimizationSdpNoMem):
  58. name = "sdp"
  59. label = "scaled dot product"
  60. cmd_opt = "opt_sdp_attention"
  61. priority = 70
  62. def apply(self):
  63. ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
  64. ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
  65. sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
  66. sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
  67. class SdOptimizationSubQuad(SdOptimization):
  68. name = "sub-quadratic"
  69. cmd_opt = "opt_sub_quad_attention"
  70. priority = 10
  71. def apply(self):
  72. ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
  73. ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
  74. sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
  75. sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
  76. class SdOptimizationV1(SdOptimization):
  77. name = "V1"
  78. label = "original v1"
  79. cmd_opt = "opt_split_attention_v1"
  80. priority = 10
  81. def apply(self):
  82. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
  83. sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
  84. class SdOptimizationInvokeAI(SdOptimization):
  85. name = "InvokeAI"
  86. cmd_opt = "opt_split_attention_invokeai"
  87. @property
  88. def priority(self):
  89. return 1000 if not torch.cuda.is_available() else 10
  90. def apply(self):
  91. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
  92. sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
  93. class SdOptimizationDoggettx(SdOptimization):
  94. name = "Doggettx"
  95. cmd_opt = "opt_split_attention"
  96. priority = 90
  97. def apply(self):
  98. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
  99. ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
  100. sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
  101. sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
  102. def list_optimizers(res):
  103. res.extend([
  104. SdOptimizationXformers(),
  105. SdOptimizationSdpNoMem(),
  106. SdOptimizationSdp(),
  107. SdOptimizationSubQuad(),
  108. SdOptimizationV1(),
  109. SdOptimizationInvokeAI(),
  110. SdOptimizationDoggettx(),
  111. ])
  112. if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
  113. try:
  114. import xformers.ops
  115. shared.xformers_available = True
  116. except Exception:
  117. errors.report("Cannot import xformers", exc_info=True)
  118. def get_available_vram():
  119. if shared.device.type == 'cuda':
  120. stats = torch.cuda.memory_stats(shared.device)
  121. mem_active = stats['active_bytes.all.current']
  122. mem_reserved = stats['reserved_bytes.all.current']
  123. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  124. mem_free_torch = mem_reserved - mem_active
  125. mem_free_total = mem_free_cuda + mem_free_torch
  126. return mem_free_total
  127. else:
  128. return psutil.virtual_memory().available
  129. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  130. def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
  131. h = self.heads
  132. q_in = self.to_q(x)
  133. context = default(context, x)
  134. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  135. k_in = self.to_k(context_k)
  136. v_in = self.to_v(context_v)
  137. del context, context_k, context_v, x
  138. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  139. del q_in, k_in, v_in
  140. dtype = q.dtype
  141. if shared.opts.upcast_attn:
  142. q, k, v = q.float(), k.float(), v.float()
  143. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  144. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  145. for i in range(0, q.shape[0], 2):
  146. end = i + 2
  147. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  148. s1 *= self.scale
  149. s2 = s1.softmax(dim=-1)
  150. del s1
  151. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  152. del s2
  153. del q, k, v
  154. r1 = r1.to(dtype)
  155. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  156. del r1
  157. return self.to_out(r2)
  158. # taken from https://github.com/Doggettx/stable-diffusion and modified
  159. def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
  160. h = self.heads
  161. q_in = self.to_q(x)
  162. context = default(context, x)
  163. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  164. k_in = self.to_k(context_k)
  165. v_in = self.to_v(context_v)
  166. dtype = q_in.dtype
  167. if shared.opts.upcast_attn:
  168. q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
  169. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  170. k_in = k_in * self.scale
  171. del context, x
  172. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  173. del q_in, k_in, v_in
  174. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  175. mem_free_total = get_available_vram()
  176. gb = 1024 ** 3
  177. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  178. modifier = 3 if q.element_size() == 2 else 2.5
  179. mem_required = tensor_size * modifier
  180. steps = 1
  181. if mem_required > mem_free_total:
  182. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  183. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  184. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  185. if steps > 64:
  186. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  187. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  188. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  189. slice_size = q.shape[1] // steps
  190. for i in range(0, q.shape[1], slice_size):
  191. end = min(i + slice_size, q.shape[1])
  192. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  193. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  194. del s1
  195. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  196. del s2
  197. del q, k, v
  198. r1 = r1.to(dtype)
  199. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  200. del r1
  201. return self.to_out(r2)
  202. # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
  203. mem_total_gb = psutil.virtual_memory().total // (1 << 30)
  204. def einsum_op_compvis(q, k, v):
  205. s = einsum('b i d, b j d -> b i j', q, k)
  206. s = s.softmax(dim=-1, dtype=s.dtype)
  207. return einsum('b i j, b j d -> b i d', s, v)
  208. def einsum_op_slice_0(q, k, v, slice_size):
  209. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  210. for i in range(0, q.shape[0], slice_size):
  211. end = i + slice_size
  212. r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
  213. return r
  214. def einsum_op_slice_1(q, k, v, slice_size):
  215. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  216. for i in range(0, q.shape[1], slice_size):
  217. end = i + slice_size
  218. r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
  219. return r
  220. def einsum_op_mps_v1(q, k, v):
  221. if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
  222. return einsum_op_compvis(q, k, v)
  223. else:
  224. slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
  225. if slice_size % 4096 == 0:
  226. slice_size -= 1
  227. return einsum_op_slice_1(q, k, v, slice_size)
  228. def einsum_op_mps_v2(q, k, v):
  229. if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
  230. return einsum_op_compvis(q, k, v)
  231. else:
  232. return einsum_op_slice_0(q, k, v, 1)
  233. def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
  234. size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
  235. if size_mb <= max_tensor_mb:
  236. return einsum_op_compvis(q, k, v)
  237. div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
  238. if div <= q.shape[0]:
  239. return einsum_op_slice_0(q, k, v, q.shape[0] // div)
  240. return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
  241. def einsum_op_cuda(q, k, v):
  242. stats = torch.cuda.memory_stats(q.device)
  243. mem_active = stats['active_bytes.all.current']
  244. mem_reserved = stats['reserved_bytes.all.current']
  245. mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
  246. mem_free_torch = mem_reserved - mem_active
  247. mem_free_total = mem_free_cuda + mem_free_torch
  248. # Divide factor of safety as there's copying and fragmentation
  249. return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
  250. def einsum_op(q, k, v):
  251. if q.device.type == 'cuda':
  252. return einsum_op_cuda(q, k, v)
  253. if q.device.type == 'mps':
  254. if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
  255. return einsum_op_mps_v1(q, k, v)
  256. return einsum_op_mps_v2(q, k, v)
  257. # Smaller slices are faster due to L2/L3/SLC caches.
  258. # Tested on i7 with 8MB L3 cache.
  259. return einsum_op_tensor_mem(q, k, v, 32)
  260. def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
  261. h = self.heads
  262. q = self.to_q(x)
  263. context = default(context, x)
  264. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  265. k = self.to_k(context_k)
  266. v = self.to_v(context_v)
  267. del context, context_k, context_v, x
  268. dtype = q.dtype
  269. if shared.opts.upcast_attn:
  270. q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
  271. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  272. k = k * self.scale
  273. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
  274. r = einsum_op(q, k, v)
  275. r = r.to(dtype)
  276. return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
  277. # -- End of code from https://github.com/invoke-ai/InvokeAI --
  278. # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
  279. # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
  280. def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
  281. assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
  282. h = self.heads
  283. q = self.to_q(x)
  284. context = default(context, x)
  285. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  286. k = self.to_k(context_k)
  287. v = self.to_v(context_v)
  288. del context, context_k, context_v, x
  289. q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  290. k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  291. v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  292. if q.device.type == 'mps':
  293. q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
  294. dtype = q.dtype
  295. if shared.opts.upcast_attn:
  296. q, k = q.float(), k.float()
  297. x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  298. x = x.to(dtype)
  299. x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
  300. out_proj, dropout = self.to_out
  301. x = out_proj(x)
  302. x = dropout(x)
  303. return x
  304. def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
  305. bytes_per_token = torch.finfo(q.dtype).bits//8
  306. batch_x_heads, q_tokens, _ = q.shape
  307. _, k_tokens, _ = k.shape
  308. qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
  309. if chunk_threshold is None:
  310. chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
  311. elif chunk_threshold == 0:
  312. chunk_threshold_bytes = None
  313. else:
  314. chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
  315. if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
  316. kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
  317. elif kv_chunk_size_min == 0:
  318. kv_chunk_size_min = None
  319. if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
  320. # the big matmul fits into our memory limit; do everything in 1 chunk,
  321. # i.e. send it down the unchunked fast-path
  322. kv_chunk_size = k_tokens
  323. with devices.without_autocast(disable=q.dtype == v.dtype):
  324. return sub_quadratic_attention.efficient_dot_product_attention(
  325. q,
  326. k,
  327. v,
  328. query_chunk_size=q_chunk_size,
  329. kv_chunk_size=kv_chunk_size,
  330. kv_chunk_size_min = kv_chunk_size_min,
  331. use_checkpoint=use_checkpoint,
  332. )
  333. def get_xformers_flash_attention_op(q, k, v):
  334. if not shared.cmd_opts.xformers_flash_attention:
  335. return None
  336. try:
  337. flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
  338. fw, bw = flash_attention_op
  339. if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
  340. return flash_attention_op
  341. except Exception as e:
  342. errors.display_once(e, "enabling flash attention")
  343. return None
  344. def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
  345. h = self.heads
  346. q_in = self.to_q(x)
  347. context = default(context, x)
  348. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  349. k_in = self.to_k(context_k)
  350. v_in = self.to_v(context_v)
  351. q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
  352. del q_in, k_in, v_in
  353. dtype = q.dtype
  354. if shared.opts.upcast_attn:
  355. q, k, v = q.float(), k.float(), v.float()
  356. out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
  357. out = out.to(dtype)
  358. out = rearrange(out, 'b n h d -> b n (h d)', h=h)
  359. return self.to_out(out)
  360. # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
  361. # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
  362. def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
  363. batch_size, sequence_length, inner_dim = x.shape
  364. if mask is not None:
  365. mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
  366. mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
  367. h = self.heads
  368. q_in = self.to_q(x)
  369. context = default(context, x)
  370. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  371. k_in = self.to_k(context_k)
  372. v_in = self.to_v(context_v)
  373. head_dim = inner_dim // h
  374. q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  375. k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  376. v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  377. del q_in, k_in, v_in
  378. dtype = q.dtype
  379. if shared.opts.upcast_attn:
  380. q, k, v = q.float(), k.float(), v.float()
  381. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  382. hidden_states = torch.nn.functional.scaled_dot_product_attention(
  383. q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
  384. )
  385. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
  386. hidden_states = hidden_states.to(dtype)
  387. # linear proj
  388. hidden_states = self.to_out[0](hidden_states)
  389. # dropout
  390. hidden_states = self.to_out[1](hidden_states)
  391. return hidden_states
  392. def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
  393. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  394. return scaled_dot_product_attention_forward(self, x, context, mask)
  395. def cross_attention_attnblock_forward(self, x):
  396. h_ = x
  397. h_ = self.norm(h_)
  398. q1 = self.q(h_)
  399. k1 = self.k(h_)
  400. v = self.v(h_)
  401. # compute attention
  402. b, c, h, w = q1.shape
  403. q2 = q1.reshape(b, c, h*w)
  404. del q1
  405. q = q2.permute(0, 2, 1) # b,hw,c
  406. del q2
  407. k = k1.reshape(b, c, h*w) # b,c,hw
  408. del k1
  409. h_ = torch.zeros_like(k, device=q.device)
  410. mem_free_total = get_available_vram()
  411. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  412. mem_required = tensor_size * 2.5
  413. steps = 1
  414. if mem_required > mem_free_total:
  415. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  416. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  417. for i in range(0, q.shape[1], slice_size):
  418. end = i + slice_size
  419. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  420. w2 = w1 * (int(c)**(-0.5))
  421. del w1
  422. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  423. del w2
  424. # attend to values
  425. v1 = v.reshape(b, c, h*w)
  426. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  427. del w3
  428. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  429. del v1, w4
  430. h2 = h_.reshape(b, c, h, w)
  431. del h_
  432. h3 = self.proj_out(h2)
  433. del h2
  434. h3 += x
  435. return h3
  436. def xformers_attnblock_forward(self, x):
  437. try:
  438. h_ = x
  439. h_ = self.norm(h_)
  440. q = self.q(h_)
  441. k = self.k(h_)
  442. v = self.v(h_)
  443. b, c, h, w = q.shape
  444. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  445. dtype = q.dtype
  446. if shared.opts.upcast_attn:
  447. q, k = q.float(), k.float()
  448. q = q.contiguous()
  449. k = k.contiguous()
  450. v = v.contiguous()
  451. out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
  452. out = out.to(dtype)
  453. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  454. out = self.proj_out(out)
  455. return x + out
  456. except NotImplementedError:
  457. return cross_attention_attnblock_forward(self, x)
  458. def sdp_attnblock_forward(self, x):
  459. h_ = x
  460. h_ = self.norm(h_)
  461. q = self.q(h_)
  462. k = self.k(h_)
  463. v = self.v(h_)
  464. b, c, h, w = q.shape
  465. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  466. dtype = q.dtype
  467. if shared.opts.upcast_attn:
  468. q, k, v = q.float(), k.float(), v.float()
  469. q = q.contiguous()
  470. k = k.contiguous()
  471. v = v.contiguous()
  472. out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
  473. out = out.to(dtype)
  474. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  475. out = self.proj_out(out)
  476. return x + out
  477. def sdp_no_mem_attnblock_forward(self, x):
  478. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  479. return sdp_attnblock_forward(self, x)
  480. def sub_quad_attnblock_forward(self, x):
  481. h_ = x
  482. h_ = self.norm(h_)
  483. q = self.q(h_)
  484. k = self.k(h_)
  485. v = self.v(h_)
  486. b, c, h, w = q.shape
  487. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  488. q = q.contiguous()
  489. k = k.contiguous()
  490. v = v.contiguous()
  491. out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  492. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  493. out = self.proj_out(out)
  494. return x + out