From 0f9dd11cc85f52ef9bb27400f09ab39c19c0e3d7 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 4 Sep 2025 12:28:53 +0200 Subject: [PATCH 1/7] add memgraph as graphdb vendor wip --- README.md | 4 +- examples/quickstart/README.md | 14 +- examples/quickstart/quickstart_memgraph.py | 252 ++++++++++++++++++ graphiti_core/driver/__init__.py | 5 +- graphiti_core/driver/driver.py | 1 + graphiti_core/driver/memgraph_driver.py | 67 +++++ graphiti_core/graph_queries.py | 41 +++ graphiti_core/models/edges/edge_db_queries.py | 21 +- graphiti_core/models/nodes/node_db_queries.py | 28 +- tests/helpers_test.py | 18 ++ 10 files changed, 442 insertions(+), 9 deletions(-) create mode 100644 examples/quickstart/quickstart_memgraph.py create mode 100644 graphiti_core/driver/memgraph_driver.py diff --git a/README.md b/README.md index 5dfe57fe..b0b10f73 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ particularly suitable for applications requiring real-time interaction and preci Requirements: - Python 3.10 or higher -- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon +- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Memgraph 2.22+ / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend) - OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding) @@ -556,7 +556,7 @@ When you initialize a Graphiti instance, we collect: - **Graphiti version**: The version you're using - **Configuration choices**: - LLM provider type (OpenAI, Azure, Anthropic, etc.) - - Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics) + - Database backend (Neo4j, FalkorDB, Kuzu, Memgraph, Amazon Neptune Database or Neptune Analytics) - Embedder provider (OpenAI, Azure, Voyage, etc.) ### What We Don't Collect diff --git a/examples/quickstart/README.md b/examples/quickstart/README.md index a172d00c..7306fff2 100644 --- a/examples/quickstart/README.md +++ b/examples/quickstart/README.md @@ -2,7 +2,7 @@ This example demonstrates the basic functionality of Graphiti, including: -1. Connecting to a Neo4j or FalkorDB database +1. Connecting to a Neo4j, FalkorDB, or Memgraph database 2. Initializing Graphiti indices and constraints 3. Adding episodes to the graph 4. Searching the graph with semantic and keyword matching @@ -18,6 +18,8 @@ This example demonstrates the basic functionality of Graphiti, including: - A local DBMS created and started in Neo4j Desktop - **For FalkorDB**: - FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup) +- **For Memgraph**: + - Memgraph server running (see [Memgraph documentation](https://memgraph.com/docs/) for setup) - **For Amazon Neptune**: - Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup) @@ -44,6 +46,11 @@ export NEO4J_PASSWORD=password # Optional FalkorDB connection parameters (defaults shown) export FALKORDB_URI=falkor://localhost:6379 +# Optional Memgraph connection parameters (defaults shown) +export MEMGRAPH_URI=bolt://localhost:7687 +export MEMGRAPH_USER= # Optional - Memgraph doesn't require auth by default +export MEMGRAPH_PASSWORD= # Optional - Memgraph doesn't require auth by default + # Optional Amazon Neptune connection parameters NEPTUNE_HOST=your_neptune_host NEPTUNE_PORT=your_port_or_8182 @@ -65,13 +72,16 @@ python quickstart_neo4j.py # For FalkorDB python quickstart_falkordb.py +# For Memgraph +python quickstart_memgraph.py + # For Amazon Neptune python quickstart_neptune.py ``` ## What This Example Demonstrates -- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB +- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, FalkorDB, Memgraph, or Amazon Neptune - **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges - **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges) - **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance diff --git a/examples/quickstart/quickstart_memgraph.py b/examples/quickstart/quickstart_memgraph.py new file mode 100644 index 00000000..911991ec --- /dev/null +++ b/examples/quickstart/quickstart_memgraph.py @@ -0,0 +1,252 @@ +""" +Copyright 2025, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import json +import logging +import os +from datetime import datetime, timezone +from logging import INFO + +from dotenv import load_dotenv + +from graphiti_core import Graphiti +from graphiti_core.driver.memgraph_driver import MemgraphDriver +from graphiti_core.nodes import EpisodeType +from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF + +################################################# +# CONFIGURATION +################################################# +# Set up logging and environment variables for +# connecting to Memgraph database +################################################# + +# Configure logging +logging.basicConfig( + level=INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', +) +logger = logging.getLogger(__name__) + +load_dotenv() + +# Memgraph connection parameters +# Make sure Memgraph is running (default port 7687, same as Neo4j) +memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687') +memgraph_user = os.environ.get('MEMGRAPH_USER', '') # Memgraph often doesn't require auth by default +memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '') + +if not memgraph_uri: + raise ValueError('MEMGRAPH_URI must be set') + + +async def main(): + ################################################# + # INITIALIZATION + ################################################# + # Connect to Memgraph and set up Graphiti indices + # This is required before using other Graphiti + # functionality + ################################################# + + # Initialize Memgraph driver + memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password) + + # Initialize Graphiti with Memgraph connection + graphiti = Graphiti(graph_driver=memgraph_driver) + + try: + # Initialize the graph database with graphiti's indices. This only needs to be done once. + await graphiti.build_indices_and_constraints() + + ################################################# + # ADDING EPISODES + ################################################# + # Episodes are the primary units of information + # in Graphiti. They can be text or structured JSON + # and are automatically processed to extract entities + # and relationships. + ################################################# + + # Example: Add Episodes + # Episodes list containing both text and JSON episodes + episodes = [ + { + 'content': 'Kamala Harris is the Attorney General of California. She was previously ' + 'the district attorney for San Francisco.', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + { + 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + { + 'content': { + 'name': 'Gavin Newsom', + 'position': 'Governor', + 'state': 'California', + 'predecessor': 'Jerry Brown', + }, + 'type': EpisodeType.json, + 'description': 'politician info', + }, + { + 'content': 'Jerry Brown was the predecessor of Gavin Newsom as Governor of California', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + ] + + # Add episodes to the graph + for i, episode in enumerate(episodes): + await graphiti.add_episode( + name=f'Memgraph Demo {i}', + episode_body=episode['content'] + if isinstance(episode['content'], str) + else json.dumps(episode['content']), + source=episode['type'], + source_description=episode['description'], + reference_time=datetime.now(timezone.utc), + ) + logger.info(f'Added episode: Memgraph Demo {i} ({episode["type"].value})') + + logger.info('Episodes added successfully!') + + ################################################# + # BASIC SEARCH + ################################################# + # The simplest way to retrieve relationships (edges) + # from Graphiti is using the search method, which + # performs a hybrid search combining semantic + # similarity and BM25 text retrieval. + ################################################# + + # Perform a hybrid search combining semantic similarity and BM25 retrieval + logger.info('\n=== BASIC SEARCH ===') + search_query = 'Who was the California Attorney General?' + logger.info(f'Search query: {search_query}') + + # Perform semantic search across edges (relationships) + results = await graphiti.search(search_query) + + logger.info('Search results:') + for result in results[:3]: # Show top 3 results + logger.info(f'UUID: {result.uuid}') + logger.info(f'Fact: {result.fact}') + if hasattr(result, 'valid_at') and result.valid_at: + logger.info(f'Valid from: {result.valid_at}') + if hasattr(result, 'invalid_at') and result.invalid_at: + logger.info(f'Valid until: {result.invalid_at}') + logger.info('---') + + ################################################# + # CENTER NODE SEARCH + ################################################# + # For more contextually relevant results, you can + # use a center node to rerank search results based + # on their graph distance to a specific node + ################################################# + + if results and len(results) > 0: + logger.info('\n=== CENTER NODE SEARCH ===') + # Get the source node UUID from the top result + center_node_uuid = results[0].source_node_uuid + logger.info(f'Using center node UUID: {center_node_uuid}') + + # Perform graph-based search using the source node + reranked_results = await graphiti.search( + search_query, center_node_uuid=center_node_uuid + ) + + logger.info('Reranked search results:') + for result in reranked_results[:3]: + logger.info(f'UUID: {result.uuid}') + logger.info(f'Fact: {result.fact}') + if hasattr(result, 'valid_at') and result.valid_at: + logger.info(f'Valid from: {result.valid_at}') + if hasattr(result, 'invalid_at') and result.invalid_at: + logger.info(f'Valid until: {result.invalid_at}') + logger.info('---') + else: + logger.info('No results found in the initial search to use as center node.') + + ################################################# + # NODE SEARCH WITH RECIPES + ################################################# + # Graphiti provides predefined search configurations + # (recipes) that optimize search for specific patterns + # and use cases. + ################################################# + + logger.info('\n=== NODE SEARCH WITH RECIPES ===') + recipe_query = 'California Governor' + logger.info(f'Recipe search query: {recipe_query}') + + # Use hybrid search recipe for balanced semantic and keyword matching + node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True) + node_search_config.limit = 5 # Limit to 5 results + + # Execute the node search + node_search_results = await graphiti._search( + query=recipe_query, + config=node_search_config, + ) + + logger.info('Node search results:') + for node in node_search_results.nodes: + logger.info(f'Node UUID: {node.uuid}') + logger.info(f'Node Name: {node.name}') + node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary + logger.info(f'Content Summary: {node_summary}') + logger.info(f'Node Labels: {", ".join(node.labels)}') + logger.info(f'Created At: {node.created_at}') + if hasattr(node, 'attributes') and node.attributes: + logger.info('Attributes:') + for key, value in node.attributes.items(): + logger.info(f' {key}: {value}') + ################################################# + # SUMMARY STATISTICS + ################################################# + # Get overall statistics about the knowledge graph + ################################################# + + logger.info('\n=== SUMMARY ===') + logger.info('Memgraph database populated successfully!') + logger.info('Knowledge graph is ready for queries and exploration.') + + except Exception as e: + logger.error(f'An error occurred: {e}') + raise + + finally: + ################################################# + # CLEANUP + ################################################# + # Always close the connection to Memgraph when + # finished to properly release resources + ################################################# + + # Close the connection + await graphiti.close() + logger.info('Connection closed.') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/graphiti_core/driver/__init__.py b/graphiti_core/driver/__init__.py index 39c00293..02804df2 100644 --- a/graphiti_core/driver/__init__.py +++ b/graphiti_core/driver/__init__.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from neo4j import Neo4jDriver +from .neo4j_driver import Neo4jDriver +from .memgraph_driver import MemgraphDriver -__all__ = ['Neo4jDriver'] +__all__ = ['Neo4jDriver', 'MemgraphDriver'] diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 670a7426..e2e81b33 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -29,6 +29,7 @@ class GraphProvider(Enum): FALKORDB = 'falkordb' KUZU = 'kuzu' NEPTUNE = 'neptune' + MEMGRAPH = 'memgraph' class GraphDriverSession(ABC): diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py new file mode 100644 index 00000000..2c331982 --- /dev/null +++ b/graphiti_core/driver/memgraph_driver.py @@ -0,0 +1,67 @@ +""" +Copyright 2025, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +from collections.abc import Coroutine +from typing import Any + +from neo4j import AsyncGraphDatabase, EagerResult +from typing_extensions import LiteralString + +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider + +logger = logging.getLogger(__name__) + + +class MemgraphDriver(GraphDriver): + provider = GraphProvider.MEMGRAPH + + def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'): + super().__init__() + self.client = AsyncGraphDatabase.driver( + uri=uri, + auth=(user or '', password or ''), + ) + self._database = database + + async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: + # Check if database_ is provided in kwargs. + # If not populated, set the value to retain backwards compatibility + params = kwargs.pop('params', None) + if params is None: + params = {} + params.setdefault('database_', self._database) + + try: + result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs) + except Exception as e: + logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}') + raise + + return result + + def session(self, database: str | None = None) -> GraphDriverSession: + _database = database or self._database + return self.client.session(database=_database) # type: ignore + + async def close(self) -> None: + return await self.client.close() + + def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: + # TODO + return self.client.execute_query( + 'SHOW INDEX INFO;', + ) diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 71fa0547..e4d86e92 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -45,6 +45,30 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]: if provider == GraphProvider.KUZU: return [] + if provider == GraphProvider.MEMGRAPH: + return [ + 'CREATE INDEX ON :Entity(uuid)', + 'CREATE INDEX ON :Entity(group_id)', + 'CREATE INDEX ON :Entity(name)', + 'CREATE INDEX ON :Entity(created_at)', + 'CREATE INDEX ON :Episodic(uuid)', + 'CREATE INDEX ON :Episodic(group_id)', + 'CREATE INDEX ON :Episodic(created_at)', + 'CREATE INDEX ON :Episodic(valid_at)', + 'CREATE INDEX ON :Community(uuid)', + 'CREATE INDEX ON :Community(group_id)', + 'CREATE INDEX ON :RELATES_TO(uuid)', + 'CREATE INDEX ON :RELATES_TO(group_id)', + 'CREATE INDEX ON :RELATES_TO(name)', + 'CREATE INDEX ON :RELATES_TO(created_at)', + 'CREATE INDEX ON :RELATES_TO(expired_at)', + 'CREATE INDEX ON :RELATES_TO(valid_at)', + 'CREATE INDEX ON :RELATES_TO(invalid_at)', + 'CREATE INDEX ON :MENTIONS(uuid)', + 'CREATE INDEX ON :MENTIONS(group_id)', + 'CREATE INDEX ON :HAS_MEMBER(uuid)', + ] + return [ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', @@ -86,6 +110,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]: "CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);", ] + if provider == GraphProvider.MEMGRAPH: + return [ + """CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id)""", + """CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id)""", + """CREATE TEXT INDEX community_name ON :Community(name, group_id)""", + """CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id)""", + ] + return [ """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", @@ -107,6 +139,9 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) label = INDEX_TO_LABEL_KUZU_MAPPING[name] return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)" + if provider == GraphProvider.MEMGRAPH: + return f'CALL text_search.search("{name}", {query}) YIELD node RETURN node LIMIT $limit' + return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' @@ -118,6 +153,9 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: if provider == GraphProvider.KUZU: return f'array_cosine_similarity({vec1}, {vec2})' + if provider == GraphProvider.MEMGRAPH: + return "TODO" + return f'vector.similarity.cosine({vec1}, {vec2})' @@ -130,4 +168,7 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s label = INDEX_TO_LABEL_KUZU_MAPPING[name] return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)" + if provider == GraphProvider.MEMGRAPH: + return f'CALL text_search.search("{name}", $query) YIELD node RETURN node LIMIT $limit' + return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index 5b8d5402..0ed6fc77 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -98,6 +98,15 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str: e.attributes = $attributes RETURN e.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + MATCH (source:Entity {uuid: $edge_data.source_uuid}) + MATCH (target:Entity {uuid: $edge_data.target_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) + SET e = $edge_data + WITH e e.fact_embedding = $edge_data.fact_embedding + RETURN e.uuid AS uuid + """ case _: # Neo4j return """ MATCH (source:Entity {uuid: $edge_data.source_uuid}) @@ -151,6 +160,16 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: e.attributes = $attributes RETURN e.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + UNWIND $entity_edges AS edge + MATCH (source:Entity {uuid: edge.source_node_uuid}) + MATCH (target:Entity {uuid: edge.target_node_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) + SET e = edge + WITH e, edge e.fact_embedding = edge.fact_embedding + RETURN edge.uuid AS uuid + """ case _: return """ UNWIND $entity_edges AS edge @@ -240,7 +259,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str: e.created_at = $created_at RETURN e.uuid AS uuid """ - case _: # Neo4j + case _: # Neo4j and Memgraph return """ MATCH (community:Community {uuid: $community_uuid}) MATCH (node:Entity | Community {uuid: $entity_uuid}) diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 9627e566..081c0252 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -49,7 +49,7 @@ def get_episode_node_save_query(provider: GraphProvider) -> str: entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid """ - case _: # Neo4j + case _: # Neo4j and Memgraph return """ MERGE (n:Episodic {uuid: $uuid}) SET n:$($group_label) @@ -92,7 +92,7 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} RETURN n.uuid AS uuid """ - case _: # Neo4j + case _: # Neo4j and Memgraph return """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) @@ -162,6 +162,14 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",") RETURN n.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + MERGE (n:Entity {{uuid: $entity_data.uuid}}) + SET n:{labels} + SET n = $entity_data + WITH n SET n.name_embedding = $entity_data.name_embedding + RETURN n.uuid AS uuid + """ case _: return f""" MERGE (n:Entity {{uuid: $entity_data.uuid}}) @@ -223,6 +231,15 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) n.attributes = $attributes RETURN n.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + UNWIND $nodes AS node + MERGE (n:Entity {uuid: node.uuid}) + SET n:$(node.labels) + SET n = node + WITH n, node SET n.name_embedding = node.name_embedding + RETURN n.uuid AS uuid + """ case _: # Neo4j return """ UNWIND $nodes AS node @@ -284,6 +301,13 @@ def get_community_node_save_query(provider: GraphProvider) -> str: n.summary = $summary RETURN n.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + MERGE (n:Community {uuid: $uuid}) + SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at} + WITH n SET n.name_embedding = $name_embedding + RETURN n.uuid AS uuid + """ case _: # Neo4j return """ MERGE (n:Community {uuid: $uuid}) diff --git a/tests/helpers_test.py b/tests/helpers_test.py index 58ef0c3e..3b5d1e4e 100644 --- a/tests/helpers_test.py +++ b/tests/helpers_test.py @@ -55,6 +55,14 @@ if os.getenv('DISABLE_KUZU') is None: except ImportError: raise +if os.getenv('DISABLE_MEMGRAPH') is None: + try: + from graphiti_core.driver.memgraph_driver import MemgraphDriver + + drivers.append(GraphProvider.MEMGRAPH) + except ImportError: + raise + # Disable Neptune for now os.environ['DISABLE_NEPTUNE'] = 'True' if os.getenv('DISABLE_NEPTUNE') is None: @@ -74,6 +82,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379') FALKORDB_USER = os.getenv('FALKORDB_USER', None) FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) +MEMGRAPH_URI = os.getenv('MEMGRAPH_URI', 'bolt://localhost:7687') +MEMGRAPH_USER = os.getenv('MEMGRAPH_USER', '') +MEMGRAPH_PASSWORD = os.getenv('MEMGRAPH_PASSWORD', '') + NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost') NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182) AOSS_HOST = os.getenv('AOSS_HOST', None) @@ -103,6 +115,12 @@ def get_driver(provider: GraphProvider) -> GraphDriver: db=KUZU_DB, ) return driver + elif provider == GraphProvider.MEMGRAPH: + return MemgraphDriver( + uri=MEMGRAPH_URI, + user=MEMGRAPH_USER, + password=MEMGRAPH_PASSWORD, + ) elif provider == GraphProvider.NEPTUNE: return NeptuneDriver( host=NEPTUNE_HOST, From b0d004142911b3185ffa32dd52142582c4e46e61 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Mon, 8 Sep 2025 16:26:22 +0200 Subject: [PATCH 2/7] wip fix quick start issues --- graphiti_core/driver/memgraph_driver.py | 59 ++++++++++++++++--------- graphiti_core/graph_queries.py | 54 +++++++++++----------- graphiti_core/search/search_utils.py | 4 ++ 3 files changed, 69 insertions(+), 48 deletions(-) diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py index 2c331982..e3377994 100644 --- a/graphiti_core/driver/memgraph_driver.py +++ b/graphiti_core/driver/memgraph_driver.py @@ -18,7 +18,7 @@ import logging from collections.abc import Coroutine from typing import Any -from neo4j import AsyncGraphDatabase, EagerResult +from neo4j import GraphDatabase from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider @@ -31,37 +31,54 @@ class MemgraphDriver(GraphDriver): def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'): super().__init__() - self.client = AsyncGraphDatabase.driver( + self.client = GraphDatabase.driver( uri=uri, auth=(user or '', password or ''), ) self._database = database - async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: - # Check if database_ is provided in kwargs. - # If not populated, set the value to retain backwards compatibility + async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> tuple[list, Any, Any]: + """ + Execute a Cypher query against Memgraph using implicit transactions. + Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface. + """ + # Extract parameters from kwargs params = kwargs.pop('params', None) if params is None: - params = {} - params.setdefault('database_', self._database) - - try: - result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs) - except Exception as e: - logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}') - raise - - return result + # If no 'params' key, use the remaining kwargs as parameters + # but first extract database-specific parameters + database = kwargs.pop('database_', self._database) + kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param) + + # All remaining kwargs are query parameters + params = kwargs + else: + # Extract database parameter if params was provided separately + database = kwargs.pop('database_', self._database) + kwargs.pop('parameters_', None) # Remove if present + + with self.client.session(database=database) as session: + try: + # Debug: Print the query and parameters + print(f"DEBUG - Memgraph Query: {cypher_query_}") + print(f"DEBUG - Memgraph Params: {params}") + + result = session.run(cypher_query_, params) + records = list(result) + summary = result.consume() + keys = result.keys() + return (records, summary, keys) + except Exception as e: + logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}') + raise def session(self, database: str | None = None) -> GraphDriverSession: _database = database or self._database return self.client.session(database=_database) # type: ignore async def close(self) -> None: - return await self.client.close() + return self.client.close() - def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: - # TODO - return self.client.execute_query( - 'SHOW INDEX INFO;', - ) + def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: + # TODO: Implement index deletion for Memgraph + raise NotImplementedError("Index deletion not implemented for MemgraphDriver") \ No newline at end of file diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index e4d86e92..56c2fb6c 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -47,26 +47,26 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]: if provider == GraphProvider.MEMGRAPH: return [ - 'CREATE INDEX ON :Entity(uuid)', - 'CREATE INDEX ON :Entity(group_id)', - 'CREATE INDEX ON :Entity(name)', - 'CREATE INDEX ON :Entity(created_at)', - 'CREATE INDEX ON :Episodic(uuid)', - 'CREATE INDEX ON :Episodic(group_id)', - 'CREATE INDEX ON :Episodic(created_at)', - 'CREATE INDEX ON :Episodic(valid_at)', - 'CREATE INDEX ON :Community(uuid)', - 'CREATE INDEX ON :Community(group_id)', - 'CREATE INDEX ON :RELATES_TO(uuid)', - 'CREATE INDEX ON :RELATES_TO(group_id)', - 'CREATE INDEX ON :RELATES_TO(name)', - 'CREATE INDEX ON :RELATES_TO(created_at)', - 'CREATE INDEX ON :RELATES_TO(expired_at)', - 'CREATE INDEX ON :RELATES_TO(valid_at)', - 'CREATE INDEX ON :RELATES_TO(invalid_at)', - 'CREATE INDEX ON :MENTIONS(uuid)', - 'CREATE INDEX ON :MENTIONS(group_id)', - 'CREATE INDEX ON :HAS_MEMBER(uuid)', + 'CREATE INDEX ON :Entity(uuid);', + 'CREATE INDEX ON :Entity(group_id);', + 'CREATE INDEX ON :Entity(name);', + 'CREATE INDEX ON :Entity(created_at);', + 'CREATE INDEX ON :Episodic(uuid);', + 'CREATE INDEX ON :Episodic(group_id);', + 'CREATE INDEX ON :Episodic(created_at);', + 'CREATE INDEX ON :Episodic(valid_at);', + 'CREATE INDEX ON :Community(uuid);', + 'CREATE INDEX ON :Community(group_id);', + 'CREATE INDEX ON :RELATES_TO(uuid);', + 'CREATE INDEX ON :RELATES_TO(group_id);', + 'CREATE INDEX ON :RELATES_TO(name);', + 'CREATE INDEX ON :RELATES_TO(created_at);', + 'CREATE INDEX ON :RELATES_TO(expired_at);', + 'CREATE INDEX ON :RELATES_TO(valid_at);', + 'CREATE INDEX ON :RELATES_TO(invalid_at);', + 'CREATE INDEX ON :MENTIONS(uuid);', + 'CREATE INDEX ON :MENTIONS(group_id);', + 'CREATE INDEX ON :HAS_MEMBER(uuid);', ] return [ @@ -112,10 +112,10 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]: if provider == GraphProvider.MEMGRAPH: return [ - """CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id)""", - """CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id)""", - """CREATE TEXT INDEX community_name ON :Community(name, group_id)""", - """CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id)""", + """CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id);""", + """CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id);""", + """CREATE TEXT INDEX community_name ON :Community(name, group_id);""", + """CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id);""", ] return [ @@ -140,7 +140,7 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)" if provider == GraphProvider.MEMGRAPH: - return f'CALL text_search.search("{name}", {query}) YIELD node RETURN node LIMIT $limit' + return f'CALL text_search.search("{name}", {query}) YIELD node' return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' @@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: return f'array_cosine_similarity({vec1}, {vec2})' if provider == GraphProvider.MEMGRAPH: - return "TODO" + return f'CALL vector_search.cosine_similarity({vec1}, {vec2}) YIELD similarity RETURN similarity AS score' return f'vector.similarity.cosine({vec1}, {vec2})' @@ -169,6 +169,6 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)" if provider == GraphProvider.MEMGRAPH: - return f'CALL text_search.search("{name}", $query) YIELD node RETURN node LIMIT $limit' + return f'CALL text_search.search_edges("{name}", $query) YIELD node' return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 9bce2e3f..14788816 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -562,6 +562,8 @@ async def node_fulltext_search( yield_query = 'YIELD node AS n, score' if driver.provider == GraphProvider.KUZU: yield_query = 'WITH node AS n, score' + elif driver.provider == GraphProvider.MEMGRAPH: + yield_query = ' WITH node AS n, 1.0 AS score' # Memgraph: continue from YIELD node if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue @@ -968,6 +970,8 @@ async def community_fulltext_search( yield_query = 'YIELD node AS c, score' if driver.provider == GraphProvider.KUZU: yield_query = 'WITH node AS c, score' + elif driver.provider == GraphProvider.MEMGRAPH: + yield_query = ' WITH node AS c, 1.0 AS score' # Memgraph: continue from YIELD node if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue From 1641b9c1c1af815d64f0ab8c91fa82ee9f873f07 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 9 Sep 2025 11:50:24 +0200 Subject: [PATCH 3/7] quickstart working with memgraph --- graphiti_core/driver/memgraph_driver.py | 18 +++++------ graphiti_core/graph_queries.py | 6 ++-- graphiti_core/models/edges/edge_db_queries.py | 4 +-- graphiti_core/models/nodes/node_db_queries.py | 31 ++++++++++++++++--- graphiti_core/search/search_utils.py | 2 +- graphiti_core/utils/bulk_utils.py | 1 + 6 files changed, 41 insertions(+), 21 deletions(-) diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py index e3377994..d14c8d83 100644 --- a/graphiti_core/driver/memgraph_driver.py +++ b/graphiti_core/driver/memgraph_driver.py @@ -18,7 +18,7 @@ import logging from collections.abc import Coroutine from typing import Any -from neo4j import GraphDatabase +from neo4j import AsyncGraphDatabase from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider @@ -31,7 +31,7 @@ class MemgraphDriver(GraphDriver): def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'): super().__init__() - self.client = GraphDatabase.driver( + self.client = AsyncGraphDatabase.driver( uri=uri, auth=(user or '', password or ''), ) @@ -57,15 +57,11 @@ class MemgraphDriver(GraphDriver): database = kwargs.pop('database_', self._database) kwargs.pop('parameters_', None) # Remove if present - with self.client.session(database=database) as session: + async with self.client.session(database=database) as session: try: - # Debug: Print the query and parameters - print(f"DEBUG - Memgraph Query: {cypher_query_}") - print(f"DEBUG - Memgraph Params: {params}") - - result = session.run(cypher_query_, params) - records = list(result) - summary = result.consume() + result = await session.run(cypher_query_, params) + records = [record async for record in result] + summary = await result.consume() keys = result.keys() return (records, summary, keys) except Exception as e: @@ -77,7 +73,7 @@ class MemgraphDriver(GraphDriver): return self.client.session(database=_database) # type: ignore async def close(self) -> None: - return self.client.close() + return await self.client.close() def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: # TODO: Implement index deletion for Memgraph diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 56c2fb6c..9f85b1c1 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -140,7 +140,7 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)" if provider == GraphProvider.MEMGRAPH: - return f'CALL text_search.search("{name}", {query}) YIELD node' + return f'CALL text_search.search_all("{name}", {query})' return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' @@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: return f'array_cosine_similarity({vec1}, {vec2})' if provider == GraphProvider.MEMGRAPH: - return f'CALL vector_search.cosine_similarity({vec1}, {vec2}) YIELD similarity RETURN similarity AS score' + return f'cosineSimilarity({vec1}, {vec2})' return f'vector.similarity.cosine({vec1}, {vec2})' @@ -169,6 +169,6 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)" if provider == GraphProvider.MEMGRAPH: - return f'CALL text_search.search_edges("{name}", $query) YIELD node' + return f'CALL text_search.search_all_edges("{name}", $query)' return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index 0ed6fc77..abb0edbb 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -167,8 +167,8 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: MATCH (target:Entity {uuid: edge.target_node_uuid}) MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) SET e = edge - WITH e, edge e.fact_embedding = edge.fact_embedding - RETURN edge.uuid AS uuid + SET e.fact_embedding = edge.fact_embedding + RETURN edge.uuid AS uuid; """ case _: return """ diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 081c0252..7cf075f0 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -92,7 +92,27 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} RETURN n.uuid AS uuid """ - case _: # Neo4j and Memgraph + case GraphProvider.MEMGRAPH: + return """ + UNWIND $episodes AS episode + MERGE (n:Episodic {uuid: episode.uuid}) + FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END | + SET n:`${episode.group_label}` + ) + SET n = { + uuid: episode.uuid, + name: episode.name, + group_id: episode.group_id, + source_description: episode.source_description, + source: episode.source, + content: episode.content, + entity_edges: episode.entity_edges, + created_at: episode.created_at, + valid_at: episode.valid_at + } + RETURN n.uuid AS uuid; + """ + case _: # Neo4j return """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) @@ -235,10 +255,13 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) return """ UNWIND $nodes AS node MERGE (n:Entity {uuid: node.uuid}) - SET n:$(node.labels) + FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END | + SET n:`${label}` + ) SET n = node - WITH n, node SET n.name_embedding = node.name_embedding - RETURN n.uuid AS uuid + WITH n, node + SET n.name_embedding = node.name_embedding + RETURN n.uuid AS uuid; """ case _: # Neo4j return """ diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 14788816..ff08239d 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -563,7 +563,7 @@ async def node_fulltext_search( if driver.provider == GraphProvider.KUZU: yield_query = 'WITH node AS n, score' elif driver.provider == GraphProvider.MEMGRAPH: - yield_query = ' WITH node AS n, 1.0 AS score' # Memgraph: continue from YIELD node + yield_query = ' YIELD node AS n WITH n, 1.0 AS score' if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 426cbe90..ece4aea8 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -191,6 +191,7 @@ async def add_nodes_and_edges_bulk_tx( for edge in episodic_edges: await tx.run(episodic_edge_query, **edge.model_dump()) else: + print(episodes) await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes) await tx.run( From 78cdab98d9b381bac636b679fb1143c02299f3db Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 9 Sep 2025 11:58:45 +0200 Subject: [PATCH 4/7] remove print statement --- graphiti_core/utils/bulk_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index ece4aea8..426cbe90 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -191,7 +191,6 @@ async def add_nodes_and_edges_bulk_tx( for edge in episodic_edges: await tx.run(episodic_edge_query, **edge.model_dump()) else: - print(episodes) await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes) await tx.run( From e43a756ac1fd0a9e30ead8536045768f9d4ecc3c Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 9 Sep 2025 16:43:23 +0200 Subject: [PATCH 5/7] run unit tests and change api where needed --- .github/workflows/unit_tests.yml | 15 +++++++ docker-compose.yml | 29 ++++++++++++++ examples/quickstart/quickstart_memgraph.py | 6 ++- graphiti_core/driver/__init__.py | 2 +- graphiti_core/driver/memgraph_driver.py | 16 +++++--- graphiti_core/graph_queries.py | 2 +- graphiti_core/models/edges/edge_db_queries.py | 2 +- graphiti_core/models/nodes/node_db_queries.py | 40 +++++++++---------- graphiti_core/nodes.py | 2 +- 9 files changed, 80 insertions(+), 34 deletions(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index cf1053a1..e7bd93ad 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -29,6 +29,11 @@ jobs: NEO4J_AUTH: neo4j/testpass NEO4J_PLUGINS: '["apoc"]' options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10 + memgraph: + image: memgraph/memgraph:latest + ports: + - 7688:7687 + options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10 steps: - uses: actions/checkout@v4 - name: Set up Python @@ -77,3 +82,13 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | uv run pytest tests/test_*_int.py -k "neo4j" + - name: Run Memgraph integration tests + env: + PYTHONPATH: ${{ github.workspace }} + MEMGRAPH_URI: bolt://localhost:7688 + MEMGRAPH_USER: + MEMGRAPH_PASSWORD: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + uv run pytest tests/test_*_int.py -k "memgraph" + \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 0400692c..38cdd9bf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -44,5 +44,34 @@ services: environment: - NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD} + memgraph: + image: memgraph/memgraph:latest + healthcheck: + test: + [ + "CMD", + "mg_client", + "--host", + "localhost", + "--port", + "7687", + "--use-ssl=false", + "--query", + "RETURN 1;" + ] + interval: 5s + timeout: 10s + retries: 10 + start_period: 3s + ports: + - "7688:7687" # Bolt (using different port to avoid conflict) + volumes: + - memgraph_data:/var/lib/memgraph + environment: + - MEMGRAPH_USER=${MEMGRAPH_USER:-} + - MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-} + command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"] + volumes: neo4j_data: + memgraph_data: diff --git a/examples/quickstart/quickstart_memgraph.py b/examples/quickstart/quickstart_memgraph.py index 911991ec..93fb4ace 100644 --- a/examples/quickstart/quickstart_memgraph.py +++ b/examples/quickstart/quickstart_memgraph.py @@ -48,7 +48,9 @@ load_dotenv() # Memgraph connection parameters # Make sure Memgraph is running (default port 7687, same as Neo4j) memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687') -memgraph_user = os.environ.get('MEMGRAPH_USER', '') # Memgraph often doesn't require auth by default +memgraph_user = os.environ.get( + 'MEMGRAPH_USER', '' +) # Memgraph often doesn't require auth by default memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '') if not memgraph_uri: @@ -66,7 +68,7 @@ async def main(): # Initialize Memgraph driver memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password) - + # Initialize Graphiti with Memgraph connection graphiti = Graphiti(graph_driver=memgraph_driver) diff --git a/graphiti_core/driver/__init__.py b/graphiti_core/driver/__init__.py index 02804df2..4babebd1 100644 --- a/graphiti_core/driver/__init__.py +++ b/graphiti_core/driver/__init__.py @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from .neo4j_driver import Neo4jDriver from .memgraph_driver import MemgraphDriver +from .neo4j_driver import Neo4jDriver __all__ = ['Neo4jDriver', 'MemgraphDriver'] diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py index d14c8d83..d419bee9 100644 --- a/graphiti_core/driver/memgraph_driver.py +++ b/graphiti_core/driver/memgraph_driver.py @@ -29,7 +29,9 @@ logger = logging.getLogger(__name__) class MemgraphDriver(GraphDriver): provider = GraphProvider.MEMGRAPH - def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'): + def __init__( + self, uri: str, user: str | None, password: str | None, database: str = 'memgraph' + ): super().__init__() self.client = AsyncGraphDatabase.driver( uri=uri, @@ -37,7 +39,9 @@ class MemgraphDriver(GraphDriver): ) self._database = database - async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> tuple[list, Any, Any]: + async def execute_query( + self, cypher_query_: LiteralString, **kwargs: Any + ) -> tuple[list, Any, Any]: """ Execute a Cypher query against Memgraph using implicit transactions. Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface. @@ -49,20 +53,20 @@ class MemgraphDriver(GraphDriver): # but first extract database-specific parameters database = kwargs.pop('database_', self._database) kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param) - + # All remaining kwargs are query parameters params = kwargs else: # Extract database parameter if params was provided separately database = kwargs.pop('database_', self._database) kwargs.pop('parameters_', None) # Remove if present - + async with self.client.session(database=database) as session: try: result = await session.run(cypher_query_, params) records = [record async for record in result] summary = await result.consume() - keys = result.keys() + keys = result.keys() return (records, summary, keys) except Exception as e: logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}') @@ -77,4 +81,4 @@ class MemgraphDriver(GraphDriver): def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: # TODO: Implement index deletion for Memgraph - raise NotImplementedError("Index deletion not implemented for MemgraphDriver") \ No newline at end of file + raise NotImplementedError('Index deletion not implemented for MemgraphDriver') diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 9f85b1c1..ffee09f2 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: return f'array_cosine_similarity({vec1}, {vec2})' if provider == GraphProvider.MEMGRAPH: - return f'cosineSimilarity({vec1}, {vec2})' + return f'vector_search.cosine_similarity({vec1}, {vec2})' return f'vector.similarity.cosine({vec1}, {vec2})' diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index abb0edbb..3554701e 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -104,7 +104,7 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str: MATCH (target:Entity {uuid: $edge_data.target_uuid}) MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) SET e = $edge_data - WITH e e.fact_embedding = $edge_data.fact_embedding + SET e.fact_embedding = $edge_data.fact_embedding RETURN e.uuid AS uuid """ case _: # Neo4j diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 7cf075f0..867b96c8 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -49,7 +49,18 @@ def get_episode_node_save_query(provider: GraphProvider) -> str: entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid """ - case _: # Neo4j and Memgraph + case GraphProvider.MEMGRAPH: + return """ + MERGE (n:Episodic {uuid: $uuid}) + SET n = { + uuid: $uuid, name: $name, group_id: $group_id, + source_description: $source_description, source: $source, + content: $content, entity_edges: $entity_edges, + created_at: $created_at, valid_at: $valid_at + } + RETURN n.uuid AS uuid + """ + case _: # Neo4j return """ MERGE (n:Episodic {uuid: $uuid}) SET n:$($group_label) @@ -96,23 +107,11 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: return """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) - FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END | - SET n:`${episode.group_label}` - ) - SET n = { - uuid: episode.uuid, - name: episode.name, - group_id: episode.group_id, - source_description: episode.source_description, - source: episode.source, - content: episode.content, - entity_edges: episode.entity_edges, - created_at: episode.created_at, - valid_at: episode.valid_at - } - RETURN n.uuid AS uuid; + SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, + entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} + RETURN n.uuid AS uuid """ - case _: # Neo4j + case _: # Neo4j return """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) @@ -183,7 +182,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: RETURN n.uuid AS uuid """ case GraphProvider.MEMGRAPH: - return """ + return f""" MERGE (n:Entity {{uuid: $entity_data.uuid}}) SET n:{labels} SET n = $entity_data @@ -255,13 +254,10 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) return """ UNWIND $nodes AS node MERGE (n:Entity {uuid: node.uuid}) - FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END | - SET n:`${label}` - ) SET n = node WITH n, node SET n.name_embedding = node.name_embedding - RETURN n.uuid AS uuid; + RETURN n.uuid AS uuid """ case _: # Neo4j return """ diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index ef28a94d..ab9b2dbc 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -299,7 +299,7 @@ class EpisodicNode(Node): 'source': self.source.value, } - if driver.provider == GraphProvider.NEO4J: + if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH): episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '') result = await driver.execute_query( From 489dffdc0c04336a8995e25cafd319d0ce796fce Mon Sep 17 00:00:00 2001 From: DavIvek Date: Wed, 10 Sep 2025 09:34:04 +0200 Subject: [PATCH 6/7] few improvements in mg driver --- graphiti_core/driver/memgraph_driver.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py index d419bee9..2b23b165 100644 --- a/graphiti_core/driver/memgraph_driver.py +++ b/graphiti_core/driver/memgraph_driver.py @@ -39,9 +39,7 @@ class MemgraphDriver(GraphDriver): ) self._database = database - async def execute_query( - self, cypher_query_: LiteralString, **kwargs: Any - ) -> tuple[list, Any, Any]: + async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Any: """ Execute a Cypher query against Memgraph using implicit transactions. Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface. @@ -49,28 +47,25 @@ class MemgraphDriver(GraphDriver): # Extract parameters from kwargs params = kwargs.pop('params', None) if params is None: - # If no 'params' key, use the remaining kwargs as parameters - # but first extract database-specific parameters database = kwargs.pop('database_', self._database) - kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param) - - # All remaining kwargs are query parameters + kwargs.pop('parameters_', None) params = kwargs else: - # Extract database parameter if params was provided separately database = kwargs.pop('database_', self._database) - kwargs.pop('parameters_', None) # Remove if present + kwargs.pop('parameters_', None) async with self.client.session(database=database) as session: try: result = await session.run(cypher_query_, params) + keys = result.keys() records = [record async for record in result] summary = await result.consume() - keys = result.keys() return (records, summary, keys) except Exception as e: logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}') raise + finally: + await session.close() def session(self, database: str | None = None) -> GraphDriverSession: _database = database or self._database @@ -80,5 +75,4 @@ class MemgraphDriver(GraphDriver): return await self.client.close() def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: - # TODO: Implement index deletion for Memgraph - raise NotImplementedError('Index deletion not implemented for MemgraphDriver') + return self.client.execute_query('DROP ALL INDEXES') From 9cc43aeb5f982a8e963970f69e0b4c03049354ee Mon Sep 17 00:00:00 2001 From: DavIvek Date: Wed, 10 Sep 2025 17:32:15 +0200 Subject: [PATCH 7/7] add limit and score --- graphiti_core/graph_queries.py | 4 ++-- graphiti_core/search/search_utils.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index ffee09f2..7ecafeca 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -140,7 +140,7 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)" if provider == GraphProvider.MEMGRAPH: - return f'CALL text_search.search_all("{name}", {query})' + return f'CALL text_search.search_all("{name}", {query}, {limit})' return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' @@ -169,6 +169,6 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)" if provider == GraphProvider.MEMGRAPH: - return f'CALL text_search.search_all_edges("{name}", $query)' + return f'CALL text_search.search_all_edges("{name}", $query, {limit})' return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index ff08239d..9bce2e3f 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -562,8 +562,6 @@ async def node_fulltext_search( yield_query = 'YIELD node AS n, score' if driver.provider == GraphProvider.KUZU: yield_query = 'WITH node AS n, score' - elif driver.provider == GraphProvider.MEMGRAPH: - yield_query = ' YIELD node AS n WITH n, 1.0 AS score' if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue @@ -970,8 +968,6 @@ async def community_fulltext_search( yield_query = 'YIELD node AS c, score' if driver.provider == GraphProvider.KUZU: yield_query = 'WITH node AS c, score' - elif driver.provider == GraphProvider.MEMGRAPH: - yield_query = ' WITH node AS c, 1.0 AS score' # Memgraph: continue from YIELD node if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue