| | |
| | """ |
| | Memory Manager for BackgroundFX Pro |
| | - Safe on CPU/CUDA/MPS (mostly CUDA/T4 on Spaces) |
| | - Accepts `device` as str or torch.device |
| | - Optional per-process VRAM cap (env or method) |
| | - Detailed usage reporting (CPU/RAM + VRAM + torch allocator) |
| | - Light and aggressive cleanup paths |
| | - Background monitor (optional) |
| | |
| | Env switches: |
| | BFX_DISABLE_LIMIT=1 -> do not set VRAM fraction automatically |
| | BFX_CUDA_FRACTION=0.80 -> fraction to cap per-process VRAM (0.10..0.95) |
| | """ |
| |
|
| | from __future__ import annotations |
| | import gc |
| | import os |
| | import time |
| | import logging |
| | import threading |
| | from typing import Dict, Any, Optional, Callable |
| |
|
| | |
| | try: |
| | import psutil |
| | except Exception: |
| | psutil = None |
| |
|
| | try: |
| | import torch |
| | except Exception: |
| | torch = None |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | class MemoryManagerError(Exception): |
| | pass |
| |
|
| |
|
| | def _bytes_to_gb(x: int | float) -> float: |
| | try: |
| | return float(x) / (1024**3) |
| | except Exception: |
| | return 0.0 |
| |
|
| |
|
| | def _normalize_device(dev) -> "torch.device": |
| | if torch is None: |
| | |
| | class _Fake: |
| | type = "cpu" |
| | index = None |
| | return _Fake() |
| |
|
| | if isinstance(dev, str): |
| | return torch.device(dev) |
| | if hasattr(dev, "type"): |
| | return dev |
| | |
| | return torch.device("cpu") |
| |
|
| |
|
| | def _cuda_index(device) -> Optional[int]: |
| | if getattr(device, "type", "cpu") != "cuda": |
| | return None |
| | idx = getattr(device, "index", None) |
| | if idx is None: |
| | |
| | return 0 |
| | return int(idx) |
| |
|
| |
|
| | class MemoryManager: |
| | """ |
| | Comprehensive memory management with VRAM cap + cleanup utilities. |
| | """ |
| |
|
| | def __init__(self, device, memory_limit_gb: Optional[float] = None): |
| | self.device = _normalize_device(device) |
| | self.device_type = getattr(self.device, "type", "cpu") |
| | self.cuda_idx = _cuda_index(self.device) |
| |
|
| | self.gpu_available = bool( |
| | torch and self.device_type == "cuda" and torch.cuda.is_available() |
| | ) |
| | self.mps_available = bool( |
| | torch and self.device_type == "mps" and getattr(torch.backends, "mps", None) |
| | and torch.backends.mps.is_available() |
| | ) |
| |
|
| | self.memory_limit_gb = memory_limit_gb |
| | self.cleanup_callbacks: list[Callable] = [] |
| | self.monitoring_active = False |
| | self.monitoring_thread: Optional[threading.Thread] = None |
| | self.stats = { |
| | "cleanup_count": 0, |
| | "peak_memory_usage": 0.0, |
| | "total_allocated": 0.0, |
| | "total_freed": 0.0, |
| | } |
| | self.applied_fraction: Optional[float] = None |
| |
|
| | self._initialize_memory_limits() |
| | self._maybe_apply_vram_fraction() |
| | logger.info(f"MemoryManager initialized (device={self.device}, cuda={self.gpu_available})") |
| |
|
| | |
| | |
| | |
| | def _initialize_memory_limits(self): |
| | try: |
| | if self.gpu_available: |
| | props = torch.cuda.get_device_properties(self.cuda_idx or 0) |
| | total_gb = _bytes_to_gb(props.total_memory) |
| | if self.memory_limit_gb is None: |
| | self.memory_limit_gb = max(0.5, total_gb * 0.80) |
| | logger.info( |
| | f"CUDA memory limit baseline ~{self.memory_limit_gb:.1f}GB " |
| | f"(device total {total_gb:.1f}GB)" |
| | ) |
| | elif self.mps_available: |
| | vm = psutil.virtual_memory() if psutil else None |
| | total_gb = _bytes_to_gb(vm.total) if vm else 0.0 |
| | if self.memory_limit_gb is None: |
| | self.memory_limit_gb = max(0.5, total_gb * 0.50) |
| | logger.info(f"MPS memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)") |
| | else: |
| | vm = psutil.virtual_memory() if psutil else None |
| | total_gb = _bytes_to_gb(vm.total) if vm else 0.0 |
| | if self.memory_limit_gb is None: |
| | self.memory_limit_gb = max(0.5, total_gb * 0.60) |
| | logger.info(f"CPU memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)") |
| | except Exception as e: |
| | logger.warning(f"Memory limit init failed: {e}") |
| | if self.memory_limit_gb is None: |
| | self.memory_limit_gb = 4.0 |
| |
|
| | def _maybe_apply_vram_fraction(self): |
| | if not self.gpu_available or torch is None: |
| | return |
| | if os.environ.get("BFX_DISABLE_LIMIT", ""): |
| | return |
| | frac_env = os.environ.get("BFX_CUDA_FRACTION", "").strip() |
| | try: |
| | fraction = float(frac_env) if frac_env else 0.80 |
| | except Exception: |
| | fraction = 0.80 |
| | applied = self.limit_cuda_memory(fraction=fraction) |
| | if applied: |
| | logger.info(f"Per-process CUDA memory fraction set to {applied:.2f} on device {self.cuda_idx or 0}") |
| |
|
| | |
| | |
| | |
| | def get_memory_usage(self) -> Dict[str, Any]: |
| | usage: Dict[str, Any] = { |
| | "device_type": self.device_type, |
| | "memory_limit_gb": self.memory_limit_gb, |
| | "timestamp": time.time(), |
| | } |
| |
|
| | |
| | if psutil: |
| | try: |
| | vm = psutil.virtual_memory() |
| | usage.update( |
| | dict( |
| | system_total_gb=round(_bytes_to_gb(vm.total), 3), |
| | system_available_gb=round(_bytes_to_gb(vm.available), 3), |
| | system_used_gb=round(_bytes_to_gb(vm.used), 3), |
| | system_percent=float(vm.percent), |
| | ) |
| | ) |
| | swap = psutil.swap_memory() |
| | usage.update( |
| | dict( |
| | swap_total_gb=round(_bytes_to_gb(swap.total), 3), |
| | swap_used_gb=round(_bytes_to_gb(swap.used), 3), |
| | swap_percent=float(swap.percent), |
| | ) |
| | ) |
| | proc = psutil.Process() |
| | mi = proc.memory_info() |
| | usage.update( |
| | dict( |
| | process_rss_gb=round(_bytes_to_gb(mi.rss), 3), |
| | process_vms_gb=round(_bytes_to_gb(mi.vms), 3), |
| | ) |
| | ) |
| | except Exception as e: |
| | logger.debug(f"psutil stats error: {e}") |
| |
|
| | |
| | if self.gpu_available and torch is not None: |
| | try: |
| | |
| | free_b, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0) |
| | used_b = total_b - free_b |
| | usage.update( |
| | dict( |
| | vram_total_gb=round(_bytes_to_gb(total_b), 3), |
| | vram_used_gb=round(_bytes_to_gb(used_b), 3), |
| | vram_free_gb=round(_bytes_to_gb(free_b), 3), |
| | vram_used_percent=float(used_b / total_b * 100.0) if total_b else 0.0, |
| | ) |
| | ) |
| | except Exception as e: |
| | logger.debug(f"mem_get_info failed: {e}") |
| |
|
| | |
| | try: |
| | idx = self.cuda_idx or 0 |
| | allocated = torch.cuda.memory_allocated(idx) |
| | reserved = torch.cuda.memory_reserved(idx) |
| | usage["torch_allocated_gb"] = round(_bytes_to_gb(allocated), 3) |
| | usage["torch_reserved_gb"] = round(_bytes_to_gb(reserved), 3) |
| | |
| | try: |
| | inactive = torch.cuda.memory_stats(idx).get("inactive_split_bytes.all.current", 0) |
| | usage["torch_inactive_split_gb"] = round(_bytes_to_gb(inactive), 3) |
| | except Exception: |
| | pass |
| | except Exception as e: |
| | logger.debug(f"allocator stats failed: {e}") |
| |
|
| | usage["applied_fraction"] = self.applied_fraction |
| |
|
| | |
| | current = usage.get("vram_used_gb", usage.get("system_used_gb", 0.0)) |
| | try: |
| | if float(current) > float(self.stats["peak_memory_usage"]): |
| | self.stats["peak_memory_usage"] = float(current) |
| | except Exception: |
| | pass |
| |
|
| | return usage |
| |
|
| | def limit_cuda_memory(self, fraction: Optional[float] = None, max_gb: Optional[float] = None) -> Optional[float]: |
| | if not self.gpu_available or torch is None: |
| | return None |
| |
|
| | |
| | if max_gb is not None: |
| | try: |
| | _, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0) |
| | total_gb = _bytes_to_gb(total_b) |
| | if total_gb <= 0: |
| | return None |
| | fraction = min(max(0.10, max_gb / total_gb), 0.95) |
| | except Exception as e: |
| | logger.debug(f"fraction from max_gb failed: {e}") |
| | return None |
| |
|
| | if fraction is None: |
| | fraction = 0.80 |
| | fraction = float(max(0.10, min(0.95, fraction))) |
| |
|
| | try: |
| | torch.cuda.set_per_process_memory_fraction(fraction, device=self.cuda_idx or 0) |
| | self.applied_fraction = fraction |
| | return fraction |
| | except Exception as e: |
| | logger.debug(f"set_per_process_memory_fraction failed: {e}") |
| | return None |
| |
|
| | def cleanup(self) -> None: |
| | """Light cleanup used frequently between steps.""" |
| | try: |
| | gc.collect() |
| | except Exception: |
| | pass |
| | if self.gpu_available and torch is not None: |
| | try: |
| | torch.cuda.empty_cache() |
| | except Exception: |
| | pass |
| | self.stats["cleanup_count"] += 1 |
| |
|
| | def cleanup_basic(self) -> None: |
| | """Alias kept for compatibility.""" |
| | self.cleanup() |
| |
|
| | def cleanup_aggressive(self) -> None: |
| | """Aggressive cleanup for OOM recovery or big scene switches.""" |
| | if self.gpu_available and torch is not None: |
| | try: |
| | torch.cuda.synchronize(self.cuda_idx or 0) |
| | except Exception: |
| | pass |
| | try: |
| | torch.cuda.empty_cache() |
| | except Exception: |
| | pass |
| | try: |
| | torch.cuda.reset_peak_memory_stats(self.cuda_idx or 0) |
| | except Exception: |
| | pass |
| | try: |
| | if hasattr(torch.cuda, "ipc_collect"): |
| | torch.cuda.ipc_collect() |
| | except Exception: |
| | pass |
| | try: |
| | gc.collect(); gc.collect() |
| | except Exception: |
| | pass |
| | self.stats["cleanup_count"] += 1 |
| |
|
| | def register_cleanup_callback(self, callback: Callable): |
| | self.cleanup_callbacks.append(callback) |
| |
|
| | def start_monitoring(self, interval_seconds: float = 30.0, pressure_callback: Optional[Callable] = None): |
| | if self.monitoring_active: |
| | logger.warning("Memory monitoring already active") |
| | return |
| | self.monitoring_active = True |
| |
|
| | def loop(): |
| | while self.monitoring_active: |
| | try: |
| | pressure = self.check_memory_pressure() |
| | if pressure["under_pressure"]: |
| | logger.warning( |
| | f"Memory pressure: {pressure['pressure_level']} " |
| | f"({pressure['usage_percent']:.1f}%)" |
| | ) |
| | if pressure_callback: |
| | try: |
| | pressure_callback(pressure) |
| | except Exception as e: |
| | logger.error(f"Pressure callback failed: {e}") |
| | if pressure["pressure_level"] == "critical": |
| | self.cleanup_aggressive() |
| | except Exception as e: |
| | logger.error(f"Memory monitoring error: {e}") |
| | time.sleep(interval_seconds) |
| |
|
| | self.monitoring_thread = threading.Thread(target=loop, daemon=True) |
| | self.monitoring_thread.start() |
| | logger.info(f"Memory monitoring started (interval: {interval_seconds}s)") |
| |
|
| | def stop_monitoring(self): |
| | if self.monitoring_active: |
| | self.monitoring_active = False |
| | if self.monitoring_thread and self.monitoring_thread.is_alive(): |
| | self.monitoring_thread.join(timeout=5.0) |
| | logger.info("Memory monitoring stopped") |
| |
|
| | def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]: |
| | usage = self.get_memory_usage() |
| | info = { |
| | "under_pressure": False, |
| | "pressure_level": "normal", |
| | "usage_percent": 0.0, |
| | "recommendations": [], |
| | } |
| |
|
| | if self.gpu_available: |
| | percent = usage.get("vram_used_percent", 0.0) |
| | info["usage_percent"] = percent |
| | if percent >= threshold_percent: |
| | info["under_pressure"] = True |
| | if percent >= 95: |
| | info["pressure_level"] = "critical" |
| | info["recommendations"] += [ |
| | "Run aggressive memory cleanup", |
| | "Reduce frame cache / chunk size", |
| | "Lower resolution or disable previews", |
| | ] |
| | else: |
| | info["pressure_level"] = "warning" |
| | info["recommendations"] += [ |
| | "Run cleanup", |
| | "Monitor memory usage", |
| | "Reduce keyframe interval", |
| | ] |
| | else: |
| | percent = usage.get("system_percent", 0.0) |
| | info["usage_percent"] = percent |
| | if percent >= threshold_percent: |
| | info["under_pressure"] = True |
| | if percent >= 95: |
| | info["pressure_level"] = "critical" |
| | info["recommendations"] += [ |
| | "Close other processes", |
| | "Reduce resolution", |
| | "Split video into chunks", |
| | ] |
| | else: |
| | info["pressure_level"] = "warning" |
| | info["recommendations"] += [ |
| | "Run cleanup", |
| | "Monitor usage", |
| | "Reduce processing footprint", |
| | ] |
| | return info |
| |
|
| | def estimate_memory_requirement(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, float]: |
| | bytes_per_frame = video_width * video_height * 3 |
| | overhead_multiplier = 3.0 |
| | frames_gb = _bytes_to_gb(bytes_per_frame * frames_in_memory * overhead_multiplier) |
| | estimate = { |
| | "frames_memory_gb": round(frames_gb, 3), |
| | "model_memory_gb": 4.0, |
| | "system_overhead_gb": 2.0, |
| | } |
| | estimate["total_estimated_gb"] = round( |
| | estimate["frames_memory_gb"] + estimate["model_memory_gb"] + estimate["system_overhead_gb"], 3 |
| | ) |
| | return estimate |
| |
|
| | def can_process_video(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, Any]: |
| | estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory) |
| | usage = self.get_memory_usage() |
| | if self.gpu_available: |
| | available = usage.get("vram_free_gb", 0.0) |
| | else: |
| | available = usage.get("system_available_gb", 0.0) |
| |
|
| | can = estimate["total_estimated_gb"] <= available |
| | return { |
| | "can_process": can, |
| | "estimated_memory_gb": estimate["total_estimated_gb"], |
| | "available_memory_gb": available, |
| | "memory_margin_gb": round(available - estimate["total_estimated_gb"], 3), |
| | "recommendations": [] if can else [ |
| | "Reduce resolution or duration", |
| | "Process in smaller chunks", |
| | "Run aggressive cleanup before start", |
| | ], |
| | } |
| |
|
| | def get_stats(self) -> Dict[str, Any]: |
| | return { |
| | "cleanup_count": self.stats["cleanup_count"], |
| | "peak_memory_usage_gb": self.stats["peak_memory_usage"], |
| | "device_type": self.device_type, |
| | "memory_limit_gb": self.memory_limit_gb, |
| | "applied_fraction": self.applied_fraction, |
| | "monitoring_active": self.monitoring_active, |
| | "callbacks_registered": len(self.cleanup_callbacks), |
| | } |
| |
|
| | def __del__(self): |
| | try: |
| | self.stop_monitoring() |
| | self.cleanup_aggressive() |
| | except Exception: |
| | pass |
| |
|