| """ |
| Client for Czech text correction API with local server auto-start |
| """ |
|
|
| import requests |
| import time |
| import subprocess |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Optional, Dict, List, Any |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class CzechCorrectionClient: |
| """Client for Czech text correction with automatic local server startup""" |
|
|
| |
| LOCAL_ENDPOINT = { |
| "name": "Local", |
| "base_url": "http://localhost:8042", |
| "timeout": 3600 |
| } |
|
|
| def __init__(self, prefer_local: bool = True): |
| """ |
| Initialize the client |
| |
| Args: |
| prefer_local: Deprecated, always uses local API now |
| """ |
| self.endpoint = self.LOCAL_ENDPOINT |
| self._working_endpoint = None |
| self._last_health_check = 0 |
| self.health_check_interval = 3600 |
| self._server_process = None |
|
|
| def _check_endpoint_health(self, endpoint: Dict) -> bool: |
| """Check if an endpoint is healthy""" |
| try: |
| response = requests.get( |
| f"{endpoint['base_url']}/api/health", |
| timeout=10 |
| ) |
| if response.status_code == 200: |
| data = response.json() |
| return data.get('status') == 'healthy' |
| except Exception as e: |
| logger.debug(f"Health check failed for {endpoint['name']}: {e}") |
| return False |
|
|
| def _is_port_in_use(self, port: int) -> bool: |
| """Check if a port is already in use""" |
| import socket |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| try: |
| s.bind(('localhost', port)) |
| return False |
| except OSError: |
| return True |
|
|
| def _start_local_server(self) -> bool: |
| """Start the local API server if not already running""" |
| try: |
| |
| if self._is_port_in_use(8042): |
| logger.warning("Port 8042 is already in use - server may already be running") |
| |
| time.sleep(2) |
| if self._check_endpoint_health(self.endpoint): |
| logger.info("✅ Server is already running on port 8042") |
| return True |
| else: |
| logger.error("Port 8042 is in use but server is not responding to health checks") |
| return False |
|
|
| |
| current_file = Path(__file__).resolve() |
| api_service_dir = current_file.parent |
| api_script = api_service_dir / "api.py" |
|
|
| if not api_script.exists(): |
| logger.error(f"API script not found at {api_script}") |
| return False |
|
|
| logger.info("Starting local API server...") |
| logger.info("This may take 1-2 minutes to load models...") |
|
|
| |
| env = os.environ.copy() |
| env['PORT'] = '8042' |
|
|
| self._server_process = subprocess.Popen( |
| [sys.executable, str(api_script)], |
| cwd=str(api_service_dir), |
| env=env, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| start_new_session=True |
| ) |
|
|
| |
| max_wait = 120 |
| start_time = time.time() |
|
|
| while time.time() - start_time < max_wait: |
| if self._check_endpoint_health(self.endpoint): |
| logger.info("✅ Local API server started successfully") |
| return True |
| time.sleep(2) |
|
|
| logger.error("Server failed to start within timeout") |
| return False |
|
|
| except Exception as e: |
| logger.error(f"Failed to start local server: {e}") |
| return False |
|
|
| def _get_working_endpoint(self) -> Optional[Dict]: |
| """Get working endpoint, starting server if needed""" |
| current_time = time.time() |
|
|
| |
| if self._working_endpoint and (current_time - self._last_health_check < self.health_check_interval): |
| return self._working_endpoint |
|
|
| |
| if self._check_endpoint_health(self.endpoint): |
| logger.info(f"Using {self.endpoint['name']} API endpoint") |
| self._working_endpoint = self.endpoint |
| self._last_health_check = current_time |
| return self.endpoint |
|
|
| |
| logger.info("Local API server not running, attempting to start...") |
| if self._start_local_server(): |
| self._working_endpoint = self.endpoint |
| self._last_health_check = current_time |
| return self.endpoint |
|
|
| logger.error("Could not start or connect to local API server") |
| return None |
|
|
| def correct_text(self, text: str, include_timing: bool = False) -> Dict[str, Any]: |
| """ |
| Correct Czech text (grammar and punctuation) |
| |
| Args: |
| text: Text to correct |
| include_timing: Whether to include processing time in response |
| |
| Returns: |
| Dict with 'success', 'corrected_text', and optionally 'processing_time_ms' |
| """ |
| if not text or not text.strip(): |
| return { |
| "success": True, |
| "corrected_text": text, |
| "error": None |
| } |
|
|
| endpoint = self._get_working_endpoint() |
| if not endpoint: |
| return { |
| "success": False, |
| "corrected_text": text, |
| "error": "Could not start or connect to local API server" |
| } |
|
|
| try: |
| payload = { |
| "text": text, |
| "options": {"include_timing": include_timing} |
| } |
|
|
| response = requests.post( |
| f"{endpoint['base_url']}/api/correct", |
| json=payload, |
| timeout=endpoint['timeout'] |
| ) |
|
|
| if response.status_code == 200: |
| return response.json() |
| else: |
| return { |
| "success": False, |
| "corrected_text": text, |
| "error": f"API error: {response.status_code}" |
| } |
|
|
| except requests.exceptions.Timeout: |
| logger.warning(f"Timeout on {endpoint['name']} API") |
| return { |
| "success": False, |
| "corrected_text": text, |
| "error": "Request timeout" |
| } |
|
|
| except Exception as e: |
| logger.error(f"Error calling API: {e}") |
| return { |
| "success": False, |
| "corrected_text": text, |
| "error": str(e) |
| } |
|
|
| def correct_batch(self, texts: List[str], include_timing: bool = False) -> Dict[str, Any]: |
| """ |
| Correct multiple Czech texts in batch |
| |
| Args: |
| texts: List of texts to correct (max 10) |
| include_timing: Whether to include processing time |
| |
| Returns: |
| Dict with 'success', 'corrected_texts', and optionally 'processing_time_ms' |
| """ |
| if not texts: |
| return { |
| "success": True, |
| "corrected_texts": [], |
| "error": None |
| } |
|
|
| if len(texts) > 10: |
| return { |
| "success": False, |
| "corrected_texts": texts, |
| "error": "Batch size exceeds limit (10)" |
| } |
|
|
| endpoint = self._get_working_endpoint() |
| if not endpoint: |
| return { |
| "success": False, |
| "corrected_texts": texts, |
| "error": "Could not start or connect to local API server" |
| } |
|
|
| try: |
| payload = { |
| "texts": texts, |
| "options": {"include_timing": include_timing} |
| } |
|
|
| response = requests.post( |
| f"{endpoint['base_url']}/api/correct/batch", |
| json=payload, |
| timeout=endpoint['timeout'] * 2 |
| ) |
|
|
| if response.status_code == 200: |
| return response.json() |
| else: |
| |
| logger.warning(f"Batch API failed, falling back to individual corrections") |
| corrected_texts = [] |
| for text in texts: |
| result = self.correct_text(text, include_timing=False) |
| corrected_texts.append(result.get('corrected_text', text)) |
|
|
| return { |
| "success": True, |
| "corrected_texts": corrected_texts, |
| "error": None |
| } |
|
|
| except Exception as e: |
| logger.error(f"Error calling batch API: {e}") |
| |
| corrected_texts = [] |
| for text in texts: |
| result = self.correct_text(text, include_timing=False) |
| corrected_texts.append(result.get('corrected_text', text)) |
|
|
| return { |
| "success": True, |
| "corrected_texts": corrected_texts, |
| "error": None |
| } |
|
|
|
|
| |
| _default_client = None |
|
|
| def get_client(prefer_local: bool = True) -> CzechCorrectionClient: |
| """Get or create the default client (always uses local now)""" |
| global _default_client |
| if _default_client is None: |
| _default_client = CzechCorrectionClient(prefer_local=True) |
| return _default_client |
|
|
| def correct_text(text: str, prefer_local: bool = True) -> str: |
| """Simple function for text correction (always uses local now)""" |
| client = get_client(prefer_local=True) |
| result = client.correct_text(text) |
| if result['success']: |
| return result['corrected_text'] |
| return text |
|
|
| def correct_batch(texts: List[str], prefer_local: bool = True) -> List[str]: |
| """Simple function for batch correction (always uses local now)""" |
| client = get_client(prefer_local=True) |
| result = client.correct_batch(texts) |
| if result['success']: |
| return result.get('corrected_texts', texts) |
| return texts |