|
@@ -127,8 +127,14 @@ from sam2.sam2_video_predictor import SAM2VideoPredictor
|
|
|
predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
|
|
predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
|
|
|
|
|
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
|
|
- predictor.set_image(<your_image>)
|
|
|
|
|
- masks, _, _ = predictor.predict(<input_prompts>)
|
|
|
|
|
|
|
+ state = predictor.init_state(<your_video>)
|
|
|
|
|
+
|
|
|
|
|
+ # add new prompts and instantly get the output on the same frame
|
|
|
|
|
+ frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
|
|
|
|
|
+
|
|
|
|
|
+ # propagate the prompts to get masklets throughout the video
|
|
|
|
|
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
|
|
|
|
+ ...
|
|
|
```
|
|
```
|
|
|
|
|
|
|
|
## Model Description
|
|
## Model Description
|