logger.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. # Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py
  6. import atexit
  7. import functools
  8. import logging
  9. import sys
  10. import uuid
  11. from typing import Any, Dict, Optional, Union
  12. from hydra.utils import instantiate
  13. from iopath.common.file_io import g_pathmgr
  14. from numpy import ndarray
  15. from torch import Tensor
  16. from torch.utils.tensorboard import SummaryWriter
  17. from training.utils.train_utils import get_machine_local_and_dist_rank, makedir
  18. Scalar = Union[Tensor, ndarray, int, float]
  19. def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
  20. makedir(log_dir)
  21. summary_writer_method = SummaryWriter
  22. return TensorBoardLogger(
  23. path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
  24. )
  25. class TensorBoardWriterWrapper:
  26. """
  27. A wrapper around a SummaryWriter object.
  28. """
  29. def __init__(
  30. self,
  31. path: str,
  32. *args: Any,
  33. filename_suffix: str = None,
  34. summary_writer_method: Any = SummaryWriter,
  35. **kwargs: Any,
  36. ) -> None:
  37. """Create a new TensorBoard logger.
  38. On construction, the logger creates a new events file that logs
  39. will be written to. If the environment variable `RANK` is defined,
  40. logger will only log if RANK = 0.
  41. NOTE: If using the logger with distributed training:
  42. - This logger can call collective operations
  43. - Logs will be written on rank 0 only
  44. - Logger must be constructed synchronously *after* initializing distributed process group.
  45. Args:
  46. path (str): path to write logs to
  47. *args, **kwargs: Extra arguments to pass to SummaryWriter
  48. """
  49. self._writer: Optional[SummaryWriter] = None
  50. _, self._rank = get_machine_local_and_dist_rank()
  51. self._path: str = path
  52. if self._rank == 0:
  53. logging.info(
  54. f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
  55. )
  56. self._writer = summary_writer_method(
  57. log_dir=path,
  58. *args,
  59. filename_suffix=filename_suffix or str(uuid.uuid4()),
  60. **kwargs,
  61. )
  62. else:
  63. logging.debug(
  64. f"Not logging meters on this host because env RANK: {self._rank} != 0"
  65. )
  66. atexit.register(self.close)
  67. @property
  68. def writer(self) -> Optional[SummaryWriter]:
  69. return self._writer
  70. @property
  71. def path(self) -> str:
  72. return self._path
  73. def flush(self) -> None:
  74. """Writes pending logs to disk."""
  75. if not self._writer:
  76. return
  77. self._writer.flush()
  78. def close(self) -> None:
  79. """Close writer, flushing pending logs to disk.
  80. Logs cannot be written after `close` is called.
  81. """
  82. if not self._writer:
  83. return
  84. self._writer.close()
  85. self._writer = None
  86. class TensorBoardLogger(TensorBoardWriterWrapper):
  87. """
  88. A simple logger for TensorBoard.
  89. """
  90. def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
  91. """Add multiple scalar values to TensorBoard.
  92. Args:
  93. payload (dict): dictionary of tag name and scalar value
  94. step (int, Optional): step value to record
  95. """
  96. if not self._writer:
  97. return
  98. for k, v in payload.items():
  99. self.log(k, v, step)
  100. def log(self, name: str, data: Scalar, step: int) -> None:
  101. """Add scalar data to TensorBoard.
  102. Args:
  103. name (string): tag name used to group scalars
  104. data (float/int/Tensor): scalar data to log
  105. step (int, optional): step value to record
  106. """
  107. if not self._writer:
  108. return
  109. self._writer.add_scalar(name, data, global_step=step, new_style=True)
  110. def log_hparams(
  111. self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
  112. ) -> None:
  113. """Add hyperparameter data to TensorBoard.
  114. Args:
  115. hparams (dict): dictionary of hyperparameter names and corresponding values
  116. meters (dict): dictionary of name of meter and corersponding values
  117. """
  118. if not self._writer:
  119. return
  120. self._writer.add_hparams(hparams, meters)
  121. class Logger:
  122. """
  123. A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
  124. """
  125. def __init__(self, logging_conf):
  126. # allow turning off TensorBoard with "should_log: false" in config
  127. tb_config = logging_conf.tensorboard_writer
  128. tb_should_log = tb_config and tb_config.pop("should_log", True)
  129. self.tb_logger = instantiate(tb_config) if tb_should_log else None
  130. def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
  131. if self.tb_logger:
  132. self.tb_logger.log_dict(payload, step)
  133. def log(self, name: str, data: Scalar, step: int) -> None:
  134. if self.tb_logger:
  135. self.tb_logger.log(name, data, step)
  136. def log_hparams(
  137. self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
  138. ) -> None:
  139. if self.tb_logger:
  140. self.tb_logger.log_hparams(hparams, meters)
  141. # cache the opened file object, so that different calls to `setup_logger`
  142. # with the same file name can safely write to the same file.
  143. @functools.lru_cache(maxsize=None)
  144. def _cached_log_stream(filename):
  145. # we tune the buffering value so that the logs are updated
  146. # frequently.
  147. log_buffer_kb = 10 * 1024 # 10KB
  148. io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
  149. atexit.register(io.close)
  150. return io
  151. def setup_logging(
  152. name,
  153. output_dir=None,
  154. rank=0,
  155. log_level_primary="INFO",
  156. log_level_secondary="ERROR",
  157. ):
  158. """
  159. Setup various logging streams: stdout and file handlers.
  160. For file handlers, we only setup for the master gpu.
  161. """
  162. # get the filename if we want to log to the file as well
  163. log_filename = None
  164. if output_dir:
  165. makedir(output_dir)
  166. if rank == 0:
  167. log_filename = f"{output_dir}/log.txt"
  168. logger = logging.getLogger(name)
  169. logger.setLevel(log_level_primary)
  170. # create formatter
  171. FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
  172. formatter = logging.Formatter(FORMAT)
  173. # Cleanup any existing handlers
  174. for h in logger.handlers:
  175. logger.removeHandler(h)
  176. logger.root.handlers = []
  177. # setup the console handler
  178. console_handler = logging.StreamHandler(sys.stdout)
  179. console_handler.setFormatter(formatter)
  180. logger.addHandler(console_handler)
  181. if rank == 0:
  182. console_handler.setLevel(log_level_primary)
  183. else:
  184. console_handler.setLevel(log_level_secondary)
  185. # we log to file as well if user wants
  186. if log_filename and rank == 0:
  187. file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
  188. file_handler.setLevel(log_level_primary)
  189. file_handler.setFormatter(formatter)
  190. logger.addHandler(file_handler)
  191. logging.root = logger
  192. def shutdown_logging():
  193. """
  194. After training is done, we ensure to shut down all the logger streams.
  195. """
  196. logging.info("Shutting down loggers...")
  197. handlers = logging.root.handlers
  198. for handler in handlers:
  199. handler.close()