Merge b534850d1e into d6ff7bb78c
This commit is contained in:
commit
3e2026250d
13 changed files with 520 additions and 7 deletions
15
.github/workflows/unit_tests.yml
vendored
15
.github/workflows/unit_tests.yml
vendored
|
|
@ -61,6 +61,11 @@ jobs:
|
|||
NEO4J_AUTH: neo4j/testpass
|
||||
NEO4J_PLUGINS: '["apoc"]'
|
||||
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
|
||||
memgraph:
|
||||
image: memgraph/memgraph:latest
|
||||
ports:
|
||||
- 7688:7687
|
||||
options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
|
|
@ -98,3 +103,13 @@ jobs:
|
|||
tests/cross_encoder/test_bge_reranker_client_int.py \
|
||||
tests/driver/test_falkordb_driver.py \
|
||||
-m "not integration"
|
||||
- name: Run Memgraph integration tests
|
||||
env:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
MEMGRAPH_URI: bolt://localhost:7688
|
||||
MEMGRAPH_USER:
|
||||
MEMGRAPH_PASSWORD:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
uv run pytest tests/test_*_int.py -k "memgraph"
|
||||
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ particularly suitable for applications requiring real-time interaction and preci
|
|||
Requirements:
|
||||
|
||||
- Python 3.10 or higher
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Memgraph 2.22+ / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon
|
||||
OpenSearch Serverless collection (serves as the full text search backend)
|
||||
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
|
||||
|
||||
|
|
@ -564,7 +564,7 @@ When you initialize a Graphiti instance, we collect:
|
|||
- **Graphiti version**: The version you're using
|
||||
- **Configuration choices**:
|
||||
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
|
||||
- Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics)
|
||||
- Database backend (Neo4j, FalkorDB, Kuzu, Memgraph, Amazon Neptune Database or Neptune Analytics)
|
||||
- Embedder provider (OpenAI, Azure, Voyage, etc.)
|
||||
|
||||
### What We Don't Collect
|
||||
|
|
|
|||
|
|
@ -87,6 +87,35 @@ services:
|
|||
- PORT=8001
|
||||
- db_backend=falkordb
|
||||
|
||||
memgraph:
|
||||
image: memgraph/memgraph:latest
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"mg_client",
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
"7687",
|
||||
"--use-ssl=false",
|
||||
"--query",
|
||||
"RETURN 1;"
|
||||
]
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 3s
|
||||
ports:
|
||||
- "7688:7687" # Bolt (using different port to avoid conflict)
|
||||
volumes:
|
||||
- memgraph_data:/var/lib/memgraph
|
||||
environment:
|
||||
- MEMGRAPH_USER=${MEMGRAPH_USER:-}
|
||||
- MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-}
|
||||
command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"]
|
||||
|
||||
volumes:
|
||||
neo4j_data:
|
||||
memgraph_data:
|
||||
falkordb_data:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
This example demonstrates the basic functionality of Graphiti, including:
|
||||
|
||||
1. Connecting to a Neo4j or FalkorDB database
|
||||
1. Connecting to a Neo4j, FalkorDB, or Memgraph database
|
||||
2. Initializing Graphiti indices and constraints
|
||||
3. Adding episodes to the graph
|
||||
4. Searching the graph with semantic and keyword matching
|
||||
|
|
@ -17,6 +17,9 @@ This example demonstrates the basic functionality of Graphiti, including:
|
|||
- Neo4j Desktop installed and running
|
||||
- A local DBMS created and started in Neo4j Desktop
|
||||
- **For FalkorDB**:
|
||||
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
|
||||
- **For Memgraph**:
|
||||
- Memgraph server running (see [Memgraph documentation](https://memgraph.com/docs/) for setup)
|
||||
- FalkorDB server running (see [FalkorDB documentation](https://docs.falkordb.com) for setup)
|
||||
- **For Amazon Neptune**:
|
||||
- Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup)
|
||||
|
|
@ -44,6 +47,11 @@ export NEO4J_PASSWORD=password
|
|||
# Optional FalkorDB connection parameters (defaults shown)
|
||||
export FALKORDB_URI=falkor://localhost:6379
|
||||
|
||||
# Optional Memgraph connection parameters (defaults shown)
|
||||
export MEMGRAPH_URI=bolt://localhost:7687
|
||||
export MEMGRAPH_USER= # Optional - Memgraph doesn't require auth by default
|
||||
export MEMGRAPH_PASSWORD= # Optional - Memgraph doesn't require auth by default
|
||||
|
||||
# Optional Amazon Neptune connection parameters
|
||||
NEPTUNE_HOST=your_neptune_host
|
||||
NEPTUNE_PORT=your_port_or_8182
|
||||
|
|
@ -65,13 +73,16 @@ python quickstart_neo4j.py
|
|||
# For FalkorDB
|
||||
python quickstart_falkordb.py
|
||||
|
||||
# For Memgraph
|
||||
python quickstart_memgraph.py
|
||||
|
||||
# For Amazon Neptune
|
||||
python quickstart_neptune.py
|
||||
```
|
||||
|
||||
## What This Example Demonstrates
|
||||
|
||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB
|
||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, FalkorDB, Memgraph, or Amazon Neptune
|
||||
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
|
||||
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
|
||||
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance
|
||||
|
|
|
|||
254
examples/quickstart/quickstart_memgraph.py
Normal file
254
examples/quickstart/quickstart_memgraph.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""
|
||||
Copyright 2025, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from logging import INFO
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.driver.memgraph_driver import MemgraphDriver
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
||||
|
||||
#################################################
|
||||
# CONFIGURATION
|
||||
#################################################
|
||||
# Set up logging and environment variables for
|
||||
# connecting to Memgraph database
|
||||
#################################################
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Memgraph connection parameters
|
||||
# Make sure Memgraph is running (default port 7687, same as Neo4j)
|
||||
memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687')
|
||||
memgraph_user = os.environ.get(
|
||||
'MEMGRAPH_USER', ''
|
||||
) # Memgraph often doesn't require auth by default
|
||||
memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '')
|
||||
|
||||
if not memgraph_uri:
|
||||
raise ValueError('MEMGRAPH_URI must be set')
|
||||
|
||||
|
||||
async def main():
|
||||
#################################################
|
||||
# INITIALIZATION
|
||||
#################################################
|
||||
# Connect to Memgraph and set up Graphiti indices
|
||||
# This is required before using other Graphiti
|
||||
# functionality
|
||||
#################################################
|
||||
|
||||
# Initialize Memgraph driver
|
||||
memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password)
|
||||
|
||||
# Initialize Graphiti with Memgraph connection
|
||||
graphiti = Graphiti(graph_driver=memgraph_driver)
|
||||
|
||||
try:
|
||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
#################################################
|
||||
# Episodes are the primary units of information
|
||||
# in Graphiti. They can be text or structured JSON
|
||||
# and are automatically processed to extract entities
|
||||
# and relationships.
|
||||
#################################################
|
||||
|
||||
# Example: Add Episodes
|
||||
# Episodes list containing both text and JSON episodes
|
||||
episodes = [
|
||||
{
|
||||
'content': 'Kamala Harris is the Attorney General of California. She was previously '
|
||||
'the district attorney for San Francisco.',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
{
|
||||
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
{
|
||||
'content': {
|
||||
'name': 'Gavin Newsom',
|
||||
'position': 'Governor',
|
||||
'state': 'California',
|
||||
'predecessor': 'Jerry Brown',
|
||||
},
|
||||
'type': EpisodeType.json,
|
||||
'description': 'politician info',
|
||||
},
|
||||
{
|
||||
'content': 'Jerry Brown was the predecessor of Gavin Newsom as Governor of California',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
]
|
||||
|
||||
# Add episodes to the graph
|
||||
for i, episode in enumerate(episodes):
|
||||
await graphiti.add_episode(
|
||||
name=f'Memgraph Demo {i}',
|
||||
episode_body=episode['content']
|
||||
if isinstance(episode['content'], str)
|
||||
else json.dumps(episode['content']),
|
||||
source=episode['type'],
|
||||
source_description=episode['description'],
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
)
|
||||
logger.info(f'Added episode: Memgraph Demo {i} ({episode["type"].value})')
|
||||
|
||||
logger.info('Episodes added successfully!')
|
||||
|
||||
#################################################
|
||||
# BASIC SEARCH
|
||||
#################################################
|
||||
# The simplest way to retrieve relationships (edges)
|
||||
# from Graphiti is using the search method, which
|
||||
# performs a hybrid search combining semantic
|
||||
# similarity and BM25 text retrieval.
|
||||
#################################################
|
||||
|
||||
# Perform a hybrid search combining semantic similarity and BM25 retrieval
|
||||
logger.info('\n=== BASIC SEARCH ===')
|
||||
search_query = 'Who was the California Attorney General?'
|
||||
logger.info(f'Search query: {search_query}')
|
||||
|
||||
# Perform semantic search across edges (relationships)
|
||||
results = await graphiti.search(search_query)
|
||||
|
||||
logger.info('Search results:')
|
||||
for result in results[:3]: # Show top 3 results
|
||||
logger.info(f'UUID: {result.uuid}')
|
||||
logger.info(f'Fact: {result.fact}')
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
logger.info(f'Valid from: {result.valid_at}')
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
logger.info(f'Valid until: {result.invalid_at}')
|
||||
logger.info('---')
|
||||
|
||||
#################################################
|
||||
# CENTER NODE SEARCH
|
||||
#################################################
|
||||
# For more contextually relevant results, you can
|
||||
# use a center node to rerank search results based
|
||||
# on their graph distance to a specific node
|
||||
#################################################
|
||||
|
||||
if results and len(results) > 0:
|
||||
logger.info('\n=== CENTER NODE SEARCH ===')
|
||||
# Get the source node UUID from the top result
|
||||
center_node_uuid = results[0].source_node_uuid
|
||||
logger.info(f'Using center node UUID: {center_node_uuid}')
|
||||
|
||||
# Perform graph-based search using the source node
|
||||
reranked_results = await graphiti.search(
|
||||
search_query, center_node_uuid=center_node_uuid
|
||||
)
|
||||
|
||||
logger.info('Reranked search results:')
|
||||
for result in reranked_results[:3]:
|
||||
logger.info(f'UUID: {result.uuid}')
|
||||
logger.info(f'Fact: {result.fact}')
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
logger.info(f'Valid from: {result.valid_at}')
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
logger.info(f'Valid until: {result.invalid_at}')
|
||||
logger.info('---')
|
||||
else:
|
||||
logger.info('No results found in the initial search to use as center node.')
|
||||
|
||||
#################################################
|
||||
# NODE SEARCH WITH RECIPES
|
||||
#################################################
|
||||
# Graphiti provides predefined search configurations
|
||||
# (recipes) that optimize search for specific patterns
|
||||
# and use cases.
|
||||
#################################################
|
||||
|
||||
logger.info('\n=== NODE SEARCH WITH RECIPES ===')
|
||||
recipe_query = 'California Governor'
|
||||
logger.info(f'Recipe search query: {recipe_query}')
|
||||
|
||||
# Use hybrid search recipe for balanced semantic and keyword matching
|
||||
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
|
||||
node_search_config.limit = 5 # Limit to 5 results
|
||||
|
||||
# Execute the node search
|
||||
node_search_results = await graphiti._search(
|
||||
query=recipe_query,
|
||||
config=node_search_config,
|
||||
)
|
||||
|
||||
logger.info('Node search results:')
|
||||
for node in node_search_results.nodes:
|
||||
logger.info(f'Node UUID: {node.uuid}')
|
||||
logger.info(f'Node Name: {node.name}')
|
||||
node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
|
||||
logger.info(f'Content Summary: {node_summary}')
|
||||
logger.info(f'Node Labels: {", ".join(node.labels)}')
|
||||
logger.info(f'Created At: {node.created_at}')
|
||||
if hasattr(node, 'attributes') and node.attributes:
|
||||
logger.info('Attributes:')
|
||||
for key, value in node.attributes.items():
|
||||
logger.info(f' {key}: {value}')
|
||||
#################################################
|
||||
# SUMMARY STATISTICS
|
||||
#################################################
|
||||
# Get overall statistics about the knowledge graph
|
||||
#################################################
|
||||
|
||||
logger.info('\n=== SUMMARY ===')
|
||||
logger.info('Memgraph database populated successfully!')
|
||||
logger.info('Knowledge graph is ready for queries and exploration.')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'An error occurred: {e}')
|
||||
raise
|
||||
|
||||
finally:
|
||||
#################################################
|
||||
# CLEANUP
|
||||
#################################################
|
||||
# Always close the connection to Memgraph when
|
||||
# finished to properly release resources
|
||||
#################################################
|
||||
|
||||
# Close the connection
|
||||
await graphiti.close()
|
||||
logger.info('Connection closed.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from neo4j import Neo4jDriver
|
||||
from .memgraph_driver import MemgraphDriver
|
||||
from .neo4j_driver import Neo4jDriver
|
||||
|
||||
__all__ = ['Neo4jDriver']
|
||||
__all__ = ['Neo4jDriver', 'MemgraphDriver']
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class GraphProvider(Enum):
|
|||
FALKORDB = 'falkordb'
|
||||
KUZU = 'kuzu'
|
||||
NEPTUNE = 'neptune'
|
||||
MEMGRAPH = 'memgraph'
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
|
|
|
|||
78
graphiti_core/driver/memgraph_driver.py
Normal file
78
graphiti_core/driver/memgraph_driver.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""
|
||||
Copyright 2025, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemgraphDriver(GraphDriver):
|
||||
provider = GraphProvider.MEMGRAPH
|
||||
|
||||
def __init__(
|
||||
self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'
|
||||
):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
uri=uri,
|
||||
auth=(user or '', password or ''),
|
||||
)
|
||||
self._database = database
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Execute a Cypher query against Memgraph using implicit transactions.
|
||||
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
|
||||
"""
|
||||
# Extract parameters from kwargs
|
||||
params = kwargs.pop('params', None)
|
||||
if params is None:
|
||||
database = kwargs.pop('database_', self._database)
|
||||
kwargs.pop('parameters_', None)
|
||||
params = kwargs
|
||||
else:
|
||||
database = kwargs.pop('database_', self._database)
|
||||
kwargs.pop('parameters_', None)
|
||||
|
||||
async with self.client.session(database=database) as session:
|
||||
try:
|
||||
result = await session.run(cypher_query_, params)
|
||||
keys = result.keys()
|
||||
records = [record async for record in result]
|
||||
summary = await result.consume()
|
||||
return (records, summary, keys)
|
||||
except Exception as e:
|
||||
logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}')
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||
_database = database or self._database
|
||||
return self.client.session(database=_database) # type: ignore
|
||||
|
||||
async def close(self) -> None:
|
||||
return await self.client.close()
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
return self.client.execute_query('DROP ALL INDEXES')
|
||||
|
|
@ -45,6 +45,30 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|||
if provider == GraphProvider.KUZU:
|
||||
return []
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return [
|
||||
'CREATE INDEX ON :Entity(uuid);',
|
||||
'CREATE INDEX ON :Entity(group_id);',
|
||||
'CREATE INDEX ON :Entity(name);',
|
||||
'CREATE INDEX ON :Entity(created_at);',
|
||||
'CREATE INDEX ON :Episodic(uuid);',
|
||||
'CREATE INDEX ON :Episodic(group_id);',
|
||||
'CREATE INDEX ON :Episodic(created_at);',
|
||||
'CREATE INDEX ON :Episodic(valid_at);',
|
||||
'CREATE INDEX ON :Community(uuid);',
|
||||
'CREATE INDEX ON :Community(group_id);',
|
||||
'CREATE INDEX ON :RELATES_TO(uuid);',
|
||||
'CREATE INDEX ON :RELATES_TO(group_id);',
|
||||
'CREATE INDEX ON :RELATES_TO(name);',
|
||||
'CREATE INDEX ON :RELATES_TO(created_at);',
|
||||
'CREATE INDEX ON :RELATES_TO(expired_at);',
|
||||
'CREATE INDEX ON :RELATES_TO(valid_at);',
|
||||
'CREATE INDEX ON :RELATES_TO(invalid_at);',
|
||||
'CREATE INDEX ON :MENTIONS(uuid);',
|
||||
'CREATE INDEX ON :MENTIONS(group_id);',
|
||||
'CREATE INDEX ON :HAS_MEMBER(uuid);',
|
||||
]
|
||||
|
||||
return [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
|
|
@ -115,6 +139,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|||
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
|
||||
]
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return [
|
||||
"""CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id);""",
|
||||
"""CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id);""",
|
||||
"""CREATE TEXT INDEX community_name ON :Community(name, group_id);""",
|
||||
"""CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id);""",
|
||||
]
|
||||
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
|
|
@ -136,6 +168,9 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider)
|
|||
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return f'CALL text_search.search_all("{name}", {query}, {limit})'
|
||||
|
||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||
|
||||
|
||||
|
|
@ -147,6 +182,9 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
|||
if provider == GraphProvider.KUZU:
|
||||
return f'array_cosine_similarity({vec1}, {vec2})'
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return f'vector_search.cosine_similarity({vec1}, {vec2})'
|
||||
|
||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||
|
||||
|
||||
|
|
@ -159,4 +197,7 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s
|
|||
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return f'CALL text_search.search_all_edges("{name}", $query, {limit})'
|
||||
|
||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||
|
|
|
|||
|
|
@ -99,6 +99,15 @@ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False)
|
|||
e.attributes = $attributes
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
SET e.fact_embedding = $edge_data.fact_embedding
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
save_embedding_query = (
|
||||
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
|
||||
|
|
@ -163,6 +172,16 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = Fa
|
|||
e.attributes = $attributes
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET e = edge
|
||||
SET e.fact_embedding = edge.fact_embedding
|
||||
RETURN edge.uuid AS uuid;
|
||||
"""
|
||||
case _:
|
||||
save_embedding_query = (
|
||||
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
|
||||
|
|
@ -261,7 +280,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
|
|||
e.created_at = $created_at
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
case _: # Neo4j and Memgraph
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
|
|
|
|||
|
|
@ -49,6 +49,17 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
|
|||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {
|
||||
uuid: $uuid, name: $name, group_id: $group_id,
|
||||
source_description: $source_description, source: $source,
|
||||
content: $content, entity_edges: $entity_edges,
|
||||
created_at: $created_at, valid_at: $valid_at
|
||||
}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
|
|
@ -91,6 +102,14 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
|||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
|
|
@ -161,6 +180,14 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b
|
|||
SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
WITH n SET n.name_embedding = $entity_data.name_embedding
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
save_embedding_query = (
|
||||
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
|
||||
|
|
@ -233,6 +260,15 @@ def get_entity_node_save_bulk_query(
|
|||
n.attributes = $attributes
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n = node
|
||||
WITH n, node
|
||||
SET n.name_embedding = node.name_embedding
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
save_embedding_query = (
|
||||
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
|
||||
|
|
@ -303,6 +339,13 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
|
|||
n.summary = $summary
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n SET n.name_embedding = $name_embedding
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
|
|
|
|||
|
|
@ -320,6 +320,9 @@ class EpisodicNode(Node):
|
|||
'source': self.source.value,
|
||||
}
|
||||
|
||||
if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH):
|
||||
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_episode_node_save_query(driver.provider), **episode_args
|
||||
)
|
||||
|
|
|
|||
|
|
@ -55,6 +55,14 @@ if os.getenv('DISABLE_KUZU') is None:
|
|||
except ImportError:
|
||||
raise
|
||||
|
||||
if os.getenv('DISABLE_MEMGRAPH') is None:
|
||||
try:
|
||||
from graphiti_core.driver.memgraph_driver import MemgraphDriver
|
||||
|
||||
drivers.append(GraphProvider.MEMGRAPH)
|
||||
except ImportError:
|
||||
raise
|
||||
|
||||
# Disable Neptune for now
|
||||
os.environ['DISABLE_NEPTUNE'] = 'True'
|
||||
if os.getenv('DISABLE_NEPTUNE') is None:
|
||||
|
|
@ -74,6 +82,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
|||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||
|
||||
MEMGRAPH_URI = os.getenv('MEMGRAPH_URI', 'bolt://localhost:7687')
|
||||
MEMGRAPH_USER = os.getenv('MEMGRAPH_USER', '')
|
||||
MEMGRAPH_PASSWORD = os.getenv('MEMGRAPH_PASSWORD', '')
|
||||
|
||||
NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
|
||||
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
|
||||
AOSS_HOST = os.getenv('AOSS_HOST', None)
|
||||
|
|
@ -103,6 +115,12 @@ def get_driver(provider: GraphProvider) -> GraphDriver:
|
|||
db=KUZU_DB,
|
||||
)
|
||||
return driver
|
||||
elif provider == GraphProvider.MEMGRAPH:
|
||||
return MemgraphDriver(
|
||||
uri=MEMGRAPH_URI,
|
||||
user=MEMGRAPH_USER,
|
||||
password=MEMGRAPH_PASSWORD,
|
||||
)
|
||||
elif provider == GraphProvider.NEPTUNE:
|
||||
return NeptuneDriver(
|
||||
host=NEPTUNE_HOST,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue