Add support for Kuzu as the graph driver (#799)
* Fix FalkoDB tests * Add support for graph memory using Kuzu * Fix lints * Fix queries * Add tests * Add comments * Add more test coverage * Add mocked tests * Format * Add mocked tests II * Refactor community queries * Add more mocked tests * Refactor tests to always cleanup * Add more mocked tests * Update kuzu * Refactor how filters are built * Add more mocked tests * Refactor and cleanup * Fix tests * Fix lints * Refactor tests * Disable neptune * Fix * Update kuzu version * Update kuzu to latest release * Fix filter * Fix query * Fix Neptune query * Fix bulk queries * Fix lints * Fix deletes * Comments and format * Add Kuzu to the README * Fix bulk queries * Test all fields of nodes and edges * Fix lints * Update search_utils.py --------- Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
This commit is contained in:
parent
309159bccb
commit
8802b7db13
30 changed files with 4219 additions and 966 deletions
1
.github/workflows/unit_tests.yml
vendored
1
.github/workflows/unit_tests.yml
vendored
|
|
@ -49,6 +49,7 @@ jobs:
|
|||
NEO4J_URI: bolt://localhost:7687
|
||||
NEO4J_USER: neo4j
|
||||
NEO4J_PASSWORD: testpass
|
||||
DISABLE_NEPTUNE: 1
|
||||
run: |
|
||||
uv run pytest -m "not integration"
|
||||
- name: Wait for FalkorDB
|
||||
|
|
|
|||
34
README.md
34
README.md
|
|
@ -44,7 +44,7 @@ Use Graphiti to:
|
|||
<br />
|
||||
|
||||
<p align="center">
|
||||
<img src="images/graphiti-graph-intro.gif" alt="Graphiti temporal walkthrough" width="700px">
|
||||
<img src="images/graphiti-graph-intro.gif" alt="Graphiti temporal walkthrough" width="700px">
|
||||
</p>
|
||||
|
||||
<br />
|
||||
|
|
@ -80,7 +80,7 @@ Traditional RAG approaches often rely on batch processing and static data summar
|
|||
- **Scalability:** Efficiently manages large datasets with parallel processing, suitable for enterprise environments.
|
||||
|
||||
<p align="center">
|
||||
<img src="/images/graphiti-intro-slides-stock-2.gif" alt="Graphiti structured + unstructured demo" width="700px">
|
||||
<img src="/images/graphiti-intro-slides-stock-2.gif" alt="Graphiti structured + unstructured demo" width="700px">
|
||||
</p>
|
||||
|
||||
## Graphiti vs. GraphRAG
|
||||
|
|
@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
|
|||
Requirements:
|
||||
|
||||
- Python 3.10 or higher
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
|
||||
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
|
@ -148,6 +148,17 @@ pip install graphiti-core[falkordb]
|
|||
uv add graphiti-core[falkordb]
|
||||
```
|
||||
|
||||
### Installing with Kuzu Support
|
||||
|
||||
If you plan to use Kuzu as your graph database backend, install with the Kuzu extra:
|
||||
|
||||
```bash
|
||||
pip install graphiti-core[kuzu]
|
||||
|
||||
# or with uv
|
||||
uv add graphiti-core[kuzu]
|
||||
```
|
||||
|
||||
### Installing with Amazon Neptune Support
|
||||
|
||||
If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra:
|
||||
|
|
@ -198,7 +209,7 @@ If your LLM provider allows higher throughput, you can increase `SEMAPHORE_LIMIT
|
|||
|
||||
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
|
||||
|
||||
1. Connecting to a Neo4j, Amazon Neptune, or FalkorDB database
|
||||
1. Connecting to a Neo4j, Amazon Neptune, FalkorDB, or Kuzu database
|
||||
2. Initializing Graphiti indices and constraints
|
||||
3. Adding episodes to the graph (both text and structured JSON)
|
||||
4. Searching for relationships (edges) using hybrid search
|
||||
|
|
@ -281,6 +292,19 @@ driver = FalkorDriver(
|
|||
graphiti = Graphiti(graph_driver=driver)
|
||||
```
|
||||
|
||||
#### Kuzu
|
||||
|
||||
```python
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.driver.kuzu_driver import KuzuDriver
|
||||
|
||||
# Create a Kuzu driver
|
||||
driver = KuzuDriver(db="/tmp/graphiti.kuzu")
|
||||
|
||||
# Pass the driver to Graphiti
|
||||
graphiti = Graphiti(graph_driver=driver)
|
||||
```
|
||||
|
||||
#### Amazon Neptune
|
||||
|
||||
```python
|
||||
|
|
@ -494,7 +518,7 @@ When you initialize a Graphiti instance, we collect:
|
|||
- **Graphiti version**: The version you're using
|
||||
- **Configuration choices**:
|
||||
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
|
||||
- Database backend (Neo4j, FalkorDB, Amazon Neptune Database or Neptune Analytics)
|
||||
- Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics)
|
||||
- Embedder provider (OpenAI, Azure, Voyage, etc.)
|
||||
|
||||
### What We Don't Collect
|
||||
|
|
|
|||
|
|
@ -4,3 +4,7 @@ import sys
|
|||
# This code adds the project root directory to the Python path, allowing imports to work correctly when running tests.
|
||||
# Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests.
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))
|
||||
|
||||
from tests.helpers_test import graph_driver, mock_embedder
|
||||
|
||||
__all__ = ['graph_driver', 'mock_embedder']
|
||||
|
|
|
|||
|
|
@ -27,10 +27,13 @@ logger = logging.getLogger(__name__)
|
|||
class GraphProvider(Enum):
|
||||
NEO4J = 'neo4j'
|
||||
FALKORDB = 'falkordb'
|
||||
KUZU = 'kuzu'
|
||||
NEPTUNE = 'neptune'
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
provider: GraphProvider
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -33,11 +32,14 @@ else:
|
|||
) from None
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FalkorDriverSession(GraphDriverSession):
|
||||
provider = GraphProvider.FALKORDB
|
||||
|
||||
def __init__(self, graph: FalkorGraph):
|
||||
self.graph = graph
|
||||
|
||||
|
|
@ -164,16 +166,3 @@ class FalkorDriver(GraphDriver):
|
|||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
|
||||
return cloned
|
||||
|
||||
|
||||
def convert_datetimes_to_strings(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_datetimes_to_strings(item) for item in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
return obj
|
||||
|
|
|
|||
175
graphiti_core/driver/kuzu_driver.py
Normal file
175
graphiti_core/driver/kuzu_driver.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import kuzu
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Kuzu requires an explicit schema.
|
||||
# As Kuzu currently does not support creating full text indexes on edge properties,
|
||||
# we work around this by representing (n:Entity)-[:RELATES_TO]->(m:Entity) as
|
||||
# (n)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m).
|
||||
SCHEMA_QUERIES = """
|
||||
CREATE NODE TABLE IF NOT EXISTS Episodic (
|
||||
uuid STRING PRIMARY KEY,
|
||||
name STRING,
|
||||
group_id STRING,
|
||||
created_at TIMESTAMP,
|
||||
source STRING,
|
||||
source_description STRING,
|
||||
content STRING,
|
||||
valid_at TIMESTAMP,
|
||||
entity_edges STRING[]
|
||||
);
|
||||
CREATE NODE TABLE IF NOT EXISTS Entity (
|
||||
uuid STRING PRIMARY KEY,
|
||||
name STRING,
|
||||
group_id STRING,
|
||||
labels STRING[],
|
||||
created_at TIMESTAMP,
|
||||
name_embedding FLOAT[],
|
||||
summary STRING,
|
||||
attributes STRING
|
||||
);
|
||||
CREATE NODE TABLE IF NOT EXISTS Community (
|
||||
uuid STRING PRIMARY KEY,
|
||||
name STRING,
|
||||
group_id STRING,
|
||||
created_at TIMESTAMP,
|
||||
name_embedding FLOAT[],
|
||||
summary STRING
|
||||
);
|
||||
CREATE NODE TABLE IF NOT EXISTS RelatesToNode_ (
|
||||
uuid STRING PRIMARY KEY,
|
||||
group_id STRING,
|
||||
created_at TIMESTAMP,
|
||||
name STRING,
|
||||
fact STRING,
|
||||
fact_embedding FLOAT[],
|
||||
episodes STRING[],
|
||||
expired_at TIMESTAMP,
|
||||
valid_at TIMESTAMP,
|
||||
invalid_at TIMESTAMP,
|
||||
attributes STRING
|
||||
);
|
||||
CREATE REL TABLE IF NOT EXISTS RELATES_TO(
|
||||
FROM Entity TO RelatesToNode_,
|
||||
FROM RelatesToNode_ TO Entity
|
||||
);
|
||||
CREATE REL TABLE IF NOT EXISTS MENTIONS(
|
||||
FROM Episodic TO Entity,
|
||||
uuid STRING PRIMARY KEY,
|
||||
group_id STRING,
|
||||
created_at TIMESTAMP
|
||||
);
|
||||
CREATE REL TABLE IF NOT EXISTS HAS_MEMBER(
|
||||
FROM Community TO Entity,
|
||||
FROM Community TO Community,
|
||||
uuid STRING,
|
||||
group_id STRING,
|
||||
created_at TIMESTAMP
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
class KuzuDriver(GraphDriver):
|
||||
provider: GraphProvider = GraphProvider.KUZU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: str = ':memory:',
|
||||
max_concurrent_queries: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.db = kuzu.Database(db)
|
||||
|
||||
self.setup_schema()
|
||||
|
||||
self.client = kuzu.AsyncConnection(self.db, max_concurrent_queries=max_concurrent_queries)
|
||||
|
||||
async def execute_query(
|
||||
self, cypher_query_: str, **kwargs: Any
|
||||
) -> tuple[list[dict[str, Any]] | list[list[dict[str, Any]]], None, None]:
|
||||
params = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# Kuzu does not support these parameters.
|
||||
params.pop('database_', None)
|
||||
params.pop('routing_', None)
|
||||
|
||||
try:
|
||||
results = await self.client.execute(cypher_query_, parameters=params)
|
||||
except Exception as e:
|
||||
params = {k: (v[:5] if isinstance(v, list) else v) for k, v in params.items()}
|
||||
logger.error(f'Error executing Kuzu query: {e}\n{cypher_query_}\n{params}')
|
||||
raise
|
||||
|
||||
if not results:
|
||||
return [], None, None
|
||||
|
||||
if isinstance(results, list):
|
||||
dict_results = [list(result.rows_as_dict()) for result in results]
|
||||
else:
|
||||
dict_results = list(results.rows_as_dict())
|
||||
return dict_results, None, None # type: ignore
|
||||
|
||||
def session(self, _database: str | None = None) -> GraphDriverSession:
|
||||
return KuzuDriverSession(self)
|
||||
|
||||
async def close(self):
|
||||
# Do not explicity close the connection, instead rely on GC.
|
||||
pass
|
||||
|
||||
def delete_all_indexes(self, database_: str):
|
||||
pass
|
||||
|
||||
def setup_schema(self):
|
||||
conn = kuzu.Connection(self.db)
|
||||
conn.execute(SCHEMA_QUERIES)
|
||||
conn.close()
|
||||
|
||||
|
||||
class KuzuDriverSession(GraphDriverSession):
|
||||
provider = GraphProvider.KUZU
|
||||
|
||||
def __init__(self, driver: KuzuDriver):
|
||||
self.driver = driver
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
# No cleanup needed for Kuzu, but method must exist.
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
# Do not close the session here, as we're reusing the driver connection.
|
||||
pass
|
||||
|
||||
async def execute_write(self, func, *args, **kwargs):
|
||||
# Directly await the provided async function with `self` as the transaction/session
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
||||
if isinstance(query, list):
|
||||
for cypher, params in query:
|
||||
await self.driver.execute_query(cypher, **params)
|
||||
else:
|
||||
await self.driver.execute_query(query, **kwargs)
|
||||
return None
|
||||
|
|
@ -271,6 +271,8 @@ class NeptuneDriver(GraphDriver):
|
|||
|
||||
|
||||
class NeptuneDriverSession(GraphDriverSession):
|
||||
provider = GraphProvider.NEPTUNE
|
||||
|
||||
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
|
||||
self.driver = driver
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
|
@ -30,11 +31,10 @@ from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
COMMUNITY_EDGE_RETURN,
|
||||
ENTITY_EDGE_RETURN,
|
||||
ENTITY_EDGE_RETURN_NEPTUNE,
|
||||
EPISODIC_EDGE_RETURN,
|
||||
EPISODIC_EDGE_SAVE,
|
||||
get_community_edge_save_query,
|
||||
get_entity_edge_return_query,
|
||||
get_entity_edge_save_query,
|
||||
)
|
||||
from graphiti_core.nodes import Node
|
||||
|
|
@ -53,33 +53,63 @@ class Edge(BaseModel, ABC):
|
|||
async def save(self, driver: GraphDriver): ...
|
||||
|
||||
async def delete(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:RelatesToNode_ {uuid: $uuid})
|
||||
DETACH DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
else:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
|
||||
WHERE e.uuid IN $uuids
|
||||
DELETE e
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
|
||||
WHERE e.uuid IN $uuids
|
||||
DELETE e
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:RelatesToNode_)
|
||||
WHERE e.uuid IN $uuids
|
||||
DETACH DELETE e
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
else:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
|
||||
WHERE e.uuid IN $uuids
|
||||
DELETE e
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edges: {uuids}')
|
||||
|
||||
return result
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.uuid)
|
||||
|
||||
|
|
@ -166,7 +196,7 @@ class EpisodicEdge(Edge):
|
|||
"""
|
||||
+ EPISODIC_EDGE_RETURN
|
||||
+ """
|
||||
ORDER BY e.uuid DESC
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
|
|
@ -215,15 +245,21 @@ class EntityEdge(Edge):
|
|||
return self.fact_embedding
|
||||
|
||||
async def load_fact_embedding(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||
"""
|
||||
else:
|
||||
query: LiteralString = """
|
||||
query = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN e.fact_embedding AS fact_embedding
|
||||
"""
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||
"""
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
query = """
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
||||
RETURN e.fact_embedding AS fact_embedding
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -253,15 +289,22 @@ class EntityEdge(Edge):
|
|||
'invalid_at': self.invalid_at,
|
||||
}
|
||||
|
||||
edge_data.update(self.attributes or {})
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
edge_data['attributes'] = json.dumps(self.attributes)
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
**edge_data,
|
||||
)
|
||||
else:
|
||||
edge_data.update(self.attributes or {})
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
edge_data=edge_data,
|
||||
)
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
edge_data=edge_data,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||
|
||||
|
|
@ -269,21 +312,25 @@ class EntityEdge(Edge):
|
|||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
match_query = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
"""
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
match_query = """
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ (
|
||||
ENTITY_EDGE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else ENTITY_EDGE_RETURN
|
||||
),
|
||||
+ get_entity_edge_return_query(driver.provider),
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
||||
|
||||
if len(edges) == 0:
|
||||
raise EdgeNotFoundError(uuid)
|
||||
|
|
@ -294,22 +341,26 @@ class EntityEdge(Edge):
|
|||
if len(uuids) == 0:
|
||||
return []
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
match_query = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
match_query = """
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query
|
||||
+ """
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ (
|
||||
ENTITY_EDGE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else ENTITY_EDGE_RETURN
|
||||
),
|
||||
+ get_entity_edge_return_query(driver.provider),
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
|
|
@ -332,23 +383,27 @@ class EntityEdge(Edge):
|
|||
else ''
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
match_query = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
match_query = """
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query
|
||||
+ """
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ (
|
||||
ENTITY_EDGE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else ENTITY_EDGE_RETURN
|
||||
)
|
||||
+ get_entity_edge_return_query(driver.provider)
|
||||
+ with_embeddings_query
|
||||
+ """
|
||||
ORDER BY e.uuid DESC
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
|
|
@ -357,7 +412,7 @@ class EntityEdge(Edge):
|
|||
routing_='r',
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
||||
|
||||
if len(edges) == 0:
|
||||
raise GroupsEdgesNotFoundError(group_ids)
|
||||
|
|
@ -365,21 +420,25 @@ class EntityEdge(Edge):
|
|||
|
||||
@classmethod
|
||||
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
match_query = """
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
match_query = """
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ (
|
||||
ENTITY_EDGE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else ENTITY_EDGE_RETURN
|
||||
),
|
||||
+ get_entity_edge_return_query(driver.provider),
|
||||
node_uuid=node_uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
|
|
@ -479,7 +538,25 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|||
)
|
||||
|
||||
|
||||
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||
def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
|
||||
episodes = record['episodes']
|
||||
if provider == GraphProvider.KUZU:
|
||||
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
||||
else:
|
||||
attributes = record['attributes']
|
||||
attributes.pop('uuid', None)
|
||||
attributes.pop('source_node_uuid', None)
|
||||
attributes.pop('target_node_uuid', None)
|
||||
attributes.pop('fact', None)
|
||||
attributes.pop('fact_embedding', None)
|
||||
attributes.pop('name', None)
|
||||
attributes.pop('group_id', None)
|
||||
attributes.pop('episodes', None)
|
||||
attributes.pop('created_at', None)
|
||||
attributes.pop('expired_at', None)
|
||||
attributes.pop('valid_at', None)
|
||||
attributes.pop('invalid_at', None)
|
||||
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
|
|
@ -488,26 +565,14 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|||
fact_embedding=record.get('fact_embedding'),
|
||||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
episodes=record['episodes'],
|
||||
episodes=episodes,
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
attributes=record['attributes'],
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
edge.attributes.pop('uuid', None)
|
||||
edge.attributes.pop('source_node_uuid', None)
|
||||
edge.attributes.pop('target_node_uuid', None)
|
||||
edge.attributes.pop('fact', None)
|
||||
edge.attributes.pop('name', None)
|
||||
edge.attributes.pop('group_id', None)
|
||||
edge.attributes.pop('episodes', None)
|
||||
edge.attributes.pop('created_at', None)
|
||||
edge.attributes.pop('expired_at', None)
|
||||
edge.attributes.pop('valid_at', None)
|
||||
edge.attributes.pop('invalid_at', None)
|
||||
|
||||
return edge
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,13 @@ NEO4J_TO_FALKORDB_MAPPING = {
|
|||
'episode_content': 'Episodic',
|
||||
'edge_name_and_fact': 'RELATES_TO',
|
||||
}
|
||||
# Mapping from fulltext index names to Kuzu node labels
|
||||
INDEX_TO_LABEL_KUZU_MAPPING = {
|
||||
'node_name_and_summary': 'Entity',
|
||||
'community_name': 'Community',
|
||||
'episode_content': 'Episodic',
|
||||
'edge_name_and_fact': 'RelatesToNode_',
|
||||
}
|
||||
|
||||
|
||||
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||
|
|
@ -35,6 +42,9 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|||
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||
]
|
||||
|
||||
if provider == GraphProvider.KUZU:
|
||||
return []
|
||||
|
||||
return [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
|
|
@ -68,6 +78,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|||
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
||||
]
|
||||
|
||||
if provider == GraphProvider.KUZU:
|
||||
return [
|
||||
"CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
|
||||
"CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
|
||||
"CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
|
||||
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
|
||||
]
|
||||
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
|
|
@ -80,11 +98,15 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|||
]
|
||||
|
||||
|
||||
def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
|
||||
def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
||||
|
||||
if provider == GraphProvider.KUZU:
|
||||
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
||||
|
||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||
|
||||
|
||||
|
|
@ -93,12 +115,19 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
|||
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
||||
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
||||
|
||||
if provider == GraphProvider.KUZU:
|
||||
return f'array_cosine_similarity({vec1}, {vec2})'
|
||||
|
||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||
|
||||
|
||||
def get_relationships_query(name: str, provider: GraphProvider) -> str:
|
||||
def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
||||
|
||||
if provider == GraphProvider.KUZU:
|
||||
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
||||
|
||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||
|
|
|
|||
|
|
@ -1070,7 +1070,7 @@ class Graphiti:
|
|||
if record['episode_count'] == 1:
|
||||
nodes_to_delete.append(node)
|
||||
|
||||
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
|
||||
await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
|
||||
|
||||
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
|
||||
await episode.delete(self.driver)
|
||||
|
|
|
|||
|
|
@ -43,14 +43,14 @@ RUNTIME_QUERY: LiteralString = (
|
|||
)
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
||||
return (
|
||||
neo_date.to_native()
|
||||
if isinstance(neo_date, neo4j_time.DateTime)
|
||||
else datetime.fromisoformat(neo_date)
|
||||
if neo_date
|
||||
else None
|
||||
)
|
||||
def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
||||
if isinstance(input_date, neo4j_time.DateTime):
|
||||
return input_date.to_native()
|
||||
|
||||
if isinstance(input_date, str):
|
||||
return datetime.fromisoformat(input_date)
|
||||
|
||||
return input_date
|
||||
|
||||
|
||||
def get_default_group_id(provider: GraphProvider) -> str:
|
||||
|
|
|
|||
|
|
@ -20,18 +20,36 @@ EPISODIC_EDGE_SAVE = """
|
|||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
MATCH (node:Entity {uuid: $entity_uuid})
|
||||
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
|
||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
SET
|
||||
e.group_id = $group_id,
|
||||
e.created_at = $created_at
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
EPISODIC_EDGE_SAVE_BULK = """
|
||||
UNWIND $episodic_edges AS edge
|
||||
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
||||
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
|
||||
SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
def get_episodic_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.KUZU:
|
||||
return """
|
||||
MATCH (episode:Episodic {uuid: $source_node_uuid})
|
||||
MATCH (node:Entity {uuid: $target_node_uuid})
|
||||
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
|
||||
SET
|
||||
e.group_id = $group_id,
|
||||
e.created_at = $created_at
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
UNWIND $episodic_edges AS edge
|
||||
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
||||
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
|
||||
SET
|
||||
e.group_id = edge.group_id,
|
||||
e.created_at = edge.created_at
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
EPISODIC_EDGE_RETURN = """
|
||||
e.uuid AS uuid,
|
||||
|
|
@ -54,14 +72,32 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
|||
"""
|
||||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
|
||||
SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
|
||||
SET e.episodes = join($edge_data.episodes, ",")
|
||||
RETURN $edge_data.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.KUZU:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
|
||||
SET
|
||||
e.group_id = $group_id,
|
||||
e.created_at = $created_at,
|
||||
e.name = $name,
|
||||
e.fact = $fact,
|
||||
e.fact_embedding = $fact_embedding,
|
||||
e.episodes = $episodes,
|
||||
e.expired_at = $expired_at,
|
||||
e.valid_at = $valid_at,
|
||||
e.invalid_at = $invalid_at,
|
||||
e.attributes = $attributes
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
|
|
@ -89,14 +125,32 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
|||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
|
||||
SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
|
||||
SET r.episodes = join(edge.episodes, ",")
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.KUZU:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $source_node_uuid})
|
||||
MATCH (target:Entity {uuid: $target_node_uuid})
|
||||
MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
|
||||
SET
|
||||
e.group_id = $group_id,
|
||||
e.created_at = $created_at,
|
||||
e.name = $name,
|
||||
e.fact = $fact,
|
||||
e.fact_embedding = $fact_embedding,
|
||||
e.episodes = $episodes,
|
||||
e.expired_at = $expired_at,
|
||||
e.valid_at = $valid_at,
|
||||
e.invalid_at = $invalid_at,
|
||||
e.attributes = $attributes
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
|
|
@ -109,35 +163,42 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
|||
"""
|
||||
|
||||
|
||||
ENTITY_EDGE_RETURN = """
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.group_id AS group_id,
|
||||
e.name AS name,
|
||||
e.fact AS fact,
|
||||
e.episodes AS episodes,
|
||||
e.created_at AS created_at,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at,
|
||||
properties(e) AS attributes
|
||||
"""
|
||||
def get_entity_edge_return_query(provider: GraphProvider) -> str:
|
||||
# `fact_embedding` is not returned by default and must be manually loaded using `load_fact_embedding()`.
|
||||
|
||||
ENTITY_EDGE_RETURN_NEPTUNE = """
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.group_id AS group_id,
|
||||
e.name AS name,
|
||||
e.fact AS fact,
|
||||
split(e.episodes, ',') AS episodes,
|
||||
e.created_at AS created_at,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at,
|
||||
properties(e) AS attributes
|
||||
"""
|
||||
if provider == GraphProvider.NEPTUNE:
|
||||
return """
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.group_id AS group_id,
|
||||
e.name AS name,
|
||||
e.fact AS fact,
|
||||
split(e.episodes, ',') AS episodes,
|
||||
e.created_at AS created_at,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at,
|
||||
properties(e) AS attributes
|
||||
"""
|
||||
|
||||
return """
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.group_id AS group_id,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.fact AS fact,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at,
|
||||
""" + (
|
||||
'e.attributes AS attributes'
|
||||
if provider == GraphProvider.KUZU
|
||||
else 'properties(e) AS attributes'
|
||||
)
|
||||
|
||||
|
||||
def get_community_edge_save_query(provider: GraphProvider) -> str:
|
||||
|
|
@ -152,7 +213,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
|
|||
"""
|
||||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node {uuid: $entity_uuid})
|
||||
WHERE node:Entity OR node:Community
|
||||
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
|
|
@ -161,6 +222,24 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
|
|||
SET r.created_at= $created_at
|
||||
RETURN r.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.KUZU:
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity {uuid: $entity_uuid})
|
||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET
|
||||
e.group_id = $group_id,
|
||||
e.created_at = $created_at
|
||||
RETURN e.uuid AS uuid
|
||||
UNION
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Community {uuid: $entity_uuid})
|
||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET
|
||||
e.group_id = $group_id,
|
||||
e.created_at = $created_at
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
|
|
|
|||
|
|
@ -24,10 +24,24 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
|
|||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.KUZU:
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET
|
||||
n.name = $name,
|
||||
n.group_id = $group_id,
|
||||
n.created_at = $created_at,
|
||||
n.source = $source,
|
||||
n.source_description = $source_description,
|
||||
n.content = $content,
|
||||
n.valid_at = $valid_at,
|
||||
n.entity_edges = $entity_edges
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
|
|
@ -51,11 +65,25 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
|||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||
source: episode.source, content: episode.content,
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||
source: episode.source, content: episode.content,
|
||||
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.KUZU:
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET
|
||||
n.name = $name,
|
||||
n.group_id = $group_id,
|
||||
n.created_at = $created_at,
|
||||
n.source = $source,
|
||||
n.source_description = $source_description,
|
||||
n.content = $content,
|
||||
n.valid_at = $valid_at,
|
||||
n.entity_edges = $entity_edges
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
|
|
@ -76,14 +104,14 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
|||
|
||||
|
||||
EPISODIC_NODE_RETURN = """
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.created_at AS created_at,
|
||||
e.source AS source,
|
||||
e.source_description AS source_description,
|
||||
e.content AS content,
|
||||
e.valid_at AS valid_at,
|
||||
e.entity_edges AS entity_edges
|
||||
"""
|
||||
|
||||
|
|
@ -109,6 +137,20 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
|||
SET n = $entity_data
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.KUZU:
|
||||
return """
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET
|
||||
n.name = $name,
|
||||
n.group_id = $group_id,
|
||||
n.labels = $labels,
|
||||
n.created_at = $created_at,
|
||||
n.name_embedding = $name_embedding,
|
||||
n.summary = $summary,
|
||||
n.attributes = $attributes
|
||||
WITH n
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.NEPTUNE:
|
||||
label_subquery = ''
|
||||
for label in labels.split(':'):
|
||||
|
|
@ -168,6 +210,19 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
|||
"""
|
||||
)
|
||||
return queries
|
||||
case GraphProvider.KUZU:
|
||||
return """
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET
|
||||
n.name = $name,
|
||||
n.group_id = $group_id,
|
||||
n.labels = $labels,
|
||||
n.created_at = $created_at,
|
||||
n.name_embedding = $name_embedding,
|
||||
n.summary = $summary,
|
||||
n.attributes = $attributes
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
UNWIND $nodes AS node
|
||||
|
|
@ -179,15 +234,28 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
|||
"""
|
||||
|
||||
|
||||
ENTITY_NODE_RETURN = """
|
||||
n.uuid AS uuid,
|
||||
n.name AS name,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
"""
|
||||
def get_entity_node_return_query(provider: GraphProvider) -> str:
|
||||
# `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
|
||||
if provider == GraphProvider.KUZU:
|
||||
return """
|
||||
n.uuid AS uuid,
|
||||
n.name AS name,
|
||||
n.group_id AS group_id,
|
||||
n.labels AS labels,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
n.attributes AS attributes
|
||||
"""
|
||||
|
||||
return """
|
||||
n.uuid AS uuid,
|
||||
n.name AS name,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
"""
|
||||
|
||||
|
||||
def get_community_node_save_query(provider: GraphProvider) -> str:
|
||||
|
|
@ -201,10 +269,21 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
|
|||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.KUZU:
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET
|
||||
n.name = $name,
|
||||
n.group_id = $group_id,
|
||||
n.created_at = $created_at,
|
||||
n.name_embedding = $name_embedding,
|
||||
n.summary = $summary
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
|
|
@ -215,12 +294,12 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
|
|||
|
||||
|
||||
COMMUNITY_NODE_RETURN = """
|
||||
n.uuid AS uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.summary AS summary,
|
||||
n.created_at AS created_at
|
||||
c.uuid AS uuid,
|
||||
c.name AS name,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.name_embedding AS name_embedding,
|
||||
c.summary AS summary
|
||||
"""
|
||||
|
||||
COMMUNITY_NODE_RETURN_NEPTUNE = """
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
|
@ -32,10 +33,10 @@ from graphiti_core.helpers import parse_db_date
|
|||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
COMMUNITY_NODE_RETURN,
|
||||
COMMUNITY_NODE_RETURN_NEPTUNE,
|
||||
ENTITY_NODE_RETURN,
|
||||
EPISODIC_NODE_RETURN,
|
||||
EPISODIC_NODE_RETURN_NEPTUNE,
|
||||
get_community_node_save_query,
|
||||
get_entity_node_return_query,
|
||||
get_entity_node_save_query,
|
||||
get_episode_node_save_query,
|
||||
)
|
||||
|
|
@ -95,12 +96,37 @@ class Node(BaseModel, ABC):
|
|||
case GraphProvider.NEO4J:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
case _: # FalkorDB and Neptune
|
||||
case GraphProvider.KUZU:
|
||||
for label in ['Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label} {{uuid: $uuid}})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
||||
# Explicitly delete the "edge" nodes first, then the entity node.
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
|
||||
DETACH DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
case _: # FalkorDB, Neptune
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
|
|
@ -136,8 +162,32 @@ class Node(BaseModel, ABC):
|
|||
group_id=group_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
case _: # FalkorDB and Neptune
|
||||
case GraphProvider.KUZU:
|
||||
for label in ['Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label} {{group_id: $group_id}})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
group_id=group_id,
|
||||
)
|
||||
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
||||
# Explicitly delete the "edge" nodes first, then the entity node.
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
|
||||
DETACH DELETE e
|
||||
""",
|
||||
group_id=group_id,
|
||||
)
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
group_id=group_id,
|
||||
)
|
||||
case _: # FalkorDB, Neptune
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
|
|
@ -149,30 +199,59 @@ class Node(BaseModel, ABC):
|
|||
|
||||
@classmethod
|
||||
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
||||
if driver.provider == GraphProvider.FALKORDB:
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label})
|
||||
WHERE n.uuid IN $uuids
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
else:
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community)
|
||||
WHERE n.uuid IN $uuids
|
||||
CALL {
|
||||
WITH n
|
||||
match driver.provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label})
|
||||
WHERE n.uuid IN $uuids
|
||||
DETACH DELETE n
|
||||
} IN TRANSACTIONS OF $batch_size ROWS
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
case GraphProvider.KUZU:
|
||||
for label in ['Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label})
|
||||
WHERE n.uuid IN $uuids
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
||||
# Explicitly delete the "edge" nodes first, then the entity node.
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
|
||||
WHERE n.uuid IN $uuids
|
||||
DETACH DELETE e
|
||||
""",
|
||||
uuids=uuids,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid IN $uuids
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
case _: # Neo4J, Neptune
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community)
|
||||
WHERE n.uuid IN $uuids
|
||||
CALL {
|
||||
WITH n
|
||||
DETACH DELETE n
|
||||
} IN TRANSACTIONS OF $batch_size ROWS
|
||||
""",
|
||||
uuids=uuids,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||
|
|
@ -376,17 +455,25 @@ class EntityNode(Node):
|
|||
'summary': self.summary,
|
||||
'created_at': self.created_at,
|
||||
}
|
||||
entity_data.update(self.attributes or {})
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
entity_data['attributes'] = json.dumps(self.attributes)
|
||||
entity_data['labels'] = list(set(self.labels + ['Entity']))
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels=''),
|
||||
**entity_data,
|
||||
)
|
||||
else:
|
||||
entity_data.update(self.attributes or {})
|
||||
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
|
||||
|
||||
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels),
|
||||
entity_data=entity_data,
|
||||
)
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels),
|
||||
entity_data=entity_data,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||
|
||||
|
|
@ -399,12 +486,12 @@ class EntityNode(Node):
|
|||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN,
|
||||
+ get_entity_node_return_query(driver.provider),
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
||||
|
||||
if len(nodes) == 0:
|
||||
raise NodeNotFoundError(uuid)
|
||||
|
|
@ -419,12 +506,12 @@ class EntityNode(Node):
|
|||
WHERE n.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN,
|
||||
+ get_entity_node_return_query(driver.provider),
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
|
@ -456,7 +543,7 @@ class EntityNode(Node):
|
|||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ with_embeddings_query
|
||||
+ """
|
||||
ORDER BY n.uuid DESC
|
||||
|
|
@ -468,7 +555,7 @@ class EntityNode(Node):
|
|||
routing_='r',
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
|
@ -533,7 +620,7 @@ class CommunityNode(Node):
|
|||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community {uuid: $uuid})
|
||||
MATCH (c:Community {uuid: $uuid})
|
||||
RETURN
|
||||
"""
|
||||
+ (
|
||||
|
|
@ -556,8 +643,8 @@ class CommunityNode(Node):
|
|||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
WHERE n.uuid IN $uuids
|
||||
MATCH (c:Community)
|
||||
WHERE c.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ (
|
||||
|
|
@ -581,13 +668,13 @@ class CommunityNode(Node):
|
|||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
):
|
||||
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
||||
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
|
||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
WHERE n.group_id IN $group_ids
|
||||
MATCH (c:Community)
|
||||
WHERE c.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
|
|
@ -599,7 +686,7 @@ class CommunityNode(Node):
|
|||
else COMMUNITY_NODE_RETURN
|
||||
)
|
||||
+ """
|
||||
ORDER BY n.uuid DESC
|
||||
ORDER BY c.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
|
|
@ -636,7 +723,19 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|||
)
|
||||
|
||||
|
||||
def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
|
||||
if provider == GraphProvider.KUZU:
|
||||
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
||||
else:
|
||||
attributes = record['attributes']
|
||||
attributes.pop('uuid', None)
|
||||
attributes.pop('name', None)
|
||||
attributes.pop('group_id', None)
|
||||
attributes.pop('name_embedding', None)
|
||||
attributes.pop('summary', None)
|
||||
attributes.pop('created_at', None)
|
||||
attributes.pop('labels', None)
|
||||
|
||||
entity_node = EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
|
|
@ -645,16 +744,9 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|||
labels=record['labels'],
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
summary=record['summary'],
|
||||
attributes=record['attributes'],
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
entity_node.attributes.pop('uuid', None)
|
||||
entity_node.attributes.pop('name', None)
|
||||
entity_node.attributes.pop('group_id', None)
|
||||
entity_node.attributes.pop('name_embedding', None)
|
||||
entity_node.attributes.pop('summary', None)
|
||||
entity_node.attributes.pop('created_at', None)
|
||||
|
||||
return entity_node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
|
||||
|
||||
class ComparisonOperator(Enum):
|
||||
equals = '='
|
||||
|
|
@ -54,16 +56,21 @@ class SearchFilters(BaseModel):
|
|||
|
||||
def node_search_filter_query_constructor(
|
||||
filters: SearchFilters,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
filter_query: str = ''
|
||||
provider: GraphProvider,
|
||||
) -> tuple[list[str], dict[str, Any]]:
|
||||
filter_queries: list[str] = []
|
||||
filter_params: dict[str, Any] = {}
|
||||
|
||||
if filters.node_labels is not None:
|
||||
node_labels = '|'.join(filters.node_labels)
|
||||
node_label_filter = ' AND n:' + node_labels
|
||||
filter_query += node_label_filter
|
||||
if provider == GraphProvider.KUZU:
|
||||
node_label_filter = 'list_has_all(n.labels, $labels)'
|
||||
filter_params['labels'] = filters.node_labels
|
||||
else:
|
||||
node_labels = '|'.join(filters.node_labels)
|
||||
node_label_filter = 'n:' + node_labels
|
||||
filter_queries.append(node_label_filter)
|
||||
|
||||
return filter_query, filter_params
|
||||
return filter_queries, filter_params
|
||||
|
||||
|
||||
def date_filter_query_constructor(
|
||||
|
|
@ -81,23 +88,29 @@ def date_filter_query_constructor(
|
|||
|
||||
def edge_search_filter_query_constructor(
|
||||
filters: SearchFilters,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
filter_query: str = ''
|
||||
provider: GraphProvider,
|
||||
) -> tuple[list[str], dict[str, Any]]:
|
||||
filter_queries: list[str] = []
|
||||
filter_params: dict[str, Any] = {}
|
||||
|
||||
if filters.edge_types is not None:
|
||||
edge_types = filters.edge_types
|
||||
edge_types_filter = '\nAND e.name in $edge_types'
|
||||
filter_query += edge_types_filter
|
||||
filter_queries.append('e.name in $edge_types')
|
||||
filter_params['edge_types'] = edge_types
|
||||
|
||||
if filters.node_labels is not None:
|
||||
node_labels = '|'.join(filters.node_labels)
|
||||
node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels
|
||||
filter_query += node_label_filter
|
||||
if provider == GraphProvider.KUZU:
|
||||
node_label_filter = (
|
||||
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
|
||||
)
|
||||
filter_params['labels'] = filters.node_labels
|
||||
else:
|
||||
node_labels = '|'.join(filters.node_labels)
|
||||
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
|
||||
filter_queries.append(node_label_filter)
|
||||
|
||||
if filters.valid_at is not None:
|
||||
valid_at_filter = '\nAND ('
|
||||
valid_at_filter = '('
|
||||
for i, or_list in enumerate(filters.valid_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
if date_filter.comparison_operator not in [
|
||||
|
|
@ -125,10 +138,10 @@ def edge_search_filter_query_constructor(
|
|||
else:
|
||||
valid_at_filter += ' OR '
|
||||
|
||||
filter_query += valid_at_filter
|
||||
filter_queries.append(valid_at_filter)
|
||||
|
||||
if filters.invalid_at is not None:
|
||||
invalid_at_filter = ' AND ('
|
||||
invalid_at_filter = '('
|
||||
for i, or_list in enumerate(filters.invalid_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
if date_filter.comparison_operator not in [
|
||||
|
|
@ -156,10 +169,10 @@ def edge_search_filter_query_constructor(
|
|||
else:
|
||||
invalid_at_filter += ' OR '
|
||||
|
||||
filter_query += invalid_at_filter
|
||||
filter_queries.append(invalid_at_filter)
|
||||
|
||||
if filters.created_at is not None:
|
||||
created_at_filter = ' AND ('
|
||||
created_at_filter = '('
|
||||
for i, or_list in enumerate(filters.created_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
if date_filter.comparison_operator not in [
|
||||
|
|
@ -187,10 +200,10 @@ def edge_search_filter_query_constructor(
|
|||
else:
|
||||
created_at_filter += ' OR '
|
||||
|
||||
filter_query += created_at_filter
|
||||
filter_queries.append(created_at_filter)
|
||||
|
||||
if filters.expired_at is not None:
|
||||
expired_at_filter = ' AND ('
|
||||
expired_at_filter = '('
|
||||
for i, or_list in enumerate(filters.expired_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
if date_filter.comparison_operator not in [
|
||||
|
|
@ -218,6 +231,6 @@ def edge_search_filter_query_constructor(
|
|||
else:
|
||||
expired_at_filter += ' OR '
|
||||
|
||||
filter_query += expired_at_filter
|
||||
filter_queries.append(expired_at_filter)
|
||||
|
||||
return filter_query, filter_params
|
||||
return filter_queries, filter_params
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime
|
||||
|
|
@ -22,20 +23,21 @@ import numpy as np
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Any
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
EPISODIC_EDGE_SAVE_BULK,
|
||||
get_entity_edge_save_bulk_query,
|
||||
get_episodic_edge_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
get_entity_node_save_bulk_query,
|
||||
get_episode_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
resolve_extracted_edge,
|
||||
|
|
@ -116,11 +118,15 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
episodes = [dict(episode) for episode in episodic_nodes]
|
||||
for episode in episodes:
|
||||
episode['source'] = str(episode['source'].value)
|
||||
episode.pop('labels', None)
|
||||
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
|
||||
nodes: list[dict[str, Any]] = []
|
||||
|
||||
nodes = []
|
||||
|
||||
for node in entity_nodes:
|
||||
if node.name_embedding is None:
|
||||
await node.generate_name_embedding(embedder)
|
||||
|
||||
entity_data: dict[str, Any] = {
|
||||
'uuid': node.uuid,
|
||||
'name': node.name,
|
||||
|
|
@ -130,13 +136,19 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
'created_at': node.created_at,
|
||||
}
|
||||
|
||||
entity_data.update(node.attributes or {})
|
||||
entity_data['labels'] = list(
|
||||
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
|
||||
)
|
||||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
||||
entity_data['attributes'] = json.dumps(attributes)
|
||||
else:
|
||||
entity_data.update(node.attributes or {})
|
||||
entity_data['labels'] = list(
|
||||
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
|
||||
)
|
||||
|
||||
nodes.append(entity_data)
|
||||
|
||||
edges: list[dict[str, Any]] = []
|
||||
edges = []
|
||||
for edge in entity_edges:
|
||||
if edge.fact_embedding is None:
|
||||
await edge.generate_embedding(embedder)
|
||||
|
|
@ -155,17 +167,36 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
'invalid_at': edge.invalid_at,
|
||||
}
|
||||
|
||||
edge_data.update(edge.attributes or {})
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
||||
edge_data['attributes'] = json.dumps(attributes)
|
||||
else:
|
||||
edge_data.update(edge.attributes or {})
|
||||
|
||||
edges.append(edge_data)
|
||||
|
||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
|
||||
await tx.run(entity_node_save_bulk, nodes=nodes)
|
||||
await tx.run(
|
||||
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||
)
|
||||
entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
|
||||
await tx.run(entity_edge_save_bulk, entity_edges=edges)
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
|
||||
episode_query = get_episode_node_save_bulk_query(driver.provider)
|
||||
for episode in episodes:
|
||||
await tx.run(episode_query, **episode)
|
||||
entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
|
||||
for node in nodes:
|
||||
await tx.run(entity_node_query, **node)
|
||||
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
|
||||
for edge in edges:
|
||||
await tx.run(entity_edge_query, **edge)
|
||||
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
|
||||
for edge in episodic_edges:
|
||||
await tx.run(episodic_edge_query, **edge.model_dump())
|
||||
else:
|
||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
|
||||
await tx.run(
|
||||
get_episodic_edge_save_bulk_query(driver.provider),
|
||||
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
||||
)
|
||||
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
|
|
|
|||
|
|
@ -40,3 +40,16 @@ def ensure_utc(dt: datetime | None) -> datetime | None:
|
|||
return dt.astimezone(timezone.utc)
|
||||
|
||||
return dt
|
||||
|
||||
|
||||
def convert_datetimes_to_strings(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_datetimes_to_strings(item) for item in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
return obj
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@ from collections import defaultdict
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.edges import CommunityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
|
||||
|
|
@ -33,11 +34,11 @@ async def get_community_clusters(
|
|||
if group_ids is None:
|
||||
group_id_values, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
WHERE n.group_id IS NOT NULL
|
||||
RETURN
|
||||
collect(DISTINCT n.group_id) AS group_ids
|
||||
""",
|
||||
MATCH (n:Entity)
|
||||
WHERE n.group_id IS NOT NULL
|
||||
RETURN
|
||||
collect(DISTINCT n.group_id) AS group_ids
|
||||
"""
|
||||
)
|
||||
|
||||
group_ids = group_id_values[0]['group_ids'] if group_id_values else []
|
||||
|
|
@ -46,14 +47,21 @@ async def get_community_clusters(
|
|||
projection: dict[str, list[Neighbor]] = {}
|
||||
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
|
||||
for node in nodes:
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query = """
|
||||
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
|
||||
"""
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
match_query = """
|
||||
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
|
||||
"""
|
||||
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
|
||||
WITH count(r) AS count, m.uuid AS uuid
|
||||
RETURN
|
||||
uuid,
|
||||
count
|
||||
""",
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query
|
||||
+ """
|
||||
WITH count(e) AS count, m.uuid AS uuid
|
||||
RETURN
|
||||
uuid,
|
||||
count
|
||||
""",
|
||||
uuid=node.uuid,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
|
@ -235,9 +243,9 @@ async def build_communities(
|
|||
async def remove_communities(driver: GraphDriver):
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
DETACH DELETE c
|
||||
""",
|
||||
MATCH (c:Community)
|
||||
DETACH DELETE c
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -247,14 +255,10 @@ async def determine_entity_community(
|
|||
# Check if the node is already part of a community
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
c.uuid AS uuid,
|
||||
c.name AS name,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN,
|
||||
entity_uuid=entity.uuid,
|
||||
)
|
||||
|
||||
|
|
@ -262,16 +266,19 @@ async def determine_entity_community(
|
|||
return get_community_node_from_record(records[0]), False
|
||||
|
||||
# If the node has no community, add it to the mode community of surrounding entities
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query = """
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||
"""
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
match_query = """
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
c.uuid AS uuid,
|
||||
c.name AS name,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
records, _, _ = await driver.execute_query(
|
||||
match_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN,
|
||||
entity_uuid=entity.uuid,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -531,17 +531,28 @@ async def filter_existing_duplicate_of_edges(
|
|||
routing_='r',
|
||||
)
|
||||
else:
|
||||
query: LiteralString = """
|
||||
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_uuid,
|
||||
m.uuid AS target_uuid
|
||||
"""
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
query = """
|
||||
UNWIND $duplicate_node_uuids AS duplicate
|
||||
MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_uuid,
|
||||
m.uuid AS target_uuid
|
||||
"""
|
||||
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
|
||||
else:
|
||||
query: LiteralString = """
|
||||
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_uuid,
|
||||
m.uuid AS target_uuid
|
||||
"""
|
||||
duplicate_node_uuids = list(duplicate_nodes_map.keys())
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
|
||||
duplicate_node_uuids=duplicate_node_uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -53,10 +53,29 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|||
for name in index_names
|
||||
]
|
||||
)
|
||||
|
||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
||||
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
# Skip creating fulltext indices if they already exist. Need to do this manually
|
||||
# until Kuzu supports `IF NOT EXISTS` for indices.
|
||||
result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
|
||||
if len(result) > 0:
|
||||
fulltext_indices = []
|
||||
|
||||
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
|
||||
result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
|
||||
if len(result) == 0:
|
||||
fulltext_indices.insert(
|
||||
0,
|
||||
"""
|
||||
INSTALL fts;
|
||||
LOAD fts;
|
||||
""",
|
||||
)
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await semaphore_gather(
|
||||
|
|
@ -76,10 +95,19 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|||
await tx.run('MATCH (n) DETACH DELETE n')
|
||||
|
||||
async def delete_group_ids(tx):
|
||||
await tx.run(
|
||||
'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n',
|
||||
group_ids=group_ids,
|
||||
)
|
||||
labels = ['Entity', 'Episodic', 'Community']
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
labels.append('RelatesToNode_')
|
||||
|
||||
for label in labels:
|
||||
await tx.run(
|
||||
f"""
|
||||
MATCH (n:{label})
|
||||
WHERE n.group_id IN $group_ids
|
||||
DETACH DELETE n
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
if group_ids is None:
|
||||
await session.execute_write(delete_all)
|
||||
|
|
@ -108,18 +136,23 @@ async def retrieve_episodes(
|
|||
Returns:
|
||||
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
||||
"""
|
||||
group_id_filter: LiteralString = (
|
||||
'\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
||||
)
|
||||
source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
|
||||
|
||||
query_params: dict = {}
|
||||
query_filter = ''
|
||||
if group_ids and len(group_ids) > 0:
|
||||
query_filter += '\nAND e.group_id IN $group_ids'
|
||||
query_params['group_ids'] = group_ids
|
||||
|
||||
if source is not None:
|
||||
query_filter += '\nAND e.source = $source'
|
||||
query_params['source'] = source.name
|
||||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ group_id_filter
|
||||
+ source_filter
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ query_filter
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
|
|
@ -136,9 +169,8 @@ async def retrieve_episodes(
|
|||
result, _, _ = await driver.execute_query(
|
||||
query,
|
||||
reference_time=reference_time,
|
||||
source=source.name if source is not None else None,
|
||||
num_episodes=last_n,
|
||||
group_ids=group_ids,
|
||||
**query_params,
|
||||
)
|
||||
|
||||
episodes = [get_episodic_node_from_record(record) for record in result]
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ Repository = "https://github.com/getzep/graphiti"
|
|||
anthropic = ["anthropic>=0.49.0"]
|
||||
groq = ["groq>=0.2.0"]
|
||||
google-genai = ["google-genai>=1.8.0"]
|
||||
kuzu = ["kuzu>=0.11.2"]
|
||||
falkordb = ["falkordb>=1.1.2,<2.0.0"]
|
||||
voyageai = ["voyageai>=0.2.3"]
|
||||
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
||||
|
|
@ -39,6 +40,7 @@ dev = [
|
|||
"anthropic>=0.49.0",
|
||||
"google-genai>=1.8.0",
|
||||
"falkordb>=1.1.2,<2.0.0",
|
||||
"kuzu>=0.11.2",
|
||||
"ipykernel>=6.29.5",
|
||||
"jupyterlab>=4.2.4",
|
||||
"diskcache-stubs>=5.6.3.6.20240818",
|
||||
|
|
@ -91,7 +93,3 @@ docstring-code-format = true
|
|||
include = ["graphiti_core"]
|
||||
pythonVersion = "3.10"
|
||||
typeCheckingMode = "basic"
|
||||
|
||||
[[tool.pyright.overrides]]
|
||||
include = ["**/falkordb*"]
|
||||
reportMissingImports = false
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
[pytest]
|
||||
markers =
|
||||
integration: marks tests as integration tests
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
asyncio_mode = auto
|
||||
|
|
|
|||
|
|
@ -15,42 +15,55 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.neptune_driver import NeptuneDriver
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder.client import EmbedderClient
|
||||
from graphiti_core.helpers import lucene_sanitize
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HAS_NEO4J = False
|
||||
HAS_FALKORDB = False
|
||||
HAS_NEPTUNE = False
|
||||
drivers: list[GraphProvider] = []
|
||||
if os.getenv('DISABLE_NEO4J') is None:
|
||||
try:
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
|
||||
HAS_NEO4J = True
|
||||
drivers.append(GraphProvider.NEO4J)
|
||||
except ImportError:
|
||||
pass
|
||||
raise
|
||||
|
||||
if os.getenv('DISABLE_FALKORDB') is None:
|
||||
try:
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||
|
||||
HAS_FALKORDB = True
|
||||
drivers.append(GraphProvider.FALKORDB)
|
||||
except ImportError:
|
||||
pass
|
||||
raise
|
||||
|
||||
if os.getenv('DISABLE_KUZU') is None:
|
||||
try:
|
||||
from graphiti_core.driver.kuzu_driver import KuzuDriver
|
||||
|
||||
drivers.append(GraphProvider.KUZU)
|
||||
except ImportError:
|
||||
raise
|
||||
|
||||
# Disable Neptune for now
|
||||
os.environ['DISABLE_NEPTUNE'] = 'True'
|
||||
if os.getenv('DISABLE_NEPTUNE') is None:
|
||||
try:
|
||||
from graphiti_core.driver.neptune_driver import NeptuneDriver
|
||||
|
||||
HAS_NEPTUNE = False
|
||||
drivers.append(GraphProvider.NEPTUNE)
|
||||
except ImportError:
|
||||
pass
|
||||
raise
|
||||
|
||||
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
||||
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
||||
|
|
@ -65,38 +78,100 @@ NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
|
|||
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
|
||||
AOSS_HOST = os.getenv('AOSS_HOST', None)
|
||||
|
||||
KUZU_DB = os.getenv('KUZU_DB', ':memory:')
|
||||
|
||||
def get_driver(driver_name: str) -> GraphDriver:
|
||||
if driver_name == 'neo4j':
|
||||
group_id = 'graphiti_test_group'
|
||||
group_id_2 = 'graphiti_test_group_2'
|
||||
|
||||
|
||||
def get_driver(provider: GraphProvider) -> GraphDriver:
|
||||
if provider == GraphProvider.NEO4J:
|
||||
return Neo4jDriver(
|
||||
uri=NEO4J_URI,
|
||||
user=NEO4J_USER,
|
||||
password=NEO4J_PASSWORD,
|
||||
)
|
||||
elif driver_name == 'falkordb':
|
||||
elif provider == GraphProvider.FALKORDB:
|
||||
return FalkorDriver(
|
||||
host=FALKORDB_HOST,
|
||||
port=int(FALKORDB_PORT),
|
||||
username=FALKORDB_USER,
|
||||
password=FALKORDB_PASSWORD,
|
||||
)
|
||||
elif driver_name == 'neptune':
|
||||
elif provider == GraphProvider.KUZU:
|
||||
driver = KuzuDriver(
|
||||
db=KUZU_DB,
|
||||
)
|
||||
return driver
|
||||
elif provider == GraphProvider.NEPTUNE:
|
||||
return NeptuneDriver(
|
||||
host=NEPTUNE_HOST,
|
||||
port=int(NEPTUNE_PORT),
|
||||
aoss_host=AOSS_HOST,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Driver {driver_name} not available')
|
||||
raise ValueError(f'Driver {provider} not available')
|
||||
|
||||
|
||||
drivers: list[str] = []
|
||||
if HAS_NEO4J:
|
||||
drivers.append('neo4j')
|
||||
if HAS_FALKORDB:
|
||||
drivers.append('falkordb')
|
||||
if HAS_NEPTUNE:
|
||||
drivers.append('neptune')
|
||||
@pytest.fixture(params=drivers)
|
||||
async def graph_driver(request):
|
||||
driver = request.param
|
||||
graph_driver = get_driver(driver)
|
||||
await clear_data(graph_driver, [group_id, group_id_2])
|
||||
try:
|
||||
yield graph_driver # provide driver to the test
|
||||
finally:
|
||||
# always called, even if the test fails or raises
|
||||
# await clean_up(graph_driver)
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
embedding_dim = 384
|
||||
embeddings = {
|
||||
key: np.random.uniform(0.0, 0.9, embedding_dim).tolist()
|
||||
for key in [
|
||||
'Alice',
|
||||
'Bob',
|
||||
'Alice likes Bob',
|
||||
'test_entity_1',
|
||||
'test_entity_2',
|
||||
'test_entity_3',
|
||||
'test_entity_4',
|
||||
'test_entity_alice',
|
||||
'test_entity_bob',
|
||||
'test_entity_1 is a duplicate of test_entity_2',
|
||||
'test_entity_3 is a duplicate of test_entity_4',
|
||||
'test_entity_1 relates to test_entity_2',
|
||||
'test_entity_1 relates to test_entity_3',
|
||||
'test_entity_2 relates to test_entity_3',
|
||||
'test_entity_1 relates to test_entity_4',
|
||||
'test_entity_2 relates to test_entity_4',
|
||||
'test_entity_3 relates to test_entity_4',
|
||||
'test_entity_1 relates to test_entity_2',
|
||||
'test_entity_3 relates to test_entity_4',
|
||||
'test_entity_2 relates to test_entity_3',
|
||||
'test_community_1',
|
||||
'test_community_2',
|
||||
]
|
||||
}
|
||||
embeddings['Alice Smith'] = embeddings['Alice']
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedder():
|
||||
mock_model = Mock(spec=EmbedderClient)
|
||||
|
||||
def mock_embed(input_data):
|
||||
if isinstance(input_data, str):
|
||||
return embeddings[input_data]
|
||||
elif isinstance(input_data, list):
|
||||
combined_input = ' '.join(input_data)
|
||||
return embeddings[combined_input]
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type: {type(input_data)}')
|
||||
|
||||
mock_model.create.side_effect = mock_embed
|
||||
return mock_model
|
||||
|
||||
|
||||
def test_lucene_sanitize():
|
||||
|
|
@ -114,5 +189,125 @@ def test_lucene_sanitize():
|
|||
assert assert_result == result
|
||||
|
||||
|
||||
async def get_node_count(driver: GraphDriver, uuids: list[str]) -> int:
|
||||
results, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)
|
||||
WHERE n.uuid IN $uuids
|
||||
RETURN COUNT(n) as count
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
return int(results[0]['count'])
|
||||
|
||||
|
||||
async def get_edge_count(driver: GraphDriver, uuids: list[str]) -> int:
|
||||
results, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e]->(m)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN COUNT(e) as count
|
||||
UNION ALL
|
||||
MATCH (e:RelatesToNode_)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN COUNT(e) as count
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
return sum(int(result['count']) for result in results)
|
||||
|
||||
|
||||
async def print_graph(graph_driver: GraphDriver):
|
||||
nodes, _, _ = await graph_driver.execute_query(
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n.uuid, n.name
|
||||
""",
|
||||
)
|
||||
print('Nodes:')
|
||||
for node in nodes:
|
||||
print(' ', node)
|
||||
edges, _, _ = await graph_driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e]->(m)
|
||||
RETURN n.name, e.uuid, m.name
|
||||
""",
|
||||
)
|
||||
print('Edges:')
|
||||
for edge in edges:
|
||||
print(' ', edge)
|
||||
|
||||
|
||||
async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode):
|
||||
assert retrieved.uuid == sample.uuid
|
||||
assert retrieved.name == sample.name
|
||||
assert retrieved.group_id == group_id
|
||||
assert retrieved.created_at == sample.created_at
|
||||
assert retrieved.source == sample.source
|
||||
assert retrieved.source_description == sample.source_description
|
||||
assert retrieved.content == sample.content
|
||||
assert retrieved.valid_at == sample.valid_at
|
||||
assert set(retrieved.entity_edges) == set(sample.entity_edges)
|
||||
|
||||
|
||||
async def assert_entity_node_equals(
|
||||
graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode
|
||||
):
|
||||
await retrieved.load_name_embedding(graph_driver)
|
||||
assert retrieved.uuid == sample.uuid
|
||||
assert retrieved.name == sample.name
|
||||
assert retrieved.group_id == sample.group_id
|
||||
assert set(retrieved.labels) == set(sample.labels)
|
||||
assert retrieved.created_at == sample.created_at
|
||||
assert retrieved.name_embedding is not None
|
||||
assert sample.name_embedding is not None
|
||||
assert np.allclose(retrieved.name_embedding, sample.name_embedding)
|
||||
assert retrieved.summary == sample.summary
|
||||
assert retrieved.attributes == sample.attributes
|
||||
|
||||
|
||||
async def assert_community_node_equals(
|
||||
graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode
|
||||
):
|
||||
await retrieved.load_name_embedding(graph_driver)
|
||||
assert retrieved.uuid == sample.uuid
|
||||
assert retrieved.name == sample.name
|
||||
assert retrieved.group_id == group_id
|
||||
assert retrieved.created_at == sample.created_at
|
||||
assert retrieved.name_embedding is not None
|
||||
assert sample.name_embedding is not None
|
||||
assert np.allclose(retrieved.name_embedding, sample.name_embedding)
|
||||
assert retrieved.summary == sample.summary
|
||||
|
||||
|
||||
async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge):
|
||||
assert retrieved.uuid == sample.uuid
|
||||
assert retrieved.group_id == sample.group_id
|
||||
assert retrieved.created_at == sample.created_at
|
||||
assert retrieved.source_node_uuid == sample.source_node_uuid
|
||||
assert retrieved.target_node_uuid == sample.target_node_uuid
|
||||
|
||||
|
||||
async def assert_entity_edge_equals(
|
||||
graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge
|
||||
):
|
||||
await retrieved.load_fact_embedding(graph_driver)
|
||||
assert retrieved.uuid == sample.uuid
|
||||
assert retrieved.group_id == sample.group_id
|
||||
assert retrieved.created_at == sample.created_at
|
||||
assert retrieved.source_node_uuid == sample.source_node_uuid
|
||||
assert retrieved.target_node_uuid == sample.target_node_uuid
|
||||
assert retrieved.name == sample.name
|
||||
assert retrieved.fact == sample.fact
|
||||
assert retrieved.fact_embedding is not None
|
||||
assert sample.fact_embedding is not None
|
||||
assert np.allclose(retrieved.fact_embedding, sample.fact_embedding)
|
||||
assert retrieved.episodes == sample.episodes
|
||||
assert retrieved.expired_at == sample.expired_at
|
||||
assert retrieved.valid_at == sample.valid_at
|
||||
assert retrieved.invalid_at == sample.invalid_at
|
||||
assert retrieved.attributes == sample.attributes
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
|
|
@ -17,23 +17,16 @@ limitations under the License.
|
|||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||
from tests.helpers_test import drivers, get_driver
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
from tests.helpers_test import get_edge_count, get_node_count, group_id
|
||||
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
group_id = f'test_group_{str(uuid4())}'
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
|
|
@ -57,17 +50,10 @@ def setup_logging():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_episodic_edge(driver):
|
||||
graph_driver = get_driver(driver)
|
||||
embedder = OpenAIEmbedder()
|
||||
|
||||
async def test_episodic_edge(graph_driver, mock_embedder):
|
||||
now = datetime.now()
|
||||
|
||||
# Create episodic node
|
||||
episode_node = EpisodicNode(
|
||||
name='test_episode',
|
||||
labels=[],
|
||||
|
|
@ -79,13 +65,13 @@ async def test_episodic_edge(driver):
|
|||
entity_edges=[],
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [episode_node.uuid])
|
||||
assert node_count == 0
|
||||
await episode_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [episode_node.uuid])
|
||||
assert node_count == 1
|
||||
|
||||
# Create entity node
|
||||
alice_node = EntityNode(
|
||||
name='Alice',
|
||||
labels=[],
|
||||
|
|
@ -93,27 +79,27 @@ async def test_episodic_edge(driver):
|
|||
summary='Alice summary',
|
||||
group_id=group_id,
|
||||
)
|
||||
await alice_node.generate_name_embedding(embedder)
|
||||
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
await alice_node.generate_name_embedding(mock_embedder)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
await alice_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 1
|
||||
|
||||
# Create episodic to entity edge
|
||||
episodic_edge = EpisodicEdge(
|
||||
source_node_uuid=episode_node.uuid,
|
||||
target_node_uuid=alice_node.uuid,
|
||||
created_at=now,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
|
||||
assert edge_count == 0
|
||||
await episodic_edge.save(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
|
||||
assert edge_count == 1
|
||||
|
||||
# Get edge by uuid
|
||||
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
|
||||
assert retrieved.uuid == episodic_edge.uuid
|
||||
assert retrieved.source_node_uuid == episode_node.uuid
|
||||
|
|
@ -121,6 +107,7 @@ async def test_episodic_edge(driver):
|
|||
assert retrieved.created_at == now
|
||||
assert retrieved.group_id == group_id
|
||||
|
||||
# Get edge by uuids
|
||||
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == episodic_edge.uuid
|
||||
|
|
@ -129,6 +116,7 @@ async def test_episodic_edge(driver):
|
|||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Get edge by group ids
|
||||
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == episodic_edge.uuid
|
||||
|
|
@ -137,33 +125,41 @@ async def test_episodic_edge(driver):
|
|||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Get episodic node by entity node uuid
|
||||
retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == episode_node.uuid
|
||||
assert retrieved[0].name == 'test_episode'
|
||||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Delete edge by uuid
|
||||
await episodic_edge.delete(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
await episode_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||
assert node_count == 0
|
||||
# Delete edge by uuids
|
||||
await episodic_edge.save(graph_driver)
|
||||
await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
|
||||
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
# Cleanup nodes
|
||||
await episode_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [episode_node.uuid])
|
||||
assert node_count == 0
|
||||
await alice_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_entity_edge(driver):
|
||||
graph_driver = get_driver(driver)
|
||||
embedder = OpenAIEmbedder()
|
||||
|
||||
async def test_entity_edge(graph_driver, mock_embedder):
|
||||
now = datetime.now()
|
||||
|
||||
# Create entity node
|
||||
alice_node = EntityNode(
|
||||
name='Alice',
|
||||
labels=[],
|
||||
|
|
@ -171,25 +167,25 @@ async def test_entity_edge(driver):
|
|||
summary='Alice summary',
|
||||
group_id=group_id,
|
||||
)
|
||||
await alice_node.generate_name_embedding(embedder)
|
||||
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
await alice_node.generate_name_embedding(mock_embedder)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
await alice_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 1
|
||||
|
||||
# Create entity node
|
||||
bob_node = EntityNode(
|
||||
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
|
||||
)
|
||||
await bob_node.generate_name_embedding(embedder)
|
||||
|
||||
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||
await bob_node.generate_name_embedding(mock_embedder)
|
||||
node_count = await get_node_count(graph_driver, [bob_node.uuid])
|
||||
assert node_count == 0
|
||||
await bob_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [bob_node.uuid])
|
||||
assert node_count == 1
|
||||
|
||||
# Create entity to entity edge
|
||||
entity_edge = EntityEdge(
|
||||
source_node_uuid=alice_node.uuid,
|
||||
target_node_uuid=bob_node.uuid,
|
||||
|
|
@ -202,14 +198,14 @@ async def test_entity_edge(driver):
|
|||
invalid_at=now,
|
||||
group_id=group_id,
|
||||
)
|
||||
edge_embedding = await entity_edge.generate_embedding(embedder)
|
||||
|
||||
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||
edge_embedding = await entity_edge.generate_embedding(mock_embedder)
|
||||
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
|
||||
assert edge_count == 0
|
||||
await entity_edge.save(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
|
||||
assert edge_count == 1
|
||||
|
||||
# Get edge by uuid
|
||||
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
|
||||
assert retrieved.uuid == entity_edge.uuid
|
||||
assert retrieved.source_node_uuid == alice_node.uuid
|
||||
|
|
@ -217,6 +213,7 @@ async def test_entity_edge(driver):
|
|||
assert retrieved.created_at == now
|
||||
assert retrieved.group_id == group_id
|
||||
|
||||
# Get edge by uuids
|
||||
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == entity_edge.uuid
|
||||
|
|
@ -225,6 +222,7 @@ async def test_entity_edge(driver):
|
|||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Get edge by group ids
|
||||
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == entity_edge.uuid
|
||||
|
|
@ -233,6 +231,7 @@ async def test_entity_edge(driver):
|
|||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Get edge by node uuid
|
||||
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == entity_edge.uuid
|
||||
|
|
@ -241,82 +240,113 @@ async def test_entity_edge(driver):
|
|||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Get fact embedding
|
||||
await entity_edge.load_fact_embedding(graph_driver)
|
||||
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
|
||||
|
||||
# Delete edge by uuid
|
||||
await entity_edge.delete(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
await alice_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 0
|
||||
# Delete edge by uuids
|
||||
await entity_edge.save(graph_driver)
|
||||
await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
|
||||
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
# Deleting node should delete the edge
|
||||
await entity_edge.save(graph_driver)
|
||||
await alice_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
# Deleting node by uuids should delete the edge
|
||||
await alice_node.save(graph_driver)
|
||||
await entity_edge.save(graph_driver)
|
||||
await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
# Deleting node by group id should delete the edge
|
||||
await alice_node.save(graph_driver)
|
||||
await entity_edge.save(graph_driver)
|
||||
await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
# Cleanup nodes
|
||||
await alice_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
await bob_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [bob_node.uuid])
|
||||
assert node_count == 0
|
||||
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_community_edge(driver):
|
||||
graph_driver = get_driver(driver)
|
||||
embedder = OpenAIEmbedder()
|
||||
|
||||
async def test_community_edge(graph_driver, mock_embedder):
|
||||
now = datetime.now()
|
||||
|
||||
# Create community node
|
||||
community_node_1 = CommunityNode(
|
||||
name='Community A',
|
||||
name='test_community_1',
|
||||
group_id=group_id,
|
||||
summary='Community A summary',
|
||||
)
|
||||
await community_node_1.generate_name_embedding(embedder)
|
||||
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||
await community_node_1.generate_name_embedding(mock_embedder)
|
||||
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
|
||||
assert node_count == 0
|
||||
await community_node_1.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
|
||||
assert node_count == 1
|
||||
|
||||
# Create community node
|
||||
community_node_2 = CommunityNode(
|
||||
name='Community B',
|
||||
name='test_community_2',
|
||||
group_id=group_id,
|
||||
summary='Community B summary',
|
||||
)
|
||||
await community_node_2.generate_name_embedding(embedder)
|
||||
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||
await community_node_2.generate_name_embedding(mock_embedder)
|
||||
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
|
||||
assert node_count == 0
|
||||
await community_node_2.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
|
||||
assert node_count == 1
|
||||
|
||||
# Create entity node
|
||||
alice_node = EntityNode(
|
||||
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
|
||||
)
|
||||
await alice_node.generate_name_embedding(embedder)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
await alice_node.generate_name_embedding(mock_embedder)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
await alice_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 1
|
||||
|
||||
# Create community to community edge
|
||||
community_edge = CommunityEdge(
|
||||
source_node_uuid=community_node_1.uuid,
|
||||
target_node_uuid=community_node_2.uuid,
|
||||
created_at=now,
|
||||
group_id=group_id,
|
||||
)
|
||||
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
|
||||
assert edge_count == 0
|
||||
await community_edge.save(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
|
||||
assert edge_count == 1
|
||||
|
||||
# Get edge by uuid
|
||||
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
|
||||
assert retrieved.uuid == community_edge.uuid
|
||||
assert retrieved.source_node_uuid == community_node_1.uuid
|
||||
|
|
@ -324,6 +354,7 @@ async def test_community_edge(driver):
|
|||
assert retrieved.created_at == now
|
||||
assert retrieved.group_id == group_id
|
||||
|
||||
# Get edge by uuids
|
||||
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == community_edge.uuid
|
||||
|
|
@ -332,6 +363,7 @@ async def test_community_edge(driver):
|
|||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Get edge by group ids
|
||||
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == community_edge.uuid
|
||||
|
|
@ -340,45 +372,26 @@ async def test_community_edge(driver):
|
|||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Delete edge by uuid
|
||||
await community_edge.delete(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
# Delete edge by uuids
|
||||
await community_edge.save(graph_driver)
|
||||
await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
|
||||
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
|
||||
assert edge_count == 0
|
||||
|
||||
# Cleanup nodes
|
||||
await alice_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
node_count = await get_node_count(graph_driver, [alice_node.uuid])
|
||||
assert node_count == 0
|
||||
|
||||
await community_node_1.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
|
||||
assert node_count == 0
|
||||
|
||||
await community_node_2.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
|
||||
assert node_count == 0
|
||||
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
async def get_node_count(driver: GraphDriver, uuid: str):
|
||||
results, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n {uuid: $uuid})
|
||||
RETURN COUNT(n) as count
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
return int(results[0]['count'])
|
||||
|
||||
|
||||
async def get_edge_count(driver: GraphDriver, uuid: str):
|
||||
results, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e {uuid: $uuid}]->(m)
|
||||
RETURN COUNT(e) as count
|
||||
UNION ALL
|
||||
MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2)
|
||||
RETURN COUNT(m) as count
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
return sum(int(result['count']) for result in results)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ class Location(BaseModel):
|
|||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_exclude_default_entity_type(driver):
|
||||
"""Test excluding the default 'Entity' type while keeping custom types."""
|
||||
|
|
@ -118,7 +117,6 @@ async def test_exclude_default_entity_type(driver):
|
|||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_exclude_specific_custom_types(driver):
|
||||
"""Test excluding specific custom entity types while keeping others."""
|
||||
|
|
@ -182,7 +180,6 @@ async def test_exclude_specific_custom_types(driver):
|
|||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_exclude_all_types(driver):
|
||||
"""Test excluding all entity types (edge case)."""
|
||||
|
|
@ -231,7 +228,6 @@ async def test_exclude_all_types(driver):
|
|||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_exclude_no_types(driver):
|
||||
"""Test normal behavior when no types are excluded (baseline test)."""
|
||||
|
|
@ -314,7 +310,6 @@ def test_validation_invalid_excluded_types():
|
|||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_excluded_types_parameter_validation_in_add_episode(driver):
|
||||
"""Test that add_episode validates excluded_entity_types parameter."""
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from graphiti_core.graphiti import Graphiti
|
|||
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
|
||||
from graphiti_core.search.search_helpers import search_results_to_context_string
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from tests.helpers_test import drivers, get_driver
|
||||
from tests.helpers_test import GraphProvider
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
|
@ -51,15 +51,12 @@ def setup_logging():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_graphiti_init(driver):
|
||||
async def test_graphiti_init(graph_driver):
|
||||
if graph_driver.provider == GraphProvider.FALKORDB:
|
||||
pytest.skip('Skipping as tests fail on Falkordb')
|
||||
|
||||
logger = setup_logging()
|
||||
driver = get_driver(driver)
|
||||
graphiti = Graphiti(graph_driver=driver)
|
||||
graphiti = Graphiti(graph_driver=graph_driver)
|
||||
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
|
|
|
|||
2056
tests/test_graphiti_mock.py
Normal file
2056
tests/test_graphiti_mock.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -14,22 +14,29 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
EpisodeType,
|
||||
EpisodicNode,
|
||||
)
|
||||
from tests.helpers_test import drivers, get_driver
|
||||
from tests.helpers_test import (
|
||||
assert_community_node_equals,
|
||||
assert_entity_node_equals,
|
||||
assert_episodic_node_equals,
|
||||
get_node_count,
|
||||
group_id,
|
||||
)
|
||||
|
||||
group_id = f'test_group_{str(uuid4())}'
|
||||
created_at = datetime.now()
|
||||
deleted_at = created_at + timedelta(days=3)
|
||||
valid_at = created_at + timedelta(days=1)
|
||||
invalid_at = created_at + timedelta(days=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -38,9 +45,14 @@ def sample_entity_node():
|
|||
uuid=str(uuid4()),
|
||||
name='Test Entity',
|
||||
group_id=group_id,
|
||||
labels=[],
|
||||
labels=['Entity', 'Person'],
|
||||
created_at=created_at,
|
||||
name_embedding=[0.5] * 1024,
|
||||
summary='Entity Summary',
|
||||
attributes={
|
||||
'age': 30,
|
||||
'location': 'New York',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -50,10 +62,12 @@ def sample_episodic_node():
|
|||
uuid=str(uuid4()),
|
||||
name='Episode 1',
|
||||
group_id=group_id,
|
||||
created_at=created_at,
|
||||
source=EpisodeType.text,
|
||||
source_description='Test source',
|
||||
content='Some content here',
|
||||
valid_at=datetime.now(),
|
||||
valid_at=valid_at,
|
||||
entity_edges=[],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -62,182 +76,152 @@ def sample_community_node():
|
|||
return CommunityNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Community A',
|
||||
name_embedding=[0.5] * 1024,
|
||||
group_id=group_id,
|
||||
created_at=created_at,
|
||||
name_embedding=[0.5] * 1024,
|
||||
summary='Community summary',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_entity_node(sample_entity_node, driver):
|
||||
driver = get_driver(driver)
|
||||
async def test_entity_node(sample_entity_node, graph_driver):
|
||||
uuid = sample_entity_node.uuid
|
||||
|
||||
# Create node
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
await sample_entity_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_entity_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
|
||||
retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid)
|
||||
assert retrieved.uuid == sample_entity_node.uuid
|
||||
assert retrieved.name == 'Test Entity'
|
||||
assert retrieved.group_id == group_id
|
||||
# Get node by uuid
|
||||
retrieved = await EntityNode.get_by_uuid(graph_driver, sample_entity_node.uuid)
|
||||
await assert_entity_node_equals(graph_driver, retrieved, sample_entity_node)
|
||||
|
||||
retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid])
|
||||
assert retrieved[0].uuid == sample_entity_node.uuid
|
||||
assert retrieved[0].name == 'Test Entity'
|
||||
assert retrieved[0].group_id == group_id
|
||||
# Get node by uuids
|
||||
retrieved = await EntityNode.get_by_uuids(graph_driver, [sample_entity_node.uuid])
|
||||
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
|
||||
|
||||
retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||
# Get node by group ids
|
||||
retrieved = await EntityNode.get_by_group_ids(graph_driver, [group_id], limit=2, with_embeddings=True)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == sample_entity_node.uuid
|
||||
assert retrieved[0].name == 'Test Entity'
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
await sample_entity_node.load_name_embedding(driver)
|
||||
assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024)
|
||||
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
|
||||
|
||||
# Delete node by uuid
|
||||
await sample_entity_node.delete(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_entity_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by uuids
|
||||
await sample_entity_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
await sample_entity_node.delete_by_uuids(graph_driver, [uuid])
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by group id
|
||||
await sample_entity_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_entity_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
await sample_entity_node.delete_by_group_id(driver, group_id)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_entity_node.delete_by_group_id(graph_driver, group_id)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
await driver.close()
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_community_node(sample_community_node, driver):
|
||||
driver = get_driver(driver)
|
||||
async def test_community_node(sample_community_node, graph_driver):
|
||||
uuid = sample_community_node.uuid
|
||||
|
||||
# Create node
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
await sample_community_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_community_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
|
||||
retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid)
|
||||
assert retrieved.uuid == sample_community_node.uuid
|
||||
assert retrieved.name == 'Community A'
|
||||
assert retrieved.group_id == group_id
|
||||
assert retrieved.summary == 'Community summary'
|
||||
# Get node by uuid
|
||||
retrieved = await CommunityNode.get_by_uuid(graph_driver, sample_community_node.uuid)
|
||||
await assert_community_node_equals(graph_driver, retrieved, sample_community_node)
|
||||
|
||||
retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid])
|
||||
assert retrieved[0].uuid == sample_community_node.uuid
|
||||
assert retrieved[0].name == 'Community A'
|
||||
assert retrieved[0].group_id == group_id
|
||||
assert retrieved[0].summary == 'Community summary'
|
||||
# Get node by uuids
|
||||
retrieved = await CommunityNode.get_by_uuids(graph_driver, [sample_community_node.uuid])
|
||||
await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
|
||||
|
||||
retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||
# Get node by group ids
|
||||
retrieved = await CommunityNode.get_by_group_ids(graph_driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == sample_community_node.uuid
|
||||
assert retrieved[0].name == 'Community A'
|
||||
assert retrieved[0].group_id == group_id
|
||||
await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
|
||||
|
||||
# Delete node by uuid
|
||||
await sample_community_node.delete(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_community_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by uuids
|
||||
await sample_community_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
await sample_community_node.delete_by_uuids(graph_driver, [uuid])
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by group id
|
||||
await sample_community_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_community_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
await sample_community_node.delete_by_group_id(driver, group_id)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_community_node.delete_by_group_id(graph_driver, group_id)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
await driver.close()
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_episodic_node(sample_episodic_node, driver):
|
||||
driver = get_driver(driver)
|
||||
async def test_episodic_node(sample_episodic_node, graph_driver):
|
||||
uuid = sample_episodic_node.uuid
|
||||
|
||||
# Create node
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
await sample_episodic_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_episodic_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
|
||||
retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid)
|
||||
assert retrieved.uuid == sample_episodic_node.uuid
|
||||
assert retrieved.name == 'Episode 1'
|
||||
assert retrieved.group_id == group_id
|
||||
assert retrieved.source == EpisodeType.text
|
||||
assert retrieved.source_description == 'Test source'
|
||||
assert retrieved.content == 'Some content here'
|
||||
assert retrieved.valid_at == sample_episodic_node.valid_at
|
||||
# Get node by uuid
|
||||
retrieved = await EpisodicNode.get_by_uuid(graph_driver, sample_episodic_node.uuid)
|
||||
await assert_episodic_node_equals(retrieved, sample_episodic_node)
|
||||
|
||||
retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid])
|
||||
assert retrieved[0].uuid == sample_episodic_node.uuid
|
||||
assert retrieved[0].name == 'Episode 1'
|
||||
assert retrieved[0].group_id == group_id
|
||||
assert retrieved[0].source == EpisodeType.text
|
||||
assert retrieved[0].source_description == 'Test source'
|
||||
assert retrieved[0].content == 'Some content here'
|
||||
assert retrieved[0].valid_at == sample_episodic_node.valid_at
|
||||
# Get node by uuids
|
||||
retrieved = await EpisodicNode.get_by_uuids(graph_driver, [sample_episodic_node.uuid])
|
||||
await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
|
||||
|
||||
retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||
# Get node by group ids
|
||||
retrieved = await EpisodicNode.get_by_group_ids(graph_driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == sample_episodic_node.uuid
|
||||
assert retrieved[0].name == 'Episode 1'
|
||||
assert retrieved[0].group_id == group_id
|
||||
assert retrieved[0].source == EpisodeType.text
|
||||
assert retrieved[0].source_description == 'Test source'
|
||||
assert retrieved[0].content == 'Some content here'
|
||||
assert retrieved[0].valid_at == sample_episodic_node.valid_at
|
||||
await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
|
||||
|
||||
# Delete node by uuid
|
||||
await sample_episodic_node.delete(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_episodic_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by uuids
|
||||
await sample_episodic_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
await sample_episodic_node.delete_by_uuids(graph_driver, [uuid])
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by group id
|
||||
await sample_episodic_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_episodic_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 1
|
||||
await sample_episodic_node.delete_by_group_id(driver, group_id)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
await sample_episodic_node.delete_by_group_id(graph_driver, group_id)
|
||||
node_count = await get_node_count(graph_driver, [uuid])
|
||||
assert node_count == 0
|
||||
|
||||
await driver.close()
|
||||
|
||||
|
||||
async def get_node_count(driver: GraphDriver, uuid: str):
|
||||
result, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n {uuid: $uuid})
|
||||
RETURN COUNT(n) as count
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
return int(result[0]['count'])
|
||||
await graph_driver.close()
|
||||
|
|
|
|||
42
uv.lock
generated
42
uv.lock
generated
|
|
@ -809,6 +809,7 @@ dev = [
|
|||
{ name = "groq" },
|
||||
{ name = "ipykernel" },
|
||||
{ name = "jupyterlab" },
|
||||
{ name = "kuzu" },
|
||||
{ name = "langchain-anthropic" },
|
||||
{ name = "langchain-openai" },
|
||||
{ name = "langgraph" },
|
||||
|
|
@ -831,6 +832,9 @@ google-genai = [
|
|||
groq = [
|
||||
{ name = "groq" },
|
||||
]
|
||||
kuzu = [
|
||||
{ name = "kuzu" },
|
||||
]
|
||||
neptune = [
|
||||
{ name = "boto3" },
|
||||
{ name = "langchain-aws" },
|
||||
|
|
@ -858,6 +862,8 @@ requires-dist = [
|
|||
{ name = "groq", marker = "extra == 'groq'", specifier = ">=0.2.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5" },
|
||||
{ name = "jupyterlab", marker = "extra == 'dev'", specifier = ">=4.2.4" },
|
||||
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
|
||||
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
|
||||
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
|
||||
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
|
||||
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
|
||||
|
|
@ -882,7 +888,7 @@ requires-dist = [
|
|||
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
|
||||
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
|
||||
]
|
||||
provides-extras = ["anthropic", "groq", "google-genai", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
|
||||
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
|
||||
|
||||
[[package]]
|
||||
name = "groq"
|
||||
|
|
@ -1387,6 +1393,40 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/54/09/2032e7d15c544a0e3cd831c51d77a8ca57f7555b2e1b2922142eddb02a84/jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4", size = 59700, upload-time = "2024-07-16T17:02:01.115Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kuzu"
|
||||
version = "0.11.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/fd/adbd05ccf81e6ad2674fcd3849d5d6ffeaf2141a9b8d1c1c4e282e923e1f/kuzu-0.11.2.tar.gz", hash = "sha256:9f224ec218ab165a18acaea903695779780d70335baf402d9b7f59ba389db0bd", size = 4902887, upload-time = "2025-08-21T05:17:00.152Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/91/bed837f5f49220a9f869da8a078b34a3484f210f7b57b267177821545c03/kuzu-0.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b25174cdb721aae47896ed62842d3859679607b493a9a6bbbcd9fb7fb3707", size = 3702618, upload-time = "2025-08-21T05:15:53.726Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/8a/fd5e053b0055718afe00b6a99393a835c6254354128fbb7f66a35fd76089/kuzu-0.11.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:9a8567c53bfe282f4727782471ff718842ffead8c48c1762c1df9197408fc986", size = 4101371, upload-time = "2025-08-21T05:15:55.889Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/4b/e45cadc85bdc5079f432675bbe8d557600f0d4ab46fe24ef218374419902/kuzu-0.11.2-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d793bb5a0a14ada730a697eccac2a4c68b434b82692d985942900ef2003e099e", size = 6211974, upload-time = "2025-08-21T05:15:57.505Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/ca/92d6a1e6452fcf06bfc423ce2cde819ace6b6e47921921cc8fae87c27780/kuzu-0.11.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1be4e9b6c93ca8591b1fb165f9b9a27d70a56af061831afcdfe7aebb89ee6ff", size = 6992196, upload-time = "2025-08-21T05:15:59.006Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/49/6c/983fc6265dfc1169c87c4a0722f36ee665c5688e1166faeb4cd85e6af078/kuzu-0.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0ec7a304c746a2a98ecfd7e7c3f6fe92c4dfee2e2565c0b7cb4cffd0c2e374a", size = 4303517, upload-time = "2025-08-21T05:16:00.814Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/14/8ae2c52657b93715052ecf47d70232f2c8d9ffe2d1ec3527c8e9c3cb2df5/kuzu-0.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf53b4f321a4c05882b14cef96d39a1e90fa993bab88a1554fb1565367553b8c", size = 3704177, upload-time = "2025-08-21T05:16:02.354Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/7a/bce7bb755e16f9ca855f76a3acc6cfa9fae88c4d6af9df3784c50b2120a5/kuzu-0.11.2-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:2d749883b74f5da5ff4a4b0635a98f6cc5165743995828924321d2ca797317cb", size = 4102372, upload-time = "2025-08-21T05:16:04.249Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/12/f5b1d51fcb78a86c078fb85cc53184ce962a3e86852d47d30e287a932e3f/kuzu-0.11.2-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:632507e5982928ed24fbb5e70ad143d7970bc4059046e77e0522707efbad303b", size = 6212492, upload-time = "2025-08-21T05:16:05.99Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/96/d6e57af6ccf9e0697812ad3039c80b87b768cf2674833b0b23d317ea3427/kuzu-0.11.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5211884601f8f08ae97ba25006d0facde24077c5333411d944282b8a2068ab4", size = 6992888, upload-time = "2025-08-21T05:16:07.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/ee/1f275ac5679a3f615ce0d9cf8c79001fdb535ccc8bc344e49b14484c7cd7/kuzu-0.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:82a6c8bfe1278dc1010790e398bf772683797ef5c16052fa0f6f78bacbc59aa3", size = 4304064, upload-time = "2025-08-21T05:16:10.163Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/ba/9f20d9e83681a0ddae8ec13046b116c34745fa0e66862d4e2d8414734ce0/kuzu-0.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aed88ffa695d07289a3d8557bd8f9e743298a4f4349208a60bbb06f4ebf15c26", size = 3703781, upload-time = "2025-08-21T05:16:12.232Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/a0/bb815c0490f3d4d30389156369b9fe641e154f0d4b1e8340f09a76021922/kuzu-0.11.2-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:595824b03248af928e3faee57f6825d3a46920f2d3b9bd0c0bb7fc3fa097fce9", size = 4103990, upload-time = "2025-08-21T05:16:14.139Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/6f/97b647c0547a634a669055ff4cfd21a92ea3999aedc6a7fe9004f03f25e3/kuzu-0.11.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5674c6d9d26f5caa0c7ce6f34c02e4411894879aa5b2ce174fad576fa898523", size = 6211947, upload-time = "2025-08-21T05:16:16.48Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/74/c7f1a1cfb08c05c91c5a94483be387e80fafab8923c4243a22e9cced5c1b/kuzu-0.11.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c61daf02da35b671f4c6f3c17105725c399a5e14b7349b00eafbcd24ac90034a", size = 6991879, upload-time = "2025-08-21T05:16:18.402Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/9e/50d67d7bc08faed95ede6de1a6aa0d81079c98028ca99e32d09c2ab1aead/kuzu-0.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:682096cd87dcbb8257f933ea4172d9dc5617a8d0a5bdd19cd66cf05b68881afd", size = 4305706, upload-time = "2025-08-21T05:16:20.244Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/65/f0/5649a01af37def50293cd7c194afc19f09b343fd2b7f2b28e021a207f8ce/kuzu-0.11.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:17a11b67652e8b331c85cd1a1a30b32ee6783750084473abbab2aa1963ee2a3b", size = 3703740, upload-time = "2025-08-21T05:16:21.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/e2/e0beb9080911fc1689899a42da0f83534949f43169fb80197def3ec1223f/kuzu-0.11.2-cp313-cp313-macosx_11_0_x86_64.whl", hash = "sha256:bdded35426210faeca8da11e8b4a54e60ccc0c1a832660d76587b5be133b0f55", size = 4104073, upload-time = "2025-08-21T05:16:23.819Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/4c/7a831c9c6e609692953db677f54788bd1dde4c9d34e6ba91f1e153d2e7fe/kuzu-0.11.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6116b609aac153f3523130b31295643d34a6c9509914c0fa9d804b26b23eee73", size = 6212263, upload-time = "2025-08-21T05:16:25.351Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/95/615ef10b46b22ec1d33fdbba795e6e79733d9a244aabdeeb910f267ab36c/kuzu-0.11.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:09da5b8cb24dc6b281a6e4ac0f7f24226eb9909803b187e02d014da13ba57bcf", size = 6992492, upload-time = "2025-08-21T05:16:27.518Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/dd/2c905575913c743e6c67a5ca89a6b4ea9d9737238966d85d7e710f0d3e60/kuzu-0.11.2-cp313-cp313-win_amd64.whl", hash = "sha256:c663fb84682f8ebffbe7447a4e552a0e03bd29097d319084a2c53c2e032a780e", size = 4305267, upload-time = "2025-08-21T05:16:29.307Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/05/44fbfc9055dba3f472ea4aaa8110635864d3441eede987526ef401680765/kuzu-0.11.2-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5c03fb95ffb9185c1519333f8ee92b7a9695aa7aa9a179e868a7d7bd13d10a16", size = 6216795, upload-time = "2025-08-21T05:16:30.944Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/ca/16c81dc68cc1e8918f8481e7ee89c28aa665c5cb36be7ad0fc1d0d295760/kuzu-0.11.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d857f0efddf26d5e2dc189facb84bf04a997e395972486669b418a470cc76034", size = 6996333, upload-time = "2025-08-21T05:16:32.568Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/d8/9275c7e6312bd76dc670e8e2da68639757c22cf2c366e96527595a1d881c/kuzu-0.11.2-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb9e4641867c35b98ceaa604aa79832c0eeed41f5fd1b6da22b1c217b2f1b8ea", size = 6212202, upload-time = "2025-08-21T05:16:34.571Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/89/67a977493c60bca3610845df13020711f357a5d80bf91549e4b48d877c2f/kuzu-0.11.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:553408d76a0b4fdecc1338b69b71d7bde42f6936d3b99d9852b30d33bda15978", size = 6992264, upload-time = "2025-08-21T05:16:36.316Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/49/869ceceb1d8a5ea032a35c734e55cfee919340889973623096da7eb94f6b/kuzu-0.11.2-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:989a87fa13ffa39ab7773d968fe739ac4f8faf9ddb5dad72ced2eeef12180293", size = 6216814, upload-time = "2025-08-21T05:16:38.348Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/cd/933b34a246edb882a042eb402747167719222c05149b73b48ba7d310d554/kuzu-0.11.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e67420d04a9643fd6376a23b17b398a3e32bb0c2bd8abbf8d1e4697056596c7e", size = 6996343, upload-time = "2025-08-21T05:16:39.973Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-anthropic"
|
||||
version = "0.3.17"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue