Spaces:
Running
Running
File size: 3,694 Bytes
31fda96 0117df3 31fda96 0117df3 31fda96 49fe170 31fda96 8d28be7 0117df3 8d28be7 49fe170 8d28be7 0117df3 183f1c4 0117df3 31fda96 8d28be7 31fda96 0117df3 31fda96 8d28be7 0117df3 8d28be7 0117df3 31fda96 8d28be7 31fda96 49fe170 31fda96 8d28be7 31fda96 8d28be7 31fda96 8d28be7 31fda96 8d28be7 31fda96 8d28be7 31fda96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import json
import logging
import pickle
import shutil
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from config import Config
REPO_ID = Config.REPO_ID_LANG
MODEL_DIR = Path(Config.LANG_MODEL) if Config.LANG_MODEL else None
HF_TOKEN = Config.HF_TOKEN
ENGLISH_SUBDIR = "English_model"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
REQUIRED_FILES = (
"classifier.pkl",
"scaler.pkl",
"word_vectorizer.pkl",
"char_vectorizer.pkl",
"feature_names.json",
"metadata.json",
)
def _patch_legacy_logistic_model(model):
"""Backfill attributes expected by newer sklearn versions."""
if isinstance(model, (LogisticRegression, LogisticRegressionCV)) and not hasattr(model, "multi_class"):
model.multi_class = "auto"
return model
def _has_required_artifacts(model_dir: Path) -> bool:
if not model_dir.exists() or not model_dir.is_dir():
return False
return all((model_dir / filename).exists() for filename in REQUIRED_FILES)
def _resolve_artifact_dir(base_dir: Path) -> Path | None:
candidates = [base_dir, base_dir / ENGLISH_SUBDIR]
for candidate in candidates:
if _has_required_artifacts(candidate):
return candidate
return None
def warmup():
logging.info("Warming up model...")
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
if _resolve_artifact_dir(MODEL_DIR):
logging.info("Model artifacts already exist, skipping download.")
return
download_model_repo()
def download_model_repo():
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
if not REPO_ID:
raise ValueError("English_model repo id is not configured")
if _resolve_artifact_dir(MODEL_DIR):
logging.info("Model artifacts already exist, skipping download.")
return
snapshot_path = Path(snapshot_download(repo_id=REPO_ID, token=HF_TOKEN))
source_dir = snapshot_path / ENGLISH_SUBDIR if (snapshot_path / ENGLISH_SUBDIR).is_dir() else snapshot_path
MODEL_DIR.mkdir(parents=True, exist_ok=True)
shutil.copytree(source_dir, MODEL_DIR, dirs_exist_ok=True)
def load_model():
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
artifact_dir = _resolve_artifact_dir(MODEL_DIR)
if artifact_dir is None:
logging.info("Model artifacts missing in %s, downloading now.", MODEL_DIR)
download_model_repo()
artifact_dir = _resolve_artifact_dir(MODEL_DIR)
if artifact_dir is None:
raise FileNotFoundError(
f"Required model artifacts not found in {MODEL_DIR}. Expected files: {', '.join(REQUIRED_FILES)}"
)
with open(artifact_dir / "classifier.pkl", "rb") as f:
loaded_classifier = pickle.load(f)
loaded_classifier = _patch_legacy_logistic_model(loaded_classifier)
with open(artifact_dir / "scaler.pkl", "rb") as f:
loaded_scaler = pickle.load(f)
with open(artifact_dir / "word_vectorizer.pkl", "rb") as f:
loaded_word_vectorizer = pickle.load(f)
with open(artifact_dir / "char_vectorizer.pkl", "rb") as f:
loaded_char_vectorizer = pickle.load(f)
with open(artifact_dir / "feature_names.json", "r") as f:
loaded_features = json.load(f)
with open(artifact_dir / "metadata.json", "r") as f:
loaded_metadata = json.load(f)
return (
loaded_classifier,
loaded_scaler,
loaded_word_vectorizer,
loaded_char_vectorizer,
loaded_features,
loaded_metadata,
)
|