taiyi.py 1.3 KB

123456789101112131415161718192021222324252627282930313233
  1. import torch
  2. from transformers import BertTokenizer, BertModel
  3. # use by webui
  4. class TaiyiCLIPEmbedder(torch.nn.Module):
  5. """Uses the Taiyi CLIP transf ormer encoder for text (from Hugging Face)"""
  6. def __init__(self, version="IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1", device="cuda", max_length=512,
  7. use_auth_token=False):
  8. super().__init__()
  9. self.tokenizer = BertTokenizer.from_pretrained(version, subfolder="tokenizer", use_auth_token=use_auth_token)
  10. self.transformer = BertModel.from_pretrained(version, subfolder="text_encoder", use_auth_token=use_auth_token)
  11. self.device = device
  12. self.max_length = max_length
  13. self.freeze()
  14. def freeze(self):
  15. self.transformer = self.transformer.eval()
  16. for param in self.parameters():
  17. param.requires_grad = False
  18. def forward(self, text):
  19. batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
  20. return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
  21. tokens = batch_encoding["input_ids"].to(self.device)
  22. outputs = self.transformer(input_ids=tokens)
  23. z = outputs.last_hidden_state
  24. return z
  25. def encode(self, text):
  26. return self(text)