chore: Adds error handling to brute force triplet search
This commit is contained in:
parent
c66c43e717
commit
db07179856
2 changed files with 62 additions and 28 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, List
|
import logging
|
||||||
|
from typing import List
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
|
@ -76,42 +77,74 @@ def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is
|
||||||
return results_dict
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
async def brute_force_search(query: str, user: User, top_k: int, collections: List[str] = None) -> list:
|
async def brute_force_search(
|
||||||
|
query: str,
|
||||||
|
user: User,
|
||||||
|
top_k: int,
|
||||||
|
collections: List[str] = None
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
Performs a brute force search to retrieve the top triplets from the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The search query.
|
||||||
|
user (User): The user performing the search.
|
||||||
|
top_k (int): The number of top results to retrieve.
|
||||||
|
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The top triplet results.
|
||||||
|
"""
|
||||||
|
if not query or not isinstance(query, str):
|
||||||
|
raise ValueError("The query must be a non-empty string.")
|
||||||
|
if top_k <= 0:
|
||||||
|
raise ValueError("top_k must be a positive integer.")
|
||||||
|
|
||||||
if collections is None:
|
if collections is None:
|
||||||
collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"]
|
collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"]
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
try:
|
||||||
graph_engine = await get_graph_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Failed to initialize engines: %s", e)
|
||||||
|
raise RuntimeError("Initialization error") from e
|
||||||
|
|
||||||
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
results = await asyncio.gather(
|
try:
|
||||||
*[vector_engine.get_distances_of_collection(collection, query_text=query) for collection in collections]
|
results = await asyncio.gather(
|
||||||
)
|
*[vector_engine.get_distances_of_collection(collection, query_text=query) for collection in collections]
|
||||||
|
)
|
||||||
|
|
||||||
############################################# :TODO: Change when vector db does not contain duplicates
|
############################################# :TODO: Change when vector db does not contain duplicates
|
||||||
node_distances = delete_duplicated_vector_db_elements(collections, results)
|
node_distances = delete_duplicated_vector_db_elements(collections, results)
|
||||||
# node_distances = {collection: result for collection, result in zip(collections, results)}
|
# node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||||
##############################################
|
##############################################
|
||||||
|
|
||||||
memory_fragment = CogneeGraph()
|
memory_fragment = CogneeGraph()
|
||||||
|
|
||||||
await memory_fragment.project_graph_from_db(graph_engine,
|
await memory_fragment.project_graph_from_db(graph_engine,
|
||||||
node_properties_to_project=['id',
|
node_properties_to_project=['id',
|
||||||
'description',
|
'description',
|
||||||
'name',
|
'name',
|
||||||
'type',
|
'type',
|
||||||
'text'],
|
'text'],
|
||||||
edge_properties_to_project=['relationship_name'])
|
edge_properties_to_project=['relationship_name'])
|
||||||
|
|
||||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||||
|
|
||||||
#:TODO: Change when vectordb contains edge embeddings
|
#:TODO: Change when vectordb contains edge embeddings
|
||||||
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
|
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
|
||||||
|
|
||||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||||
|
|
||||||
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
|
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e)
|
||||||
|
send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id)
|
||||||
|
raise RuntimeError("An error occurred during brute force search") from e
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import cognee
|
import cognee
|
||||||
import asyncio
|
import asyncio
|
||||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||||
|
from cognee.modules.retrieval.brute_force_triplet_search import format_triplets
|
||||||
|
|
||||||
job_1 = """
|
job_1 = """
|
||||||
CV 1: Relevant
|
CV 1: Relevant
|
||||||
|
|
@ -181,8 +182,8 @@ async def main(enable_steps):
|
||||||
|
|
||||||
# Step 4: Query insights
|
# Step 4: Query insights
|
||||||
if enable_steps.get("retriever"):
|
if enable_steps.get("retriever"):
|
||||||
await brute_force_triplet_search('Who has Phd?')
|
results = await brute_force_triplet_search('Who has the most experience with graphic design?')
|
||||||
|
print(format_triplets(results))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Flags to enable/disable steps
|
# Flags to enable/disable steps
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue