Prechádzať zdrojové kódy

Merge branch 'main' into patch-1

Arun 1 rok pred
rodič
commit
102ddb8899
4 zmenil súbory, kde vykonal 93 pridanie a 36 odobranie
  1. 3 3
      README.md
  2. 10 4
      notebooks/video_predictor_example.ipynb
  3. 48 6
      sam2/sam2_video_predictor.py
  4. 32 23
      setup.py

+ 3 - 3
README.md

@@ -92,14 +92,14 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
     state = predictor.init_state(<your_video>)
 
     # add new prompts and instantly get the output on the same frame
-    frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
+    frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
 
     # propagate the prompts to get masklets throughout the video
     for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
         ...
 ```
 
-Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
+Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
 
 ## Load from 🤗 Hugging Face
 
@@ -130,7 +130,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
     state = predictor.init_state(<your_video>)
 
     # add new prompts and instantly get the output on the same frame
-    frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
+    frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
 
     # propagate the prompts to get masklets throughout the video
     for frame_idx, object_ids, masks in predictor.propagate_in_video(state):

Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 10 - 4
notebooks/video_predictor_example.ipynb


+ 48 - 6
sam2/sam2_video_predictor.py

@@ -4,6 +4,7 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+import warnings
 from collections import OrderedDict
 
 import torch
@@ -163,29 +164,66 @@ class SAM2VideoPredictor(SAM2Base):
         return len(inference_state["obj_idx_to_id"])
 
     @torch.inference_mode()
-    def add_new_points(
+    def add_new_points_or_box(
         self,
         inference_state,
         frame_idx,
         obj_id,
-        points,
-        labels,
+        points=None,
+        labels=None,
         clear_old_points=True,
         normalize_coords=True,
+        box=None,
     ):
         """Add new points to a frame."""
         obj_idx = self._obj_id_to_idx(inference_state, obj_id)
         point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
         mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
 
-        if not isinstance(points, torch.Tensor):
+        if (points is not None) != (labels is not None):
+            raise ValueError("points and labels must be provided together")
+        if points is None and box is None:
+            raise ValueError("at least one of points or box must be provided as input")
+
+        if points is None:
+            points = torch.zeros(0, 2, dtype=torch.float32)
+        elif not isinstance(points, torch.Tensor):
             points = torch.tensor(points, dtype=torch.float32)
-        if not isinstance(labels, torch.Tensor):
+        if labels is None:
+            labels = torch.zeros(0, dtype=torch.int32)
+        elif not isinstance(labels, torch.Tensor):
             labels = torch.tensor(labels, dtype=torch.int32)
         if points.dim() == 2:
             points = points.unsqueeze(0)  # add batch dimension
         if labels.dim() == 1:
             labels = labels.unsqueeze(0)  # add batch dimension
+
+        # If `box` is provided, we add it as the first two points with labels 2 and 3
+        # along with the user-provided points (consistent with how SAM 2 is trained).
+        if box is not None:
+            if not clear_old_points:
+                raise ValueError(
+                    "cannot add box without clearing old points, since "
+                    "box prompt must be provided before any point prompt "
+                    "(please use clear_old_points=True instead)"
+                )
+            if inference_state["tracking_has_started"]:
+                warnings.warn(
+                    "You are adding a box after tracking starts. SAM 2 may not always be "
+                    "able to incorporate a box prompt for *refinement*. If you intend to "
+                    "use box prompt as an *initial* input before tracking, please call "
+                    "'reset_state' on the inference state to restart from scratch.",
+                    category=UserWarning,
+                    stacklevel=2,
+                )
+            if not isinstance(box, torch.Tensor):
+                box = torch.tensor(box, dtype=torch.float32, device=points.device)
+            box_coords = box.reshape(1, 2, 2)
+            box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
+            box_labels = box_labels.reshape(1, 2)
+            points = torch.cat([box_coords, points], dim=1)
+            labels = torch.cat([box_labels, labels], dim=1)
+
         if normalize_coords:
             video_H = inference_state["video_height"]
             video_W = inference_state["video_width"]
@@ -268,6 +306,10 @@ class SAM2VideoPredictor(SAM2Base):
         )
         return frame_idx, obj_ids, video_res_masks
 
+    def add_new_points(self, *args, **kwargs):
+        """Deprecated method. Please use `add_new_points_or_box` instead."""
+        return self.add_new_points_or_box(*args, **kwargs)
+
     @torch.inference_mode()
     def add_new_mask(
         self,
@@ -548,7 +590,7 @@ class SAM2VideoPredictor(SAM2Base):
             storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
             # Find all the frames that contain temporary outputs for any objects
             # (these should be the frames that have just received clicks for mask inputs
-            # via `add_new_points` or `add_new_mask`)
+            # via `add_new_points_or_box` or `add_new_mask`)
             temp_frame_inds = set()
             for obj_temp_output_dict in temp_output_dict_per_obj.values():
                 temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())

+ 32 - 23
setup.py

@@ -44,55 +44,64 @@ BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
 # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
 BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
 
+# Catch and skip errors during extension building and print a warning message
+# (note that this message only shows up under verbose build mode
+# "pip install -v -e ." or "python setup.py build_ext -v")
+CUDA_ERROR_MSG = (
+    "{}\n\n"
+    "Failed to build the SAM 2 CUDA extension due to the error above. "
+    "You can still use SAM 2, but some post-processing functionality may be limited "
+    "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
+)
+
 
 def get_extensions():
     if not BUILD_CUDA:
         return []
 
-    srcs = ["sam2/csrc/connected_components.cu"]
-    compile_args = {
-        "cxx": [],
-        "nvcc": [
-            "-DCUDA_HAS_FP16=1",
-            "-D__CUDA_NO_HALF_OPERATORS__",
-            "-D__CUDA_NO_HALF_CONVERSIONS__",
-            "-D__CUDA_NO_HALF2_OPERATORS__",
-        ],
-    }
-    ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
+    try:
+        srcs = ["sam2/csrc/connected_components.cu"]
+        compile_args = {
+            "cxx": [],
+            "nvcc": [
+                "-DCUDA_HAS_FP16=1",
+                "-D__CUDA_NO_HALF_OPERATORS__",
+                "-D__CUDA_NO_HALF_CONVERSIONS__",
+                "-D__CUDA_NO_HALF2_OPERATORS__",
+            ],
+        }
+        ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
+    except Exception as e:
+        if BUILD_ALLOW_ERRORS:
+            print(CUDA_ERROR_MSG.format(e))
+            ext_modules = []
+        else:
+            raise e
+
     return ext_modules
 
 
 class BuildExtensionIgnoreErrors(BuildExtension):
-    # Catch and skip errors during extension building and print a warning message
-    # (note that this message only shows up under verbose build mode
-    # "pip install -v -e ." or "python setup.py build_ext -v")
-    ERROR_MSG = (
-        "{}\n\n"
-        "Failed to build the SAM 2 CUDA extension due to the error above. "
-        "You can still use SAM 2, but some post-processing functionality may be limited "
-        "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
-    )
 
     def finalize_options(self):
         try:
             super().finalize_options()
         except Exception as e:
-            print(self.ERROR_MSG.format(e))
+            print(CUDA_ERROR_MSG.format(e))
             self.extensions = []
 
     def build_extensions(self):
         try:
             super().build_extensions()
         except Exception as e:
-            print(self.ERROR_MSG.format(e))
+            print(CUDA_ERROR_MSG.format(e))
             self.extensions = []
 
     def get_ext_filename(self, ext_name):
         try:
             return super().get_ext_filename(ext_name)
         except Exception as e:
-            print(self.ERROR_MSG.format(e))
+            print(CUDA_ERROR_MSG.format(e))
             self.extensions = []
             return "_C.so"
 

Niektoré súbory nie sú zobrazené, pretože je v týchto rozdielových dátach zmenené mnoho súborov