sd_vae_taesd.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """
  2. Tiny AutoEncoder for Stable Diffusion
  3. (DNN for encoding / decoding SD's latent space)
  4. https://github.com/madebyollin/taesd
  5. """
  6. import os
  7. import torch
  8. import torch.nn as nn
  9. from modules import devices, paths_internal, shared
  10. sd_vae_taesd_models = {}
  11. def conv(n_in, n_out, **kwargs):
  12. return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
  13. class Clamp(nn.Module):
  14. @staticmethod
  15. def forward(x):
  16. return torch.tanh(x / 3) * 3
  17. class Block(nn.Module):
  18. def __init__(self, n_in, n_out):
  19. super().__init__()
  20. self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
  21. self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
  22. self.fuse = nn.ReLU()
  23. def forward(self, x):
  24. return self.fuse(self.conv(x) + self.skip(x))
  25. def decoder():
  26. return nn.Sequential(
  27. Clamp(), conv(4, 64), nn.ReLU(),
  28. Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
  29. Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
  30. Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
  31. Block(64, 64), conv(64, 3),
  32. )
  33. class TAESD(nn.Module):
  34. latent_magnitude = 3
  35. latent_shift = 0.5
  36. def __init__(self, decoder_path="taesd_decoder.pth"):
  37. """Initialize pretrained TAESD on the given device from the given checkpoints."""
  38. super().__init__()
  39. self.decoder = decoder()
  40. self.decoder.load_state_dict(
  41. torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
  42. @staticmethod
  43. def unscale_latents(x):
  44. """[0, 1] -> raw latents"""
  45. return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
  46. def download_model(model_path, model_url):
  47. if not os.path.exists(model_path):
  48. os.makedirs(os.path.dirname(model_path), exist_ok=True)
  49. print(f'Downloading TAESD decoder to: {model_path}')
  50. torch.hub.download_url_to_file(model_url, model_path)
  51. def model():
  52. model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
  53. loaded_model = sd_vae_taesd_models.get(model_name)
  54. if loaded_model is None:
  55. model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
  56. download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
  57. if os.path.exists(model_path):
  58. loaded_model = TAESD(model_path)
  59. loaded_model.eval()
  60. loaded_model.to(devices.device, devices.dtype)
  61. sd_vae_taesd_models[model_name] = loaded_model
  62. else:
  63. raise FileNotFoundError('TAESD model not found')
  64. return loaded_model.decoder