|
|
@@ -44,7 +44,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
offload_state_to_cpu=False,
|
|
|
async_loading_frames=False,
|
|
|
):
|
|
|
- """Initialize a inference state."""
|
|
|
+ """Initialize an inference state."""
|
|
|
compute_device = self.device # device of the model
|
|
|
images, video_height, video_width = load_video_frames(
|
|
|
video_path=video_path,
|
|
|
@@ -589,7 +589,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
# to `propagate_in_video_preflight`).
|
|
|
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
|
for is_cond in [False, True]:
|
|
|
- # Separately consolidate conditioning and non-conditioning temp outptus
|
|
|
+ # Separately consolidate conditioning and non-conditioning temp outputs
|
|
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
# Find all the frames that contain temporary outputs for any objects
|
|
|
# (these should be the frames that have just received clicks for mask inputs
|
|
|
@@ -598,7 +598,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
|
|
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
|
|
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
|
|
- # consolidate the temprary output across all objects on this frame
|
|
|
+ # consolidate the temporary output across all objects on this frame
|
|
|
for frame_idx in temp_frame_inds:
|
|
|
consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|