GLEN-model / src /tevatron /utils /gpu_monitor.py
QuanTH02's picture
Commit 15-06-v1
6534252
import torch
import psutil
import os
import logging
from typing import Optional
logger = logging.getLogger(__name__)
class GPUMemoryMonitor:
def __init__(self,
memory_threshold: float = 0.9, # 90% of GPU memory
check_interval: int = 100, # Check every 100 steps
gpu_id: Optional[int] = None):
self.memory_threshold = memory_threshold
self.check_interval = check_interval
self.gpu_id = gpu_id if gpu_id is not None else 0
self.step_count = 0
if not torch.cuda.is_available():
logger.warning("CUDA is not available. GPU monitoring will be disabled.")
self.enabled = False
else:
self.enabled = True
self.device = torch.device(f"cuda:{self.gpu_id}")
def check_memory(self) -> bool:
"""Check if GPU memory usage is below threshold"""
if not self.enabled:
return True
self.step_count += 1
if self.step_count % self.check_interval != 0:
return True
try:
# Get GPU memory info
memory_allocated = torch.cuda.memory_allocated(self.device)
memory_reserved = torch.cuda.memory_reserved(self.device)
memory_total = torch.cuda.get_device_properties(self.device).total_memory
# Calculate memory usage ratio
memory_ratio = memory_allocated / memory_total
if memory_ratio > self.memory_threshold:
logger.warning(f"GPU memory usage ({memory_ratio:.2%}) exceeds threshold ({self.memory_threshold:.2%})")
return False
return True
except Exception as e:
logger.error(f"Error checking GPU memory: {str(e)}")
return True
def clear_memory(self):
"""Clear GPU memory cache"""
if self.enabled:
torch.cuda.empty_cache()
def get_memory_stats(self) -> dict:
"""Get current GPU memory statistics"""
if not self.enabled:
return {"enabled": False}
try:
memory_allocated = torch.cuda.memory_allocated(self.device)
memory_reserved = torch.cuda.memory_reserved(self.device)
memory_total = torch.cuda.get_device_properties(self.device).total_memory
return {
"enabled": True,
"allocated_gb": memory_allocated / 1024**3,
"reserved_gb": memory_reserved / 1024**3,
"total_gb": memory_total / 1024**3,
"usage_ratio": memory_allocated / memory_total
}
except Exception as e:
logger.error(f"Error getting GPU memory stats: {str(e)}")
return {"enabled": False, "error": str(e)}