Spaces:
Running
Running
| import json | |
| import logging | |
| import pickle | |
| import shutil | |
| from pathlib import Path | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from sklearn.linear_model import LogisticRegression, LogisticRegressionCV | |
| from config import Config | |
| REPO_ID = Config.REPO_ID_LANG | |
| MODEL_DIR = Path(Config.LANG_MODEL) if Config.LANG_MODEL else None | |
| HF_TOKEN = Config.HF_TOKEN | |
| ENGLISH_SUBDIR = "English_model" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| REQUIRED_FILES = ( | |
| "classifier.pkl", | |
| "scaler.pkl", | |
| "word_vectorizer.pkl", | |
| "char_vectorizer.pkl", | |
| "feature_names.json", | |
| "metadata.json", | |
| ) | |
| def _patch_legacy_logistic_model(model): | |
| """Backfill attributes expected by newer sklearn versions.""" | |
| if isinstance(model, (LogisticRegression, LogisticRegressionCV)) and not hasattr(model, "multi_class"): | |
| model.multi_class = "auto" | |
| return model | |
| def _has_required_artifacts(model_dir: Path) -> bool: | |
| if not model_dir.exists() or not model_dir.is_dir(): | |
| return False | |
| return all((model_dir / filename).exists() for filename in REQUIRED_FILES) | |
| def _resolve_artifact_dir(base_dir: Path) -> Path | None: | |
| candidates = [base_dir, base_dir / ENGLISH_SUBDIR] | |
| for candidate in candidates: | |
| if _has_required_artifacts(candidate): | |
| return candidate | |
| return None | |
| def warmup(): | |
| logging.info("Warming up model...") | |
| if MODEL_DIR is None: | |
| raise ValueError("LANG_MODEL is not configured") | |
| if _resolve_artifact_dir(MODEL_DIR): | |
| logging.info("Model artifacts already exist, skipping download.") | |
| return | |
| download_model_repo() | |
| def download_model_repo(): | |
| if MODEL_DIR is None: | |
| raise ValueError("LANG_MODEL is not configured") | |
| if not REPO_ID: | |
| raise ValueError("English_model repo id is not configured") | |
| if _resolve_artifact_dir(MODEL_DIR): | |
| logging.info("Model artifacts already exist, skipping download.") | |
| return | |
| snapshot_path = Path(snapshot_download(repo_id=REPO_ID, token=HF_TOKEN)) | |
| source_dir = snapshot_path / ENGLISH_SUBDIR if (snapshot_path / ENGLISH_SUBDIR).is_dir() else snapshot_path | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| shutil.copytree(source_dir, MODEL_DIR, dirs_exist_ok=True) | |
| def load_model(): | |
| if MODEL_DIR is None: | |
| raise ValueError("LANG_MODEL is not configured") | |
| artifact_dir = _resolve_artifact_dir(MODEL_DIR) | |
| if artifact_dir is None: | |
| logging.info("Model artifacts missing in %s, downloading now.", MODEL_DIR) | |
| download_model_repo() | |
| artifact_dir = _resolve_artifact_dir(MODEL_DIR) | |
| if artifact_dir is None: | |
| raise FileNotFoundError( | |
| f"Required model artifacts not found in {MODEL_DIR}. Expected files: {', '.join(REQUIRED_FILES)}" | |
| ) | |
| with open(artifact_dir / "classifier.pkl", "rb") as f: | |
| loaded_classifier = pickle.load(f) | |
| loaded_classifier = _patch_legacy_logistic_model(loaded_classifier) | |
| with open(artifact_dir / "scaler.pkl", "rb") as f: | |
| loaded_scaler = pickle.load(f) | |
| with open(artifact_dir / "word_vectorizer.pkl", "rb") as f: | |
| loaded_word_vectorizer = pickle.load(f) | |
| with open(artifact_dir / "char_vectorizer.pkl", "rb") as f: | |
| loaded_char_vectorizer = pickle.load(f) | |
| with open(artifact_dir / "feature_names.json", "r") as f: | |
| loaded_features = json.load(f) | |
| with open(artifact_dir / "metadata.json", "r") as f: | |
| loaded_metadata = json.load(f) | |
| return ( | |
| loaded_classifier, | |
| loaded_scaler, | |
| loaded_word_vectorizer, | |
| loaded_char_vectorizer, | |
| loaded_features, | |
| loaded_metadata, | |
| ) | |