This commit is contained in:
David Ivekovic 2025-12-14 06:08:18 +01:00 committed by GitHub
commit 3e2026250d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 520 additions and 7 deletions

View file

@ -61,6 +61,11 @@ jobs:
NEO4J_AUTH: neo4j/testpass
NEO4J_PLUGINS: '["apoc"]'
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
memgraph:
image: memgraph/memgraph:latest
ports:
- 7688:7687
options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10
steps:
- uses: actions/checkout@v4
- name: Set up Python
@ -98,3 +103,13 @@ jobs:
tests/cross_encoder/test_bge_reranker_client_int.py \
tests/driver/test_falkordb_driver.py \
-m "not integration"
- name: Run Memgraph integration tests
env:
PYTHONPATH: ${{ github.workspace }}
MEMGRAPH_URI: bolt://localhost:7688
MEMGRAPH_USER:
MEMGRAPH_PASSWORD:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
uv run pytest tests/test_*_int.py -k "memgraph"

View file

@ -135,7 +135,7 @@ particularly suitable for applications requiring real-time interaction and preci
Requirements:
- Python 3.10 or higher
- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon
- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Memgraph 2.22+ / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon
OpenSearch Serverless collection (serves as the full text search backend)
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
@ -564,7 +564,7 @@ When you initialize a Graphiti instance, we collect:
- **Graphiti version**: The version you're using
- **Configuration choices**:
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
- Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics)
- Database backend (Neo4j, FalkorDB, Kuzu, Memgraph, Amazon Neptune Database or Neptune Analytics)
- Embedder provider (OpenAI, Azure, Voyage, etc.)
### What We Don't Collect

View file

@ -87,6 +87,35 @@ services:
- PORT=8001
- db_backend=falkordb
memgraph:
image: memgraph/memgraph:latest
healthcheck:
test:
[
"CMD",
"mg_client",
"--host",
"localhost",
"--port",
"7687",
"--use-ssl=false",
"--query",
"RETURN 1;"
]
interval: 5s
timeout: 10s
retries: 10
start_period: 3s
ports:
- "7688:7687" # Bolt (using different port to avoid conflict)
volumes:
- memgraph_data:/var/lib/memgraph
environment:
- MEMGRAPH_USER=${MEMGRAPH_USER:-}
- MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-}
command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"]
volumes:
neo4j_data:
memgraph_data:
falkordb_data:

View file

@ -2,7 +2,7 @@
This example demonstrates the basic functionality of Graphiti, including:
1. Connecting to a Neo4j or FalkorDB database
1. Connecting to a Neo4j, FalkorDB, or Memgraph database
2. Initializing Graphiti indices and constraints
3. Adding episodes to the graph
4. Searching the graph with semantic and keyword matching
@ -17,6 +17,9 @@ This example demonstrates the basic functionality of Graphiti, including:
- Neo4j Desktop installed and running
- A local DBMS created and started in Neo4j Desktop
- **For FalkorDB**:
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
- **For Memgraph**:
- Memgraph server running (see [Memgraph documentation](https://memgraph.com/docs/) for setup)
- FalkorDB server running (see [FalkorDB documentation](https://docs.falkordb.com) for setup)
- **For Amazon Neptune**:
- Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup)
@ -44,6 +47,11 @@ export NEO4J_PASSWORD=password
# Optional FalkorDB connection parameters (defaults shown)
export FALKORDB_URI=falkor://localhost:6379
# Optional Memgraph connection parameters (defaults shown)
export MEMGRAPH_URI=bolt://localhost:7687
export MEMGRAPH_USER= # Optional - Memgraph doesn't require auth by default
export MEMGRAPH_PASSWORD= # Optional - Memgraph doesn't require auth by default
# Optional Amazon Neptune connection parameters
NEPTUNE_HOST=your_neptune_host
NEPTUNE_PORT=your_port_or_8182
@ -65,13 +73,16 @@ python quickstart_neo4j.py
# For FalkorDB
python quickstart_falkordb.py
# For Memgraph
python quickstart_memgraph.py
# For Amazon Neptune
python quickstart_neptune.py
```
## What This Example Demonstrates
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, FalkorDB, Memgraph, or Amazon Neptune
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance

View file

@ -0,0 +1,254 @@
"""
Copyright 2025, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from logging import INFO
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.driver.memgraph_driver import MemgraphDriver
from graphiti_core.nodes import EpisodeType
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
#################################################
# CONFIGURATION
#################################################
# Set up logging and environment variables for
# connecting to Memgraph database
#################################################
# Configure logging
logging.basicConfig(
level=INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)
load_dotenv()
# Memgraph connection parameters
# Make sure Memgraph is running (default port 7687, same as Neo4j)
memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687')
memgraph_user = os.environ.get(
'MEMGRAPH_USER', ''
) # Memgraph often doesn't require auth by default
memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '')
if not memgraph_uri:
raise ValueError('MEMGRAPH_URI must be set')
async def main():
#################################################
# INITIALIZATION
#################################################
# Connect to Memgraph and set up Graphiti indices
# This is required before using other Graphiti
# functionality
#################################################
# Initialize Memgraph driver
memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password)
# Initialize Graphiti with Memgraph connection
graphiti = Graphiti(graph_driver=memgraph_driver)
try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()
#################################################
# ADDING EPISODES
#################################################
# Episodes are the primary units of information
# in Graphiti. They can be text or structured JSON
# and are automatically processed to extract entities
# and relationships.
#################################################
# Example: Add Episodes
# Episodes list containing both text and JSON episodes
episodes = [
{
'content': 'Kamala Harris is the Attorney General of California. She was previously '
'the district attorney for San Francisco.',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': 'As AG, Harris was in office from January 3, 2011 January 3, 2017',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': {
'name': 'Gavin Newsom',
'position': 'Governor',
'state': 'California',
'predecessor': 'Jerry Brown',
},
'type': EpisodeType.json,
'description': 'politician info',
},
{
'content': 'Jerry Brown was the predecessor of Gavin Newsom as Governor of California',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
]
# Add episodes to the graph
for i, episode in enumerate(episodes):
await graphiti.add_episode(
name=f'Memgraph Demo {i}',
episode_body=episode['content']
if isinstance(episode['content'], str)
else json.dumps(episode['content']),
source=episode['type'],
source_description=episode['description'],
reference_time=datetime.now(timezone.utc),
)
logger.info(f'Added episode: Memgraph Demo {i} ({episode["type"].value})')
logger.info('Episodes added successfully!')
#################################################
# BASIC SEARCH
#################################################
# The simplest way to retrieve relationships (edges)
# from Graphiti is using the search method, which
# performs a hybrid search combining semantic
# similarity and BM25 text retrieval.
#################################################
# Perform a hybrid search combining semantic similarity and BM25 retrieval
logger.info('\n=== BASIC SEARCH ===')
search_query = 'Who was the California Attorney General?'
logger.info(f'Search query: {search_query}')
# Perform semantic search across edges (relationships)
results = await graphiti.search(search_query)
logger.info('Search results:')
for result in results[:3]: # Show top 3 results
logger.info(f'UUID: {result.uuid}')
logger.info(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
logger.info(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
logger.info(f'Valid until: {result.invalid_at}')
logger.info('---')
#################################################
# CENTER NODE SEARCH
#################################################
# For more contextually relevant results, you can
# use a center node to rerank search results based
# on their graph distance to a specific node
#################################################
if results and len(results) > 0:
logger.info('\n=== CENTER NODE SEARCH ===')
# Get the source node UUID from the top result
center_node_uuid = results[0].source_node_uuid
logger.info(f'Using center node UUID: {center_node_uuid}')
# Perform graph-based search using the source node
reranked_results = await graphiti.search(
search_query, center_node_uuid=center_node_uuid
)
logger.info('Reranked search results:')
for result in reranked_results[:3]:
logger.info(f'UUID: {result.uuid}')
logger.info(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
logger.info(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
logger.info(f'Valid until: {result.invalid_at}')
logger.info('---')
else:
logger.info('No results found in the initial search to use as center node.')
#################################################
# NODE SEARCH WITH RECIPES
#################################################
# Graphiti provides predefined search configurations
# (recipes) that optimize search for specific patterns
# and use cases.
#################################################
logger.info('\n=== NODE SEARCH WITH RECIPES ===')
recipe_query = 'California Governor'
logger.info(f'Recipe search query: {recipe_query}')
# Use hybrid search recipe for balanced semantic and keyword matching
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
node_search_config.limit = 5 # Limit to 5 results
# Execute the node search
node_search_results = await graphiti._search(
query=recipe_query,
config=node_search_config,
)
logger.info('Node search results:')
for node in node_search_results.nodes:
logger.info(f'Node UUID: {node.uuid}')
logger.info(f'Node Name: {node.name}')
node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
logger.info(f'Content Summary: {node_summary}')
logger.info(f'Node Labels: {", ".join(node.labels)}')
logger.info(f'Created At: {node.created_at}')
if hasattr(node, 'attributes') and node.attributes:
logger.info('Attributes:')
for key, value in node.attributes.items():
logger.info(f' {key}: {value}')
#################################################
# SUMMARY STATISTICS
#################################################
# Get overall statistics about the knowledge graph
#################################################
logger.info('\n=== SUMMARY ===')
logger.info('Memgraph database populated successfully!')
logger.info('Knowledge graph is ready for queries and exploration.')
except Exception as e:
logger.error(f'An error occurred: {e}')
raise
finally:
#################################################
# CLEANUP
#################################################
# Always close the connection to Memgraph when
# finished to properly release resources
#################################################
# Close the connection
await graphiti.close()
logger.info('Connection closed.')
if __name__ == '__main__':
asyncio.run(main())

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from neo4j import Neo4jDriver
from .memgraph_driver import MemgraphDriver
from .neo4j_driver import Neo4jDriver
__all__ = ['Neo4jDriver']
__all__ = ['Neo4jDriver', 'MemgraphDriver']

View file

@ -44,6 +44,7 @@ class GraphProvider(Enum):
FALKORDB = 'falkordb'
KUZU = 'kuzu'
NEPTUNE = 'neptune'
MEMGRAPH = 'memgraph'
class GraphDriverSession(ABC):

View file

@ -0,0 +1,78 @@
"""
Copyright 2025, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections.abc import Coroutine
from typing import Any
from neo4j import AsyncGraphDatabase
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__)
class MemgraphDriver(GraphDriver):
provider = GraphProvider.MEMGRAPH
def __init__(
self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'
):
super().__init__()
self.client = AsyncGraphDatabase.driver(
uri=uri,
auth=(user or '', password or ''),
)
self._database = database
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Any:
"""
Execute a Cypher query against Memgraph using implicit transactions.
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
"""
# Extract parameters from kwargs
params = kwargs.pop('params', None)
if params is None:
database = kwargs.pop('database_', self._database)
kwargs.pop('parameters_', None)
params = kwargs
else:
database = kwargs.pop('database_', self._database)
kwargs.pop('parameters_', None)
async with self.client.session(database=database) as session:
try:
result = await session.run(cypher_query_, params)
keys = result.keys()
records = [record async for record in result]
summary = await result.consume()
return (records, summary, keys)
except Exception as e:
logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}')
raise
finally:
await session.close()
def session(self, database: str | None = None) -> GraphDriverSession:
_database = database or self._database
return self.client.session(database=_database) # type: ignore
async def close(self) -> None:
return await self.client.close()
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
return self.client.execute_query('DROP ALL INDEXES')

View file

@ -45,6 +45,30 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
if provider == GraphProvider.KUZU:
return []
if provider == GraphProvider.MEMGRAPH:
return [
'CREATE INDEX ON :Entity(uuid);',
'CREATE INDEX ON :Entity(group_id);',
'CREATE INDEX ON :Entity(name);',
'CREATE INDEX ON :Entity(created_at);',
'CREATE INDEX ON :Episodic(uuid);',
'CREATE INDEX ON :Episodic(group_id);',
'CREATE INDEX ON :Episodic(created_at);',
'CREATE INDEX ON :Episodic(valid_at);',
'CREATE INDEX ON :Community(uuid);',
'CREATE INDEX ON :Community(group_id);',
'CREATE INDEX ON :RELATES_TO(uuid);',
'CREATE INDEX ON :RELATES_TO(group_id);',
'CREATE INDEX ON :RELATES_TO(name);',
'CREATE INDEX ON :RELATES_TO(created_at);',
'CREATE INDEX ON :RELATES_TO(expired_at);',
'CREATE INDEX ON :RELATES_TO(valid_at);',
'CREATE INDEX ON :RELATES_TO(invalid_at);',
'CREATE INDEX ON :MENTIONS(uuid);',
'CREATE INDEX ON :MENTIONS(group_id);',
'CREATE INDEX ON :HAS_MEMBER(uuid);',
]
return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
@ -115,6 +139,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
]
if provider == GraphProvider.MEMGRAPH:
return [
"""CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id);""",
"""CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id);""",
"""CREATE TEXT INDEX community_name ON :Community(name, group_id);""",
"""CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id);""",
]
return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
@ -136,6 +168,9 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider)
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
if provider == GraphProvider.MEMGRAPH:
return f'CALL text_search.search_all("{name}", {query}, {limit})'
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
@ -147,6 +182,9 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
if provider == GraphProvider.KUZU:
return f'array_cosine_similarity({vec1}, {vec2})'
if provider == GraphProvider.MEMGRAPH:
return f'vector_search.cosine_similarity({vec1}, {vec2})'
return f'vector.similarity.cosine({vec1}, {vec2})'
@ -159,4 +197,7 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
if provider == GraphProvider.MEMGRAPH:
return f'CALL text_search.search_all_edges("{name}", $query, {limit})'
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'

View file

@ -99,6 +99,15 @@ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False)
e.attributes = $attributes
RETURN e.uuid AS uuid
"""
case GraphProvider.MEMGRAPH:
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
SET e.fact_embedding = $edge_data.fact_embedding
RETURN e.uuid AS uuid
"""
case _: # Neo4j
save_embedding_query = (
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
@ -163,6 +172,16 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = Fa
e.attributes = $attributes
RETURN e.uuid AS uuid
"""
case GraphProvider.MEMGRAPH:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge
SET e.fact_embedding = edge.fact_embedding
RETURN edge.uuid AS uuid;
"""
case _:
save_embedding_query = (
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
@ -261,7 +280,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
e.created_at = $created_at
RETURN e.uuid AS uuid
"""
case _: # Neo4j
case _: # Neo4j and Memgraph
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})

View file

@ -49,6 +49,17 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
case GraphProvider.MEMGRAPH:
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {
uuid: $uuid, name: $name, group_id: $group_id,
source_description: $source_description, source: $source,
content: $content, entity_edges: $entity_edges,
created_at: $created_at, valid_at: $valid_at
}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Episodic {uuid: $uuid})
@ -91,6 +102,14 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case GraphProvider.MEMGRAPH:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
UNWIND $episodes AS episode
@ -161,6 +180,14 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b
SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
case GraphProvider.MEMGRAPH:
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
WITH n SET n.name_embedding = $entity_data.name_embedding
RETURN n.uuid AS uuid
"""
case _:
save_embedding_query = (
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
@ -233,6 +260,15 @@ def get_entity_node_save_bulk_query(
n.attributes = $attributes
RETURN n.uuid AS uuid
"""
case GraphProvider.MEMGRAPH:
return """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n = node
WITH n, node
SET n.name_embedding = node.name_embedding
RETURN n.uuid AS uuid
"""
case _: # Neo4j
save_embedding_query = (
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
@ -303,6 +339,13 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
n.summary = $summary
RETURN n.uuid AS uuid
"""
case GraphProvider.MEMGRAPH:
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n SET n.name_embedding = $name_embedding
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Community {uuid: $uuid})

View file

@ -320,6 +320,9 @@ class EpisodicNode(Node):
'source': self.source.value,
}
if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH):
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args
)

View file

@ -55,6 +55,14 @@ if os.getenv('DISABLE_KUZU') is None:
except ImportError:
raise
if os.getenv('DISABLE_MEMGRAPH') is None:
try:
from graphiti_core.driver.memgraph_driver import MemgraphDriver
drivers.append(GraphProvider.MEMGRAPH)
except ImportError:
raise
# Disable Neptune for now
os.environ['DISABLE_NEPTUNE'] = 'True'
if os.getenv('DISABLE_NEPTUNE') is None:
@ -74,6 +82,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
MEMGRAPH_URI = os.getenv('MEMGRAPH_URI', 'bolt://localhost:7687')
MEMGRAPH_USER = os.getenv('MEMGRAPH_USER', '')
MEMGRAPH_PASSWORD = os.getenv('MEMGRAPH_PASSWORD', '')
NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
AOSS_HOST = os.getenv('AOSS_HOST', None)
@ -103,6 +115,12 @@ def get_driver(provider: GraphProvider) -> GraphDriver:
db=KUZU_DB,
)
return driver
elif provider == GraphProvider.MEMGRAPH:
return MemgraphDriver(
uri=MEMGRAPH_URI,
user=MEMGRAPH_USER,
password=MEMGRAPH_PASSWORD,
)
elif provider == GraphProvider.NEPTUNE:
return NeptuneDriver(
host=NEPTUNE_HOST,