From 2b1c17404c8102853f997dbe6583fa355aaf4795 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 23 Jul 2025 11:47:22 +0200 Subject: [PATCH] Feature: optimizes query embedding and edge collection search (#1126) ## Description Optimizes query embedding by reducing the number of query embedding calls and avoids multiple edge collection searches when they are available. ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .../modules/graph/cognee_graph/CogneeGraph.py | 17 +++++++++-------- .../utils/brute_force_triplet_search.py | 6 +++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index ada8821eb..ba34c897a 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -132,18 +132,19 @@ class CogneeGraph(CogneeAbstractGraph): if node: node.add_attribute("vector_distance", score) - async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None: + async def map_vector_distances_to_graph_edges( + self, vector_engine, query_vector, edge_distances + ) -> None: try: - query_vector = await vector_engine.embed_data([query]) - query_vector = query_vector[0] if query_vector is None or len(query_vector) == 0: raise ValueError("Failed to generate query embedding.") - edge_distances = await vector_engine.search( - collection_name="EdgeType_relationship_name", - query_text=query, - limit=0, - ) + if edge_distances is None: + edge_distances = await vector_engine.search( + collection_name="EdgeType_relationship_name", + query_vector=query_vector, + limit=0, + ) embedding_map = {result.payload["text"]: result.score for result in edge_distances} diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 4667f4738..bfe0aa521 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -177,8 +177,12 @@ async def brute_force_search( node_distances = {collection: result for collection, result in zip(collections, results)} + edge_distances = node_distances.get("EdgeType_relationship_name", None) + await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) - await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query) + await memory_fragment.map_vector_distances_to_graph_edges( + vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances + ) results = await memory_fragment.calculate_top_triplet_importances(k=top_k)