Add support for Kuzu as the graph driver (#799)

* Fix FalkoDB tests

* Add support for graph memory using Kuzu

* Fix lints

* Fix queries

* Add tests

* Add comments

* Add more test coverage

* Add mocked tests

* Format

* Add mocked tests II

* Refactor community queries

* Add more mocked tests

* Refactor tests to always cleanup

* Add more mocked tests

* Update kuzu

* Refactor how filters are built

* Add more mocked tests

* Refactor and cleanup

* Fix tests

* Fix lints

* Refactor tests

* Disable neptune

* Fix

* Update kuzu version

* Update kuzu to latest release

* Fix filter

* Fix query

* Fix Neptune query

* Fix bulk queries

* Fix lints

* Fix deletes

* Comments and format

* Add Kuzu to the README

* Fix bulk queries

* Test all fields of nodes and edges

* Fix lints

* Update search_utils.py

---------

Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
This commit is contained in:
Siddhartha Sahu 2025-08-27 11:45:21 -04:00 committed by GitHub
parent 309159bccb
commit 8802b7db13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 4219 additions and 966 deletions

View file

@ -49,6 +49,7 @@ jobs:
NEO4J_URI: bolt://localhost:7687 NEO4J_URI: bolt://localhost:7687
NEO4J_USER: neo4j NEO4J_USER: neo4j
NEO4J_PASSWORD: testpass NEO4J_PASSWORD: testpass
DISABLE_NEPTUNE: 1
run: | run: |
uv run pytest -m "not integration" uv run pytest -m "not integration"
- name: Wait for FalkorDB - name: Wait for FalkorDB

View file

@ -44,7 +44,7 @@ Use Graphiti to:
<br /> <br />
<p align="center"> <p align="center">
<img src="images/graphiti-graph-intro.gif" alt="Graphiti temporal walkthrough" width="700px"> <img src="images/graphiti-graph-intro.gif" alt="Graphiti temporal walkthrough" width="700px">
</p> </p>
<br /> <br />
@ -80,7 +80,7 @@ Traditional RAG approaches often rely on batch processing and static data summar
- **Scalability:** Efficiently manages large datasets with parallel processing, suitable for enterprise environments. - **Scalability:** Efficiently manages large datasets with parallel processing, suitable for enterprise environments.
<p align="center"> <p align="center">
<img src="/images/graphiti-intro-slides-stock-2.gif" alt="Graphiti structured + unstructured demo" width="700px"> <img src="/images/graphiti-intro-slides-stock-2.gif" alt="Graphiti structured + unstructured demo" width="700px">
</p> </p>
## Graphiti vs. GraphRAG ## Graphiti vs. GraphRAG
@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
Requirements: Requirements:
- Python 3.10 or higher - Python 3.10 or higher
- Neo4j 5.26 / FalkorDB 1.1.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend) - Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding) - OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
> [!IMPORTANT] > [!IMPORTANT]
@ -148,6 +148,17 @@ pip install graphiti-core[falkordb]
uv add graphiti-core[falkordb] uv add graphiti-core[falkordb]
``` ```
### Installing with Kuzu Support
If you plan to use Kuzu as your graph database backend, install with the Kuzu extra:
```bash
pip install graphiti-core[kuzu]
# or with uv
uv add graphiti-core[kuzu]
```
### Installing with Amazon Neptune Support ### Installing with Amazon Neptune Support
If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra: If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra:
@ -198,7 +209,7 @@ If your LLM provider allows higher throughput, you can increase `SEMAPHORE_LIMIT
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates: For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
1. Connecting to a Neo4j, Amazon Neptune, or FalkorDB database 1. Connecting to a Neo4j, Amazon Neptune, FalkorDB, or Kuzu database
2. Initializing Graphiti indices and constraints 2. Initializing Graphiti indices and constraints
3. Adding episodes to the graph (both text and structured JSON) 3. Adding episodes to the graph (both text and structured JSON)
4. Searching for relationships (edges) using hybrid search 4. Searching for relationships (edges) using hybrid search
@ -281,6 +292,19 @@ driver = FalkorDriver(
graphiti = Graphiti(graph_driver=driver) graphiti = Graphiti(graph_driver=driver)
``` ```
#### Kuzu
```python
from graphiti_core import Graphiti
from graphiti_core.driver.kuzu_driver import KuzuDriver
# Create a Kuzu driver
driver = KuzuDriver(db="/tmp/graphiti.kuzu")
# Pass the driver to Graphiti
graphiti = Graphiti(graph_driver=driver)
```
#### Amazon Neptune #### Amazon Neptune
```python ```python
@ -494,7 +518,7 @@ When you initialize a Graphiti instance, we collect:
- **Graphiti version**: The version you're using - **Graphiti version**: The version you're using
- **Configuration choices**: - **Configuration choices**:
- LLM provider type (OpenAI, Azure, Anthropic, etc.) - LLM provider type (OpenAI, Azure, Anthropic, etc.)
- Database backend (Neo4j, FalkorDB, Amazon Neptune Database or Neptune Analytics) - Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics)
- Embedder provider (OpenAI, Azure, Voyage, etc.) - Embedder provider (OpenAI, Azure, Voyage, etc.)
### What We Don't Collect ### What We Don't Collect

View file

@ -4,3 +4,7 @@ import sys
# This code adds the project root directory to the Python path, allowing imports to work correctly when running tests. # This code adds the project root directory to the Python path, allowing imports to work correctly when running tests.
# Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests. # Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__)))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))
from tests.helpers_test import graph_driver, mock_embedder
__all__ = ['graph_driver', 'mock_embedder']

View file

@ -27,10 +27,13 @@ logger = logging.getLogger(__name__)
class GraphProvider(Enum): class GraphProvider(Enum):
NEO4J = 'neo4j' NEO4J = 'neo4j'
FALKORDB = 'falkordb' FALKORDB = 'falkordb'
KUZU = 'kuzu'
NEPTUNE = 'neptune' NEPTUNE = 'neptune'
class GraphDriverSession(ABC): class GraphDriverSession(ABC):
provider: GraphProvider
async def __aenter__(self): async def __aenter__(self):
return self return self

View file

@ -15,7 +15,6 @@ limitations under the License.
""" """
import logging import logging
from datetime import datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
@ -33,11 +32,14 @@ else:
) from None ) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FalkorDriverSession(GraphDriverSession): class FalkorDriverSession(GraphDriverSession):
provider = GraphProvider.FALKORDB
def __init__(self, graph: FalkorGraph): def __init__(self, graph: FalkorGraph):
self.graph = graph self.graph = graph
@ -164,16 +166,3 @@ class FalkorDriver(GraphDriver):
cloned = FalkorDriver(falkor_db=self.client, database=database) cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned return cloned
def convert_datetimes_to_strings(obj):
if isinstance(obj, dict):
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_datetimes_to_strings(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(convert_datetimes_to_strings(item) for item in obj)
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj

View file

@ -0,0 +1,175 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from typing import Any
import kuzu
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__)
# Kuzu requires an explicit schema.
# As Kuzu currently does not support creating full text indexes on edge properties,
# we work around this by representing (n:Entity)-[:RELATES_TO]->(m:Entity) as
# (n)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m).
SCHEMA_QUERIES = """
CREATE NODE TABLE IF NOT EXISTS Episodic (
uuid STRING PRIMARY KEY,
name STRING,
group_id STRING,
created_at TIMESTAMP,
source STRING,
source_description STRING,
content STRING,
valid_at TIMESTAMP,
entity_edges STRING[]
);
CREATE NODE TABLE IF NOT EXISTS Entity (
uuid STRING PRIMARY KEY,
name STRING,
group_id STRING,
labels STRING[],
created_at TIMESTAMP,
name_embedding FLOAT[],
summary STRING,
attributes STRING
);
CREATE NODE TABLE IF NOT EXISTS Community (
uuid STRING PRIMARY KEY,
name STRING,
group_id STRING,
created_at TIMESTAMP,
name_embedding FLOAT[],
summary STRING
);
CREATE NODE TABLE IF NOT EXISTS RelatesToNode_ (
uuid STRING PRIMARY KEY,
group_id STRING,
created_at TIMESTAMP,
name STRING,
fact STRING,
fact_embedding FLOAT[],
episodes STRING[],
expired_at TIMESTAMP,
valid_at TIMESTAMP,
invalid_at TIMESTAMP,
attributes STRING
);
CREATE REL TABLE IF NOT EXISTS RELATES_TO(
FROM Entity TO RelatesToNode_,
FROM RelatesToNode_ TO Entity
);
CREATE REL TABLE IF NOT EXISTS MENTIONS(
FROM Episodic TO Entity,
uuid STRING PRIMARY KEY,
group_id STRING,
created_at TIMESTAMP
);
CREATE REL TABLE IF NOT EXISTS HAS_MEMBER(
FROM Community TO Entity,
FROM Community TO Community,
uuid STRING,
group_id STRING,
created_at TIMESTAMP
);
"""
class KuzuDriver(GraphDriver):
provider: GraphProvider = GraphProvider.KUZU
def __init__(
self,
db: str = ':memory:',
max_concurrent_queries: int = 1,
):
super().__init__()
self.db = kuzu.Database(db)
self.setup_schema()
self.client = kuzu.AsyncConnection(self.db, max_concurrent_queries=max_concurrent_queries)
async def execute_query(
self, cypher_query_: str, **kwargs: Any
) -> tuple[list[dict[str, Any]] | list[list[dict[str, Any]]], None, None]:
params = {k: v for k, v in kwargs.items() if v is not None}
# Kuzu does not support these parameters.
params.pop('database_', None)
params.pop('routing_', None)
try:
results = await self.client.execute(cypher_query_, parameters=params)
except Exception as e:
params = {k: (v[:5] if isinstance(v, list) else v) for k, v in params.items()}
logger.error(f'Error executing Kuzu query: {e}\n{cypher_query_}\n{params}')
raise
if not results:
return [], None, None
if isinstance(results, list):
dict_results = [list(result.rows_as_dict()) for result in results]
else:
dict_results = list(results.rows_as_dict())
return dict_results, None, None # type: ignore
def session(self, _database: str | None = None) -> GraphDriverSession:
return KuzuDriverSession(self)
async def close(self):
# Do not explicity close the connection, instead rely on GC.
pass
def delete_all_indexes(self, database_: str):
pass
def setup_schema(self):
conn = kuzu.Connection(self.db)
conn.execute(SCHEMA_QUERIES)
conn.close()
class KuzuDriverSession(GraphDriverSession):
provider = GraphProvider.KUZU
def __init__(self, driver: KuzuDriver):
self.driver = driver
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
# No cleanup needed for Kuzu, but method must exist.
pass
async def close(self):
# Do not close the session here, as we're reusing the driver connection.
pass
async def execute_write(self, func, *args, **kwargs):
# Directly await the provided async function with `self` as the transaction/session
return await func(self, *args, **kwargs)
async def run(self, query: str | list, **kwargs: Any) -> Any:
if isinstance(query, list):
for cypher, params in query:
await self.driver.execute_query(cypher, **params)
else:
await self.driver.execute_query(query, **kwargs)
return None

View file

@ -271,6 +271,8 @@ class NeptuneDriver(GraphDriver):
class NeptuneDriverSession(GraphDriverSession): class NeptuneDriverSession(GraphDriverSession):
provider = GraphProvider.NEPTUNE
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType] def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
self.driver = driver self.driver = driver

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
@ -30,11 +31,10 @@ from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import ( from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_RETURN, COMMUNITY_EDGE_RETURN,
ENTITY_EDGE_RETURN,
ENTITY_EDGE_RETURN_NEPTUNE,
EPISODIC_EDGE_RETURN, EPISODIC_EDGE_RETURN,
EPISODIC_EDGE_SAVE, EPISODIC_EDGE_SAVE,
get_community_edge_save_query, get_community_edge_save_query,
get_entity_edge_return_query,
get_entity_edge_save_query, get_entity_edge_save_query,
) )
from graphiti_core.nodes import Node from graphiti_core.nodes import Node
@ -53,33 +53,63 @@ class Edge(BaseModel, ABC):
async def save(self, driver: GraphDriver): ... async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver): async def delete(self, driver: GraphDriver):
result = await driver.execute_query( if driver.provider == GraphProvider.KUZU:
""" await driver.execute_query(
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m) """
DELETE e MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
""", DELETE e
uuid=self.uuid, """,
) uuid=self.uuid,
)
await driver.execute_query(
"""
MATCH (e:RelatesToNode_ {uuid: $uuid})
DETACH DELETE e
""",
uuid=self.uuid,
)
else:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Edge: {self.uuid}') logger.debug(f'Deleted Edge: {self.uuid}')
return result
@classmethod @classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]): async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
result = await driver.execute_query( if driver.provider == GraphProvider.KUZU:
""" await driver.execute_query(
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m) """
WHERE e.uuid IN $uuids MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
DELETE e WHERE e.uuid IN $uuids
""", DELETE e
uuids=uuids, """,
) uuids=uuids,
)
await driver.execute_query(
"""
MATCH (e:RelatesToNode_)
WHERE e.uuid IN $uuids
DETACH DELETE e
""",
uuids=uuids,
)
else:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
DELETE e
""",
uuids=uuids,
)
logger.debug(f'Deleted Edges: {uuids}') logger.debug(f'Deleted Edges: {uuids}')
return result
def __hash__(self): def __hash__(self):
return hash(self.uuid) return hash(self.uuid)
@ -166,7 +196,7 @@ class EpisodicEdge(Edge):
""" """
+ EPISODIC_EDGE_RETURN + EPISODIC_EDGE_RETURN
+ """ + """
ORDER BY e.uuid DESC ORDER BY e.uuid DESC
""" """
+ limit_query, + limit_query,
group_ids=group_ids, group_ids=group_ids,
@ -215,15 +245,21 @@ class EntityEdge(Edge):
return self.fact_embedding return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver): async def load_fact_embedding(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE: query = """
query: LiteralString = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
else:
query: LiteralString = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding RETURN e.fact_embedding AS fact_embedding
"""
if driver.provider == GraphProvider.NEPTUNE:
query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
if driver.provider == GraphProvider.KUZU:
query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
""" """
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -253,15 +289,22 @@ class EntityEdge(Edge):
'invalid_at': self.invalid_at, 'invalid_at': self.invalid_at,
} }
edge_data.update(self.attributes or {}) if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
**edge_data,
)
else:
edge_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_edge_save_query(driver.provider), get_entity_edge_save_query(driver.provider),
edge_data=edge_data, edge_data=edge_data,
) )
logger.debug(f'Saved edge to Graph: {self.uuid}') logger.debug(f'Saved edge to Graph: {self.uuid}')
@ -269,21 +312,25 @@ class EntityEdge(Edge):
@classmethod @classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query( match_query = """
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN RETURN
""" """
+ ( + get_entity_edge_return_query(driver.provider),
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
uuid=uuid, uuid=uuid,
routing_='r', routing_='r',
) )
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0: if len(edges) == 0:
raise EdgeNotFoundError(uuid) raise EdgeNotFoundError(uuid)
@ -294,22 +341,26 @@ class EntityEdge(Edge):
if len(uuids) == 0: if len(uuids) == 0:
return [] return []
records, _, _ = await driver.execute_query( match_query = """
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WHERE e.uuid IN $uuids WHERE e.uuid IN $uuids
RETURN RETURN
""" """
+ ( + get_entity_edge_return_query(driver.provider),
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
uuids=uuids, uuids=uuids,
routing_='r', routing_='r',
) )
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges return edges
@ -332,23 +383,27 @@ class EntityEdge(Edge):
else '' else ''
) )
records, _, _ = await driver.execute_query( match_query = """
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
""" """
+ cursor_query + cursor_query
+ """ + """
RETURN RETURN
""" """
+ ( + get_entity_edge_return_query(driver.provider)
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
)
+ with_embeddings_query + with_embeddings_query
+ """ + """
ORDER BY e.uuid DESC ORDER BY e.uuid DESC
""" """
+ limit_query, + limit_query,
group_ids=group_ids, group_ids=group_ids,
@ -357,7 +412,7 @@ class EntityEdge(Edge):
routing_='r', routing_='r',
) )
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0: if len(edges) == 0:
raise GroupsEdgesNotFoundError(group_ids) raise GroupsEdgesNotFoundError(group_ids)
@ -365,21 +420,25 @@ class EntityEdge(Edge):
@classmethod @classmethod
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str): async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
records, _, _ = await driver.execute_query( match_query = """
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN RETURN
""" """
+ ( + get_entity_edge_return_query(driver.provider),
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
node_uuid=node_uuid, node_uuid=node_uuid,
routing_='r', routing_='r',
) )
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges return edges
@ -479,7 +538,25 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
) )
def get_entity_edge_from_record(record: Any) -> EntityEdge: def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
episodes = record['episodes']
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
else:
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('source_node_uuid', None)
attributes.pop('target_node_uuid', None)
attributes.pop('fact', None)
attributes.pop('fact_embedding', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('episodes', None)
attributes.pop('created_at', None)
attributes.pop('expired_at', None)
attributes.pop('valid_at', None)
attributes.pop('invalid_at', None)
edge = EntityEdge( edge = EntityEdge(
uuid=record['uuid'], uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'], source_node_uuid=record['source_node_uuid'],
@ -488,26 +565,14 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
fact_embedding=record.get('fact_embedding'), fact_embedding=record.get('fact_embedding'),
name=record['name'], name=record['name'],
group_id=record['group_id'], group_id=record['group_id'],
episodes=record['episodes'], episodes=episodes,
created_at=parse_db_date(record['created_at']), # type: ignore created_at=parse_db_date(record['created_at']), # type: ignore
expired_at=parse_db_date(record['expired_at']), expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']), valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']), invalid_at=parse_db_date(record['invalid_at']),
attributes=record['attributes'], attributes=attributes,
) )
edge.attributes.pop('uuid', None)
edge.attributes.pop('source_node_uuid', None)
edge.attributes.pop('target_node_uuid', None)
edge.attributes.pop('fact', None)
edge.attributes.pop('name', None)
edge.attributes.pop('group_id', None)
edge.attributes.pop('episodes', None)
edge.attributes.pop('created_at', None)
edge.attributes.pop('expired_at', None)
edge.attributes.pop('valid_at', None)
edge.attributes.pop('invalid_at', None)
return edge return edge

View file

@ -16,6 +16,13 @@ NEO4J_TO_FALKORDB_MAPPING = {
'episode_content': 'Episodic', 'episode_content': 'Episodic',
'edge_name_and_fact': 'RELATES_TO', 'edge_name_and_fact': 'RELATES_TO',
} }
# Mapping from fulltext index names to Kuzu node labels
INDEX_TO_LABEL_KUZU_MAPPING = {
'node_name_and_summary': 'Entity',
'community_name': 'Community',
'episode_content': 'Episodic',
'edge_name_and_fact': 'RelatesToNode_',
}
def get_range_indices(provider: GraphProvider) -> list[LiteralString]: def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
@ -35,6 +42,9 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
] ]
if provider == GraphProvider.KUZU:
return []
return [ return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
@ -68,6 +78,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""", """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
] ]
if provider == GraphProvider.KUZU:
return [
"CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
"CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
"CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
]
return [ return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
@ -80,11 +98,15 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
] ]
def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str: def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB: if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name] label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})" return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
if provider == GraphProvider.KUZU:
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
@ -93,12 +115,19 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2' return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
if provider == GraphProvider.KUZU:
return f'array_cosine_similarity({vec1}, {vec2})'
return f'vector.similarity.cosine({vec1}, {vec2})' return f'vector.similarity.cosine({vec1}, {vec2})'
def get_relationships_query(name: str, provider: GraphProvider) -> str: def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB: if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name] label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)" return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
if provider == GraphProvider.KUZU:
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'

View file

@ -1070,7 +1070,7 @@ class Graphiti:
if record['episode_count'] == 1: if record['episode_count'] == 1:
nodes_to_delete.append(node) nodes_to_delete.append(node)
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete]) await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
await episode.delete(self.driver) await episode.delete(self.driver)

View file

@ -43,14 +43,14 @@ RUNTIME_QUERY: LiteralString = (
) )
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None: def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
return ( if isinstance(input_date, neo4j_time.DateTime):
neo_date.to_native() return input_date.to_native()
if isinstance(neo_date, neo4j_time.DateTime)
else datetime.fromisoformat(neo_date) if isinstance(input_date, str):
if neo_date return datetime.fromisoformat(input_date)
else None
) return input_date
def get_default_group_id(provider: GraphProvider) -> str: def get_default_group_id(provider: GraphProvider) -> str:

View file

@ -20,18 +20,36 @@ EPISODIC_EDGE_SAVE = """
MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid}) MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node) MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} SET
e.group_id = $group_id,
e.created_at = $created_at
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
""" """
EPISODIC_EDGE_SAVE_BULK = """
UNWIND $episodic_edges AS edge def get_episodic_edge_save_bulk_query(provider: GraphProvider) -> str:
MATCH (episode:Episodic {uuid: edge.source_node_uuid}) if provider == GraphProvider.KUZU:
MATCH (node:Entity {uuid: edge.target_node_uuid}) return """
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node) MATCH (episode:Episodic {uuid: $source_node_uuid})
SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at} MATCH (node:Entity {uuid: $target_node_uuid})
RETURN e.uuid AS uuid MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
""" SET
e.group_id = $group_id,
e.created_at = $created_at
RETURN e.uuid AS uuid
"""
return """
UNWIND $episodic_edges AS edge
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
SET
e.group_id = edge.group_id,
e.created_at = edge.created_at
RETURN e.uuid AS uuid
"""
EPISODIC_EDGE_RETURN = """ EPISODIC_EDGE_RETURN = """
e.uuid AS uuid, e.uuid AS uuid,
@ -54,14 +72,32 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
""" """
case GraphProvider.NEPTUNE: case GraphProvider.NEPTUNE:
return """ return """
MATCH (source:Entity {uuid: $edge_data.source_uuid}) MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid}) MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes") SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",") SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
SET e.episodes = join($edge_data.episodes, ",") SET e.episodes = join($edge_data.episodes, ",")
RETURN $edge_data.uuid AS uuid RETURN $edge_data.uuid AS uuid
""" """
case GraphProvider.KUZU:
return """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
SET
e.group_id = $group_id,
e.created_at = $created_at,
e.name = $name,
e.fact = $fact,
e.fact_embedding = $fact_embedding,
e.episodes = $episodes,
e.expired_at = $expired_at,
e.valid_at = $valid_at,
e.invalid_at = $invalid_at,
e.attributes = $attributes
RETURN e.uuid AS uuid
"""
case _: # Neo4j case _: # Neo4j
return """ return """
MATCH (source:Entity {uuid: $edge_data.source_uuid}) MATCH (source:Entity {uuid: $edge_data.source_uuid})
@ -89,14 +125,32 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE: case GraphProvider.NEPTUNE:
return """ return """
UNWIND $entity_edges AS edge UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid}) MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid}) MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes") SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",") SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
SET r.episodes = join(edge.episodes, ",") SET r.episodes = join(edge.episodes, ",")
RETURN edge.uuid AS uuid RETURN edge.uuid AS uuid
""" """
case GraphProvider.KUZU:
return """
MATCH (source:Entity {uuid: $source_node_uuid})
MATCH (target:Entity {uuid: $target_node_uuid})
MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
SET
e.group_id = $group_id,
e.created_at = $created_at,
e.name = $name,
e.fact = $fact,
e.fact_embedding = $fact_embedding,
e.episodes = $episodes,
e.expired_at = $expired_at,
e.valid_at = $valid_at,
e.invalid_at = $invalid_at,
e.attributes = $attributes
RETURN e.uuid AS uuid
"""
case _: case _:
return """ return """
UNWIND $entity_edges AS edge UNWIND $entity_edges AS edge
@ -109,35 +163,42 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
""" """
ENTITY_EDGE_RETURN = """ def get_entity_edge_return_query(provider: GraphProvider) -> str:
e.uuid AS uuid, # `fact_embedding` is not returned by default and must be manually loaded using `load_fact_embedding()`.
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.name AS name,
e.fact AS fact,
e.episodes AS episodes,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
ENTITY_EDGE_RETURN_NEPTUNE = """ if provider == GraphProvider.NEPTUNE:
e.uuid AS uuid, return """
n.uuid AS source_node_uuid, e.uuid AS uuid,
m.uuid AS target_node_uuid, n.uuid AS source_node_uuid,
e.group_id AS group_id, m.uuid AS target_node_uuid,
e.name AS name, e.group_id AS group_id,
e.fact AS fact, e.name AS name,
split(e.episodes, ',') AS episodes, e.fact AS fact,
e.created_at AS created_at, split(e.episodes, ',') AS episodes,
e.expired_at AS expired_at, e.created_at AS created_at,
e.valid_at AS valid_at, e.expired_at AS expired_at,
e.invalid_at AS invalid_at, e.valid_at AS valid_at,
properties(e) AS attributes e.invalid_at AS invalid_at,
""" properties(e) AS attributes
"""
return """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.created_at AS created_at,
e.name AS name,
e.fact AS fact,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
""" + (
'e.attributes AS attributes'
if provider == GraphProvider.KUZU
else 'properties(e) AS attributes'
)
def get_community_edge_save_query(provider: GraphProvider) -> str: def get_community_edge_save_query(provider: GraphProvider) -> str:
@ -152,7 +213,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
""" """
case GraphProvider.NEPTUNE: case GraphProvider.NEPTUNE:
return """ return """
MATCH (community:Community {uuid: $community_uuid}) MATCH (community:Community {uuid: $community_uuid})
MATCH (node {uuid: $entity_uuid}) MATCH (node {uuid: $entity_uuid})
WHERE node:Entity OR node:Community WHERE node:Entity OR node:Community
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node) MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
@ -161,6 +222,24 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
SET r.created_at= $created_at SET r.created_at= $created_at
RETURN r.uuid AS uuid RETURN r.uuid AS uuid
""" """
case GraphProvider.KUZU:
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET
e.group_id = $group_id,
e.created_at = $created_at
RETURN e.uuid AS uuid
UNION
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Community {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET
e.group_id = $group_id,
e.created_at = $created_at
RETURN e.uuid AS uuid
"""
case _: # Neo4j case _: # Neo4j
return """ return """
MATCH (community:Community {uuid: $community_uuid}) MATCH (community:Community {uuid: $community_uuid})

View file

@ -24,10 +24,24 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE: case GraphProvider.NEPTUNE:
return """ return """
MERGE (n:Episodic {uuid: $uuid}) MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at} entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case GraphProvider.KUZU:
return """
MERGE (n:Episodic {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.source = $source,
n.source_description = $source_description,
n.content = $content,
n.valid_at = $valid_at,
n.entity_edges = $entity_edges
RETURN n.uuid AS uuid
"""
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
return """ return """
MERGE (n:Episodic {uuid: $uuid}) MERGE (n:Episodic {uuid: $uuid})
@ -51,11 +65,25 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
return """ return """
UNWIND $episodes AS episode UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid}) MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content, source: episode.source, content: episode.content,
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at} entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case GraphProvider.KUZU:
return """
MERGE (n:Episodic {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.source = $source,
n.source_description = $source_description,
n.content = $content,
n.valid_at = $valid_at,
n.entity_edges = $entity_edges
RETURN n.uuid AS uuid
"""
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
return """ return """
UNWIND $episodes AS episode UNWIND $episodes AS episode
@ -76,14 +104,14 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
EPISODIC_NODE_RETURN = """ EPISODIC_NODE_RETURN = """
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid, e.uuid AS uuid,
e.name AS name, e.name AS name,
e.group_id AS group_id, e.group_id AS group_id,
e.source_description AS source_description, e.created_at AS created_at,
e.source AS source, e.source AS source,
e.source_description AS source_description,
e.content AS content,
e.valid_at AS valid_at,
e.entity_edges AS entity_edges e.entity_edges AS entity_edges
""" """
@ -109,6 +137,20 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
SET n = $entity_data SET n = $entity_data
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case GraphProvider.KUZU:
return """
MERGE (n:Entity {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.labels = $labels,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary,
n.attributes = $attributes
WITH n
RETURN n.uuid AS uuid
"""
case GraphProvider.NEPTUNE: case GraphProvider.NEPTUNE:
label_subquery = '' label_subquery = ''
for label in labels.split(':'): for label in labels.split(':'):
@ -168,6 +210,19 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
""" """
) )
return queries return queries
case GraphProvider.KUZU:
return """
MERGE (n:Entity {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.labels = $labels,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary,
n.attributes = $attributes
RETURN n.uuid AS uuid
"""
case _: # Neo4j case _: # Neo4j
return """ return """
UNWIND $nodes AS node UNWIND $nodes AS node
@ -179,15 +234,28 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
""" """
ENTITY_NODE_RETURN = """ def get_entity_node_return_query(provider: GraphProvider) -> str:
n.uuid AS uuid, # `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
n.name AS name, if provider == GraphProvider.KUZU:
n.group_id AS group_id, return """
n.created_at AS created_at, n.uuid AS uuid,
n.summary AS summary, n.name AS name,
labels(n) AS labels, n.group_id AS group_id,
properties(n) AS attributes n.labels AS labels,
""" n.created_at AS created_at,
n.summary AS summary,
n.attributes AS attributes
"""
return """
n.uuid AS uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
"""
def get_community_node_save_query(provider: GraphProvider) -> str: def get_community_node_save_query(provider: GraphProvider) -> str:
@ -201,10 +269,21 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE: case GraphProvider.NEPTUNE:
return """ return """
MERGE (n:Community {uuid: $uuid}) MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at} SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",") SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case GraphProvider.KUZU:
return """
MERGE (n:Community {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary
RETURN n.uuid AS uuid
"""
case _: # Neo4j case _: # Neo4j
return """ return """
MERGE (n:Community {uuid: $uuid}) MERGE (n:Community {uuid: $uuid})
@ -215,12 +294,12 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
COMMUNITY_NODE_RETURN = """ COMMUNITY_NODE_RETURN = """
n.uuid AS uuid, c.uuid AS uuid,
n.name AS name, c.name AS name,
n.name_embedding AS name_embedding, c.group_id AS group_id,
n.group_id AS group_id, c.created_at AS created_at,
n.summary AS summary, c.name_embedding AS name_embedding,
n.created_at AS created_at c.summary AS summary
""" """
COMMUNITY_NODE_RETURN_NEPTUNE = """ COMMUNITY_NODE_RETURN_NEPTUNE = """

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
@ -32,10 +33,10 @@ from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import ( from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_RETURN, COMMUNITY_NODE_RETURN,
COMMUNITY_NODE_RETURN_NEPTUNE, COMMUNITY_NODE_RETURN_NEPTUNE,
ENTITY_NODE_RETURN,
EPISODIC_NODE_RETURN, EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE, EPISODIC_NODE_RETURN_NEPTUNE,
get_community_node_save_query, get_community_node_save_query,
get_entity_node_return_query,
get_entity_node_save_query, get_entity_node_save_query,
get_episode_node_save_query, get_episode_node_save_query,
) )
@ -95,12 +96,37 @@ class Node(BaseModel, ABC):
case GraphProvider.NEO4J: case GraphProvider.NEO4J:
await driver.execute_query( await driver.execute_query(
""" """
MATCH (n:Entity|Episodic|Community {uuid: $uuid}) MATCH (n:Entity|Episodic|Community {uuid: $uuid})
DETACH DELETE n DETACH DELETE n
""", """,
uuid=self.uuid, uuid=self.uuid,
) )
case _: # FalkorDB and Neptune case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
DETACH DELETE e
""",
uuid=self.uuid,
)
await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']: for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query( await driver.execute_query(
f""" f"""
@ -136,8 +162,32 @@ class Node(BaseModel, ABC):
group_id=group_id, group_id=group_id,
batch_size=batch_size, batch_size=batch_size,
) )
case GraphProvider.KUZU:
case _: # FalkorDB and Neptune for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
""",
group_id=group_id,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
DETACH DELETE e
""",
group_id=group_id,
)
await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id})
DETACH DELETE n
""",
group_id=group_id,
)
case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']: for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query( await driver.execute_query(
f""" f"""
@ -149,30 +199,59 @@ class Node(BaseModel, ABC):
@classmethod @classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100): async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
if driver.provider == GraphProvider.FALKORDB: match driver.provider:
for label in ['Entity', 'Episodic', 'Community']: case GraphProvider.FALKORDB:
await driver.execute_query( for label in ['Entity', 'Episodic', 'Community']:
f""" await driver.execute_query(
MATCH (n:{label}) f"""
WHERE n.uuid IN $uuids MATCH (n:{label})
DETACH DELETE n WHERE n.uuid IN $uuids
""",
uuids=uuids,
)
else:
async with driver.session() as session:
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
CALL {
WITH n
DETACH DELETE n DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS """,
uuids=uuids,
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label})
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
WHERE n.uuid IN $uuids
DETACH DELETE e
""", """,
uuids=uuids, uuids=uuids,
batch_size=batch_size,
) )
await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
case _: # Neo4J, Neptune
async with driver.session() as session:
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
CALL {
WITH n
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
uuids=uuids,
batch_size=batch_size,
)
@classmethod @classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@ -376,17 +455,25 @@ class EntityNode(Node):
'summary': self.summary, 'summary': self.summary,
'created_at': self.created_at, 'created_at': self.created_at,
} }
entity_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.KUZU:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue entity_data['attributes'] = json.dumps(self.attributes)
entity_data['labels'] = list(set(self.labels + ['Entity']))
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels=''),
**entity_data,
)
else:
entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')]) if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels), get_entity_node_save_query(driver.provider, labels),
entity_data=entity_data, entity_data=entity_data,
) )
logger.debug(f'Saved Node to Graph: {self.uuid}') logger.debug(f'Saved Node to Graph: {self.uuid}')
@ -399,12 +486,12 @@ class EntityNode(Node):
MATCH (n:Entity {uuid: $uuid}) MATCH (n:Entity {uuid: $uuid})
RETURN RETURN
""" """
+ ENTITY_NODE_RETURN, + get_entity_node_return_query(driver.provider),
uuid=uuid, uuid=uuid,
routing_='r', routing_='r',
) )
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
if len(nodes) == 0: if len(nodes) == 0:
raise NodeNotFoundError(uuid) raise NodeNotFoundError(uuid)
@ -419,12 +506,12 @@ class EntityNode(Node):
WHERE n.uuid IN $uuids WHERE n.uuid IN $uuids
RETURN RETURN
""" """
+ ENTITY_NODE_RETURN, + get_entity_node_return_query(driver.provider),
uuids=uuids, uuids=uuids,
routing_='r', routing_='r',
) )
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes return nodes
@ -456,7 +543,7 @@ class EntityNode(Node):
+ """ + """
RETURN RETURN
""" """
+ ENTITY_NODE_RETURN + get_entity_node_return_query(driver.provider)
+ with_embeddings_query + with_embeddings_query
+ """ + """
ORDER BY n.uuid DESC ORDER BY n.uuid DESC
@ -468,7 +555,7 @@ class EntityNode(Node):
routing_='r', routing_='r',
) )
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes return nodes
@ -533,7 +620,7 @@ class CommunityNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Community {uuid: $uuid}) MATCH (c:Community {uuid: $uuid})
RETURN RETURN
""" """
+ ( + (
@ -556,8 +643,8 @@ class CommunityNode(Node):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Community) MATCH (c:Community)
WHERE n.uuid IN $uuids WHERE c.uuid IN $uuids
RETURN RETURN
""" """
+ ( + (
@ -581,13 +668,13 @@ class CommunityNode(Node):
limit: int | None = None, limit: int | None = None,
uuid_cursor: str | None = None, uuid_cursor: str | None = None,
): ):
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else '' cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Community) MATCH (c:Community)
WHERE n.group_id IN $group_ids WHERE c.group_id IN $group_ids
""" """
+ cursor_query + cursor_query
+ """ + """
@ -599,7 +686,7 @@ class CommunityNode(Node):
else COMMUNITY_NODE_RETURN else COMMUNITY_NODE_RETURN
) )
+ """ + """
ORDER BY n.uuid DESC ORDER BY c.uuid DESC
""" """
+ limit_query, + limit_query,
group_ids=group_ids, group_ids=group_ids,
@ -636,7 +723,19 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
) )
def get_entity_node_from_record(record: Any) -> EntityNode: def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
else:
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('name_embedding', None)
attributes.pop('summary', None)
attributes.pop('created_at', None)
attributes.pop('labels', None)
entity_node = EntityNode( entity_node = EntityNode(
uuid=record['uuid'], uuid=record['uuid'],
name=record['name'], name=record['name'],
@ -645,16 +744,9 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
labels=record['labels'], labels=record['labels'],
created_at=parse_db_date(record['created_at']), # type: ignore created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'], summary=record['summary'],
attributes=record['attributes'], attributes=attributes,
) )
entity_node.attributes.pop('uuid', None)
entity_node.attributes.pop('name', None)
entity_node.attributes.pop('group_id', None)
entity_node.attributes.pop('name_embedding', None)
entity_node.attributes.pop('summary', None)
entity_node.attributes.pop('created_at', None)
return entity_node return entity_node

View file

@ -20,6 +20,8 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from graphiti_core.driver.driver import GraphProvider
class ComparisonOperator(Enum): class ComparisonOperator(Enum):
equals = '=' equals = '='
@ -54,16 +56,21 @@ class SearchFilters(BaseModel):
def node_search_filter_query_constructor( def node_search_filter_query_constructor(
filters: SearchFilters, filters: SearchFilters,
) -> tuple[str, dict[str, Any]]: provider: GraphProvider,
filter_query: str = '' ) -> tuple[list[str], dict[str, Any]]:
filter_queries: list[str] = []
filter_params: dict[str, Any] = {} filter_params: dict[str, Any] = {}
if filters.node_labels is not None: if filters.node_labels is not None:
node_labels = '|'.join(filters.node_labels) if provider == GraphProvider.KUZU:
node_label_filter = ' AND n:' + node_labels node_label_filter = 'list_has_all(n.labels, $labels)'
filter_query += node_label_filter filter_params['labels'] = filters.node_labels
else:
node_labels = '|'.join(filters.node_labels)
node_label_filter = 'n:' + node_labels
filter_queries.append(node_label_filter)
return filter_query, filter_params return filter_queries, filter_params
def date_filter_query_constructor( def date_filter_query_constructor(
@ -81,23 +88,29 @@ def date_filter_query_constructor(
def edge_search_filter_query_constructor( def edge_search_filter_query_constructor(
filters: SearchFilters, filters: SearchFilters,
) -> tuple[str, dict[str, Any]]: provider: GraphProvider,
filter_query: str = '' ) -> tuple[list[str], dict[str, Any]]:
filter_queries: list[str] = []
filter_params: dict[str, Any] = {} filter_params: dict[str, Any] = {}
if filters.edge_types is not None: if filters.edge_types is not None:
edge_types = filters.edge_types edge_types = filters.edge_types
edge_types_filter = '\nAND e.name in $edge_types' filter_queries.append('e.name in $edge_types')
filter_query += edge_types_filter
filter_params['edge_types'] = edge_types filter_params['edge_types'] = edge_types
if filters.node_labels is not None: if filters.node_labels is not None:
node_labels = '|'.join(filters.node_labels) if provider == GraphProvider.KUZU:
node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels node_label_filter = (
filter_query += node_label_filter 'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
)
filter_params['labels'] = filters.node_labels
else:
node_labels = '|'.join(filters.node_labels)
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
filter_queries.append(node_label_filter)
if filters.valid_at is not None: if filters.valid_at is not None:
valid_at_filter = '\nAND (' valid_at_filter = '('
for i, or_list in enumerate(filters.valid_at): for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [ if date_filter.comparison_operator not in [
@ -125,10 +138,10 @@ def edge_search_filter_query_constructor(
else: else:
valid_at_filter += ' OR ' valid_at_filter += ' OR '
filter_query += valid_at_filter filter_queries.append(valid_at_filter)
if filters.invalid_at is not None: if filters.invalid_at is not None:
invalid_at_filter = ' AND (' invalid_at_filter = '('
for i, or_list in enumerate(filters.invalid_at): for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [ if date_filter.comparison_operator not in [
@ -156,10 +169,10 @@ def edge_search_filter_query_constructor(
else: else:
invalid_at_filter += ' OR ' invalid_at_filter += ' OR '
filter_query += invalid_at_filter filter_queries.append(invalid_at_filter)
if filters.created_at is not None: if filters.created_at is not None:
created_at_filter = ' AND (' created_at_filter = '('
for i, or_list in enumerate(filters.created_at): for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [ if date_filter.comparison_operator not in [
@ -187,10 +200,10 @@ def edge_search_filter_query_constructor(
else: else:
created_at_filter += ' OR ' created_at_filter += ' OR '
filter_query += created_at_filter filter_queries.append(created_at_filter)
if filters.expired_at is not None: if filters.expired_at is not None:
expired_at_filter = ' AND (' expired_at_filter = '('
for i, or_list in enumerate(filters.expired_at): for i, or_list in enumerate(filters.expired_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [ if date_filter.comparison_operator not in [
@ -218,6 +231,6 @@ def edge_search_filter_query_constructor(
else: else:
expired_at_filter += ' OR ' expired_at_filter += ' OR '
filter_query += expired_at_filter filter_queries.append(expired_at_filter)
return filter_query, filter_params return filter_queries, filter_params

File diff suppressed because it is too large Load diff

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import json
import logging import logging
import typing import typing
from datetime import datetime from datetime import datetime
@ -22,20 +23,21 @@ import numpy as np
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Any from typing_extensions import Any
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import normalize_l2, semaphore_gather from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import ( from graphiti_core.models.edges.edge_db_queries import (
EPISODIC_EDGE_SAVE_BULK,
get_entity_edge_save_bulk_query, get_entity_edge_save_bulk_query,
get_episodic_edge_save_bulk_query,
) )
from graphiti_core.models.nodes.node_db_queries import ( from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query, get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query, get_episode_node_save_bulk_query,
) )
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.edge_operations import (
extract_edges, extract_edges,
resolve_extracted_edge, resolve_extracted_edge,
@ -116,11 +118,15 @@ async def add_nodes_and_edges_bulk_tx(
episodes = [dict(episode) for episode in episodic_nodes] episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes: for episode in episodes:
episode['source'] = str(episode['source'].value) episode['source'] = str(episode['source'].value)
episode.pop('labels', None)
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '') episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
nodes: list[dict[str, Any]] = []
nodes = []
for node in entity_nodes: for node in entity_nodes:
if node.name_embedding is None: if node.name_embedding is None:
await node.generate_name_embedding(embedder) await node.generate_name_embedding(embedder)
entity_data: dict[str, Any] = { entity_data: dict[str, Any] = {
'uuid': node.uuid, 'uuid': node.uuid,
'name': node.name, 'name': node.name,
@ -130,13 +136,19 @@ async def add_nodes_and_edges_bulk_tx(
'created_at': node.created_at, 'created_at': node.created_at,
} }
entity_data.update(node.attributes or {}) entity_data['labels'] = list(set(node.labels + ['Entity']))
entity_data['labels'] = list( if driver.provider == GraphProvider.KUZU:
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')]) attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
) entity_data['attributes'] = json.dumps(attributes)
else:
entity_data.update(node.attributes or {})
entity_data['labels'] = list(
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
)
nodes.append(entity_data) nodes.append(entity_data)
edges: list[dict[str, Any]] = [] edges = []
for edge in entity_edges: for edge in entity_edges:
if edge.fact_embedding is None: if edge.fact_embedding is None:
await edge.generate_embedding(embedder) await edge.generate_embedding(embedder)
@ -155,17 +167,36 @@ async def add_nodes_and_edges_bulk_tx(
'invalid_at': edge.invalid_at, 'invalid_at': edge.invalid_at,
} }
edge_data.update(edge.attributes or {}) if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes)
else:
edge_data.update(edge.attributes or {})
edges.append(edge_data) edges.append(edge_data)
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) if driver.provider == GraphProvider.KUZU:
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes) # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
await tx.run(entity_node_save_bulk, nodes=nodes) episode_query = get_episode_node_save_bulk_query(driver.provider)
await tx.run( for episode in episodes:
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges] await tx.run(episode_query, **episode)
) entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider) for node in nodes:
await tx.run(entity_edge_save_bulk, entity_edges=edges) await tx.run(entity_node_query, **node)
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
for edge in edges:
await tx.run(entity_edge_query, **edge)
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
for edge in episodic_edges:
await tx.run(episodic_edge_query, **edge.model_dump())
else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
await tx.run(
get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges],
)
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
async def extract_nodes_and_edges_bulk( async def extract_nodes_and_edges_bulk(

View file

@ -40,3 +40,16 @@ def ensure_utc(dt: datetime | None) -> datetime | None:
return dt.astimezone(timezone.utc) return dt.astimezone(timezone.utc)
return dt return dt
def convert_datetimes_to_strings(obj):
if isinstance(obj, dict):
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_datetimes_to_strings(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(convert_datetimes_to_strings(item) for item in obj)
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj

View file

@ -4,11 +4,12 @@ from collections import defaultdict
from pydantic import BaseModel from pydantic import BaseModel
from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import CommunityEdge from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import semaphore_gather from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client import LLMClient
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
@ -33,11 +34,11 @@ async def get_community_clusters(
if group_ids is None: if group_ids is None:
group_id_values, _, _ = await driver.execute_query( group_id_values, _, _ = await driver.execute_query(
""" """
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.group_id IS NOT NULL WHERE n.group_id IS NOT NULL
RETURN RETURN
collect(DISTINCT n.group_id) AS group_ids collect(DISTINCT n.group_id) AS group_ids
""", """
) )
group_ids = group_id_values[0]['group_ids'] if group_id_values else [] group_ids = group_id_values[0]['group_ids'] if group_id_values else []
@ -46,14 +47,21 @@ async def get_community_clusters(
projection: dict[str, list[Neighbor]] = {} projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id]) nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes: for node in nodes:
records, _, _ = await driver.execute_query( match_query = """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
""" """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id}) records, _, _ = await driver.execute_query(
WITH count(r) AS count, m.uuid AS uuid match_query
RETURN + """
uuid, WITH count(e) AS count, m.uuid AS uuid
count RETURN
""", uuid,
count
""",
uuid=node.uuid, uuid=node.uuid,
group_id=group_id, group_id=group_id,
) )
@ -235,9 +243,9 @@ async def build_communities(
async def remove_communities(driver: GraphDriver): async def remove_communities(driver: GraphDriver):
await driver.execute_query( await driver.execute_query(
""" """
MATCH (c:Community) MATCH (c:Community)
DETACH DELETE c DETACH DELETE c
""", """
) )
@ -247,14 +255,10 @@ async def determine_entity_community(
# Check if the node is already part of a community # Check if the node is already part of a community
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid}) MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
RETURN RETURN
c.uuid AS uuid, """
c.name AS name, + COMMUNITY_NODE_RETURN,
c.group_id AS group_id,
c.created_at AS created_at,
c.summary AS summary
""",
entity_uuid=entity.uuid, entity_uuid=entity.uuid,
) )
@ -262,16 +266,19 @@ async def determine_entity_community(
return get_community_node_from_record(records[0]), False return get_community_node_from_record(records[0]), False
# If the node has no community, add it to the mode community of surrounding entities # If the node has no community, add it to the mode community of surrounding entities
records, _, _ = await driver.execute_query( match_query = """
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
""" """
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid}) records, _, _ = await driver.execute_query(
RETURN match_query
c.uuid AS uuid, + """
c.name AS name, RETURN
c.group_id AS group_id, """
c.created_at AS created_at, + COMMUNITY_NODE_RETURN,
c.summary AS summary
""",
entity_uuid=entity.uuid, entity_uuid=entity.uuid,
) )

View file

@ -531,17 +531,28 @@ async def filter_existing_duplicate_of_edges(
routing_='r', routing_='r',
) )
else: else:
query: LiteralString = """ if driver.provider == GraphProvider.KUZU:
UNWIND $duplicate_node_uuids AS duplicate_tuple query = """
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]}) UNWIND $duplicate_node_uuids AS duplicate
RETURN DISTINCT MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
n.uuid AS source_uuid, RETURN DISTINCT
m.uuid AS target_uuid n.uuid AS source_uuid,
""" m.uuid AS target_uuid
"""
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
else:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_node_uuids = list(duplicate_nodes_map.keys())
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
query, query,
duplicate_node_uuids=list(duplicate_nodes_map.keys()), duplicate_node_uuids=duplicate_node_uuids,
routing_='r', routing_='r',
) )

View file

@ -53,10 +53,29 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
for name in index_names for name in index_names
] ]
) )
range_indices: list[LiteralString] = get_range_indices(driver.provider) range_indices: list[LiteralString] = get_range_indices(driver.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
if driver.provider == GraphProvider.KUZU:
# Skip creating fulltext indices if they already exist. Need to do this manually
# until Kuzu supports `IF NOT EXISTS` for indices.
result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
if len(result) > 0:
fulltext_indices = []
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
if len(result) == 0:
fulltext_indices.insert(
0,
"""
INSTALL fts;
LOAD fts;
""",
)
index_queries: list[LiteralString] = range_indices + fulltext_indices index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather( await semaphore_gather(
@ -76,10 +95,19 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
await tx.run('MATCH (n) DETACH DELETE n') await tx.run('MATCH (n) DETACH DELETE n')
async def delete_group_ids(tx): async def delete_group_ids(tx):
await tx.run( labels = ['Entity', 'Episodic', 'Community']
'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n', if driver.provider == GraphProvider.KUZU:
group_ids=group_ids, labels.append('RelatesToNode_')
)
for label in labels:
await tx.run(
f"""
MATCH (n:{label})
WHERE n.group_id IN $group_ids
DETACH DELETE n
""",
group_ids=group_ids,
)
if group_ids is None: if group_ids is None:
await session.execute_write(delete_all) await session.execute_write(delete_all)
@ -108,18 +136,23 @@ async def retrieve_episodes(
Returns: Returns:
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
""" """
group_id_filter: LiteralString = (
'\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else '' query_params: dict = {}
) query_filter = ''
source_filter: LiteralString = '\nAND e.source = $source' if source is not None else '' if group_ids and len(group_ids) > 0:
query_filter += '\nAND e.group_id IN $group_ids'
query_params['group_ids'] = group_ids
if source is not None:
query_filter += '\nAND e.source = $source'
query_params['source'] = source.name
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (e:Episodic) MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time WHERE e.valid_at <= $reference_time
""" """
+ group_id_filter + query_filter
+ source_filter
+ """ + """
RETURN RETURN
""" """
@ -136,9 +169,8 @@ async def retrieve_episodes(
result, _, _ = await driver.execute_query( result, _, _ = await driver.execute_query(
query, query,
reference_time=reference_time, reference_time=reference_time,
source=source.name if source is not None else None,
num_episodes=last_n, num_episodes=last_n,
group_ids=group_ids, **query_params,
) )
episodes = [get_episodic_node_from_record(record) for record in result] episodes = [get_episodic_node_from_record(record) for record in result]

View file

@ -29,6 +29,7 @@ Repository = "https://github.com/getzep/graphiti"
anthropic = ["anthropic>=0.49.0"] anthropic = ["anthropic>=0.49.0"]
groq = ["groq>=0.2.0"] groq = ["groq>=0.2.0"]
google-genai = ["google-genai>=1.8.0"] google-genai = ["google-genai>=1.8.0"]
kuzu = ["kuzu>=0.11.2"]
falkordb = ["falkordb>=1.1.2,<2.0.0"] falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"] voyageai = ["voyageai>=0.2.3"]
sentence-transformers = ["sentence-transformers>=3.2.1"] sentence-transformers = ["sentence-transformers>=3.2.1"]
@ -39,6 +40,7 @@ dev = [
"anthropic>=0.49.0", "anthropic>=0.49.0",
"google-genai>=1.8.0", "google-genai>=1.8.0",
"falkordb>=1.1.2,<2.0.0", "falkordb>=1.1.2,<2.0.0",
"kuzu>=0.11.2",
"ipykernel>=6.29.5", "ipykernel>=6.29.5",
"jupyterlab>=4.2.4", "jupyterlab>=4.2.4",
"diskcache-stubs>=5.6.3.6.20240818", "diskcache-stubs>=5.6.3.6.20240818",
@ -91,7 +93,3 @@ docstring-code-format = true
include = ["graphiti_core"] include = ["graphiti_core"]
pythonVersion = "3.10" pythonVersion = "3.10"
typeCheckingMode = "basic" typeCheckingMode = "basic"
[[tool.pyright.overrides]]
include = ["**/falkordb*"]
reportMissingImports = false

View file

@ -1,4 +1,5 @@
[pytest] [pytest]
markers = markers =
integration: marks tests as integration tests integration: marks tests as integration tests
asyncio_default_fixture_loop_scope = function asyncio_default_fixture_loop_scope = function
asyncio_mode = auto

View file

@ -15,42 +15,55 @@ limitations under the License.
""" """
import os import os
from unittest.mock import Mock
import numpy as np
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.driver.neptune_driver import NeptuneDriver from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder.client import EmbedderClient
from graphiti_core.helpers import lucene_sanitize from graphiti_core.helpers import lucene_sanitize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv() load_dotenv()
HAS_NEO4J = False drivers: list[GraphProvider] = []
HAS_FALKORDB = False
HAS_NEPTUNE = False
if os.getenv('DISABLE_NEO4J') is None: if os.getenv('DISABLE_NEO4J') is None:
try: try:
from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver
HAS_NEO4J = True drivers.append(GraphProvider.NEO4J)
except ImportError: except ImportError:
pass raise
if os.getenv('DISABLE_FALKORDB') is None: if os.getenv('DISABLE_FALKORDB') is None:
try: try:
from graphiti_core.driver.falkordb_driver import FalkorDriver from graphiti_core.driver.falkordb_driver import FalkorDriver
HAS_FALKORDB = True drivers.append(GraphProvider.FALKORDB)
except ImportError: except ImportError:
pass raise
if os.getenv('DISABLE_KUZU') is None:
try:
from graphiti_core.driver.kuzu_driver import KuzuDriver
drivers.append(GraphProvider.KUZU)
except ImportError:
raise
# Disable Neptune for now
os.environ['DISABLE_NEPTUNE'] = 'True'
if os.getenv('DISABLE_NEPTUNE') is None: if os.getenv('DISABLE_NEPTUNE') is None:
try: try:
from graphiti_core.driver.neptune_driver import NeptuneDriver from graphiti_core.driver.neptune_driver import NeptuneDriver
HAS_NEPTUNE = False drivers.append(GraphProvider.NEPTUNE)
except ImportError: except ImportError:
pass raise
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687') NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j') NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
@ -65,38 +78,100 @@ NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182) NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
AOSS_HOST = os.getenv('AOSS_HOST', None) AOSS_HOST = os.getenv('AOSS_HOST', None)
KUZU_DB = os.getenv('KUZU_DB', ':memory:')
def get_driver(driver_name: str) -> GraphDriver: group_id = 'graphiti_test_group'
if driver_name == 'neo4j': group_id_2 = 'graphiti_test_group_2'
def get_driver(provider: GraphProvider) -> GraphDriver:
if provider == GraphProvider.NEO4J:
return Neo4jDriver( return Neo4jDriver(
uri=NEO4J_URI, uri=NEO4J_URI,
user=NEO4J_USER, user=NEO4J_USER,
password=NEO4J_PASSWORD, password=NEO4J_PASSWORD,
) )
elif driver_name == 'falkordb': elif provider == GraphProvider.FALKORDB:
return FalkorDriver( return FalkorDriver(
host=FALKORDB_HOST, host=FALKORDB_HOST,
port=int(FALKORDB_PORT), port=int(FALKORDB_PORT),
username=FALKORDB_USER, username=FALKORDB_USER,
password=FALKORDB_PASSWORD, password=FALKORDB_PASSWORD,
) )
elif driver_name == 'neptune': elif provider == GraphProvider.KUZU:
driver = KuzuDriver(
db=KUZU_DB,
)
return driver
elif provider == GraphProvider.NEPTUNE:
return NeptuneDriver( return NeptuneDriver(
host=NEPTUNE_HOST, host=NEPTUNE_HOST,
port=int(NEPTUNE_PORT), port=int(NEPTUNE_PORT),
aoss_host=AOSS_HOST, aoss_host=AOSS_HOST,
) )
else: else:
raise ValueError(f'Driver {driver_name} not available') raise ValueError(f'Driver {provider} not available')
drivers: list[str] = [] @pytest.fixture(params=drivers)
if HAS_NEO4J: async def graph_driver(request):
drivers.append('neo4j') driver = request.param
if HAS_FALKORDB: graph_driver = get_driver(driver)
drivers.append('falkordb') await clear_data(graph_driver, [group_id, group_id_2])
if HAS_NEPTUNE: try:
drivers.append('neptune') yield graph_driver # provide driver to the test
finally:
# always called, even if the test fails or raises
# await clean_up(graph_driver)
await graph_driver.close()
embedding_dim = 384
embeddings = {
key: np.random.uniform(0.0, 0.9, embedding_dim).tolist()
for key in [
'Alice',
'Bob',
'Alice likes Bob',
'test_entity_1',
'test_entity_2',
'test_entity_3',
'test_entity_4',
'test_entity_alice',
'test_entity_bob',
'test_entity_1 is a duplicate of test_entity_2',
'test_entity_3 is a duplicate of test_entity_4',
'test_entity_1 relates to test_entity_2',
'test_entity_1 relates to test_entity_3',
'test_entity_2 relates to test_entity_3',
'test_entity_1 relates to test_entity_4',
'test_entity_2 relates to test_entity_4',
'test_entity_3 relates to test_entity_4',
'test_entity_1 relates to test_entity_2',
'test_entity_3 relates to test_entity_4',
'test_entity_2 relates to test_entity_3',
'test_community_1',
'test_community_2',
]
}
embeddings['Alice Smith'] = embeddings['Alice']
@pytest.fixture
def mock_embedder():
mock_model = Mock(spec=EmbedderClient)
def mock_embed(input_data):
if isinstance(input_data, str):
return embeddings[input_data]
elif isinstance(input_data, list):
combined_input = ' '.join(input_data)
return embeddings[combined_input]
else:
raise ValueError(f'Unsupported input type: {type(input_data)}')
mock_model.create.side_effect = mock_embed
return mock_model
def test_lucene_sanitize(): def test_lucene_sanitize():
@ -114,5 +189,125 @@ def test_lucene_sanitize():
assert assert_result == result assert assert_result == result
async def get_node_count(driver: GraphDriver, uuids: list[str]) -> int:
results, _, _ = await driver.execute_query(
"""
MATCH (n)
WHERE n.uuid IN $uuids
RETURN COUNT(n) as count
""",
uuids=uuids,
)
return int(results[0]['count'])
async def get_edge_count(driver: GraphDriver, uuids: list[str]) -> int:
results, _, _ = await driver.execute_query(
"""
MATCH (n)-[e]->(m)
WHERE e.uuid IN $uuids
RETURN COUNT(e) as count
UNION ALL
MATCH (e:RelatesToNode_)
WHERE e.uuid IN $uuids
RETURN COUNT(e) as count
""",
uuids=uuids,
)
return sum(int(result['count']) for result in results)
async def print_graph(graph_driver: GraphDriver):
nodes, _, _ = await graph_driver.execute_query(
"""
MATCH (n)
RETURN n.uuid, n.name
""",
)
print('Nodes:')
for node in nodes:
print(' ', node)
edges, _, _ = await graph_driver.execute_query(
"""
MATCH (n)-[e]->(m)
RETURN n.name, e.uuid, m.name
""",
)
print('Edges:')
for edge in edges:
print(' ', edge)
async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode):
assert retrieved.uuid == sample.uuid
assert retrieved.name == sample.name
assert retrieved.group_id == group_id
assert retrieved.created_at == sample.created_at
assert retrieved.source == sample.source
assert retrieved.source_description == sample.source_description
assert retrieved.content == sample.content
assert retrieved.valid_at == sample.valid_at
assert set(retrieved.entity_edges) == set(sample.entity_edges)
async def assert_entity_node_equals(
graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode
):
await retrieved.load_name_embedding(graph_driver)
assert retrieved.uuid == sample.uuid
assert retrieved.name == sample.name
assert retrieved.group_id == sample.group_id
assert set(retrieved.labels) == set(sample.labels)
assert retrieved.created_at == sample.created_at
assert retrieved.name_embedding is not None
assert sample.name_embedding is not None
assert np.allclose(retrieved.name_embedding, sample.name_embedding)
assert retrieved.summary == sample.summary
assert retrieved.attributes == sample.attributes
async def assert_community_node_equals(
graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode
):
await retrieved.load_name_embedding(graph_driver)
assert retrieved.uuid == sample.uuid
assert retrieved.name == sample.name
assert retrieved.group_id == group_id
assert retrieved.created_at == sample.created_at
assert retrieved.name_embedding is not None
assert sample.name_embedding is not None
assert np.allclose(retrieved.name_embedding, sample.name_embedding)
assert retrieved.summary == sample.summary
async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge):
assert retrieved.uuid == sample.uuid
assert retrieved.group_id == sample.group_id
assert retrieved.created_at == sample.created_at
assert retrieved.source_node_uuid == sample.source_node_uuid
assert retrieved.target_node_uuid == sample.target_node_uuid
async def assert_entity_edge_equals(
graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge
):
await retrieved.load_fact_embedding(graph_driver)
assert retrieved.uuid == sample.uuid
assert retrieved.group_id == sample.group_id
assert retrieved.created_at == sample.created_at
assert retrieved.source_node_uuid == sample.source_node_uuid
assert retrieved.target_node_uuid == sample.target_node_uuid
assert retrieved.name == sample.name
assert retrieved.fact == sample.fact
assert retrieved.fact_embedding is not None
assert sample.fact_embedding is not None
assert np.allclose(retrieved.fact_embedding, sample.fact_embedding)
assert retrieved.episodes == sample.episodes
assert retrieved.expired_at == sample.expired_at
assert retrieved.valid_at == sample.valid_at
assert retrieved.invalid_at == sample.invalid_at
assert retrieved.attributes == sample.attributes
if __name__ == '__main__': if __name__ == '__main__':
pytest.main([__file__]) pytest.main([__file__])

View file

@ -17,23 +17,16 @@ limitations under the License.
import logging import logging
import sys import sys
from datetime import datetime from datetime import datetime
from uuid import uuid4
import numpy as np import numpy as np
import pytest import pytest
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.embedder.openai import OpenAIEmbedder
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from tests.helpers_test import drivers, get_driver from tests.helpers_test import get_edge_count, get_node_count, group_id
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',) pytest_plugins = ('pytest_asyncio',)
group_id = f'test_group_{str(uuid4())}'
def setup_logging(): def setup_logging():
# Create a logger # Create a logger
@ -57,17 +50,10 @@ def setup_logging():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( async def test_episodic_edge(graph_driver, mock_embedder):
'driver',
drivers,
ids=drivers,
)
async def test_episodic_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now() now = datetime.now()
# Create episodic node
episode_node = EpisodicNode( episode_node = EpisodicNode(
name='test_episode', name='test_episode',
labels=[], labels=[],
@ -79,13 +65,13 @@ async def test_episodic_edge(driver):
entity_edges=[], entity_edges=[],
group_id=group_id, group_id=group_id,
) )
node_count = await get_node_count(graph_driver, [episode_node.uuid])
node_count = await get_node_count(graph_driver, episode_node.uuid)
assert node_count == 0 assert node_count == 0
await episode_node.save(graph_driver) await episode_node.save(graph_driver)
node_count = await get_node_count(graph_driver, episode_node.uuid) node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 1 assert node_count == 1
# Create entity node
alice_node = EntityNode( alice_node = EntityNode(
name='Alice', name='Alice',
labels=[], labels=[],
@ -93,27 +79,27 @@ async def test_episodic_edge(driver):
summary='Alice summary', summary='Alice summary',
group_id=group_id, group_id=group_id,
) )
await alice_node.generate_name_embedding(embedder) await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0 assert node_count == 0
await alice_node.save(graph_driver) await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid) node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1 assert node_count == 1
# Create episodic to entity edge
episodic_edge = EpisodicEdge( episodic_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid, source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid, target_node_uuid=alice_node.uuid,
created_at=now, created_at=now,
group_id=group_id, group_id=group_id,
) )
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
assert edge_count == 0 assert edge_count == 0
await episodic_edge.save(graph_driver) await episodic_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid) edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 1 assert edge_count == 1
# Get edge by uuid
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid) retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
assert retrieved.uuid == episodic_edge.uuid assert retrieved.uuid == episodic_edge.uuid
assert retrieved.source_node_uuid == episode_node.uuid assert retrieved.source_node_uuid == episode_node.uuid
@ -121,6 +107,7 @@ async def test_episodic_edge(driver):
assert retrieved.created_at == now assert retrieved.created_at == now
assert retrieved.group_id == group_id assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid]) retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid assert retrieved[0].uuid == episodic_edge.uuid
@ -129,6 +116,7 @@ async def test_episodic_edge(driver):
assert retrieved[0].created_at == now assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2) retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid assert retrieved[0].uuid == episodic_edge.uuid
@ -137,33 +125,41 @@ async def test_episodic_edge(driver):
assert retrieved[0].created_at == now assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id assert retrieved[0].group_id == group_id
# Get episodic node by entity node uuid
retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == episode_node.uuid
assert retrieved[0].name == 'test_episode'
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Delete edge by uuid
await episodic_edge.delete(graph_driver) await episodic_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid) edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0 assert edge_count == 0
await episode_node.delete(graph_driver) # Delete edge by uuids
node_count = await get_node_count(graph_driver, episode_node.uuid) await episodic_edge.save(graph_driver)
assert node_count == 0 await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await episode_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 0
await alice_node.delete(graph_driver) await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid) node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0 assert node_count == 0
await graph_driver.close() await graph_driver.close()
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( async def test_entity_edge(graph_driver, mock_embedder):
'driver',
drivers,
ids=drivers,
)
async def test_entity_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now() now = datetime.now()
# Create entity node
alice_node = EntityNode( alice_node = EntityNode(
name='Alice', name='Alice',
labels=[], labels=[],
@ -171,25 +167,25 @@ async def test_entity_edge(driver):
summary='Alice summary', summary='Alice summary',
group_id=group_id, group_id=group_id,
) )
await alice_node.generate_name_embedding(embedder) await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0 assert node_count == 0
await alice_node.save(graph_driver) await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid) node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1 assert node_count == 1
# Create entity node
bob_node = EntityNode( bob_node = EntityNode(
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
) )
await bob_node.generate_name_embedding(embedder) await bob_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
node_count = await get_node_count(graph_driver, bob_node.uuid)
assert node_count == 0 assert node_count == 0
await bob_node.save(graph_driver) await bob_node.save(graph_driver)
node_count = await get_node_count(graph_driver, bob_node.uuid) node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 1 assert node_count == 1
# Create entity to entity edge
entity_edge = EntityEdge( entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid, source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid, target_node_uuid=bob_node.uuid,
@ -202,14 +198,14 @@ async def test_entity_edge(driver):
invalid_at=now, invalid_at=now,
group_id=group_id, group_id=group_id,
) )
edge_embedding = await entity_edge.generate_embedding(embedder) edge_embedding = await entity_edge.generate_embedding(mock_embedder)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
assert edge_count == 0 assert edge_count == 0
await entity_edge.save(graph_driver) await entity_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid) edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 1 assert edge_count == 1
# Get edge by uuid
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid) retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
assert retrieved.uuid == entity_edge.uuid assert retrieved.uuid == entity_edge.uuid
assert retrieved.source_node_uuid == alice_node.uuid assert retrieved.source_node_uuid == alice_node.uuid
@ -217,6 +213,7 @@ async def test_entity_edge(driver):
assert retrieved.created_at == now assert retrieved.created_at == now
assert retrieved.group_id == group_id assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid]) retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid assert retrieved[0].uuid == entity_edge.uuid
@ -225,6 +222,7 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2) retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid assert retrieved[0].uuid == entity_edge.uuid
@ -233,6 +231,7 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id assert retrieved[0].group_id == group_id
# Get edge by node uuid
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid) retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid assert retrieved[0].uuid == entity_edge.uuid
@ -241,82 +240,113 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id assert retrieved[0].group_id == group_id
# Get fact embedding
await entity_edge.load_fact_embedding(graph_driver) await entity_edge.load_fact_embedding(graph_driver)
assert np.allclose(entity_edge.fact_embedding, edge_embedding) assert np.allclose(entity_edge.fact_embedding, edge_embedding)
# Delete edge by uuid
await entity_edge.delete(graph_driver) await entity_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid) edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0 assert edge_count == 0
await alice_node.delete(graph_driver) # Delete edge by uuids
node_count = await get_node_count(graph_driver, alice_node.uuid) await entity_edge.save(graph_driver)
assert node_count == 0 await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node should delete the edge
await entity_edge.save(graph_driver)
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node by uuids should delete the edge
await alice_node.save(graph_driver)
await entity_edge.save(graph_driver)
await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node by group id should delete the edge
await alice_node.save(graph_driver)
await entity_edge.save(graph_driver)
await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await bob_node.delete(graph_driver) await bob_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, bob_node.uuid) node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0 assert node_count == 0
await graph_driver.close() await graph_driver.close()
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( async def test_community_edge(graph_driver, mock_embedder):
'driver',
drivers,
ids=drivers,
)
async def test_community_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now() now = datetime.now()
# Create community node
community_node_1 = CommunityNode( community_node_1 = CommunityNode(
name='Community A', name='test_community_1',
group_id=group_id, group_id=group_id,
summary='Community A summary', summary='Community A summary',
) )
await community_node_1.generate_name_embedding(embedder) await community_node_1.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, community_node_1.uuid) node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0 assert node_count == 0
await community_node_1.save(graph_driver) await community_node_1.save(graph_driver)
node_count = await get_node_count(graph_driver, community_node_1.uuid) node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 1 assert node_count == 1
# Create community node
community_node_2 = CommunityNode( community_node_2 = CommunityNode(
name='Community B', name='test_community_2',
group_id=group_id, group_id=group_id,
summary='Community B summary', summary='Community B summary',
) )
await community_node_2.generate_name_embedding(embedder) await community_node_2.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, community_node_2.uuid) node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0 assert node_count == 0
await community_node_2.save(graph_driver) await community_node_2.save(graph_driver)
node_count = await get_node_count(graph_driver, community_node_2.uuid) node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 1 assert node_count == 1
# Create entity node
alice_node = EntityNode( alice_node = EntityNode(
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
) )
await alice_node.generate_name_embedding(embedder) await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid) node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0 assert node_count == 0
await alice_node.save(graph_driver) await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid) node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1 assert node_count == 1
# Create community to community edge
community_edge = CommunityEdge( community_edge = CommunityEdge(
source_node_uuid=community_node_1.uuid, source_node_uuid=community_node_1.uuid,
target_node_uuid=community_node_2.uuid, target_node_uuid=community_node_2.uuid,
created_at=now, created_at=now,
group_id=group_id, group_id=group_id,
) )
edge_count = await get_edge_count(graph_driver, community_edge.uuid) edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0 assert edge_count == 0
await community_edge.save(graph_driver) await community_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, community_edge.uuid) edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 1 assert edge_count == 1
# Get edge by uuid
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid) retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
assert retrieved.uuid == community_edge.uuid assert retrieved.uuid == community_edge.uuid
assert retrieved.source_node_uuid == community_node_1.uuid assert retrieved.source_node_uuid == community_node_1.uuid
@ -324,6 +354,7 @@ async def test_community_edge(driver):
assert retrieved.created_at == now assert retrieved.created_at == now
assert retrieved.group_id == group_id assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid]) retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid assert retrieved[0].uuid == community_edge.uuid
@ -332,6 +363,7 @@ async def test_community_edge(driver):
assert retrieved[0].created_at == now assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1) retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid assert retrieved[0].uuid == community_edge.uuid
@ -340,45 +372,26 @@ async def test_community_edge(driver):
assert retrieved[0].created_at == now assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id assert retrieved[0].group_id == group_id
# Delete edge by uuid
await community_edge.delete(graph_driver) await community_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, community_edge.uuid) edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0 assert edge_count == 0
# Delete edge by uuids
await community_edge.save(graph_driver)
await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await alice_node.delete(graph_driver) await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid) node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0 assert node_count == 0
await community_node_1.delete(graph_driver) await community_node_1.delete(graph_driver)
node_count = await get_node_count(graph_driver, community_node_1.uuid) node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0 assert node_count == 0
await community_node_2.delete(graph_driver) await community_node_2.delete(graph_driver)
node_count = await get_node_count(graph_driver, community_node_2.uuid) node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0 assert node_count == 0
await graph_driver.close() await graph_driver.close()
async def get_node_count(driver: GraphDriver, uuid: str):
results, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
RETURN COUNT(n) as count
""",
uuid=uuid,
)
return int(results[0]['count'])
async def get_edge_count(driver: GraphDriver, uuid: str):
results, _, _ = await driver.execute_query(
"""
MATCH (n)-[e {uuid: $uuid}]->(m)
RETURN COUNT(e) as count
UNION ALL
MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2)
RETURN COUNT(m) as count
""",
uuid=uuid,
)
return sum(int(result['count']) for result in results)

View file

@ -60,7 +60,6 @@ class Location(BaseModel):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'driver', 'driver',
drivers, drivers,
ids=drivers,
) )
async def test_exclude_default_entity_type(driver): async def test_exclude_default_entity_type(driver):
"""Test excluding the default 'Entity' type while keeping custom types.""" """Test excluding the default 'Entity' type while keeping custom types."""
@ -118,7 +117,6 @@ async def test_exclude_default_entity_type(driver):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'driver', 'driver',
drivers, drivers,
ids=drivers,
) )
async def test_exclude_specific_custom_types(driver): async def test_exclude_specific_custom_types(driver):
"""Test excluding specific custom entity types while keeping others.""" """Test excluding specific custom entity types while keeping others."""
@ -182,7 +180,6 @@ async def test_exclude_specific_custom_types(driver):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'driver', 'driver',
drivers, drivers,
ids=drivers,
) )
async def test_exclude_all_types(driver): async def test_exclude_all_types(driver):
"""Test excluding all entity types (edge case).""" """Test excluding all entity types (edge case)."""
@ -231,7 +228,6 @@ async def test_exclude_all_types(driver):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'driver', 'driver',
drivers, drivers,
ids=drivers,
) )
async def test_exclude_no_types(driver): async def test_exclude_no_types(driver):
"""Test normal behavior when no types are excluded (baseline test).""" """Test normal behavior when no types are excluded (baseline test)."""
@ -314,7 +310,6 @@ def test_validation_invalid_excluded_types():
@pytest.mark.parametrize( @pytest.mark.parametrize(
'driver', 'driver',
drivers, drivers,
ids=drivers,
) )
async def test_excluded_types_parameter_validation_in_add_episode(driver): async def test_excluded_types_parameter_validation_in_add_episode(driver):
"""Test that add_episode validates excluded_entity_types parameter.""" """Test that add_episode validates excluded_entity_types parameter."""

View file

@ -23,7 +23,7 @@ from graphiti_core.graphiti import Graphiti
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
from graphiti_core.search.search_helpers import search_results_to_context_string from graphiti_core.search.search_helpers import search_results_to_context_string
from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.datetime_utils import utc_now
from tests.helpers_test import drivers, get_driver from tests.helpers_test import GraphProvider
pytestmark = pytest.mark.integration pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',) pytest_plugins = ('pytest_asyncio',)
@ -51,15 +51,12 @@ def setup_logging():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( async def test_graphiti_init(graph_driver):
'driver', if graph_driver.provider == GraphProvider.FALKORDB:
drivers, pytest.skip('Skipping as tests fail on Falkordb')
ids=drivers,
)
async def test_graphiti_init(driver):
logger = setup_logging() logger = setup_logging()
driver = get_driver(driver) graphiti = Graphiti(graph_driver=graph_driver)
graphiti = Graphiti(graph_driver=driver)
await graphiti.build_indices_and_constraints() await graphiti.build_indices_and_constraints()

2056
tests/test_graphiti_mock.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -14,22 +14,29 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from datetime import datetime from datetime import datetime, timedelta
from uuid import uuid4 from uuid import uuid4
import numpy as np
import pytest import pytest
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.nodes import ( from graphiti_core.nodes import (
CommunityNode, CommunityNode,
EntityNode, EntityNode,
EpisodeType, EpisodeType,
EpisodicNode, EpisodicNode,
) )
from tests.helpers_test import drivers, get_driver from tests.helpers_test import (
assert_community_node_equals,
assert_entity_node_equals,
assert_episodic_node_equals,
get_node_count,
group_id,
)
group_id = f'test_group_{str(uuid4())}' created_at = datetime.now()
deleted_at = created_at + timedelta(days=3)
valid_at = created_at + timedelta(days=1)
invalid_at = created_at + timedelta(days=2)
@pytest.fixture @pytest.fixture
@ -38,9 +45,14 @@ def sample_entity_node():
uuid=str(uuid4()), uuid=str(uuid4()),
name='Test Entity', name='Test Entity',
group_id=group_id, group_id=group_id,
labels=[], labels=['Entity', 'Person'],
created_at=created_at,
name_embedding=[0.5] * 1024, name_embedding=[0.5] * 1024,
summary='Entity Summary', summary='Entity Summary',
attributes={
'age': 30,
'location': 'New York',
},
) )
@ -50,10 +62,12 @@ def sample_episodic_node():
uuid=str(uuid4()), uuid=str(uuid4()),
name='Episode 1', name='Episode 1',
group_id=group_id, group_id=group_id,
created_at=created_at,
source=EpisodeType.text, source=EpisodeType.text,
source_description='Test source', source_description='Test source',
content='Some content here', content='Some content here',
valid_at=datetime.now(), valid_at=valid_at,
entity_edges=[],
) )
@ -62,182 +76,152 @@ def sample_community_node():
return CommunityNode( return CommunityNode(
uuid=str(uuid4()), uuid=str(uuid4()),
name='Community A', name='Community A',
name_embedding=[0.5] * 1024,
group_id=group_id, group_id=group_id,
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Community summary', summary='Community summary',
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( async def test_entity_node(sample_entity_node, graph_driver):
'driver',
drivers,
ids=drivers,
)
async def test_entity_node(sample_entity_node, driver):
driver = get_driver(driver)
uuid = sample_entity_node.uuid uuid = sample_entity_node.uuid
# Create node # Create node
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
await sample_entity_node.save(driver) await sample_entity_node.save(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1 assert node_count == 1
retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid) # Get node by uuid
assert retrieved.uuid == sample_entity_node.uuid retrieved = await EntityNode.get_by_uuid(graph_driver, sample_entity_node.uuid)
assert retrieved.name == 'Test Entity' await assert_entity_node_equals(graph_driver, retrieved, sample_entity_node)
assert retrieved.group_id == group_id
retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid]) # Get node by uuids
assert retrieved[0].uuid == sample_entity_node.uuid retrieved = await EntityNode.get_by_uuids(graph_driver, [sample_entity_node.uuid])
assert retrieved[0].name == 'Test Entity' await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
assert retrieved[0].group_id == group_id
retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2) # Get node by group ids
retrieved = await EntityNode.get_by_group_ids(graph_driver, [group_id], limit=2, with_embeddings=True)
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == sample_entity_node.uuid await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
assert retrieved[0].name == 'Test Entity'
assert retrieved[0].group_id == group_id
await sample_entity_node.load_name_embedding(driver)
assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024)
# Delete node by uuid # Delete node by uuid
await sample_entity_node.delete(driver) await sample_entity_node.delete(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_entity_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
# Delete node by group id # Delete node by group id
await sample_entity_node.save(driver) await sample_entity_node.save(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1 assert node_count == 1
await sample_entity_node.delete_by_group_id(driver, group_id) await sample_entity_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
await driver.close() await graph_driver.close()
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( async def test_community_node(sample_community_node, graph_driver):
'driver',
drivers,
ids=drivers,
)
async def test_community_node(sample_community_node, driver):
driver = get_driver(driver)
uuid = sample_community_node.uuid uuid = sample_community_node.uuid
# Create node # Create node
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
await sample_community_node.save(driver) await sample_community_node.save(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1 assert node_count == 1
retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid) # Get node by uuid
assert retrieved.uuid == sample_community_node.uuid retrieved = await CommunityNode.get_by_uuid(graph_driver, sample_community_node.uuid)
assert retrieved.name == 'Community A' await assert_community_node_equals(graph_driver, retrieved, sample_community_node)
assert retrieved.group_id == group_id
assert retrieved.summary == 'Community summary'
retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid]) # Get node by uuids
assert retrieved[0].uuid == sample_community_node.uuid retrieved = await CommunityNode.get_by_uuids(graph_driver, [sample_community_node.uuid])
assert retrieved[0].name == 'Community A' await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
assert retrieved[0].group_id == group_id
assert retrieved[0].summary == 'Community summary'
retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2) # Get node by group ids
retrieved = await CommunityNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == sample_community_node.uuid await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
assert retrieved[0].name == 'Community A'
assert retrieved[0].group_id == group_id
# Delete node by uuid # Delete node by uuid
await sample_community_node.delete(driver) await sample_community_node.delete(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_community_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
# Delete node by group id # Delete node by group id
await sample_community_node.save(driver) await sample_community_node.save(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1 assert node_count == 1
await sample_community_node.delete_by_group_id(driver, group_id) await sample_community_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
await driver.close() await graph_driver.close()
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( async def test_episodic_node(sample_episodic_node, graph_driver):
'driver',
drivers,
ids=drivers,
)
async def test_episodic_node(sample_episodic_node, driver):
driver = get_driver(driver)
uuid = sample_episodic_node.uuid uuid = sample_episodic_node.uuid
# Create node # Create node
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
await sample_episodic_node.save(driver) await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1 assert node_count == 1
retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid) # Get node by uuid
assert retrieved.uuid == sample_episodic_node.uuid retrieved = await EpisodicNode.get_by_uuid(graph_driver, sample_episodic_node.uuid)
assert retrieved.name == 'Episode 1' await assert_episodic_node_equals(retrieved, sample_episodic_node)
assert retrieved.group_id == group_id
assert retrieved.source == EpisodeType.text
assert retrieved.source_description == 'Test source'
assert retrieved.content == 'Some content here'
assert retrieved.valid_at == sample_episodic_node.valid_at
retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid]) # Get node by uuids
assert retrieved[0].uuid == sample_episodic_node.uuid retrieved = await EpisodicNode.get_by_uuids(graph_driver, [sample_episodic_node.uuid])
assert retrieved[0].name == 'Episode 1' await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
assert retrieved[0].group_id == group_id
assert retrieved[0].source == EpisodeType.text
assert retrieved[0].source_description == 'Test source'
assert retrieved[0].content == 'Some content here'
assert retrieved[0].valid_at == sample_episodic_node.valid_at
retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2) # Get node by group ids
retrieved = await EpisodicNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1 assert len(retrieved) == 1
assert retrieved[0].uuid == sample_episodic_node.uuid await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
assert retrieved[0].name == 'Episode 1'
assert retrieved[0].group_id == group_id
assert retrieved[0].source == EpisodeType.text
assert retrieved[0].source_description == 'Test source'
assert retrieved[0].content == 'Some content here'
assert retrieved[0].valid_at == sample_episodic_node.valid_at
# Delete node by uuid # Delete node by uuid
await sample_episodic_node.delete(driver) await sample_episodic_node.delete(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_episodic_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
# Delete node by group id # Delete node by group id
await sample_episodic_node.save(driver) await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1 assert node_count == 1
await sample_episodic_node.delete_by_group_id(driver, group_id) await sample_episodic_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(driver, uuid) node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0 assert node_count == 0
await driver.close() await graph_driver.close()
async def get_node_count(driver: GraphDriver, uuid: str):
result, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
RETURN COUNT(n) as count
""",
uuid=uuid,
)
return int(result[0]['count'])

42
uv.lock generated
View file

@ -809,6 +809,7 @@ dev = [
{ name = "groq" }, { name = "groq" },
{ name = "ipykernel" }, { name = "ipykernel" },
{ name = "jupyterlab" }, { name = "jupyterlab" },
{ name = "kuzu" },
{ name = "langchain-anthropic" }, { name = "langchain-anthropic" },
{ name = "langchain-openai" }, { name = "langchain-openai" },
{ name = "langgraph" }, { name = "langgraph" },
@ -831,6 +832,9 @@ google-genai = [
groq = [ groq = [
{ name = "groq" }, { name = "groq" },
] ]
kuzu = [
{ name = "kuzu" },
]
neptune = [ neptune = [
{ name = "boto3" }, { name = "boto3" },
{ name = "langchain-aws" }, { name = "langchain-aws" },
@ -858,6 +862,8 @@ requires-dist = [
{ name = "groq", marker = "extra == 'groq'", specifier = ">=0.2.0" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.2.0" },
{ name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5" }, { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5" },
{ name = "jupyterlab", marker = "extra == 'dev'", specifier = ">=4.2.4" }, { name = "jupyterlab", marker = "extra == 'dev'", specifier = ">=4.2.4" },
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" }, { name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" }, { name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" }, { name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
@ -882,7 +888,7 @@ requires-dist = [
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" }, { name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
] ]
provides-extras = ["anthropic", "groq", "google-genai", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"] provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
[[package]] [[package]]
name = "groq" name = "groq"
@ -1387,6 +1393,40 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/09/2032e7d15c544a0e3cd831c51d77a8ca57f7555b2e1b2922142eddb02a84/jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4", size = 59700, upload-time = "2024-07-16T17:02:01.115Z" }, { url = "https://files.pythonhosted.org/packages/54/09/2032e7d15c544a0e3cd831c51d77a8ca57f7555b2e1b2922142eddb02a84/jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4", size = 59700, upload-time = "2024-07-16T17:02:01.115Z" },
] ]
[[package]]
name = "kuzu"
version = "0.11.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/66/fd/adbd05ccf81e6ad2674fcd3849d5d6ffeaf2141a9b8d1c1c4e282e923e1f/kuzu-0.11.2.tar.gz", hash = "sha256:9f224ec218ab165a18acaea903695779780d70335baf402d9b7f59ba389db0bd", size = 4902887, upload-time = "2025-08-21T05:17:00.152Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/91/bed837f5f49220a9f869da8a078b34a3484f210f7b57b267177821545c03/kuzu-0.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b25174cdb721aae47896ed62842d3859679607b493a9a6bbbcd9fb7fb3707", size = 3702618, upload-time = "2025-08-21T05:15:53.726Z" },
{ url = "https://files.pythonhosted.org/packages/72/8a/fd5e053b0055718afe00b6a99393a835c6254354128fbb7f66a35fd76089/kuzu-0.11.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:9a8567c53bfe282f4727782471ff718842ffead8c48c1762c1df9197408fc986", size = 4101371, upload-time = "2025-08-21T05:15:55.889Z" },
{ url = "https://files.pythonhosted.org/packages/ad/4b/e45cadc85bdc5079f432675bbe8d557600f0d4ab46fe24ef218374419902/kuzu-0.11.2-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d793bb5a0a14ada730a697eccac2a4c68b434b82692d985942900ef2003e099e", size = 6211974, upload-time = "2025-08-21T05:15:57.505Z" },
{ url = "https://files.pythonhosted.org/packages/10/ca/92d6a1e6452fcf06bfc423ce2cde819ace6b6e47921921cc8fae87c27780/kuzu-0.11.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1be4e9b6c93ca8591b1fb165f9b9a27d70a56af061831afcdfe7aebb89ee6ff", size = 6992196, upload-time = "2025-08-21T05:15:59.006Z" },
{ url = "https://files.pythonhosted.org/packages/49/6c/983fc6265dfc1169c87c4a0722f36ee665c5688e1166faeb4cd85e6af078/kuzu-0.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0ec7a304c746a2a98ecfd7e7c3f6fe92c4dfee2e2565c0b7cb4cffd0c2e374a", size = 4303517, upload-time = "2025-08-21T05:16:00.814Z" },
{ url = "https://files.pythonhosted.org/packages/b5/14/8ae2c52657b93715052ecf47d70232f2c8d9ffe2d1ec3527c8e9c3cb2df5/kuzu-0.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf53b4f321a4c05882b14cef96d39a1e90fa993bab88a1554fb1565367553b8c", size = 3704177, upload-time = "2025-08-21T05:16:02.354Z" },
{ url = "https://files.pythonhosted.org/packages/2d/7a/bce7bb755e16f9ca855f76a3acc6cfa9fae88c4d6af9df3784c50b2120a5/kuzu-0.11.2-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:2d749883b74f5da5ff4a4b0635a98f6cc5165743995828924321d2ca797317cb", size = 4102372, upload-time = "2025-08-21T05:16:04.249Z" },
{ url = "https://files.pythonhosted.org/packages/c8/12/f5b1d51fcb78a86c078fb85cc53184ce962a3e86852d47d30e287a932e3f/kuzu-0.11.2-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:632507e5982928ed24fbb5e70ad143d7970bc4059046e77e0522707efbad303b", size = 6212492, upload-time = "2025-08-21T05:16:05.99Z" },
{ url = "https://files.pythonhosted.org/packages/81/96/d6e57af6ccf9e0697812ad3039c80b87b768cf2674833b0b23d317ea3427/kuzu-0.11.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5211884601f8f08ae97ba25006d0facde24077c5333411d944282b8a2068ab4", size = 6992888, upload-time = "2025-08-21T05:16:07.896Z" },
{ url = "https://files.pythonhosted.org/packages/40/ee/1f275ac5679a3f615ce0d9cf8c79001fdb535ccc8bc344e49b14484c7cd7/kuzu-0.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:82a6c8bfe1278dc1010790e398bf772683797ef5c16052fa0f6f78bacbc59aa3", size = 4304064, upload-time = "2025-08-21T05:16:10.163Z" },
{ url = "https://files.pythonhosted.org/packages/73/ba/9f20d9e83681a0ddae8ec13046b116c34745fa0e66862d4e2d8414734ce0/kuzu-0.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aed88ffa695d07289a3d8557bd8f9e743298a4f4349208a60bbb06f4ebf15c26", size = 3703781, upload-time = "2025-08-21T05:16:12.232Z" },
{ url = "https://files.pythonhosted.org/packages/53/a0/bb815c0490f3d4d30389156369b9fe641e154f0d4b1e8340f09a76021922/kuzu-0.11.2-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:595824b03248af928e3faee57f6825d3a46920f2d3b9bd0c0bb7fc3fa097fce9", size = 4103990, upload-time = "2025-08-21T05:16:14.139Z" },
{ url = "https://files.pythonhosted.org/packages/a5/6f/97b647c0547a634a669055ff4cfd21a92ea3999aedc6a7fe9004f03f25e3/kuzu-0.11.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5674c6d9d26f5caa0c7ce6f34c02e4411894879aa5b2ce174fad576fa898523", size = 6211947, upload-time = "2025-08-21T05:16:16.48Z" },
{ url = "https://files.pythonhosted.org/packages/42/74/c7f1a1cfb08c05c91c5a94483be387e80fafab8923c4243a22e9cced5c1b/kuzu-0.11.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c61daf02da35b671f4c6f3c17105725c399a5e14b7349b00eafbcd24ac90034a", size = 6991879, upload-time = "2025-08-21T05:16:18.402Z" },
{ url = "https://files.pythonhosted.org/packages/54/9e/50d67d7bc08faed95ede6de1a6aa0d81079c98028ca99e32d09c2ab1aead/kuzu-0.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:682096cd87dcbb8257f933ea4172d9dc5617a8d0a5bdd19cd66cf05b68881afd", size = 4305706, upload-time = "2025-08-21T05:16:20.244Z" },
{ url = "https://files.pythonhosted.org/packages/65/f0/5649a01af37def50293cd7c194afc19f09b343fd2b7f2b28e021a207f8ce/kuzu-0.11.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:17a11b67652e8b331c85cd1a1a30b32ee6783750084473abbab2aa1963ee2a3b", size = 3703740, upload-time = "2025-08-21T05:16:21.896Z" },
{ url = "https://files.pythonhosted.org/packages/24/e2/e0beb9080911fc1689899a42da0f83534949f43169fb80197def3ec1223f/kuzu-0.11.2-cp313-cp313-macosx_11_0_x86_64.whl", hash = "sha256:bdded35426210faeca8da11e8b4a54e60ccc0c1a832660d76587b5be133b0f55", size = 4104073, upload-time = "2025-08-21T05:16:23.819Z" },
{ url = "https://files.pythonhosted.org/packages/f2/4c/7a831c9c6e609692953db677f54788bd1dde4c9d34e6ba91f1e153d2e7fe/kuzu-0.11.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6116b609aac153f3523130b31295643d34a6c9509914c0fa9d804b26b23eee73", size = 6212263, upload-time = "2025-08-21T05:16:25.351Z" },
{ url = "https://files.pythonhosted.org/packages/47/95/615ef10b46b22ec1d33fdbba795e6e79733d9a244aabdeeb910f267ab36c/kuzu-0.11.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:09da5b8cb24dc6b281a6e4ac0f7f24226eb9909803b187e02d014da13ba57bcf", size = 6992492, upload-time = "2025-08-21T05:16:27.518Z" },
{ url = "https://files.pythonhosted.org/packages/a7/dd/2c905575913c743e6c67a5ca89a6b4ea9d9737238966d85d7e710f0d3e60/kuzu-0.11.2-cp313-cp313-win_amd64.whl", hash = "sha256:c663fb84682f8ebffbe7447a4e552a0e03bd29097d319084a2c53c2e032a780e", size = 4305267, upload-time = "2025-08-21T05:16:29.307Z" },
{ url = "https://files.pythonhosted.org/packages/89/05/44fbfc9055dba3f472ea4aaa8110635864d3441eede987526ef401680765/kuzu-0.11.2-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5c03fb95ffb9185c1519333f8ee92b7a9695aa7aa9a179e868a7d7bd13d10a16", size = 6216795, upload-time = "2025-08-21T05:16:30.944Z" },
{ url = "https://files.pythonhosted.org/packages/4f/ca/16c81dc68cc1e8918f8481e7ee89c28aa665c5cb36be7ad0fc1d0d295760/kuzu-0.11.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d857f0efddf26d5e2dc189facb84bf04a997e395972486669b418a470cc76034", size = 6996333, upload-time = "2025-08-21T05:16:32.568Z" },
{ url = "https://files.pythonhosted.org/packages/48/d8/9275c7e6312bd76dc670e8e2da68639757c22cf2c366e96527595a1d881c/kuzu-0.11.2-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb9e4641867c35b98ceaa604aa79832c0eeed41f5fd1b6da22b1c217b2f1b8ea", size = 6212202, upload-time = "2025-08-21T05:16:34.571Z" },
{ url = "https://files.pythonhosted.org/packages/88/89/67a977493c60bca3610845df13020711f357a5d80bf91549e4b48d877c2f/kuzu-0.11.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:553408d76a0b4fdecc1338b69b71d7bde42f6936d3b99d9852b30d33bda15978", size = 6992264, upload-time = "2025-08-21T05:16:36.316Z" },
{ url = "https://files.pythonhosted.org/packages/b6/49/869ceceb1d8a5ea032a35c734e55cfee919340889973623096da7eb94f6b/kuzu-0.11.2-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:989a87fa13ffa39ab7773d968fe739ac4f8faf9ddb5dad72ced2eeef12180293", size = 6216814, upload-time = "2025-08-21T05:16:38.348Z" },
{ url = "https://files.pythonhosted.org/packages/bc/cd/933b34a246edb882a042eb402747167719222c05149b73b48ba7d310d554/kuzu-0.11.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e67420d04a9643fd6376a23b17b398a3e32bb0c2bd8abbf8d1e4697056596c7e", size = 6996343, upload-time = "2025-08-21T05:16:39.973Z" },
]
[[package]] [[package]]
name = "langchain-anthropic" name = "langchain-anthropic"
version = "0.3.17" version = "0.3.17"