SPARKNET / src /utils /config.py
MHamdan's picture
Initial commit: SPARKNET framework
a9dc537
"""
Configuration management for SPARKNET
Handles loading and validation of configuration files
"""
import yaml
from pathlib import Path
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from loguru import logger
class GPUConfig(BaseModel):
"""GPU configuration."""
primary: int = Field(default=0, description="Primary GPU device ID")
fallback: list[int] = Field(default=[1, 2, 3], description="Fallback GPU device IDs")
max_memory_per_model: str = Field(default="8GB", description="Max memory per model")
class OllamaConfig(BaseModel):
"""Ollama configuration."""
host: str = Field(default="localhost", description="Ollama server host")
port: int = Field(default=11434, description="Ollama server port")
default_model: str = Field(default="llama3.2:latest", description="Default model")
timeout: int = Field(default=300, description="Request timeout in seconds")
class MemoryConfig(BaseModel):
"""Memory configuration."""
vector_store: str = Field(default="chromadb", description="Vector store backend")
embedding_model: str = Field(
default="nomic-embed-text:latest",
description="Embedding model"
)
max_context_length: int = Field(default=4096, description="Max context length")
persist_directory: str = Field(default="./data/memory", description="Memory persistence directory")
class WorkflowConfig(BaseModel):
"""Workflow configuration."""
max_parallel_tasks: int = Field(default=5, description="Max parallel tasks")
task_timeout: int = Field(default=600, description="Task timeout in seconds")
retry_attempts: int = Field(default=3, description="Retry attempts for failed tasks")
class LoggingConfig(BaseModel):
"""Logging configuration."""
level: str = Field(default="INFO", description="Logging level")
log_file: Optional[str] = Field(default="./logs/sparknet.log", description="Log file path")
rotation: str = Field(default="100 MB", description="Log rotation size")
retention: str = Field(default="7 days", description="Log retention period")
class SparknetConfig(BaseModel):
"""Main SPARKNET configuration."""
gpu: GPUConfig = Field(default_factory=GPUConfig)
ollama: OllamaConfig = Field(default_factory=OllamaConfig)
memory: MemoryConfig = Field(default_factory=MemoryConfig)
workflow: WorkflowConfig = Field(default_factory=WorkflowConfig)
logging: LoggingConfig = Field(default_factory=LoggingConfig)
def load_config(config_path: Optional[Path] = None) -> SparknetConfig:
"""
Load configuration from YAML file or use defaults.
Args:
config_path: Path to configuration file
Returns:
SparknetConfig instance
"""
if config_path and config_path.exists():
logger.info(f"Loading configuration from {config_path}")
with open(config_path, "r") as f:
config_data = yaml.safe_load(f)
return SparknetConfig(**config_data)
else:
logger.info("Using default configuration")
return SparknetConfig()
def save_config(config: SparknetConfig, config_path: Path):
"""
Save configuration to YAML file.
Args:
config: SparknetConfig instance
config_path: Path to save configuration
"""
config_path.parent.mkdir(parents=True, exist_ok=True)
with open(config_path, "w") as f:
yaml.dump(config.model_dump(), f, default_flow_style=False)
logger.info(f"Configuration saved to {config_path}")
# Global configuration instance
_config: Optional[SparknetConfig] = None
def get_config() -> SparknetConfig:
"""Get or create the global configuration instance."""
global _config
if _config is None:
_config = load_config()
return _config
def set_config(config: SparknetConfig):
"""Set the global configuration instance."""
global _config
_config = config