فهرست منبع

accept kwargs in auto_mask_generator

Haitham Khedr 1 سال پیش
والد
کامیت
fd5125b97a
1فایلهای تغییر یافته به همراه18 افزوده شده و 0 حذف شده
  1. 18 0
      sam2/automatic_mask_generator.py

+ 18 - 0
sam2/automatic_mask_generator.py

@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
         output_mode: str = "binary_mask",
         use_m2m: bool = False,
         multimask_output: bool = True,
+        **kwargs,
     ) -> None:
         """
         Using a SAM 2 model, generates masks for the entire image.
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
         self.use_m2m = use_m2m
         self.multimask_output = multimask_output
 
+    @classmethod
+    def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
+        """
+        Load a pretrained model from the Hugging Face hub.
+
+        Arguments:
+          model_id (str): The Hugging Face repository ID.
+          **kwargs: Additional arguments to pass to the model constructor.
+
+        Returns:
+          (SAM2AutomaticMaskGenerator): The loaded model.
+        """
+        from sam2.build_sam import build_sam2_hf
+
+        sam_model = build_sam2_hf(model_id, **kwargs)
+        return cls(sam_model, **kwargs)
+
     @torch.no_grad()
     def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
         """