Add some error handling for VRAM monitor

This commit is contained in:
EyeDeck
2022-09-18 05:20:33 -04:00
parent 7e77938230
commit fabaf4bddb
2 changed files with 29 additions and 17 deletions

View File

@@ -22,6 +22,13 @@ class MemUsageMonitor(threading.Thread):
self.run_flag = threading.Event()
self.data = defaultdict(int)
try:
torch.cuda.mem_get_info()
torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True
def run(self):
if self.disabled:
return
@@ -62,13 +69,14 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.set()
def read(self):
free, total = torch.cuda.mem_get_info()
self.data["total"] = total
if not self.disabled:
free, total = torch.cuda.mem_get_info()
self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device)
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"]
torch_stats = torch.cuda.memory_stats(self.device)
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"]
return self.data