|
|
@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
|
|
|
output_mode: str = "binary_mask",
|
|
|
use_m2m: bool = False,
|
|
|
multimask_output: bool = True,
|
|
|
+ **kwargs,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Using a SAM 2 model, generates masks for the entire image.
|
|
|
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
|
|
|
self.use_m2m = use_m2m
|
|
|
self.multimask_output = multimask_output
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
|
|
+ """
|
|
|
+ Load a pretrained model from the Hugging Face hub.
|
|
|
+
|
|
|
+ Arguments:
|
|
|
+ model_id (str): The Hugging Face repository ID.
|
|
|
+ **kwargs: Additional arguments to pass to the model constructor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (SAM2AutomaticMaskGenerator): The loaded model.
|
|
|
+ """
|
|
|
+ from sam2.build_sam import build_sam2_hf
|
|
|
+
|
|
|
+ sam_model = build_sam2_hf(model_id, **kwargs)
|
|
|
+ return cls(sam_model, **kwargs)
|
|
|
+
|
|
|
@torch.no_grad()
|
|
|
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
|
|
"""
|