| | import pandas as pd |
| | import os |
| | import chromadb |
| | from chromadb.utils import embedding_functions |
| | import math |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def create_domain_identification_database(vdb_path: str,collection_name:str , df: pd.DataFrame) -> None: |
| | """This function processes the dataframe into the required format, and then creates the following collections in a ChromaDB instance |
| | 1. domain_identification_collection - Contains input text embeddings, and the metadata the other columns |
| | |
| | Args: |
| | collection_name (str) : name of database collection |
| | vdb_path (str): Relative path of the location of the ChromaDB instance. |
| | df (pd.DataFrame): task scheduling dataset. |
| | |
| | """ |
| |
|
| | |
| | chroma_client = chromadb.PersistentClient(path=vdb_path) |
| |
|
| | |
| | embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE") |
| |
|
| | |
| | domain_identification_collection = chroma_client.create_collection( |
| | name=collection_name, |
| | embedding_function=embedding_function, |
| | ) |
| |
|
| |
|
| | |
| | domain_identification_documents = [row.query for row in df.itertuples()] |
| |
|
| | |
| | domain_identification_metadata = [ |
| | {"domain": row.domain , "label": row.label} |
| | for row in df.itertuples() |
| | ] |
| |
|
| | |
| | domain_ids = ["domain_id " + str(row.Index) for row in df.itertuples()] |
| |
|
| |
|
| | length = len(df) |
| | num_iteration = length / 166 |
| | num_iteration = math.ceil(num_iteration) |
| |
|
| | start = 0 |
| | |
| | for i in range(num_iteration): |
| | if i == num_iteration - 1 : |
| | domain_identification_collection.add(documents=domain_identification_documents[start:], metadatas=domain_identification_metadata[start:], ids=domain_ids[start:]) |
| | else: |
| | end = start + 166 |
| | domain_identification_collection.add(documents=domain_identification_documents[start:end], metadatas=domain_identification_metadata[start:end], ids=domain_ids[start:end]) |
| | start = end |
| | return None |
| |
|
| |
|
| |
|
| | def delete_collection_from_vector_db(vdb_path: str, collection_name: str) -> None: |
| | """Deletes a particular collection from the persistent ChromaDB instance. |
| | |
| | Args: |
| | vdb_path (str): Path of the persistent ChromaDB instance. |
| | collection_name (str): Name of the collection to be deleted. |
| | """ |
| | chroma_client = chromadb.PersistentClient(path=vdb_path) |
| | chroma_client.delete_collection(collection_name) |
| | return None |
| |
|
| |
|
| | def list_collections_from_vector_db(vdb_path: str) -> None: |
| | """Lists all the available collections from the persistent ChromaDB instance. |
| | |
| | Args: |
| | vdb_path (str): Path of the persistent ChromaDB instance. |
| | """ |
| | chroma_client = chromadb.PersistentClient(path=vdb_path) |
| | print(chroma_client.list_collections()) |
| |
|
| |
|
| | def get_collection_from_vector_db( |
| | vdb_path: str, collection_name: str |
| | ) -> chromadb.Collection: |
| | """Fetches a particular ChromaDB collection object from the persistent ChromaDB instance. |
| | |
| | Args: |
| | vdb_path (str): Path of the persistent ChromaDB instance. |
| | collection_name (str): Name of the collection which needs to be retrieved. |
| | """ |
| | chroma_client = chromadb.PersistentClient(path=vdb_path) |
| |
|
| | huggingface_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE") |
| |
|
| |
|
| |
|
| |
|
| | collection = chroma_client.get_collection( |
| | name=collection_name, embedding_function=huggingface_ef |
| | ) |
| |
|
| | return collection |
| |
|
| |
|
| | def retrieval( input_text : str, |
| | num_results : int, |
| | collection: chromadb.Collection ): |
| |
|
| | """fetches the domain name from the collection based on the semantic similarity |
| | |
| | args: |
| | input_text : the received text which can be news , posts , or tweets |
| | num_results : number of fetched examples from the collection |
| | collection : the extracted collection from the database that we will fetch examples from |
| | |
| | """ |
| |
|
| |
|
| | fetched_domain = collection.query( |
| | query_texts = [input_text], |
| | n_results = num_results, |
| | ) |
| |
|
| | |
| |
|
| | domain = fetched_domain["metadatas"][0][0]["domain"] |
| | label = fetched_domain["metadatas"][0][0]["label"] |
| | distance = fetched_domain["distances"][0][0] |
| |
|
| | return domain , label , distance |