File size: 3,694 Bytes
31fda96
 
 
0117df3
31fda96
 
0117df3
31fda96
49fe170
31fda96
 
 
 
 
8d28be7
 
0117df3
 
8d28be7
 
 
 
 
 
 
 
 
 
 
49fe170
 
 
 
 
 
 
8d28be7
 
 
 
 
 
 
 
 
 
 
 
0117df3
183f1c4
0117df3
31fda96
 
 
8d28be7
 
31fda96
0117df3
 
 
 
31fda96
 
8d28be7
 
 
 
0117df3
8d28be7
 
 
 
0117df3
 
 
31fda96
 
8d28be7
 
 
 
 
 
 
 
 
 
 
31fda96
49fe170
31fda96
8d28be7
31fda96
 
8d28be7
31fda96
 
8d28be7
31fda96
 
8d28be7
31fda96
 
8d28be7
31fda96
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
import logging
import pickle
import shutil
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV

from config import Config

REPO_ID = Config.REPO_ID_LANG
MODEL_DIR = Path(Config.LANG_MODEL) if Config.LANG_MODEL else None
HF_TOKEN = Config.HF_TOKEN
ENGLISH_SUBDIR = "English_model"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

REQUIRED_FILES = (
    "classifier.pkl",
    "scaler.pkl",
    "word_vectorizer.pkl",
    "char_vectorizer.pkl",
    "feature_names.json",
    "metadata.json",
)


def _patch_legacy_logistic_model(model):
    """Backfill attributes expected by newer sklearn versions."""
    if isinstance(model, (LogisticRegression, LogisticRegressionCV)) and not hasattr(model, "multi_class"):
        model.multi_class = "auto"
    return model


def _has_required_artifacts(model_dir: Path) -> bool:
    if not model_dir.exists() or not model_dir.is_dir():
        return False
    return all((model_dir / filename).exists() for filename in REQUIRED_FILES)


def _resolve_artifact_dir(base_dir: Path) -> Path | None:
    candidates = [base_dir, base_dir / ENGLISH_SUBDIR]
    for candidate in candidates:
        if _has_required_artifacts(candidate):
            return candidate
    return None


def warmup():
    logging.info("Warming up model...")
    if MODEL_DIR is None:
        raise ValueError("LANG_MODEL is not configured")
    if _resolve_artifact_dir(MODEL_DIR):
        logging.info("Model artifacts already exist, skipping download.")
        return
    download_model_repo()


def download_model_repo():
    if MODEL_DIR is None:
        raise ValueError("LANG_MODEL is not configured")
    if not REPO_ID:
        raise ValueError("English_model repo id is not configured")
    if _resolve_artifact_dir(MODEL_DIR):
        logging.info("Model artifacts already exist, skipping download.")
        return
    snapshot_path = Path(snapshot_download(repo_id=REPO_ID, token=HF_TOKEN))
    source_dir = snapshot_path / ENGLISH_SUBDIR if (snapshot_path / ENGLISH_SUBDIR).is_dir() else snapshot_path
    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    shutil.copytree(source_dir, MODEL_DIR, dirs_exist_ok=True)


def load_model():
    if MODEL_DIR is None:
        raise ValueError("LANG_MODEL is not configured")
    artifact_dir = _resolve_artifact_dir(MODEL_DIR)
    if artifact_dir is None:
        logging.info("Model artifacts missing in %s, downloading now.", MODEL_DIR)
        download_model_repo()
        artifact_dir = _resolve_artifact_dir(MODEL_DIR)
    if artifact_dir is None:
        raise FileNotFoundError(
            f"Required model artifacts not found in {MODEL_DIR}. Expected files: {', '.join(REQUIRED_FILES)}"
        )

    with open(artifact_dir / "classifier.pkl", "rb") as f:
        loaded_classifier = pickle.load(f)
    loaded_classifier = _patch_legacy_logistic_model(loaded_classifier)

    with open(artifact_dir / "scaler.pkl", "rb") as f:
        loaded_scaler = pickle.load(f)

    with open(artifact_dir / "word_vectorizer.pkl", "rb") as f:
        loaded_word_vectorizer = pickle.load(f)

    with open(artifact_dir / "char_vectorizer.pkl", "rb") as f:
        loaded_char_vectorizer = pickle.load(f)

    with open(artifact_dir / "feature_names.json", "r") as f:
        loaded_features = json.load(f)

    with open(artifact_dir / "metadata.json", "r") as f:
        loaded_metadata = json.load(f)
    return (
        loaded_classifier,
        loaded_scaler,
        loaded_word_vectorizer,
        loaded_char_vectorizer,
        loaded_features,
        loaded_metadata,
    )