From 1641b9c1c1af815d64f0ab8c91fa82ee9f873f07 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 9 Sep 2025 11:50:24 +0200 Subject: [PATCH] quickstart working with memgraph --- graphiti_core/driver/memgraph_driver.py | 18 +++++------ graphiti_core/graph_queries.py | 6 ++-- graphiti_core/models/edges/edge_db_queries.py | 4 +-- graphiti_core/models/nodes/node_db_queries.py | 31 ++++++++++++++++--- graphiti_core/search/search_utils.py | 2 +- graphiti_core/utils/bulk_utils.py | 1 + 6 files changed, 41 insertions(+), 21 deletions(-) diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py index e3377994..d14c8d83 100644 --- a/graphiti_core/driver/memgraph_driver.py +++ b/graphiti_core/driver/memgraph_driver.py @@ -18,7 +18,7 @@ import logging from collections.abc import Coroutine from typing import Any -from neo4j import GraphDatabase +from neo4j import AsyncGraphDatabase from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider @@ -31,7 +31,7 @@ class MemgraphDriver(GraphDriver): def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'): super().__init__() - self.client = GraphDatabase.driver( + self.client = AsyncGraphDatabase.driver( uri=uri, auth=(user or '', password or ''), ) @@ -57,15 +57,11 @@ class MemgraphDriver(GraphDriver): database = kwargs.pop('database_', self._database) kwargs.pop('parameters_', None) # Remove if present - with self.client.session(database=database) as session: + async with self.client.session(database=database) as session: try: - # Debug: Print the query and parameters - print(f"DEBUG - Memgraph Query: {cypher_query_}") - print(f"DEBUG - Memgraph Params: {params}") - - result = session.run(cypher_query_, params) - records = list(result) - summary = result.consume() + result = await session.run(cypher_query_, params) + records = [record async for record in result] + summary = await result.consume() keys = result.keys() return (records, summary, keys) except Exception as e: @@ -77,7 +73,7 @@ class MemgraphDriver(GraphDriver): return self.client.session(database=_database) # type: ignore async def close(self) -> None: - return self.client.close() + return await self.client.close() def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: # TODO: Implement index deletion for Memgraph diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 56c2fb6c..9f85b1c1 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -140,7 +140,7 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)" if provider == GraphProvider.MEMGRAPH: - return f'CALL text_search.search("{name}", {query}) YIELD node' + return f'CALL text_search.search_all("{name}", {query})' return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' @@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: return f'array_cosine_similarity({vec1}, {vec2})' if provider == GraphProvider.MEMGRAPH: - return f'CALL vector_search.cosine_similarity({vec1}, {vec2}) YIELD similarity RETURN similarity AS score' + return f'cosineSimilarity({vec1}, {vec2})' return f'vector.similarity.cosine({vec1}, {vec2})' @@ -169,6 +169,6 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)" if provider == GraphProvider.MEMGRAPH: - return f'CALL text_search.search_edges("{name}", $query) YIELD node' + return f'CALL text_search.search_all_edges("{name}", $query)' return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index 0ed6fc77..abb0edbb 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -167,8 +167,8 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: MATCH (target:Entity {uuid: edge.target_node_uuid}) MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) SET e = edge - WITH e, edge e.fact_embedding = edge.fact_embedding - RETURN edge.uuid AS uuid + SET e.fact_embedding = edge.fact_embedding + RETURN edge.uuid AS uuid; """ case _: return """ diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 081c0252..7cf075f0 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -92,7 +92,27 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} RETURN n.uuid AS uuid """ - case _: # Neo4j and Memgraph + case GraphProvider.MEMGRAPH: + return """ + UNWIND $episodes AS episode + MERGE (n:Episodic {uuid: episode.uuid}) + FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END | + SET n:`${episode.group_label}` + ) + SET n = { + uuid: episode.uuid, + name: episode.name, + group_id: episode.group_id, + source_description: episode.source_description, + source: episode.source, + content: episode.content, + entity_edges: episode.entity_edges, + created_at: episode.created_at, + valid_at: episode.valid_at + } + RETURN n.uuid AS uuid; + """ + case _: # Neo4j return """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) @@ -235,10 +255,13 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) return """ UNWIND $nodes AS node MERGE (n:Entity {uuid: node.uuid}) - SET n:$(node.labels) + FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END | + SET n:`${label}` + ) SET n = node - WITH n, node SET n.name_embedding = node.name_embedding - RETURN n.uuid AS uuid + WITH n, node + SET n.name_embedding = node.name_embedding + RETURN n.uuid AS uuid; """ case _: # Neo4j return """ diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 14788816..ff08239d 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -563,7 +563,7 @@ async def node_fulltext_search( if driver.provider == GraphProvider.KUZU: yield_query = 'WITH node AS n, score' elif driver.provider == GraphProvider.MEMGRAPH: - yield_query = ' WITH node AS n, 1.0 AS score' # Memgraph: continue from YIELD node + yield_query = ' YIELD node AS n WITH n, 1.0 AS score' if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 426cbe90..ece4aea8 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -191,6 +191,7 @@ async def add_nodes_and_edges_bulk_tx( for edge in episodic_edges: await tx.run(episodic_edge_query, **edge.model_dump()) else: + print(episodes) await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes) await tx.run(