diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index c2c664a6..91879de3 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -61,6 +61,11 @@ jobs: NEO4J_AUTH: neo4j/testpass NEO4J_PLUGINS: '["apoc"]' options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10 + memgraph: + image: memgraph/memgraph:latest + ports: + - 7688:7687 + options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10 steps: - uses: actions/checkout@v4 - name: Set up Python @@ -98,3 +103,13 @@ jobs: tests/cross_encoder/test_bge_reranker_client_int.py \ tests/driver/test_falkordb_driver.py \ -m "not integration" + - name: Run Memgraph integration tests + env: + PYTHONPATH: ${{ github.workspace }} + MEMGRAPH_URI: bolt://localhost:7688 + MEMGRAPH_USER: + MEMGRAPH_PASSWORD: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + uv run pytest tests/test_*_int.py -k "memgraph" + diff --git a/README.md b/README.md index 57259675..3419fe4e 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ particularly suitable for applications requiring real-time interaction and preci Requirements: - Python 3.10 or higher -- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon +- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Memgraph 2.22+ / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend) - OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding) @@ -564,7 +564,7 @@ When you initialize a Graphiti instance, we collect: - **Graphiti version**: The version you're using - **Configuration choices**: - LLM provider type (OpenAI, Azure, Anthropic, etc.) - - Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics) + - Database backend (Neo4j, FalkorDB, Kuzu, Memgraph, Amazon Neptune Database or Neptune Analytics) - Embedder provider (OpenAI, Azure, Voyage, etc.) ### What We Don't Collect diff --git a/docker-compose.yml b/docker-compose.yml index 1b5ba06d..e69a9fcf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -87,6 +87,35 @@ services: - PORT=8001 - db_backend=falkordb + memgraph: + image: memgraph/memgraph:latest + healthcheck: + test: + [ + "CMD", + "mg_client", + "--host", + "localhost", + "--port", + "7687", + "--use-ssl=false", + "--query", + "RETURN 1;" + ] + interval: 5s + timeout: 10s + retries: 10 + start_period: 3s + ports: + - "7688:7687" # Bolt (using different port to avoid conflict) + volumes: + - memgraph_data:/var/lib/memgraph + environment: + - MEMGRAPH_USER=${MEMGRAPH_USER:-} + - MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-} + command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"] + volumes: neo4j_data: + memgraph_data: falkordb_data: diff --git a/examples/quickstart/README.md b/examples/quickstart/README.md index d254f2d6..19a45591 100644 --- a/examples/quickstart/README.md +++ b/examples/quickstart/README.md @@ -2,7 +2,7 @@ This example demonstrates the basic functionality of Graphiti, including: -1. Connecting to a Neo4j or FalkorDB database +1. Connecting to a Neo4j, FalkorDB, or Memgraph database 2. Initializing Graphiti indices and constraints 3. Adding episodes to the graph 4. Searching the graph with semantic and keyword matching @@ -17,6 +17,9 @@ This example demonstrates the basic functionality of Graphiti, including: - Neo4j Desktop installed and running - A local DBMS created and started in Neo4j Desktop - **For FalkorDB**: + - FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup) +- **For Memgraph**: + - Memgraph server running (see [Memgraph documentation](https://memgraph.com/docs/) for setup) - FalkorDB server running (see [FalkorDB documentation](https://docs.falkordb.com) for setup) - **For Amazon Neptune**: - Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup) @@ -44,6 +47,11 @@ export NEO4J_PASSWORD=password # Optional FalkorDB connection parameters (defaults shown) export FALKORDB_URI=falkor://localhost:6379 +# Optional Memgraph connection parameters (defaults shown) +export MEMGRAPH_URI=bolt://localhost:7687 +export MEMGRAPH_USER= # Optional - Memgraph doesn't require auth by default +export MEMGRAPH_PASSWORD= # Optional - Memgraph doesn't require auth by default + # Optional Amazon Neptune connection parameters NEPTUNE_HOST=your_neptune_host NEPTUNE_PORT=your_port_or_8182 @@ -65,13 +73,16 @@ python quickstart_neo4j.py # For FalkorDB python quickstart_falkordb.py +# For Memgraph +python quickstart_memgraph.py + # For Amazon Neptune python quickstart_neptune.py ``` ## What This Example Demonstrates -- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB +- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, FalkorDB, Memgraph, or Amazon Neptune - **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges - **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges) - **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance diff --git a/examples/quickstart/quickstart_memgraph.py b/examples/quickstart/quickstart_memgraph.py new file mode 100644 index 00000000..93fb4ace --- /dev/null +++ b/examples/quickstart/quickstart_memgraph.py @@ -0,0 +1,254 @@ +""" +Copyright 2025, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import json +import logging +import os +from datetime import datetime, timezone +from logging import INFO + +from dotenv import load_dotenv + +from graphiti_core import Graphiti +from graphiti_core.driver.memgraph_driver import MemgraphDriver +from graphiti_core.nodes import EpisodeType +from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF + +################################################# +# CONFIGURATION +################################################# +# Set up logging and environment variables for +# connecting to Memgraph database +################################################# + +# Configure logging +logging.basicConfig( + level=INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', +) +logger = logging.getLogger(__name__) + +load_dotenv() + +# Memgraph connection parameters +# Make sure Memgraph is running (default port 7687, same as Neo4j) +memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687') +memgraph_user = os.environ.get( + 'MEMGRAPH_USER', '' +) # Memgraph often doesn't require auth by default +memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '') + +if not memgraph_uri: + raise ValueError('MEMGRAPH_URI must be set') + + +async def main(): + ################################################# + # INITIALIZATION + ################################################# + # Connect to Memgraph and set up Graphiti indices + # This is required before using other Graphiti + # functionality + ################################################# + + # Initialize Memgraph driver + memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password) + + # Initialize Graphiti with Memgraph connection + graphiti = Graphiti(graph_driver=memgraph_driver) + + try: + # Initialize the graph database with graphiti's indices. This only needs to be done once. + await graphiti.build_indices_and_constraints() + + ################################################# + # ADDING EPISODES + ################################################# + # Episodes are the primary units of information + # in Graphiti. They can be text or structured JSON + # and are automatically processed to extract entities + # and relationships. + ################################################# + + # Example: Add Episodes + # Episodes list containing both text and JSON episodes + episodes = [ + { + 'content': 'Kamala Harris is the Attorney General of California. She was previously ' + 'the district attorney for San Francisco.', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + { + 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + { + 'content': { + 'name': 'Gavin Newsom', + 'position': 'Governor', + 'state': 'California', + 'predecessor': 'Jerry Brown', + }, + 'type': EpisodeType.json, + 'description': 'politician info', + }, + { + 'content': 'Jerry Brown was the predecessor of Gavin Newsom as Governor of California', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + ] + + # Add episodes to the graph + for i, episode in enumerate(episodes): + await graphiti.add_episode( + name=f'Memgraph Demo {i}', + episode_body=episode['content'] + if isinstance(episode['content'], str) + else json.dumps(episode['content']), + source=episode['type'], + source_description=episode['description'], + reference_time=datetime.now(timezone.utc), + ) + logger.info(f'Added episode: Memgraph Demo {i} ({episode["type"].value})') + + logger.info('Episodes added successfully!') + + ################################################# + # BASIC SEARCH + ################################################# + # The simplest way to retrieve relationships (edges) + # from Graphiti is using the search method, which + # performs a hybrid search combining semantic + # similarity and BM25 text retrieval. + ################################################# + + # Perform a hybrid search combining semantic similarity and BM25 retrieval + logger.info('\n=== BASIC SEARCH ===') + search_query = 'Who was the California Attorney General?' + logger.info(f'Search query: {search_query}') + + # Perform semantic search across edges (relationships) + results = await graphiti.search(search_query) + + logger.info('Search results:') + for result in results[:3]: # Show top 3 results + logger.info(f'UUID: {result.uuid}') + logger.info(f'Fact: {result.fact}') + if hasattr(result, 'valid_at') and result.valid_at: + logger.info(f'Valid from: {result.valid_at}') + if hasattr(result, 'invalid_at') and result.invalid_at: + logger.info(f'Valid until: {result.invalid_at}') + logger.info('---') + + ################################################# + # CENTER NODE SEARCH + ################################################# + # For more contextually relevant results, you can + # use a center node to rerank search results based + # on their graph distance to a specific node + ################################################# + + if results and len(results) > 0: + logger.info('\n=== CENTER NODE SEARCH ===') + # Get the source node UUID from the top result + center_node_uuid = results[0].source_node_uuid + logger.info(f'Using center node UUID: {center_node_uuid}') + + # Perform graph-based search using the source node + reranked_results = await graphiti.search( + search_query, center_node_uuid=center_node_uuid + ) + + logger.info('Reranked search results:') + for result in reranked_results[:3]: + logger.info(f'UUID: {result.uuid}') + logger.info(f'Fact: {result.fact}') + if hasattr(result, 'valid_at') and result.valid_at: + logger.info(f'Valid from: {result.valid_at}') + if hasattr(result, 'invalid_at') and result.invalid_at: + logger.info(f'Valid until: {result.invalid_at}') + logger.info('---') + else: + logger.info('No results found in the initial search to use as center node.') + + ################################################# + # NODE SEARCH WITH RECIPES + ################################################# + # Graphiti provides predefined search configurations + # (recipes) that optimize search for specific patterns + # and use cases. + ################################################# + + logger.info('\n=== NODE SEARCH WITH RECIPES ===') + recipe_query = 'California Governor' + logger.info(f'Recipe search query: {recipe_query}') + + # Use hybrid search recipe for balanced semantic and keyword matching + node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True) + node_search_config.limit = 5 # Limit to 5 results + + # Execute the node search + node_search_results = await graphiti._search( + query=recipe_query, + config=node_search_config, + ) + + logger.info('Node search results:') + for node in node_search_results.nodes: + logger.info(f'Node UUID: {node.uuid}') + logger.info(f'Node Name: {node.name}') + node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary + logger.info(f'Content Summary: {node_summary}') + logger.info(f'Node Labels: {", ".join(node.labels)}') + logger.info(f'Created At: {node.created_at}') + if hasattr(node, 'attributes') and node.attributes: + logger.info('Attributes:') + for key, value in node.attributes.items(): + logger.info(f' {key}: {value}') + ################################################# + # SUMMARY STATISTICS + ################################################# + # Get overall statistics about the knowledge graph + ################################################# + + logger.info('\n=== SUMMARY ===') + logger.info('Memgraph database populated successfully!') + logger.info('Knowledge graph is ready for queries and exploration.') + + except Exception as e: + logger.error(f'An error occurred: {e}') + raise + + finally: + ################################################# + # CLEANUP + ################################################# + # Always close the connection to Memgraph when + # finished to properly release resources + ################################################# + + # Close the connection + await graphiti.close() + logger.info('Connection closed.') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/graphiti_core/driver/__init__.py b/graphiti_core/driver/__init__.py index 39c00293..4babebd1 100644 --- a/graphiti_core/driver/__init__.py +++ b/graphiti_core/driver/__init__.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from neo4j import Neo4jDriver +from .memgraph_driver import MemgraphDriver +from .neo4j_driver import Neo4jDriver -__all__ = ['Neo4jDriver'] +__all__ = ['Neo4jDriver', 'MemgraphDriver'] diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index c1a355f3..73859b1d 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -44,6 +44,7 @@ class GraphProvider(Enum): FALKORDB = 'falkordb' KUZU = 'kuzu' NEPTUNE = 'neptune' + MEMGRAPH = 'memgraph' class GraphDriverSession(ABC): diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py new file mode 100644 index 00000000..2b23b165 --- /dev/null +++ b/graphiti_core/driver/memgraph_driver.py @@ -0,0 +1,78 @@ +""" +Copyright 2025, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +from collections.abc import Coroutine +from typing import Any + +from neo4j import AsyncGraphDatabase +from typing_extensions import LiteralString + +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider + +logger = logging.getLogger(__name__) + + +class MemgraphDriver(GraphDriver): + provider = GraphProvider.MEMGRAPH + + def __init__( + self, uri: str, user: str | None, password: str | None, database: str = 'memgraph' + ): + super().__init__() + self.client = AsyncGraphDatabase.driver( + uri=uri, + auth=(user or '', password or ''), + ) + self._database = database + + async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Any: + """ + Execute a Cypher query against Memgraph using implicit transactions. + Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface. + """ + # Extract parameters from kwargs + params = kwargs.pop('params', None) + if params is None: + database = kwargs.pop('database_', self._database) + kwargs.pop('parameters_', None) + params = kwargs + else: + database = kwargs.pop('database_', self._database) + kwargs.pop('parameters_', None) + + async with self.client.session(database=database) as session: + try: + result = await session.run(cypher_query_, params) + keys = result.keys() + records = [record async for record in result] + summary = await result.consume() + return (records, summary, keys) + except Exception as e: + logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}') + raise + finally: + await session.close() + + def session(self, database: str | None = None) -> GraphDriverSession: + _database = database or self._database + return self.client.session(database=_database) # type: ignore + + async def close(self) -> None: + return await self.client.close() + + def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: + return self.client.execute_query('DROP ALL INDEXES') diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 8e4cca4e..8f8b91a3 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -45,6 +45,30 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]: if provider == GraphProvider.KUZU: return [] + if provider == GraphProvider.MEMGRAPH: + return [ + 'CREATE INDEX ON :Entity(uuid);', + 'CREATE INDEX ON :Entity(group_id);', + 'CREATE INDEX ON :Entity(name);', + 'CREATE INDEX ON :Entity(created_at);', + 'CREATE INDEX ON :Episodic(uuid);', + 'CREATE INDEX ON :Episodic(group_id);', + 'CREATE INDEX ON :Episodic(created_at);', + 'CREATE INDEX ON :Episodic(valid_at);', + 'CREATE INDEX ON :Community(uuid);', + 'CREATE INDEX ON :Community(group_id);', + 'CREATE INDEX ON :RELATES_TO(uuid);', + 'CREATE INDEX ON :RELATES_TO(group_id);', + 'CREATE INDEX ON :RELATES_TO(name);', + 'CREATE INDEX ON :RELATES_TO(created_at);', + 'CREATE INDEX ON :RELATES_TO(expired_at);', + 'CREATE INDEX ON :RELATES_TO(valid_at);', + 'CREATE INDEX ON :RELATES_TO(invalid_at);', + 'CREATE INDEX ON :MENTIONS(uuid);', + 'CREATE INDEX ON :MENTIONS(group_id);', + 'CREATE INDEX ON :HAS_MEMBER(uuid);', + ] + return [ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', @@ -115,6 +139,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]: "CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);", ] + if provider == GraphProvider.MEMGRAPH: + return [ + """CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id);""", + """CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id);""", + """CREATE TEXT INDEX community_name ON :Community(name, group_id);""", + """CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id);""", + ] + return [ """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", @@ -136,6 +168,9 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) label = INDEX_TO_LABEL_KUZU_MAPPING[name] return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)" + if provider == GraphProvider.MEMGRAPH: + return f'CALL text_search.search_all("{name}", {query}, {limit})' + return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' @@ -147,6 +182,9 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: if provider == GraphProvider.KUZU: return f'array_cosine_similarity({vec1}, {vec2})' + if provider == GraphProvider.MEMGRAPH: + return f'vector_search.cosine_similarity({vec1}, {vec2})' + return f'vector.similarity.cosine({vec1}, {vec2})' @@ -159,4 +197,7 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s label = INDEX_TO_LABEL_KUZU_MAPPING[name] return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)" + if provider == GraphProvider.MEMGRAPH: + return f'CALL text_search.search_all_edges("{name}", $query, {limit})' + return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index c57f3694..0081303a 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -99,6 +99,15 @@ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) e.attributes = $attributes RETURN e.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + MATCH (source:Entity {uuid: $edge_data.source_uuid}) + MATCH (target:Entity {uuid: $edge_data.target_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) + SET e = $edge_data + SET e.fact_embedding = $edge_data.fact_embedding + RETURN e.uuid AS uuid + """ case _: # Neo4j save_embedding_query = ( """WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)""" @@ -163,6 +172,16 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = Fa e.attributes = $attributes RETURN e.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + UNWIND $entity_edges AS edge + MATCH (source:Entity {uuid: edge.source_node_uuid}) + MATCH (target:Entity {uuid: edge.target_node_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) + SET e = edge + SET e.fact_embedding = edge.fact_embedding + RETURN edge.uuid AS uuid; + """ case _: save_embedding_query = ( 'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)' @@ -261,7 +280,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str: e.created_at = $created_at RETURN e.uuid AS uuid """ - case _: # Neo4j + case _: # Neo4j and Memgraph return """ MATCH (community:Community {uuid: $community_uuid}) MATCH (node:Entity | Community {uuid: $entity_uuid}) diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 34e3d8b8..3035c0f2 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -49,6 +49,17 @@ def get_episode_node_save_query(provider: GraphProvider) -> str: entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + MERGE (n:Episodic {uuid: $uuid}) + SET n = { + uuid: $uuid, name: $name, group_id: $group_id, + source_description: $source_description, source: $source, + content: $content, entity_edges: $entity_edges, + created_at: $created_at, valid_at: $valid_at + } + RETURN n.uuid AS uuid + """ case _: # Neo4j return """ MERGE (n:Episodic {uuid: $uuid}) @@ -91,6 +102,14 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} RETURN n.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + UNWIND $episodes AS episode + MERGE (n:Episodic {uuid: episode.uuid}) + SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, + entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} + RETURN n.uuid AS uuid + """ case _: # Neo4j return """ UNWIND $episodes AS episode @@ -161,6 +180,14 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",") RETURN n.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return f""" + MERGE (n:Entity {{uuid: $entity_data.uuid}}) + SET n:{labels} + SET n = $entity_data + WITH n SET n.name_embedding = $entity_data.name_embedding + RETURN n.uuid AS uuid + """ case _: save_embedding_query = ( 'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)' @@ -233,6 +260,15 @@ def get_entity_node_save_bulk_query( n.attributes = $attributes RETURN n.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + UNWIND $nodes AS node + MERGE (n:Entity {uuid: node.uuid}) + SET n = node + WITH n, node + SET n.name_embedding = node.name_embedding + RETURN n.uuid AS uuid + """ case _: # Neo4j save_embedding_query = ( 'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)' @@ -303,6 +339,13 @@ def get_community_node_save_query(provider: GraphProvider) -> str: n.summary = $summary RETURN n.uuid AS uuid """ + case GraphProvider.MEMGRAPH: + return """ + MERGE (n:Community {uuid: $uuid}) + SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at} + WITH n SET n.name_embedding = $name_embedding + RETURN n.uuid AS uuid + """ case _: # Neo4j return """ MERGE (n:Community {uuid: $uuid}) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index cd3d003d..a0998b47 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -320,6 +320,9 @@ class EpisodicNode(Node): 'source': self.source.value, } + if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH): + episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '') + result = await driver.execute_query( get_episode_node_save_query(driver.provider), **episode_args ) diff --git a/tests/helpers_test.py b/tests/helpers_test.py index 58ef0c3e..3b5d1e4e 100644 --- a/tests/helpers_test.py +++ b/tests/helpers_test.py @@ -55,6 +55,14 @@ if os.getenv('DISABLE_KUZU') is None: except ImportError: raise +if os.getenv('DISABLE_MEMGRAPH') is None: + try: + from graphiti_core.driver.memgraph_driver import MemgraphDriver + + drivers.append(GraphProvider.MEMGRAPH) + except ImportError: + raise + # Disable Neptune for now os.environ['DISABLE_NEPTUNE'] = 'True' if os.getenv('DISABLE_NEPTUNE') is None: @@ -74,6 +82,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379') FALKORDB_USER = os.getenv('FALKORDB_USER', None) FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) +MEMGRAPH_URI = os.getenv('MEMGRAPH_URI', 'bolt://localhost:7687') +MEMGRAPH_USER = os.getenv('MEMGRAPH_USER', '') +MEMGRAPH_PASSWORD = os.getenv('MEMGRAPH_PASSWORD', '') + NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost') NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182) AOSS_HOST = os.getenv('AOSS_HOST', None) @@ -103,6 +115,12 @@ def get_driver(provider: GraphProvider) -> GraphDriver: db=KUZU_DB, ) return driver + elif provider == GraphProvider.MEMGRAPH: + return MemgraphDriver( + uri=MEMGRAPH_URI, + user=MEMGRAPH_USER, + password=MEMGRAPH_PASSWORD, + ) elif provider == GraphProvider.NEPTUNE: return NeptuneDriver( host=NEPTUNE_HOST,