| """ |
| FastAPI service for Czech text correction pipeline |
| Combines grammar error correction and punctuation restoration |
| """ |
|
|
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from typing import Optional, List, Dict |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, pipeline |
| import time |
| import re |
| import logging |
| import os |
| from contextlib import asynccontextmanager |
|
|
| |
| num_threads = int(os.environ.get("OMP_NUM_THREADS", 12)) |
| torch.set_num_threads(num_threads) |
| torch.set_num_interop_threads(num_threads) |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
| logger.info(f"PyTorch configured to use {num_threads} CPU threads") |
|
|
| |
| gec_model = None |
| gec_tokenizer = None |
| punct_pipeline = None |
| device = None |
|
|
| |
| GEC_CONFIG = { |
| "num_beams": 8, |
| "do_sample": False, |
| "repetition_penalty": 1.0, |
| "length_penalty": 1.0, |
| "no_repeat_ngram_size": 0, |
| "early_stopping": True, |
| "max_new_tokens": 100000 |
| } |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Load models on startup, cleanup on shutdown""" |
| global gec_model, gec_tokenizer, punct_pipeline, device |
|
|
| logger.info("Loading models...") |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Using device: {device}") |
|
|
| |
| logger.info("Loading Czech GEC model...") |
| gec_tokenizer = AutoTokenizer.from_pretrained("ufal/byt5-large-geccc-mate") |
| gec_model = AutoModelForSeq2SeqLM.from_pretrained("ufal/byt5-large-geccc-mate") |
| gec_model = gec_model.to(device) |
| logger.info("GEC model loaded successfully") |
|
|
| |
| logger.info("Loading punctuation model...") |
| punct_tokenizer = AutoTokenizer.from_pretrained("kredor/punctuate-all") |
| punct_model = AutoModelForTokenClassification.from_pretrained("kredor/punctuate-all") |
| punct_pipeline = pipeline( |
| "token-classification", |
| model=punct_model, |
| tokenizer=punct_tokenizer, |
| device=0 if torch.cuda.is_available() else -1 |
| ) |
| logger.info("Punctuation model loaded successfully") |
|
|
| logger.info("All models loaded and ready") |
|
|
| yield |
|
|
| |
| logger.info("Shutting down...") |
|
|
| |
| app = FastAPI( |
| title="Czech Text Correction API", |
| description="API for Czech grammar error correction and punctuation restoration", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| class CorrectionRequest(BaseModel): |
| text: str = Field(..., max_length=100000, description="Czech text to correct") |
| options: Optional[Dict] = Field(default={}, description="Optional parameters") |
|
|
| class CorrectionResponse(BaseModel): |
| success: bool |
| corrected_text: str |
| processing_time_ms: Optional[float] = None |
| error: Optional[str] = None |
|
|
| class BatchCorrectionRequest(BaseModel): |
| texts: List[str] = Field(..., max_items=10, description="List of texts to correct") |
| options: Optional[Dict] = Field(default={}, description="Optional parameters") |
|
|
| class BatchCorrectionResponse(BaseModel): |
| success: bool |
| corrected_texts: List[str] |
| processing_time_ms: Optional[float] = None |
| error: Optional[str] = None |
|
|
| class HealthResponse(BaseModel): |
| status: str |
| models_loaded: bool |
| gpu_available: bool |
| device: str |
|
|
| class InfoResponse(BaseModel): |
| name: str |
| version: str |
| models: Dict[str, str] |
| capabilities: List[str] |
| max_input_length: int |
|
|
| def apply_gec_correction(text: str) -> str: |
| """Apply grammar error correction to text""" |
| if not text.strip(): |
| return text |
|
|
| |
| inputs = gec_tokenizer( |
| text, |
| return_tensors="pt", |
| max_length=100000, |
| truncation=True |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| |
| with torch.no_grad(): |
| outputs = gec_model.generate( |
| **inputs, |
| **GEC_CONFIG |
| ) |
|
|
| |
| corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return corrected |
|
|
| def apply_gec_correction_batch(texts: List[str]) -> List[str]: |
| """Apply grammar error correction to multiple texts (batched for GPU efficiency)""" |
| if not texts: |
| return [] |
|
|
| |
| non_empty_texts = [] |
| non_empty_indices = [] |
| results = [""] * len(texts) |
|
|
| for i, text in enumerate(texts): |
| if text.strip(): |
| non_empty_texts.append(text) |
| non_empty_indices.append(i) |
| else: |
| results[i] = text |
|
|
| if not non_empty_texts: |
| return results |
|
|
| |
| inputs = gec_tokenizer( |
| non_empty_texts, |
| return_tensors="pt", |
| max_length=100000, |
| truncation=True, |
| padding=True |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| |
| with torch.no_grad(): |
| outputs = gec_model.generate( |
| **inputs, |
| **GEC_CONFIG |
| ) |
|
|
| |
| corrected_texts = gec_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
| |
| for i, corrected in zip(non_empty_indices, corrected_texts): |
| results[i] = corrected |
|
|
| return results |
|
|
| def apply_punctuation(text: str) -> str: |
| """Apply punctuation and capitalization to text""" |
| if not text.strip(): |
| return text |
|
|
| |
| clean_text = text.lower() |
| results = punct_pipeline(clean_text) |
|
|
| |
| punct_map = {} |
| current_word = "" |
| current_punct = "" |
|
|
| for i, result in enumerate(results): |
| word = result['word'].replace('β', '').strip() |
|
|
| |
| entity = result['entity'] |
| punct_marks = { |
| 'LABEL_0': '', |
| 'LABEL_1': '.', |
| 'LABEL_2': ',', |
| 'LABEL_3': '?', |
| 'LABEL_4': '-', |
| 'LABEL_5': ':' |
| } |
| punct = punct_marks.get(entity, '') |
|
|
| |
| if not result['word'].startswith('β') and i > 0: |
| current_word += word |
| else: |
| if current_word: |
| punct_map[current_word] = current_punct |
| current_word = word |
| current_punct = punct |
|
|
| |
| if current_word: |
| punct_map[current_word] = current_punct |
|
|
| |
| words = clean_text.split() |
| punctuated = [] |
|
|
| for word in words: |
| if word in punct_map and punct_map[word]: |
| punctuated.append(word + punct_map[word]) |
| else: |
| punctuated.append(word) |
|
|
| |
| result = ' '.join(punctuated) |
|
|
| |
| sentences = re.split(r'(?<=[.?!])\s+', result) |
| capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences) |
|
|
| |
| for p in [',', '.', '?', ':', '!', ';']: |
| capitalized = capitalized.replace(f' {p}', p) |
|
|
| return capitalized |
|
|
| def apply_punctuation_batch(texts: List[str]) -> List[str]: |
| """Apply punctuation and capitalization to multiple texts (batched for GPU efficiency)""" |
| if not texts: |
| return [] |
|
|
| results = [] |
| for text in texts: |
| if not text.strip(): |
| results.append(text) |
| continue |
|
|
| |
| clean_text = text.lower() |
| pipeline_results = punct_pipeline(clean_text) |
|
|
| |
| punct_map = {} |
| current_word = "" |
| current_punct = "" |
|
|
| for i, result in enumerate(pipeline_results): |
| word = result['word'].replace('β', '').strip() |
|
|
| |
| entity = result['entity'] |
| punct_marks = { |
| 'LABEL_0': '', |
| 'LABEL_1': '.', |
| 'LABEL_2': ',', |
| 'LABEL_3': '?', |
| 'LABEL_4': '-', |
| 'LABEL_5': ':' |
| } |
| punct = punct_marks.get(entity, '') |
|
|
| |
| if not result['word'].startswith('β') and i > 0: |
| current_word += word |
| else: |
| if current_word: |
| punct_map[current_word] = current_punct |
| current_word = word |
| current_punct = punct |
|
|
| |
| if current_word: |
| punct_map[current_word] = current_punct |
|
|
| |
| words = clean_text.split() |
| punctuated = [] |
|
|
| for word in words: |
| if word in punct_map and punct_map[word]: |
| punctuated.append(word + punct_map[word]) |
| else: |
| punctuated.append(word) |
|
|
| |
| result_text = ' '.join(punctuated) |
|
|
| |
| sentences = re.split(r'(?<=[.?!])\s+', result_text) |
| capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences) |
|
|
| |
| for p in [',', '.', '?', ':', '!', ';']: |
| capitalized = capitalized.replace(f' {p}', p) |
|
|
| results.append(capitalized) |
|
|
| return results |
|
|
| def process_text(text: str) -> str: |
| """Full pipeline: GEC + punctuation""" |
| |
| gec_corrected = apply_gec_correction(text) |
|
|
| |
| final_text = apply_punctuation(gec_corrected) |
|
|
| return final_text |
|
|
| @app.post("/api/correct", response_model=CorrectionResponse) |
| async def correct_text(request: CorrectionRequest): |
| """ |
| Correct Czech text (grammar + punctuation) |
| """ |
| try: |
| start_time = time.time() |
|
|
| |
| if not request.text.strip(): |
| raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
|
| if len(request.text) > 100000: |
| raise HTTPException(status_code=400, detail="Text too long (max 100000 characters)") |
|
|
| logger.info(f"π Single text request received ({len(request.text)} chars)") |
|
|
| |
| corrected = process_text(request.text) |
|
|
| |
| processing_time = (time.time() - start_time) * 1000 |
| logger.info(f"β
Completed in {processing_time:.1f}ms") |
|
|
| |
| response = CorrectionResponse( |
| success=True, |
| corrected_text=corrected |
| ) |
|
|
| if request.options.get("include_timing", False): |
| response.processing_time_ms = processing_time |
|
|
| return response |
|
|
| except Exception as e: |
| logger.error(f"Error processing text: {str(e)}") |
| return CorrectionResponse( |
| success=False, |
| corrected_text="", |
| error=str(e) |
| ) |
|
|
| @app.post("/api/correct/batch", response_model=BatchCorrectionResponse) |
| async def correct_batch(request: BatchCorrectionRequest): |
| """ |
| Correct multiple Czech texts (batched for GPU efficiency) |
| """ |
| try: |
| start_time = time.time() |
|
|
| |
| if not request.texts: |
| raise HTTPException(status_code=400, detail="No texts provided") |
|
|
| logger.info(f"π¦ Batch request received: {len(request.texts)} texts") |
|
|
| |
| validated_texts = [] |
| for text in request.texts: |
| if len(text) > 100000: |
| validated_texts.append("") |
| else: |
| validated_texts.append(text) |
|
|
| |
| |
| logger.info(f"π§ Starting GEC batch processing ({len(validated_texts)} texts)...") |
| gec_start = time.time() |
| gec_corrected_texts = apply_gec_correction_batch(validated_texts) |
| gec_time = (time.time() - gec_start) * 1000 |
| logger.info(f"β GEC completed in {gec_time:.1f}ms") |
|
|
| |
| logger.info(f"π Starting punctuation batch processing...") |
| punct_start = time.time() |
| final_texts = apply_punctuation_batch(gec_corrected_texts) |
| punct_time = (time.time() - punct_start) * 1000 |
| logger.info(f"β Punctuation completed in {punct_time:.1f}ms") |
|
|
| |
| corrected_texts = [] |
| for i, text in enumerate(request.texts): |
| if len(text) > 100000: |
| corrected_texts.append("[Error: Text too long]") |
| else: |
| corrected_texts.append(final_texts[i]) |
|
|
| |
| processing_time = (time.time() - start_time) * 1000 |
|
|
| logger.info(f"β
Batch completed: {len(corrected_texts)} texts in {processing_time:.1f}ms (avg {processing_time/len(corrected_texts):.1f}ms/text)") |
|
|
| response = BatchCorrectionResponse( |
| success=True, |
| corrected_texts=corrected_texts |
| ) |
|
|
| if request.options.get("include_timing", False): |
| response.processing_time_ms = processing_time |
|
|
| return response |
|
|
| except Exception as e: |
| logger.error(f"Error processing batch: {str(e)}") |
| return BatchCorrectionResponse( |
| success=False, |
| corrected_texts=[], |
| error=str(e) |
| ) |
|
|
| @app.post("/api/correct/gec-only") |
| async def correct_gec_only(request: CorrectionRequest): |
| """ |
| Apply only grammar error correction (no punctuation) |
| """ |
| try: |
| corrected = apply_gec_correction(request.text) |
| return CorrectionResponse( |
| success=True, |
| corrected_text=corrected |
| ) |
| except Exception as e: |
| return CorrectionResponse( |
| success=False, |
| corrected_text="", |
| error=str(e) |
| ) |
|
|
| @app.post("/api/correct/punct-only") |
| async def correct_punct_only(request: CorrectionRequest): |
| """ |
| Apply only punctuation restoration (no grammar correction) |
| """ |
| try: |
| corrected = apply_punctuation(request.text) |
| return CorrectionResponse( |
| success=True, |
| corrected_text=corrected |
| ) |
| except Exception as e: |
| return CorrectionResponse( |
| success=False, |
| corrected_text="", |
| error=str(e) |
| ) |
|
|
| @app.get("/api/health", response_model=HealthResponse) |
| async def health_check(): |
| """ |
| Check API health and model status |
| """ |
| models_loaded = (gec_model is not None and punct_pipeline is not None) |
|
|
| return HealthResponse( |
| status="healthy" if models_loaded else "loading", |
| models_loaded=models_loaded, |
| gpu_available=torch.cuda.is_available(), |
| device=str(device) if device else "not initialized" |
| ) |
|
|
| @app.get("/api/info", response_model=InfoResponse) |
| async def get_info(): |
| """ |
| Get API information and capabilities |
| """ |
| return InfoResponse( |
| name="Czech Text Correction API", |
| version="1.0.0", |
| models={ |
| "gec": "ufal/byt5-large-geccc-mate", |
| "punctuation": "kredor/punctuate-all" |
| }, |
| capabilities=[ |
| "Grammar error correction", |
| "Punctuation restoration", |
| "Capitalization", |
| "Batch processing", |
| "Czech language focus" |
| ], |
| max_input_length=100000 |
| ) |
|
|
| @app.get("/") |
| async def root(): |
| """Root endpoint with API documentation link""" |
| return { |
| "message": "Czech Text Correction API", |
| "docs": "/docs", |
| "health": "/api/health", |
| "info": "/api/info" |
| } |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| import os |
| port = int(os.environ.get("PORT", 8042)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |