| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import argparse
- import os
- import numpy as np
- import torch
- from PIL import Image
- from sam2.build_sam import build_sam2_video_predictor
- # the PNG palette for DAVIS 2017 dataset
- DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0"
- def load_ann_png(path):
- """Load a PNG file as a mask and its palette."""
- mask = Image.open(path)
- palette = mask.getpalette()
- mask = np.array(mask).astype(np.uint8)
- return mask, palette
- def save_ann_png(path, mask, palette):
- """Save a mask as a PNG file with the given palette."""
- assert mask.dtype == np.uint8
- assert mask.ndim == 2
- output_mask = Image.fromarray(mask)
- output_mask.putpalette(palette)
- output_mask.save(path)
- def get_per_obj_mask(mask):
- """Split a mask into per-object masks."""
- object_ids = np.unique(mask)
- object_ids = object_ids[object_ids > 0].tolist()
- per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
- return per_obj_mask
- def put_per_obj_mask(per_obj_mask, height, width):
- """Combine per-object masks into a single mask."""
- mask = np.zeros((height, width), dtype=np.uint8)
- object_ids = sorted(per_obj_mask)[::-1]
- for object_id in object_ids:
- object_mask = per_obj_mask[object_id]
- object_mask = object_mask.reshape(height, width)
- mask[object_mask] = object_id
- return mask
- def load_masks_from_dir(input_mask_dir, video_name, frame_name, per_obj_png_file):
- """Load masks from a directory as a dict of per-object masks."""
- if not per_obj_png_file:
- input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
- input_mask, input_palette = load_ann_png(input_mask_path)
- per_obj_input_mask = get_per_obj_mask(input_mask)
- else:
- per_obj_input_mask = {}
- # each object is a directory in "{object_id:%03d}" format
- for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
- object_id = int(object_name)
- input_mask_path = os.path.join(
- input_mask_dir, video_name, object_name, f"{frame_name}.png"
- )
- input_mask, input_palette = load_ann_png(input_mask_path)
- per_obj_input_mask[object_id] = input_mask > 0
- return per_obj_input_mask, input_palette
- def save_masks_to_dir(
- output_mask_dir,
- video_name,
- frame_name,
- per_obj_output_mask,
- height,
- width,
- per_obj_png_file,
- output_palette,
- ):
- """Save masks to a directory as PNG files."""
- os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
- if not per_obj_png_file:
- output_mask = put_per_obj_mask(per_obj_output_mask, height, width)
- output_mask_path = os.path.join(
- output_mask_dir, video_name, f"{frame_name}.png"
- )
- save_ann_png(output_mask_path, output_mask, output_palette)
- else:
- for object_id, object_mask in per_obj_output_mask.items():
- object_name = f"{object_id:03d}"
- os.makedirs(
- os.path.join(output_mask_dir, video_name, object_name),
- exist_ok=True,
- )
- output_mask = object_mask.reshape(height, width).astype(np.uint8)
- output_mask_path = os.path.join(
- output_mask_dir, video_name, object_name, f"{frame_name}.png"
- )
- save_ann_png(output_mask_path, output_mask, output_palette)
- @torch.inference_mode()
- @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
- def vos_inference(
- predictor,
- base_video_dir,
- input_mask_dir,
- output_mask_dir,
- video_name,
- score_thresh=0.0,
- use_all_masks=False,
- per_obj_png_file=False,
- ):
- """Run VOS inference on a single video with the given predictor."""
- # load the video frames and initialize the inference state on this video
- video_dir = os.path.join(base_video_dir, video_name)
- frame_names = [
- os.path.splitext(p)[0]
- for p in os.listdir(video_dir)
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
- ]
- frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
- inference_state = predictor.init_state(
- video_path=video_dir, async_loading_frames=False
- )
- height = inference_state["video_height"]
- width = inference_state["video_width"]
- input_palette = None
- # fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks)
- if not use_all_masks:
- # use only the first video's ground-truth mask as the input mask
- input_frame_inds = [0]
- else:
- # use all mask files available in the input_mask_dir as the input masks
- if not per_obj_png_file:
- input_frame_inds = [
- idx
- for idx, name in enumerate(frame_names)
- if os.path.exists(
- os.path.join(input_mask_dir, video_name, f"{name}.png")
- )
- ]
- else:
- input_frame_inds = [
- idx
- for object_name in os.listdir(os.path.join(input_mask_dir, video_name))
- for idx, name in enumerate(frame_names)
- if os.path.exists(
- os.path.join(input_mask_dir, video_name, object_name, f"{name}.png")
- )
- ]
- input_frame_inds = sorted(set(input_frame_inds))
- # add those input masks to SAM 2 inference state before propagation
- for input_frame_idx in input_frame_inds:
- per_obj_input_mask, input_palette = load_masks_from_dir(
- input_mask_dir=input_mask_dir,
- video_name=video_name,
- frame_name=frame_names[input_frame_idx],
- per_obj_png_file=per_obj_png_file,
- )
- for object_id, object_mask in per_obj_input_mask.items():
- predictor.add_new_mask(
- inference_state=inference_state,
- frame_idx=input_frame_idx,
- obj_id=object_id,
- mask=object_mask,
- )
- # run propagation throughout the video and collect the results in a dict
- os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
- output_palette = input_palette or DAVIS_PALETTE
- video_segments = {} # video_segments contains the per-frame segmentation results
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
- inference_state
- ):
- per_obj_output_mask = {
- out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
- for i, out_obj_id in enumerate(out_obj_ids)
- }
- video_segments[out_frame_idx] = per_obj_output_mask
- # write the output masks as palette PNG files to output_mask_dir
- for out_frame_idx, per_obj_output_mask in video_segments.items():
- save_masks_to_dir(
- output_mask_dir=output_mask_dir,
- video_name=video_name,
- frame_name=frame_names[out_frame_idx],
- per_obj_output_mask=per_obj_output_mask,
- height=height,
- width=width,
- per_obj_png_file=per_obj_png_file,
- output_palette=output_palette,
- )
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--sam2_cfg",
- type=str,
- default="sam2_hiera_b+.yaml",
- help="SAM 2 model configuration file",
- )
- parser.add_argument(
- "--sam2_checkpoint",
- type=str,
- default="./checkpoints/sam2_hiera_b+.pt",
- help="path to the SAM 2 model checkpoint",
- )
- parser.add_argument(
- "--base_video_dir",
- type=str,
- required=True,
- help="directory containing videos (as JPEG files) to run VOS prediction on",
- )
- parser.add_argument(
- "--input_mask_dir",
- type=str,
- required=True,
- help="directory containing input masks (as PNG files) of each video",
- )
- parser.add_argument(
- "--video_list_file",
- type=str,
- default=None,
- help="text file containing the list of video names to run VOS prediction on",
- )
- parser.add_argument(
- "--output_mask_dir",
- type=str,
- required=True,
- help="directory to save the output masks (as PNG files)",
- )
- parser.add_argument(
- "--score_thresh",
- type=float,
- default=0.0,
- help="threshold for the output mask logits (default: 0.0)",
- )
- parser.add_argument(
- "--use_all_masks",
- action="store_true",
- help="whether to use all available PNG files in input_mask_dir "
- "(default without this flag: just the first PNG file as input to the SAM 2 model; "
- "usually we don't need this flag, since semi-supervised VOS evalaution usually takes input from the first frame only)",
- )
- parser.add_argument(
- "--per_obj_png_file",
- action="store_true",
- help="whether use separate per-object PNG files for input and output masks "
- "(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; "
- "note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)",
- )
- parser.add_argument(
- "--apply_postprocessing",
- action="store_true",
- help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
- "(we don't apply such post-processing in the SAM 2 model evaluation)",
- )
- args = parser.parse_args()
- # if we use per-object PNG files, they could possibly overlap in inputs and outputs
- hydra_overrides_extra = [
- "++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true")
- ]
- predictor = build_sam2_video_predictor(
- config_file=args.sam2_cfg,
- ckpt_path=args.sam2_checkpoint,
- apply_postprocessing=args.apply_postprocessing,
- hydra_overrides_extra=hydra_overrides_extra,
- )
- if args.use_all_masks:
- print("using all available masks in input_mask_dir as input to the SAM 2 model")
- else:
- print(
- "using only the first frame's mask in input_mask_dir as input to the SAM 2 model"
- )
- # if a video list file is provided, read the video names from the file
- # (otherwise, we use all subdirectories in base_video_dir)
- if args.video_list_file is not None:
- with open(args.video_list_file, "r") as f:
- video_names = [v.strip() for v in f.readlines()]
- else:
- video_names = [
- p
- for p in os.listdir(args.base_video_dir)
- if os.path.isdir(os.path.join(args.base_video_dir, p))
- ]
- print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}")
- for n_video, video_name in enumerate(video_names):
- print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
- vos_inference(
- predictor=predictor,
- base_video_dir=args.base_video_dir,
- input_mask_dir=args.input_mask_dir,
- output_mask_dir=args.output_mask_dir,
- video_name=video_name,
- score_thresh=args.score_thresh,
- use_all_masks=args.use_all_masks,
- per_obj_png_file=args.per_obj_png_file,
- )
- print(
- f"completed VOS prediction on {len(video_names)} videos -- "
- f"output masks saved to {args.output_mask_dir}"
- )
- if __name__ == "__main__":
- main()
|