Add support for Kuzu as the graph driver (#799)

* Fix FalkoDB tests

* Add support for graph memory using Kuzu

* Fix lints

* Fix queries

* Add tests

* Add comments

* Add more test coverage

* Add mocked tests

* Format

* Add mocked tests II

* Refactor community queries

* Add more mocked tests

* Refactor tests to always cleanup

* Add more mocked tests

* Update kuzu

* Refactor how filters are built

* Add more mocked tests

* Refactor and cleanup

* Fix tests

* Fix lints

* Refactor tests

* Disable neptune

* Fix

* Update kuzu version

* Update kuzu to latest release

* Fix filter

* Fix query

* Fix Neptune query

* Fix bulk queries

* Fix lints

* Fix deletes

* Comments and format

* Add Kuzu to the README

* Fix bulk queries

* Test all fields of nodes and edges

* Fix lints

* Update search_utils.py

---------

Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
This commit is contained in:
Siddhartha Sahu 2025-08-27 11:45:21 -04:00 committed by GitHub
parent 309159bccb
commit 8802b7db13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 4219 additions and 966 deletions

View file

@ -49,6 +49,7 @@ jobs:
NEO4J_URI: bolt://localhost:7687
NEO4J_USER: neo4j
NEO4J_PASSWORD: testpass
DISABLE_NEPTUNE: 1
run: |
uv run pytest -m "not integration"
- name: Wait for FalkorDB

View file

@ -44,7 +44,7 @@ Use Graphiti to:
<br />
<p align="center">
<img src="images/graphiti-graph-intro.gif" alt="Graphiti temporal walkthrough" width="700px">
<img src="images/graphiti-graph-intro.gif" alt="Graphiti temporal walkthrough" width="700px">
</p>
<br />
@ -80,7 +80,7 @@ Traditional RAG approaches often rely on batch processing and static data summar
- **Scalability:** Efficiently manages large datasets with parallel processing, suitable for enterprise environments.
<p align="center">
<img src="/images/graphiti-intro-slides-stock-2.gif" alt="Graphiti structured + unstructured demo" width="700px">
<img src="/images/graphiti-intro-slides-stock-2.gif" alt="Graphiti structured + unstructured demo" width="700px">
</p>
## Graphiti vs. GraphRAG
@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
Requirements:
- Python 3.10 or higher
- Neo4j 5.26 / FalkorDB 1.1.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
> [!IMPORTANT]
@ -148,6 +148,17 @@ pip install graphiti-core[falkordb]
uv add graphiti-core[falkordb]
```
### Installing with Kuzu Support
If you plan to use Kuzu as your graph database backend, install with the Kuzu extra:
```bash
pip install graphiti-core[kuzu]
# or with uv
uv add graphiti-core[kuzu]
```
### Installing with Amazon Neptune Support
If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra:
@ -198,7 +209,7 @@ If your LLM provider allows higher throughput, you can increase `SEMAPHORE_LIMIT
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
1. Connecting to a Neo4j, Amazon Neptune, or FalkorDB database
1. Connecting to a Neo4j, Amazon Neptune, FalkorDB, or Kuzu database
2. Initializing Graphiti indices and constraints
3. Adding episodes to the graph (both text and structured JSON)
4. Searching for relationships (edges) using hybrid search
@ -281,6 +292,19 @@ driver = FalkorDriver(
graphiti = Graphiti(graph_driver=driver)
```
#### Kuzu
```python
from graphiti_core import Graphiti
from graphiti_core.driver.kuzu_driver import KuzuDriver
# Create a Kuzu driver
driver = KuzuDriver(db="/tmp/graphiti.kuzu")
# Pass the driver to Graphiti
graphiti = Graphiti(graph_driver=driver)
```
#### Amazon Neptune
```python
@ -494,7 +518,7 @@ When you initialize a Graphiti instance, we collect:
- **Graphiti version**: The version you're using
- **Configuration choices**:
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
- Database backend (Neo4j, FalkorDB, Amazon Neptune Database or Neptune Analytics)
- Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics)
- Embedder provider (OpenAI, Azure, Voyage, etc.)
### What We Don't Collect

View file

@ -4,3 +4,7 @@ import sys
# This code adds the project root directory to the Python path, allowing imports to work correctly when running tests.
# Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))
from tests.helpers_test import graph_driver, mock_embedder
__all__ = ['graph_driver', 'mock_embedder']

View file

@ -27,10 +27,13 @@ logger = logging.getLogger(__name__)
class GraphProvider(Enum):
NEO4J = 'neo4j'
FALKORDB = 'falkordb'
KUZU = 'kuzu'
NEPTUNE = 'neptune'
class GraphDriverSession(ABC):
provider: GraphProvider
async def __aenter__(self):
return self

View file

@ -15,7 +15,6 @@ limitations under the License.
"""
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
@ -33,11 +32,14 @@ else:
) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
logger = logging.getLogger(__name__)
class FalkorDriverSession(GraphDriverSession):
provider = GraphProvider.FALKORDB
def __init__(self, graph: FalkorGraph):
self.graph = graph
@ -164,16 +166,3 @@ class FalkorDriver(GraphDriver):
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned
def convert_datetimes_to_strings(obj):
if isinstance(obj, dict):
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_datetimes_to_strings(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(convert_datetimes_to_strings(item) for item in obj)
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj

View file

@ -0,0 +1,175 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from typing import Any
import kuzu
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__)
# Kuzu requires an explicit schema.
# As Kuzu currently does not support creating full text indexes on edge properties,
# we work around this by representing (n:Entity)-[:RELATES_TO]->(m:Entity) as
# (n)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m).
SCHEMA_QUERIES = """
CREATE NODE TABLE IF NOT EXISTS Episodic (
uuid STRING PRIMARY KEY,
name STRING,
group_id STRING,
created_at TIMESTAMP,
source STRING,
source_description STRING,
content STRING,
valid_at TIMESTAMP,
entity_edges STRING[]
);
CREATE NODE TABLE IF NOT EXISTS Entity (
uuid STRING PRIMARY KEY,
name STRING,
group_id STRING,
labels STRING[],
created_at TIMESTAMP,
name_embedding FLOAT[],
summary STRING,
attributes STRING
);
CREATE NODE TABLE IF NOT EXISTS Community (
uuid STRING PRIMARY KEY,
name STRING,
group_id STRING,
created_at TIMESTAMP,
name_embedding FLOAT[],
summary STRING
);
CREATE NODE TABLE IF NOT EXISTS RelatesToNode_ (
uuid STRING PRIMARY KEY,
group_id STRING,
created_at TIMESTAMP,
name STRING,
fact STRING,
fact_embedding FLOAT[],
episodes STRING[],
expired_at TIMESTAMP,
valid_at TIMESTAMP,
invalid_at TIMESTAMP,
attributes STRING
);
CREATE REL TABLE IF NOT EXISTS RELATES_TO(
FROM Entity TO RelatesToNode_,
FROM RelatesToNode_ TO Entity
);
CREATE REL TABLE IF NOT EXISTS MENTIONS(
FROM Episodic TO Entity,
uuid STRING PRIMARY KEY,
group_id STRING,
created_at TIMESTAMP
);
CREATE REL TABLE IF NOT EXISTS HAS_MEMBER(
FROM Community TO Entity,
FROM Community TO Community,
uuid STRING,
group_id STRING,
created_at TIMESTAMP
);
"""
class KuzuDriver(GraphDriver):
provider: GraphProvider = GraphProvider.KUZU
def __init__(
self,
db: str = ':memory:',
max_concurrent_queries: int = 1,
):
super().__init__()
self.db = kuzu.Database(db)
self.setup_schema()
self.client = kuzu.AsyncConnection(self.db, max_concurrent_queries=max_concurrent_queries)
async def execute_query(
self, cypher_query_: str, **kwargs: Any
) -> tuple[list[dict[str, Any]] | list[list[dict[str, Any]]], None, None]:
params = {k: v for k, v in kwargs.items() if v is not None}
# Kuzu does not support these parameters.
params.pop('database_', None)
params.pop('routing_', None)
try:
results = await self.client.execute(cypher_query_, parameters=params)
except Exception as e:
params = {k: (v[:5] if isinstance(v, list) else v) for k, v in params.items()}
logger.error(f'Error executing Kuzu query: {e}\n{cypher_query_}\n{params}')
raise
if not results:
return [], None, None
if isinstance(results, list):
dict_results = [list(result.rows_as_dict()) for result in results]
else:
dict_results = list(results.rows_as_dict())
return dict_results, None, None # type: ignore
def session(self, _database: str | None = None) -> GraphDriverSession:
return KuzuDriverSession(self)
async def close(self):
# Do not explicity close the connection, instead rely on GC.
pass
def delete_all_indexes(self, database_: str):
pass
def setup_schema(self):
conn = kuzu.Connection(self.db)
conn.execute(SCHEMA_QUERIES)
conn.close()
class KuzuDriverSession(GraphDriverSession):
provider = GraphProvider.KUZU
def __init__(self, driver: KuzuDriver):
self.driver = driver
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
# No cleanup needed for Kuzu, but method must exist.
pass
async def close(self):
# Do not close the session here, as we're reusing the driver connection.
pass
async def execute_write(self, func, *args, **kwargs):
# Directly await the provided async function with `self` as the transaction/session
return await func(self, *args, **kwargs)
async def run(self, query: str | list, **kwargs: Any) -> Any:
if isinstance(query, list):
for cypher, params in query:
await self.driver.execute_query(cypher, **params)
else:
await self.driver.execute_query(query, **kwargs)
return None

View file

@ -271,6 +271,8 @@ class NeptuneDriver(GraphDriver):
class NeptuneDriverSession(GraphDriverSession):
provider = GraphProvider.NEPTUNE
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
self.driver = driver

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
@ -30,11 +31,10 @@ from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_RETURN,
ENTITY_EDGE_RETURN,
ENTITY_EDGE_RETURN_NEPTUNE,
EPISODIC_EDGE_RETURN,
EPISODIC_EDGE_SAVE,
get_community_edge_save_query,
get_entity_edge_return_query,
get_entity_edge_save_query,
)
from graphiti_core.nodes import Node
@ -53,33 +53,63 @@ class Edge(BaseModel, ABC):
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
result = await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
if driver.provider == GraphProvider.KUZU:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
await driver.execute_query(
"""
MATCH (e:RelatesToNode_ {uuid: $uuid})
DETACH DELETE e
""",
uuid=self.uuid,
)
else:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Edge: {self.uuid}')
return result
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
result = await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
DELETE e
""",
uuids=uuids,
)
if driver.provider == GraphProvider.KUZU:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
DELETE e
""",
uuids=uuids,
)
await driver.execute_query(
"""
MATCH (e:RelatesToNode_)
WHERE e.uuid IN $uuids
DETACH DELETE e
""",
uuids=uuids,
)
else:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
DELETE e
""",
uuids=uuids,
)
logger.debug(f'Deleted Edges: {uuids}')
return result
def __hash__(self):
return hash(self.uuid)
@ -166,7 +196,7 @@ class EpisodicEdge(Edge):
"""
+ EPISODIC_EDGE_RETURN
+ """
ORDER BY e.uuid DESC
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
@ -215,15 +245,21 @@ class EntityEdge(Edge):
return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
else:
query: LiteralString = """
query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
if driver.provider == GraphProvider.NEPTUNE:
query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
if driver.provider == GraphProvider.KUZU:
query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
records, _, _ = await driver.execute_query(
@ -253,15 +289,22 @@ class EntityEdge(Edge):
'invalid_at': self.invalid_at,
}
edge_data.update(self.attributes or {})
if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
**edge_data,
)
else:
edge_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
edge_data=edge_data,
)
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
edge_data=edge_data,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
@ -269,21 +312,25 @@ class EntityEdge(Edge):
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
match_query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
+ get_entity_edge_return_query(driver.provider),
uuid=uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record) for record in records]
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0:
raise EdgeNotFoundError(uuid)
@ -294,22 +341,26 @@ class EntityEdge(Edge):
if len(uuids) == 0:
return []
records, _, _ = await driver.execute_query(
"""
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WHERE e.uuid IN $uuids
RETURN
"""
+ (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
+ get_entity_edge_return_query(driver.provider),
uuids=uuids,
routing_='r',
)
edges = [get_entity_edge_from_record(record) for record in records]
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@ -332,23 +383,27 @@ class EntityEdge(Edge):
else ''
)
records, _, _ = await driver.execute_query(
"""
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
)
+ get_entity_edge_return_query(driver.provider)
+ with_embeddings_query
+ """
ORDER BY e.uuid DESC
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
@ -357,7 +412,7 @@ class EntityEdge(Edge):
routing_='r',
)
edges = [get_entity_edge_from_record(record) for record in records]
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0:
raise GroupsEdgesNotFoundError(group_ids)
@ -365,21 +420,25 @@ class EntityEdge(Edge):
@classmethod
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
records, _, _ = await driver.execute_query(
"""
match_query = """
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
+ get_entity_edge_return_query(driver.provider),
node_uuid=node_uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record) for record in records]
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@ -479,7 +538,25 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
)
def get_entity_edge_from_record(record: Any) -> EntityEdge:
def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
episodes = record['episodes']
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
else:
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('source_node_uuid', None)
attributes.pop('target_node_uuid', None)
attributes.pop('fact', None)
attributes.pop('fact_embedding', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('episodes', None)
attributes.pop('created_at', None)
attributes.pop('expired_at', None)
attributes.pop('valid_at', None)
attributes.pop('invalid_at', None)
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
@ -488,26 +565,14 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
fact_embedding=record.get('fact_embedding'),
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],
episodes=episodes,
created_at=parse_db_date(record['created_at']), # type: ignore
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
attributes=record['attributes'],
attributes=attributes,
)
edge.attributes.pop('uuid', None)
edge.attributes.pop('source_node_uuid', None)
edge.attributes.pop('target_node_uuid', None)
edge.attributes.pop('fact', None)
edge.attributes.pop('name', None)
edge.attributes.pop('group_id', None)
edge.attributes.pop('episodes', None)
edge.attributes.pop('created_at', None)
edge.attributes.pop('expired_at', None)
edge.attributes.pop('valid_at', None)
edge.attributes.pop('invalid_at', None)
return edge

View file

@ -16,6 +16,13 @@ NEO4J_TO_FALKORDB_MAPPING = {
'episode_content': 'Episodic',
'edge_name_and_fact': 'RELATES_TO',
}
# Mapping from fulltext index names to Kuzu node labels
INDEX_TO_LABEL_KUZU_MAPPING = {
'node_name_and_summary': 'Entity',
'community_name': 'Community',
'episode_content': 'Episodic',
'edge_name_and_fact': 'RelatesToNode_',
}
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
@ -35,6 +42,9 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
]
if provider == GraphProvider.KUZU:
return []
return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
@ -68,6 +78,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
]
if provider == GraphProvider.KUZU:
return [
"CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
"CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
"CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
]
return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
@ -80,11 +98,15 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
]
def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
if provider == GraphProvider.KUZU:
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
@ -93,12 +115,19 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
if provider == GraphProvider.KUZU:
return f'array_cosine_similarity({vec1}, {vec2})'
return f'vector.similarity.cosine({vec1}, {vec2})'
def get_relationships_query(name: str, provider: GraphProvider) -> str:
def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
if provider == GraphProvider.KUZU:
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'

View file

@ -1070,7 +1070,7 @@ class Graphiti:
if record['episode_count'] == 1:
nodes_to_delete.append(node)
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
await episode.delete(self.driver)

View file

@ -43,14 +43,14 @@ RUNTIME_QUERY: LiteralString = (
)
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
return (
neo_date.to_native()
if isinstance(neo_date, neo4j_time.DateTime)
else datetime.fromisoformat(neo_date)
if neo_date
else None
)
def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
if isinstance(input_date, neo4j_time.DateTime):
return input_date.to_native()
if isinstance(input_date, str):
return datetime.fromisoformat(input_date)
return input_date
def get_default_group_id(provider: GraphProvider) -> str:

View file

@ -20,18 +20,36 @@ EPISODIC_EDGE_SAVE = """
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
SET
e.group_id = $group_id,
e.created_at = $created_at
RETURN e.uuid AS uuid
"""
EPISODIC_EDGE_SAVE_BULK = """
UNWIND $episodic_edges AS edge
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
RETURN e.uuid AS uuid
"""
def get_episodic_edge_save_bulk_query(provider: GraphProvider) -> str:
if provider == GraphProvider.KUZU:
return """
MATCH (episode:Episodic {uuid: $source_node_uuid})
MATCH (node:Entity {uuid: $target_node_uuid})
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
SET
e.group_id = $group_id,
e.created_at = $created_at
RETURN e.uuid AS uuid
"""
return """
UNWIND $episodic_edges AS edge
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
SET
e.group_id = edge.group_id,
e.created_at = edge.created_at
RETURN e.uuid AS uuid
"""
EPISODIC_EDGE_RETURN = """
e.uuid AS uuid,
@ -54,14 +72,32 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
"""
case GraphProvider.NEPTUNE:
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
SET e.episodes = join($edge_data.episodes, ",")
RETURN $edge_data.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
SET
e.group_id = $group_id,
e.created_at = $created_at,
e.name = $name,
e.fact = $fact,
e.fact_embedding = $fact_embedding,
e.episodes = $episodes,
e.expired_at = $expired_at,
e.valid_at = $valid_at,
e.invalid_at = $invalid_at,
e.attributes = $attributes
RETURN e.uuid AS uuid
"""
case _: # Neo4j
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
@ -89,14 +125,32 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
SET r.episodes = join(edge.episodes, ",")
RETURN edge.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MATCH (source:Entity {uuid: $source_node_uuid})
MATCH (target:Entity {uuid: $target_node_uuid})
MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
SET
e.group_id = $group_id,
e.created_at = $created_at,
e.name = $name,
e.fact = $fact,
e.fact_embedding = $fact_embedding,
e.episodes = $episodes,
e.expired_at = $expired_at,
e.valid_at = $valid_at,
e.invalid_at = $invalid_at,
e.attributes = $attributes
RETURN e.uuid AS uuid
"""
case _:
return """
UNWIND $entity_edges AS edge
@ -109,35 +163,42 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
"""
ENTITY_EDGE_RETURN = """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.name AS name,
e.fact AS fact,
e.episodes AS episodes,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
def get_entity_edge_return_query(provider: GraphProvider) -> str:
# `fact_embedding` is not returned by default and must be manually loaded using `load_fact_embedding()`.
ENTITY_EDGE_RETURN_NEPTUNE = """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.name AS name,
e.fact AS fact,
split(e.episodes, ',') AS episodes,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
if provider == GraphProvider.NEPTUNE:
return """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.name AS name,
e.fact AS fact,
split(e.episodes, ',') AS episodes,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
return """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.created_at AS created_at,
e.name AS name,
e.fact AS fact,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
""" + (
'e.attributes AS attributes'
if provider == GraphProvider.KUZU
else 'properties(e) AS attributes'
)
def get_community_edge_save_query(provider: GraphProvider) -> str:
@ -152,7 +213,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
"""
case GraphProvider.NEPTUNE:
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (community:Community {uuid: $community_uuid})
MATCH (node {uuid: $entity_uuid})
WHERE node:Entity OR node:Community
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
@ -161,6 +222,24 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
SET r.created_at= $created_at
RETURN r.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET
e.group_id = $group_id,
e.created_at = $created_at
RETURN e.uuid AS uuid
UNION
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Community {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET
e.group_id = $group_id,
e.created_at = $created_at
RETURN e.uuid AS uuid
"""
case _: # Neo4j
return """
MATCH (community:Community {uuid: $community_uuid})

View file

@ -24,10 +24,24 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE:
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MERGE (n:Episodic {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.source = $source,
n.source_description = $source_description,
n.content = $content,
n.valid_at = $valid_at,
n.entity_edges = $entity_edges
RETURN n.uuid AS uuid
"""
case GraphProvider.FALKORDB:
return """
MERGE (n:Episodic {uuid: $uuid})
@ -51,11 +65,25 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MERGE (n:Episodic {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.source = $source,
n.source_description = $source_description,
n.content = $content,
n.valid_at = $valid_at,
n.entity_edges = $entity_edges
RETURN n.uuid AS uuid
"""
case GraphProvider.FALKORDB:
return """
UNWIND $episodes AS episode
@ -76,14 +104,14 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
EPISODIC_NODE_RETURN = """
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.created_at AS created_at,
e.source AS source,
e.source_description AS source_description,
e.content AS content,
e.valid_at AS valid_at,
e.entity_edges AS entity_edges
"""
@ -109,6 +137,20 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
SET n = $entity_data
RETURN n.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MERGE (n:Entity {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.labels = $labels,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary,
n.attributes = $attributes
WITH n
RETURN n.uuid AS uuid
"""
case GraphProvider.NEPTUNE:
label_subquery = ''
for label in labels.split(':'):
@ -168,6 +210,19 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
"""
)
return queries
case GraphProvider.KUZU:
return """
MERGE (n:Entity {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.labels = $labels,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary,
n.attributes = $attributes
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
UNWIND $nodes AS node
@ -179,15 +234,28 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
"""
ENTITY_NODE_RETURN = """
n.uuid AS uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
"""
def get_entity_node_return_query(provider: GraphProvider) -> str:
# `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
if provider == GraphProvider.KUZU:
return """
n.uuid AS uuid,
n.name AS name,
n.group_id AS group_id,
n.labels AS labels,
n.created_at AS created_at,
n.summary AS summary,
n.attributes AS attributes
"""
return """
n.uuid AS uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
"""
def get_community_node_save_query(provider: GraphProvider) -> str:
@ -201,10 +269,21 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE:
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MERGE (n:Community {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Community {uuid: $uuid})
@ -215,12 +294,12 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
COMMUNITY_NODE_RETURN = """
n.uuid AS uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.summary AS summary,
n.created_at AS created_at
c.uuid AS uuid,
c.name AS name,
c.group_id AS group_id,
c.created_at AS created_at,
c.name_embedding AS name_embedding,
c.summary AS summary
"""
COMMUNITY_NODE_RETURN_NEPTUNE = """

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
@ -32,10 +33,10 @@ from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_RETURN,
COMMUNITY_NODE_RETURN_NEPTUNE,
ENTITY_NODE_RETURN,
EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE,
get_community_node_save_query,
get_entity_node_return_query,
get_entity_node_save_query,
get_episode_node_save_query,
)
@ -95,12 +96,37 @@ class Node(BaseModel, ABC):
case GraphProvider.NEO4J:
await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
DETACH DELETE n
""",
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
case _: # FalkorDB and Neptune
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
DETACH DELETE e
""",
uuid=self.uuid,
)
await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
@ -136,8 +162,32 @@ class Node(BaseModel, ABC):
group_id=group_id,
batch_size=batch_size,
)
case _: # FalkorDB and Neptune
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
""",
group_id=group_id,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
DETACH DELETE e
""",
group_id=group_id,
)
await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id})
DETACH DELETE n
""",
group_id=group_id,
)
case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
@ -149,30 +199,59 @@ class Node(BaseModel, ABC):
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
if driver.provider == GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label})
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
else:
async with driver.session() as session:
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
CALL {
WITH n
match driver.provider:
case GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label})
WHERE n.uuid IN $uuids
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
uuids=uuids,
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label})
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
WHERE n.uuid IN $uuids
DETACH DELETE e
""",
uuids=uuids,
batch_size=batch_size,
)
await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
case _: # Neo4J, Neptune
async with driver.session() as session:
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
CALL {
WITH n
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
uuids=uuids,
batch_size=batch_size,
)
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@ -376,17 +455,25 @@ class EntityNode(Node):
'summary': self.summary,
'created_at': self.created_at,
}
entity_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
if driver.provider == GraphProvider.KUZU:
entity_data['attributes'] = json.dumps(self.attributes)
entity_data['labels'] = list(set(self.labels + ['Entity']))
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels=''),
**entity_data,
)
else:
entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels),
entity_data=entity_data,
)
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels),
entity_data=entity_data,
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
@ -399,12 +486,12 @@ class EntityNode(Node):
MATCH (n:Entity {uuid: $uuid})
RETURN
"""
+ ENTITY_NODE_RETURN,
+ get_entity_node_return_query(driver.provider),
uuid=uuid,
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records]
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
@ -419,12 +506,12 @@ class EntityNode(Node):
WHERE n.uuid IN $uuids
RETURN
"""
+ ENTITY_NODE_RETURN,
+ get_entity_node_return_query(driver.provider),
uuids=uuids,
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records]
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@ -456,7 +543,7 @@ class EntityNode(Node):
+ """
RETURN
"""
+ ENTITY_NODE_RETURN
+ get_entity_node_return_query(driver.provider)
+ with_embeddings_query
+ """
ORDER BY n.uuid DESC
@ -468,7 +555,7 @@ class EntityNode(Node):
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records]
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@ -533,7 +620,7 @@ class CommunityNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community {uuid: $uuid})
MATCH (c:Community {uuid: $uuid})
RETURN
"""
+ (
@ -556,8 +643,8 @@ class CommunityNode(Node):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)
WHERE n.uuid IN $uuids
MATCH (c:Community)
WHERE c.uuid IN $uuids
RETURN
"""
+ (
@ -581,13 +668,13 @@ class CommunityNode(Node):
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)
WHERE n.group_id IN $group_ids
MATCH (c:Community)
WHERE c.group_id IN $group_ids
"""
+ cursor_query
+ """
@ -599,7 +686,7 @@ class CommunityNode(Node):
else COMMUNITY_NODE_RETURN
)
+ """
ORDER BY n.uuid DESC
ORDER BY c.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
@ -636,7 +723,19 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
)
def get_entity_node_from_record(record: Any) -> EntityNode:
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
else:
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('name_embedding', None)
attributes.pop('summary', None)
attributes.pop('created_at', None)
attributes.pop('labels', None)
entity_node = EntityNode(
uuid=record['uuid'],
name=record['name'],
@ -645,16 +744,9 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
labels=record['labels'],
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
attributes=record['attributes'],
attributes=attributes,
)
entity_node.attributes.pop('uuid', None)
entity_node.attributes.pop('name', None)
entity_node.attributes.pop('group_id', None)
entity_node.attributes.pop('name_embedding', None)
entity_node.attributes.pop('summary', None)
entity_node.attributes.pop('created_at', None)
return entity_node

View file

@ -20,6 +20,8 @@ from typing import Any
from pydantic import BaseModel, Field
from graphiti_core.driver.driver import GraphProvider
class ComparisonOperator(Enum):
equals = '='
@ -54,16 +56,21 @@ class SearchFilters(BaseModel):
def node_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[str, dict[str, Any]]:
filter_query: str = ''
provider: GraphProvider,
) -> tuple[list[str], dict[str, Any]]:
filter_queries: list[str] = []
filter_params: dict[str, Any] = {}
if filters.node_labels is not None:
node_labels = '|'.join(filters.node_labels)
node_label_filter = ' AND n:' + node_labels
filter_query += node_label_filter
if provider == GraphProvider.KUZU:
node_label_filter = 'list_has_all(n.labels, $labels)'
filter_params['labels'] = filters.node_labels
else:
node_labels = '|'.join(filters.node_labels)
node_label_filter = 'n:' + node_labels
filter_queries.append(node_label_filter)
return filter_query, filter_params
return filter_queries, filter_params
def date_filter_query_constructor(
@ -81,23 +88,29 @@ def date_filter_query_constructor(
def edge_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[str, dict[str, Any]]:
filter_query: str = ''
provider: GraphProvider,
) -> tuple[list[str], dict[str, Any]]:
filter_queries: list[str] = []
filter_params: dict[str, Any] = {}
if filters.edge_types is not None:
edge_types = filters.edge_types
edge_types_filter = '\nAND e.name in $edge_types'
filter_query += edge_types_filter
filter_queries.append('e.name in $edge_types')
filter_params['edge_types'] = edge_types
if filters.node_labels is not None:
node_labels = '|'.join(filters.node_labels)
node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels
filter_query += node_label_filter
if provider == GraphProvider.KUZU:
node_label_filter = (
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
)
filter_params['labels'] = filters.node_labels
else:
node_labels = '|'.join(filters.node_labels)
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
filter_queries.append(node_label_filter)
if filters.valid_at is not None:
valid_at_filter = '\nAND ('
valid_at_filter = '('
for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
@ -125,10 +138,10 @@ def edge_search_filter_query_constructor(
else:
valid_at_filter += ' OR '
filter_query += valid_at_filter
filter_queries.append(valid_at_filter)
if filters.invalid_at is not None:
invalid_at_filter = ' AND ('
invalid_at_filter = '('
for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
@ -156,10 +169,10 @@ def edge_search_filter_query_constructor(
else:
invalid_at_filter += ' OR '
filter_query += invalid_at_filter
filter_queries.append(invalid_at_filter)
if filters.created_at is not None:
created_at_filter = ' AND ('
created_at_filter = '('
for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
@ -187,10 +200,10 @@ def edge_search_filter_query_constructor(
else:
created_at_filter += ' OR '
filter_query += created_at_filter
filter_queries.append(created_at_filter)
if filters.expired_at is not None:
expired_at_filter = ' AND ('
expired_at_filter = '('
for i, or_list in enumerate(filters.expired_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
@ -218,6 +231,6 @@ def edge_search_filter_query_constructor(
else:
expired_at_filter += ' OR '
filter_query += expired_at_filter
filter_queries.append(expired_at_filter)
return filter_query, filter_params
return filter_queries, filter_params

File diff suppressed because it is too large Load diff

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from datetime import datetime
@ -22,20 +23,21 @@ import numpy as np
from pydantic import BaseModel, Field
from typing_extensions import Any
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
EPISODIC_EDGE_SAVE_BULK,
get_entity_edge_save_bulk_query,
get_episodic_edge_save_bulk_query,
)
from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
resolve_extracted_edge,
@ -116,11 +118,15 @@ async def add_nodes_and_edges_bulk_tx(
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
episode.pop('labels', None)
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
nodes: list[dict[str, Any]] = []
nodes = []
for node in entity_nodes:
if node.name_embedding is None:
await node.generate_name_embedding(embedder)
entity_data: dict[str, Any] = {
'uuid': node.uuid,
'name': node.name,
@ -130,13 +136,19 @@ async def add_nodes_and_edges_bulk_tx(
'created_at': node.created_at,
}
entity_data.update(node.attributes or {})
entity_data['labels'] = list(
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
)
entity_data['labels'] = list(set(node.labels + ['Entity']))
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
entity_data['attributes'] = json.dumps(attributes)
else:
entity_data.update(node.attributes or {})
entity_data['labels'] = list(
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
)
nodes.append(entity_data)
edges: list[dict[str, Any]] = []
edges = []
for edge in entity_edges:
if edge.fact_embedding is None:
await edge.generate_embedding(embedder)
@ -155,17 +167,36 @@ async def add_nodes_and_edges_bulk_tx(
'invalid_at': edge.invalid_at,
}
edge_data.update(edge.attributes or {})
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes)
else:
edge_data.update(edge.attributes or {})
edges.append(edge_data)
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
await tx.run(entity_node_save_bulk, nodes=nodes)
await tx.run(
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
)
entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
await tx.run(entity_edge_save_bulk, entity_edges=edges)
if driver.provider == GraphProvider.KUZU:
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
episode_query = get_episode_node_save_bulk_query(driver.provider)
for episode in episodes:
await tx.run(episode_query, **episode)
entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
for node in nodes:
await tx.run(entity_node_query, **node)
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
for edge in edges:
await tx.run(entity_edge_query, **edge)
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
for edge in episodic_edges:
await tx.run(episodic_edge_query, **edge.model_dump())
else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
await tx.run(
get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges],
)
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
async def extract_nodes_and_edges_bulk(

View file

@ -40,3 +40,16 @@ def ensure_utc(dt: datetime | None) -> datetime | None:
return dt.astimezone(timezone.utc)
return dt
def convert_datetimes_to_strings(obj):
if isinstance(obj, dict):
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_datetimes_to_strings(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(convert_datetimes_to_strings(item) for item in obj)
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj

View file

@ -4,11 +4,12 @@ from collections import defaultdict
from pydantic import BaseModel
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
@ -33,11 +34,11 @@ async def get_community_clusters(
if group_ids is None:
group_id_values, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.group_id IS NOT NULL
RETURN
collect(DISTINCT n.group_id) AS group_ids
""",
MATCH (n:Entity)
WHERE n.group_id IS NOT NULL
RETURN
collect(DISTINCT n.group_id) AS group_ids
"""
)
group_ids = group_id_values[0]['group_ids'] if group_id_values else []
@ -46,14 +47,21 @@ async def get_community_clusters(
projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes:
records, _, _ = await driver.execute_query(
match_query = """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
"""
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
WITH count(r) AS count, m.uuid AS uuid
RETURN
uuid,
count
""",
records, _, _ = await driver.execute_query(
match_query
+ """
WITH count(e) AS count, m.uuid AS uuid
RETURN
uuid,
count
""",
uuid=node.uuid,
group_id=group_id,
)
@ -235,9 +243,9 @@ async def build_communities(
async def remove_communities(driver: GraphDriver):
await driver.execute_query(
"""
MATCH (c:Community)
DETACH DELETE c
""",
MATCH (c:Community)
DETACH DELETE c
"""
)
@ -247,14 +255,10 @@ async def determine_entity_community(
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
RETURN
c.uuid AS uuid,
c.name AS name,
c.group_id AS group_id,
c.created_at AS created_at,
c.summary AS summary
""",
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
RETURN
"""
+ COMMUNITY_NODE_RETURN,
entity_uuid=entity.uuid,
)
@ -262,16 +266,19 @@ async def determine_entity_community(
return get_community_node_from_record(records[0]), False
# If the node has no community, add it to the mode community of surrounding entities
records, _, _ = await driver.execute_query(
match_query = """
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
"""
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
RETURN
c.uuid AS uuid,
c.name AS name,
c.group_id AS group_id,
c.created_at AS created_at,
c.summary AS summary
""",
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ COMMUNITY_NODE_RETURN,
entity_uuid=entity.uuid,
)

View file

@ -531,17 +531,28 @@ async def filter_existing_duplicate_of_edges(
routing_='r',
)
else:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
if driver.provider == GraphProvider.KUZU:
query = """
UNWIND $duplicate_node_uuids AS duplicate
MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
else:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_node_uuids = list(duplicate_nodes_map.keys())
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
duplicate_node_uuids=duplicate_node_uuids,
routing_='r',
)

View file

@ -53,10 +53,29 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
for name in index_names
]
)
range_indices: list[LiteralString] = get_range_indices(driver.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
if driver.provider == GraphProvider.KUZU:
# Skip creating fulltext indices if they already exist. Need to do this manually
# until Kuzu supports `IF NOT EXISTS` for indices.
result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
if len(result) > 0:
fulltext_indices = []
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
if len(result) == 0:
fulltext_indices.insert(
0,
"""
INSTALL fts;
LOAD fts;
""",
)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
@ -76,10 +95,19 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
await tx.run('MATCH (n) DETACH DELETE n')
async def delete_group_ids(tx):
await tx.run(
'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n',
group_ids=group_ids,
)
labels = ['Entity', 'Episodic', 'Community']
if driver.provider == GraphProvider.KUZU:
labels.append('RelatesToNode_')
for label in labels:
await tx.run(
f"""
MATCH (n:{label})
WHERE n.group_id IN $group_ids
DETACH DELETE n
""",
group_ids=group_ids,
)
if group_ids is None:
await session.execute_write(delete_all)
@ -108,18 +136,23 @@ async def retrieve_episodes(
Returns:
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
"""
group_id_filter: LiteralString = (
'\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
)
source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
query_params: dict = {}
query_filter = ''
if group_ids and len(group_ids) > 0:
query_filter += '\nAND e.group_id IN $group_ids'
query_params['group_ids'] = group_ids
if source is not None:
query_filter += '\nAND e.source = $source'
query_params['source'] = source.name
query: LiteralString = (
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ group_id_filter
+ source_filter
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ query_filter
+ """
RETURN
"""
@ -136,9 +169,8 @@ async def retrieve_episodes(
result, _, _ = await driver.execute_query(
query,
reference_time=reference_time,
source=source.name if source is not None else None,
num_episodes=last_n,
group_ids=group_ids,
**query_params,
)
episodes = [get_episodic_node_from_record(record) for record in result]

View file

@ -29,6 +29,7 @@ Repository = "https://github.com/getzep/graphiti"
anthropic = ["anthropic>=0.49.0"]
groq = ["groq>=0.2.0"]
google-genai = ["google-genai>=1.8.0"]
kuzu = ["kuzu>=0.11.2"]
falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"]
sentence-transformers = ["sentence-transformers>=3.2.1"]
@ -39,6 +40,7 @@ dev = [
"anthropic>=0.49.0",
"google-genai>=1.8.0",
"falkordb>=1.1.2,<2.0.0",
"kuzu>=0.11.2",
"ipykernel>=6.29.5",
"jupyterlab>=4.2.4",
"diskcache-stubs>=5.6.3.6.20240818",
@ -91,7 +93,3 @@ docstring-code-format = true
include = ["graphiti_core"]
pythonVersion = "3.10"
typeCheckingMode = "basic"
[[tool.pyright.overrides]]
include = ["**/falkordb*"]
reportMissingImports = false

View file

@ -1,4 +1,5 @@
[pytest]
markers =
integration: marks tests as integration tests
asyncio_default_fixture_loop_scope = function
asyncio_default_fixture_loop_scope = function
asyncio_mode = auto

View file

@ -15,42 +15,55 @@ limitations under the License.
"""
import os
from unittest.mock import Mock
import numpy as np
import pytest
from dotenv import load_dotenv
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neptune_driver import NeptuneDriver
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder.client import EmbedderClient
from graphiti_core.helpers import lucene_sanitize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
HAS_NEO4J = False
HAS_FALKORDB = False
HAS_NEPTUNE = False
drivers: list[GraphProvider] = []
if os.getenv('DISABLE_NEO4J') is None:
try:
from graphiti_core.driver.neo4j_driver import Neo4jDriver
HAS_NEO4J = True
drivers.append(GraphProvider.NEO4J)
except ImportError:
pass
raise
if os.getenv('DISABLE_FALKORDB') is None:
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver
HAS_FALKORDB = True
drivers.append(GraphProvider.FALKORDB)
except ImportError:
pass
raise
if os.getenv('DISABLE_KUZU') is None:
try:
from graphiti_core.driver.kuzu_driver import KuzuDriver
drivers.append(GraphProvider.KUZU)
except ImportError:
raise
# Disable Neptune for now
os.environ['DISABLE_NEPTUNE'] = 'True'
if os.getenv('DISABLE_NEPTUNE') is None:
try:
from graphiti_core.driver.neptune_driver import NeptuneDriver
HAS_NEPTUNE = False
drivers.append(GraphProvider.NEPTUNE)
except ImportError:
pass
raise
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
@ -65,38 +78,100 @@ NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
AOSS_HOST = os.getenv('AOSS_HOST', None)
KUZU_DB = os.getenv('KUZU_DB', ':memory:')
def get_driver(driver_name: str) -> GraphDriver:
if driver_name == 'neo4j':
group_id = 'graphiti_test_group'
group_id_2 = 'graphiti_test_group_2'
def get_driver(provider: GraphProvider) -> GraphDriver:
if provider == GraphProvider.NEO4J:
return Neo4jDriver(
uri=NEO4J_URI,
user=NEO4J_USER,
password=NEO4J_PASSWORD,
)
elif driver_name == 'falkordb':
elif provider == GraphProvider.FALKORDB:
return FalkorDriver(
host=FALKORDB_HOST,
port=int(FALKORDB_PORT),
username=FALKORDB_USER,
password=FALKORDB_PASSWORD,
)
elif driver_name == 'neptune':
elif provider == GraphProvider.KUZU:
driver = KuzuDriver(
db=KUZU_DB,
)
return driver
elif provider == GraphProvider.NEPTUNE:
return NeptuneDriver(
host=NEPTUNE_HOST,
port=int(NEPTUNE_PORT),
aoss_host=AOSS_HOST,
)
else:
raise ValueError(f'Driver {driver_name} not available')
raise ValueError(f'Driver {provider} not available')
drivers: list[str] = []
if HAS_NEO4J:
drivers.append('neo4j')
if HAS_FALKORDB:
drivers.append('falkordb')
if HAS_NEPTUNE:
drivers.append('neptune')
@pytest.fixture(params=drivers)
async def graph_driver(request):
driver = request.param
graph_driver = get_driver(driver)
await clear_data(graph_driver, [group_id, group_id_2])
try:
yield graph_driver # provide driver to the test
finally:
# always called, even if the test fails or raises
# await clean_up(graph_driver)
await graph_driver.close()
embedding_dim = 384
embeddings = {
key: np.random.uniform(0.0, 0.9, embedding_dim).tolist()
for key in [
'Alice',
'Bob',
'Alice likes Bob',
'test_entity_1',
'test_entity_2',
'test_entity_3',
'test_entity_4',
'test_entity_alice',
'test_entity_bob',
'test_entity_1 is a duplicate of test_entity_2',
'test_entity_3 is a duplicate of test_entity_4',
'test_entity_1 relates to test_entity_2',
'test_entity_1 relates to test_entity_3',
'test_entity_2 relates to test_entity_3',
'test_entity_1 relates to test_entity_4',
'test_entity_2 relates to test_entity_4',
'test_entity_3 relates to test_entity_4',
'test_entity_1 relates to test_entity_2',
'test_entity_3 relates to test_entity_4',
'test_entity_2 relates to test_entity_3',
'test_community_1',
'test_community_2',
]
}
embeddings['Alice Smith'] = embeddings['Alice']
@pytest.fixture
def mock_embedder():
mock_model = Mock(spec=EmbedderClient)
def mock_embed(input_data):
if isinstance(input_data, str):
return embeddings[input_data]
elif isinstance(input_data, list):
combined_input = ' '.join(input_data)
return embeddings[combined_input]
else:
raise ValueError(f'Unsupported input type: {type(input_data)}')
mock_model.create.side_effect = mock_embed
return mock_model
def test_lucene_sanitize():
@ -114,5 +189,125 @@ def test_lucene_sanitize():
assert assert_result == result
async def get_node_count(driver: GraphDriver, uuids: list[str]) -> int:
results, _, _ = await driver.execute_query(
"""
MATCH (n)
WHERE n.uuid IN $uuids
RETURN COUNT(n) as count
""",
uuids=uuids,
)
return int(results[0]['count'])
async def get_edge_count(driver: GraphDriver, uuids: list[str]) -> int:
results, _, _ = await driver.execute_query(
"""
MATCH (n)-[e]->(m)
WHERE e.uuid IN $uuids
RETURN COUNT(e) as count
UNION ALL
MATCH (e:RelatesToNode_)
WHERE e.uuid IN $uuids
RETURN COUNT(e) as count
""",
uuids=uuids,
)
return sum(int(result['count']) for result in results)
async def print_graph(graph_driver: GraphDriver):
nodes, _, _ = await graph_driver.execute_query(
"""
MATCH (n)
RETURN n.uuid, n.name
""",
)
print('Nodes:')
for node in nodes:
print(' ', node)
edges, _, _ = await graph_driver.execute_query(
"""
MATCH (n)-[e]->(m)
RETURN n.name, e.uuid, m.name
""",
)
print('Edges:')
for edge in edges:
print(' ', edge)
async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode):
assert retrieved.uuid == sample.uuid
assert retrieved.name == sample.name
assert retrieved.group_id == group_id
assert retrieved.created_at == sample.created_at
assert retrieved.source == sample.source
assert retrieved.source_description == sample.source_description
assert retrieved.content == sample.content
assert retrieved.valid_at == sample.valid_at
assert set(retrieved.entity_edges) == set(sample.entity_edges)
async def assert_entity_node_equals(
graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode
):
await retrieved.load_name_embedding(graph_driver)
assert retrieved.uuid == sample.uuid
assert retrieved.name == sample.name
assert retrieved.group_id == sample.group_id
assert set(retrieved.labels) == set(sample.labels)
assert retrieved.created_at == sample.created_at
assert retrieved.name_embedding is not None
assert sample.name_embedding is not None
assert np.allclose(retrieved.name_embedding, sample.name_embedding)
assert retrieved.summary == sample.summary
assert retrieved.attributes == sample.attributes
async def assert_community_node_equals(
graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode
):
await retrieved.load_name_embedding(graph_driver)
assert retrieved.uuid == sample.uuid
assert retrieved.name == sample.name
assert retrieved.group_id == group_id
assert retrieved.created_at == sample.created_at
assert retrieved.name_embedding is not None
assert sample.name_embedding is not None
assert np.allclose(retrieved.name_embedding, sample.name_embedding)
assert retrieved.summary == sample.summary
async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge):
assert retrieved.uuid == sample.uuid
assert retrieved.group_id == sample.group_id
assert retrieved.created_at == sample.created_at
assert retrieved.source_node_uuid == sample.source_node_uuid
assert retrieved.target_node_uuid == sample.target_node_uuid
async def assert_entity_edge_equals(
graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge
):
await retrieved.load_fact_embedding(graph_driver)
assert retrieved.uuid == sample.uuid
assert retrieved.group_id == sample.group_id
assert retrieved.created_at == sample.created_at
assert retrieved.source_node_uuid == sample.source_node_uuid
assert retrieved.target_node_uuid == sample.target_node_uuid
assert retrieved.name == sample.name
assert retrieved.fact == sample.fact
assert retrieved.fact_embedding is not None
assert sample.fact_embedding is not None
assert np.allclose(retrieved.fact_embedding, sample.fact_embedding)
assert retrieved.episodes == sample.episodes
assert retrieved.expired_at == sample.expired_at
assert retrieved.valid_at == sample.valid_at
assert retrieved.invalid_at == sample.invalid_at
assert retrieved.attributes == sample.attributes
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -17,23 +17,16 @@ limitations under the License.
import logging
import sys
from datetime import datetime
from uuid import uuid4
import numpy as np
import pytest
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.embedder.openai import OpenAIEmbedder
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from tests.helpers_test import drivers, get_driver
pytestmark = pytest.mark.integration
from tests.helpers_test import get_edge_count, get_node_count, group_id
pytest_plugins = ('pytest_asyncio',)
group_id = f'test_group_{str(uuid4())}'
def setup_logging():
# Create a logger
@ -57,17 +50,10 @@ def setup_logging():
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_episodic_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
async def test_episodic_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create episodic node
episode_node = EpisodicNode(
name='test_episode',
labels=[],
@ -79,13 +65,13 @@ async def test_episodic_edge(driver):
entity_edges=[],
group_id=group_id,
)
node_count = await get_node_count(graph_driver, episode_node.uuid)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 0
await episode_node.save(graph_driver)
node_count = await get_node_count(graph_driver, episode_node.uuid)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 1
# Create entity node
alice_node = EntityNode(
name='Alice',
labels=[],
@ -93,27 +79,27 @@ async def test_episodic_edge(driver):
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create episodic to entity edge
episodic_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
await episodic_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
assert retrieved.uuid == episodic_edge.uuid
assert retrieved.source_node_uuid == episode_node.uuid
@ -121,6 +107,7 @@ async def test_episodic_edge(driver):
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
@ -129,6 +116,7 @@ async def test_episodic_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
@ -137,33 +125,41 @@ async def test_episodic_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get episodic node by entity node uuid
retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == episode_node.uuid
assert retrieved[0].name == 'test_episode'
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Delete edge by uuid
await episodic_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
await episode_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, episode_node.uuid)
assert node_count == 0
# Delete edge by uuids
await episodic_edge.save(graph_driver)
await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await episode_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_entity_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
async def test_entity_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create entity node
alice_node = EntityNode(
name='Alice',
labels=[],
@ -171,25 +167,25 @@ async def test_entity_edge(driver):
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create entity node
bob_node = EntityNode(
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
)
await bob_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, bob_node.uuid)
await bob_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0
await bob_node.save(graph_driver)
node_count = await get_node_count(graph_driver, bob_node.uuid)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 1
# Create entity to entity edge
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
@ -202,14 +198,14 @@ async def test_entity_edge(driver):
invalid_at=now,
group_id=group_id,
)
edge_embedding = await entity_edge.generate_embedding(embedder)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
edge_embedding = await entity_edge.generate_embedding(mock_embedder)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
await entity_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
assert retrieved.uuid == entity_edge.uuid
assert retrieved.source_node_uuid == alice_node.uuid
@ -217,6 +213,7 @@ async def test_entity_edge(driver):
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
@ -225,6 +222,7 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
@ -233,6 +231,7 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by node uuid
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
@ -241,82 +240,113 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get fact embedding
await entity_edge.load_fact_embedding(graph_driver)
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
# Delete edge by uuid
await entity_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
# Delete edge by uuids
await entity_edge.save(graph_driver)
await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node should delete the edge
await entity_edge.save(graph_driver)
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node by uuids should delete the edge
await alice_node.save(graph_driver)
await entity_edge.save(graph_driver)
await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node by group id should delete the edge
await alice_node.save(graph_driver)
await entity_edge.save(graph_driver)
await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await bob_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, bob_node.uuid)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_community_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
async def test_community_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create community node
community_node_1 = CommunityNode(
name='Community A',
name='test_community_1',
group_id=group_id,
summary='Community A summary',
)
await community_node_1.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
await community_node_1.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0
await community_node_1.save(graph_driver)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 1
# Create community node
community_node_2 = CommunityNode(
name='Community B',
name='test_community_2',
group_id=group_id,
summary='Community B summary',
)
await community_node_2.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
await community_node_2.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0
await community_node_2.save(graph_driver)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 1
# Create entity node
alice_node = EntityNode(
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create community to community edge
community_edge = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=community_node_2.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
await community_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
assert retrieved.uuid == community_edge.uuid
assert retrieved.source_node_uuid == community_node_1.uuid
@ -324,6 +354,7 @@ async def test_community_edge(driver):
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
@ -332,6 +363,7 @@ async def test_community_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
@ -340,45 +372,26 @@ async def test_community_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Delete edge by uuid
await community_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
# Delete edge by uuids
await community_edge.save(graph_driver)
await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await community_node_1.delete(graph_driver)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0
await community_node_2.delete(graph_driver)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0
await graph_driver.close()
async def get_node_count(driver: GraphDriver, uuid: str):
results, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
RETURN COUNT(n) as count
""",
uuid=uuid,
)
return int(results[0]['count'])
async def get_edge_count(driver: GraphDriver, uuid: str):
results, _, _ = await driver.execute_query(
"""
MATCH (n)-[e {uuid: $uuid}]->(m)
RETURN COUNT(e) as count
UNION ALL
MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2)
RETURN COUNT(m) as count
""",
uuid=uuid,
)
return sum(int(result['count']) for result in results)

View file

@ -60,7 +60,6 @@ class Location(BaseModel):
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_exclude_default_entity_type(driver):
"""Test excluding the default 'Entity' type while keeping custom types."""
@ -118,7 +117,6 @@ async def test_exclude_default_entity_type(driver):
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_exclude_specific_custom_types(driver):
"""Test excluding specific custom entity types while keeping others."""
@ -182,7 +180,6 @@ async def test_exclude_specific_custom_types(driver):
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_exclude_all_types(driver):
"""Test excluding all entity types (edge case)."""
@ -231,7 +228,6 @@ async def test_exclude_all_types(driver):
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_exclude_no_types(driver):
"""Test normal behavior when no types are excluded (baseline test)."""
@ -314,7 +310,6 @@ def test_validation_invalid_excluded_types():
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_excluded_types_parameter_validation_in_add_episode(driver):
"""Test that add_episode validates excluded_entity_types parameter."""

View file

@ -23,7 +23,7 @@ from graphiti_core.graphiti import Graphiti
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
from graphiti_core.search.search_helpers import search_results_to_context_string
from graphiti_core.utils.datetime_utils import utc_now
from tests.helpers_test import drivers, get_driver
from tests.helpers_test import GraphProvider
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
@ -51,15 +51,12 @@ def setup_logging():
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_graphiti_init(driver):
async def test_graphiti_init(graph_driver):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
logger = setup_logging()
driver = get_driver(driver)
graphiti = Graphiti(graph_driver=driver)
graphiti = Graphiti(graph_driver=graph_driver)
await graphiti.build_indices_and_constraints()

2056
tests/test_graphiti_mock.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -14,22 +14,29 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime
from datetime import datetime, timedelta
from uuid import uuid4
import numpy as np
import pytest
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
from tests.helpers_test import drivers, get_driver
from tests.helpers_test import (
assert_community_node_equals,
assert_entity_node_equals,
assert_episodic_node_equals,
get_node_count,
group_id,
)
group_id = f'test_group_{str(uuid4())}'
created_at = datetime.now()
deleted_at = created_at + timedelta(days=3)
valid_at = created_at + timedelta(days=1)
invalid_at = created_at + timedelta(days=2)
@pytest.fixture
@ -38,9 +45,14 @@ def sample_entity_node():
uuid=str(uuid4()),
name='Test Entity',
group_id=group_id,
labels=[],
labels=['Entity', 'Person'],
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Entity Summary',
attributes={
'age': 30,
'location': 'New York',
},
)
@ -50,10 +62,12 @@ def sample_episodic_node():
uuid=str(uuid4()),
name='Episode 1',
group_id=group_id,
created_at=created_at,
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
valid_at=datetime.now(),
valid_at=valid_at,
entity_edges=[],
)
@ -62,182 +76,152 @@ def sample_community_node():
return CommunityNode(
uuid=str(uuid4()),
name='Community A',
name_embedding=[0.5] * 1024,
group_id=group_id,
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Community summary',
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_entity_node(sample_entity_node, driver):
driver = get_driver(driver)
async def test_entity_node(sample_entity_node, graph_driver):
uuid = sample_entity_node.uuid
# Create node
node_count = await get_node_count(driver, uuid)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_entity_node.save(driver)
node_count = await get_node_count(driver, uuid)
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid)
assert retrieved.uuid == sample_entity_node.uuid
assert retrieved.name == 'Test Entity'
assert retrieved.group_id == group_id
# Get node by uuid
retrieved = await EntityNode.get_by_uuid(graph_driver, sample_entity_node.uuid)
await assert_entity_node_equals(graph_driver, retrieved, sample_entity_node)
retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid])
assert retrieved[0].uuid == sample_entity_node.uuid
assert retrieved[0].name == 'Test Entity'
assert retrieved[0].group_id == group_id
# Get node by uuids
retrieved = await EntityNode.get_by_uuids(graph_driver, [sample_entity_node.uuid])
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2)
# Get node by group ids
retrieved = await EntityNode.get_by_group_ids(graph_driver, [group_id], limit=2, with_embeddings=True)
assert len(retrieved) == 1
assert retrieved[0].uuid == sample_entity_node.uuid
assert retrieved[0].name == 'Test Entity'
assert retrieved[0].group_id == group_id
await sample_entity_node.load_name_embedding(driver)
assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024)
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
# Delete node by uuid
await sample_entity_node.delete(driver)
node_count = await get_node_count(driver, uuid)
await sample_entity_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_entity_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_entity_node.save(driver)
node_count = await get_node_count(driver, uuid)
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_entity_node.delete_by_group_id(driver, group_id)
node_count = await get_node_count(driver, uuid)
await sample_entity_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await driver.close()
await graph_driver.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_community_node(sample_community_node, driver):
driver = get_driver(driver)
async def test_community_node(sample_community_node, graph_driver):
uuid = sample_community_node.uuid
# Create node
node_count = await get_node_count(driver, uuid)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_community_node.save(driver)
node_count = await get_node_count(driver, uuid)
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid)
assert retrieved.uuid == sample_community_node.uuid
assert retrieved.name == 'Community A'
assert retrieved.group_id == group_id
assert retrieved.summary == 'Community summary'
# Get node by uuid
retrieved = await CommunityNode.get_by_uuid(graph_driver, sample_community_node.uuid)
await assert_community_node_equals(graph_driver, retrieved, sample_community_node)
retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid])
assert retrieved[0].uuid == sample_community_node.uuid
assert retrieved[0].name == 'Community A'
assert retrieved[0].group_id == group_id
assert retrieved[0].summary == 'Community summary'
# Get node by uuids
retrieved = await CommunityNode.get_by_uuids(graph_driver, [sample_community_node.uuid])
await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2)
# Get node by group ids
retrieved = await CommunityNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == sample_community_node.uuid
assert retrieved[0].name == 'Community A'
assert retrieved[0].group_id == group_id
await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
# Delete node by uuid
await sample_community_node.delete(driver)
node_count = await get_node_count(driver, uuid)
await sample_community_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_community_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_community_node.save(driver)
node_count = await get_node_count(driver, uuid)
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_community_node.delete_by_group_id(driver, group_id)
node_count = await get_node_count(driver, uuid)
await sample_community_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await driver.close()
await graph_driver.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_episodic_node(sample_episodic_node, driver):
driver = get_driver(driver)
async def test_episodic_node(sample_episodic_node, graph_driver):
uuid = sample_episodic_node.uuid
# Create node
node_count = await get_node_count(driver, uuid)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_episodic_node.save(driver)
node_count = await get_node_count(driver, uuid)
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid)
assert retrieved.uuid == sample_episodic_node.uuid
assert retrieved.name == 'Episode 1'
assert retrieved.group_id == group_id
assert retrieved.source == EpisodeType.text
assert retrieved.source_description == 'Test source'
assert retrieved.content == 'Some content here'
assert retrieved.valid_at == sample_episodic_node.valid_at
# Get node by uuid
retrieved = await EpisodicNode.get_by_uuid(graph_driver, sample_episodic_node.uuid)
await assert_episodic_node_equals(retrieved, sample_episodic_node)
retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid])
assert retrieved[0].uuid == sample_episodic_node.uuid
assert retrieved[0].name == 'Episode 1'
assert retrieved[0].group_id == group_id
assert retrieved[0].source == EpisodeType.text
assert retrieved[0].source_description == 'Test source'
assert retrieved[0].content == 'Some content here'
assert retrieved[0].valid_at == sample_episodic_node.valid_at
# Get node by uuids
retrieved = await EpisodicNode.get_by_uuids(graph_driver, [sample_episodic_node.uuid])
await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2)
# Get node by group ids
retrieved = await EpisodicNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == sample_episodic_node.uuid
assert retrieved[0].name == 'Episode 1'
assert retrieved[0].group_id == group_id
assert retrieved[0].source == EpisodeType.text
assert retrieved[0].source_description == 'Test source'
assert retrieved[0].content == 'Some content here'
assert retrieved[0].valid_at == sample_episodic_node.valid_at
await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
# Delete node by uuid
await sample_episodic_node.delete(driver)
node_count = await get_node_count(driver, uuid)
await sample_episodic_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_episodic_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_episodic_node.save(driver)
node_count = await get_node_count(driver, uuid)
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_episodic_node.delete_by_group_id(driver, group_id)
node_count = await get_node_count(driver, uuid)
await sample_episodic_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await driver.close()
async def get_node_count(driver: GraphDriver, uuid: str):
result, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
RETURN COUNT(n) as count
""",
uuid=uuid,
)
return int(result[0]['count'])
await graph_driver.close()

42
uv.lock generated
View file

@ -809,6 +809,7 @@ dev = [
{ name = "groq" },
{ name = "ipykernel" },
{ name = "jupyterlab" },
{ name = "kuzu" },
{ name = "langchain-anthropic" },
{ name = "langchain-openai" },
{ name = "langgraph" },
@ -831,6 +832,9 @@ google-genai = [
groq = [
{ name = "groq" },
]
kuzu = [
{ name = "kuzu" },
]
neptune = [
{ name = "boto3" },
{ name = "langchain-aws" },
@ -858,6 +862,8 @@ requires-dist = [
{ name = "groq", marker = "extra == 'groq'", specifier = ">=0.2.0" },
{ name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5" },
{ name = "jupyterlab", marker = "extra == 'dev'", specifier = ">=4.2.4" },
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
@ -882,7 +888,7 @@ requires-dist = [
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
]
provides-extras = ["anthropic", "groq", "google-genai", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
[[package]]
name = "groq"
@ -1387,6 +1393,40 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/09/2032e7d15c544a0e3cd831c51d77a8ca57f7555b2e1b2922142eddb02a84/jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4", size = 59700, upload-time = "2024-07-16T17:02:01.115Z" },
]
[[package]]
name = "kuzu"
version = "0.11.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/66/fd/adbd05ccf81e6ad2674fcd3849d5d6ffeaf2141a9b8d1c1c4e282e923e1f/kuzu-0.11.2.tar.gz", hash = "sha256:9f224ec218ab165a18acaea903695779780d70335baf402d9b7f59ba389db0bd", size = 4902887, upload-time = "2025-08-21T05:17:00.152Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/91/bed837f5f49220a9f869da8a078b34a3484f210f7b57b267177821545c03/kuzu-0.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b25174cdb721aae47896ed62842d3859679607b493a9a6bbbcd9fb7fb3707", size = 3702618, upload-time = "2025-08-21T05:15:53.726Z" },
{ url = "https://files.pythonhosted.org/packages/72/8a/fd5e053b0055718afe00b6a99393a835c6254354128fbb7f66a35fd76089/kuzu-0.11.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:9a8567c53bfe282f4727782471ff718842ffead8c48c1762c1df9197408fc986", size = 4101371, upload-time = "2025-08-21T05:15:55.889Z" },
{ url = "https://files.pythonhosted.org/packages/ad/4b/e45cadc85bdc5079f432675bbe8d557600f0d4ab46fe24ef218374419902/kuzu-0.11.2-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d793bb5a0a14ada730a697eccac2a4c68b434b82692d985942900ef2003e099e", size = 6211974, upload-time = "2025-08-21T05:15:57.505Z" },
{ url = "https://files.pythonhosted.org/packages/10/ca/92d6a1e6452fcf06bfc423ce2cde819ace6b6e47921921cc8fae87c27780/kuzu-0.11.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1be4e9b6c93ca8591b1fb165f9b9a27d70a56af061831afcdfe7aebb89ee6ff", size = 6992196, upload-time = "2025-08-21T05:15:59.006Z" },
{ url = "https://files.pythonhosted.org/packages/49/6c/983fc6265dfc1169c87c4a0722f36ee665c5688e1166faeb4cd85e6af078/kuzu-0.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0ec7a304c746a2a98ecfd7e7c3f6fe92c4dfee2e2565c0b7cb4cffd0c2e374a", size = 4303517, upload-time = "2025-08-21T05:16:00.814Z" },
{ url = "https://files.pythonhosted.org/packages/b5/14/8ae2c52657b93715052ecf47d70232f2c8d9ffe2d1ec3527c8e9c3cb2df5/kuzu-0.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf53b4f321a4c05882b14cef96d39a1e90fa993bab88a1554fb1565367553b8c", size = 3704177, upload-time = "2025-08-21T05:16:02.354Z" },
{ url = "https://files.pythonhosted.org/packages/2d/7a/bce7bb755e16f9ca855f76a3acc6cfa9fae88c4d6af9df3784c50b2120a5/kuzu-0.11.2-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:2d749883b74f5da5ff4a4b0635a98f6cc5165743995828924321d2ca797317cb", size = 4102372, upload-time = "2025-08-21T05:16:04.249Z" },
{ url = "https://files.pythonhosted.org/packages/c8/12/f5b1d51fcb78a86c078fb85cc53184ce962a3e86852d47d30e287a932e3f/kuzu-0.11.2-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:632507e5982928ed24fbb5e70ad143d7970bc4059046e77e0522707efbad303b", size = 6212492, upload-time = "2025-08-21T05:16:05.99Z" },
{ url = "https://files.pythonhosted.org/packages/81/96/d6e57af6ccf9e0697812ad3039c80b87b768cf2674833b0b23d317ea3427/kuzu-0.11.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5211884601f8f08ae97ba25006d0facde24077c5333411d944282b8a2068ab4", size = 6992888, upload-time = "2025-08-21T05:16:07.896Z" },
{ url = "https://files.pythonhosted.org/packages/40/ee/1f275ac5679a3f615ce0d9cf8c79001fdb535ccc8bc344e49b14484c7cd7/kuzu-0.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:82a6c8bfe1278dc1010790e398bf772683797ef5c16052fa0f6f78bacbc59aa3", size = 4304064, upload-time = "2025-08-21T05:16:10.163Z" },
{ url = "https://files.pythonhosted.org/packages/73/ba/9f20d9e83681a0ddae8ec13046b116c34745fa0e66862d4e2d8414734ce0/kuzu-0.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aed88ffa695d07289a3d8557bd8f9e743298a4f4349208a60bbb06f4ebf15c26", size = 3703781, upload-time = "2025-08-21T05:16:12.232Z" },
{ url = "https://files.pythonhosted.org/packages/53/a0/bb815c0490f3d4d30389156369b9fe641e154f0d4b1e8340f09a76021922/kuzu-0.11.2-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:595824b03248af928e3faee57f6825d3a46920f2d3b9bd0c0bb7fc3fa097fce9", size = 4103990, upload-time = "2025-08-21T05:16:14.139Z" },
{ url = "https://files.pythonhosted.org/packages/a5/6f/97b647c0547a634a669055ff4cfd21a92ea3999aedc6a7fe9004f03f25e3/kuzu-0.11.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5674c6d9d26f5caa0c7ce6f34c02e4411894879aa5b2ce174fad576fa898523", size = 6211947, upload-time = "2025-08-21T05:16:16.48Z" },
{ url = "https://files.pythonhosted.org/packages/42/74/c7f1a1cfb08c05c91c5a94483be387e80fafab8923c4243a22e9cced5c1b/kuzu-0.11.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c61daf02da35b671f4c6f3c17105725c399a5e14b7349b00eafbcd24ac90034a", size = 6991879, upload-time = "2025-08-21T05:16:18.402Z" },
{ url = "https://files.pythonhosted.org/packages/54/9e/50d67d7bc08faed95ede6de1a6aa0d81079c98028ca99e32d09c2ab1aead/kuzu-0.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:682096cd87dcbb8257f933ea4172d9dc5617a8d0a5bdd19cd66cf05b68881afd", size = 4305706, upload-time = "2025-08-21T05:16:20.244Z" },
{ url = "https://files.pythonhosted.org/packages/65/f0/5649a01af37def50293cd7c194afc19f09b343fd2b7f2b28e021a207f8ce/kuzu-0.11.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:17a11b67652e8b331c85cd1a1a30b32ee6783750084473abbab2aa1963ee2a3b", size = 3703740, upload-time = "2025-08-21T05:16:21.896Z" },
{ url = "https://files.pythonhosted.org/packages/24/e2/e0beb9080911fc1689899a42da0f83534949f43169fb80197def3ec1223f/kuzu-0.11.2-cp313-cp313-macosx_11_0_x86_64.whl", hash = "sha256:bdded35426210faeca8da11e8b4a54e60ccc0c1a832660d76587b5be133b0f55", size = 4104073, upload-time = "2025-08-21T05:16:23.819Z" },
{ url = "https://files.pythonhosted.org/packages/f2/4c/7a831c9c6e609692953db677f54788bd1dde4c9d34e6ba91f1e153d2e7fe/kuzu-0.11.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6116b609aac153f3523130b31295643d34a6c9509914c0fa9d804b26b23eee73", size = 6212263, upload-time = "2025-08-21T05:16:25.351Z" },
{ url = "https://files.pythonhosted.org/packages/47/95/615ef10b46b22ec1d33fdbba795e6e79733d9a244aabdeeb910f267ab36c/kuzu-0.11.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:09da5b8cb24dc6b281a6e4ac0f7f24226eb9909803b187e02d014da13ba57bcf", size = 6992492, upload-time = "2025-08-21T05:16:27.518Z" },
{ url = "https://files.pythonhosted.org/packages/a7/dd/2c905575913c743e6c67a5ca89a6b4ea9d9737238966d85d7e710f0d3e60/kuzu-0.11.2-cp313-cp313-win_amd64.whl", hash = "sha256:c663fb84682f8ebffbe7447a4e552a0e03bd29097d319084a2c53c2e032a780e", size = 4305267, upload-time = "2025-08-21T05:16:29.307Z" },
{ url = "https://files.pythonhosted.org/packages/89/05/44fbfc9055dba3f472ea4aaa8110635864d3441eede987526ef401680765/kuzu-0.11.2-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5c03fb95ffb9185c1519333f8ee92b7a9695aa7aa9a179e868a7d7bd13d10a16", size = 6216795, upload-time = "2025-08-21T05:16:30.944Z" },
{ url = "https://files.pythonhosted.org/packages/4f/ca/16c81dc68cc1e8918f8481e7ee89c28aa665c5cb36be7ad0fc1d0d295760/kuzu-0.11.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d857f0efddf26d5e2dc189facb84bf04a997e395972486669b418a470cc76034", size = 6996333, upload-time = "2025-08-21T05:16:32.568Z" },
{ url = "https://files.pythonhosted.org/packages/48/d8/9275c7e6312bd76dc670e8e2da68639757c22cf2c366e96527595a1d881c/kuzu-0.11.2-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb9e4641867c35b98ceaa604aa79832c0eeed41f5fd1b6da22b1c217b2f1b8ea", size = 6212202, upload-time = "2025-08-21T05:16:34.571Z" },
{ url = "https://files.pythonhosted.org/packages/88/89/67a977493c60bca3610845df13020711f357a5d80bf91549e4b48d877c2f/kuzu-0.11.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:553408d76a0b4fdecc1338b69b71d7bde42f6936d3b99d9852b30d33bda15978", size = 6992264, upload-time = "2025-08-21T05:16:36.316Z" },
{ url = "https://files.pythonhosted.org/packages/b6/49/869ceceb1d8a5ea032a35c734e55cfee919340889973623096da7eb94f6b/kuzu-0.11.2-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:989a87fa13ffa39ab7773d968fe739ac4f8faf9ddb5dad72ced2eeef12180293", size = 6216814, upload-time = "2025-08-21T05:16:38.348Z" },
{ url = "https://files.pythonhosted.org/packages/bc/cd/933b34a246edb882a042eb402747167719222c05149b73b48ba7d310d554/kuzu-0.11.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e67420d04a9643fd6376a23b17b398a3e32bb0c2bd8abbf8d1e4697056596c7e", size = 6996343, upload-time = "2025-08-21T05:16:39.973Z" },
]
[[package]]
name = "langchain-anthropic"
version = "0.3.17"