_base_dataset.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # fmt: off
  2. # flake8: noqa
  3. # pyre-unsafe
  4. import csv
  5. import io
  6. import os
  7. import traceback
  8. import zipfile
  9. from abc import ABC, abstractmethod
  10. from copy import deepcopy
  11. import numpy as np
  12. from .. import _timing
  13. from ..utils import TrackEvalException
  14. class _BaseDataset(ABC):
  15. @abstractmethod
  16. def __init__(self):
  17. self.tracker_list = None
  18. self.seq_list = None
  19. self.class_list = None
  20. self.output_fol = None
  21. self.output_sub_fol = None
  22. self.should_classes_combine = True
  23. self.use_super_categories = False
  24. # Functions to implement:
  25. @abstractmethod
  26. def _load_raw_file(self, tracker, seq, is_gt):
  27. ...
  28. @_timing.time
  29. @abstractmethod
  30. def get_preprocessed_seq_data(self, raw_data, cls):
  31. ...
  32. @abstractmethod
  33. def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
  34. ...
  35. # Helper functions for all datasets:
  36. @classmethod
  37. def get_class_name(cls):
  38. return cls.__name__
  39. def get_name(self):
  40. return self.get_class_name()
  41. def get_output_fol(self, tracker):
  42. return os.path.join(self.output_fol, tracker, self.output_sub_fol)
  43. def get_display_name(self, tracker):
  44. """Can be overwritten if the trackers name (in files) is different to how it should be displayed.
  45. By default this method just returns the trackers name as is.
  46. """
  47. return tracker
  48. def get_eval_info(self):
  49. """Return info about the dataset needed for the Evaluator"""
  50. return self.tracker_list, self.seq_list, self.class_list
  51. @_timing.time
  52. def get_raw_seq_data(self, tracker, seq):
  53. """Loads raw data (tracker and ground-truth) for a single tracker on a single sequence.
  54. Raw data includes all of the information needed for both preprocessing and evaluation, for all classes.
  55. A later function (get_processed_seq_data) will perform such preprocessing and extract relevant information for
  56. the evaluation of each class.
  57. This returns a dict which contains the fields:
  58. [num_timesteps]: integer
  59. [gt_ids, tracker_ids, gt_classes, tracker_classes, tracker_confidences]:
  60. list (for each timestep) of 1D NDArrays (for each det).
  61. [gt_dets, tracker_dets, gt_crowd_ignore_regions]: list (for each timestep) of lists of detections.
  62. [similarity_scores]: list (for each timestep) of 2D NDArrays.
  63. [gt_extras]: dict (for each extra) of lists (for each timestep) of 1D NDArrays (for each det).
  64. gt_extras contains dataset specific information used for preprocessing such as occlusion and truncation levels.
  65. Note that similarities are extracted as part of the dataset and not the metric, because almost all metrics are
  66. independent of the exact method of calculating the similarity. However datasets are not (e.g. segmentation
  67. masks vs 2D boxes vs 3D boxes).
  68. We calculate the similarity before preprocessing because often both preprocessing and evaluation require it and
  69. we don't wish to calculate this twice.
  70. We calculate similarity between all gt and tracker classes (not just each class individually) to allow for
  71. calculation of metrics such as class confusion matrices. Typically the impact of this on performance is low.
  72. """
  73. # Load raw data.
  74. raw_gt_data = self._load_raw_file(tracker, seq, is_gt=True)
  75. raw_tracker_data = self._load_raw_file(tracker, seq, is_gt=False)
  76. raw_data = {**raw_tracker_data, **raw_gt_data} # Merges dictionaries
  77. # Calculate similarities for each timestep.
  78. similarity_scores = []
  79. for _, (gt_dets_t, tracker_dets_t) in enumerate(
  80. zip(raw_data["gt_dets"], raw_data["tk_dets"])
  81. ):
  82. ious = self._calculate_similarities(gt_dets_t, tracker_dets_t)
  83. similarity_scores.append(ious)
  84. raw_data["similarity_scores"] = similarity_scores
  85. return raw_data
  86. @staticmethod
  87. def _load_simple_text_file(
  88. file,
  89. time_col=0,
  90. id_col=None,
  91. remove_negative_ids=False,
  92. valid_filter=None,
  93. crowd_ignore_filter=None,
  94. convert_filter=None,
  95. is_zipped=False,
  96. zip_file=None,
  97. force_delimiters=None,
  98. ):
  99. """Function that loads data which is in a commonly used text file format.
  100. Assumes each det is given by one row of a text file.
  101. There is no limit to the number or meaning of each column,
  102. however one column needs to give the timestep of each det (time_col) which is default col 0.
  103. The file dialect (deliminator, num cols, etc) is determined automatically.
  104. This function automatically separates dets by timestep,
  105. and is much faster than alternatives such as np.loadtext or pandas.
  106. If remove_negative_ids is True and id_col is not None, dets with negative values in id_col are excluded.
  107. These are not excluded from ignore data.
  108. valid_filter can be used to only include certain classes.
  109. It is a dict with ints as keys, and lists as values,
  110. such that a row is included if "row[key].lower() is in value" for all key/value pairs in the dict.
  111. If None, all classes are included.
  112. crowd_ignore_filter can be used to read crowd_ignore regions separately. It has the same format as valid filter.
  113. convert_filter can be used to convert value read to another format.
  114. This is used most commonly to convert classes given as string to a class id.
  115. This is a dict such that the key is the column to convert, and the value is another dict giving the mapping.
  116. Optionally, input files could be a zip of multiple text files for storage efficiency.
  117. Returns read_data and ignore_data.
  118. Each is a dict (with keys as timesteps as strings) of lists (over dets) of lists (over column values).
  119. Note that all data is returned as strings, and must be converted to float/int later if needed.
  120. Note that timesteps will not be present in the returned dict keys if there are no dets for them
  121. """
  122. if remove_negative_ids and id_col is None:
  123. raise TrackEvalException(
  124. "remove_negative_ids is True, but id_col is not given."
  125. )
  126. if crowd_ignore_filter is None:
  127. crowd_ignore_filter = {}
  128. if convert_filter is None:
  129. convert_filter = {}
  130. try:
  131. if is_zipped: # Either open file directly or within a zip.
  132. if zip_file is None:
  133. raise TrackEvalException(
  134. "is_zipped set to True, but no zip_file is given."
  135. )
  136. archive = zipfile.ZipFile(os.path.join(zip_file), "r")
  137. fp = io.TextIOWrapper(archive.open(file, "r"))
  138. else:
  139. fp = open(file)
  140. read_data = {}
  141. crowd_ignore_data = {}
  142. fp.seek(0, os.SEEK_END)
  143. # check if file is empty
  144. if fp.tell():
  145. fp.seek(0)
  146. dialect = csv.Sniffer().sniff(
  147. fp.readline(), delimiters=force_delimiters
  148. ) # Auto determine structure.
  149. dialect.skipinitialspace = (
  150. True # Deal with extra spaces between columns
  151. )
  152. fp.seek(0)
  153. reader = csv.reader(fp, dialect)
  154. for row in reader:
  155. try:
  156. # Deal with extra trailing spaces at the end of rows
  157. if row[-1] in "":
  158. row = row[:-1]
  159. timestep = str(int(float(row[time_col])))
  160. # Read ignore regions separately.
  161. is_ignored = False
  162. for ignore_key, ignore_value in crowd_ignore_filter.items():
  163. if row[ignore_key].lower() in ignore_value:
  164. # Convert values in one column (e.g. string to id)
  165. for (
  166. convert_key,
  167. convert_value,
  168. ) in convert_filter.items():
  169. row[convert_key] = convert_value[
  170. row[convert_key].lower()
  171. ]
  172. # Save data separated by timestep.
  173. if timestep in crowd_ignore_data.keys():
  174. crowd_ignore_data[timestep].append(row)
  175. else:
  176. crowd_ignore_data[timestep] = [row]
  177. is_ignored = True
  178. if (
  179. is_ignored
  180. ): # if det is an ignore region, it cannot be a normal det.
  181. continue
  182. # Exclude some dets if not valid.
  183. if valid_filter is not None:
  184. for key, value in valid_filter.items():
  185. if row[key].lower() not in value:
  186. continue
  187. if remove_negative_ids:
  188. if int(float(row[id_col])) < 0:
  189. continue
  190. # Convert values in one column (e.g. string to id)
  191. for convert_key, convert_value in convert_filter.items():
  192. row[convert_key] = convert_value[row[convert_key].lower()]
  193. # Save data separated by timestep.
  194. if timestep in read_data.keys():
  195. read_data[timestep].append(row)
  196. else:
  197. read_data[timestep] = [row]
  198. except Exception:
  199. exc_str_init = (
  200. "In file %s the following line cannot be read correctly: \n"
  201. % os.path.basename(file)
  202. )
  203. exc_str = " ".join([exc_str_init] + row)
  204. raise TrackEvalException(exc_str)
  205. fp.close()
  206. except Exception:
  207. print("Error loading file: %s, printing traceback." % file)
  208. traceback.print_exc()
  209. raise TrackEvalException(
  210. "File %s cannot be read because it is either not present or invalidly formatted"
  211. % os.path.basename(file)
  212. )
  213. return read_data, crowd_ignore_data
  214. @staticmethod
  215. def _calculate_mask_ious(masks1, masks2, is_encoded=False, do_ioa=False):
  216. """Calculates the IOU (intersection over union) between two arrays of segmentation masks.
  217. If is_encoded a run length encoding with pycocotools is assumed as input format, otherwise an input of numpy
  218. arrays of the shape (num_masks, height, width) is assumed and the encoding is performed.
  219. If do_ioa (intersection over area) , then calculates the intersection over the area of masks1 - this is commonly
  220. used to determine if detections are within crowd ignore region.
  221. :param masks1: first set of masks (numpy array of shape (num_masks, height, width) if not encoded,
  222. else pycocotools rle encoded format)
  223. :param masks2: second set of masks (numpy array of shape (num_masks, height, width) if not encoded,
  224. else pycocotools rle encoded format)
  225. :param is_encoded: whether the input is in pycocotools rle encoded format
  226. :param do_ioa: whether to perform IoA computation
  227. :return: the IoU/IoA scores
  228. """
  229. # Only loaded when run to reduce minimum requirements
  230. from pycocotools import mask as mask_utils
  231. # use pycocotools for run length encoding of masks
  232. if not is_encoded:
  233. masks1 = mask_utils.encode(
  234. np.array(np.transpose(masks1, (1, 2, 0)), order="F")
  235. )
  236. masks2 = mask_utils.encode(
  237. np.array(np.transpose(masks2, (1, 2, 0)), order="F")
  238. )
  239. # use pycocotools for iou computation of rle encoded masks
  240. ious = mask_utils.iou(masks1, masks2, [do_ioa] * len(masks2))
  241. if len(masks1) == 0 or len(masks2) == 0:
  242. ious = np.asarray(ious).reshape(len(masks1), len(masks2))
  243. assert (ious >= 0 - np.finfo("float").eps).all()
  244. assert (ious <= 1 + np.finfo("float").eps).all()
  245. return ious
  246. @staticmethod
  247. def _calculate_box_ious(bboxes1, bboxes2, box_format="xywh", do_ioa=False):
  248. """Calculates the IOU (intersection over union) between two arrays of boxes.
  249. Allows variable box formats ('xywh' and 'x0y0x1y1').
  250. If do_ioa (intersection over area) , then calculates the intersection over the area of boxes1 - this is commonly
  251. used to determine if detections are within crowd ignore region.
  252. """
  253. if box_format in "xywh":
  254. # layout: (x0, y0, w, h)
  255. bboxes1 = deepcopy(bboxes1)
  256. bboxes2 = deepcopy(bboxes2)
  257. bboxes1[:, 2] = bboxes1[:, 0] + bboxes1[:, 2]
  258. bboxes1[:, 3] = bboxes1[:, 1] + bboxes1[:, 3]
  259. bboxes2[:, 2] = bboxes2[:, 0] + bboxes2[:, 2]
  260. bboxes2[:, 3] = bboxes2[:, 1] + bboxes2[:, 3]
  261. elif box_format not in "x0y0x1y1":
  262. raise (TrackEvalException("box_format %s is not implemented" % box_format))
  263. # layout: (x0, y0, x1, y1)
  264. min_ = np.minimum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
  265. max_ = np.maximum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
  266. intersection = np.maximum(min_[..., 2] - max_[..., 0], 0) * np.maximum(
  267. min_[..., 3] - max_[..., 1], 0
  268. )
  269. area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
  270. bboxes1[..., 3] - bboxes1[..., 1]
  271. )
  272. if do_ioa:
  273. ioas = np.zeros_like(intersection)
  274. valid_mask = area1 > 0 + np.finfo("float").eps
  275. ioas[valid_mask, :] = (
  276. intersection[valid_mask, :] / area1[valid_mask][:, np.newaxis]
  277. )
  278. return ioas
  279. else:
  280. area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
  281. bboxes2[..., 3] - bboxes2[..., 1]
  282. )
  283. union = area1[:, np.newaxis] + area2[np.newaxis, :] - intersection
  284. intersection[area1 <= 0 + np.finfo("float").eps, :] = 0
  285. intersection[:, area2 <= 0 + np.finfo("float").eps] = 0
  286. intersection[union <= 0 + np.finfo("float").eps] = 0
  287. union[union <= 0 + np.finfo("float").eps] = 1
  288. ious = intersection / union
  289. return ious
  290. @staticmethod
  291. def _calculate_euclidean_similarity(dets1, dets2, zero_distance=2.0):
  292. """Calculates the euclidean distance between two sets of detections, and then converts this into a similarity
  293. measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance).
  294. The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity
  295. threshold corresponds to a 1m distance threshold for TPs.
  296. """
  297. dist = np.linalg.norm(dets1[:, np.newaxis] - dets2[np.newaxis, :], axis=2)
  298. sim = np.maximum(0, 1 - dist / zero_distance)
  299. return sim
  300. @staticmethod
  301. def _check_unique_ids(data, after_preproc=False):
  302. """Check the requirement that the tracker_ids and gt_ids are unique per timestep"""
  303. gt_ids = data["gt_ids"]
  304. tracker_ids = data["tk_ids"]
  305. for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(gt_ids, tracker_ids)):
  306. if len(tracker_ids_t) > 0:
  307. unique_ids, counts = np.unique(tracker_ids_t, return_counts=True)
  308. if np.max(counts) != 1:
  309. duplicate_ids = unique_ids[counts > 1]
  310. exc_str_init = (
  311. "Tracker predicts the same ID more than once in a single timestep "
  312. "(seq: %s, frame: %i, ids:" % (data["seq"], t + 1)
  313. )
  314. exc_str = (
  315. " ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")"
  316. )
  317. if after_preproc:
  318. exc_str_init += (
  319. "\n Note that this error occurred after preprocessing (but not before), "
  320. "so ids may not be as in file, and something seems wrong with preproc."
  321. )
  322. raise TrackEvalException(exc_str)
  323. if len(gt_ids_t) > 0:
  324. unique_ids, counts = np.unique(gt_ids_t, return_counts=True)
  325. if np.max(counts) != 1:
  326. duplicate_ids = unique_ids[counts > 1]
  327. exc_str_init = (
  328. "Ground-truth has the same ID more than once in a single timestep "
  329. "(seq: %s, frame: %i, ids:" % (data["seq"], t + 1)
  330. )
  331. exc_str = (
  332. " ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")"
  333. )
  334. if after_preproc:
  335. exc_str_init += (
  336. "\n Note that this error occurred after preprocessing (but not before), "
  337. "so ids may not be as in file, and something seems wrong with preproc."
  338. )
  339. raise TrackEvalException(exc_str)