tao_ow.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893
  1. # flake8: noqa
  2. # pyre-unsafe
  3. import itertools
  4. import json
  5. import os
  6. from collections import defaultdict
  7. import numpy as np
  8. from scipy.optimize import linear_sum_assignment
  9. from .. import _timing, utils
  10. from ..utils import TrackEvalException
  11. from ._base_dataset import _BaseDataset
  12. class TAO_OW(_BaseDataset):
  13. """Dataset class for TAO tracking"""
  14. @staticmethod
  15. def get_default_dataset_config():
  16. """Default class config values"""
  17. code_path = utils.get_code_path()
  18. default_config = {
  19. "GT_FOLDER": os.path.join(
  20. code_path, "data/gt/tao/tao_training"
  21. ), # Location of GT data
  22. "TRACKERS_FOLDER": os.path.join(
  23. code_path, "data/trackers/tao/tao_training"
  24. ), # Trackers location
  25. "OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
  26. "TRACKERS_TO_EVAL": None, # Filenames of trackers to eval (if None, all in folder)
  27. "CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes)
  28. "SPLIT_TO_EVAL": "training", # Valid: 'training', 'val'
  29. "PRINT_CONFIG": True, # Whether to print current config
  30. "TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
  31. "OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
  32. "TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
  33. "MAX_DETECTIONS": 300, # Number of maximal allowed detections per image (0 for unlimited)
  34. "SUBSET": "all",
  35. }
  36. return default_config
  37. def __init__(self, config=None):
  38. """Initialise dataset, checking that all required files are present"""
  39. super().__init__()
  40. # Fill non-given config values with defaults
  41. self.config = utils.init_config(
  42. config, self.get_default_dataset_config(), self.get_name()
  43. )
  44. self.gt_fol = self.config["GT_FOLDER"]
  45. self.tracker_fol = self.config["TRACKERS_FOLDER"]
  46. self.should_classes_combine = True
  47. self.use_super_categories = False
  48. self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"]
  49. self.output_fol = self.config["OUTPUT_FOLDER"]
  50. if self.output_fol is None:
  51. self.output_fol = self.tracker_fol
  52. self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"]
  53. gt_dir_files = [
  54. file for file in os.listdir(self.gt_fol) if file.endswith(".json")
  55. ]
  56. if len(gt_dir_files) != 1:
  57. raise TrackEvalException(
  58. self.gt_fol + " does not contain exactly one json file."
  59. )
  60. with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
  61. self.gt_data = json.load(f)
  62. self.subset = self.config["SUBSET"]
  63. if self.subset != "all":
  64. # Split GT data into `known`, `unknown` or `distractor`
  65. self._split_known_unknown_distractor()
  66. self.gt_data = self._filter_gt_data(self.gt_data)
  67. # merge categories marked with a merged tag in TAO dataset
  68. self._merge_categories(self.gt_data["annotations"] + self.gt_data["tracks"])
  69. # Get sequences to eval and sequence information
  70. self.seq_list = [
  71. vid["name"].replace("/", "-") for vid in self.gt_data["videos"]
  72. ]
  73. self.seq_name_to_seq_id = {
  74. vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"]
  75. }
  76. # compute mappings from videos to annotation data
  77. self.videos_to_gt_tracks, self.videos_to_gt_images = self._compute_vid_mappings(
  78. self.gt_data["annotations"]
  79. )
  80. # compute sequence lengths
  81. self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]}
  82. for img in self.gt_data["images"]:
  83. self.seq_lengths[img["video_id"]] += 1
  84. self.seq_to_images_to_timestep = self._compute_image_to_timestep_mappings()
  85. self.seq_to_classes = {
  86. vid["id"]: {
  87. "pos_cat_ids": list(
  88. {
  89. track["category_id"]
  90. for track in self.videos_to_gt_tracks[vid["id"]]
  91. }
  92. ),
  93. "neg_cat_ids": vid["neg_category_ids"],
  94. "not_exhaustively_labeled_cat_ids": vid["not_exhaustive_category_ids"],
  95. }
  96. for vid in self.gt_data["videos"]
  97. }
  98. # Get classes to eval
  99. considered_vid_ids = [self.seq_name_to_seq_id[vid] for vid in self.seq_list]
  100. seen_cats = set(
  101. [
  102. cat_id
  103. for vid_id in considered_vid_ids
  104. for cat_id in self.seq_to_classes[vid_id]["pos_cat_ids"]
  105. ]
  106. )
  107. # only classes with ground truth are evaluated in TAO
  108. self.valid_classes = [
  109. cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats
  110. ]
  111. # cls_name_to_cls_id_map = {cls['name']: cls['id'] for cls in self.gt_data['categories']}
  112. if self.config["CLASSES_TO_EVAL"]:
  113. # self.class_list = [cls.lower() if cls.lower() in self.valid_classes else None
  114. # for cls in self.config['CLASSES_TO_EVAL']]
  115. self.class_list = ["object"] # class-agnostic
  116. if not all(self.class_list):
  117. raise TrackEvalException(
  118. "Attempted to evaluate an invalid class. Only classes "
  119. + ", ".join(self.valid_classes)
  120. + " are valid (classes present in ground truth data)."
  121. )
  122. else:
  123. # self.class_list = [cls for cls in self.valid_classes]
  124. self.class_list = ["object"] # class-agnostic
  125. # self.class_name_to_class_id = {k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list}
  126. self.class_name_to_class_id = {"object": 1} # class-agnostic
  127. # Get trackers to eval
  128. if self.config["TRACKERS_TO_EVAL"] is None:
  129. self.tracker_list = os.listdir(self.tracker_fol)
  130. else:
  131. self.tracker_list = self.config["TRACKERS_TO_EVAL"]
  132. if self.config["TRACKER_DISPLAY_NAMES"] is None:
  133. self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
  134. elif (self.config["TRACKERS_TO_EVAL"] is not None) and (
  135. len(self.config["TRACKER_DISPLAY_NAMES"]) == len(self.tracker_list)
  136. ):
  137. self.tracker_to_disp = dict(
  138. zip(self.tracker_list, self.config["TRACKER_DISPLAY_NAMES"])
  139. )
  140. else:
  141. raise TrackEvalException(
  142. "List of tracker files and tracker display names do not match."
  143. )
  144. self.tracker_data = {tracker: dict() for tracker in self.tracker_list}
  145. for tracker in self.tracker_list:
  146. tr_dir_files = [
  147. file
  148. for file in os.listdir(
  149. os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
  150. )
  151. if file.endswith(".json")
  152. ]
  153. if len(tr_dir_files) != 1:
  154. raise TrackEvalException(
  155. os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
  156. + " does not contain exactly one json file."
  157. )
  158. with open(
  159. os.path.join(
  160. self.tracker_fol, tracker, self.tracker_sub_fol, tr_dir_files[0]
  161. )
  162. ) as f:
  163. curr_data = json.load(f)
  164. # limit detections if MAX_DETECTIONS > 0
  165. if self.config["MAX_DETECTIONS"]:
  166. curr_data = self._limit_dets_per_image(curr_data)
  167. # fill missing video ids
  168. self._fill_video_ids_inplace(curr_data)
  169. # make track ids unique over whole evaluation set
  170. self._make_track_ids_unique(curr_data)
  171. # merge categories marked with a merged tag in TAO dataset
  172. self._merge_categories(curr_data)
  173. # get tracker sequence information
  174. curr_videos_to_tracker_tracks, curr_videos_to_tracker_images = (
  175. self._compute_vid_mappings(curr_data)
  176. )
  177. self.tracker_data[tracker]["vids_to_tracks"] = curr_videos_to_tracker_tracks
  178. self.tracker_data[tracker]["vids_to_images"] = curr_videos_to_tracker_images
  179. def get_display_name(self, tracker):
  180. return self.tracker_to_disp[tracker]
  181. def _load_raw_file(self, tracker, seq, is_gt):
  182. """Load a file (gt or tracker) in the TAO format
  183. If is_gt, this returns a dict which contains the fields:
  184. [gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det).
  185. [gt_dets]: list (for each timestep) of lists of detections.
  186. [classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
  187. keys and corresponding segmentations as values) for each track
  188. [classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_lengths]: dictionary with class values
  189. as keys and lists (for each track) as values
  190. if not is_gt, this returns a dict which contains the fields:
  191. [tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det).
  192. [tracker_dets]: list (for each timestep) of lists of detections.
  193. [classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
  194. keys and corresponding segmentations as values) for each track
  195. [classes_to_dt_track_ids, classes_to_dt_track_areas, classes_to_dt_track_lengths]: dictionary with class values
  196. as keys and lists as values
  197. [classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values
  198. """
  199. seq_id = self.seq_name_to_seq_id[seq]
  200. # File location
  201. if is_gt:
  202. imgs = self.videos_to_gt_images[seq_id]
  203. else:
  204. imgs = self.tracker_data[tracker]["vids_to_images"][seq_id]
  205. # Convert data to required format
  206. num_timesteps = self.seq_lengths[seq_id]
  207. img_to_timestep = self.seq_to_images_to_timestep[seq_id]
  208. data_keys = ["ids", "classes", "dets"]
  209. if not is_gt:
  210. data_keys += ["tracker_confidences"]
  211. raw_data = {key: [None] * num_timesteps for key in data_keys}
  212. for img in imgs:
  213. # some tracker data contains images without any ground truth information, these are ignored
  214. try:
  215. t = img_to_timestep[img["id"]]
  216. except KeyError:
  217. continue
  218. annotations = img["annotations"]
  219. raw_data["dets"][t] = np.atleast_2d(
  220. [ann["bbox"] for ann in annotations]
  221. ).astype(float)
  222. raw_data["ids"][t] = np.atleast_1d(
  223. [ann["track_id"] for ann in annotations]
  224. ).astype(int)
  225. raw_data["classes"][t] = np.atleast_1d([1 for _ in annotations]).astype(
  226. int
  227. ) # class-agnostic
  228. if not is_gt:
  229. raw_data["tracker_confidences"][t] = np.atleast_1d(
  230. [ann["score"] for ann in annotations]
  231. ).astype(float)
  232. for t, d in enumerate(raw_data["dets"]):
  233. if d is None:
  234. raw_data["dets"][t] = np.empty((0, 4)).astype(float)
  235. raw_data["ids"][t] = np.empty(0).astype(int)
  236. raw_data["classes"][t] = np.empty(0).astype(int)
  237. if not is_gt:
  238. raw_data["tracker_confidences"][t] = np.empty(0)
  239. if is_gt:
  240. key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"}
  241. else:
  242. key_map = {
  243. "ids": "tracker_ids",
  244. "classes": "tracker_classes",
  245. "dets": "tracker_dets",
  246. }
  247. for k, v in key_map.items():
  248. raw_data[v] = raw_data.pop(k)
  249. # all_classes = [self.class_name_to_class_id[cls] for cls in self.class_list]
  250. all_classes = [1] # class-agnostic
  251. if is_gt:
  252. classes_to_consider = all_classes
  253. all_tracks = self.videos_to_gt_tracks[seq_id]
  254. else:
  255. # classes_to_consider = self.seq_to_classes[seq_id]['pos_cat_ids'] \
  256. # + self.seq_to_classes[seq_id]['neg_cat_ids']
  257. classes_to_consider = all_classes # class-agnostic
  258. all_tracks = self.tracker_data[tracker]["vids_to_tracks"][seq_id]
  259. # classes_to_tracks = {cls: [track for track in all_tracks if track['category_id'] == cls]
  260. # if cls in classes_to_consider else [] for cls in all_classes}
  261. classes_to_tracks = {
  262. cls: [track for track in all_tracks] if cls in classes_to_consider else []
  263. for cls in all_classes
  264. } # class-agnostic
  265. # mapping from classes to track information
  266. raw_data["classes_to_tracks"] = {
  267. cls: [
  268. {
  269. det["image_id"]: np.atleast_1d(det["bbox"])
  270. for det in track["annotations"]
  271. }
  272. for track in tracks
  273. ]
  274. for cls, tracks in classes_to_tracks.items()
  275. }
  276. raw_data["classes_to_track_ids"] = {
  277. cls: [track["id"] for track in tracks]
  278. for cls, tracks in classes_to_tracks.items()
  279. }
  280. raw_data["classes_to_track_areas"] = {
  281. cls: [track["area"] for track in tracks]
  282. for cls, tracks in classes_to_tracks.items()
  283. }
  284. raw_data["classes_to_track_lengths"] = {
  285. cls: [len(track["annotations"]) for track in tracks]
  286. for cls, tracks in classes_to_tracks.items()
  287. }
  288. if not is_gt:
  289. raw_data["classes_to_dt_track_scores"] = {
  290. cls: np.array(
  291. [
  292. np.mean([float(x["score"]) for x in track["annotations"]])
  293. for track in tracks
  294. ]
  295. )
  296. for cls, tracks in classes_to_tracks.items()
  297. }
  298. if is_gt:
  299. key_map = {
  300. "classes_to_tracks": "classes_to_gt_tracks",
  301. "classes_to_track_ids": "classes_to_gt_track_ids",
  302. "classes_to_track_lengths": "classes_to_gt_track_lengths",
  303. "classes_to_track_areas": "classes_to_gt_track_areas",
  304. }
  305. else:
  306. key_map = {
  307. "classes_to_tracks": "classes_to_dt_tracks",
  308. "classes_to_track_ids": "classes_to_dt_track_ids",
  309. "classes_to_track_lengths": "classes_to_dt_track_lengths",
  310. "classes_to_track_areas": "classes_to_dt_track_areas",
  311. }
  312. for k, v in key_map.items():
  313. raw_data[v] = raw_data.pop(k)
  314. raw_data["num_timesteps"] = num_timesteps
  315. raw_data["neg_cat_ids"] = self.seq_to_classes[seq_id]["neg_cat_ids"]
  316. raw_data["not_exhaustively_labeled_cls"] = self.seq_to_classes[seq_id][
  317. "not_exhaustively_labeled_cat_ids"
  318. ]
  319. raw_data["seq"] = seq
  320. return raw_data
  321. @_timing.time
  322. def get_preprocessed_seq_data(self, raw_data, cls):
  323. """Preprocess data for a single sequence for a single class ready for evaluation.
  324. Inputs:
  325. - raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data().
  326. - cls is the class to be evaluated.
  327. Outputs:
  328. - data is a dict containing all of the information that metrics need to perform evaluation.
  329. It contains the following fields:
  330. [num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers.
  331. [gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det).
  332. [gt_dets, tracker_dets]: list (for each timestep) of lists of detections.
  333. [similarity_scores]: list (for each timestep) of 2D NDArrays.
  334. Notes:
  335. General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps.
  336. 1) Extract only detections relevant for the class to be evaluated (including distractor detections).
  337. 2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a
  338. distractor class, or otherwise marked as to be removed.
  339. 3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain
  340. other criteria (e.g. are too small).
  341. 4) Remove gt dets that were only useful for preprocessing and not for actual evaluation.
  342. After the above preprocessing steps, this function also calculates the number of gt and tracker detections
  343. and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are
  344. unique within each timestep.
  345. TAO:
  346. In TAO, the 4 preproc steps are as follow:
  347. 1) All classes present in the ground truth data are evaluated separately.
  348. 2) No matched tracker detections are removed.
  349. 3) Unmatched tracker detections are removed if there is not ground truth data and the class does not
  350. belong to the categories marked as negative for this sequence. Additionally, unmatched tracker
  351. detections for classes which are marked as not exhaustively labeled are removed.
  352. 4) No gt detections are removed.
  353. Further, for TrackMAP computation track representations for the given class are accessed from a dictionary
  354. and the tracks from the tracker data are sorted according to the tracker confidence.
  355. """
  356. cls_id = self.class_name_to_class_id[cls]
  357. is_not_exhaustively_labeled = cls_id in raw_data["not_exhaustively_labeled_cls"]
  358. is_neg_category = cls_id in raw_data["neg_cat_ids"]
  359. data_keys = [
  360. "gt_ids",
  361. "tracker_ids",
  362. "gt_dets",
  363. "tracker_dets",
  364. "tracker_confidences",
  365. "similarity_scores",
  366. ]
  367. data = {key: [None] * raw_data["num_timesteps"] for key in data_keys}
  368. unique_gt_ids = []
  369. unique_tracker_ids = []
  370. num_gt_dets = 0
  371. num_tracker_dets = 0
  372. for t in range(raw_data["num_timesteps"]):
  373. # Only extract relevant dets for this class for preproc and eval (cls)
  374. gt_class_mask = np.atleast_1d(raw_data["gt_classes"][t] == cls_id)
  375. gt_class_mask = gt_class_mask.astype(bool)
  376. gt_ids = raw_data["gt_ids"][t][gt_class_mask]
  377. gt_dets = raw_data["gt_dets"][t][gt_class_mask]
  378. tracker_class_mask = np.atleast_1d(raw_data["tracker_classes"][t] == cls_id)
  379. tracker_class_mask = tracker_class_mask.astype(bool)
  380. tracker_ids = raw_data["tracker_ids"][t][tracker_class_mask]
  381. tracker_dets = raw_data["tracker_dets"][t][tracker_class_mask]
  382. tracker_confidences = raw_data["tracker_confidences"][t][tracker_class_mask]
  383. similarity_scores = raw_data["similarity_scores"][t][gt_class_mask, :][
  384. :, tracker_class_mask
  385. ]
  386. # Match tracker and gt dets (with hungarian algorithm).
  387. unmatched_indices = np.arange(tracker_ids.shape[0])
  388. if gt_ids.shape[0] > 0 and tracker_ids.shape[0] > 0:
  389. matching_scores = similarity_scores.copy()
  390. matching_scores[matching_scores < 0.5 - np.finfo("float").eps] = 0
  391. match_rows, match_cols = linear_sum_assignment(-matching_scores)
  392. actually_matched_mask = (
  393. matching_scores[match_rows, match_cols] > 0 + np.finfo("float").eps
  394. )
  395. match_cols = match_cols[actually_matched_mask]
  396. unmatched_indices = np.delete(unmatched_indices, match_cols, axis=0)
  397. if gt_ids.shape[0] == 0 and not is_neg_category:
  398. to_remove_tracker = unmatched_indices
  399. elif is_not_exhaustively_labeled:
  400. to_remove_tracker = unmatched_indices
  401. else:
  402. to_remove_tracker = np.array([], dtype=int)
  403. # remove all unwanted unmatched tracker detections
  404. data["tracker_ids"][t] = np.delete(tracker_ids, to_remove_tracker, axis=0)
  405. data["tracker_dets"][t] = np.delete(tracker_dets, to_remove_tracker, axis=0)
  406. data["tracker_confidences"][t] = np.delete(
  407. tracker_confidences, to_remove_tracker, axis=0
  408. )
  409. similarity_scores = np.delete(similarity_scores, to_remove_tracker, axis=1)
  410. data["gt_ids"][t] = gt_ids
  411. data["gt_dets"][t] = gt_dets
  412. data["similarity_scores"][t] = similarity_scores
  413. unique_gt_ids += list(np.unique(data["gt_ids"][t]))
  414. unique_tracker_ids += list(np.unique(data["tracker_ids"][t]))
  415. num_tracker_dets += len(data["tracker_ids"][t])
  416. num_gt_dets += len(data["gt_ids"][t])
  417. # Re-label IDs such that there are no empty IDs
  418. if len(unique_gt_ids) > 0:
  419. unique_gt_ids = np.unique(unique_gt_ids)
  420. gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
  421. gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
  422. for t in range(raw_data["num_timesteps"]):
  423. if len(data["gt_ids"][t]) > 0:
  424. data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int)
  425. if len(unique_tracker_ids) > 0:
  426. unique_tracker_ids = np.unique(unique_tracker_ids)
  427. tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
  428. tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
  429. for t in range(raw_data["num_timesteps"]):
  430. if len(data["tracker_ids"][t]) > 0:
  431. data["tracker_ids"][t] = tracker_id_map[
  432. data["tracker_ids"][t]
  433. ].astype(int)
  434. # Record overview statistics.
  435. data["num_tracker_dets"] = num_tracker_dets
  436. data["num_gt_dets"] = num_gt_dets
  437. data["num_tracker_ids"] = len(unique_tracker_ids)
  438. data["num_gt_ids"] = len(unique_gt_ids)
  439. data["num_timesteps"] = raw_data["num_timesteps"]
  440. data["seq"] = raw_data["seq"]
  441. # get track representations
  442. data["gt_tracks"] = raw_data["classes_to_gt_tracks"][cls_id]
  443. data["gt_track_ids"] = raw_data["classes_to_gt_track_ids"][cls_id]
  444. data["gt_track_lengths"] = raw_data["classes_to_gt_track_lengths"][cls_id]
  445. data["gt_track_areas"] = raw_data["classes_to_gt_track_areas"][cls_id]
  446. data["dt_tracks"] = raw_data["classes_to_dt_tracks"][cls_id]
  447. data["dt_track_ids"] = raw_data["classes_to_dt_track_ids"][cls_id]
  448. data["dt_track_lengths"] = raw_data["classes_to_dt_track_lengths"][cls_id]
  449. data["dt_track_areas"] = raw_data["classes_to_dt_track_areas"][cls_id]
  450. data["dt_track_scores"] = raw_data["classes_to_dt_track_scores"][cls_id]
  451. data["not_exhaustively_labeled"] = is_not_exhaustively_labeled
  452. data["iou_type"] = "bbox"
  453. # sort tracker data tracks by tracker confidence scores
  454. if data["dt_tracks"]:
  455. idx = np.argsort(
  456. [-score for score in data["dt_track_scores"]], kind="mergesort"
  457. )
  458. data["dt_track_scores"] = [data["dt_track_scores"][i] for i in idx]
  459. data["dt_tracks"] = [data["dt_tracks"][i] for i in idx]
  460. data["dt_track_ids"] = [data["dt_track_ids"][i] for i in idx]
  461. data["dt_track_lengths"] = [data["dt_track_lengths"][i] for i in idx]
  462. data["dt_track_areas"] = [data["dt_track_areas"][i] for i in idx]
  463. # Ensure that ids are unique per timestep.
  464. self._check_unique_ids(data)
  465. return data
  466. def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
  467. similarity_scores = self._calculate_box_ious(gt_dets_t, tracker_dets_t)
  468. return similarity_scores
  469. def _merge_categories(self, annotations):
  470. """
  471. Merges categories with a merged tag. Adapted from https://github.com/TAO-Dataset
  472. :param annotations: the annotations in which the classes should be merged
  473. :return: None
  474. """
  475. merge_map = {}
  476. for category in self.gt_data["categories"]:
  477. if "merged" in category:
  478. for to_merge in category["merged"]:
  479. merge_map[to_merge["id"]] = category["id"]
  480. for ann in annotations:
  481. ann["category_id"] = merge_map.get(ann["category_id"], ann["category_id"])
  482. def _compute_vid_mappings(self, annotations):
  483. """
  484. Computes mappings from Videos to corresponding tracks and images.
  485. :param annotations: the annotations for which the mapping should be generated
  486. :return: the video-to-track-mapping, the video-to-image-mapping
  487. """
  488. vids_to_tracks = {}
  489. vids_to_imgs = {}
  490. vid_ids = [vid["id"] for vid in self.gt_data["videos"]]
  491. # compute an mapping from image IDs to images
  492. images = {}
  493. for image in self.gt_data["images"]:
  494. images[image["id"]] = image
  495. for ann in annotations:
  496. ann["area"] = ann["bbox"][2] * ann["bbox"][3]
  497. vid = ann["video_id"]
  498. if ann["video_id"] not in vids_to_tracks.keys():
  499. vids_to_tracks[ann["video_id"]] = list()
  500. if ann["video_id"] not in vids_to_imgs.keys():
  501. vids_to_imgs[ann["video_id"]] = list()
  502. # Fill in vids_to_tracks
  503. tid = ann["track_id"]
  504. exist_tids = [track["id"] for track in vids_to_tracks[vid]]
  505. try:
  506. index1 = exist_tids.index(tid)
  507. except ValueError:
  508. index1 = -1
  509. if tid not in exist_tids:
  510. curr_track = {
  511. "id": tid,
  512. "category_id": ann["category_id"],
  513. "video_id": vid,
  514. "annotations": [ann],
  515. }
  516. vids_to_tracks[vid].append(curr_track)
  517. else:
  518. vids_to_tracks[vid][index1]["annotations"].append(ann)
  519. # Fill in vids_to_imgs
  520. img_id = ann["image_id"]
  521. exist_img_ids = [img["id"] for img in vids_to_imgs[vid]]
  522. try:
  523. index2 = exist_img_ids.index(img_id)
  524. except ValueError:
  525. index2 = -1
  526. if index2 == -1:
  527. curr_img = {"id": img_id, "annotations": [ann]}
  528. vids_to_imgs[vid].append(curr_img)
  529. else:
  530. vids_to_imgs[vid][index2]["annotations"].append(ann)
  531. # sort annotations by frame index and compute track area
  532. for vid, tracks in vids_to_tracks.items():
  533. for track in tracks:
  534. track["annotations"] = sorted(
  535. track["annotations"],
  536. key=lambda x: images[x["image_id"]]["frame_index"],
  537. )
  538. # Computer average area
  539. track["area"] = sum(x["area"] for x in track["annotations"]) / len(
  540. track["annotations"]
  541. )
  542. # Ensure all videos are present
  543. for vid_id in vid_ids:
  544. if vid_id not in vids_to_tracks.keys():
  545. vids_to_tracks[vid_id] = []
  546. if vid_id not in vids_to_imgs.keys():
  547. vids_to_imgs[vid_id] = []
  548. return vids_to_tracks, vids_to_imgs
  549. def _compute_image_to_timestep_mappings(self):
  550. """
  551. Computes a mapping from images to the corresponding timestep in the sequence.
  552. :return: the image-to-timestep-mapping
  553. """
  554. images = {}
  555. for image in self.gt_data["images"]:
  556. images[image["id"]] = image
  557. seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]}
  558. for vid in seq_to_imgs_to_timestep:
  559. curr_imgs = [img["id"] for img in self.videos_to_gt_images[vid]]
  560. curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_index"])
  561. seq_to_imgs_to_timestep[vid] = {
  562. curr_imgs[i]: i for i in range(len(curr_imgs))
  563. }
  564. return seq_to_imgs_to_timestep
  565. def _limit_dets_per_image(self, annotations):
  566. """
  567. Limits the number of detections for each image to config['MAX_DETECTIONS']. Adapted from
  568. https://github.com/TAO-Dataset/
  569. :param annotations: the annotations in which the detections should be limited
  570. :return: the annotations with limited detections
  571. """
  572. max_dets = self.config["MAX_DETECTIONS"]
  573. img_ann = defaultdict(list)
  574. for ann in annotations:
  575. img_ann[ann["image_id"]].append(ann)
  576. for img_id, _anns in img_ann.items():
  577. if len(_anns) <= max_dets:
  578. continue
  579. _anns = sorted(_anns, key=lambda x: x["score"], reverse=True)
  580. img_ann[img_id] = _anns[:max_dets]
  581. return [ann for anns in img_ann.values() for ann in anns]
  582. def _fill_video_ids_inplace(self, annotations):
  583. """
  584. Fills in missing video IDs inplace. Adapted from https://github.com/TAO-Dataset/
  585. :param annotations: the annotations for which the videos IDs should be filled inplace
  586. :return: None
  587. """
  588. missing_video_id = [x for x in annotations if "video_id" not in x]
  589. if missing_video_id:
  590. image_id_to_video_id = {
  591. x["id"]: x["video_id"] for x in self.gt_data["images"]
  592. }
  593. for x in missing_video_id:
  594. x["video_id"] = image_id_to_video_id[x["image_id"]]
  595. @staticmethod
  596. def _make_track_ids_unique(annotations):
  597. """
  598. Makes the track IDs unqiue over the whole annotation set. Adapted from https://github.com/TAO-Dataset/
  599. :param annotations: the annotation set
  600. :return: the number of updated IDs
  601. """
  602. track_id_videos = {}
  603. track_ids_to_update = set()
  604. max_track_id = 0
  605. for ann in annotations:
  606. t = ann["track_id"]
  607. if t not in track_id_videos:
  608. track_id_videos[t] = ann["video_id"]
  609. if ann["video_id"] != track_id_videos[t]:
  610. # Track id is assigned to multiple videos
  611. track_ids_to_update.add(t)
  612. max_track_id = max(max_track_id, t)
  613. if track_ids_to_update:
  614. print("true")
  615. next_id = itertools.count(max_track_id + 1)
  616. new_track_ids = defaultdict(lambda: next(next_id))
  617. for ann in annotations:
  618. t = ann["track_id"]
  619. v = ann["video_id"]
  620. if t in track_ids_to_update:
  621. ann["track_id"] = new_track_ids[t, v]
  622. return len(track_ids_to_update)
  623. def _split_known_unknown_distractor(self):
  624. all_ids = set(
  625. [i for i in range(1, 2000)]
  626. ) # 2000 is larger than the max category id in TAO-OW.
  627. # `knowns` includes 78 TAO_category_ids that corresponds to 78 COCO classes.
  628. # (The other 2 COCO classes do not have corresponding classes in TAO).
  629. self.knowns = {
  630. 4,
  631. 13,
  632. 1038,
  633. 544,
  634. 1057,
  635. 34,
  636. 35,
  637. 36,
  638. 41,
  639. 45,
  640. 58,
  641. 60,
  642. 579,
  643. 1091,
  644. 1097,
  645. 1099,
  646. 78,
  647. 79,
  648. 81,
  649. 91,
  650. 1115,
  651. 1117,
  652. 95,
  653. 1122,
  654. 99,
  655. 1132,
  656. 621,
  657. 1135,
  658. 625,
  659. 118,
  660. 1144,
  661. 126,
  662. 642,
  663. 1155,
  664. 133,
  665. 1162,
  666. 139,
  667. 154,
  668. 174,
  669. 185,
  670. 699,
  671. 1215,
  672. 714,
  673. 717,
  674. 1229,
  675. 211,
  676. 729,
  677. 221,
  678. 229,
  679. 747,
  680. 235,
  681. 237,
  682. 779,
  683. 276,
  684. 805,
  685. 299,
  686. 829,
  687. 852,
  688. 347,
  689. 371,
  690. 382,
  691. 896,
  692. 392,
  693. 926,
  694. 937,
  695. 428,
  696. 429,
  697. 961,
  698. 452,
  699. 979,
  700. 980,
  701. 982,
  702. 475,
  703. 480,
  704. 993,
  705. 1001,
  706. 502,
  707. 1018,
  708. }
  709. # `distractors` is defined as in the paper "Opening up Open-World Tracking"
  710. self.distractors = {
  711. 20,
  712. 63,
  713. 108,
  714. 180,
  715. 188,
  716. 204,
  717. 212,
  718. 247,
  719. 303,
  720. 403,
  721. 407,
  722. 415,
  723. 490,
  724. 504,
  725. 507,
  726. 513,
  727. 529,
  728. 567,
  729. 569,
  730. 588,
  731. 672,
  732. 691,
  733. 702,
  734. 708,
  735. 711,
  736. 720,
  737. 736,
  738. 737,
  739. 798,
  740. 813,
  741. 815,
  742. 827,
  743. 831,
  744. 851,
  745. 877,
  746. 883,
  747. 912,
  748. 971,
  749. 976,
  750. 1130,
  751. 1133,
  752. 1134,
  753. 1169,
  754. 1184,
  755. 1220,
  756. }
  757. self.unknowns = all_ids.difference(self.knowns.union(self.distractors))
  758. def _filter_gt_data(self, raw_gt_data):
  759. """
  760. Filter out irrelevant data in the raw_gt_data
  761. Args:
  762. raw_gt_data: directly loaded from json.
  763. Returns:
  764. filtered gt_data
  765. """
  766. valid_cat_ids = list()
  767. if self.subset == "known":
  768. valid_cat_ids = self.knowns
  769. elif self.subset == "distractor":
  770. valid_cat_ids = self.distractors
  771. elif self.subset == "unknown":
  772. valid_cat_ids = self.unknowns
  773. # elif self.subset == "test_only_unknowns":
  774. # valid_cat_ids = test_only_unknowns
  775. else:
  776. raise Exception("The parameter `SUBSET` is incorrect")
  777. filtered = dict()
  778. filtered["videos"] = raw_gt_data["videos"]
  779. # filtered["videos"] = list()
  780. unwanted_vid = set()
  781. # for video in raw_gt_data["videos"]:
  782. # datasrc = video["name"].split('/')[1]
  783. # if datasrc in data_srcs:
  784. # filtered["videos"].append(video)
  785. # else:
  786. # unwanted_vid.add(video["id"])
  787. filtered["annotations"] = list()
  788. for ann in raw_gt_data["annotations"]:
  789. if (ann["video_id"] not in unwanted_vid) and (
  790. ann["category_id"] in valid_cat_ids
  791. ):
  792. filtered["annotations"].append(ann)
  793. filtered["tracks"] = list()
  794. for track in raw_gt_data["tracks"]:
  795. if (track["video_id"] not in unwanted_vid) and (
  796. track["category_id"] in valid_cat_ids
  797. ):
  798. filtered["tracks"].append(track)
  799. filtered["images"] = list()
  800. for image in raw_gt_data["images"]:
  801. if image["video_id"] not in unwanted_vid:
  802. filtered["images"].append(image)
  803. filtered["categories"] = list()
  804. for cat in raw_gt_data["categories"]:
  805. if cat["id"] in valid_cat_ids:
  806. filtered["categories"].append(cat)
  807. filtered["info"] = raw_gt_data["info"]
  808. filtered["licenses"] = raw_gt_data["licenses"]
  809. return filtered