| import torch, chromadb, gc |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
| class is_docs: |
| def __init__(self): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = SentenceTransformer("nlpai-lab/KURE-v1", |
| cache_folder="/Users/jaewook/PycharmProjects/DS_security_API/weights", |
| trust_remote_code=True).eval().to(self.device) |
|
|
| self.client_docs = chromadb.PersistentClient(path="../db/docs") |
| self.collection_docs = self.client_docs.get_or_create_collection(name="image_embedding", |
| metadata={"hnsw": "cosine"}, ) |
| self.cos_sim = torch.nn.CosineSimilarity(dim=0) |
|
|
| @torch.inference_mode() |
| async def making_embedding_vector(self, docs: str, category: int = 1, infer_mode: bool = False): |
| embeddings = self.model.encode(docs).tolist() |
| test_metadata = {"category": category} |
| if not infer_mode: |
| for embedding in embeddings: |
| self.add_doc_vectors(embedding, test_metadata) |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return embeddings |
|
|
| def add_doc_vectors(self, vectors, metadatas): |
| self.collection_docs.add( |
| embeddings=vectors, |
| metadatas=metadatas, |
| ids="asdf" |
| ) |
|
|
|
|
| if __name__=="__main__": |
| import os |
| print(os.getcwd()) |
| |
| |
| |
| |