utils.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # flake8: noqa
  2. # pyre-unsafe
  3. import argparse
  4. import csv
  5. import os
  6. from collections import OrderedDict
  7. def init_config(config, default_config, name=None):
  8. """Initialise non-given config values with defaults"""
  9. if config is None:
  10. config = default_config
  11. else:
  12. for k in default_config.keys():
  13. if k not in config.keys():
  14. config[k] = default_config[k]
  15. if name and config["PRINT_CONFIG"]:
  16. print("\n%s Config:" % name)
  17. for c in config.keys():
  18. print("%-20s : %-30s" % (c, config[c]))
  19. return config
  20. def update_config(config):
  21. """
  22. Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
  23. :param config: the config to update
  24. :return: the updated config
  25. """
  26. parser = argparse.ArgumentParser()
  27. for setting in config.keys():
  28. if type(config[setting]) == list or type(config[setting]) == type(None):
  29. parser.add_argument("--" + setting, nargs="+")
  30. else:
  31. parser.add_argument("--" + setting)
  32. args = parser.parse_args().__dict__
  33. for setting in args.keys():
  34. if args[setting] is not None:
  35. if type(config[setting]) == type(True):
  36. if args[setting] == "True":
  37. x = True
  38. elif args[setting] == "False":
  39. x = False
  40. else:
  41. raise Exception(
  42. "Command line parameter " + setting + "must be True or False"
  43. )
  44. elif type(config[setting]) == type(1):
  45. x = int(args[setting])
  46. elif type(args[setting]) == type(None):
  47. x = None
  48. else:
  49. x = args[setting]
  50. config[setting] = x
  51. return config
  52. def get_code_path():
  53. """Get base path where code is"""
  54. return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
  55. def validate_metrics_list(metrics_list):
  56. """Get names of metric class and ensures they are unique, further checks that the fields within each metric class
  57. do not have overlapping names.
  58. """
  59. metric_names = [metric.get_name() for metric in metrics_list]
  60. # check metric names are unique
  61. if len(metric_names) != len(set(metric_names)):
  62. raise TrackEvalException(
  63. "Code being run with multiple metrics of the same name"
  64. )
  65. fields = []
  66. for m in metrics_list:
  67. fields += m.fields
  68. # check metric fields are unique
  69. if len(fields) != len(set(fields)):
  70. raise TrackEvalException(
  71. "Code being run with multiple metrics with fields of the same name"
  72. )
  73. return metric_names
  74. def write_summary_results(summaries, cls, output_folder):
  75. """Write summary results to file"""
  76. fields = sum([list(s.keys()) for s in summaries], [])
  77. values = sum([list(s.values()) for s in summaries], [])
  78. # In order to remain consistent upon new fields being adding, for each of the following fields if they are present
  79. # they will be output in the summary first in the order below. Any further fields will be output in the order each
  80. # metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
  81. # randomly (python < 3.6).
  82. default_order = [
  83. "HOTA",
  84. "DetA",
  85. "AssA",
  86. "DetRe",
  87. "DetPr",
  88. "AssRe",
  89. "AssPr",
  90. "LocA",
  91. "OWTA",
  92. "HOTA(0)",
  93. "LocA(0)",
  94. "HOTALocA(0)",
  95. "MOTA",
  96. "MOTP",
  97. "MODA",
  98. "CLR_Re",
  99. "CLR_Pr",
  100. "MTR",
  101. "PTR",
  102. "MLR",
  103. "CLR_TP",
  104. "CLR_FN",
  105. "CLR_FP",
  106. "IDSW",
  107. "MT",
  108. "PT",
  109. "ML",
  110. "Frag",
  111. "sMOTA",
  112. "IDF1",
  113. "IDR",
  114. "IDP",
  115. "IDTP",
  116. "IDFN",
  117. "IDFP",
  118. "Dets",
  119. "GT_Dets",
  120. "IDs",
  121. "GT_IDs",
  122. ]
  123. default_ordered_dict = OrderedDict(
  124. zip(default_order, [None for _ in default_order])
  125. )
  126. for f, v in zip(fields, values):
  127. default_ordered_dict[f] = v
  128. for df in default_order:
  129. if default_ordered_dict[df] is None:
  130. del default_ordered_dict[df]
  131. fields = list(default_ordered_dict.keys())
  132. values = list(default_ordered_dict.values())
  133. out_file = os.path.join(output_folder, cls + "_summary.txt")
  134. os.makedirs(os.path.dirname(out_file), exist_ok=True)
  135. with open(out_file, "w", newline="") as f:
  136. writer = csv.writer(f, delimiter=" ")
  137. writer.writerow(fields)
  138. writer.writerow(values)
  139. def write_detailed_results(details, cls, output_folder):
  140. """Write detailed results to file"""
  141. sequences = details[0].keys()
  142. fields = ["seq"] + sum([list(s["COMBINED_SEQ"].keys()) for s in details], [])
  143. out_file = os.path.join(output_folder, cls + "_detailed.csv")
  144. os.makedirs(os.path.dirname(out_file), exist_ok=True)
  145. with open(out_file, "w", newline="") as f:
  146. writer = csv.writer(f)
  147. writer.writerow(fields)
  148. for seq in sorted(sequences):
  149. if seq == "COMBINED_SEQ":
  150. continue
  151. writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
  152. writer.writerow(
  153. ["COMBINED"] + sum([list(s["COMBINED_SEQ"].values()) for s in details], [])
  154. )
  155. def load_detail(file):
  156. """Loads detailed data for a tracker."""
  157. data = {}
  158. with open(file) as f:
  159. for i, row_text in enumerate(f):
  160. row = row_text.replace("\r", "").replace("\n", "").split(",")
  161. if i == 0:
  162. keys = row[1:]
  163. continue
  164. current_values = row[1:]
  165. seq = row[0]
  166. if seq == "COMBINED":
  167. seq = "COMBINED_SEQ"
  168. if (len(current_values) == len(keys)) and seq != "":
  169. data[seq] = {}
  170. for key, value in zip(keys, current_values):
  171. data[seq][key] = float(value)
  172. return data
  173. class TrackEvalException(Exception):
  174. """Custom exception for catching expected errors."""
  175. ...