File size: 6,749 Bytes
4ee3607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12e4ca0
4ee3607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc2f1b4
 
4ee3607
 
 
 
dc2f1b4
 
 
 
 
 
 
 
 
 
 
 
 
 
4ee3607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5913e40
 
b217733
5913e40
 
 
4ee3607
 
5913e40
4ee3607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc2f1b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# =============================================================================
# model.py
# HuggingFace Model + Dataset Access Layer
# SmolLM2 Service Space
# Copyright 2026 - Volkan KΓΌcΓΌkbudak
# Apache License V2 + ESOL 1.1
# =============================================================================
# Handles:
#   - Model loading (SmolLM2 from HF or private repo)
#   - Dataset read/write (private HF dataset)
#   - Token resolution (HF_TOKEN β†’ TEST_TOKEN β†’ None)
# =============================================================================

import os
import logging
from datetime import datetime
from typing import Optional
from huggingface_hub import HfApi, login
from datasets import load_dataset, Dataset

logger = logging.getLogger("model")

# ── Token Resolution ──────────────────────────────────────────────────────────
TOKEN = (
    os.environ.get("SMOLLM_API_KEY") or
    os.environ.get("HF_TOKEN") or
    os.environ.get("TEST_TOKEN") or
    os.environ.get("HUGGINGFACE_TOKEN") or
    os.environ.get("HF_API_TOKEN") or
    None
)

# ── Config from ENV ───────────────────────────────────────────────────────────
MODEL_REPO    = os.environ.get("MODEL_REPO", "HuggingFaceTB/SmolLM2-360M-Instruct")
DATASET_REPO  = os.environ.get("DATASET_REPO", "codey-lab/data.universal-mcp-hub")
PRIVATE_MODEL = os.environ.get("PRIVATE_MODEL_REPO", "codey-lab/model.universal-mcp-hub")

# ── HF API ────────────────────────────────────────────────────────────────────
_api: Optional[HfApi] = None

def get_api() -> Optional[HfApi]:
    """Returns authenticated HfApi instance or None if no token."""
    global _api
    if _api is None and TOKEN:
        try:
            login(token=TOKEN, add_to_git_credential=False)
            _api = HfApi(token=TOKEN)
            logger.info("HF API authenticated")
        except Exception as e:
            logger.warning(f"HF API auth failed: {type(e).__name__} β€” running unauthenticated")
    return _api


# =============================================================================
# Model Access
# =============================================================================

def get_model_id() -> str:
    """
    Returns model ID to load.
    Prefers private fine-tuned model only if it has actual weights (config.json with model_type).
    Falls back to base model if private repo is empty or not ready.
    """
    api = get_api()
    if api and PRIVATE_MODEL:
        try:
            files = api.list_repo_files(PRIVATE_MODEL, repo_type="model", token=TOKEN)
            has_config = "config.json" in list(files)
            if has_config:
                # Double-check it's a real model config, not just a README
                from huggingface_hub import hf_hub_download
                import json
                cfg_path = hf_hub_download(PRIVATE_MODEL, "config.json", token=TOKEN)
                cfg = json.loads(open(cfg_path).read())
                if "model_type" in cfg:
                    logger.info(f"Using private model: {PRIVATE_MODEL}")
                    return PRIVATE_MODEL
            logger.info(f"Private repo exists but has no weights yet β€” using base: {MODEL_REPO}")
        except Exception as e:
            logger.info(f"Private model check failed ({type(e).__name__}) β€” using base: {MODEL_REPO}")
    return MODEL_REPO


def get_model_kwargs() -> dict:
    """Returns kwargs for from_pretrained() calls."""
    kwargs = {}
    if TOKEN:
        kwargs["token"] = TOKEN
    return kwargs


# =============================================================================
# Dataset Access
# =============================================================================

def load_logs() -> list:
    if not TOKEN:
        logger.warning("No token β€” dataset read skipped")
        return []
    try:
        ds = load_dataset(
            "parquet",
            data_files={"train": f"hf://datasets/{DATASET_REPO}/data/*.parquet"},
            split="train",
            token=TOKEN
        )
        return ds.to_list()
    except Exception as e:
        logger.info(f"Dataset load: {type(e).__name__}: {e} β€” starting fresh")
        return []


def push_log(entry: dict) -> bool:
    """
    Append a log entry to HF Dataset and push.
    
    Args:
        entry: dict with prompt, adi, response, model, timestamp etc.
    
    Returns:
        True on success, False on failure.
    """
    if not TOKEN:
        logger.warning("No token β€” dataset push skipped")
        return False
    try:
        existing = load_logs()
        entry["timestamp"] = datetime.utcnow().isoformat()
        existing.append(entry)
        ds = Dataset.from_list(existing)
        ds.push_to_hub(DATASET_REPO, token=TOKEN, private=True)
        logger.info(f"Dataset updated β€” total entries: {len(existing)}")
        return True
    except Exception as e:
        logger.warning(f"Dataset push failed: {type(e).__name__}: {e}")
        return False


def push_model_card(info: dict) -> bool:
    """
    Update model card / metadata in private model repo.
    Useful for tracking which weights/config is deployed.
    """
    api = get_api()
    if not api:
        return False
    try:
        content = f"""---
language: en
license: apache-2.0
base_model: {MODEL_REPO}
---

# SmolLM2 Service

Base: `{MODEL_REPO}`  
Dataset: `{DATASET_REPO}`  
Last updated: {datetime.utcnow().isoformat()}

## Config
```json
{info}
```
"""
        api.upload_file(
            path_or_fileobj=content.encode(),
            path_in_repo="README.md",
            repo_id=PRIVATE_MODEL,
            repo_type="model",
            token=TOKEN,
        )
        logger.info(f"Model card updated: {PRIVATE_MODEL}")
        return True
    except Exception as e:
        logger.warning(f"Model card update failed: {type(e).__name__}: {e}")
        return False


# =============================================================================
# Health
# =============================================================================

def status() -> dict:
    """Returns model/dataset config status for health endpoint."""
    return {
        "token":         "set" if TOKEN else "missing",
        "model_repo":    MODEL_REPO,
        "private_model": PRIVATE_MODEL,
        "dataset_repo":  DATASET_REPO,
        "hf_api":        "authenticated" if get_api() else "unauthenticated",
    }