File size: 6,417 Bytes
7b4f5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
Clean, secure ML code β€” baseline for comparison with vulnerable_ml_code.py.

Demonstrates security best-practices:
  - Structured prompts (no string interpolation with user input)
  - Model singleton loaded at startup
  - @torch.no_grad on all inference paths
  - BF16 dtype for memory efficiency
  - Batched embeddings
  - Parameterised SQL
  - Authentication middleware
  - torch.cuda.empty_cache() after inference
  - No hardcoded secrets
"""

from __future__ import annotations

import os
import sqlite3
from functools import lru_cache
from typing import List

import torch
from fastapi import FastAPI, Depends, HTTPException, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel

app = FastAPI(debug=False)  # No debug in production
security_scheme = HTTPBearer()

# ── Secrets from environment (never hardcoded) ───────────────
HF_TOKEN = os.getenv("HF_TOKEN")  # Set in .env, never in code
DB_PATH = os.getenv("DB_PATH", "knowledge.db")

# ── Singleton model loading at startup ───────────────────────

@lru_cache(maxsize=1)
def get_llm():
    """Load LLM once at startup β€” not per-request."""
    tokenizer = AutoTokenizer.from_pretrained("gpt2", token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(
        "gpt2",
        token=HF_TOKEN,
        torch_dtype=torch.bfloat16,  # 50% VRAM vs float32
        device_map="auto",
    )
    model.eval()
    return tokenizer, model


@lru_cache(maxsize=1)
def get_embedding_model() -> SentenceTransformer:
    """Load embedding model once at startup."""
    return SentenceTransformer("all-MiniLM-L6-v2")


# ── Auth middleware ───────────────────────────────────────────

def require_auth(credentials: HTTPAuthorizationCredentials = Security(security_scheme)):
    token = credentials.credentials
    valid_token = os.getenv("API_TOKEN", "")
    if not valid_token or token != valid_token:
        raise HTTPException(status_code=401, detail="Unauthorized")
    return token


# ── Request schemas ───────────────────────────────────────────

class GenerateRequest(BaseModel):
    message: str
    max_new_tokens: int = 200


class EmbedRequest(BaseModel):
    documents: List[str]


class SearchRequest(BaseModel):
    query: str


# ── LLM01 Fix: Structured prompt (no string interpolation) ───

@app.post("/generate")
async def generate(body: GenerateRequest, _: str = Depends(require_auth)):
    """
    Chat endpoint β€” uses structured prompt template, never concatenates
    raw user input into the prompt instruction block.
    """
    tokenizer, model = get_llm()

    # Safe: user content is clearly separated from instruction
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": body.message},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():  # No gradient tracking during inference
        outputs = model.generate(
            **inputs,
            max_new_tokens=min(body.max_new_tokens, 512),  # LLM04: bounded
        )

    result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Move tensors back to CPU immediately
    inputs_cpu = {k: v.cpu() for k, v in inputs.items()}
    del inputs_cpu
    torch.cuda.empty_cache()  # Return VRAM to pool

    # LLM02 Fix: NEVER eval() LLM output β€” parse structured JSON instead
    return {"result": result_text}


# ── A03 Fix: Parameterised SQL query ─────────────────────────

@app.get("/search")
async def rag_search(query: str, _: str = Depends(require_auth)):
    """Parameterised SQL β€” immune to SQL injection."""
    conn = sqlite3.connect(DB_PATH)
    try:
        cursor = conn.cursor()
        cursor.execute(
            "SELECT * FROM documents WHERE content LIKE ?",
            (f"%{query}%",),  # Parameterised β€” safe
        )
        results = cursor.fetchall()
    finally:
        conn.close()
    return {"results": results}


# ── ML03 Fix: Batched embeddings ─────────────────────────────

@app.post("/embed_documents")
async def embed_documents(body: EmbedRequest, _: str = Depends(require_auth)):
    """Batch-encodes all documents in a single GPU call."""
    model = get_embedding_model()
    # Single batch call β€” no N+1
    embeddings = model.encode(
        body.documents,
        batch_size=32,
        show_progress_bar=False,
    )
    return {"embeddings": embeddings.tolist()}


# ── A01 Fix: Protected admin endpoint ────────────────────────

@app.post("/admin/retrain")
async def retrain_model(
    data: List[dict],
    _: str = Depends(require_auth),  # Auth required
):
    """Triggers retraining β€” authentication enforced."""
    # Validate data before accepting (LLM03 protection)
    if not data or len(data) > 10_000:
        raise HTTPException(status_code=400, detail="Invalid training data size")
    return {"status": "retraining queued", "samples": len(data)}


# ── A04 Fix: Safe model loading with safetensors ─────────────

@app.post("/load_model")
async def load_model(model_name: str, _: str = Depends(require_auth)):
    """
    Loads a model from HuggingFace Hub only (no arbitrary paths).
    Uses safetensors format β€” no pickle deserialization.
    """
    # Allowlist of approved models only
    ALLOWED_MODELS = {"gpt2", "distilgpt2", "facebook/opt-125m"}
    if model_name not in ALLOWED_MODELS:
        raise HTTPException(status_code=400, detail=f"Model '{model_name}' not in allowlist")

    # from_pretrained uses safetensors when available β€” no pickle
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
    )
    return {"status": "loaded", "model": model_name}