""" Configuration management for SPARKNET Handles loading and validation of configuration files """ import yaml from pathlib import Path from typing import Dict, Any, Optional from pydantic import BaseModel, Field from loguru import logger class GPUConfig(BaseModel): """GPU configuration.""" primary: int = Field(default=0, description="Primary GPU device ID") fallback: list[int] = Field(default=[1, 2, 3], description="Fallback GPU device IDs") max_memory_per_model: str = Field(default="8GB", description="Max memory per model") class OllamaConfig(BaseModel): """Ollama configuration.""" host: str = Field(default="localhost", description="Ollama server host") port: int = Field(default=11434, description="Ollama server port") default_model: str = Field(default="llama3.2:latest", description="Default model") timeout: int = Field(default=300, description="Request timeout in seconds") class MemoryConfig(BaseModel): """Memory configuration.""" vector_store: str = Field(default="chromadb", description="Vector store backend") embedding_model: str = Field( default="nomic-embed-text:latest", description="Embedding model" ) max_context_length: int = Field(default=4096, description="Max context length") persist_directory: str = Field(default="./data/memory", description="Memory persistence directory") class WorkflowConfig(BaseModel): """Workflow configuration.""" max_parallel_tasks: int = Field(default=5, description="Max parallel tasks") task_timeout: int = Field(default=600, description="Task timeout in seconds") retry_attempts: int = Field(default=3, description="Retry attempts for failed tasks") class LoggingConfig(BaseModel): """Logging configuration.""" level: str = Field(default="INFO", description="Logging level") log_file: Optional[str] = Field(default="./logs/sparknet.log", description="Log file path") rotation: str = Field(default="100 MB", description="Log rotation size") retention: str = Field(default="7 days", description="Log retention period") class SparknetConfig(BaseModel): """Main SPARKNET configuration.""" gpu: GPUConfig = Field(default_factory=GPUConfig) ollama: OllamaConfig = Field(default_factory=OllamaConfig) memory: MemoryConfig = Field(default_factory=MemoryConfig) workflow: WorkflowConfig = Field(default_factory=WorkflowConfig) logging: LoggingConfig = Field(default_factory=LoggingConfig) def load_config(config_path: Optional[Path] = None) -> SparknetConfig: """ Load configuration from YAML file or use defaults. Args: config_path: Path to configuration file Returns: SparknetConfig instance """ if config_path and config_path.exists(): logger.info(f"Loading configuration from {config_path}") with open(config_path, "r") as f: config_data = yaml.safe_load(f) return SparknetConfig(**config_data) else: logger.info("Using default configuration") return SparknetConfig() def save_config(config: SparknetConfig, config_path: Path): """ Save configuration to YAML file. Args: config: SparknetConfig instance config_path: Path to save configuration """ config_path.parent.mkdir(parents=True, exist_ok=True) with open(config_path, "w") as f: yaml.dump(config.model_dump(), f, default_flow_style=False) logger.info(f"Configuration saved to {config_path}") # Global configuration instance _config: Optional[SparknetConfig] = None def get_config() -> SparknetConfig: """Get or create the global configuration instance.""" global _config if _config is None: _config = load_config() return _config def set_config(config: SparknetConfig): """Set the global configuration instance.""" global _config _config = config