|
@@ -628,10 +628,8 @@ class SAM2Base(torch.nn.Module):
|
|
|
if self.add_tpos_enc_to_obj_ptrs:
|
|
if self.add_tpos_enc_to_obj_ptrs:
|
|
|
t_diff_max = max_obj_ptrs_in_encoder - 1
|
|
t_diff_max = max_obj_ptrs_in_encoder - 1
|
|
|
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
|
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
|
|
- obj_pos = (
|
|
|
|
|
- torch.tensor(pos_list)
|
|
|
|
|
- .pin_memory()
|
|
|
|
|
- .to(device=device, non_blocking=True)
|
|
|
|
|
|
|
+ obj_pos = torch.tensor(pos_list).to(
|
|
|
|
|
+ device=device, non_blocking=True
|
|
|
)
|
|
)
|
|
|
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
|
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
|
|
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
|
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|