| import pandas as pd |
| import os |
| import chromadb |
| from chromadb.utils import embedding_functions |
| import math |
|
|
|
|
|
|
|
|
|
|
| def create_domain_identification_database(vdb_path: str,collection_name:str , df: pd.DataFrame) -> None: |
| """This function processes the dataframe into the required format, and then creates the following collections in a ChromaDB instance |
| 1. domain_identification_collection - Contains input text embeddings, and the metadata the other columns |
| |
| Args: |
| collection_name (str) : name of database collection |
| vdb_path (str): Relative path of the location of the ChromaDB instance. |
| df (pd.DataFrame): task scheduling dataset. |
| |
| """ |
|
|
| |
| chroma_client = chromadb.PersistentClient(path=vdb_path) |
|
|
| |
| embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE") |
|
|
| |
| domain_identification_collection = chroma_client.create_collection( |
| name=collection_name, |
| embedding_function=embedding_function, |
| ) |
|
|
|
|
| |
| domain_identification_documents = [row.query for row in df.itertuples()] |
|
|
| |
| domain_identification_metadata = [ |
| {"domain": row.domain , "label": row.label} |
| for row in df.itertuples() |
| ] |
|
|
| |
| domain_ids = ["domain_id " + str(row.Index) for row in df.itertuples()] |
|
|
|
|
| length = len(df) |
| num_iteration = length / 166 |
| num_iteration = math.ceil(num_iteration) |
|
|
| start = 0 |
| |
| for i in range(num_iteration): |
| if i == num_iteration - 1 : |
| domain_identification_collection.add(documents=domain_identification_documents[start:], metadatas=domain_identification_metadata[start:], ids=domain_ids[start:]) |
| else: |
| end = start + 166 |
| domain_identification_collection.add(documents=domain_identification_documents[start:end], metadatas=domain_identification_metadata[start:end], ids=domain_ids[start:end]) |
| start = end |
| return None |
|
|
|
|
|
|
| def delete_collection_from_vector_db(vdb_path: str, collection_name: str) -> None: |
| """Deletes a particular collection from the persistent ChromaDB instance. |
| |
| Args: |
| vdb_path (str): Path of the persistent ChromaDB instance. |
| collection_name (str): Name of the collection to be deleted. |
| """ |
| chroma_client = chromadb.PersistentClient(path=vdb_path) |
| chroma_client.delete_collection(collection_name) |
| return None |
|
|
|
|
| def list_collections_from_vector_db(vdb_path: str) -> None: |
| """Lists all the available collections from the persistent ChromaDB instance. |
| |
| Args: |
| vdb_path (str): Path of the persistent ChromaDB instance. |
| """ |
| chroma_client = chromadb.PersistentClient(path=vdb_path) |
| print(chroma_client.list_collections()) |
|
|
|
|
| def get_collection_from_vector_db( |
| vdb_path: str, collection_name: str |
| ) -> chromadb.Collection: |
| """Fetches a particular ChromaDB collection object from the persistent ChromaDB instance. |
| |
| Args: |
| vdb_path (str): Path of the persistent ChromaDB instance. |
| collection_name (str): Name of the collection which needs to be retrieved. |
| """ |
| chroma_client = chromadb.PersistentClient(path=vdb_path) |
|
|
| huggingface_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE") |
|
|
|
|
|
|
|
|
| collection = chroma_client.get_collection( |
| name=collection_name, embedding_function=huggingface_ef |
| ) |
|
|
| return collection |
|
|
|
|
| def retrieval( input_text : str, |
| num_results : int, |
| collection: chromadb.Collection ): |
|
|
| """fetches the domain name from the collection based on the semantic similarity |
| |
| args: |
| input_text : the received text which can be news , posts , or tweets |
| num_results : number of fetched examples from the collection |
| collection : the extracted collection from the database that we will fetch examples from |
| |
| """ |
|
|
|
|
| fetched_domain = collection.query( |
| query_texts = [input_text], |
| n_results = num_results, |
| ) |
|
|
| |
|
|
| domain = fetched_domain["metadatas"][0][0]["domain"] |
| label = fetched_domain["metadatas"][0][0]["label"] |
| distance = fetched_domain["distances"][0][0] |
|
|
| return domain , label , distance |