File size: 2,922 Bytes
6534252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import psutil
import os
import logging
from typing import Optional

logger = logging.getLogger(__name__)

class GPUMemoryMonitor:
    def __init__(self, 
                 memory_threshold: float = 0.9,  # 90% of GPU memory
                 check_interval: int = 100,      # Check every 100 steps
                 gpu_id: Optional[int] = None):
        self.memory_threshold = memory_threshold
        self.check_interval = check_interval
        self.gpu_id = gpu_id if gpu_id is not None else 0
        self.step_count = 0
        
        if not torch.cuda.is_available():
            logger.warning("CUDA is not available. GPU monitoring will be disabled.")
            self.enabled = False
        else:
            self.enabled = True
            self.device = torch.device(f"cuda:{self.gpu_id}")
            
    def check_memory(self) -> bool:
        """Check if GPU memory usage is below threshold"""
        if not self.enabled:
            return True
            
        self.step_count += 1
        if self.step_count % self.check_interval != 0:
            return True
            
        try:
            # Get GPU memory info
            memory_allocated = torch.cuda.memory_allocated(self.device)
            memory_reserved = torch.cuda.memory_reserved(self.device)
            memory_total = torch.cuda.get_device_properties(self.device).total_memory
            
            # Calculate memory usage ratio
            memory_ratio = memory_allocated / memory_total
            
            if memory_ratio > self.memory_threshold:
                logger.warning(f"GPU memory usage ({memory_ratio:.2%}) exceeds threshold ({self.memory_threshold:.2%})")
                return False
                
            return True
            
        except Exception as e:
            logger.error(f"Error checking GPU memory: {str(e)}")
            return True
            
    def clear_memory(self):
        """Clear GPU memory cache"""
        if self.enabled:
            torch.cuda.empty_cache()
            
    def get_memory_stats(self) -> dict:
        """Get current GPU memory statistics"""
        if not self.enabled:
            return {"enabled": False}
            
        try:
            memory_allocated = torch.cuda.memory_allocated(self.device)
            memory_reserved = torch.cuda.memory_reserved(self.device)
            memory_total = torch.cuda.get_device_properties(self.device).total_memory
            
            return {
                "enabled": True,
                "allocated_gb": memory_allocated / 1024**3,
                "reserved_gb": memory_reserved / 1024**3,
                "total_gb": memory_total / 1024**3,
                "usage_ratio": memory_allocated / memory_total
            }
        except Exception as e:
            logger.error(f"Error getting GPU memory stats: {str(e)}")
            return {"enabled": False, "error": str(e)}