Spaces:
Running
Running
| """ | |
| Clean, secure ML code β baseline for comparison with vulnerable_ml_code.py. | |
| Demonstrates security best-practices: | |
| - Structured prompts (no string interpolation with user input) | |
| - Model singleton loaded at startup | |
| - @torch.no_grad on all inference paths | |
| - BF16 dtype for memory efficiency | |
| - Batched embeddings | |
| - Parameterised SQL | |
| - Authentication middleware | |
| - torch.cuda.empty_cache() after inference | |
| - No hardcoded secrets | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sqlite3 | |
| from functools import lru_cache | |
| from typing import List | |
| import torch | |
| from fastapi import FastAPI, Depends, HTTPException, Security | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| from pydantic import BaseModel | |
| app = FastAPI(debug=False) # No debug in production | |
| security_scheme = HTTPBearer() | |
| # ββ Secrets from environment (never hardcoded) βββββββββββββββ | |
| HF_TOKEN = os.getenv("HF_TOKEN") # Set in .env, never in code | |
| DB_PATH = os.getenv("DB_PATH", "knowledge.db") | |
| # ββ Singleton model loading at startup βββββββββββββββββββββββ | |
| def get_llm(): | |
| """Load LLM once at startup β not per-request.""" | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2", token=HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "gpt2", | |
| token=HF_TOKEN, | |
| torch_dtype=torch.bfloat16, # 50% VRAM vs float32 | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| return tokenizer, model | |
| def get_embedding_model() -> SentenceTransformer: | |
| """Load embedding model once at startup.""" | |
| return SentenceTransformer("all-MiniLM-L6-v2") | |
| # ββ Auth middleware βββββββββββββββββββββββββββββββββββββββββββ | |
| def require_auth(credentials: HTTPAuthorizationCredentials = Security(security_scheme)): | |
| token = credentials.credentials | |
| valid_token = os.getenv("API_TOKEN", "") | |
| if not valid_token or token != valid_token: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| return token | |
| # ββ Request schemas βββββββββββββββββββββββββββββββββββββββββββ | |
| class GenerateRequest(BaseModel): | |
| message: str | |
| max_new_tokens: int = 200 | |
| class EmbedRequest(BaseModel): | |
| documents: List[str] | |
| class SearchRequest(BaseModel): | |
| query: str | |
| # ββ LLM01 Fix: Structured prompt (no string interpolation) βββ | |
| async def generate(body: GenerateRequest, _: str = Depends(require_auth)): | |
| """ | |
| Chat endpoint β uses structured prompt template, never concatenates | |
| raw user input into the prompt instruction block. | |
| """ | |
| tokenizer, model = get_llm() | |
| # Safe: user content is clearly separated from instruction | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": body.message}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): # No gradient tracking during inference | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=min(body.max_new_tokens, 512), # LLM04: bounded | |
| ) | |
| result_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Move tensors back to CPU immediately | |
| inputs_cpu = {k: v.cpu() for k, v in inputs.items()} | |
| del inputs_cpu | |
| torch.cuda.empty_cache() # Return VRAM to pool | |
| # LLM02 Fix: NEVER eval() LLM output β parse structured JSON instead | |
| return {"result": result_text} | |
| # ββ A03 Fix: Parameterised SQL query βββββββββββββββββββββββββ | |
| async def rag_search(query: str, _: str = Depends(require_auth)): | |
| """Parameterised SQL β immune to SQL injection.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| try: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT * FROM documents WHERE content LIKE ?", | |
| (f"%{query}%",), # Parameterised β safe | |
| ) | |
| results = cursor.fetchall() | |
| finally: | |
| conn.close() | |
| return {"results": results} | |
| # ββ ML03 Fix: Batched embeddings βββββββββββββββββββββββββββββ | |
| async def embed_documents(body: EmbedRequest, _: str = Depends(require_auth)): | |
| """Batch-encodes all documents in a single GPU call.""" | |
| model = get_embedding_model() | |
| # Single batch call β no N+1 | |
| embeddings = model.encode( | |
| body.documents, | |
| batch_size=32, | |
| show_progress_bar=False, | |
| ) | |
| return {"embeddings": embeddings.tolist()} | |
| # ββ A01 Fix: Protected admin endpoint ββββββββββββββββββββββββ | |
| async def retrain_model( | |
| data: List[dict], | |
| _: str = Depends(require_auth), # Auth required | |
| ): | |
| """Triggers retraining β authentication enforced.""" | |
| # Validate data before accepting (LLM03 protection) | |
| if not data or len(data) > 10_000: | |
| raise HTTPException(status_code=400, detail="Invalid training data size") | |
| return {"status": "retraining queued", "samples": len(data)} | |
| # ββ A04 Fix: Safe model loading with safetensors βββββββββββββ | |
| async def load_model(model_name: str, _: str = Depends(require_auth)): | |
| """ | |
| Loads a model from HuggingFace Hub only (no arbitrary paths). | |
| Uses safetensors format β no pickle deserialization. | |
| """ | |
| # Allowlist of approved models only | |
| ALLOWED_MODELS = {"gpt2", "distilgpt2", "facebook/opt-125m"} | |
| if model_name not in ALLOWED_MODELS: | |
| raise HTTPException(status_code=400, detail=f"Model '{model_name}' not in allowlist") | |
| # from_pretrained uses safetensors when available β no pickle | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| return {"status": "loaded", "model": model_name} | |