From a78fec3a918e80e221cc5260bf12c97f7cb7e22e Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 12 May 2025 13:06:57 +0200 Subject: [PATCH] fix: Fixes collection search limit in brute force triplet search (#814) ## Description ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .github/workflows/test_memgraph.yml | 8 +-- .../graph/memgraph/memgraph_adapter.py | 63 ++++++++++--------- .../utils/brute_force_triplet_search.py | 2 +- cognee/tests/test_memgraph.py | 2 +- 4 files changed, 39 insertions(+), 36 deletions(-) diff --git a/.github/workflows/test_memgraph.yml b/.github/workflows/test_memgraph.yml index e160382f4..b7ea9d837 100644 --- a/.github/workflows/test_memgraph.yml +++ b/.github/workflows/test_memgraph.yml @@ -1,9 +1,9 @@ name: test | memgraph -on: - workflow_dispatch: - pull_request: - types: [labeled, synchronize] +# on: +# workflow_dispatch: +# pull_request: +# types: [labeled, synchronize] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} diff --git a/cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py b/cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py index f3a974359..5ef438077 100644 --- a/cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py +++ b/cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py @@ -16,6 +16,7 @@ from cognee.modules.storage.utils import JSONEncoder logger = get_logger("MemgraphAdapter", level=ERROR) + class MemgraphAdapter(GraphDBInterface): def __init__( self, @@ -34,7 +35,7 @@ class MemgraphAdapter(GraphDBInterface): async def get_session(self) -> AsyncSession: async with self.driver.session() as session: yield session - + async def query( self, query: str, @@ -48,7 +49,7 @@ class MemgraphAdapter(GraphDBInterface): except Neo4jError as error: logger.error("Memgraph query error: %s", error, exc_info=True) raise error - + async def has_node(self, node_id: str) -> bool: results = await self.query( """ @@ -59,7 +60,7 @@ class MemgraphAdapter(GraphDBInterface): {"node_id": node_id}, ) return results[0]["node_exists"] if len(results) > 0 else False - + async def add_node(self, node: DataPoint): serialized_properties = self.serialize_properties(node.model_dump()) @@ -102,7 +103,7 @@ class MemgraphAdapter(GraphDBInterface): results = await self.extract_nodes([node_id]) return results[0] if len(results) > 0 else None - + async def extract_nodes(self, node_ids: List[str]): query = """ UNWIND $node_ids AS id @@ -114,15 +115,15 @@ class MemgraphAdapter(GraphDBInterface): results = await self.query(query, params) return [result["node"] for result in results] - + async def delete_node(self, node_id: str): sanitized_id = node_id.replace(":", "_") - + query = "MATCH (node: {{id: $node_id}}) DETACH DELETE node" params = {"node_id": sanitized_id} return await self.query(query, params) - + async def delete_nodes(self, node_ids: list[str]) -> None: query = """ UNWIND $node_ids AS id @@ -132,7 +133,7 @@ class MemgraphAdapter(GraphDBInterface): params = {"node_ids": node_ids} return await self.query(query, params) - + async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: query = """ MATCH (from_node)-[relationship]->(to_node) @@ -145,10 +146,10 @@ class MemgraphAdapter(GraphDBInterface): "to_node_id": str(to_node), "edge_label": edge_label, } - + records = await self.query(query, params) return records[0]["edge_exists"] if records else False - + async def has_edges(self, edges): query = """ UNWIND $edges AS edge @@ -174,7 +175,7 @@ class MemgraphAdapter(GraphDBInterface): except Neo4jError as error: logger.error("Memgraph query error: %s", error, exc_info=True) raise error - + async def add_edge( self, from_node: UUID, @@ -203,7 +204,7 @@ class MemgraphAdapter(GraphDBInterface): } return await self.query(query, params) - + async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None: query = """ UNWIND $edges AS edge @@ -217,7 +218,7 @@ class MemgraphAdapter(GraphDBInterface): target_node_id: edge.to_node }, edge.properties, - to_node, + to_node, {} ) YIELD rel RETURN rel""" @@ -242,7 +243,7 @@ class MemgraphAdapter(GraphDBInterface): except Neo4jError as error: logger.error("Memgraph query error: %s", error, exc_info=True) raise error - + async def get_edges(self, node_id: str): query = """ MATCH (n {id: $node_id})-[r]-(m) @@ -255,7 +256,7 @@ class MemgraphAdapter(GraphDBInterface): (result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]}) for result in results ] - + async def get_disconnected_nodes(self) -> list[str]: query = """ // Step 1: Collect all nodes @@ -290,7 +291,7 @@ class MemgraphAdapter(GraphDBInterface): results = await self.query(query) return results[0]["ids"] if len(results) > 0 else [] - + async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]: if edge_label is not None: query = """ @@ -323,7 +324,7 @@ class MemgraphAdapter(GraphDBInterface): ) return [result["predecessor"] for result in results] - + async def get_successors(self, node_id: str, edge_label: str = None) -> list[str]: if edge_label is not None: query = """ @@ -356,14 +357,14 @@ class MemgraphAdapter(GraphDBInterface): ) return [result["successor"] for result in results] - + async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]: predecessors, successors = await asyncio.gather( self.get_predecessors(node_id), self.get_successors(node_id) ) return predecessors + successors - + async def get_connections(self, node_id: UUID) -> list: predecessors_query = """ MATCH (node)<-[relation]-(neighbour) @@ -392,7 +393,7 @@ class MemgraphAdapter(GraphDBInterface): connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2])) return connections - + async def remove_connection_to_predecessors_of( self, node_ids: list[str], edge_label: str ) -> None: @@ -406,7 +407,7 @@ class MemgraphAdapter(GraphDBInterface): params = {"node_ids": node_ids, "edge_label": edge_label} return await self.query(query, params) - + async def remove_connection_to_successors_of( self, node_ids: list[str], edge_label: str ) -> None: @@ -419,13 +420,13 @@ class MemgraphAdapter(GraphDBInterface): params = {"node_ids": node_ids} return await self.query(query, params) - + async def delete_graph(self): query = """MATCH (node) DETACH DELETE node;""" return await self.query(query) - + def serialize_properties(self, properties=dict()): serialized_properties = {} @@ -441,7 +442,7 @@ class MemgraphAdapter(GraphDBInterface): serialized_properties[property_key] = property_value return serialized_properties - + async def get_model_independent_graph_data(self): query_nodes = "MATCH (n) RETURN collect(n) AS nodes" nodes = await self.query(query_nodes) @@ -450,7 +451,7 @@ class MemgraphAdapter(GraphDBInterface): edges = await self.query(query_edges) return (nodes, edges) - + async def get_graph_data(self): query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties" @@ -480,7 +481,7 @@ class MemgraphAdapter(GraphDBInterface): ] return (nodes, edges) - + async def get_filtered_graph_data(self, attribute_filters): """ Fetches nodes and relationships filtered by specified attribute values. @@ -536,7 +537,7 @@ class MemgraphAdapter(GraphDBInterface): return (nodes, edges) async def get_node_labels_string(self): - node_labels_query = f""" + node_labels_query = """ MATCH (n) WITH DISTINCT labels(n) AS labelList UNWIND labelList AS label @@ -552,7 +553,9 @@ class MemgraphAdapter(GraphDBInterface): return node_labels_str async def get_relationship_labels_string(self): - relationship_types_query = "MATCH ()-[r]->() RETURN collect(DISTINCT type(r)) AS relationships;" + relationship_types_query = ( + "MATCH ()-[r]->() RETURN collect(DISTINCT type(r)) AS relationships;" + ) relationship_types_result = await self.query(relationship_types_query) relationship_types = ( relationship_types_result[0]["relationships"] if relationship_types_result else [] @@ -643,7 +646,7 @@ class MemgraphAdapter(GraphDBInterface): WITH n, degree, COUNT(n2) AS triangle_count // Step 4: Compute local clustering coefficient - WITH n, degree, + WITH n, degree, CASE WHEN degree <= 1 THEN 0.0 ELSE (1.0 * triangle_count) / (degree * (degree - 1) / 2.0) END AS local_cc @@ -684,4 +687,4 @@ class MemgraphAdapter(GraphDBInterface): "diameter": -1, "avg_shortest_path_length": -1, "avg_clustering": -1, - } \ No newline at end of file + } diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 5e08eb9ac..0a08fbd00 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -146,7 +146,7 @@ async def brute_force_search( async def search_in_collection(collection_name: str): try: return await vector_engine.search( - collection_name=collection_name, query_text=query, limit=top_k + collection_name=collection_name, query_text=query, limit=0 ) except CollectionNotFoundError: return [] diff --git a/cognee/tests/test_memgraph.py b/cognee/tests/test_memgraph.py index f3363d2f1..49f3e08db 100644 --- a/cognee/tests/test_memgraph.py +++ b/cognee/tests/test_memgraph.py @@ -95,7 +95,7 @@ async def main(): await cognee.prune.prune_system(metadata=True) from cognee.infrastructure.databases.graph import get_graph_engine - + graph_engine = await get_graph_engine() nodes, edges = await graph_engine.get_graph_data() assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"