config.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # fmt: off
  2. # flake8: noqa
  3. # pyre-unsafe
  4. """Config."""
  5. import argparse
  6. import os
  7. def parse_configs():
  8. """Parse command line."""
  9. default_eval_config = get_default_eval_config()
  10. default_eval_config["DISPLAY_LESS_PROGRESS"] = True
  11. default_dataset_config = get_default_dataset_config()
  12. default_metrics_config = {"METRICS": ["TETA"]}
  13. config = {
  14. **default_eval_config,
  15. **default_dataset_config,
  16. **default_metrics_config,
  17. }
  18. parser = argparse.ArgumentParser()
  19. for setting in config.keys():
  20. if type(config[setting]) == list or type(config[setting]) == type(None):
  21. parser.add_argument("--" + setting, nargs="+")
  22. else:
  23. parser.add_argument("--" + setting)
  24. args = parser.parse_args().__dict__
  25. for setting in args.keys():
  26. if args[setting] is not None:
  27. if type(config[setting]) == type(True):
  28. if args[setting] == "True":
  29. x = True
  30. elif args[setting] == "False":
  31. x = False
  32. else:
  33. raise Exception(
  34. f"Command line parameter {setting} must be True/False"
  35. )
  36. elif type(config[setting]) == type(1):
  37. x = int(args[setting])
  38. elif type(args[setting]) == type(None):
  39. x = None
  40. else:
  41. x = args[setting]
  42. config[setting] = x
  43. eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
  44. dataset_config = {
  45. k: v for k, v in config.items() if k in default_dataset_config.keys()
  46. }
  47. metrics_config = {
  48. k: v for k, v in config.items() if k in default_metrics_config.keys()
  49. }
  50. return eval_config, dataset_config, metrics_config
  51. def get_default_eval_config():
  52. """Returns the default config values for evaluation."""
  53. code_path = get_code_path()
  54. default_config = {
  55. "USE_PARALLEL": True,
  56. "NUM_PARALLEL_CORES": 8,
  57. "BREAK_ON_ERROR": True,
  58. "RETURN_ON_ERROR": False,
  59. "LOG_ON_ERROR": os.path.join(code_path, "error_log.txt"),
  60. "PRINT_RESULTS": True,
  61. "PRINT_ONLY_COMBINED": True,
  62. "PRINT_CONFIG": True,
  63. "TIME_PROGRESS": True,
  64. "DISPLAY_LESS_PROGRESS": True,
  65. "OUTPUT_SUMMARY": True,
  66. "OUTPUT_EMPTY_CLASSES": True,
  67. "OUTPUT_TEM_RAW_DATA": True,
  68. "OUTPUT_PER_SEQ_RES": True,
  69. }
  70. return default_config
  71. def get_default_dataset_config():
  72. """Default class config values"""
  73. code_path = get_code_path()
  74. default_config = {
  75. "GT_FOLDER": os.path.join(
  76. code_path, "data/gt/tao/tao_training"
  77. ), # Location of GT data
  78. "TRACKERS_FOLDER": os.path.join(
  79. code_path, "data/trackers/tao/tao_training"
  80. ), # Trackers location
  81. "OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
  82. "TRACKERS_TO_EVAL": ['TETer'], # Filenames of trackers to eval (if None, all in folder)
  83. "CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes)
  84. "SPLIT_TO_EVAL": "training", # Valid: 'training', 'val'
  85. "PRINT_CONFIG": True, # Whether to print current config
  86. "TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
  87. "OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
  88. "TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
  89. "MAX_DETECTIONS": 0, # Number of maximal allowed detections per image (0 for unlimited)
  90. "USE_MASK": False, # Whether to use mask data for evaluation
  91. }
  92. return default_config
  93. def init_config(config, default_config, name=None):
  94. """Initialize non-given config values with defaults."""
  95. if config is None:
  96. config = default_config
  97. else:
  98. for k in default_config.keys():
  99. if k not in config.keys():
  100. config[k] = default_config[k]
  101. if name and config["PRINT_CONFIG"]:
  102. print("\n%s Config:" % name)
  103. for c in config.keys():
  104. print("%-20s : %-30s" % (c, config[c]))
  105. return config
  106. def update_config(config):
  107. """
  108. Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
  109. :param config: the config to update
  110. :return: the updated config
  111. """
  112. parser = argparse.ArgumentParser()
  113. for setting in config.keys():
  114. if type(config[setting]) == list or type(config[setting]) == type(None):
  115. parser.add_argument("--" + setting, nargs="+")
  116. else:
  117. parser.add_argument("--" + setting)
  118. args = parser.parse_args().__dict__
  119. for setting in args.keys():
  120. if args[setting] is not None:
  121. if type(config[setting]) == type(True):
  122. if args[setting] == "True":
  123. x = True
  124. elif args[setting] == "False":
  125. x = False
  126. else:
  127. raise Exception(
  128. "Command line parameter " + setting + "must be True or False"
  129. )
  130. elif type(config[setting]) == type(1):
  131. x = int(args[setting])
  132. elif type(args[setting]) == type(None):
  133. x = None
  134. else:
  135. x = args[setting]
  136. config[setting] = x
  137. return config
  138. def get_code_path():
  139. """Get base path where code is"""
  140. return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))