|
@@ -591,7 +591,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
if frame_idx in obj_output_dict["cond_frame_outputs"]:
|
|
if frame_idx in obj_output_dict["cond_frame_outputs"]:
|
|
|
storage_key = "cond_frame_outputs"
|
|
storage_key = "cond_frame_outputs"
|
|
|
current_out = obj_output_dict[storage_key][frame_idx]
|
|
current_out = obj_output_dict[storage_key][frame_idx]
|
|
|
- pred_masks = current_out["pred_masks"]
|
|
|
|
|
|
|
+ device = inference_state["device"]
|
|
|
|
|
+ pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
|
|
|
if self.clear_non_cond_mem_around_input:
|
|
if self.clear_non_cond_mem_around_input:
|
|
|
# clear non-conditioning memory of the surrounding frames
|
|
# clear non-conditioning memory of the surrounding frames
|
|
|
self._clear_obj_non_cond_mem_around_input(
|
|
self._clear_obj_non_cond_mem_around_input(
|