Amazon Neptune Support (#793)
* Rebased Neptune changes based on significant rework done * Updated the README documentation * Fixed linting and formatting * Update README.md Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update graphiti_core/driver/neptune_driver.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update README.md Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Addressed feedback from code review * Updated the README documentation for clarity * Updated the README and neptune_driver based on PR feedback * Update node_db_queries.py --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
This commit is contained in:
parent
9c1e1ad7ef
commit
ef56dc779a
15 changed files with 3805 additions and 2460 deletions
40
README.md
40
README.md
|
|
@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
|
||||||
Requirements:
|
Requirements:
|
||||||
|
|
||||||
- Python 3.10 or higher
|
- Python 3.10 or higher
|
||||||
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
|
- Neo4j 5.26 / FalkorDB 1.1.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
|
||||||
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
|
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
|
|
@ -148,6 +148,17 @@ pip install graphiti-core[falkordb]
|
||||||
uv add graphiti-core[falkordb]
|
uv add graphiti-core[falkordb]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Installing with Amazon Neptune Support
|
||||||
|
|
||||||
|
If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install graphiti-core[neptune]
|
||||||
|
|
||||||
|
# or with uv
|
||||||
|
uv add graphiti-core[neptune]
|
||||||
|
```
|
||||||
|
|
||||||
### You can also install optional LLM providers as extras:
|
### You can also install optional LLM providers as extras:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
@ -165,6 +176,9 @@ pip install graphiti-core[anthropic,groq,google-genai]
|
||||||
|
|
||||||
# Install with FalkorDB and LLM providers
|
# Install with FalkorDB and LLM providers
|
||||||
pip install graphiti-core[falkordb,anthropic,google-genai]
|
pip install graphiti-core[falkordb,anthropic,google-genai]
|
||||||
|
|
||||||
|
# Install with Amazon Neptune
|
||||||
|
pip install graphiti-core[neptune]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Default to Low Concurrency; LLM Provider 429 Rate Limit Errors
|
## Default to Low Concurrency; LLM Provider 429 Rate Limit Errors
|
||||||
|
|
@ -184,7 +198,7 @@ If your LLM provider allows higher throughput, you can increase `SEMAPHORE_LIMIT
|
||||||
|
|
||||||
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
|
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
|
||||||
|
|
||||||
1. Connecting to a Neo4j or FalkorDB database
|
1. Connecting to a Neo4j, Amazon Neptune, or FalkorDB database
|
||||||
2. Initializing Graphiti indices and constraints
|
2. Initializing Graphiti indices and constraints
|
||||||
3. Adding episodes to the graph (both text and structured JSON)
|
3. Adding episodes to the graph (both text and structured JSON)
|
||||||
4. Searching for relationships (edges) using hybrid search
|
4. Searching for relationships (edges) using hybrid search
|
||||||
|
|
@ -267,6 +281,26 @@ driver = FalkorDriver(
|
||||||
graphiti = Graphiti(graph_driver=driver)
|
graphiti = Graphiti(graph_driver=driver)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Amazon Neptune
|
||||||
|
|
||||||
|
```python
|
||||||
|
from graphiti_core import Graphiti
|
||||||
|
from graphiti_core.driver.neptune_driver import NeptuneDriver
|
||||||
|
|
||||||
|
# Create a FalkorDB driver with custom database name
|
||||||
|
driver = NeptuneDriver(
|
||||||
|
host=<NEPTUNE ENDPOINT>,
|
||||||
|
aoss_host=<Amazon OpenSearch Serverless Host>,
|
||||||
|
port=<PORT> # Optional, defaults to 8182,
|
||||||
|
aoss_port=<PORT> # Optional, defaults to 443
|
||||||
|
)
|
||||||
|
|
||||||
|
driver = NeptuneDriver(host=neptune_uri, aoss_host=aoss_host, port=neptune_port)
|
||||||
|
|
||||||
|
# Pass the driver to Graphiti
|
||||||
|
graphiti = Graphiti(graph_driver=driver)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Performance Configuration
|
### Performance Configuration
|
||||||
|
|
||||||
|
|
@ -458,7 +492,7 @@ When you initialize a Graphiti instance, we collect:
|
||||||
- **Graphiti version**: The version you're using
|
- **Graphiti version**: The version you're using
|
||||||
- **Configuration choices**:
|
- **Configuration choices**:
|
||||||
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
|
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
|
||||||
- Database backend (Neo4j, FalkorDB)
|
- Database backend (Neo4j, FalkorDB, Amazon Neptune Database or Neptune Analytics)
|
||||||
- Embedder provider (OpenAI, Azure, Voyage, etc.)
|
- Embedder provider (OpenAI, Azure, Voyage, etc.)
|
||||||
|
|
||||||
### What We Don't Collect
|
### What We Don't Collect
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ This example demonstrates the basic functionality of Graphiti, including:
|
||||||
- A local DBMS created and started in Neo4j Desktop
|
- A local DBMS created and started in Neo4j Desktop
|
||||||
- **For FalkorDB**:
|
- **For FalkorDB**:
|
||||||
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
|
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
|
||||||
|
- **For Amazon Neptune**:
|
||||||
|
- Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup)
|
||||||
|
|
||||||
|
|
||||||
## Setup Instructions
|
## Setup Instructions
|
||||||
|
|
@ -42,9 +44,19 @@ export NEO4J_PASSWORD=password
|
||||||
# Optional FalkorDB connection parameters (defaults shown)
|
# Optional FalkorDB connection parameters (defaults shown)
|
||||||
export FALKORDB_URI=falkor://localhost:6379
|
export FALKORDB_URI=falkor://localhost:6379
|
||||||
|
|
||||||
|
# Optional Amazon Neptune connection parameters
|
||||||
|
NEPTUNE_HOST=your_neptune_host
|
||||||
|
NEPTUNE_PORT=your_port_or_8182
|
||||||
|
AOSS_HOST=your_aoss_host
|
||||||
|
AOSS_PORT=your_port_or_443
|
||||||
|
|
||||||
# To use a different database, modify the driver constructor in the script
|
# To use a different database, modify the driver constructor in the script
|
||||||
```
|
```
|
||||||
|
|
||||||
|
TIP: For Amazon Neptune host string please use the following formats
|
||||||
|
* For Neptune Database: `neptune-db://<cluster endpoint>`
|
||||||
|
* For Neptune Analytics: `neptune-graph://<graph identifier>`
|
||||||
|
|
||||||
3. Run the example:
|
3. Run the example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
@ -52,11 +64,14 @@ python quickstart_neo4j.py
|
||||||
|
|
||||||
# For FalkorDB
|
# For FalkorDB
|
||||||
python quickstart_falkordb.py
|
python quickstart_falkordb.py
|
||||||
|
|
||||||
|
# For Amazon Neptune
|
||||||
|
python quickstart_neptune.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## What This Example Demonstrates
|
## What This Example Demonstrates
|
||||||
|
|
||||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j or FalkorDB
|
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB
|
||||||
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
|
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
|
||||||
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
|
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
|
||||||
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance
|
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
|
||||||
class GraphProvider(Enum):
|
class GraphProvider(Enum):
|
||||||
NEO4J = 'neo4j'
|
NEO4J = 'neo4j'
|
||||||
FALKORDB = 'falkordb'
|
FALKORDB = 'falkordb'
|
||||||
|
NEPTUNE = 'neptune'
|
||||||
|
|
||||||
|
|
||||||
class GraphDriverSession(ABC):
|
class GraphDriverSession(ABC):
|
||||||
|
|
|
||||||
299
graphiti_core/driver/neptune_driver.py
Normal file
299
graphiti_core/driver/neptune_driver.py
Normal file
|
|
@ -0,0 +1,299 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from collections.abc import Coroutine
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||||
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_SIZE = 10
|
||||||
|
|
||||||
|
aoss_indices = [
|
||||||
|
{
|
||||||
|
'index_name': 'node_name_and_summary',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'summary': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'community_name',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'episode_content',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'content': {'type': 'text'},
|
||||||
|
'source': {'type': 'text'},
|
||||||
|
'source_description': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {
|
||||||
|
'multi_match': {
|
||||||
|
'query': '',
|
||||||
|
'fields': ['content', 'source', 'source_description', 'group_id'],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'edge_name_and_fact',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'fact': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneDriver(GraphDriver):
|
||||||
|
provider: GraphProvider = GraphProvider.NEPTUNE
|
||||||
|
|
||||||
|
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
|
||||||
|
"""This initializes a NeptuneDriver for use with Neptune as a backend
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host (str): The Neptune Database or Neptune Analytics host
|
||||||
|
aoss_host (str): The OpenSearch host value
|
||||||
|
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
|
||||||
|
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
|
||||||
|
"""
|
||||||
|
if not host:
|
||||||
|
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
|
||||||
|
|
||||||
|
if host.startswith('neptune-db://'):
|
||||||
|
# This is a Neptune Database Cluster
|
||||||
|
endpoint = host.replace('neptune-db://', '')
|
||||||
|
self.client = NeptuneGraph(endpoint, port)
|
||||||
|
logger.debug('Creating Neptune Database session for %s', host)
|
||||||
|
elif host.startswith('neptune-graph://'):
|
||||||
|
# This is a Neptune Analytics Graph
|
||||||
|
graphId = host.replace('neptune-graph://', '')
|
||||||
|
self.client = NeptuneAnalyticsGraph(graphId)
|
||||||
|
logger.debug('Creating Neptune Graph session for %s', host)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not aoss_host:
|
||||||
|
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
|
||||||
|
|
||||||
|
session = boto3.Session()
|
||||||
|
self.aoss_client = OpenSearch(
|
||||||
|
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||||
|
http_auth=Urllib3AWSV4SignerAuth(
|
||||||
|
session.get_credentials(), session.region_name, 'aoss'
|
||||||
|
),
|
||||||
|
use_ssl=True,
|
||||||
|
verify_certs=True,
|
||||||
|
connection_class=Urllib3HttpConnection,
|
||||||
|
pool_maxsize=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sanitize_parameters(self, query, params: dict):
|
||||||
|
if isinstance(query, list):
|
||||||
|
queries = []
|
||||||
|
for q in query:
|
||||||
|
queries.append(self._sanitize_parameters(q, params))
|
||||||
|
return queries
|
||||||
|
else:
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, datetime.datetime):
|
||||||
|
params[k] = v.isoformat()
|
||||||
|
elif isinstance(v, list):
|
||||||
|
# Handle lists that might contain datetime objects
|
||||||
|
for i, item in enumerate(v):
|
||||||
|
if isinstance(item, datetime.datetime):
|
||||||
|
v[i] = item.isoformat()
|
||||||
|
query = str(query).replace(f'${k}', f'datetime(${k})')
|
||||||
|
if isinstance(item, dict):
|
||||||
|
query = self._sanitize_parameters(query, v[i])
|
||||||
|
|
||||||
|
# If the list contains datetime objects, we need to wrap each element with datetime()
|
||||||
|
if any(isinstance(item, str) and 'T' in item for item in v):
|
||||||
|
# Create a new list expression with datetime() wrapped around each element
|
||||||
|
datetime_list = (
|
||||||
|
'['
|
||||||
|
+ ', '.join(
|
||||||
|
f'datetime("{item}")'
|
||||||
|
if isinstance(item, str) and 'T' in item
|
||||||
|
else repr(item)
|
||||||
|
for item in v
|
||||||
|
)
|
||||||
|
+ ']'
|
||||||
|
)
|
||||||
|
query = str(query).replace(f'${k}', datetime_list)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
query = self._sanitize_parameters(query, v)
|
||||||
|
return query
|
||||||
|
|
||||||
|
async def execute_query(
|
||||||
|
self, cypher_query_, **kwargs: Any
|
||||||
|
) -> tuple[dict[str, Any], None, None]:
|
||||||
|
params = dict(kwargs)
|
||||||
|
if isinstance(cypher_query_, list):
|
||||||
|
for q in cypher_query_:
|
||||||
|
result, _, _ = self._run_query(q[0], q[1])
|
||||||
|
return result, None, None
|
||||||
|
else:
|
||||||
|
return self._run_query(cypher_query_, params)
|
||||||
|
|
||||||
|
def _run_query(self, cypher_query_, params):
|
||||||
|
cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
|
||||||
|
try:
|
||||||
|
result = self.client.query(cypher_query_, params=params)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('Query: %s', cypher_query_)
|
||||||
|
logger.error('Parameters: %s', params)
|
||||||
|
logger.error('Error executing query: %s', e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return result, None, None
|
||||||
|
|
||||||
|
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||||
|
return NeptuneDriverSession(driver=self)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
return self.client.client.close()
|
||||||
|
|
||||||
|
async def _delete_all_data(self) -> Any:
|
||||||
|
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
||||||
|
|
||||||
|
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||||
|
return self.delete_all_indexes_impl()
|
||||||
|
|
||||||
|
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||||
|
# No matter what happens above, always return True
|
||||||
|
return self.delete_aoss_indices()
|
||||||
|
|
||||||
|
async def create_aoss_indices(self):
|
||||||
|
for index in aoss_indices:
|
||||||
|
index_name = index['index_name']
|
||||||
|
client = self.aoss_client
|
||||||
|
if not client.indices.exists(index=index_name):
|
||||||
|
client.indices.create(index=index_name, body=index['body'])
|
||||||
|
# Sleep for 1 minute to let the index creation complete
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
async def delete_aoss_indices(self):
|
||||||
|
for index in aoss_indices:
|
||||||
|
index_name = index['index_name']
|
||||||
|
client = self.aoss_client
|
||||||
|
if client.indices.exists(index=index_name):
|
||||||
|
client.indices.delete(index=index_name)
|
||||||
|
|
||||||
|
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
||||||
|
for index in aoss_indices:
|
||||||
|
if name.lower() == index['index_name']:
|
||||||
|
index['query']['query']['multi_match']['query'] = query_text
|
||||||
|
query = {'size': limit, 'query': index['query']}
|
||||||
|
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
||||||
|
return resp
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
|
for index in aoss_indices:
|
||||||
|
if name.lower() == index['index_name']:
|
||||||
|
to_index = []
|
||||||
|
for d in data:
|
||||||
|
item = {'_index': name}
|
||||||
|
for p in index['body']['mappings']['properties']:
|
||||||
|
item[p] = d[p]
|
||||||
|
to_index.append(item)
|
||||||
|
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||||
|
if failed > 0:
|
||||||
|
return success
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneDriverSession(GraphDriverSession):
|
||||||
|
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
|
||||||
|
self.driver = driver
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
# No cleanup needed for Neptune, but method must exist
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
# No explicit close needed for Neptune, but method must exist
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_write(self, func, *args, **kwargs):
|
||||||
|
# Directly await the provided async function with `self` as the transaction/session
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
||||||
|
if isinstance(query, list):
|
||||||
|
res = None
|
||||||
|
for q in query:
|
||||||
|
res = await self.driver.execute_query(q, **kwargs)
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return await self.driver.execute_query(str(query), **kwargs)
|
||||||
|
|
@ -24,13 +24,14 @@ from uuid import uuid4
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import parse_db_date
|
||||||
from graphiti_core.models.edges.edge_db_queries import (
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
COMMUNITY_EDGE_RETURN,
|
COMMUNITY_EDGE_RETURN,
|
||||||
ENTITY_EDGE_RETURN,
|
ENTITY_EDGE_RETURN,
|
||||||
|
ENTITY_EDGE_RETURN_NEPTUNE,
|
||||||
EPISODIC_EDGE_RETURN,
|
EPISODIC_EDGE_RETURN,
|
||||||
EPISODIC_EDGE_SAVE,
|
EPISODIC_EDGE_SAVE,
|
||||||
get_community_edge_save_query,
|
get_community_edge_save_query,
|
||||||
|
|
@ -214,11 +215,19 @@ class EntityEdge(Edge):
|
||||||
return self.fact_embedding
|
return self.fact_embedding
|
||||||
|
|
||||||
async def load_fact_embedding(self, driver: GraphDriver):
|
async def load_fact_embedding(self, driver: GraphDriver):
|
||||||
records, _, _ = await driver.execute_query(
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
|
query: LiteralString = """
|
||||||
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||||
|
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||||
"""
|
"""
|
||||||
|
else:
|
||||||
|
query: LiteralString = """
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||||
RETURN e.fact_embedding AS fact_embedding
|
RETURN e.fact_embedding AS fact_embedding
|
||||||
""",
|
"""
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -246,6 +255,9 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
edge_data.update(self.attributes or {})
|
edge_data.update(self.attributes or {})
|
||||||
|
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
|
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_entity_edge_save_query(driver.provider),
|
get_entity_edge_save_query(driver.provider),
|
||||||
edge_data=edge_data,
|
edge_data=edge_data,
|
||||||
|
|
@ -262,7 +274,11 @@ class EntityEdge(Edge):
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN,
|
+ (
|
||||||
|
ENTITY_EDGE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else ENTITY_EDGE_RETURN
|
||||||
|
),
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -284,7 +300,11 @@ class EntityEdge(Edge):
|
||||||
WHERE e.uuid IN $uuids
|
WHERE e.uuid IN $uuids
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN,
|
+ (
|
||||||
|
ENTITY_EDGE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else ENTITY_EDGE_RETURN
|
||||||
|
),
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -321,7 +341,11 @@ class EntityEdge(Edge):
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN
|
+ (
|
||||||
|
ENTITY_EDGE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else ENTITY_EDGE_RETURN
|
||||||
|
)
|
||||||
+ with_embeddings_query
|
+ with_embeddings_query
|
||||||
+ """
|
+ """
|
||||||
ORDER BY e.uuid DESC
|
ORDER BY e.uuid DESC
|
||||||
|
|
@ -346,7 +370,11 @@ class EntityEdge(Edge):
|
||||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN,
|
+ (
|
||||||
|
ENTITY_EDGE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else ENTITY_EDGE_RETURN
|
||||||
|
),
|
||||||
node_uuid=node_uuid,
|
node_uuid=node_uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -43,47 +43,70 @@ EPISODIC_EDGE_RETURN = """
|
||||||
|
|
||||||
|
|
||||||
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
||||||
if provider == GraphProvider.FALKORDB:
|
match provider:
|
||||||
return """
|
case GraphProvider.FALKORDB:
|
||||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
return """
|
||||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||||
SET e = $edge_data
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||||
RETURN e.uuid AS uuid
|
SET e = $edge_data
|
||||||
"""
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
return """
|
case GraphProvider.NEPTUNE:
|
||||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
return """
|
||||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||||
SET e = $edge_data
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||||
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
|
||||||
RETURN e.uuid AS uuid
|
SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
|
||||||
"""
|
SET e.episodes = join($edge_data.episodes, ",")
|
||||||
|
RETURN $edge_data.uuid AS uuid
|
||||||
|
"""
|
||||||
|
case _: # Neo4j
|
||||||
|
return """
|
||||||
|
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||||
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||||
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||||
|
SET e = $edge_data
|
||||||
|
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
||||||
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||||
if provider == GraphProvider.FALKORDB:
|
match provider:
|
||||||
return """
|
case GraphProvider.FALKORDB:
|
||||||
UNWIND $entity_edges AS edge
|
return """
|
||||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
UNWIND $entity_edges AS edge
|
||||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||||
WITH r, edge
|
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
||||||
RETURN edge.uuid AS uuid
|
WITH r, edge
|
||||||
"""
|
RETURN edge.uuid AS uuid
|
||||||
|
"""
|
||||||
return """
|
case GraphProvider.NEPTUNE:
|
||||||
UNWIND $entity_edges AS edge
|
return """
|
||||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
UNWIND $entity_edges AS edge
|
||||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
SET e = edge
|
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
|
||||||
RETURN edge.uuid AS uuid
|
SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
|
||||||
"""
|
SET r.episodes = join(edge.episodes, ",")
|
||||||
|
RETURN edge.uuid AS uuid
|
||||||
|
"""
|
||||||
|
case _:
|
||||||
|
return """
|
||||||
|
UNWIND $entity_edges AS edge
|
||||||
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
|
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
|
SET e = edge
|
||||||
|
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
||||||
|
RETURN edge.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
ENTITY_EDGE_RETURN = """
|
ENTITY_EDGE_RETURN = """
|
||||||
|
|
@ -101,24 +124,51 @@ ENTITY_EDGE_RETURN = """
|
||||||
properties(e) AS attributes
|
properties(e) AS attributes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ENTITY_EDGE_RETURN_NEPTUNE = """
|
||||||
|
e.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
e.name AS name,
|
||||||
|
e.fact AS fact,
|
||||||
|
split(e.episodes, ',') AS episodes,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.expired_at AS expired_at,
|
||||||
|
e.valid_at AS valid_at,
|
||||||
|
e.invalid_at AS invalid_at,
|
||||||
|
properties(e) AS attributes
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_community_edge_save_query(provider: GraphProvider) -> str:
|
def get_community_edge_save_query(provider: GraphProvider) -> str:
|
||||||
if provider == GraphProvider.FALKORDB:
|
match provider:
|
||||||
return """
|
case GraphProvider.FALKORDB:
|
||||||
MATCH (community:Community {uuid: $community_uuid})
|
return """
|
||||||
MATCH (node {uuid: $entity_uuid})
|
MATCH (community:Community {uuid: $community_uuid})
|
||||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
MATCH (node {uuid: $entity_uuid})
|
||||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||||
RETURN e.uuid AS uuid
|
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
"""
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
return """
|
case GraphProvider.NEPTUNE:
|
||||||
MATCH (community:Community {uuid: $community_uuid})
|
return """
|
||||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
MATCH (community:Community {uuid: $community_uuid})
|
||||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
MATCH (node {uuid: $entity_uuid})
|
||||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
WHERE node:Entity OR node:Community
|
||||||
RETURN e.uuid AS uuid
|
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||||
"""
|
SET r.uuid= $uuid
|
||||||
|
SET r.group_id= $group_id
|
||||||
|
SET r.created_at= $created_at
|
||||||
|
RETURN r.uuid AS uuid
|
||||||
|
"""
|
||||||
|
case _: # Neo4j
|
||||||
|
return """
|
||||||
|
MATCH (community:Community {uuid: $community_uuid})
|
||||||
|
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||||
|
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||||
|
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
COMMUNITY_EDGE_RETURN = """
|
COMMUNITY_EDGE_RETURN = """
|
||||||
|
|
|
||||||
|
|
@ -18,21 +18,45 @@ from typing import Any
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphProvider
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
|
|
||||||
EPISODIC_NODE_SAVE = """
|
|
||||||
MERGE (n:Episodic {uuid: $uuid})
|
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
|
||||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
|
||||||
RETURN n.uuid AS uuid
|
|
||||||
"""
|
|
||||||
|
|
||||||
EPISODIC_NODE_SAVE_BULK = """
|
def get_episode_node_save_query(provider: GraphProvider) -> str:
|
||||||
UNWIND $episodes AS episode
|
match provider:
|
||||||
MERGE (n:Episodic {uuid: episode.uuid})
|
case GraphProvider.NEPTUNE:
|
||||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
return """
|
||||||
source: episode.source, content: episode.content,
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||||
RETURN n.uuid AS uuid
|
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
|
||||||
"""
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
case _: # Neo4j and FalkorDB
|
||||||
|
return """
|
||||||
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||||
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
||||||
|
match provider:
|
||||||
|
case GraphProvider.NEPTUNE:
|
||||||
|
return """
|
||||||
|
UNWIND $episodes AS episode
|
||||||
|
MERGE (n:Episodic {uuid: episode.uuid})
|
||||||
|
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||||
|
source: episode.source, content: episode.content,
|
||||||
|
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
case _: # Neo4j and FalkorDB
|
||||||
|
return """
|
||||||
|
UNWIND $episodes AS episode
|
||||||
|
MERGE (n:Episodic {uuid: episode.uuid})
|
||||||
|
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
||||||
|
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
EPISODIC_NODE_RETURN = """
|
EPISODIC_NODE_RETURN = """
|
||||||
e.content AS content,
|
e.content AS content,
|
||||||
|
|
@ -46,54 +70,96 @@ EPISODIC_NODE_RETURN = """
|
||||||
e.entity_edges AS entity_edges
|
e.entity_edges AS entity_edges
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
EPISODIC_NODE_RETURN_NEPTUNE = """
|
||||||
|
e.content AS content,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.valid_at AS valid_at,
|
||||||
|
e.uuid AS uuid,
|
||||||
|
e.name AS name,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
e.source_description AS source_description,
|
||||||
|
e.source AS source,
|
||||||
|
split(e.entity_edges, ",") AS entity_edges
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
||||||
if provider == GraphProvider.FALKORDB:
|
match provider:
|
||||||
return f"""
|
case GraphProvider.FALKORDB:
|
||||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
return f"""
|
||||||
SET n:{labels}
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||||
SET n = $entity_data
|
SET n:{labels}
|
||||||
RETURN n.uuid AS uuid
|
SET n = $entity_data
|
||||||
"""
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
return f"""
|
case GraphProvider.NEPTUNE:
|
||||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
label_subquery = ''
|
||||||
SET n:{labels}
|
for label in labels.split(':'):
|
||||||
SET n = $entity_data
|
label_subquery += f' SET n:{label}\n'
|
||||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
return f"""
|
||||||
RETURN n.uuid AS uuid
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||||
"""
|
{label_subquery}
|
||||||
|
SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
|
||||||
|
SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
case _:
|
||||||
|
return f"""
|
||||||
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||||
|
SET n:{labels}
|
||||||
|
SET n = $entity_data
|
||||||
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
|
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
|
||||||
if provider == GraphProvider.FALKORDB:
|
match provider:
|
||||||
queries = []
|
case GraphProvider.FALKORDB:
|
||||||
for node in nodes:
|
queries = []
|
||||||
for label in node['labels']:
|
for node in nodes:
|
||||||
|
for label in node['labels']:
|
||||||
|
queries.append(
|
||||||
|
(
|
||||||
|
f"""
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (n:Entity {{uuid: node.uuid}})
|
||||||
|
SET n:{label}
|
||||||
|
SET n = node
|
||||||
|
WITH n, node
|
||||||
|
SET n.name_embedding = vecf32(node.name_embedding)
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
""",
|
||||||
|
{'nodes': [node]},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return queries
|
||||||
|
case GraphProvider.NEPTUNE:
|
||||||
|
queries = []
|
||||||
|
for node in nodes:
|
||||||
|
labels = ''
|
||||||
|
for label in node['labels']:
|
||||||
|
labels += f' SET n:{label}\n'
|
||||||
queries.append(
|
queries.append(
|
||||||
(
|
f"""
|
||||||
f"""
|
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MERGE (n:Entity {{uuid: node.uuid}})
|
MERGE (n:Entity {{uuid: node.uuid}})
|
||||||
SET n:{label}
|
{labels}
|
||||||
SET n = node
|
SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
|
||||||
WITH n, node
|
SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
|
||||||
SET n.name_embedding = vecf32(node.name_embedding)
|
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
""",
|
"""
|
||||||
{'nodes': [node]},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return queries
|
return queries
|
||||||
|
case _: # Neo4j
|
||||||
return """
|
return """
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MERGE (n:Entity {uuid: node.uuid})
|
MERGE (n:Entity {uuid: node.uuid})
|
||||||
SET n:$(node.labels)
|
SET n:$(node.labels)
|
||||||
SET n = node
|
SET n = node
|
||||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
ENTITY_NODE_RETURN = """
|
ENTITY_NODE_RETURN = """
|
||||||
|
|
@ -108,19 +174,27 @@ ENTITY_NODE_RETURN = """
|
||||||
|
|
||||||
|
|
||||||
def get_community_node_save_query(provider: GraphProvider) -> str:
|
def get_community_node_save_query(provider: GraphProvider) -> str:
|
||||||
if provider == GraphProvider.FALKORDB:
|
match provider:
|
||||||
return """
|
case GraphProvider.FALKORDB:
|
||||||
MERGE (n:Community {uuid: $uuid})
|
return """
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
|
MERGE (n:Community {uuid: $uuid})
|
||||||
RETURN n.uuid AS uuid
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
|
||||||
"""
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
return """
|
case GraphProvider.NEPTUNE:
|
||||||
MERGE (n:Community {uuid: $uuid})
|
return """
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
MERGE (n:Community {uuid: $uuid})
|
||||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
RETURN n.uuid AS uuid
|
SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
|
||||||
"""
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
case _: # Neo4j
|
||||||
|
return """
|
||||||
|
MERGE (n:Community {uuid: $uuid})
|
||||||
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
COMMUNITY_NODE_RETURN = """
|
COMMUNITY_NODE_RETURN = """
|
||||||
|
|
@ -131,3 +205,12 @@ COMMUNITY_NODE_RETURN = """
|
||||||
n.summary AS summary,
|
n.summary AS summary,
|
||||||
n.created_at AS created_at
|
n.created_at AS created_at
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
COMMUNITY_NODE_RETURN_NEPTUNE = """
|
||||||
|
n.uuid AS uuid,
|
||||||
|
n.name AS name,
|
||||||
|
[x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.summary AS summary,
|
||||||
|
n.created_at AS created_at
|
||||||
|
"""
|
||||||
|
|
|
||||||
|
|
@ -31,11 +31,13 @@ from graphiti_core.errors import NodeNotFoundError
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import parse_db_date
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
COMMUNITY_NODE_RETURN,
|
COMMUNITY_NODE_RETURN,
|
||||||
|
COMMUNITY_NODE_RETURN_NEPTUNE,
|
||||||
ENTITY_NODE_RETURN,
|
ENTITY_NODE_RETURN,
|
||||||
EPISODIC_NODE_RETURN,
|
EPISODIC_NODE_RETURN,
|
||||||
EPISODIC_NODE_SAVE,
|
EPISODIC_NODE_RETURN_NEPTUNE,
|
||||||
get_community_node_save_query,
|
get_community_node_save_query,
|
||||||
get_entity_node_save_query,
|
get_entity_node_save_query,
|
||||||
|
get_episode_node_save_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.datetime_utils import utc_now
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
|
||||||
|
|
@ -89,23 +91,24 @@ class Node(BaseModel, ABC):
|
||||||
async def save(self, driver: GraphDriver): ...
|
async def save(self, driver: GraphDriver): ...
|
||||||
|
|
||||||
async def delete(self, driver: GraphDriver):
|
async def delete(self, driver: GraphDriver):
|
||||||
if driver.provider == GraphProvider.FALKORDB:
|
match driver.provider:
|
||||||
for label in ['Entity', 'Episodic', 'Community']:
|
case GraphProvider.NEO4J:
|
||||||
await driver.execute_query(
|
await driver.execute_query(
|
||||||
f"""
|
"""
|
||||||
MATCH (n:{label} {{uuid: $uuid}})
|
|
||||||
DETACH DELETE n
|
|
||||||
""",
|
|
||||||
uuid=self.uuid,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await driver.execute_query(
|
|
||||||
"""
|
|
||||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
""",
|
""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
)
|
)
|
||||||
|
case _: # FalkorDB and Neptune
|
||||||
|
for label in ['Entity', 'Episodic', 'Community']:
|
||||||
|
await driver.execute_query(
|
||||||
|
f"""
|
||||||
|
MATCH (n:{label} {{uuid: $uuid}})
|
||||||
|
DETACH DELETE n
|
||||||
|
""",
|
||||||
|
uuid=self.uuid,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Node: {self.uuid}')
|
logger.debug(f'Deleted Node: {self.uuid}')
|
||||||
|
|
||||||
|
|
@ -119,28 +122,30 @@ class Node(BaseModel, ABC):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
||||||
if driver.provider == GraphProvider.FALKORDB:
|
match driver.provider:
|
||||||
for label in ['Entity', 'Episodic', 'Community']:
|
case GraphProvider.NEO4J:
|
||||||
await driver.execute_query(
|
async with driver.session() as session:
|
||||||
f"""
|
await session.run(
|
||||||
MATCH (n:{label} {{group_id: $group_id}})
|
"""
|
||||||
DETACH DELETE n
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||||
""",
|
CALL {
|
||||||
group_id=group_id,
|
WITH n
|
||||||
)
|
DETACH DELETE n
|
||||||
else:
|
} IN TRANSACTIONS OF $batch_size ROWS
|
||||||
async with driver.session() as session:
|
""",
|
||||||
await session.run(
|
group_id=group_id,
|
||||||
"""
|
batch_size=batch_size,
|
||||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
)
|
||||||
CALL {
|
|
||||||
WITH n
|
case _: # FalkorDB and Neptune
|
||||||
|
for label in ['Entity', 'Episodic', 'Community']:
|
||||||
|
await driver.execute_query(
|
||||||
|
f"""
|
||||||
|
MATCH (n:{label} {{group_id: $group_id}})
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
} IN TRANSACTIONS OF $batch_size ROWS
|
""",
|
||||||
""",
|
group_id=group_id,
|
||||||
group_id=group_id,
|
)
|
||||||
batch_size=batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
||||||
|
|
@ -189,8 +194,21 @@ class EpisodicNode(Node):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def save(self, driver: GraphDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
|
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||||
|
'episode_content',
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'uuid': self.uuid,
|
||||||
|
'group_id': self.group_id,
|
||||||
|
'source': self.source.value,
|
||||||
|
'content': self.content,
|
||||||
|
'source_description': self.source_description,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
EPISODIC_NODE_SAVE,
|
get_episode_node_save_query(driver.provider),
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
|
|
@ -213,7 +231,11 @@ class EpisodicNode(Node):
|
||||||
MATCH (e:Episodic {uuid: $uuid})
|
MATCH (e:Episodic {uuid: $uuid})
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ EPISODIC_NODE_RETURN,
|
+ (
|
||||||
|
EPISODIC_NODE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else EPISODIC_NODE_RETURN
|
||||||
|
),
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -233,7 +255,11 @@ class EpisodicNode(Node):
|
||||||
WHERE e.uuid IN $uuids
|
WHERE e.uuid IN $uuids
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
"""
|
"""
|
||||||
+ EPISODIC_NODE_RETURN,
|
+ (
|
||||||
|
EPISODIC_NODE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else EPISODIC_NODE_RETURN
|
||||||
|
),
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -262,7 +288,11 @@ class EpisodicNode(Node):
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
"""
|
"""
|
||||||
+ EPISODIC_NODE_RETURN
|
+ (
|
||||||
|
EPISODIC_NODE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else EPISODIC_NODE_RETURN
|
||||||
|
)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY uuid DESC
|
ORDER BY uuid DESC
|
||||||
"""
|
"""
|
||||||
|
|
@ -284,7 +314,11 @@ class EpisodicNode(Node):
|
||||||
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
"""
|
"""
|
||||||
+ EPISODIC_NODE_RETURN,
|
+ (
|
||||||
|
EPISODIC_NODE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else EPISODIC_NODE_RETURN
|
||||||
|
),
|
||||||
entity_node_uuid=entity_node_uuid,
|
entity_node_uuid=entity_node_uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -311,11 +345,18 @@ class EntityNode(Node):
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
async def load_name_embedding(self, driver: GraphDriver):
|
async def load_name_embedding(self, driver: GraphDriver):
|
||||||
records, _, _ = await driver.execute_query(
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
|
query: LiteralString = """
|
||||||
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
|
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
else:
|
||||||
RETURN n.name_embedding AS name_embedding
|
query: LiteralString = """
|
||||||
""",
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
|
RETURN n.name_embedding AS name_embedding
|
||||||
|
"""
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -336,6 +377,9 @@ class EntityNode(Node):
|
||||||
}
|
}
|
||||||
entity_data.update(self.attributes or {})
|
entity_data.update(self.attributes or {})
|
||||||
|
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
|
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
labels = ':'.join(self.labels + ['Entity'])
|
labels = ':'.join(self.labels + ['Entity'])
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
|
|
@ -433,8 +477,13 @@ class CommunityNode(Node):
|
||||||
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
||||||
|
|
||||||
async def save(self, driver: GraphDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
|
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||||
|
'community_name',
|
||||||
|
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
||||||
|
)
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_community_node_save_query(driver.provider),
|
get_community_node_save_query(driver.provider), # type: ignore
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
|
|
@ -457,11 +506,19 @@ class CommunityNode(Node):
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
async def load_name_embedding(self, driver: GraphDriver):
|
async def load_name_embedding(self, driver: GraphDriver):
|
||||||
records, _, _ = await driver.execute_query(
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
|
query: LiteralString = """
|
||||||
|
MATCH (c:Community {uuid: $uuid})
|
||||||
|
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||||
"""
|
"""
|
||||||
|
else:
|
||||||
|
query: LiteralString = """
|
||||||
MATCH (c:Community {uuid: $uuid})
|
MATCH (c:Community {uuid: $uuid})
|
||||||
RETURN c.name_embedding AS name_embedding
|
RETURN c.name_embedding AS name_embedding
|
||||||
""",
|
"""
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -478,7 +535,11 @@ class CommunityNode(Node):
|
||||||
MATCH (n:Community {uuid: $uuid})
|
MATCH (n:Community {uuid: $uuid})
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ COMMUNITY_NODE_RETURN,
|
+ (
|
||||||
|
COMMUNITY_NODE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else COMMUNITY_NODE_RETURN
|
||||||
|
),
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -498,7 +559,11 @@ class CommunityNode(Node):
|
||||||
WHERE n.uuid IN $uuids
|
WHERE n.uuid IN $uuids
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ COMMUNITY_NODE_RETURN,
|
+ (
|
||||||
|
COMMUNITY_NODE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else COMMUNITY_NODE_RETURN
|
||||||
|
),
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -527,7 +592,11 @@ class CommunityNode(Node):
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ COMMUNITY_NODE_RETURN
|
+ (
|
||||||
|
COMMUNITY_NODE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else COMMUNITY_NODE_RETURN
|
||||||
|
)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY n.uuid DESC
|
ORDER BY n.uuid DESC
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -32,8 +32,8 @@ from graphiti_core.models.edges.edge_db_queries import (
|
||||||
get_entity_edge_save_bulk_query,
|
get_entity_edge_save_bulk_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
EPISODIC_NODE_SAVE_BULK,
|
|
||||||
get_entity_node_save_bulk_query,
|
get_entity_node_save_bulk_query,
|
||||||
|
get_episode_node_save_bulk_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
|
|
@ -155,7 +155,7 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
edge_data.update(edge.attributes or {})
|
edge_data.update(edge.attributes or {})
|
||||||
edges.append(edge_data)
|
edges.append(edge_data)
|
||||||
|
|
||||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||||
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
|
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
|
||||||
await tx.run(entity_node_save_bulk, nodes=nodes)
|
await tx.run(entity_node_save_bulk, nodes=nodes)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from time import time
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||||
from graphiti_core.edges import (
|
from graphiti_core.edges import (
|
||||||
CommunityEdge,
|
CommunityEdge,
|
||||||
EntityEdge,
|
EntityEdge,
|
||||||
|
|
@ -504,23 +504,46 @@ async def resolve_extracted_edge(
|
||||||
async def filter_existing_duplicate_of_edges(
|
async def filter_existing_duplicate_of_edges(
|
||||||
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
||||||
) -> list[tuple[EntityNode, EntityNode]]:
|
) -> list[tuple[EntityNode, EntityNode]]:
|
||||||
query: LiteralString = """
|
if not duplicates_node_tuples:
|
||||||
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
return []
|
||||||
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
|
||||||
RETURN DISTINCT
|
|
||||||
n.uuid AS source_uuid,
|
|
||||||
m.uuid AS target_uuid
|
|
||||||
"""
|
|
||||||
|
|
||||||
duplicate_nodes_map = {
|
duplicate_nodes_map = {
|
||||||
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
||||||
}
|
}
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query,
|
query: LiteralString = """
|
||||||
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
|
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||||
routing_='r',
|
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
|
||||||
)
|
RETURN DISTINCT
|
||||||
|
n.uuid AS source_uuid,
|
||||||
|
m.uuid AS target_uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
duplicate_nodes = [
|
||||||
|
{'source': source.uuid, 'target': target.uuid}
|
||||||
|
for source, target in duplicates_node_tuples
|
||||||
|
]
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
duplicate_node_uuids=duplicate_nodes,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query: LiteralString = """
|
||||||
|
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||||
|
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
||||||
|
RETURN DISTINCT
|
||||||
|
n.uuid AS source_uuid,
|
||||||
|
m.uuid AS target_uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
||||||
for record in records:
|
for record in records:
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,13 @@ from datetime import datetime
|
||||||
|
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
from graphiti_core.helpers import semaphore_gather
|
from graphiti_core.helpers import semaphore_gather
|
||||||
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
|
EPISODIC_NODE_RETURN,
|
||||||
|
EPISODIC_NODE_RETURN_NEPTUNE,
|
||||||
|
)
|
||||||
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
||||||
|
|
||||||
EPISODE_WINDOW_LEN = 3
|
EPISODE_WINDOW_LEN = 3
|
||||||
|
|
@ -31,6 +34,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
|
return # Neptune does not need indexes built
|
||||||
if delete_existing:
|
if delete_existing:
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
|
|
@ -71,7 +76,7 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||||
|
|
||||||
async def delete_group_ids(tx):
|
async def delete_group_ids(tx):
|
||||||
await tx.run(
|
await tx.run(
|
||||||
'MATCH (n:Entity|Episodic|Community) WHERE n.group_id IN $group_ids DETACH DELETE n',
|
'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n',
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -117,7 +122,11 @@ async def retrieve_episodes(
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ EPISODIC_NODE_RETURN
|
+ (
|
||||||
|
EPISODIC_NODE_RETURN_NEPTUNE
|
||||||
|
if driver.provider == GraphProvider.NEPTUNE
|
||||||
|
else EPISODIC_NODE_RETURN
|
||||||
|
)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY e.valid_at DESC
|
ORDER BY e.valid_at DESC
|
||||||
LIMIT $num_episodes
|
LIMIT $num_episodes
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ dependencies = [
|
||||||
"tenacity>=9.0.0",
|
"tenacity>=9.0.0",
|
||||||
"numpy>=1.0.0",
|
"numpy>=1.0.0",
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
"posthog>=3.0.0",
|
"posthog>=3.0.0"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|
@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
|
||||||
falkordb = ["falkordb>=1.1.2,<2.0.0"]
|
falkordb = ["falkordb>=1.1.2,<2.0.0"]
|
||||||
voyageai = ["voyageai>=0.2.3"]
|
voyageai = ["voyageai>=0.2.3"]
|
||||||
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
||||||
|
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
|
||||||
dev = [
|
dev = [
|
||||||
"pyright>=1.1.380",
|
"pyright>=1.1.380",
|
||||||
"groq>=0.2.0",
|
"groq>=0.2.0",
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,14 @@ import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
|
from graphiti_core.driver.neptune_driver import NeptuneDriver
|
||||||
from graphiti_core.helpers import lucene_sanitize
|
from graphiti_core.helpers import lucene_sanitize
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
HAS_NEO4J = False
|
HAS_NEO4J = False
|
||||||
HAS_FALKORDB = False
|
HAS_FALKORDB = False
|
||||||
|
HAS_NEPTUNE = False
|
||||||
if os.getenv('DISABLE_NEO4J') is None:
|
if os.getenv('DISABLE_NEO4J') is None:
|
||||||
try:
|
try:
|
||||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
|
|
@ -42,6 +44,14 @@ if os.getenv('DISABLE_FALKORDB') is None:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if os.getenv('DISABLE_NEPTUNE') is None:
|
||||||
|
try:
|
||||||
|
from graphiti_core.driver.neptune_driver import NeptuneDriver
|
||||||
|
|
||||||
|
HAS_NEPTUNE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
||||||
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
||||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
||||||
|
|
@ -51,6 +61,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||||
|
|
||||||
|
NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
|
||||||
|
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
|
||||||
|
AOSS_HOST = os.getenv('AOSS_HOST', None)
|
||||||
|
|
||||||
|
|
||||||
def get_driver(driver_name: str) -> GraphDriver:
|
def get_driver(driver_name: str) -> GraphDriver:
|
||||||
if driver_name == 'neo4j':
|
if driver_name == 'neo4j':
|
||||||
|
|
@ -66,6 +80,12 @@ def get_driver(driver_name: str) -> GraphDriver:
|
||||||
username=FALKORDB_USER,
|
username=FALKORDB_USER,
|
||||||
password=FALKORDB_PASSWORD,
|
password=FALKORDB_PASSWORD,
|
||||||
)
|
)
|
||||||
|
elif driver_name == 'neptune':
|
||||||
|
return NeptuneDriver(
|
||||||
|
host=NEPTUNE_HOST,
|
||||||
|
port=int(NEPTUNE_PORT),
|
||||||
|
aoss_host=AOSS_HOST,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Driver {driver_name} not available')
|
raise ValueError(f'Driver {driver_name} not available')
|
||||||
|
|
||||||
|
|
@ -75,6 +95,8 @@ if HAS_NEO4J:
|
||||||
drivers.append('neo4j')
|
drivers.append('neo4j')
|
||||||
if HAS_FALKORDB:
|
if HAS_FALKORDB:
|
||||||
drivers.append('falkordb')
|
drivers.append('falkordb')
|
||||||
|
if HAS_NEPTUNE:
|
||||||
|
drivers.append('neptune')
|
||||||
|
|
||||||
|
|
||||||
def test_lucene_sanitize():
|
def test_lucene_sanitize():
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue