diff --git a/cognee/modules/retrieval/brute_force_triplet_search.py b/cognee/modules/retrieval/brute_force_triplet_search.py index 6fef6104f..ea7c2cb4d 100644 --- a/cognee/modules/retrieval/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/brute_force_triplet_search.py @@ -1,5 +1,6 @@ import asyncio -from typing import Dict, List +import logging +from typing import List from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph @@ -76,42 +77,74 @@ def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is return results_dict -async def brute_force_search(query: str, user: User, top_k: int, collections: List[str] = None) -> list: +async def brute_force_search( + query: str, + user: User, + top_k: int, + collections: List[str] = None +) -> list: + """ + Performs a brute force search to retrieve the top triplets from the graph. + + Args: + query (str): The search query. + user (User): The user performing the search. + top_k (int): The number of top results to retrieve. + collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections. + + Returns: + list: The top triplet results. + """ + if not query or not isinstance(query, str): + raise ValueError("The query must be a non-empty string.") + if top_k <= 0: + raise ValueError("top_k must be a positive integer.") + if collections is None: collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"] - vector_engine = get_vector_engine() - graph_engine = await get_graph_engine() + try: + vector_engine = get_vector_engine() + graph_engine = await get_graph_engine() + except Exception as e: + logging.error("Failed to initialize engines: %s", e) + raise RuntimeError("Initialization error") from e send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) - results = await asyncio.gather( - *[vector_engine.get_distances_of_collection(collection, query_text=query) for collection in collections] - ) + try: + results = await asyncio.gather( + *[vector_engine.get_distances_of_collection(collection, query_text=query) for collection in collections] + ) - ############################################# :TODO: Change when vector db does not contain duplicates - node_distances = delete_duplicated_vector_db_elements(collections, results) - # node_distances = {collection: result for collection, result in zip(collections, results)} - ############################################## + ############################################# :TODO: Change when vector db does not contain duplicates + node_distances = delete_duplicated_vector_db_elements(collections, results) + # node_distances = {collection: result for collection, result in zip(collections, results)} + ############################################## - memory_fragment = CogneeGraph() + memory_fragment = CogneeGraph() - await memory_fragment.project_graph_from_db(graph_engine, - node_properties_to_project=['id', - 'description', - 'name', - 'type', - 'text'], - edge_properties_to_project=['relationship_name']) + await memory_fragment.project_graph_from_db(graph_engine, + node_properties_to_project=['id', + 'description', + 'name', + 'type', + 'text'], + edge_properties_to_project=['relationship_name']) - await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) + await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) - #:TODO: Change when vectordb contains edge embeddings - await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query) + #:TODO: Change when vectordb contains edge embeddings + await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query) - results = await memory_fragment.calculate_top_triplet_importances(k=top_k) + results = await memory_fragment.calculate_top_triplet_importances(k=top_k) - send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) + send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) - #:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db - return results + #:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db + return results + + except Exception as e: + logging.error("Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e) + send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id) + raise RuntimeError("An error occurred during brute force search") from e diff --git a/examples/python/dynamic_steps_example.py b/examples/python/dynamic_steps_example.py index 68bbb7bce..ed5c97561 100644 --- a/examples/python/dynamic_steps_example.py +++ b/examples/python/dynamic_steps_example.py @@ -1,6 +1,7 @@ import cognee import asyncio from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.brute_force_triplet_search import format_triplets job_1 = """ CV 1: Relevant @@ -181,8 +182,8 @@ async def main(enable_steps): # Step 4: Query insights if enable_steps.get("retriever"): - await brute_force_triplet_search('Who has Phd?') - + results = await brute_force_triplet_search('Who has the most experience with graphic design?') + print(format_triplets(results)) if __name__ == '__main__': # Flags to enable/disable steps