Feature: optimizes query embedding and edge collection search (#1126)
<!-- .github/pull_request_template.md --> ## Description Optimizes query embedding by reducing the number of query embedding calls and avoids multiple edge collection searches when they are available. ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
c0805b3aea
commit
d7b66dae03
2 changed files with 14 additions and 9 deletions
|
|
@ -132,18 +132,19 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
if node:
|
if node:
|
||||||
node.add_attribute("vector_distance", score)
|
node.add_attribute("vector_distance", score)
|
||||||
|
|
||||||
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None:
|
async def map_vector_distances_to_graph_edges(
|
||||||
|
self, vector_engine, query_vector, edge_distances
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
query_vector = await vector_engine.embed_data([query])
|
|
||||||
query_vector = query_vector[0]
|
|
||||||
if query_vector is None or len(query_vector) == 0:
|
if query_vector is None or len(query_vector) == 0:
|
||||||
raise ValueError("Failed to generate query embedding.")
|
raise ValueError("Failed to generate query embedding.")
|
||||||
|
|
||||||
edge_distances = await vector_engine.search(
|
if edge_distances is None:
|
||||||
collection_name="EdgeType_relationship_name",
|
edge_distances = await vector_engine.search(
|
||||||
query_text=query,
|
collection_name="EdgeType_relationship_name",
|
||||||
limit=0,
|
query_vector=query_vector,
|
||||||
)
|
limit=0,
|
||||||
|
)
|
||||||
|
|
||||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -177,8 +177,12 @@ async def brute_force_search(
|
||||||
|
|
||||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||||
|
|
||||||
|
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||||
|
|
||||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||||
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
|
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||||
|
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
||||||
|
)
|
||||||
|
|
||||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue