sd_samplers.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
  2. # imports for functions that previously were here and are used by other modules
  3. from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
  4. all_samplers = [
  5. *sd_samplers_kdiffusion.samplers_data_k_diffusion,
  6. *sd_samplers_compvis.samplers_data_compvis,
  7. ]
  8. all_samplers_map = {x.name: x for x in all_samplers}
  9. samplers = []
  10. samplers_for_img2img = []
  11. samplers_map = {}
  12. def find_sampler_config(name):
  13. if name is not None:
  14. config = all_samplers_map.get(name, None)
  15. else:
  16. config = all_samplers[0]
  17. return config
  18. def create_sampler(name, model):
  19. config = find_sampler_config(name)
  20. assert config is not None, f'bad sampler name: {name}'
  21. if model.is_sdxl and config.options.get("no_sdxl", False):
  22. raise Exception(f"Sampler {config.name} is not supported for SDXL")
  23. sampler = config.constructor(model)
  24. sampler.config = config
  25. return sampler
  26. def set_samplers():
  27. global samplers, samplers_for_img2img
  28. hidden = set(shared.opts.hide_samplers)
  29. hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
  30. samplers = [x for x in all_samplers if x.name not in hidden]
  31. samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
  32. samplers_map.clear()
  33. for sampler in all_samplers:
  34. samplers_map[sampler.name.lower()] = sampler.name
  35. for alias in sampler.aliases:
  36. samplers_map[alias.lower()] = sampler.name
  37. set_samplers()