|
|
import torch |
|
|
import psutil |
|
|
import os |
|
|
import logging |
|
|
from typing import Optional |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GPUMemoryMonitor: |
|
|
def __init__(self, |
|
|
memory_threshold: float = 0.9, |
|
|
check_interval: int = 100, |
|
|
gpu_id: Optional[int] = None): |
|
|
self.memory_threshold = memory_threshold |
|
|
self.check_interval = check_interval |
|
|
self.gpu_id = gpu_id if gpu_id is not None else 0 |
|
|
self.step_count = 0 |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
logger.warning("CUDA is not available. GPU monitoring will be disabled.") |
|
|
self.enabled = False |
|
|
else: |
|
|
self.enabled = True |
|
|
self.device = torch.device(f"cuda:{self.gpu_id}") |
|
|
|
|
|
def check_memory(self) -> bool: |
|
|
"""Check if GPU memory usage is below threshold""" |
|
|
if not self.enabled: |
|
|
return True |
|
|
|
|
|
self.step_count += 1 |
|
|
if self.step_count % self.check_interval != 0: |
|
|
return True |
|
|
|
|
|
try: |
|
|
|
|
|
memory_allocated = torch.cuda.memory_allocated(self.device) |
|
|
memory_reserved = torch.cuda.memory_reserved(self.device) |
|
|
memory_total = torch.cuda.get_device_properties(self.device).total_memory |
|
|
|
|
|
|
|
|
memory_ratio = memory_allocated / memory_total |
|
|
|
|
|
if memory_ratio > self.memory_threshold: |
|
|
logger.warning(f"GPU memory usage ({memory_ratio:.2%}) exceeds threshold ({self.memory_threshold:.2%})") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error checking GPU memory: {str(e)}") |
|
|
return True |
|
|
|
|
|
def clear_memory(self): |
|
|
"""Clear GPU memory cache""" |
|
|
if self.enabled: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_memory_stats(self) -> dict: |
|
|
"""Get current GPU memory statistics""" |
|
|
if not self.enabled: |
|
|
return {"enabled": False} |
|
|
|
|
|
try: |
|
|
memory_allocated = torch.cuda.memory_allocated(self.device) |
|
|
memory_reserved = torch.cuda.memory_reserved(self.device) |
|
|
memory_total = torch.cuda.get_device_properties(self.device).total_memory |
|
|
|
|
|
return { |
|
|
"enabled": True, |
|
|
"allocated_gb": memory_allocated / 1024**3, |
|
|
"reserved_gb": memory_reserved / 1024**3, |
|
|
"total_gb": memory_total / 1024**3, |
|
|
"usage_ratio": memory_allocated / memory_total |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Error getting GPU memory stats: {str(e)}") |
|
|
return {"enabled": False, "error": str(e)} |