1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import threading
- import time
- from collections import defaultdict
- import torch
- class MemUsageMonitor(threading.Thread):
- run_flag = None
- device = None
- disabled = False
- opts = None
- data = None
- def __init__(self, name, device, opts):
- threading.Thread.__init__(self)
- self.name = name
- self.device = device
- self.opts = opts
- self.daemon = True
- self.run_flag = threading.Event()
- self.data = defaultdict(int)
- try:
- self.cuda_mem_get_info()
- torch.cuda.memory_stats(self.device)
- except Exception as e: # AMD or whatever
- print(f"Warning: caught exception '{e}', memory monitor disabled")
- self.disabled = True
- def cuda_mem_get_info(self):
- index = self.device.index if self.device.index is not None else torch.cuda.current_device()
- return torch.cuda.mem_get_info(index)
- def run(self):
- if self.disabled:
- return
- while True:
- self.run_flag.wait()
- torch.cuda.reset_peak_memory_stats()
- self.data.clear()
- if self.opts.memmon_poll_rate <= 0:
- self.run_flag.clear()
- continue
- self.data["min_free"] = self.cuda_mem_get_info()[0]
- while self.run_flag.is_set():
- free, total = self.cuda_mem_get_info()
- self.data["min_free"] = min(self.data["min_free"], free)
- time.sleep(1 / self.opts.memmon_poll_rate)
- def dump_debug(self):
- print(self, 'recorded data:')
- for k, v in self.read().items():
- print(k, -(v // -(1024 ** 2)))
- print(self, 'raw torch memory stats:')
- tm = torch.cuda.memory_stats(self.device)
- for k, v in tm.items():
- if 'bytes' not in k:
- continue
- print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
- print(torch.cuda.memory_summary())
- def monitor(self):
- self.run_flag.set()
- def read(self):
- if not self.disabled:
- free, total = self.cuda_mem_get_info()
- self.data["free"] = free
- self.data["total"] = total
- torch_stats = torch.cuda.memory_stats(self.device)
- self.data["active"] = torch_stats["active.all.current"]
- self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
- self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
- self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
- self.data["system_peak"] = total - self.data["min_free"]
- return self.data
- def stop(self):
- self.run_flag.clear()
- return self.read()
|