encoder.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # Based on https://github.com/IDEA-Research/GroundingDINO
  3. # pyre-unsafe
  4. from typing import Any, Dict, List, Optional, Tuple
  5. import torch
  6. from torch import nn, Tensor
  7. from .act_ckpt_utils import activation_ckpt_wrapper
  8. from .model_misc import get_activation_fn, get_clones, get_valid_ratio
  9. class TransformerEncoderLayer(nn.Module):
  10. """
  11. Transformer encoder layer that performs self-attention followed by cross-attention.
  12. This layer was previously called TransformerDecoderLayer but was renamed to better
  13. reflect its role in the architecture. It processes input sequences through self-attention
  14. and then cross-attention with another input (typically image features).
  15. The layer supports both pre-norm and post-norm configurations, as well as
  16. positional encoding at different stages of the attention mechanism.
  17. """
  18. def __init__(
  19. self,
  20. activation: str,
  21. cross_attention: nn.Module,
  22. d_model: int,
  23. dim_feedforward: int,
  24. dropout: float,
  25. pos_enc_at_attn: bool,
  26. pos_enc_at_cross_attn_keys: bool,
  27. pos_enc_at_cross_attn_queries: bool,
  28. pre_norm: bool,
  29. self_attention: nn.Module,
  30. ):
  31. """
  32. Initialize a transformer encoder layer.
  33. Args:
  34. activation: Activation function to use in the feedforward network
  35. cross_attention: Cross-attention module for attending to image features
  36. d_model: Model dimension/hidden size
  37. dim_feedforward: Dimension of the feedforward network
  38. dropout: Dropout probability
  39. pos_enc_at_attn: Whether to add positional encodings at self-attention
  40. pos_enc_at_cross_attn_keys: Whether to add positional encodings to keys in cross-attention
  41. pos_enc_at_cross_attn_queries: Whether to add positional encodings to queries in cross-attention
  42. pre_norm: Whether to use pre-norm (True) or post-norm (False) architecture
  43. self_attention: Self-attention module
  44. """
  45. super().__init__()
  46. self.d_model = d_model
  47. self.dim_feedforward = dim_feedforward
  48. self.dropout_value = dropout
  49. self.self_attn = self_attention
  50. self.cross_attn_image = cross_attention
  51. # Implementation of Feedforward model
  52. self.linear1 = nn.Linear(d_model, dim_feedforward)
  53. self.dropout = nn.Dropout(dropout)
  54. self.linear2 = nn.Linear(dim_feedforward, d_model)
  55. self.norm1 = nn.LayerNorm(d_model)
  56. self.norm2 = nn.LayerNorm(d_model)
  57. self.norm3 = nn.LayerNorm(d_model)
  58. self.dropout1 = nn.Dropout(dropout)
  59. self.dropout2 = nn.Dropout(dropout)
  60. self.dropout3 = nn.Dropout(dropout)
  61. self.activation_str = activation
  62. self.activation = get_activation_fn(activation)
  63. self.pre_norm = pre_norm
  64. self.pos_enc_at_attn = pos_enc_at_attn
  65. self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
  66. self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
  67. self.layer_idx = None
  68. def forward_post(
  69. self,
  70. tgt: Tensor,
  71. memory: Tensor,
  72. tgt_mask: Optional[Tensor] = None,
  73. memory_mask: Optional[Tensor] = None,
  74. tgt_key_padding_mask: Optional[Tensor] = None,
  75. memory_key_padding_mask: Optional[Tensor] = None,
  76. pos: Optional[Tensor] = None,
  77. query_pos: Optional[Tensor] = None,
  78. **kwargs,
  79. ) -> Tensor:
  80. """
  81. Forward pass for post-norm architecture.
  82. In post-norm architecture, normalization is applied after attention and feedforward operations.
  83. Args:
  84. tgt: Input tensor to be processed
  85. memory: Memory tensor for cross-attention
  86. tgt_mask: Mask for self-attention
  87. memory_mask: Mask for cross-attention
  88. tgt_key_padding_mask: Key padding mask for self-attention
  89. memory_key_padding_mask: Key padding mask for cross-attention
  90. pos: Positional encoding for memory
  91. query_pos: Positional encoding for query
  92. **kwargs: Additional keyword arguments
  93. Returns:
  94. Processed tensor
  95. """
  96. q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
  97. # Self attention
  98. tgt2 = self.self_attn(
  99. q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
  100. )[0]
  101. tgt = tgt + self.dropout1(tgt2)
  102. tgt = self.norm1(tgt)
  103. # Cross attention to image
  104. tgt2 = self.cross_attn_image(
  105. query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
  106. key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
  107. value=memory,
  108. attn_mask=memory_mask,
  109. key_padding_mask=memory_key_padding_mask,
  110. )[0]
  111. tgt = tgt + self.dropout2(tgt2)
  112. tgt = self.norm2(tgt)
  113. # FFN
  114. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  115. tgt = tgt + self.dropout3(tgt2)
  116. tgt = self.norm3(tgt)
  117. return tgt
  118. def forward_pre(
  119. self,
  120. tgt: Tensor,
  121. memory: Tensor,
  122. dac: bool = False,
  123. tgt_mask: Optional[Tensor] = None,
  124. memory_mask: Optional[Tensor] = None,
  125. tgt_key_padding_mask: Optional[Tensor] = None,
  126. memory_key_padding_mask: Optional[Tensor] = None,
  127. pos: Optional[Tensor] = None,
  128. query_pos: Optional[Tensor] = None,
  129. # attn_bias: Optional[Tensor] = None,
  130. # **kwargs,
  131. ) -> Tensor:
  132. """
  133. Forward pass for pre-norm architecture.
  134. In pre-norm architecture, normalization is applied before attention and feedforward operations.
  135. Args:
  136. tgt: Input tensor to be processed
  137. memory: Memory tensor for cross-attention
  138. dac: Whether to use Divide-and-Conquer attention
  139. tgt_mask: Mask for self-attention
  140. memory_mask: Mask for cross-attention
  141. tgt_key_padding_mask: Key padding mask for self-attention
  142. memory_key_padding_mask: Key padding mask for cross-attention
  143. pos: Positional encoding for memory
  144. query_pos: Positional encoding for query
  145. attn_bias: Optional attention bias tensor
  146. **kwargs: Additional keyword arguments
  147. Returns:
  148. Processed tensor
  149. """
  150. if dac:
  151. # we only apply self attention to the first half of the queries
  152. assert tgt.shape[0] % 2 == 0
  153. other_tgt = tgt[tgt.shape[0] // 2 :]
  154. tgt = tgt[: tgt.shape[0] // 2]
  155. tgt2 = self.norm1(tgt)
  156. q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
  157. tgt2 = self.self_attn(
  158. q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
  159. )[0]
  160. tgt = tgt + self.dropout1(tgt2)
  161. if dac:
  162. # Recombine
  163. tgt = torch.cat((tgt, other_tgt), dim=0)
  164. tgt2 = self.norm2(tgt)
  165. tgt2 = self.cross_attn_image(
  166. query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
  167. key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
  168. value=memory,
  169. attn_mask=memory_mask,
  170. key_padding_mask=memory_key_padding_mask,
  171. # attn_bias=attn_bias,
  172. )[0]
  173. tgt = tgt + self.dropout2(tgt2)
  174. tgt2 = self.norm3(tgt)
  175. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  176. tgt = tgt + self.dropout3(tgt2)
  177. return tgt
  178. def forward(
  179. self,
  180. tgt: Tensor,
  181. memory: Tensor,
  182. dac: bool = False,
  183. tgt_mask: Optional[Tensor] = None,
  184. memory_mask: Optional[Tensor] = None,
  185. tgt_key_padding_mask: Optional[Tensor] = None,
  186. memory_key_padding_mask: Optional[Tensor] = None,
  187. pos: Optional[Tensor] = None,
  188. query_pos: Optional[Tensor] = None,
  189. # attn_bias: Optional[Tensor] = None,
  190. # **kwds: Any,
  191. ) -> torch.Tensor:
  192. """
  193. Forward pass for the transformer encoder layer.
  194. Args:
  195. tgt: Input tensor to be processed
  196. memory: Memory tensor (e.g., image features) for cross-attention
  197. dac: Whether to use Divide-and-Conquer attention (only apply self-attention to first half)
  198. tgt_mask: Mask for self-attention
  199. memory_mask: Mask for cross-attention
  200. tgt_key_padding_mask: Key padding mask for self-attention
  201. memory_key_padding_mask: Key padding mask for cross-attention
  202. pos: Positional encoding for memory
  203. query_pos: Positional encoding for query
  204. attn_bias: Optional attention bias tensor
  205. **kwds: Additional keyword arguments
  206. Returns:
  207. Processed tensor after self-attention, cross-attention, and feedforward network
  208. """
  209. fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
  210. return fwd_fn(
  211. tgt,
  212. memory,
  213. dac=dac,
  214. tgt_mask=tgt_mask,
  215. memory_mask=memory_mask,
  216. tgt_key_padding_mask=tgt_key_padding_mask,
  217. memory_key_padding_mask=memory_key_padding_mask,
  218. pos=pos,
  219. query_pos=query_pos,
  220. # attn_bias=attn_bias,
  221. # **kwds,
  222. )
  223. class TransformerEncoder(nn.Module):
  224. """
  225. Transformer encoder that processes multi-level features.
  226. This encoder takes multi-level features (e.g., from a backbone network) and processes
  227. them through a stack of transformer encoder layers. It supports features from multiple
  228. levels (e.g., different resolutions) and can apply activation checkpointing for memory
  229. efficiency during training.
  230. Args:
  231. layer: The encoder layer to be stacked multiple times
  232. num_layers: Number of encoder layers to stack
  233. d_model: Model dimension/hidden size
  234. num_feature_levels: Number of feature levels to process
  235. frozen: Whether to freeze the parameters of this module
  236. use_act_checkpoint: Whether to use activation checkpointing during training
  237. """
  238. def __init__(
  239. self,
  240. layer: nn.Module,
  241. num_layers: int,
  242. d_model: int,
  243. num_feature_levels: int,
  244. frozen: bool = False,
  245. use_act_checkpoint: bool = False,
  246. ):
  247. super().__init__()
  248. self.layers = get_clones(layer, num_layers)
  249. self.num_layers = num_layers
  250. self.num_feature_levels = num_feature_levels
  251. self.level_embed = None
  252. if num_feature_levels > 1:
  253. self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
  254. if frozen:
  255. for p in self.parameters():
  256. p.requires_grad_(False)
  257. self.use_act_checkpoint = use_act_checkpoint
  258. # assign layer index to each layer so that some layers can decide what to do
  259. # based on which layer index they are (e.g. cross attention to memory bank only
  260. # in selected layers)
  261. for layer_idx, layer in enumerate(self.layers):
  262. layer.layer_idx = layer_idx
  263. @staticmethod
  264. def get_reference_points(spatial_shapes, valid_ratios, device):
  265. with torch.no_grad():
  266. reference_points_list = []
  267. for lvl, (H_, W_) in enumerate(spatial_shapes):
  268. ref_y, ref_x = torch.meshgrid(
  269. torch.linspace(
  270. 0.5, H_ - 0.5, H_, dtype=torch.float32, device=device
  271. ),
  272. torch.linspace(
  273. 0.5, W_ - 0.5, W_, dtype=torch.float32, device=device
  274. ),
  275. )
  276. ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
  277. ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
  278. ref = torch.stack((ref_x, ref_y), -1)
  279. reference_points_list.append(ref)
  280. reference_points = torch.cat(reference_points_list, 1)
  281. reference_points = reference_points[:, :, None] * valid_ratios[:, None]
  282. return reference_points
  283. def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
  284. assert len(srcs) == self.num_feature_levels, (
  285. "mismatch between expected and received # of feature levels"
  286. )
  287. src_flatten = []
  288. mask_flatten = []
  289. lvl_pos_embed_flatten = []
  290. spatial_shapes = []
  291. has_mask = masks is not None and masks[0] is not None
  292. for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
  293. bs, c, h, w = src.shape
  294. spatial_shape = (h, w)
  295. spatial_shapes.append(spatial_shape)
  296. src = src.flatten(2).transpose(1, 2) # bs, hw, c
  297. if has_mask:
  298. mask = mask.flatten(1)
  299. pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
  300. if self.level_embed is not None:
  301. lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
  302. else:
  303. lvl_pos_embed = pos_embed
  304. lvl_pos_embed_flatten.append(lvl_pos_embed)
  305. src_flatten.append(src)
  306. if has_mask:
  307. mask_flatten.append(mask)
  308. src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
  309. mask_flatten = torch.cat(mask_flatten, 1) if has_mask else None # bs, \sum{hxw}
  310. lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
  311. spatial_shapes = torch.tensor(
  312. spatial_shapes, dtype=torch.long, device=src_flatten.device
  313. )
  314. level_start_index = torch.cat(
  315. (
  316. spatial_shapes.new_zeros((1,)),
  317. spatial_shapes.prod(1).cumsum(0)[:-1],
  318. )
  319. )
  320. if has_mask:
  321. valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1)
  322. else:
  323. valid_ratios = torch.ones(
  324. (src_flatten.shape[0], self.num_feature_levels, 2),
  325. device=src_flatten.device,
  326. )
  327. return (
  328. src_flatten,
  329. mask_flatten,
  330. lvl_pos_embed_flatten,
  331. level_start_index,
  332. valid_ratios,
  333. spatial_shapes,
  334. )
  335. def forward(
  336. self,
  337. src: List[Tensor],
  338. src_key_padding_masks: Optional[List[Tensor]] = None,
  339. pos: Optional[List[Tensor]] = None,
  340. prompt: Optional[Tensor] = None,
  341. prompt_key_padding_mask: Optional[Tensor] = None,
  342. encoder_extra_kwargs: Optional[Dict] = None,
  343. ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor, Tensor]:
  344. """
  345. Process multi-level features through the transformer encoder.
  346. Args:
  347. src: List of multi-level features, each with shape (batch_size, channels, height, width)
  348. src_key_padding_masks: List of padding masks for each feature level, each with shape (batch_size, height, width)
  349. pos: List of positional embeddings for each feature level, each with shape (batch_size, channels, height, width)
  350. prompt: Optional text/prompt features to attend to, with shape (seq_len, batch_size, d_model)
  351. prompt_key_padding_mask: Optional padding mask for prompt, with shape (batch_size, seq_len)
  352. encoder_extra_kwargs: Optional additional arguments to pass to each encoder layer
  353. Returns:
  354. A tuple containing:
  355. - output: Processed features with shape (seq_len, batch_size, d_model)
  356. - key_padding_masks_flatten: Flattened padding masks
  357. - lvl_pos_embed_flatten: Flattened positional embeddings
  358. - level_start_index: Starting indices for each feature level
  359. - spatial_shapes: Spatial dimensions of each feature level
  360. - valid_ratios: Valid ratios for each feature level
  361. """
  362. assert len(src) == self.num_feature_levels, (
  363. "must be equal to num_feature_levels"
  364. )
  365. if src_key_padding_masks is not None:
  366. assert len(src_key_padding_masks) == self.num_feature_levels
  367. if pos is not None:
  368. assert len(pos) == self.num_feature_levels
  369. # Flatten multilevel feats and add level pos embeds
  370. (
  371. src_flatten,
  372. key_padding_masks_flatten,
  373. lvl_pos_embed_flatten,
  374. level_start_index,
  375. valid_ratios,
  376. spatial_shapes,
  377. ) = self._prepare_multilevel_features(src, src_key_padding_masks, pos)
  378. reference_points = self.get_reference_points(
  379. spatial_shapes, valid_ratios, device=src_flatten.device
  380. )
  381. output = src_flatten
  382. for layer in self.layers:
  383. layer_kwargs = {}
  384. assert isinstance(layer, TransformerEncoderLayer)
  385. layer_kwargs["memory"] = prompt
  386. layer_kwargs["memory_key_padding_mask"] = prompt_key_padding_mask
  387. layer_kwargs["query_pos"] = lvl_pos_embed_flatten
  388. layer_kwargs["tgt"] = output
  389. layer_kwargs["tgt_key_padding_mask"] = key_padding_masks_flatten
  390. if self.training:
  391. assert self.use_act_checkpoint, "activation ckpt not enabled in encoder"
  392. if encoder_extra_kwargs is not None:
  393. layer_kwargs.update(encoder_extra_kwargs)
  394. output = activation_ckpt_wrapper(layer)(
  395. **layer_kwargs,
  396. act_ckpt_enable=self.training and self.use_act_checkpoint,
  397. )
  398. # return as seq first
  399. return (
  400. output.transpose(0, 1),
  401. (
  402. key_padding_masks_flatten.transpose(0, 1)
  403. if key_padding_masks_flatten is not None
  404. else None
  405. ),
  406. lvl_pos_embed_flatten.transpose(0, 1),
  407. level_start_index,
  408. spatial_shapes,
  409. valid_ratios,
  410. )
  411. class TransformerEncoderFusion(TransformerEncoder):
  412. """
  413. Transformer encoder that fuses text and image features.
  414. This encoder extends TransformerEncoder to handle both text and image features,
  415. with the ability to add pooled text features to image features for better
  416. cross-modal fusion. It supports torch.compile for performance optimization.
  417. Args:
  418. layer: The encoder layer to be stacked multiple times
  419. num_layers: Number of encoder layers to stack
  420. d_model: Model dimension/hidden size
  421. num_feature_levels: Number of feature levels to process
  422. add_pooled_text_to_img_feat: Whether to add pooled text features to image features
  423. pool_text_with_mask: Whether to use the mask when pooling text features
  424. compile_mode: Mode for torch.compile, or None to disable compilation
  425. **kwargs: Additional arguments to pass to the parent class
  426. """
  427. def __init__(
  428. self,
  429. layer: nn.Module,
  430. num_layers: int,
  431. d_model: int,
  432. num_feature_levels: int,
  433. add_pooled_text_to_img_feat: bool = True,
  434. pool_text_with_mask: bool = False,
  435. compile_mode: Optional[str] = None,
  436. **kwargs,
  437. ):
  438. super().__init__(
  439. layer,
  440. num_layers,
  441. d_model,
  442. num_feature_levels,
  443. **kwargs,
  444. )
  445. self.add_pooled_text_to_img_feat = add_pooled_text_to_img_feat
  446. if self.add_pooled_text_to_img_feat:
  447. self.text_pooling_proj = nn.Linear(d_model, d_model)
  448. self.pool_text_with_mask = pool_text_with_mask
  449. if compile_mode is not None:
  450. self.forward = torch.compile(
  451. self.forward, mode=compile_mode, fullgraph=True
  452. )
  453. @staticmethod
  454. def get_reference_points(spatial_shapes, valid_ratios, device):
  455. # Not needed here
  456. return None
  457. def forward(
  458. self,
  459. src: List[Tensor],
  460. prompt: Tensor,
  461. src_key_padding_mask: Optional[List[Tensor]] = None,
  462. src_pos: Optional[List[Tensor]] = None,
  463. prompt_key_padding_mask: Optional[Tensor] = None,
  464. prompt_pos: Optional[Tensor] = None,
  465. feat_sizes: Optional[List[int]] = None,
  466. encoder_extra_kwargs: Optional[Dict] = None,
  467. ):
  468. # Restore spatial shapes of vision
  469. bs = src[0].shape[1] # seq first
  470. if feat_sizes is not None:
  471. assert len(feat_sizes) == len(src)
  472. if src_key_padding_mask is None:
  473. src_key_padding_mask = [None] * len(src)
  474. for i, (h, w) in enumerate(feat_sizes):
  475. src[i] = src[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
  476. src_pos[i] = src_pos[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
  477. src_key_padding_mask[i] = (
  478. src_key_padding_mask[i].reshape(h, w, bs).permute(2, 0, 1)
  479. if src_key_padding_mask[i] is not None
  480. else None
  481. )
  482. else:
  483. assert all(x.dim == 4 for x in src), (
  484. "expected list of (bs, c, h, w) tensors"
  485. )
  486. if self.add_pooled_text_to_img_feat:
  487. # Fusion: Add mean pooled text to image features
  488. pooled_text = pool_text_feat(
  489. prompt, prompt_key_padding_mask, self.pool_text_with_mask
  490. )
  491. pooled_text = self.text_pooling_proj(pooled_text)[
  492. ..., None, None
  493. ] # prompt is seq first
  494. src = [x.add_(pooled_text) for x in src]
  495. (
  496. out,
  497. key_padding_masks_flatten,
  498. lvl_pos_embed_flatten,
  499. level_start_index,
  500. spatial_shapes,
  501. valid_ratios,
  502. ) = super().forward(
  503. src,
  504. src_key_padding_masks=src_key_padding_mask,
  505. pos=src_pos,
  506. prompt=prompt.transpose(0, 1),
  507. prompt_key_padding_mask=prompt_key_padding_mask,
  508. encoder_extra_kwargs=encoder_extra_kwargs,
  509. )
  510. return {
  511. "memory": out,
  512. "padding_mask": key_padding_masks_flatten,
  513. "pos_embed": lvl_pos_embed_flatten,
  514. "memory_text": prompt,
  515. "level_start_index": level_start_index,
  516. "spatial_shapes": spatial_shapes,
  517. "valid_ratios": valid_ratios,
  518. }
  519. def pool_text_feat(prompt, prompt_mask, pool_with_mask):
  520. # prompt has shape (seq, bs, dim)
  521. if not pool_with_mask:
  522. return prompt.mean(dim=0)
  523. # prompt_mask has shape (bs, seq), where False is valid and True is padding
  524. assert prompt_mask.dim() == 2
  525. # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
  526. is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
  527. # num_valid has shape (bs, 1)
  528. num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
  529. # mean pool over all the valid tokens
  530. pooled_text = (prompt * is_valid).sum(dim=0) / num_valid
  531. return pooled_text