uni_pc.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863
  1. import torch
  2. import math
  3. import tqdm
  4. class NoiseScheduleVP:
  5. def __init__(
  6. self,
  7. schedule='discrete',
  8. betas=None,
  9. alphas_cumprod=None,
  10. continuous_beta_0=0.1,
  11. continuous_beta_1=20.,
  12. ):
  13. """Create a wrapper class for the forward SDE (VP type).
  14. ***
  15. Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
  16. We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
  17. ***
  18. The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
  19. We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
  20. Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
  21. log_alpha_t = self.marginal_log_mean_coeff(t)
  22. sigma_t = self.marginal_std(t)
  23. lambda_t = self.marginal_lambda(t)
  24. Moreover, as lambda(t) is an invertible function, we also support its inverse function:
  25. t = self.inverse_lambda(lambda_t)
  26. ===============================================================
  27. We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
  28. 1. For discrete-time DPMs:
  29. For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
  30. t_i = (i + 1) / N
  31. e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
  32. We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
  33. Args:
  34. betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
  35. alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
  36. Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
  37. **Important**: Please pay special attention for the args for `alphas_cumprod`:
  38. The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
  39. q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
  40. Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
  41. alpha_{t_n} = \sqrt{\hat{alpha_n}},
  42. and
  43. log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
  44. 2. For continuous-time DPMs:
  45. We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
  46. schedule are the default settings in DDPM and improved-DDPM:
  47. Args:
  48. beta_min: A `float` number. The smallest beta for the linear schedule.
  49. beta_max: A `float` number. The largest beta for the linear schedule.
  50. cosine_s: A `float` number. The hyperparameter in the cosine schedule.
  51. cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
  52. T: A `float` number. The ending time of the forward process.
  53. ===============================================================
  54. Args:
  55. schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
  56. 'linear' or 'cosine' for continuous-time DPMs.
  57. Returns:
  58. A wrapper object of the forward SDE (VP type).
  59. ===============================================================
  60. Example:
  61. # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
  62. >>> ns = NoiseScheduleVP('discrete', betas=betas)
  63. # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
  64. >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
  65. # For continuous-time DPMs (VPSDE), linear schedule:
  66. >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
  67. """
  68. if schedule not in ['discrete', 'linear', 'cosine']:
  69. raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
  70. self.schedule = schedule
  71. if schedule == 'discrete':
  72. if betas is not None:
  73. log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
  74. else:
  75. assert alphas_cumprod is not None
  76. log_alphas = 0.5 * torch.log(alphas_cumprod)
  77. self.total_N = len(log_alphas)
  78. self.T = 1.
  79. self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
  80. self.log_alpha_array = log_alphas.reshape((1, -1,))
  81. else:
  82. self.total_N = 1000
  83. self.beta_0 = continuous_beta_0
  84. self.beta_1 = continuous_beta_1
  85. self.cosine_s = 0.008
  86. self.cosine_beta_max = 999.
  87. self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
  88. self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
  89. self.schedule = schedule
  90. if schedule == 'cosine':
  91. # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
  92. # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
  93. self.T = 0.9946
  94. else:
  95. self.T = 1.
  96. def marginal_log_mean_coeff(self, t):
  97. """
  98. Compute log(alpha_t) of a given continuous-time label t in [0, T].
  99. """
  100. if self.schedule == 'discrete':
  101. return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
  102. elif self.schedule == 'linear':
  103. return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
  104. elif self.schedule == 'cosine':
  105. log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
  106. log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
  107. return log_alpha_t
  108. def marginal_alpha(self, t):
  109. """
  110. Compute alpha_t of a given continuous-time label t in [0, T].
  111. """
  112. return torch.exp(self.marginal_log_mean_coeff(t))
  113. def marginal_std(self, t):
  114. """
  115. Compute sigma_t of a given continuous-time label t in [0, T].
  116. """
  117. return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
  118. def marginal_lambda(self, t):
  119. """
  120. Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
  121. """
  122. log_mean_coeff = self.marginal_log_mean_coeff(t)
  123. log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
  124. return log_mean_coeff - log_std
  125. def inverse_lambda(self, lamb):
  126. """
  127. Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
  128. """
  129. if self.schedule == 'linear':
  130. tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
  131. Delta = self.beta_0**2 + tmp
  132. return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
  133. elif self.schedule == 'discrete':
  134. log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
  135. t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
  136. return t.reshape((-1,))
  137. else:
  138. log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
  139. t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
  140. t = t_fn(log_alpha)
  141. return t
  142. def model_wrapper(
  143. model,
  144. noise_schedule,
  145. model_type="noise",
  146. model_kwargs=None,
  147. guidance_type="uncond",
  148. #condition=None,
  149. #unconditional_condition=None,
  150. guidance_scale=1.,
  151. classifier_fn=None,
  152. classifier_kwargs=None,
  153. ):
  154. """Create a wrapper function for the noise prediction model.
  155. DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
  156. firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
  157. We support four types of the diffusion model by setting `model_type`:
  158. 1. "noise": noise prediction model. (Trained by predicting noise).
  159. 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
  160. 3. "v": velocity prediction model. (Trained by predicting the velocity).
  161. The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
  162. [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
  163. arXiv preprint arXiv:2202.00512 (2022).
  164. [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
  165. arXiv preprint arXiv:2210.02303 (2022).
  166. 4. "score": marginal score function. (Trained by denoising score matching).
  167. Note that the score function and the noise prediction model follows a simple relationship:
  168. ```
  169. noise(x_t, t) = -sigma_t * score(x_t, t)
  170. ```
  171. We support three types of guided sampling by DPMs by setting `guidance_type`:
  172. 1. "uncond": unconditional sampling by DPMs.
  173. The input `model` has the following format:
  174. ``
  175. model(x, t_input, **model_kwargs) -> noise | x_start | v | score
  176. ``
  177. 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
  178. The input `model` has the following format:
  179. ``
  180. model(x, t_input, **model_kwargs) -> noise | x_start | v | score
  181. ``
  182. The input `classifier_fn` has the following format:
  183. ``
  184. classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
  185. ``
  186. [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
  187. in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
  188. 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
  189. The input `model` has the following format:
  190. ``
  191. model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
  192. ``
  193. And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
  194. [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
  195. arXiv preprint arXiv:2207.12598 (2022).
  196. The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
  197. or continuous-time labels (i.e. epsilon to T).
  198. We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
  199. ``
  200. def model_fn(x, t_continuous) -> noise:
  201. t_input = get_model_input_time(t_continuous)
  202. return noise_pred(model, x, t_input, **model_kwargs)
  203. ``
  204. where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
  205. ===============================================================
  206. Args:
  207. model: A diffusion model with the corresponding format described above.
  208. noise_schedule: A noise schedule object, such as NoiseScheduleVP.
  209. model_type: A `str`. The parameterization type of the diffusion model.
  210. "noise" or "x_start" or "v" or "score".
  211. model_kwargs: A `dict`. A dict for the other inputs of the model function.
  212. guidance_type: A `str`. The type of the guidance for sampling.
  213. "uncond" or "classifier" or "classifier-free".
  214. condition: A pytorch tensor. The condition for the guided sampling.
  215. Only used for "classifier" or "classifier-free" guidance type.
  216. unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
  217. Only used for "classifier-free" guidance type.
  218. guidance_scale: A `float`. The scale for the guided sampling.
  219. classifier_fn: A classifier function. Only used for the classifier guidance.
  220. classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
  221. Returns:
  222. A noise prediction model that accepts the noised data and the continuous time as the inputs.
  223. """
  224. model_kwargs = model_kwargs or {}
  225. classifier_kwargs = classifier_kwargs or {}
  226. def get_model_input_time(t_continuous):
  227. """
  228. Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
  229. For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
  230. For continuous-time DPMs, we just use `t_continuous`.
  231. """
  232. if noise_schedule.schedule == 'discrete':
  233. return (t_continuous - 1. / noise_schedule.total_N) * 1000.
  234. else:
  235. return t_continuous
  236. def noise_pred_fn(x, t_continuous, cond=None):
  237. if t_continuous.reshape((-1,)).shape[0] == 1:
  238. t_continuous = t_continuous.expand((x.shape[0]))
  239. t_input = get_model_input_time(t_continuous)
  240. if cond is None:
  241. output = model(x, t_input, None, **model_kwargs)
  242. else:
  243. output = model(x, t_input, cond, **model_kwargs)
  244. if model_type == "noise":
  245. return output
  246. elif model_type == "x_start":
  247. alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
  248. dims = x.dim()
  249. return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
  250. elif model_type == "v":
  251. alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
  252. dims = x.dim()
  253. return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
  254. elif model_type == "score":
  255. sigma_t = noise_schedule.marginal_std(t_continuous)
  256. dims = x.dim()
  257. return -expand_dims(sigma_t, dims) * output
  258. def cond_grad_fn(x, t_input, condition):
  259. """
  260. Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
  261. """
  262. with torch.enable_grad():
  263. x_in = x.detach().requires_grad_(True)
  264. log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
  265. return torch.autograd.grad(log_prob.sum(), x_in)[0]
  266. def model_fn(x, t_continuous, condition, unconditional_condition):
  267. """
  268. The noise predicition model function that is used for DPM-Solver.
  269. """
  270. if t_continuous.reshape((-1,)).shape[0] == 1:
  271. t_continuous = t_continuous.expand((x.shape[0]))
  272. if guidance_type == "uncond":
  273. return noise_pred_fn(x, t_continuous)
  274. elif guidance_type == "classifier":
  275. assert classifier_fn is not None
  276. t_input = get_model_input_time(t_continuous)
  277. cond_grad = cond_grad_fn(x, t_input, condition)
  278. sigma_t = noise_schedule.marginal_std(t_continuous)
  279. noise = noise_pred_fn(x, t_continuous)
  280. return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
  281. elif guidance_type == "classifier-free":
  282. if guidance_scale == 1. or unconditional_condition is None:
  283. return noise_pred_fn(x, t_continuous, cond=condition)
  284. else:
  285. x_in = torch.cat([x] * 2)
  286. t_in = torch.cat([t_continuous] * 2)
  287. if isinstance(condition, dict):
  288. assert isinstance(unconditional_condition, dict)
  289. c_in = {}
  290. for k in condition:
  291. if isinstance(condition[k], list):
  292. c_in[k] = [torch.cat([
  293. unconditional_condition[k][i],
  294. condition[k][i]]) for i in range(len(condition[k]))]
  295. else:
  296. c_in[k] = torch.cat([
  297. unconditional_condition[k],
  298. condition[k]])
  299. elif isinstance(condition, list):
  300. c_in = []
  301. assert isinstance(unconditional_condition, list)
  302. for i in range(len(condition)):
  303. c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
  304. else:
  305. c_in = torch.cat([unconditional_condition, condition])
  306. noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
  307. return noise_uncond + guidance_scale * (noise - noise_uncond)
  308. assert model_type in ["noise", "x_start", "v"]
  309. assert guidance_type in ["uncond", "classifier", "classifier-free"]
  310. return model_fn
  311. class UniPC:
  312. def __init__(
  313. self,
  314. model_fn,
  315. noise_schedule,
  316. predict_x0=True,
  317. thresholding=False,
  318. max_val=1.,
  319. variant='bh1',
  320. condition=None,
  321. unconditional_condition=None,
  322. before_sample=None,
  323. after_sample=None,
  324. after_update=None
  325. ):
  326. """Construct a UniPC.
  327. We support both data_prediction and noise_prediction.
  328. """
  329. self.model_fn_ = model_fn
  330. self.noise_schedule = noise_schedule
  331. self.variant = variant
  332. self.predict_x0 = predict_x0
  333. self.thresholding = thresholding
  334. self.max_val = max_val
  335. self.condition = condition
  336. self.unconditional_condition = unconditional_condition
  337. self.before_sample = before_sample
  338. self.after_sample = after_sample
  339. self.after_update = after_update
  340. def dynamic_thresholding_fn(self, x0, t=None):
  341. """
  342. The dynamic thresholding method.
  343. """
  344. dims = x0.dim()
  345. p = self.dynamic_thresholding_ratio
  346. s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
  347. s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
  348. x0 = torch.clamp(x0, -s, s) / s
  349. return x0
  350. def model(self, x, t):
  351. cond = self.condition
  352. uncond = self.unconditional_condition
  353. if self.before_sample is not None:
  354. x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
  355. res = self.model_fn_(x, t, cond, uncond)
  356. if self.after_sample is not None:
  357. x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
  358. if isinstance(res, tuple):
  359. # (None, pred_x0)
  360. res = res[1]
  361. return res
  362. def noise_prediction_fn(self, x, t):
  363. """
  364. Return the noise prediction model.
  365. """
  366. return self.model(x, t)
  367. def data_prediction_fn(self, x, t):
  368. """
  369. Return the data prediction model (with thresholding).
  370. """
  371. noise = self.noise_prediction_fn(x, t)
  372. dims = x.dim()
  373. alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
  374. x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
  375. if self.thresholding:
  376. p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
  377. s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
  378. s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
  379. x0 = torch.clamp(x0, -s, s) / s
  380. return x0
  381. def model_fn(self, x, t):
  382. """
  383. Convert the model to the noise prediction model or the data prediction model.
  384. """
  385. if self.predict_x0:
  386. return self.data_prediction_fn(x, t)
  387. else:
  388. return self.noise_prediction_fn(x, t)
  389. def get_time_steps(self, skip_type, t_T, t_0, N, device):
  390. """Compute the intermediate time steps for sampling.
  391. """
  392. if skip_type == 'logSNR':
  393. lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
  394. lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
  395. logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
  396. return self.noise_schedule.inverse_lambda(logSNR_steps)
  397. elif skip_type == 'time_uniform':
  398. return torch.linspace(t_T, t_0, N + 1).to(device)
  399. elif skip_type == 'time_quadratic':
  400. t_order = 2
  401. t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
  402. return t
  403. else:
  404. raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
  405. def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
  406. """
  407. Get the order of each step for sampling by the singlestep DPM-Solver.
  408. """
  409. if order == 3:
  410. K = steps // 3 + 1
  411. if steps % 3 == 0:
  412. orders = [3,] * (K - 2) + [2, 1]
  413. elif steps % 3 == 1:
  414. orders = [3,] * (K - 1) + [1]
  415. else:
  416. orders = [3,] * (K - 1) + [2]
  417. elif order == 2:
  418. if steps % 2 == 0:
  419. K = steps // 2
  420. orders = [2,] * K
  421. else:
  422. K = steps // 2 + 1
  423. orders = [2,] * (K - 1) + [1]
  424. elif order == 1:
  425. K = steps
  426. orders = [1,] * steps
  427. else:
  428. raise ValueError("'order' must be '1' or '2' or '3'.")
  429. if skip_type == 'logSNR':
  430. # To reproduce the results in DPM-Solver paper
  431. timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
  432. else:
  433. timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
  434. return timesteps_outer, orders
  435. def denoise_to_zero_fn(self, x, s):
  436. """
  437. Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
  438. """
  439. return self.data_prediction_fn(x, s)
  440. def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
  441. if len(t.shape) == 0:
  442. t = t.view(-1)
  443. if 'bh' in self.variant:
  444. return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
  445. else:
  446. assert self.variant == 'vary_coeff'
  447. return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
  448. def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
  449. #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
  450. ns = self.noise_schedule
  451. assert order <= len(model_prev_list)
  452. # first compute rks
  453. t_prev_0 = t_prev_list[-1]
  454. lambda_prev_0 = ns.marginal_lambda(t_prev_0)
  455. lambda_t = ns.marginal_lambda(t)
  456. model_prev_0 = model_prev_list[-1]
  457. sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
  458. log_alpha_t = ns.marginal_log_mean_coeff(t)
  459. alpha_t = torch.exp(log_alpha_t)
  460. h = lambda_t - lambda_prev_0
  461. rks = []
  462. D1s = []
  463. for i in range(1, order):
  464. t_prev_i = t_prev_list[-(i + 1)]
  465. model_prev_i = model_prev_list[-(i + 1)]
  466. lambda_prev_i = ns.marginal_lambda(t_prev_i)
  467. rk = (lambda_prev_i - lambda_prev_0) / h
  468. rks.append(rk)
  469. D1s.append((model_prev_i - model_prev_0) / rk)
  470. rks.append(1.)
  471. rks = torch.tensor(rks, device=x.device)
  472. K = len(rks)
  473. # build C matrix
  474. C = []
  475. col = torch.ones_like(rks)
  476. for k in range(1, K + 1):
  477. C.append(col)
  478. col = col * rks / (k + 1)
  479. C = torch.stack(C, dim=1)
  480. if len(D1s) > 0:
  481. D1s = torch.stack(D1s, dim=1) # (B, K)
  482. C_inv_p = torch.linalg.inv(C[:-1, :-1])
  483. A_p = C_inv_p
  484. if use_corrector:
  485. #print('using corrector')
  486. C_inv = torch.linalg.inv(C)
  487. A_c = C_inv
  488. hh = -h if self.predict_x0 else h
  489. h_phi_1 = torch.expm1(hh)
  490. h_phi_ks = []
  491. factorial_k = 1
  492. h_phi_k = h_phi_1
  493. for k in range(1, K + 2):
  494. h_phi_ks.append(h_phi_k)
  495. h_phi_k = h_phi_k / hh - 1 / factorial_k
  496. factorial_k *= (k + 1)
  497. model_t = None
  498. if self.predict_x0:
  499. x_t_ = (
  500. sigma_t / sigma_prev_0 * x
  501. - alpha_t * h_phi_1 * model_prev_0
  502. )
  503. # now predictor
  504. x_t = x_t_
  505. if len(D1s) > 0:
  506. # compute the residuals for predictor
  507. for k in range(K - 1):
  508. x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
  509. # now corrector
  510. if use_corrector:
  511. model_t = self.model_fn(x_t, t)
  512. D1_t = (model_t - model_prev_0)
  513. x_t = x_t_
  514. k = 0
  515. for k in range(K - 1):
  516. x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
  517. x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
  518. else:
  519. log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
  520. x_t_ = (
  521. (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
  522. - (sigma_t * h_phi_1) * model_prev_0
  523. )
  524. # now predictor
  525. x_t = x_t_
  526. if len(D1s) > 0:
  527. # compute the residuals for predictor
  528. for k in range(K - 1):
  529. x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
  530. # now corrector
  531. if use_corrector:
  532. model_t = self.model_fn(x_t, t)
  533. D1_t = (model_t - model_prev_0)
  534. x_t = x_t_
  535. k = 0
  536. for k in range(K - 1):
  537. x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
  538. x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
  539. return x_t, model_t
  540. def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
  541. #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
  542. ns = self.noise_schedule
  543. assert order <= len(model_prev_list)
  544. dims = x.dim()
  545. # first compute rks
  546. t_prev_0 = t_prev_list[-1]
  547. lambda_prev_0 = ns.marginal_lambda(t_prev_0)
  548. lambda_t = ns.marginal_lambda(t)
  549. model_prev_0 = model_prev_list[-1]
  550. sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
  551. log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
  552. alpha_t = torch.exp(log_alpha_t)
  553. h = lambda_t - lambda_prev_0
  554. rks = []
  555. D1s = []
  556. for i in range(1, order):
  557. t_prev_i = t_prev_list[-(i + 1)]
  558. model_prev_i = model_prev_list[-(i + 1)]
  559. lambda_prev_i = ns.marginal_lambda(t_prev_i)
  560. rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
  561. rks.append(rk)
  562. D1s.append((model_prev_i - model_prev_0) / rk)
  563. rks.append(1.)
  564. rks = torch.tensor(rks, device=x.device)
  565. R = []
  566. b = []
  567. hh = -h[0] if self.predict_x0 else h[0]
  568. h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
  569. h_phi_k = h_phi_1 / hh - 1
  570. factorial_i = 1
  571. if self.variant == 'bh1':
  572. B_h = hh
  573. elif self.variant == 'bh2':
  574. B_h = torch.expm1(hh)
  575. else:
  576. raise NotImplementedError()
  577. for i in range(1, order + 1):
  578. R.append(torch.pow(rks, i - 1))
  579. b.append(h_phi_k * factorial_i / B_h)
  580. factorial_i *= (i + 1)
  581. h_phi_k = h_phi_k / hh - 1 / factorial_i
  582. R = torch.stack(R)
  583. b = torch.tensor(b, device=x.device)
  584. # now predictor
  585. use_predictor = len(D1s) > 0 and x_t is None
  586. if len(D1s) > 0:
  587. D1s = torch.stack(D1s, dim=1) # (B, K)
  588. if x_t is None:
  589. # for order 2, we use a simplified version
  590. if order == 2:
  591. rhos_p = torch.tensor([0.5], device=b.device)
  592. else:
  593. rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
  594. else:
  595. D1s = None
  596. if use_corrector:
  597. #print('using corrector')
  598. # for order 1, we use a simplified version
  599. if order == 1:
  600. rhos_c = torch.tensor([0.5], device=b.device)
  601. else:
  602. rhos_c = torch.linalg.solve(R, b)
  603. model_t = None
  604. if self.predict_x0:
  605. x_t_ = (
  606. expand_dims(sigma_t / sigma_prev_0, dims) * x
  607. - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
  608. )
  609. if x_t is None:
  610. if use_predictor:
  611. pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
  612. else:
  613. pred_res = 0
  614. x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
  615. if use_corrector:
  616. model_t = self.model_fn(x_t, t)
  617. if D1s is not None:
  618. corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
  619. else:
  620. corr_res = 0
  621. D1_t = (model_t - model_prev_0)
  622. x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
  623. else:
  624. x_t_ = (
  625. expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
  626. - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
  627. )
  628. if x_t is None:
  629. if use_predictor:
  630. pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
  631. else:
  632. pred_res = 0
  633. x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
  634. if use_corrector:
  635. model_t = self.model_fn(x_t, t)
  636. if D1s is not None:
  637. corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
  638. else:
  639. corr_res = 0
  640. D1_t = (model_t - model_prev_0)
  641. x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
  642. return x_t, model_t
  643. def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
  644. method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
  645. atol=0.0078, rtol=0.05, corrector=False,
  646. ):
  647. t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
  648. t_T = self.noise_schedule.T if t_start is None else t_start
  649. device = x.device
  650. if method == 'multistep':
  651. assert steps >= order, "UniPC order must be < sampling steps"
  652. timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
  653. #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
  654. assert timesteps.shape[0] - 1 == steps
  655. with torch.no_grad():
  656. vec_t = timesteps[0].expand((x.shape[0]))
  657. model_prev_list = [self.model_fn(x, vec_t)]
  658. t_prev_list = [vec_t]
  659. with tqdm.tqdm(total=steps) as pbar:
  660. # Init the first `order` values by lower order multistep DPM-Solver.
  661. for init_order in range(1, order):
  662. vec_t = timesteps[init_order].expand(x.shape[0])
  663. x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
  664. if model_x is None:
  665. model_x = self.model_fn(x, vec_t)
  666. if self.after_update is not None:
  667. self.after_update(x, model_x)
  668. model_prev_list.append(model_x)
  669. t_prev_list.append(vec_t)
  670. pbar.update()
  671. for step in range(order, steps + 1):
  672. vec_t = timesteps[step].expand(x.shape[0])
  673. if lower_order_final:
  674. step_order = min(order, steps + 1 - step)
  675. else:
  676. step_order = order
  677. #print('this step order:', step_order)
  678. if step == steps:
  679. #print('do not run corrector at the last step')
  680. use_corrector = False
  681. else:
  682. use_corrector = True
  683. x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
  684. if self.after_update is not None:
  685. self.after_update(x, model_x)
  686. for i in range(order - 1):
  687. t_prev_list[i] = t_prev_list[i + 1]
  688. model_prev_list[i] = model_prev_list[i + 1]
  689. t_prev_list[-1] = vec_t
  690. # We do not need to evaluate the final model value.
  691. if step < steps:
  692. if model_x is None:
  693. model_x = self.model_fn(x, vec_t)
  694. model_prev_list[-1] = model_x
  695. pbar.update()
  696. else:
  697. raise NotImplementedError()
  698. if denoise_to_zero:
  699. x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
  700. return x
  701. #############################################################
  702. # other utility functions
  703. #############################################################
  704. def interpolate_fn(x, xp, yp):
  705. """
  706. A piecewise linear function y = f(x), using xp and yp as keypoints.
  707. We implement f(x) in a differentiable way (i.e. applicable for autograd).
  708. The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
  709. Args:
  710. x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
  711. xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
  712. yp: PyTorch tensor with shape [C, K].
  713. Returns:
  714. The function values f(x), with shape [N, C].
  715. """
  716. N, K = x.shape[0], xp.shape[1]
  717. all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
  718. sorted_all_x, x_indices = torch.sort(all_x, dim=2)
  719. x_idx = torch.argmin(x_indices, dim=2)
  720. cand_start_idx = x_idx - 1
  721. start_idx = torch.where(
  722. torch.eq(x_idx, 0),
  723. torch.tensor(1, device=x.device),
  724. torch.where(
  725. torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
  726. ),
  727. )
  728. end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
  729. start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
  730. end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
  731. start_idx2 = torch.where(
  732. torch.eq(x_idx, 0),
  733. torch.tensor(0, device=x.device),
  734. torch.where(
  735. torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
  736. ),
  737. )
  738. y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
  739. start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
  740. end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
  741. cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
  742. return cand
  743. def expand_dims(v, dims):
  744. """
  745. Expand the tensor `v` to the dim `dims`.
  746. Args:
  747. `v`: a PyTorch tensor with shape [N].
  748. `dim`: a `int`.
  749. Returns:
  750. a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
  751. """
  752. return v[(...,) + (None,)*(dims - 1)]