chmielvu's picture
Update app.py
1da0f6a verified
"""
FastEmbed-based Code Embedding Server
Optimized for CPU Basic (2 vCPU, 16GB RAM)
Models:
- Dense: jinaai/jina-embeddings-v2-small-en (512 dim)
- Sparse: Qdrant/bm25 (BM25, 0.01GB)
- Reranker: jinaai/jina-reranker-v1-turbo-en (0.13GB)
"""
import time
import uuid
from typing import Any, Literal
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel, ConfigDict, Field
from fastembed import TextEmbedding, SparseTextEmbedding
from fastembed.rerank.cross_encoder import TextCrossEncoder
# Model names
DENSE_MODEL = "jinaai/jina-embeddings-v2-small-en"
SPARSE_MODEL = "Qdrant/bm25"
RERANKER_MODEL = "jinaai/jina-reranker-v1-turbo-en"
# Global model cache (loaded once, reused)
_dense_model: TextEmbedding | None = None
_sparse_model: SparseTextEmbedding | None = None
_reranker_model: TextCrossEncoder | None = None
app = FastAPI(
title="FastEmbed Code Embeddings",
summary="CPU-optimized code embeddings with BM25 sparse and reranking",
version="2.0.0",
)
def _get_dense_model() -> TextEmbedding:
"""Lazy-load dense model (cached globally)."""
global _dense_model
if _dense_model is None:
_dense_model = TextEmbedding(model_name=DENSE_MODEL)
return _dense_model
def _get_sparse_model() -> SparseTextEmbedding:
"""Lazy-load sparse BM25 model (cached globally)."""
global _sparse_model
if _sparse_model is None:
_sparse_model = SparseTextEmbedding(model_name=SPARSE_MODEL)
return _sparse_model
def _get_reranker() -> TextCrossEncoder:
"""Lazy-load reranker model (cached globally)."""
global _reranker_model
if _reranker_model is None:
_reranker_model = TextCrossEncoder(model_name=RERANKER_MODEL)
return _reranker_model
# ==================== Request Models ====================
class EmbeddingRequest(BaseModel):
model_config = ConfigDict(extra="allow")
input: str | list[str]
model: str = "code-embed"
encoding_format: Literal["float", "base64"] = "float"
dimensions: int = 0 # 0 = full dimensions
class SparseEmbeddingRequest(BaseModel):
model_config = ConfigDict(extra="allow")
input: str | list[str]
model: str = "bm25"
class RerankRequest(BaseModel):
model_config = ConfigDict(extra="allow")
query: str = Field(..., max_length=8192)
documents: list[str] = Field(..., min_length=1, max_length=256)
return_documents: bool = False
raw_scores: bool = False
model: str = "code-rerank"
top_n: int | None = None
class HybridRequest(BaseModel):
"""Request for hybrid search embeddings (dense + sparse)."""
model_config = ConfigDict(extra="allow")
input: str | list[str]
dense_model: str = "code-embed"
sparse_model: str = "bm25"
# ==================== Helper Functions ====================
def _now_ts() -> int:
return int(time.time())
def _make_id(prefix: str) -> str:
return f"{prefix}-{uuid.uuid4().hex}"
def _normalize_input(input: str | list[str]) -> list[str]:
if isinstance(input, str):
return [input]
return input
def _truncate_embedding(vector: np.ndarray, dimensions: int) -> np.ndarray:
if dimensions > 0 and dimensions < len(vector):
return vector[:dimensions]
return vector
def _vector_to_payload(vector: np.ndarray, encoding_format: str) -> list[float] | str:
if encoding_format == "base64":
import base64
return base64.b64encode(vector.astype(np.float32).tobytes()).decode()
return vector.tolist()
# ==================== API Endpoints ====================
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok", "models": f"{DENSE_MODEL} + {SPARSE_MODEL} + {RERANKER_MODEL}"}
@app.post("/embeddings")
@app.post("/v1/embeddings")
def embeddings(request: EmbeddingRequest) -> dict[str, Any]:
"""Generate dense embeddings using jina-embeddings-v2-base-code."""
texts = _normalize_input(request.input)
model = _get_dense_model()
# Generate embeddings (ONNX-optimized, cached)
embeddings_list = list(model.embed(texts))
data = []
for idx, embedding in enumerate(embeddings_list):
embedding = _truncate_embedding(embedding, request.dimensions)
data.append({
"object": "embedding",
"embedding": _vector_to_payload(embedding, request.encoding_format),
"index": idx,
})
return {
"object": "list",
"data": data,
"model": request.model,
"usage": {"prompt_tokens": sum(len(t.split()) for t in texts), "total_tokens": 0},
"id": _make_id("emb"),
"created": _now_ts(),
}
@app.post("/sparse/embeddings")
@app.post("/v1/sparse/embeddings")
def sparse_embeddings(request: SparseEmbeddingRequest) -> dict[str, Any]:
"""Generate sparse BM25 embeddings."""
texts = _normalize_input(request.input)
model = _get_sparse_model()
# Generate sparse embeddings
sparse_embeddings = list(model.embed(texts))
data = []
for idx, emb in enumerate(sparse_embeddings):
data.append({
"object": "sparse_embedding",
"indices": emb.indices.tolist(),
"values": emb.values.tolist(),
"index": idx,
})
return {
"object": "list",
"data": data,
"model": request.model,
"id": _make_id("sparse"),
"created": _now_ts(),
}
@app.post("/rerank")
@app.post("/v1/rerank")
def rerank(request: RerankRequest) -> dict[str, Any]:
"""Rerank documents using cross-encoder."""
reranker = _get_reranker()
# Compute rerank scores
scores = reranker.rerank(request.query, request.documents)
results = []
for idx, score in enumerate(scores):
item = {"index": idx, "relevance_score": float(score)}
if request.return_documents:
item["document"] = request.documents[idx]
results.append(item)
# Sort by relevance
results.sort(key=lambda x: x["relevance_score"], reverse=True)
if request.top_n is not None:
results = results[:request.top_n]
return {
"object": "rerank",
"results": results,
"model": request.model,
"usage": {
"prompt_tokens": len(request.query.split()),
"total_tokens": sum(len(d.split()) for d in request.documents),
},
"id": _make_id("rerank"),
"created": _now_ts(),
}
@app.post("/hybrid/embeddings")
@app.post("/v1/hybrid/embeddings")
def hybrid_embeddings(request: HybridRequest) -> dict[str, Any]:
"""Generate both dense and sparse embeddings for hybrid search."""
texts = _normalize_input(request.input)
dense_model = _get_dense_model()
sparse_model = _get_sparse_model()
# Generate both
dense_embeddings = list(dense_model.embed(texts))
sparse_embeddings = list(sparse_model.embed(texts))
data = []
for idx, (dense, sparse) in enumerate(zip(dense_embeddings, sparse_embeddings)):
data.append({
"object": "hybrid_embedding",
"dense": {
"vector": dense.tolist(),
"dim": len(dense),
},
"sparse": {
"indices": sparse.indices.tolist(),
"values": sparse.values.tolist(),
},
"index": idx,
})
return {
"object": "list",
"data": data,
"model": f"{request.dense_model} + {request.sparse_model}",
"id": _make_id("hybrid"),
"created": _now_ts(),
}
# ==================== Model Info ====================
@app.get("/models")
def list_models() -> dict[str, Any]:
"""List supported models and their specs."""
return {
"dense": {
"model": DENSE_MODEL,
"dim": 768,
"size_gb": 0.64,
"type": "code-optimized",
},
"sparse": {
"model": SPARSE_MODEL,
"type": "bm25",
"size_gb": 0.01,
"requires_idf": True,
},
"reranker": {
"model": RERANKER_MODEL,
"size_gb": 0.13,
"type": "cross-encoder",
},
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)