| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| | import faiss |
| | import pickle |
| | from sentence_transformers import SentenceTransformer |
| | import numpy as np |
| | from collections import Counter |
| | import gzip |
| | import uvicorn |
| |
|
| | |
| | INDEX_PATH = "faiss.index" |
| | META_PATH = "metadata.pkl.gz" |
| | MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
| | CHUNK_SIZE = 2000 |
| |
|
| | |
| | index = faiss.read_index(INDEX_PATH) |
| |
|
| | with gzip.open(META_PATH, "rb") as f: |
| | meta = pickle.load(f) |
| |
|
| | texts = meta["texts"] |
| | statuses = meta["statuses"] |
| |
|
| | |
| | model = SentenceTransformer(MODEL_NAME) |
| |
|
| | |
| | app = FastAPI(title="Text Embedding Predictor") |
| |
|
| | |
| | class Query(BaseModel): |
| | text: str |
| | k: int = 5 |
| |
|
| | |
| | def split_text(text, chunk_size=CHUNK_SIZE): |
| | chunks = [] |
| | for i in range(0, len(text), chunk_size): |
| | chunks.append(text[i:i+chunk_size]) |
| | return chunks |
| |
|
| | |
| | @app.post("/predict") |
| | def predict(query: Query): |
| | text_chunks = split_text(query.text) |
| | all_top_statuses = [] |
| | all_results = [] |
| |
|
| | for chunk in text_chunks: |
| | |
| | chunk = chunk.replace("\\", "\\\\") |
| | |
| | q_emb = model.encode([chunk]).astype("float32") |
| | distances, indices = index.search(q_emb, query.k) |
| |
|
| | top_statuses = [] |
| | results = [] |
| |
|
| | for rank, idx in enumerate(indices[0]): |
| | status = statuses[idx] |
| | top_statuses.append(status) |
| | results.append({ |
| | "rank": rank + 1, |
| | "text": texts[idx], |
| | "status": status, |
| | "distance": float(distances[0][rank]) |
| | }) |
| |
|
| | all_top_statuses.extend(top_statuses) |
| | all_results.extend(results) |
| |
|
| | |
| | vote = Counter(all_top_statuses).most_common(1)[0] |
| |
|
| | return { |
| | "prediction": vote[0], |
| | "votes": dict(Counter(all_top_statuses)), |
| | "top_k": all_results[:query.k] |
| | } |
| |
|
| | |
| | if __name__ == "__main__": |
| | uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |
| |
|