memmon.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import threading
  2. import time
  3. from collections import defaultdict
  4. import torch
  5. class MemUsageMonitor(threading.Thread):
  6. run_flag = None
  7. device = None
  8. disabled = False
  9. opts = None
  10. data = None
  11. def __init__(self, name, device, opts):
  12. threading.Thread.__init__(self)
  13. self.name = name
  14. self.device = device
  15. self.opts = opts
  16. self.daemon = True
  17. self.run_flag = threading.Event()
  18. self.data = defaultdict(int)
  19. try:
  20. self.cuda_mem_get_info()
  21. torch.cuda.memory_stats(self.device)
  22. except Exception as e: # AMD or whatever
  23. print(f"Warning: caught exception '{e}', memory monitor disabled")
  24. self.disabled = True
  25. def cuda_mem_get_info(self):
  26. index = self.device.index if self.device.index is not None else torch.cuda.current_device()
  27. return torch.cuda.mem_get_info(index)
  28. def run(self):
  29. if self.disabled:
  30. return
  31. while True:
  32. self.run_flag.wait()
  33. torch.cuda.reset_peak_memory_stats()
  34. self.data.clear()
  35. if self.opts.memmon_poll_rate <= 0:
  36. self.run_flag.clear()
  37. continue
  38. self.data["min_free"] = self.cuda_mem_get_info()[0]
  39. while self.run_flag.is_set():
  40. free, total = self.cuda_mem_get_info()
  41. self.data["min_free"] = min(self.data["min_free"], free)
  42. time.sleep(1 / self.opts.memmon_poll_rate)
  43. def dump_debug(self):
  44. print(self, 'recorded data:')
  45. for k, v in self.read().items():
  46. print(k, -(v // -(1024 ** 2)))
  47. print(self, 'raw torch memory stats:')
  48. tm = torch.cuda.memory_stats(self.device)
  49. for k, v in tm.items():
  50. if 'bytes' not in k:
  51. continue
  52. print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
  53. print(torch.cuda.memory_summary())
  54. def monitor(self):
  55. self.run_flag.set()
  56. def read(self):
  57. if not self.disabled:
  58. free, total = self.cuda_mem_get_info()
  59. self.data["free"] = free
  60. self.data["total"] = total
  61. torch_stats = torch.cuda.memory_stats(self.device)
  62. self.data["active"] = torch_stats["active.all.current"]
  63. self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
  64. self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
  65. self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
  66. self.data["system_peak"] = total - self.data["min_free"]
  67. return self.data
  68. def stop(self):
  69. self.run_flag.clear()
  70. return self.read()