import torch import psutil import os import logging from typing import Optional logger = logging.getLogger(__name__) class GPUMemoryMonitor: def __init__(self, memory_threshold: float = 0.9, # 90% of GPU memory check_interval: int = 100, # Check every 100 steps gpu_id: Optional[int] = None): self.memory_threshold = memory_threshold self.check_interval = check_interval self.gpu_id = gpu_id if gpu_id is not None else 0 self.step_count = 0 if not torch.cuda.is_available(): logger.warning("CUDA is not available. GPU monitoring will be disabled.") self.enabled = False else: self.enabled = True self.device = torch.device(f"cuda:{self.gpu_id}") def check_memory(self) -> bool: """Check if GPU memory usage is below threshold""" if not self.enabled: return True self.step_count += 1 if self.step_count % self.check_interval != 0: return True try: # Get GPU memory info memory_allocated = torch.cuda.memory_allocated(self.device) memory_reserved = torch.cuda.memory_reserved(self.device) memory_total = torch.cuda.get_device_properties(self.device).total_memory # Calculate memory usage ratio memory_ratio = memory_allocated / memory_total if memory_ratio > self.memory_threshold: logger.warning(f"GPU memory usage ({memory_ratio:.2%}) exceeds threshold ({self.memory_threshold:.2%})") return False return True except Exception as e: logger.error(f"Error checking GPU memory: {str(e)}") return True def clear_memory(self): """Clear GPU memory cache""" if self.enabled: torch.cuda.empty_cache() def get_memory_stats(self) -> dict: """Get current GPU memory statistics""" if not self.enabled: return {"enabled": False} try: memory_allocated = torch.cuda.memory_allocated(self.device) memory_reserved = torch.cuda.memory_reserved(self.device) memory_total = torch.cuda.get_device_properties(self.device).total_memory return { "enabled": True, "allocated_gb": memory_allocated / 1024**3, "reserved_gb": memory_reserved / 1024**3, "total_gb": memory_total / 1024**3, "usage_ratio": memory_allocated / memory_total } except Exception as e: logger.error(f"Error getting GPU memory stats: {str(e)}") return {"enabled": False, "error": str(e)}