sampler.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """SAMPLING ONLY."""
  2. import torch
  3. from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
  4. from modules import shared, devices
  5. class UniPCSampler(object):
  6. def __init__(self, model, **kwargs):
  7. super().__init__()
  8. self.model = model
  9. to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
  10. self.before_sample = None
  11. self.after_sample = None
  12. self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
  13. def register_buffer(self, name, attr):
  14. if type(attr) == torch.Tensor:
  15. if attr.device != devices.device:
  16. attr = attr.to(devices.device)
  17. setattr(self, name, attr)
  18. def set_hooks(self, before_sample, after_sample, after_update):
  19. self.before_sample = before_sample
  20. self.after_sample = after_sample
  21. self.after_update = after_update
  22. @torch.no_grad()
  23. def sample(self,
  24. S,
  25. batch_size,
  26. shape,
  27. conditioning=None,
  28. callback=None,
  29. normals_sequence=None,
  30. img_callback=None,
  31. quantize_x0=False,
  32. eta=0.,
  33. mask=None,
  34. x0=None,
  35. temperature=1.,
  36. noise_dropout=0.,
  37. score_corrector=None,
  38. corrector_kwargs=None,
  39. verbose=True,
  40. x_T=None,
  41. log_every_t=100,
  42. unconditional_guidance_scale=1.,
  43. unconditional_conditioning=None,
  44. # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
  45. **kwargs
  46. ):
  47. if conditioning is not None:
  48. if isinstance(conditioning, dict):
  49. ctmp = conditioning[list(conditioning.keys())[0]]
  50. while isinstance(ctmp, list):
  51. ctmp = ctmp[0]
  52. cbs = ctmp.shape[0]
  53. if cbs != batch_size:
  54. print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
  55. elif isinstance(conditioning, list):
  56. for ctmp in conditioning:
  57. if ctmp.shape[0] != batch_size:
  58. print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
  59. else:
  60. if conditioning.shape[0] != batch_size:
  61. print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
  62. # sampling
  63. C, H, W = shape
  64. size = (batch_size, C, H, W)
  65. # print(f'Data shape for UniPC sampling is {size}')
  66. device = self.model.betas.device
  67. if x_T is None:
  68. img = torch.randn(size, device=device)
  69. else:
  70. img = x_T
  71. ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
  72. # SD 1.X is "noise", SD 2.X is "v"
  73. model_type = "v" if self.model.parameterization == "v" else "noise"
  74. model_fn = model_wrapper(
  75. lambda x, t, c: self.model.apply_model(x, t, c),
  76. ns,
  77. model_type=model_type,
  78. guidance_type="classifier-free",
  79. #condition=conditioning,
  80. #unconditional_condition=unconditional_conditioning,
  81. guidance_scale=unconditional_guidance_scale,
  82. )
  83. uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
  84. x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
  85. return x.to(device), None