| """ |
| Configuration management for SPARKNET |
| Handles loading and validation of configuration files |
| """ |
|
|
| import yaml |
| from pathlib import Path |
| from typing import Dict, Any, Optional |
| from pydantic import BaseModel, Field |
| from loguru import logger |
|
|
|
|
| class GPUConfig(BaseModel): |
| """GPU configuration.""" |
| primary: int = Field(default=0, description="Primary GPU device ID") |
| fallback: list[int] = Field(default=[1, 2, 3], description="Fallback GPU device IDs") |
| max_memory_per_model: str = Field(default="8GB", description="Max memory per model") |
|
|
|
|
| class OllamaConfig(BaseModel): |
| """Ollama configuration.""" |
| host: str = Field(default="localhost", description="Ollama server host") |
| port: int = Field(default=11434, description="Ollama server port") |
| default_model: str = Field(default="llama3.2:latest", description="Default model") |
| timeout: int = Field(default=300, description="Request timeout in seconds") |
|
|
|
|
| class MemoryConfig(BaseModel): |
| """Memory configuration.""" |
| vector_store: str = Field(default="chromadb", description="Vector store backend") |
| embedding_model: str = Field( |
| default="nomic-embed-text:latest", |
| description="Embedding model" |
| ) |
| max_context_length: int = Field(default=4096, description="Max context length") |
| persist_directory: str = Field(default="./data/memory", description="Memory persistence directory") |
|
|
|
|
| class WorkflowConfig(BaseModel): |
| """Workflow configuration.""" |
| max_parallel_tasks: int = Field(default=5, description="Max parallel tasks") |
| task_timeout: int = Field(default=600, description="Task timeout in seconds") |
| retry_attempts: int = Field(default=3, description="Retry attempts for failed tasks") |
|
|
|
|
| class LoggingConfig(BaseModel): |
| """Logging configuration.""" |
| level: str = Field(default="INFO", description="Logging level") |
| log_file: Optional[str] = Field(default="./logs/sparknet.log", description="Log file path") |
| rotation: str = Field(default="100 MB", description="Log rotation size") |
| retention: str = Field(default="7 days", description="Log retention period") |
|
|
|
|
| class SparknetConfig(BaseModel): |
| """Main SPARKNET configuration.""" |
| gpu: GPUConfig = Field(default_factory=GPUConfig) |
| ollama: OllamaConfig = Field(default_factory=OllamaConfig) |
| memory: MemoryConfig = Field(default_factory=MemoryConfig) |
| workflow: WorkflowConfig = Field(default_factory=WorkflowConfig) |
| logging: LoggingConfig = Field(default_factory=LoggingConfig) |
|
|
|
|
| def load_config(config_path: Optional[Path] = None) -> SparknetConfig: |
| """ |
| Load configuration from YAML file or use defaults. |
| |
| Args: |
| config_path: Path to configuration file |
| |
| Returns: |
| SparknetConfig instance |
| """ |
| if config_path and config_path.exists(): |
| logger.info(f"Loading configuration from {config_path}") |
| with open(config_path, "r") as f: |
| config_data = yaml.safe_load(f) |
| return SparknetConfig(**config_data) |
| else: |
| logger.info("Using default configuration") |
| return SparknetConfig() |
|
|
|
|
| def save_config(config: SparknetConfig, config_path: Path): |
| """ |
| Save configuration to YAML file. |
| |
| Args: |
| config: SparknetConfig instance |
| config_path: Path to save configuration |
| """ |
| config_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(config_path, "w") as f: |
| yaml.dump(config.model_dump(), f, default_flow_style=False) |
| logger.info(f"Configuration saved to {config_path}") |
|
|
|
|
| |
| _config: Optional[SparknetConfig] = None |
|
|
|
|
| def get_config() -> SparknetConfig: |
| """Get or create the global configuration instance.""" |
| global _config |
| if _config is None: |
| _config = load_config() |
| return _config |
|
|
|
|
| def set_config(config: SparknetConfig): |
| """Set the global configuration instance.""" |
| global _config |
| _config = config |
|
|