Feature: optimizes query embedding and edge collection search (#1126)

<!-- .github/pull_request_template.md -->

## Description
Optimizes query embedding by reducing the number of query embedding
calls and avoids multiple edge collection searches when they are
available.

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
hajdul88 2025-07-23 11:47:22 +02:00 committed by GitHub
parent 59594e01ac
commit 2b1c17404c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 9 deletions

View file

@ -132,18 +132,19 @@ class CogneeGraph(CogneeAbstractGraph):
if node:
node.add_attribute("vector_distance", score)
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None:
async def map_vector_distances_to_graph_edges(
self, vector_engine, query_vector, edge_distances
) -> None:
try:
query_vector = await vector_engine.embed_data([query])
query_vector = query_vector[0]
if query_vector is None or len(query_vector) == 0:
raise ValueError("Failed to generate query embedding.")
edge_distances = await vector_engine.search(
collection_name="EdgeType_relationship_name",
query_text=query,
limit=0,
)
if edge_distances is None:
edge_distances = await vector_engine.search(
collection_name="EdgeType_relationship_name",
query_vector=query_vector,
limit=0,
)
embedding_map = {result.payload["text"]: result.score for result in edge_distances}

View file

@ -177,8 +177,12 @@ async def brute_force_search(
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
await memory_fragment.map_vector_distances_to_graph_edges(
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)