diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index cf1053a1..e7bd93ad 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -29,6 +29,11 @@ jobs: NEO4J_AUTH: neo4j/testpass NEO4J_PLUGINS: '["apoc"]' options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10 + memgraph: + image: memgraph/memgraph:latest + ports: + - 7688:7687 + options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10 steps: - uses: actions/checkout@v4 - name: Set up Python @@ -77,3 +82,13 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | uv run pytest tests/test_*_int.py -k "neo4j" + - name: Run Memgraph integration tests + env: + PYTHONPATH: ${{ github.workspace }} + MEMGRAPH_URI: bolt://localhost:7688 + MEMGRAPH_USER: + MEMGRAPH_PASSWORD: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + uv run pytest tests/test_*_int.py -k "memgraph" + \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 0400692c..38cdd9bf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -44,5 +44,34 @@ services: environment: - NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD} + memgraph: + image: memgraph/memgraph:latest + healthcheck: + test: + [ + "CMD", + "mg_client", + "--host", + "localhost", + "--port", + "7687", + "--use-ssl=false", + "--query", + "RETURN 1;" + ] + interval: 5s + timeout: 10s + retries: 10 + start_period: 3s + ports: + - "7688:7687" # Bolt (using different port to avoid conflict) + volumes: + - memgraph_data:/var/lib/memgraph + environment: + - MEMGRAPH_USER=${MEMGRAPH_USER:-} + - MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-} + command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"] + volumes: neo4j_data: + memgraph_data: diff --git a/examples/quickstart/quickstart_memgraph.py b/examples/quickstart/quickstart_memgraph.py index 911991ec..93fb4ace 100644 --- a/examples/quickstart/quickstart_memgraph.py +++ b/examples/quickstart/quickstart_memgraph.py @@ -48,7 +48,9 @@ load_dotenv() # Memgraph connection parameters # Make sure Memgraph is running (default port 7687, same as Neo4j) memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687') -memgraph_user = os.environ.get('MEMGRAPH_USER', '') # Memgraph often doesn't require auth by default +memgraph_user = os.environ.get( + 'MEMGRAPH_USER', '' +) # Memgraph often doesn't require auth by default memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '') if not memgraph_uri: @@ -66,7 +68,7 @@ async def main(): # Initialize Memgraph driver memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password) - + # Initialize Graphiti with Memgraph connection graphiti = Graphiti(graph_driver=memgraph_driver) diff --git a/graphiti_core/driver/__init__.py b/graphiti_core/driver/__init__.py index 02804df2..4babebd1 100644 --- a/graphiti_core/driver/__init__.py +++ b/graphiti_core/driver/__init__.py @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from .neo4j_driver import Neo4jDriver from .memgraph_driver import MemgraphDriver +from .neo4j_driver import Neo4jDriver __all__ = ['Neo4jDriver', 'MemgraphDriver'] diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py index d14c8d83..d419bee9 100644 --- a/graphiti_core/driver/memgraph_driver.py +++ b/graphiti_core/driver/memgraph_driver.py @@ -29,7 +29,9 @@ logger = logging.getLogger(__name__) class MemgraphDriver(GraphDriver): provider = GraphProvider.MEMGRAPH - def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'): + def __init__( + self, uri: str, user: str | None, password: str | None, database: str = 'memgraph' + ): super().__init__() self.client = AsyncGraphDatabase.driver( uri=uri, @@ -37,7 +39,9 @@ class MemgraphDriver(GraphDriver): ) self._database = database - async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> tuple[list, Any, Any]: + async def execute_query( + self, cypher_query_: LiteralString, **kwargs: Any + ) -> tuple[list, Any, Any]: """ Execute a Cypher query against Memgraph using implicit transactions. Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface. @@ -49,20 +53,20 @@ class MemgraphDriver(GraphDriver): # but first extract database-specific parameters database = kwargs.pop('database_', self._database) kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param) - + # All remaining kwargs are query parameters params = kwargs else: # Extract database parameter if params was provided separately database = kwargs.pop('database_', self._database) kwargs.pop('parameters_', None) # Remove if present - + async with self.client.session(database=database) as session: try: result = await session.run(cypher_query_, params) records = [record async for record in result] summary = await result.consume() - keys = result.keys() + keys = result.keys() return (records, summary, keys) except Exception as e: logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}') @@ -77,4 +81,4 @@ class MemgraphDriver(GraphDriver): def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: # TODO: Implement index deletion for Memgraph - raise NotImplementedError("Index deletion not implemented for MemgraphDriver") \ No newline at end of file + raise NotImplementedError('Index deletion not implemented for MemgraphDriver') diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 9f85b1c1..ffee09f2 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: return f'array_cosine_similarity({vec1}, {vec2})' if provider == GraphProvider.MEMGRAPH: - return f'cosineSimilarity({vec1}, {vec2})' + return f'vector_search.cosine_similarity({vec1}, {vec2})' return f'vector.similarity.cosine({vec1}, {vec2})' diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index abb0edbb..3554701e 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -104,7 +104,7 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str: MATCH (target:Entity {uuid: $edge_data.target_uuid}) MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) SET e = $edge_data - WITH e e.fact_embedding = $edge_data.fact_embedding + SET e.fact_embedding = $edge_data.fact_embedding RETURN e.uuid AS uuid """ case _: # Neo4j diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 7cf075f0..867b96c8 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -49,7 +49,18 @@ def get_episode_node_save_query(provider: GraphProvider) -> str: entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid """ - case _: # Neo4j and Memgraph + case GraphProvider.MEMGRAPH: + return """ + MERGE (n:Episodic {uuid: $uuid}) + SET n = { + uuid: $uuid, name: $name, group_id: $group_id, + source_description: $source_description, source: $source, + content: $content, entity_edges: $entity_edges, + created_at: $created_at, valid_at: $valid_at + } + RETURN n.uuid AS uuid + """ + case _: # Neo4j return """ MERGE (n:Episodic {uuid: $uuid}) SET n:$($group_label) @@ -96,23 +107,11 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str: return """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) - FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END | - SET n:`${episode.group_label}` - ) - SET n = { - uuid: episode.uuid, - name: episode.name, - group_id: episode.group_id, - source_description: episode.source_description, - source: episode.source, - content: episode.content, - entity_edges: episode.entity_edges, - created_at: episode.created_at, - valid_at: episode.valid_at - } - RETURN n.uuid AS uuid; + SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, + entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} + RETURN n.uuid AS uuid """ - case _: # Neo4j + case _: # Neo4j return """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) @@ -183,7 +182,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: RETURN n.uuid AS uuid """ case GraphProvider.MEMGRAPH: - return """ + return f""" MERGE (n:Entity {{uuid: $entity_data.uuid}}) SET n:{labels} SET n = $entity_data @@ -255,13 +254,10 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) return """ UNWIND $nodes AS node MERGE (n:Entity {uuid: node.uuid}) - FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END | - SET n:`${label}` - ) SET n = node WITH n, node SET n.name_embedding = node.name_embedding - RETURN n.uuid AS uuid; + RETURN n.uuid AS uuid """ case _: # Neo4j return """ diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index ef28a94d..ab9b2dbc 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -299,7 +299,7 @@ class EpisodicNode(Node): 'source': self.source.value, } - if driver.provider == GraphProvider.NEO4J: + if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH): episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '') result = await driver.execute_query(