Преглед изворни кода

Merge pull request #205 from facebookresearch/haitham/fix_hf_image_predictor

Fix HF image predictor
Haitham Khedr пре 1 година
родитељ
комит
0db838b117
4 измењених фајлова са 27 додато и 4 уклоњено
  1. 18 0
      sam2/automatic_mask_generator.py
  2. 2 0
      sam2/build_sam.py
  3. 6 3
      sam2/sam2_image_predictor.py
  4. 1 1
      sam2/sam2_video_predictor.py

+ 18 - 0
sam2/automatic_mask_generator.py

@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
         output_mode: str = "binary_mask",
         use_m2m: bool = False,
         multimask_output: bool = True,
+        **kwargs,
     ) -> None:
         """
         Using a SAM 2 model, generates masks for the entire image.
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
         self.use_m2m = use_m2m
         self.multimask_output = multimask_output
 
+    @classmethod
+    def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
+        """
+        Load a pretrained model from the Hugging Face hub.
+
+        Arguments:
+          model_id (str): The Hugging Face repository ID.
+          **kwargs: Additional arguments to pass to the model constructor.
+
+        Returns:
+          (SAM2AutomaticMaskGenerator): The loaded model.
+        """
+        from sam2.build_sam import build_sam2_hf
+
+        sam_model = build_sam2_hf(model_id, **kwargs)
+        return cls(sam_model, **kwargs)
+
     @torch.no_grad()
     def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
         """

+ 2 - 0
sam2/build_sam.py

@@ -19,6 +19,7 @@ def build_sam2(
     mode="eval",
     hydra_overrides_extra=[],
     apply_postprocessing=True,
+    **kwargs,
 ):
 
     if apply_postprocessing:
@@ -47,6 +48,7 @@ def build_sam2_video_predictor(
     mode="eval",
     hydra_overrides_extra=[],
     apply_postprocessing=True,
+    **kwargs,
 ):
     hydra_overrides = [
         "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",

+ 6 - 3
sam2/sam2_image_predictor.py

@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
         mask_threshold=0.0,
         max_hole_area=0.0,
         max_sprinkle_area=0.0,
+        **kwargs,
     ) -> None:
         """
         Uses SAM-2 to calculate the image embedding for an image, and then
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
           sam_model (Sam-2): The model to use for mask prediction.
           mask_threshold (float): The threshold to use when converting mask logits
             to binary masks. Masks are thresholded at 0 by default.
-          fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
-            the maximum area of fill_hole_area in low_res_masks.
+          max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
+            the maximum area of max_hole_area in low_res_masks.
+          max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
+            the maximum area of max_sprinkle_area in low_res_masks.
         """
         super().__init__()
         self.model = sam_model
@@ -77,7 +80,7 @@ class SAM2ImagePredictor:
         from sam2.build_sam import build_sam2_hf
 
         sam_model = build_sam2_hf(model_id, **kwargs)
-        return cls(sam_model)
+        return cls(sam_model, **kwargs)
 
     @torch.no_grad()
     def set_image(

+ 1 - 1
sam2/sam2_video_predictor.py

@@ -121,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base):
         from sam2.build_sam import build_sam2_video_predictor_hf
 
         sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
-        return cls(sam_model)
+        return sam_model
 
     def _obj_id_to_idx(self, inference_state, obj_id):
         """Map client-side object id to model-side object index."""