Преглед изворни кода

better support for non-CUDA devices (CPU, MPS) (#192)

Ronghang Hu пре 1 година
родитељ
комит
1034ee2a1a

Разлика између датотеке није приказан због своје велике величине
+ 52 - 16
notebooks/automatic_mask_generator_example.ipynb


Разлика између датотеке није приказан због своје велике величине
+ 40 - 11
notebooks/image_predictor_example.ipynb


Разлика између датотеке није приказан због своје велике величине
+ 34 - 9
notebooks/video_predictor_example.ipynb


+ 3 - 1
sam2/automatic_mask_generator.py

@@ -284,7 +284,9 @@ class SAM2AutomaticMaskGenerator:
         orig_h, orig_w = orig_size
 
         # Run model on this batch
-        points = torch.as_tensor(points, device=self.predictor.device)
+        points = torch.as_tensor(
+            points, dtype=torch.float32, device=self.predictor.device
+        )
         in_points = self.predictor._transforms.transform_coords(
             points, normalize=normalize, orig_hw=im_size
         )

+ 6 - 1
sam2/modeling/position_encoding.py

@@ -211,6 +211,11 @@ def apply_rotary_enc(
     # repeat freqs along seq_len dim to match k seq_len
     if repeat_freqs_k:
         r = xk_.shape[-2] // xq_.shape[-2]
-        freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+        if freqs_cis.is_cuda:
+            freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+        else:
+            # torch.repeat on complex numbers may not be supported on non-CUDA devices
+            # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
+            freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
     xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
     return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

+ 2 - 2
sam2/modeling/sam2_base.py

@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
                     continue  # skip padding frames
                 # "maskmem_features" might have been offloaded to CPU in demo use cases,
                 # so we load it back to GPU (it's a no-op if it's already on GPU).
-                feats = prev["maskmem_features"].cuda(non_blocking=True)
+                feats = prev["maskmem_features"].to(device, non_blocking=True)
                 to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
                 # Spatial positional encoding (it might have been offloaded to CPU in eval)
-                maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
+                maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
                 maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
                 # Temporal positional encoding
                 maskmem_enc = (

+ 8 - 4
sam2/sam2_video_predictor.py

@@ -45,11 +45,13 @@ class SAM2VideoPredictor(SAM2Base):
         async_loading_frames=False,
     ):
         """Initialize a inference state."""
+        compute_device = self.device  # device of the model
         images, video_height, video_width = load_video_frames(
             video_path=video_path,
             image_size=self.image_size,
             offload_video_to_cpu=offload_video_to_cpu,
             async_loading_frames=async_loading_frames,
+            compute_device=compute_device,
         )
         inference_state = {}
         inference_state["images"] = images
@@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
         # the original video height and width, used for resizing final output scores
         inference_state["video_height"] = video_height
         inference_state["video_width"] = video_width
-        inference_state["device"] = torch.device("cuda")
+        inference_state["device"] = compute_device
         if offload_state_to_cpu:
             inference_state["storage_device"] = torch.device("cpu")
         else:
-            inference_state["storage_device"] = torch.device("cuda")
+            inference_state["storage_device"] = compute_device
         # inputs on each frame
         inference_state["point_inputs_per_obj"] = {}
         inference_state["mask_inputs_per_obj"] = {}
@@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
                 prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
 
         if prev_out is not None and prev_out["pred_masks"] is not None:
-            prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
+            device = inference_state["device"]
+            prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
             # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
             prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
         current_out, _ = self._run_single_frame_inference(
@@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
         )
         if backbone_out is None:
             # Cache miss -- we will run inference on a single image
-            image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
+            device = inference_state["device"]
+            image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
             backbone_out = self.forward_image(image)
             # Cache the most recent frame's feature (for repeated interactions with
             # a frame; we can use an LRU cache for more frames in the future).

+ 21 - 6
sam2/utils/misc.py

@@ -106,7 +106,15 @@ class AsyncVideoFrameLoader:
     A list of video frames to be load asynchronously without blocking session start.
     """
 
-    def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
+    def __init__(
+        self,
+        img_paths,
+        image_size,
+        offload_video_to_cpu,
+        img_mean,
+        img_std,
+        compute_device,
+    ):
         self.img_paths = img_paths
         self.image_size = image_size
         self.offload_video_to_cpu = offload_video_to_cpu
@@ -119,6 +127,7 @@ class AsyncVideoFrameLoader:
         # video_height and video_width be filled when loading the first image
         self.video_height = None
         self.video_width = None
+        self.compute_device = compute_device
 
         # load the first frame to fill video_height and video_width and also
         # to cache it (since it's most likely where the user will click)
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
         img -= self.img_mean
         img /= self.img_std
         if not self.offload_video_to_cpu:
-            img = img.cuda(non_blocking=True)
+            img = img.to(self.compute_device, non_blocking=True)
         self.images[index] = img
         return img
 
@@ -167,6 +176,7 @@ def load_video_frames(
     img_mean=(0.485, 0.456, 0.406),
     img_std=(0.229, 0.224, 0.225),
     async_loading_frames=False,
+    compute_device=torch.device("cuda"),
 ):
     """
     Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
@@ -196,7 +206,12 @@ def load_video_frames(
 
     if async_loading_frames:
         lazy_images = AsyncVideoFrameLoader(
-            img_paths, image_size, offload_video_to_cpu, img_mean, img_std
+            img_paths,
+            image_size,
+            offload_video_to_cpu,
+            img_mean,
+            img_std,
+            compute_device,
         )
         return lazy_images, lazy_images.video_height, lazy_images.video_width
 
@@ -204,9 +219,9 @@ def load_video_frames(
     for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
         images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
     if not offload_video_to_cpu:
-        images = images.cuda()
-        img_mean = img_mean.cuda()
-        img_std = img_std.cuda()
+        images = images.to(compute_device)
+        img_mean = img_mean.to(compute_device)
+        img_std = img_std.to(compute_device)
     # normalize by mean and std
     images -= img_mean
     images /= img_std

Неке датотеке нису приказане због велике количине промена