sd_hijack_inpainting.py 4.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. import ldm.models.diffusion.ddpm
  3. import ldm.models.diffusion.ddim
  4. import ldm.models.diffusion.plms
  5. from ldm.models.diffusion.ddim import noise_like
  6. from ldm.models.diffusion.sampling_util import norm_thresholding
  7. @torch.no_grad()
  8. def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
  9. temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
  10. unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
  11. b, *_, device = *x.shape, x.device
  12. def get_model_output(x, t):
  13. if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
  14. e_t = self.model.apply_model(x, t, c)
  15. else:
  16. x_in = torch.cat([x] * 2)
  17. t_in = torch.cat([t] * 2)
  18. if isinstance(c, dict):
  19. assert isinstance(unconditional_conditioning, dict)
  20. c_in = {}
  21. for k in c:
  22. if isinstance(c[k], list):
  23. c_in[k] = [
  24. torch.cat([unconditional_conditioning[k][i], c[k][i]])
  25. for i in range(len(c[k]))
  26. ]
  27. else:
  28. c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
  29. else:
  30. c_in = torch.cat([unconditional_conditioning, c])
  31. e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
  32. e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
  33. if score_corrector is not None:
  34. assert self.model.parameterization == "eps"
  35. e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
  36. return e_t
  37. alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
  38. alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
  39. sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
  40. sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
  41. def get_x_prev_and_pred_x0(e_t, index):
  42. # select parameters corresponding to the currently considered timestep
  43. a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
  44. a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
  45. sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
  46. sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
  47. # current prediction for x_0
  48. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  49. if quantize_denoised:
  50. pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
  51. if dynamic_threshold is not None:
  52. pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
  53. # direction pointing to x_t
  54. dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
  55. noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
  56. if noise_dropout > 0.:
  57. noise = torch.nn.functional.dropout(noise, p=noise_dropout)
  58. x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
  59. return x_prev, pred_x0
  60. e_t = get_model_output(x, t)
  61. if len(old_eps) == 0:
  62. # Pseudo Improved Euler (2nd order)
  63. x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
  64. e_t_next = get_model_output(x_prev, t_next)
  65. e_t_prime = (e_t + e_t_next) / 2
  66. elif len(old_eps) == 1:
  67. # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
  68. e_t_prime = (3 * e_t - old_eps[-1]) / 2
  69. elif len(old_eps) == 2:
  70. # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
  71. e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
  72. elif len(old_eps) >= 3:
  73. # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
  74. e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
  75. x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
  76. return x_prev, pred_x0, e_t
  77. def do_inpainting_hijack():
  78. # p_sample_plms is needed because PLMS can't work with dicts as conditionings
  79. ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms