File size: 7,663 Bytes
a9dc537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
GPU Manager for SPARKNET
Handles GPU allocation, monitoring, and resource management
"""

import os
import torch
from typing import Optional, List, Dict
from contextlib import contextmanager
import pynvml
from loguru import logger


class GPUManager:
    """Manages GPU resources for model deployment and monitoring."""

    def __init__(self, primary_gpu: int = 0, fallback_gpus: Optional[List[int]] = None):
        """
        Initialize GPU Manager.

        Args:
            primary_gpu: Primary GPU device ID (default: 0)
            fallback_gpus: List of fallback GPU IDs (default: [1, 2, 3])
        """
        self.primary_gpu = primary_gpu
        self.fallback_gpus = fallback_gpus or [1, 2, 3]
        self.initialized = False

        # Initialize NVML for GPU monitoring
        try:
            pynvml.nvmlInit()
            self.initialized = True
            logger.info("GPU Manager initialized with NVML")
        except Exception as e:
            logger.warning(f"Failed to initialize NVML: {e}")

        # Detect available GPUs
        self.available_gpus = self._detect_gpus()
        logger.info(f"Detected {len(self.available_gpus)} GPUs: {self.available_gpus}")

    def _detect_gpus(self) -> List[int]:
        """Detect available CUDA GPUs."""
        if not torch.cuda.is_available():
            logger.warning("CUDA not available!")
            return []

        gpu_count = torch.cuda.device_count()
        return list(range(gpu_count))

    def get_gpu_info(self, gpu_id: int) -> Dict[str, any]:
        """
        Get detailed information about a GPU.

        Args:
            gpu_id: GPU device ID

        Returns:
            Dictionary with GPU information
        """
        if not self.initialized:
            return {"error": "NVML not initialized"}

        try:
            handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
            mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
            temperature = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
            name = pynvml.nvmlDeviceGetName(handle)

            return {
                "gpu_id": gpu_id,
                "name": name,
                "memory_total": mem_info.total,
                "memory_used": mem_info.used,
                "memory_free": mem_info.free,
                "memory_percent": (mem_info.used / mem_info.total) * 100,
                "gpu_utilization": utilization.gpu,
                "memory_utilization": utilization.memory,
                "temperature": temperature,
            }
        except Exception as e:
            logger.error(f"Error getting GPU {gpu_id} info: {e}")
            return {"error": str(e)}

    def get_all_gpu_info(self) -> List[Dict[str, any]]:
        """Get information for all available GPUs."""
        return [self.get_gpu_info(gpu_id) for gpu_id in self.available_gpus]

    def get_free_memory(self, gpu_id: int) -> int:
        """
        Get free memory on a GPU in bytes.

        Args:
            gpu_id: GPU device ID

        Returns:
            Free memory in bytes
        """
        info = self.get_gpu_info(gpu_id)
        return info.get("memory_free", 0)

    def select_best_gpu(self, min_memory_gb: float = 8.0) -> Optional[int]:
        """
        Select the best available GPU based on free memory.

        Args:
            min_memory_gb: Minimum required free memory in GB

        Returns:
            GPU ID or None if no suitable GPU found
        """
        min_memory_bytes = min_memory_gb * 1024 ** 3

        # Try primary GPU first
        if self.primary_gpu in self.available_gpus:
            free_mem = self.get_free_memory(self.primary_gpu)
            if free_mem >= min_memory_bytes:
                logger.info(f"Selected primary GPU {self.primary_gpu} ({free_mem / 1024**3:.2f} GB free)")
                return self.primary_gpu

        # Try fallback GPUs
        for gpu_id in self.fallback_gpus:
            if gpu_id in self.available_gpus:
                free_mem = self.get_free_memory(gpu_id)
                if free_mem >= min_memory_bytes:
                    logger.info(f"Selected fallback GPU {gpu_id} ({free_mem / 1024**3:.2f} GB free)")
                    return gpu_id

        logger.warning(f"No GPU found with {min_memory_gb} GB free memory")
        return None

    def set_device(self, gpu_id: int):
        """
        Set the CUDA device.

        Args:
            gpu_id: GPU device ID
        """
        if gpu_id not in self.available_gpus:
            raise ValueError(f"GPU {gpu_id} not available")

        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        torch.cuda.set_device(gpu_id)
        logger.info(f"Set CUDA device to GPU {gpu_id}")

    @contextmanager
    def gpu_context(self, gpu_id: Optional[int] = None, min_memory_gb: float = 8.0):
        """
        Context manager for GPU allocation.

        Args:
            gpu_id: Specific GPU ID or None for auto-selection
            min_memory_gb: Minimum required memory in GB

        Yields:
            GPU device ID
        """
        # Select GPU
        if gpu_id is None:
            gpu_id = self.select_best_gpu(min_memory_gb)
            if gpu_id is None:
                raise RuntimeError("No suitable GPU available")

        # Store original device
        original_device = os.environ.get("CUDA_VISIBLE_DEVICES", "")

        try:
            self.set_device(gpu_id)
            yield gpu_id
        finally:
            # Restore original device
            if original_device:
                os.environ["CUDA_VISIBLE_DEVICES"] = original_device
            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                logger.debug("Cleared CUDA cache")

    def clear_cache(self, gpu_id: Optional[int] = None):
        """
        Clear CUDA cache for a specific GPU or all GPUs.

        Args:
            gpu_id: GPU device ID or None for all GPUs
        """
        if gpu_id is not None:
            with torch.cuda.device(gpu_id):
                torch.cuda.empty_cache()
            logger.info(f"Cleared cache for GPU {gpu_id}")
        else:
            torch.cuda.empty_cache()
            logger.info("Cleared cache for all GPUs")

    def monitor(self) -> str:
        """
        Get a formatted monitoring string for all GPUs.

        Returns:
            Formatted string with GPU status
        """
        info_list = self.get_all_gpu_info()

        lines = ["GPU Status:"]
        for info in info_list:
            if "error" in info:
                lines.append(f"  GPU {info.get('gpu_id', '?')}: Error - {info['error']}")
            else:
                lines.append(
                    f"  GPU {info['gpu_id']}: {info['name']} | "
                    f"Memory: {info['memory_used'] / 1024**3:.2f}/{info['memory_total'] / 1024**3:.2f} GB "
                    f"({info['memory_percent']:.1f}%) | "
                    f"Utilization: {info['gpu_utilization']}% | "
                    f"Temp: {info['temperature']}°C"
                )

        return "\n".join(lines)

    def __del__(self):
        """Cleanup NVML on deletion."""
        if self.initialized:
            try:
                pynvml.nvmlShutdown()
            except Exception:
                pass


# Global GPU manager instance
_gpu_manager = None


def get_gpu_manager() -> GPUManager:
    """Get or create the global GPU manager instance."""
    global _gpu_manager
    if _gpu_manager is None:
        _gpu_manager = GPUManager()
    return _gpu_manager