Ver Fonte

remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS (#495)

In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in https://github.com/facebookresearch/sam2/issues/487.

(close https://github.com/facebookresearch/sam2/issues/487)
Ronghang Hu há 1 ano atrás
pai
commit
2b90b9f5ce
1 ficheiros alterados com 2 adições e 4 exclusões
  1. 2 4
      sam2/modeling/sam2_base.py

+ 2 - 4
sam2/modeling/sam2_base.py

@@ -628,10 +628,8 @@ class SAM2Base(torch.nn.Module):
                     if self.add_tpos_enc_to_obj_ptrs:
                         t_diff_max = max_obj_ptrs_in_encoder - 1
                         tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
-                        obj_pos = (
-                            torch.tensor(pos_list)
-                            .pin_memory()
-                            .to(device=device, non_blocking=True)
+                        obj_pos = torch.tensor(pos_list).to(
+                            device=device, non_blocking=True
                         )
                         obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
                         obj_pos = self.obj_ptr_tpos_proj(obj_pos)