add memgraph as graphdb vendor wip
This commit is contained in:
parent
eeb0d877de
commit
0f9dd11cc8
10 changed files with 442 additions and 9 deletions
|
|
@ -118,7 +118,7 @@ particularly suitable for applications requiring real-time interaction and preci
|
|||
Requirements:
|
||||
|
||||
- Python 3.10 or higher
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Memgraph 2.22+ / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon
|
||||
OpenSearch Serverless collection (serves as the full text search backend)
|
||||
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
|
||||
|
||||
|
|
@ -556,7 +556,7 @@ When you initialize a Graphiti instance, we collect:
|
|||
- **Graphiti version**: The version you're using
|
||||
- **Configuration choices**:
|
||||
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
|
||||
- Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics)
|
||||
- Database backend (Neo4j, FalkorDB, Kuzu, Memgraph, Amazon Neptune Database or Neptune Analytics)
|
||||
- Embedder provider (OpenAI, Azure, Voyage, etc.)
|
||||
|
||||
### What We Don't Collect
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
This example demonstrates the basic functionality of Graphiti, including:
|
||||
|
||||
1. Connecting to a Neo4j or FalkorDB database
|
||||
1. Connecting to a Neo4j, FalkorDB, or Memgraph database
|
||||
2. Initializing Graphiti indices and constraints
|
||||
3. Adding episodes to the graph
|
||||
4. Searching the graph with semantic and keyword matching
|
||||
|
|
@ -18,6 +18,8 @@ This example demonstrates the basic functionality of Graphiti, including:
|
|||
- A local DBMS created and started in Neo4j Desktop
|
||||
- **For FalkorDB**:
|
||||
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
|
||||
- **For Memgraph**:
|
||||
- Memgraph server running (see [Memgraph documentation](https://memgraph.com/docs/) for setup)
|
||||
- **For Amazon Neptune**:
|
||||
- Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup)
|
||||
|
||||
|
|
@ -44,6 +46,11 @@ export NEO4J_PASSWORD=password
|
|||
# Optional FalkorDB connection parameters (defaults shown)
|
||||
export FALKORDB_URI=falkor://localhost:6379
|
||||
|
||||
# Optional Memgraph connection parameters (defaults shown)
|
||||
export MEMGRAPH_URI=bolt://localhost:7687
|
||||
export MEMGRAPH_USER= # Optional - Memgraph doesn't require auth by default
|
||||
export MEMGRAPH_PASSWORD= # Optional - Memgraph doesn't require auth by default
|
||||
|
||||
# Optional Amazon Neptune connection parameters
|
||||
NEPTUNE_HOST=your_neptune_host
|
||||
NEPTUNE_PORT=your_port_or_8182
|
||||
|
|
@ -65,13 +72,16 @@ python quickstart_neo4j.py
|
|||
# For FalkorDB
|
||||
python quickstart_falkordb.py
|
||||
|
||||
# For Memgraph
|
||||
python quickstart_memgraph.py
|
||||
|
||||
# For Amazon Neptune
|
||||
python quickstart_neptune.py
|
||||
```
|
||||
|
||||
## What This Example Demonstrates
|
||||
|
||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB
|
||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, FalkorDB, Memgraph, or Amazon Neptune
|
||||
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
|
||||
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
|
||||
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance
|
||||
|
|
|
|||
252
examples/quickstart/quickstart_memgraph.py
Normal file
252
examples/quickstart/quickstart_memgraph.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
"""
|
||||
Copyright 2025, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from logging import INFO
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.driver.memgraph_driver import MemgraphDriver
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
||||
|
||||
#################################################
|
||||
# CONFIGURATION
|
||||
#################################################
|
||||
# Set up logging and environment variables for
|
||||
# connecting to Memgraph database
|
||||
#################################################
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Memgraph connection parameters
|
||||
# Make sure Memgraph is running (default port 7687, same as Neo4j)
|
||||
memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687')
|
||||
memgraph_user = os.environ.get('MEMGRAPH_USER', '') # Memgraph often doesn't require auth by default
|
||||
memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '')
|
||||
|
||||
if not memgraph_uri:
|
||||
raise ValueError('MEMGRAPH_URI must be set')
|
||||
|
||||
|
||||
async def main():
|
||||
#################################################
|
||||
# INITIALIZATION
|
||||
#################################################
|
||||
# Connect to Memgraph and set up Graphiti indices
|
||||
# This is required before using other Graphiti
|
||||
# functionality
|
||||
#################################################
|
||||
|
||||
# Initialize Memgraph driver
|
||||
memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password)
|
||||
|
||||
# Initialize Graphiti with Memgraph connection
|
||||
graphiti = Graphiti(graph_driver=memgraph_driver)
|
||||
|
||||
try:
|
||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
#################################################
|
||||
# Episodes are the primary units of information
|
||||
# in Graphiti. They can be text or structured JSON
|
||||
# and are automatically processed to extract entities
|
||||
# and relationships.
|
||||
#################################################
|
||||
|
||||
# Example: Add Episodes
|
||||
# Episodes list containing both text and JSON episodes
|
||||
episodes = [
|
||||
{
|
||||
'content': 'Kamala Harris is the Attorney General of California. She was previously '
|
||||
'the district attorney for San Francisco.',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
{
|
||||
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
{
|
||||
'content': {
|
||||
'name': 'Gavin Newsom',
|
||||
'position': 'Governor',
|
||||
'state': 'California',
|
||||
'predecessor': 'Jerry Brown',
|
||||
},
|
||||
'type': EpisodeType.json,
|
||||
'description': 'politician info',
|
||||
},
|
||||
{
|
||||
'content': 'Jerry Brown was the predecessor of Gavin Newsom as Governor of California',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
]
|
||||
|
||||
# Add episodes to the graph
|
||||
for i, episode in enumerate(episodes):
|
||||
await graphiti.add_episode(
|
||||
name=f'Memgraph Demo {i}',
|
||||
episode_body=episode['content']
|
||||
if isinstance(episode['content'], str)
|
||||
else json.dumps(episode['content']),
|
||||
source=episode['type'],
|
||||
source_description=episode['description'],
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
)
|
||||
logger.info(f'Added episode: Memgraph Demo {i} ({episode["type"].value})')
|
||||
|
||||
logger.info('Episodes added successfully!')
|
||||
|
||||
#################################################
|
||||
# BASIC SEARCH
|
||||
#################################################
|
||||
# The simplest way to retrieve relationships (edges)
|
||||
# from Graphiti is using the search method, which
|
||||
# performs a hybrid search combining semantic
|
||||
# similarity and BM25 text retrieval.
|
||||
#################################################
|
||||
|
||||
# Perform a hybrid search combining semantic similarity and BM25 retrieval
|
||||
logger.info('\n=== BASIC SEARCH ===')
|
||||
search_query = 'Who was the California Attorney General?'
|
||||
logger.info(f'Search query: {search_query}')
|
||||
|
||||
# Perform semantic search across edges (relationships)
|
||||
results = await graphiti.search(search_query)
|
||||
|
||||
logger.info('Search results:')
|
||||
for result in results[:3]: # Show top 3 results
|
||||
logger.info(f'UUID: {result.uuid}')
|
||||
logger.info(f'Fact: {result.fact}')
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
logger.info(f'Valid from: {result.valid_at}')
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
logger.info(f'Valid until: {result.invalid_at}')
|
||||
logger.info('---')
|
||||
|
||||
#################################################
|
||||
# CENTER NODE SEARCH
|
||||
#################################################
|
||||
# For more contextually relevant results, you can
|
||||
# use a center node to rerank search results based
|
||||
# on their graph distance to a specific node
|
||||
#################################################
|
||||
|
||||
if results and len(results) > 0:
|
||||
logger.info('\n=== CENTER NODE SEARCH ===')
|
||||
# Get the source node UUID from the top result
|
||||
center_node_uuid = results[0].source_node_uuid
|
||||
logger.info(f'Using center node UUID: {center_node_uuid}')
|
||||
|
||||
# Perform graph-based search using the source node
|
||||
reranked_results = await graphiti.search(
|
||||
search_query, center_node_uuid=center_node_uuid
|
||||
)
|
||||
|
||||
logger.info('Reranked search results:')
|
||||
for result in reranked_results[:3]:
|
||||
logger.info(f'UUID: {result.uuid}')
|
||||
logger.info(f'Fact: {result.fact}')
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
logger.info(f'Valid from: {result.valid_at}')
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
logger.info(f'Valid until: {result.invalid_at}')
|
||||
logger.info('---')
|
||||
else:
|
||||
logger.info('No results found in the initial search to use as center node.')
|
||||
|
||||
#################################################
|
||||
# NODE SEARCH WITH RECIPES
|
||||
#################################################
|
||||
# Graphiti provides predefined search configurations
|
||||
# (recipes) that optimize search for specific patterns
|
||||
# and use cases.
|
||||
#################################################
|
||||
|
||||
logger.info('\n=== NODE SEARCH WITH RECIPES ===')
|
||||
recipe_query = 'California Governor'
|
||||
logger.info(f'Recipe search query: {recipe_query}')
|
||||
|
||||
# Use hybrid search recipe for balanced semantic and keyword matching
|
||||
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
|
||||
node_search_config.limit = 5 # Limit to 5 results
|
||||
|
||||
# Execute the node search
|
||||
node_search_results = await graphiti._search(
|
||||
query=recipe_query,
|
||||
config=node_search_config,
|
||||
)
|
||||
|
||||
logger.info('Node search results:')
|
||||
for node in node_search_results.nodes:
|
||||
logger.info(f'Node UUID: {node.uuid}')
|
||||
logger.info(f'Node Name: {node.name}')
|
||||
node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
|
||||
logger.info(f'Content Summary: {node_summary}')
|
||||
logger.info(f'Node Labels: {", ".join(node.labels)}')
|
||||
logger.info(f'Created At: {node.created_at}')
|
||||
if hasattr(node, 'attributes') and node.attributes:
|
||||
logger.info('Attributes:')
|
||||
for key, value in node.attributes.items():
|
||||
logger.info(f' {key}: {value}')
|
||||
#################################################
|
||||
# SUMMARY STATISTICS
|
||||
#################################################
|
||||
# Get overall statistics about the knowledge graph
|
||||
#################################################
|
||||
|
||||
logger.info('\n=== SUMMARY ===')
|
||||
logger.info('Memgraph database populated successfully!')
|
||||
logger.info('Knowledge graph is ready for queries and exploration.')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'An error occurred: {e}')
|
||||
raise
|
||||
|
||||
finally:
|
||||
#################################################
|
||||
# CLEANUP
|
||||
#################################################
|
||||
# Always close the connection to Memgraph when
|
||||
# finished to properly release resources
|
||||
#################################################
|
||||
|
||||
# Close the connection
|
||||
await graphiti.close()
|
||||
logger.info('Connection closed.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from neo4j import Neo4jDriver
|
||||
from .neo4j_driver import Neo4jDriver
|
||||
from .memgraph_driver import MemgraphDriver
|
||||
|
||||
__all__ = ['Neo4jDriver']
|
||||
__all__ = ['Neo4jDriver', 'MemgraphDriver']
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class GraphProvider(Enum):
|
|||
FALKORDB = 'falkordb'
|
||||
KUZU = 'kuzu'
|
||||
NEPTUNE = 'neptune'
|
||||
MEMGRAPH = 'memgraph'
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
|
|
|
|||
67
graphiti_core/driver/memgraph_driver.py
Normal file
67
graphiti_core/driver/memgraph_driver.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""
|
||||
Copyright 2025, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from neo4j import AsyncGraphDatabase, EagerResult
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemgraphDriver(GraphDriver):
|
||||
provider = GraphProvider.MEMGRAPH
|
||||
|
||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
uri=uri,
|
||||
auth=(user or '', password or ''),
|
||||
)
|
||||
self._database = database
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||
# Check if database_ is provided in kwargs.
|
||||
# If not populated, set the value to retain backwards compatibility
|
||||
params = kwargs.pop('params', None)
|
||||
if params is None:
|
||||
params = {}
|
||||
params.setdefault('database_', self._database)
|
||||
|
||||
try:
|
||||
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}')
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||
_database = database or self._database
|
||||
return self.client.session(database=_database) # type: ignore
|
||||
|
||||
async def close(self) -> None:
|
||||
return await self.client.close()
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
|
||||
# TODO
|
||||
return self.client.execute_query(
|
||||
'SHOW INDEX INFO;',
|
||||
)
|
||||
|
|
@ -45,6 +45,30 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|||
if provider == GraphProvider.KUZU:
|
||||
return []
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return [
|
||||
'CREATE INDEX ON :Entity(uuid)',
|
||||
'CREATE INDEX ON :Entity(group_id)',
|
||||
'CREATE INDEX ON :Entity(name)',
|
||||
'CREATE INDEX ON :Entity(created_at)',
|
||||
'CREATE INDEX ON :Episodic(uuid)',
|
||||
'CREATE INDEX ON :Episodic(group_id)',
|
||||
'CREATE INDEX ON :Episodic(created_at)',
|
||||
'CREATE INDEX ON :Episodic(valid_at)',
|
||||
'CREATE INDEX ON :Community(uuid)',
|
||||
'CREATE INDEX ON :Community(group_id)',
|
||||
'CREATE INDEX ON :RELATES_TO(uuid)',
|
||||
'CREATE INDEX ON :RELATES_TO(group_id)',
|
||||
'CREATE INDEX ON :RELATES_TO(name)',
|
||||
'CREATE INDEX ON :RELATES_TO(created_at)',
|
||||
'CREATE INDEX ON :RELATES_TO(expired_at)',
|
||||
'CREATE INDEX ON :RELATES_TO(valid_at)',
|
||||
'CREATE INDEX ON :RELATES_TO(invalid_at)',
|
||||
'CREATE INDEX ON :MENTIONS(uuid)',
|
||||
'CREATE INDEX ON :MENTIONS(group_id)',
|
||||
'CREATE INDEX ON :HAS_MEMBER(uuid)',
|
||||
]
|
||||
|
||||
return [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
|
|
@ -86,6 +110,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|||
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
|
||||
]
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return [
|
||||
"""CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id)""",
|
||||
"""CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id)""",
|
||||
"""CREATE TEXT INDEX community_name ON :Community(name, group_id)""",
|
||||
"""CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id)""",
|
||||
]
|
||||
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
|
|
@ -107,6 +139,9 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider)
|
|||
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return f'CALL text_search.search("{name}", {query}) YIELD node RETURN node LIMIT $limit'
|
||||
|
||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||
|
||||
|
||||
|
|
@ -118,6 +153,9 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
|||
if provider == GraphProvider.KUZU:
|
||||
return f'array_cosine_similarity({vec1}, {vec2})'
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return "TODO"
|
||||
|
||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||
|
||||
|
||||
|
|
@ -130,4 +168,7 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s
|
|||
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return f'CALL text_search.search("{name}", $query) YIELD node RETURN node LIMIT $limit'
|
||||
|
||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||
|
|
|
|||
|
|
@ -98,6 +98,15 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
|||
e.attributes = $attributes
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
WITH e e.fact_embedding = $edge_data.fact_embedding
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
|
|
@ -151,6 +160,16 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
|||
e.attributes = $attributes
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET e = edge
|
||||
WITH e, edge e.fact_embedding = edge.fact_embedding
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
|
|
@ -240,7 +259,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
|
|||
e.created_at = $created_at
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
case _: # Neo4j and Memgraph
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
|
|||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
case _: # Neo4j and Memgraph
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n:$($group_label)
|
||||
|
|
@ -92,7 +92,7 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
|||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
case _: # Neo4j and Memgraph
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
|
|
@ -162,6 +162,14 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
|||
SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
WITH n SET n.name_embedding = $entity_data.name_embedding
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
|
|
@ -223,6 +231,15 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
|||
n.attributes = $attributes
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
WITH n, node SET n.name_embedding = node.name_embedding
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
UNWIND $nodes AS node
|
||||
|
|
@ -284,6 +301,13 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
|
|||
n.summary = $summary
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n SET n.name_embedding = $name_embedding
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
|
|
|
|||
|
|
@ -55,6 +55,14 @@ if os.getenv('DISABLE_KUZU') is None:
|
|||
except ImportError:
|
||||
raise
|
||||
|
||||
if os.getenv('DISABLE_MEMGRAPH') is None:
|
||||
try:
|
||||
from graphiti_core.driver.memgraph_driver import MemgraphDriver
|
||||
|
||||
drivers.append(GraphProvider.MEMGRAPH)
|
||||
except ImportError:
|
||||
raise
|
||||
|
||||
# Disable Neptune for now
|
||||
os.environ['DISABLE_NEPTUNE'] = 'True'
|
||||
if os.getenv('DISABLE_NEPTUNE') is None:
|
||||
|
|
@ -74,6 +82,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
|||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||
|
||||
MEMGRAPH_URI = os.getenv('MEMGRAPH_URI', 'bolt://localhost:7687')
|
||||
MEMGRAPH_USER = os.getenv('MEMGRAPH_USER', '')
|
||||
MEMGRAPH_PASSWORD = os.getenv('MEMGRAPH_PASSWORD', '')
|
||||
|
||||
NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
|
||||
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
|
||||
AOSS_HOST = os.getenv('AOSS_HOST', None)
|
||||
|
|
@ -103,6 +115,12 @@ def get_driver(provider: GraphProvider) -> GraphDriver:
|
|||
db=KUZU_DB,
|
||||
)
|
||||
return driver
|
||||
elif provider == GraphProvider.MEMGRAPH:
|
||||
return MemgraphDriver(
|
||||
uri=MEMGRAPH_URI,
|
||||
user=MEMGRAPH_USER,
|
||||
password=MEMGRAPH_PASSWORD,
|
||||
)
|
||||
elif provider == GraphProvider.NEPTUNE:
|
||||
return NeptuneDriver(
|
||||
host=NEPTUNE_HOST,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue