chore: Adds error handling to brute force triplet search

This commit is contained in:
hajdul88 2024-11-26 16:17:57 +01:00
parent c66c43e717
commit db07179856
2 changed files with 62 additions and 28 deletions

View file

@ -1,5 +1,6 @@
import asyncio
from typing import Dict, List
import logging
from typing import List
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
@ -76,42 +77,74 @@ def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is
return results_dict
async def brute_force_search(query: str, user: User, top_k: int, collections: List[str] = None) -> list:
async def brute_force_search(
query: str,
user: User,
top_k: int,
collections: List[str] = None
) -> list:
"""
Performs a brute force search to retrieve the top triplets from the graph.
Args:
query (str): The search query.
user (User): The user performing the search.
top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
Returns:
list: The top triplet results.
"""
if not query or not isinstance(query, str):
raise ValueError("The query must be a non-empty string.")
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if collections is None:
collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"]
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
results = await asyncio.gather(
*[vector_engine.get_distances_of_collection(collection, query_text=query) for collection in collections]
)
try:
results = await asyncio.gather(
*[vector_engine.get_distances_of_collection(collection, query_text=query) for collection in collections]
)
############################################# :TODO: Change when vector db does not contain duplicates
node_distances = delete_duplicated_vector_db_elements(collections, results)
# node_distances = {collection: result for collection, result in zip(collections, results)}
##############################################
############################################# :TODO: Change when vector db does not contain duplicates
node_distances = delete_duplicated_vector_db_elements(collections, results)
# node_distances = {collection: result for collection, result in zip(collections, results)}
##############################################
memory_fragment = CogneeGraph()
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(graph_engine,
node_properties_to_project=['id',
'description',
'name',
'type',
'text'],
edge_properties_to_project=['relationship_name'])
await memory_fragment.project_graph_from_db(graph_engine,
node_properties_to_project=['id',
'description',
'name',
'type',
'text'],
edge_properties_to_project=['relationship_name'])
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
#:TODO: Change when vectordb contains edge embeddings
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
#:TODO: Change when vectordb contains edge embeddings
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
return results
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
return results
except Exception as e:
logging.error("Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e)
send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id)
raise RuntimeError("An error occurred during brute force search") from e

View file

@ -1,6 +1,7 @@
import cognee
import asyncio
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.brute_force_triplet_search import format_triplets
job_1 = """
CV 1: Relevant
@ -181,8 +182,8 @@ async def main(enable_steps):
# Step 4: Query insights
if enable_steps.get("retriever"):
await brute_force_triplet_search('Who has Phd?')
results = await brute_force_triplet_search('Who has the most experience with graphic design?')
print(format_triplets(results))
if __name__ == '__main__':
# Flags to enable/disable steps