Merge pull request #436 from topoteretes/feature/cog-762-deleting-in-memory-embeddings-from-bruteforce-search-and
feat: deletes on the fly embeddings and uses edge collections
This commit is contained in:
commit
25d8f5e337
2 changed files with 7 additions and 56 deletions
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInte
|
|||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
||||
import heapq
|
||||
from graphistry import edges
|
||||
import asyncio
|
||||
|
||||
|
||||
class CogneeGraph(CogneeAbstractGraph):
|
||||
|
|
@ -127,51 +127,25 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
else:
|
||||
print(f"Node with id {node_id} not found in the graph.")
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self, vector_engine, query
|
||||
) -> None: # :TODO: When we calculate edge embeddings in vector db change this similarly to node mapping
|
||||
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None:
|
||||
try:
|
||||
# Step 1: Generate the query embedding
|
||||
query_vector = await vector_engine.embed_data([query])
|
||||
query_vector = query_vector[0]
|
||||
if query_vector is None or len(query_vector) == 0:
|
||||
raise ValueError("Failed to generate query embedding.")
|
||||
|
||||
# Step 2: Collect all unique relationship types
|
||||
unique_relationship_types = set()
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
if relationship_type:
|
||||
unique_relationship_types.add(relationship_type)
|
||||
edge_distances = await vector_engine.get_distance_from_collection_elements(
|
||||
"edge_type_relationship_name", query_text=query
|
||||
)
|
||||
|
||||
# Step 3: Embed all unique relationship types
|
||||
unique_relationship_types = list(unique_relationship_types)
|
||||
relationship_type_embeddings = await vector_engine.embed_data(unique_relationship_types)
|
||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
# Step 4: Map relationship types to their embeddings and calculate distances
|
||||
embedding_map = {}
|
||||
for relationship_type, embedding in zip(
|
||||
unique_relationship_types, relationship_type_embeddings
|
||||
):
|
||||
edge_vector = np.array(embedding)
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = np.dot(query_vector, edge_vector) / (
|
||||
np.linalg.norm(query_vector) * np.linalg.norm(edge_vector)
|
||||
)
|
||||
distance = 1 - similarity
|
||||
|
||||
# Round the distance to 4 decimal places and store it
|
||||
embedding_map[relationship_type] = round(distance, 4)
|
||||
|
||||
# Step 4: Assign precomputed distances to edges
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
if not relationship_type or relationship_type not in embedding_map:
|
||||
print(f"Edge {edge} has an unknown or missing relationship type.")
|
||||
continue
|
||||
|
||||
# Assign the precomputed distance
|
||||
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
||||
|
||||
except Exception as ex:
|
||||
|
|
|
|||
|
|
@ -62,24 +62,6 @@ async def brute_force_triplet_search(
|
|||
return retrieved_results
|
||||
|
||||
|
||||
def delete_duplicated_vector_db_elements(
|
||||
collections, results
|
||||
): #:TODO: This is just for now to fix vector db duplicates
|
||||
results_dict = {}
|
||||
for collection, results in zip(collections, results):
|
||||
seen_ids = set()
|
||||
unique_results = []
|
||||
for result in results:
|
||||
if result.id not in seen_ids:
|
||||
unique_results.append(result)
|
||||
seen_ids.add(result.id)
|
||||
else:
|
||||
print(f"Duplicate found in collection '{collection}': {result.id}")
|
||||
results_dict[collection] = unique_results
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
async def brute_force_search(
|
||||
query: str, user: User, top_k: int, collections: List[str] = None
|
||||
) -> list:
|
||||
|
|
@ -125,10 +107,7 @@ async def brute_force_search(
|
|||
]
|
||||
)
|
||||
|
||||
############################################# :TODO: Change when vector db does not contain duplicates
|
||||
node_distances = delete_duplicated_vector_db_elements(collections, results)
|
||||
# node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
##############################################
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
memory_fragment = CogneeGraph()
|
||||
|
||||
|
|
@ -140,14 +119,12 @@ async def brute_force_search(
|
|||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
|
||||
#:TODO: Change when vectordb contains edge embeddings
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
|
||||
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
|
||||
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
||||
|
||||
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue