coco.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. # fmt: off
  2. # flake8: noqa
  3. # pyre-unsafe
  4. """COCO Dataset."""
  5. import copy
  6. import itertools
  7. import json
  8. import os
  9. from collections import defaultdict
  10. import numpy as np
  11. from scipy.optimize import linear_sum_assignment
  12. from .. import _timing, utils
  13. from ..config import get_default_dataset_config, init_config
  14. from ..utils import TrackEvalException
  15. from ._base_dataset import _BaseDataset
  16. class COCO(_BaseDataset):
  17. """Tracking datasets in COCO format."""
  18. def __init__(self, config=None):
  19. """Initialize dataset, checking that all required files are present."""
  20. super().__init__()
  21. # Fill non-given config values with defaults
  22. self.config = init_config(config, get_default_dataset_config(), self.get_name())
  23. self.gt_fol = self.config["GT_FOLDER"]
  24. self.tracker_fol = self.config["TRACKERS_FOLDER"]
  25. self.should_classes_combine = True
  26. self.use_super_categories = False
  27. self.use_mask = self.config["USE_MASK"]
  28. self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"]
  29. self.output_fol = self.config["OUTPUT_FOLDER"]
  30. if self.output_fol is None:
  31. self.output_fol = self.tracker_fol
  32. self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"]
  33. if self.gt_fol.endswith(".json"):
  34. self.gt_data = json.load(open(self.gt_fol, "r"))
  35. else:
  36. gt_dir_files = [
  37. file for file in os.listdir(self.gt_fol) if file.endswith(".json")
  38. ]
  39. if len(gt_dir_files) != 1:
  40. raise TrackEvalException(
  41. f"{self.gt_fol} does not contain exactly one json file."
  42. )
  43. with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
  44. self.gt_data = json.load(f)
  45. # fill missing video ids
  46. self._fill_video_ids_inplace(self.gt_data["annotations"])
  47. # get sequences to eval and sequence information
  48. self.seq_list = [
  49. vid["name"].replace("/", "-") for vid in self.gt_data["videos"]
  50. ]
  51. self.seq_name2seqid = {
  52. vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"]
  53. }
  54. # compute mappings from videos to annotation data
  55. self.video2gt_track, self.video2gt_image = self._compute_vid_mappings(
  56. self.gt_data["annotations"]
  57. )
  58. # compute sequence lengths
  59. self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]}
  60. for img in self.gt_data["images"]:
  61. self.seq_lengths[img["video_id"]] += 1
  62. self.seq2images2timestep = self._compute_image_to_timestep_mappings()
  63. self.seq2cls = {
  64. vid["id"]: {
  65. "pos_cat_ids": list(
  66. {track["category_id"] for track in self.video2gt_track[vid["id"]]}
  67. ),
  68. }
  69. for vid in self.gt_data["videos"]
  70. }
  71. # Get classes to eval
  72. considered_vid_ids = [self.seq_name2seqid[vid] for vid in self.seq_list]
  73. seen_cats = set(
  74. [
  75. cat_id
  76. for vid_id in considered_vid_ids
  77. for cat_id in self.seq2cls[vid_id]["pos_cat_ids"]
  78. ]
  79. )
  80. # only classes with ground truth are evaluated in TAO
  81. self.valid_classes = [
  82. cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats
  83. ]
  84. cls_name2clsid_map = {
  85. cls["name"]: cls["id"] for cls in self.gt_data["categories"]
  86. }
  87. if self.config["CLASSES_TO_EVAL"]:
  88. self.class_list = [
  89. cls.lower() if cls.lower() in self.valid_classes else None
  90. for cls in self.config["CLASSES_TO_EVAL"]
  91. ]
  92. if not all(self.class_list):
  93. valid_cls = ", ".join(self.valid_classes)
  94. raise TrackEvalException(
  95. "Attempted to evaluate an invalid class. Only classes "
  96. f"{valid_cls} are valid (classes present in ground truth"
  97. " data)."
  98. )
  99. else:
  100. self.class_list = [cls for cls in self.valid_classes]
  101. self.cls_name2clsid = {
  102. k: v for k, v in cls_name2clsid_map.items() if k in self.class_list
  103. }
  104. self.clsid2cls_name = {
  105. v: k for k, v in cls_name2clsid_map.items() if k in self.class_list
  106. }
  107. # get trackers to eval
  108. if self.config["TRACKERS_TO_EVAL"] is None:
  109. self.tracker_list = os.listdir(self.tracker_fol)
  110. else:
  111. self.tracker_list = self.config["TRACKERS_TO_EVAL"]
  112. if self.config["TRACKER_DISPLAY_NAMES"] is None:
  113. self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
  114. elif (self.config["TRACKERS_TO_EVAL"] is not None) and (
  115. len(self.config["TK_DISPLAY_NAMES"]) == len(self.tracker_list)
  116. ):
  117. self.tracker_to_disp = dict(
  118. zip(self.tracker_list, self.config["TK_DISPLAY_NAMES"])
  119. )
  120. else:
  121. raise TrackEvalException(
  122. "List of tracker files and tracker display names do not match."
  123. )
  124. self.tracker_data = {tracker: dict() for tracker in self.tracker_list}
  125. for tracker in self.tracker_list:
  126. if self.tracker_sub_fol.endswith(".json"):
  127. with open(os.path.join(self.tracker_sub_fol)) as f:
  128. curr_data = json.load(f)
  129. else:
  130. tr_dir = os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
  131. tr_dir_files = [
  132. file for file in os.listdir(tr_dir) if file.endswith(".json")
  133. ]
  134. if len(tr_dir_files) != 1:
  135. raise TrackEvalException(
  136. f"{tr_dir} does not contain exactly one json file."
  137. )
  138. with open(os.path.join(tr_dir, tr_dir_files[0])) as f:
  139. curr_data = json.load(f)
  140. # limit detections if MAX_DETECTIONS > 0
  141. if self.config["MAX_DETECTIONS"]:
  142. curr_data = self._limit_dets_per_image(curr_data)
  143. # fill missing video ids
  144. self._fill_video_ids_inplace(curr_data)
  145. # make track ids unique over whole evaluation set
  146. self._make_tk_ids_unique(curr_data)
  147. # get tracker sequence information
  148. curr_vids2tracks, curr_vids2images = self._compute_vid_mappings(curr_data)
  149. self.tracker_data[tracker]["vids_to_tracks"] = curr_vids2tracks
  150. self.tracker_data[tracker]["vids_to_images"] = curr_vids2images
  151. def get_display_name(self, tracker):
  152. return self.tracker_to_disp[tracker]
  153. def _load_raw_file(self, tracker, seq, is_gt):
  154. """Load a file (gt or tracker) in the TAO format
  155. If is_gt, this returns a dict which contains the fields:
  156. [gt_ids, gt_classes]:
  157. list (for each timestep) of 1D NDArrays (for each det).
  158. [gt_dets]: list (for each timestep) of lists of detections.
  159. if not is_gt, this returns a dict which contains the fields:
  160. [tk_ids, tk_classes]:
  161. list (for each timestep) of 1D NDArrays (for each det).
  162. [tk_dets]: list (for each timestep) of lists of detections.
  163. """
  164. seq_id = self.seq_name2seqid[seq]
  165. # file location
  166. if is_gt:
  167. imgs = self.video2gt_image[seq_id]
  168. else:
  169. imgs = self.tracker_data[tracker]["vids_to_images"][seq_id]
  170. # convert data to required format
  171. num_timesteps = self.seq_lengths[seq_id]
  172. img_to_timestep = self.seq2images2timestep[seq_id]
  173. data_keys = ["ids", "classes", "dets"]
  174. # if not is_gt:
  175. # data_keys += ["tk_confidences"]
  176. raw_data = {key: [None] * num_timesteps for key in data_keys}
  177. for img in imgs:
  178. # some tracker data contains images without any ground truth info,
  179. # these are ignored
  180. if img["id"] not in img_to_timestep:
  181. continue
  182. t = img_to_timestep[img["id"]]
  183. anns = img["annotations"]
  184. tk_str = utils.get_track_id_str(anns[0])
  185. if self.use_mask:
  186. # When using mask, extract segmentation data
  187. raw_data["dets"][t] = [ann.get("segmentation") for ann in anns]
  188. else:
  189. # When using bbox, extract bbox data
  190. raw_data["dets"][t] = np.atleast_2d([ann["bbox"] for ann in anns]).astype(
  191. float
  192. )
  193. raw_data["ids"][t] = np.atleast_1d([ann[tk_str] for ann in anns]).astype(
  194. int
  195. )
  196. raw_data["classes"][t] = np.atleast_1d(
  197. [ann["category_id"] for ann in anns]
  198. ).astype(int)
  199. # if not is_gt:
  200. # raw_data["tk_confidences"][t] = np.atleast_1d(
  201. # [ann["score"] for ann in anns]
  202. # ).astype(float)
  203. for t, d in enumerate(raw_data["dets"]):
  204. if d is None:
  205. raw_data["dets"][t] = np.empty((0, 4)).astype(float)
  206. raw_data["ids"][t] = np.empty(0).astype(int)
  207. raw_data["classes"][t] = np.empty(0).astype(int)
  208. # if not is_gt:
  209. # raw_data["tk_confidences"][t] = np.empty(0)
  210. if is_gt:
  211. key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"}
  212. else:
  213. key_map = {"ids": "tk_ids", "classes": "tk_classes", "dets": "tk_dets"}
  214. for k, v in key_map.items():
  215. raw_data[v] = raw_data.pop(k)
  216. raw_data["num_timesteps"] = num_timesteps
  217. raw_data["seq"] = seq
  218. return raw_data
  219. def get_preprocessed_seq_data_thr(self, raw_data, cls, assignment=None):
  220. """Preprocess data for a single sequence for a single class.
  221. Inputs:
  222. raw_data: dict containing the data for the sequence already
  223. read in by get_raw_seq_data().
  224. cls: class to be evaluated.
  225. Outputs:
  226. gt_ids:
  227. list (for each timestep) of ids of GT tracks
  228. tk_ids:
  229. list (for each timestep) of ids of predicted tracks (all for TP
  230. matching (Det + AssocA))
  231. tk_overlap_ids:
  232. list (for each timestep) of ids of predicted tracks that overlap
  233. with GTs
  234. tk_dets:
  235. list (for each timestep) of lists of detections that
  236. corresponding to the tk_ids
  237. tk_classes:
  238. list (for each timestep) of lists of classes that corresponding
  239. to the tk_ids
  240. tk_confidences:
  241. list (for each timestep) of lists of classes that corresponding
  242. to the tk_ids
  243. sim_scores:
  244. similarity score between gt_ids and tk_ids.
  245. """
  246. if cls != "all":
  247. cls_id = self.cls_name2clsid[cls]
  248. data_keys = [
  249. "gt_ids",
  250. "tk_ids",
  251. "gt_id_map",
  252. "tk_id_map",
  253. "gt_dets",
  254. "gt_classes",
  255. "gt_class_name",
  256. "tk_overlap_classes",
  257. "tk_overlap_ids",
  258. "tk_class_eval_tk_ids",
  259. "tk_dets",
  260. "tk_classes",
  261. # "tk_confidences",
  262. "tk_exh_ids",
  263. "sim_scores",
  264. ]
  265. data = {key: [None] * raw_data["num_timesteps"] for key in data_keys}
  266. unique_gt_ids = []
  267. unique_tk_ids = []
  268. num_gt_dets = 0
  269. num_tk_cls_dets = 0
  270. num_tk_overlap_dets = 0
  271. overlap_ious_thr = 0.5
  272. loc_and_asso_tk_ids = []
  273. exh_class_tk_ids = []
  274. for t in range(raw_data["num_timesteps"]):
  275. # only extract relevant dets for this class for preproc and eval
  276. if cls == "all":
  277. gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool)
  278. else:
  279. gt_class_mask = np.atleast_1d(
  280. raw_data["gt_classes"][t] == cls_id
  281. ).astype(bool)
  282. # select GT that is not in the evaluating classes
  283. if assignment is not None and assignment:
  284. all_gt_ids = list(assignment[t].keys())
  285. gt_ids_in = raw_data["gt_ids"][t][gt_class_mask]
  286. gt_ids_out = set(all_gt_ids) - set(gt_ids_in)
  287. tk_ids_out = set([assignment[t][key] for key in list(gt_ids_out)])
  288. # compute overlapped tracks and add their ids to overlap_tk_ids
  289. sim_scores = raw_data["similarity_scores"]
  290. overlap_ids_masks = (sim_scores[t][gt_class_mask] >= overlap_ious_thr).any(
  291. axis=0
  292. )
  293. overlap_tk_ids_t = raw_data["tk_ids"][t][overlap_ids_masks]
  294. if assignment is not None and assignment:
  295. data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t) - tk_ids_out)
  296. else:
  297. data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t))
  298. loc_and_asso_tk_ids += data["tk_overlap_ids"][t]
  299. data["tk_exh_ids"][t] = []
  300. if cls == "all":
  301. continue
  302. # add the track ids of exclusive annotated class to exh_class_tk_ids
  303. tk_exh_mask = np.atleast_1d(raw_data["tk_classes"][t] == cls_id)
  304. tk_exh_mask = tk_exh_mask.astype(bool)
  305. exh_class_tk_ids_t = raw_data["tk_ids"][t][tk_exh_mask]
  306. exh_class_tk_ids.append(exh_class_tk_ids_t)
  307. data["tk_exh_ids"][t] = exh_class_tk_ids_t
  308. # remove tk_ids that has been assigned to GT belongs to other classes.
  309. loc_and_asso_tk_ids = list(set(loc_and_asso_tk_ids))
  310. # remove all unwanted unmatched tracker detections
  311. for t in range(raw_data["num_timesteps"]):
  312. # add gt to the data
  313. if cls == "all":
  314. gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool)
  315. else:
  316. gt_class_mask = np.atleast_1d(
  317. raw_data["gt_classes"][t] == cls_id
  318. ).astype(bool)
  319. data["gt_classes"][t] = cls_id
  320. data["gt_class_name"][t] = cls
  321. gt_ids = raw_data["gt_ids"][t][gt_class_mask]
  322. if self.use_mask:
  323. gt_dets = [raw_data['gt_dets'][t][ind] for ind in range(len(gt_class_mask)) if gt_class_mask[ind]]
  324. else:
  325. gt_dets = raw_data["gt_dets"][t][gt_class_mask]
  326. data["gt_ids"][t] = gt_ids
  327. data["gt_dets"][t] = gt_dets
  328. # filter pred and only keep those that highly overlap with GTs
  329. tk_mask = np.isin(
  330. raw_data["tk_ids"][t], np.array(loc_and_asso_tk_ids), assume_unique=True
  331. )
  332. tk_overlap_mask = np.isin(
  333. raw_data["tk_ids"][t],
  334. np.array(data["tk_overlap_ids"][t]),
  335. assume_unique=True,
  336. )
  337. tk_ids = raw_data["tk_ids"][t][tk_mask]
  338. if self.use_mask:
  339. tk_dets = [raw_data['tk_dets'][t][ind] for ind in range(len(tk_mask)) if
  340. tk_mask[ind]]
  341. else:
  342. tk_dets = raw_data["tk_dets"][t][tk_mask]
  343. tracker_classes = raw_data["tk_classes"][t][tk_mask]
  344. # add overlap classes for computing the FP for Cls term
  345. tracker_overlap_classes = raw_data["tk_classes"][t][tk_overlap_mask]
  346. # tracker_confidences = raw_data["tk_confidences"][t][tk_mask]
  347. sim_scores_masked = sim_scores[t][gt_class_mask, :][:, tk_mask]
  348. # add filtered prediction to the data
  349. data["tk_classes"][t] = tracker_classes
  350. data["tk_overlap_classes"][t] = tracker_overlap_classes
  351. data["tk_ids"][t] = tk_ids
  352. data["tk_dets"][t] = tk_dets
  353. # data["tk_confidences"][t] = tracker_confidences
  354. data["sim_scores"][t] = sim_scores_masked
  355. data["tk_class_eval_tk_ids"][t] = set(
  356. list(data["tk_overlap_ids"][t]) + list(data["tk_exh_ids"][t])
  357. )
  358. # count total number of detections
  359. unique_gt_ids += list(np.unique(data["gt_ids"][t]))
  360. # the unique track ids are for association.
  361. unique_tk_ids += list(np.unique(data["tk_ids"][t]))
  362. num_tk_overlap_dets += len(data["tk_overlap_ids"][t])
  363. num_tk_cls_dets += len(data["tk_class_eval_tk_ids"][t])
  364. num_gt_dets += len(data["gt_ids"][t])
  365. # re-label IDs such that there are no empty IDs
  366. if len(unique_gt_ids) > 0:
  367. unique_gt_ids = np.unique(unique_gt_ids)
  368. gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
  369. gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
  370. data["gt_id_map"] = {}
  371. for gt_id in unique_gt_ids:
  372. new_gt_id = gt_id_map[gt_id].astype(int)
  373. data["gt_id_map"][new_gt_id] = gt_id
  374. for t in range(raw_data["num_timesteps"]):
  375. if len(data["gt_ids"][t]) > 0:
  376. data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int)
  377. if len(unique_tk_ids) > 0:
  378. unique_tk_ids = np.unique(unique_tk_ids)
  379. tk_id_map = np.nan * np.ones((np.max(unique_tk_ids) + 1))
  380. tk_id_map[unique_tk_ids] = np.arange(len(unique_tk_ids))
  381. data["tk_id_map"] = {}
  382. for track_id in unique_tk_ids:
  383. new_track_id = tk_id_map[track_id].astype(int)
  384. data["tk_id_map"][new_track_id] = track_id
  385. for t in range(raw_data["num_timesteps"]):
  386. if len(data["tk_ids"][t]) > 0:
  387. data["tk_ids"][t] = tk_id_map[data["tk_ids"][t]].astype(int)
  388. if len(data["tk_overlap_ids"][t]) > 0:
  389. data["tk_overlap_ids"][t] = tk_id_map[
  390. data["tk_overlap_ids"][t]
  391. ].astype(int)
  392. # record overview statistics.
  393. data["num_tk_cls_dets"] = num_tk_cls_dets
  394. data["num_tk_overlap_dets"] = num_tk_overlap_dets
  395. data["num_gt_dets"] = num_gt_dets
  396. data["num_tk_ids"] = len(unique_tk_ids)
  397. data["num_gt_ids"] = len(unique_gt_ids)
  398. data["num_timesteps"] = raw_data["num_timesteps"]
  399. data["seq"] = raw_data["seq"]
  400. self._check_unique_ids(data)
  401. return data
  402. @_timing.time
  403. def get_preprocessed_seq_data(
  404. self, raw_data, cls, assignment=None, thresholds=[50, 75]
  405. ):
  406. """Preprocess data for a single sequence for a single class."""
  407. data = {}
  408. if thresholds is None:
  409. thresholds = [50, 75]
  410. elif isinstance(thresholds, int):
  411. thresholds = [thresholds]
  412. for thr in thresholds:
  413. assignment_thr = None
  414. if assignment is not None:
  415. assignment_thr = assignment[thr]
  416. data[thr] = self.get_preprocessed_seq_data_thr(
  417. raw_data, cls, assignment_thr
  418. )
  419. return data
  420. def _calculate_similarities(self, gt_dets_t, tk_dets_t):
  421. """Compute similarity scores."""
  422. if self.use_mask:
  423. similarity_scores = self._calculate_mask_ious(gt_dets_t, tk_dets_t, is_encoded=True, do_ioa=False)
  424. else:
  425. similarity_scores = self._calculate_box_ious(gt_dets_t, tk_dets_t)
  426. return similarity_scores
  427. def _compute_vid_mappings(self, annotations):
  428. """Computes mappings from videos to corresponding tracks and images."""
  429. vids_to_tracks = {}
  430. vids_to_imgs = {}
  431. vid_ids = [vid["id"] for vid in self.gt_data["videos"]]
  432. # compute an mapping from image IDs to images
  433. images = {}
  434. for image in self.gt_data["images"]:
  435. images[image["id"]] = image
  436. tk_str = utils.get_track_id_str(annotations[0])
  437. for ann in annotations:
  438. ann["area"] = ann["bbox"][2] * ann["bbox"][3]
  439. vid = ann["video_id"]
  440. if ann["video_id"] not in vids_to_tracks.keys():
  441. vids_to_tracks[ann["video_id"]] = list()
  442. if ann["video_id"] not in vids_to_imgs.keys():
  443. vids_to_imgs[ann["video_id"]] = list()
  444. # fill in vids_to_tracks
  445. tid = ann[tk_str]
  446. exist_tids = [track["id"] for track in vids_to_tracks[vid]]
  447. try:
  448. index1 = exist_tids.index(tid)
  449. except ValueError:
  450. index1 = -1
  451. if tid not in exist_tids:
  452. curr_track = {
  453. "id": tid,
  454. "category_id": ann["category_id"],
  455. "video_id": vid,
  456. "annotations": [ann],
  457. }
  458. vids_to_tracks[vid].append(curr_track)
  459. else:
  460. vids_to_tracks[vid][index1]["annotations"].append(ann)
  461. # fill in vids_to_imgs
  462. img_id = ann["image_id"]
  463. exist_img_ids = [img["id"] for img in vids_to_imgs[vid]]
  464. try:
  465. index2 = exist_img_ids.index(img_id)
  466. except ValueError:
  467. index2 = -1
  468. if index2 == -1:
  469. curr_img = {"id": img_id, "annotations": [ann]}
  470. vids_to_imgs[vid].append(curr_img)
  471. else:
  472. vids_to_imgs[vid][index2]["annotations"].append(ann)
  473. # sort annotations by frame index and compute track area
  474. for vid, tracks in vids_to_tracks.items():
  475. for track in tracks:
  476. track["annotations"] = sorted(
  477. track["annotations"],
  478. key=lambda x: images[x["image_id"]]["frame_id"],
  479. )
  480. # compute average area
  481. track["area"] = sum(x["area"] for x in track["annotations"]) / len(
  482. track["annotations"]
  483. )
  484. # ensure all videos are present
  485. for vid_id in vid_ids:
  486. if vid_id not in vids_to_tracks.keys():
  487. vids_to_tracks[vid_id] = []
  488. if vid_id not in vids_to_imgs.keys():
  489. vids_to_imgs[vid_id] = []
  490. return vids_to_tracks, vids_to_imgs
  491. def _compute_image_to_timestep_mappings(self):
  492. """Computes a mapping from images to timestep in sequence."""
  493. images = {}
  494. for image in self.gt_data["images"]:
  495. images[image["id"]] = image
  496. seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]}
  497. for vid in seq_to_imgs_to_timestep:
  498. curr_imgs = [img["id"] for img in self.video2gt_image[vid]]
  499. curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_id"])
  500. seq_to_imgs_to_timestep[vid] = {
  501. curr_imgs[i]: i for i in range(len(curr_imgs))
  502. }
  503. return seq_to_imgs_to_timestep
  504. def _limit_dets_per_image(self, annotations):
  505. """Limits the number of detections for each image.
  506. Adapted from https://github.com/TAO-Dataset/.
  507. """
  508. max_dets = self.config["MAX_DETECTIONS"]
  509. img_ann = defaultdict(list)
  510. for ann in annotations:
  511. img_ann[ann["image_id"]].append(ann)
  512. for img_id, _anns in img_ann.items():
  513. if len(_anns) <= max_dets:
  514. continue
  515. _anns = sorted(_anns, key=lambda x: x["score"], reverse=True)
  516. img_ann[img_id] = _anns[:max_dets]
  517. return [ann for anns in img_ann.values() for ann in anns]
  518. def _fill_video_ids_inplace(self, annotations):
  519. """Fills in missing video IDs inplace.
  520. Adapted from https://github.com/TAO-Dataset/.
  521. """
  522. missing_video_id = [x for x in annotations if "video_id" not in x]
  523. if missing_video_id:
  524. image_id_to_video_id = {
  525. x["id"]: x["video_id"] for x in self.gt_data["images"]
  526. }
  527. for x in missing_video_id:
  528. x["video_id"] = image_id_to_video_id[x["image_id"]]
  529. @staticmethod
  530. def _make_tk_ids_unique(annotations):
  531. """Makes track IDs unqiue over the whole annotation set.
  532. Adapted from https://github.com/TAO-Dataset/.
  533. """
  534. track_id_videos = {}
  535. track_ids_to_update = set()
  536. max_track_id = 0
  537. tk_str = utils.get_track_id_str(annotations[0])
  538. for ann in annotations:
  539. t = int(ann[tk_str])
  540. if t not in track_id_videos:
  541. track_id_videos[t] = ann["video_id"]
  542. if ann["video_id"] != track_id_videos[t]:
  543. # track id is assigned to multiple videos
  544. track_ids_to_update.add(t)
  545. max_track_id = max(max_track_id, t)
  546. if track_ids_to_update:
  547. print("true")
  548. next_id = itertools.count(max_track_id + 1)
  549. new_tk_ids = defaultdict(lambda: next(next_id))
  550. for ann in annotations:
  551. t = ann[tk_str]
  552. v = ann["video_id"]
  553. if t in track_ids_to_update:
  554. ann[tk_str] = new_tk_ids[t, v]
  555. return len(track_ids_to_update)