xlmr.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from transformers import BertPreTrainedModel, BertConfig
  2. import torch.nn as nn
  3. import torch
  4. from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
  5. from transformers import XLMRobertaModel,XLMRobertaTokenizer
  6. from typing import Optional
  7. class BertSeriesConfig(BertConfig):
  8. def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
  9. super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
  10. self.project_dim = project_dim
  11. self.pooler_fn = pooler_fn
  12. self.learn_encoder = learn_encoder
  13. class RobertaSeriesConfig(XLMRobertaConfig):
  14. def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
  15. super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
  16. self.project_dim = project_dim
  17. self.pooler_fn = pooler_fn
  18. self.learn_encoder = learn_encoder
  19. class BertSeriesModelWithTransformation(BertPreTrainedModel):
  20. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  21. _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
  22. config_class = BertSeriesConfig
  23. def __init__(self, config=None, **kargs):
  24. # modify initialization for autoloading
  25. if config is None:
  26. config = XLMRobertaConfig()
  27. config.attention_probs_dropout_prob= 0.1
  28. config.bos_token_id=0
  29. config.eos_token_id=2
  30. config.hidden_act='gelu'
  31. config.hidden_dropout_prob=0.1
  32. config.hidden_size=1024
  33. config.initializer_range=0.02
  34. config.intermediate_size=4096
  35. config.layer_norm_eps=1e-05
  36. config.max_position_embeddings=514
  37. config.num_attention_heads=16
  38. config.num_hidden_layers=24
  39. config.output_past=True
  40. config.pad_token_id=1
  41. config.position_embedding_type= "absolute"
  42. config.type_vocab_size= 1
  43. config.use_cache=True
  44. config.vocab_size= 250002
  45. config.project_dim = 768
  46. config.learn_encoder = False
  47. super().__init__(config)
  48. self.roberta = XLMRobertaModel(config)
  49. self.transformation = nn.Linear(config.hidden_size,config.project_dim)
  50. self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  51. self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
  52. self.pooler = lambda x: x[:,0]
  53. self.post_init()
  54. def encode(self,c):
  55. device = next(self.parameters()).device
  56. text = self.tokenizer(c,
  57. truncation=True,
  58. max_length=77,
  59. return_length=False,
  60. return_overflowing_tokens=False,
  61. padding="max_length",
  62. return_tensors="pt")
  63. text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
  64. text["attention_mask"] = torch.tensor(
  65. text['attention_mask']).to(device)
  66. features = self(**text)
  67. return features['projection_state']
  68. def forward(
  69. self,
  70. input_ids: Optional[torch.Tensor] = None,
  71. attention_mask: Optional[torch.Tensor] = None,
  72. token_type_ids: Optional[torch.Tensor] = None,
  73. position_ids: Optional[torch.Tensor] = None,
  74. head_mask: Optional[torch.Tensor] = None,
  75. inputs_embeds: Optional[torch.Tensor] = None,
  76. encoder_hidden_states: Optional[torch.Tensor] = None,
  77. encoder_attention_mask: Optional[torch.Tensor] = None,
  78. output_attentions: Optional[bool] = None,
  79. return_dict: Optional[bool] = None,
  80. output_hidden_states: Optional[bool] = None,
  81. ) :
  82. r"""
  83. """
  84. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  85. outputs = self.roberta(
  86. input_ids=input_ids,
  87. attention_mask=attention_mask,
  88. token_type_ids=token_type_ids,
  89. position_ids=position_ids,
  90. head_mask=head_mask,
  91. inputs_embeds=inputs_embeds,
  92. encoder_hidden_states=encoder_hidden_states,
  93. encoder_attention_mask=encoder_attention_mask,
  94. output_attentions=output_attentions,
  95. output_hidden_states=True,
  96. return_dict=return_dict,
  97. )
  98. # last module outputs
  99. sequence_output = outputs[0]
  100. # project every module
  101. sequence_output_ln = self.pre_LN(sequence_output)
  102. # pooler
  103. pooler_output = self.pooler(sequence_output_ln)
  104. pooler_output = self.transformation(pooler_output)
  105. projection_state = self.transformation(outputs.last_hidden_state)
  106. return {
  107. 'pooler_output':pooler_output,
  108. 'last_hidden_state':outputs.last_hidden_state,
  109. 'hidden_states':outputs.hidden_states,
  110. 'attentions':outputs.attentions,
  111. 'projection_state':projection_state,
  112. 'sequence_out': sequence_output
  113. }
  114. class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
  115. base_model_prefix = 'roberta'
  116. config_class= RobertaSeriesConfig