| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- # flake8: noqa
- # pyre-unsafe
- import argparse
- import csv
- import os
- from collections import OrderedDict
- def init_config(config, default_config, name=None):
- """Initialise non-given config values with defaults"""
- if config is None:
- config = default_config
- else:
- for k in default_config.keys():
- if k not in config.keys():
- config[k] = default_config[k]
- if name and config["PRINT_CONFIG"]:
- print("\n%s Config:" % name)
- for c in config.keys():
- print("%-20s : %-30s" % (c, config[c]))
- return config
- def update_config(config):
- """
- Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
- :param config: the config to update
- :return: the updated config
- """
- parser = argparse.ArgumentParser()
- for setting in config.keys():
- if type(config[setting]) == list or type(config[setting]) == type(None):
- parser.add_argument("--" + setting, nargs="+")
- else:
- parser.add_argument("--" + setting)
- args = parser.parse_args().__dict__
- for setting in args.keys():
- if args[setting] is not None:
- if type(config[setting]) == type(True):
- if args[setting] == "True":
- x = True
- elif args[setting] == "False":
- x = False
- else:
- raise Exception(
- "Command line parameter " + setting + "must be True or False"
- )
- elif type(config[setting]) == type(1):
- x = int(args[setting])
- elif type(args[setting]) == type(None):
- x = None
- else:
- x = args[setting]
- config[setting] = x
- return config
- def get_code_path():
- """Get base path where code is"""
- return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
- def validate_metrics_list(metrics_list):
- """Get names of metric class and ensures they are unique, further checks that the fields within each metric class
- do not have overlapping names.
- """
- metric_names = [metric.get_name() for metric in metrics_list]
- # check metric names are unique
- if len(metric_names) != len(set(metric_names)):
- raise TrackEvalException(
- "Code being run with multiple metrics of the same name"
- )
- fields = []
- for m in metrics_list:
- fields += m.fields
- # check metric fields are unique
- if len(fields) != len(set(fields)):
- raise TrackEvalException(
- "Code being run with multiple metrics with fields of the same name"
- )
- return metric_names
- def write_summary_results(summaries, cls, output_folder):
- """Write summary results to file"""
- fields = sum([list(s.keys()) for s in summaries], [])
- values = sum([list(s.values()) for s in summaries], [])
- # In order to remain consistent upon new fields being adding, for each of the following fields if they are present
- # they will be output in the summary first in the order below. Any further fields will be output in the order each
- # metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
- # randomly (python < 3.6).
- default_order = [
- "HOTA",
- "DetA",
- "AssA",
- "DetRe",
- "DetPr",
- "AssRe",
- "AssPr",
- "LocA",
- "OWTA",
- "HOTA(0)",
- "LocA(0)",
- "HOTALocA(0)",
- "MOTA",
- "MOTP",
- "MODA",
- "CLR_Re",
- "CLR_Pr",
- "MTR",
- "PTR",
- "MLR",
- "CLR_TP",
- "CLR_FN",
- "CLR_FP",
- "IDSW",
- "MT",
- "PT",
- "ML",
- "Frag",
- "sMOTA",
- "IDF1",
- "IDR",
- "IDP",
- "IDTP",
- "IDFN",
- "IDFP",
- "Dets",
- "GT_Dets",
- "IDs",
- "GT_IDs",
- ]
- default_ordered_dict = OrderedDict(
- zip(default_order, [None for _ in default_order])
- )
- for f, v in zip(fields, values):
- default_ordered_dict[f] = v
- for df in default_order:
- if default_ordered_dict[df] is None:
- del default_ordered_dict[df]
- fields = list(default_ordered_dict.keys())
- values = list(default_ordered_dict.values())
- out_file = os.path.join(output_folder, cls + "_summary.txt")
- os.makedirs(os.path.dirname(out_file), exist_ok=True)
- with open(out_file, "w", newline="") as f:
- writer = csv.writer(f, delimiter=" ")
- writer.writerow(fields)
- writer.writerow(values)
- def write_detailed_results(details, cls, output_folder):
- """Write detailed results to file"""
- sequences = details[0].keys()
- fields = ["seq"] + sum([list(s["COMBINED_SEQ"].keys()) for s in details], [])
- out_file = os.path.join(output_folder, cls + "_detailed.csv")
- os.makedirs(os.path.dirname(out_file), exist_ok=True)
- with open(out_file, "w", newline="") as f:
- writer = csv.writer(f)
- writer.writerow(fields)
- for seq in sorted(sequences):
- if seq == "COMBINED_SEQ":
- continue
- writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
- writer.writerow(
- ["COMBINED"] + sum([list(s["COMBINED_SEQ"].values()) for s in details], [])
- )
- def load_detail(file):
- """Loads detailed data for a tracker."""
- data = {}
- with open(file) as f:
- for i, row_text in enumerate(f):
- row = row_text.replace("\r", "").replace("\n", "").split(",")
- if i == 0:
- keys = row[1:]
- continue
- current_values = row[1:]
- seq = row[0]
- if seq == "COMBINED":
- seq = "COMBINED_SEQ"
- if (len(current_values) == len(keys)) and seq != "":
- data[seq] = {}
- for key, value in zip(keys, current_values):
- data[seq][key] = float(value)
- return data
- class TrackEvalException(Exception):
- """Custom exception for catching expected errors."""
- ...
|