diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml
index 3096524a..cf1053a1 100644
--- a/.github/workflows/unit_tests.yml
+++ b/.github/workflows/unit_tests.yml
@@ -49,6 +49,7 @@ jobs:
NEO4J_URI: bolt://localhost:7687
NEO4J_USER: neo4j
NEO4J_PASSWORD: testpass
+ DISABLE_NEPTUNE: 1
run: |
uv run pytest -m "not integration"
- name: Wait for FalkorDB
diff --git a/README.md b/README.md
index e929086e..db09e137 100644
--- a/README.md
+++ b/README.md
@@ -44,7 +44,7 @@ Use Graphiti to:
-
+
@@ -80,7 +80,7 @@ Traditional RAG approaches often rely on batch processing and static data summar
- **Scalability:** Efficiently manages large datasets with parallel processing, suitable for enterprise environments.
-
+
## Graphiti vs. GraphRAG
@@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
Requirements:
- Python 3.10 or higher
-- Neo4j 5.26 / FalkorDB 1.1.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
+- Neo4j 5.26 / FalkorDB 1.1.2 / Kuzu 0.11.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
> [!IMPORTANT]
@@ -148,6 +148,17 @@ pip install graphiti-core[falkordb]
uv add graphiti-core[falkordb]
```
+### Installing with Kuzu Support
+
+If you plan to use Kuzu as your graph database backend, install with the Kuzu extra:
+
+```bash
+pip install graphiti-core[kuzu]
+
+# or with uv
+uv add graphiti-core[kuzu]
+```
+
### Installing with Amazon Neptune Support
If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra:
@@ -198,7 +209,7 @@ If your LLM provider allows higher throughput, you can increase `SEMAPHORE_LIMIT
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
-1. Connecting to a Neo4j, Amazon Neptune, or FalkorDB database
+1. Connecting to a Neo4j, Amazon Neptune, FalkorDB, or Kuzu database
2. Initializing Graphiti indices and constraints
3. Adding episodes to the graph (both text and structured JSON)
4. Searching for relationships (edges) using hybrid search
@@ -281,6 +292,19 @@ driver = FalkorDriver(
graphiti = Graphiti(graph_driver=driver)
```
+#### Kuzu
+
+```python
+from graphiti_core import Graphiti
+from graphiti_core.driver.kuzu_driver import KuzuDriver
+
+# Create a Kuzu driver
+driver = KuzuDriver(db="/tmp/graphiti.kuzu")
+
+# Pass the driver to Graphiti
+graphiti = Graphiti(graph_driver=driver)
+```
+
#### Amazon Neptune
```python
@@ -494,7 +518,7 @@ When you initialize a Graphiti instance, we collect:
- **Graphiti version**: The version you're using
- **Configuration choices**:
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
- - Database backend (Neo4j, FalkorDB, Amazon Neptune Database or Neptune Analytics)
+ - Database backend (Neo4j, FalkorDB, Kuzu, Amazon Neptune Database or Neptune Analytics)
- Embedder provider (OpenAI, Azure, Voyage, etc.)
### What We Don't Collect
diff --git a/conftest.py b/conftest.py
index 76019d9e..a2a31c34 100644
--- a/conftest.py
+++ b/conftest.py
@@ -4,3 +4,7 @@ import sys
# This code adds the project root directory to the Python path, allowing imports to work correctly when running tests.
# Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))
+
+from tests.helpers_test import graph_driver, mock_embedder
+
+__all__ = ['graph_driver', 'mock_embedder']
diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py
index a9fec69c..670a7426 100644
--- a/graphiti_core/driver/driver.py
+++ b/graphiti_core/driver/driver.py
@@ -27,10 +27,13 @@ logger = logging.getLogger(__name__)
class GraphProvider(Enum):
NEO4J = 'neo4j'
FALKORDB = 'falkordb'
+ KUZU = 'kuzu'
NEPTUNE = 'neptune'
class GraphDriverSession(ABC):
+ provider: GraphProvider
+
async def __aenter__(self):
return self
diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py
index f121319b..00c39342 100644
--- a/graphiti_core/driver/falkordb_driver.py
+++ b/graphiti_core/driver/falkordb_driver.py
@@ -15,7 +15,6 @@ limitations under the License.
"""
import logging
-from datetime import datetime
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
@@ -33,11 +32,14 @@ else:
) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
+from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
logger = logging.getLogger(__name__)
class FalkorDriverSession(GraphDriverSession):
+ provider = GraphProvider.FALKORDB
+
def __init__(self, graph: FalkorGraph):
self.graph = graph
@@ -164,16 +166,3 @@ class FalkorDriver(GraphDriver):
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned
-
-
-def convert_datetimes_to_strings(obj):
- if isinstance(obj, dict):
- return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
- elif isinstance(obj, list):
- return [convert_datetimes_to_strings(item) for item in obj]
- elif isinstance(obj, tuple):
- return tuple(convert_datetimes_to_strings(item) for item in obj)
- elif isinstance(obj, datetime):
- return obj.isoformat()
- else:
- return obj
diff --git a/graphiti_core/driver/kuzu_driver.py b/graphiti_core/driver/kuzu_driver.py
new file mode 100644
index 00000000..af371b2f
--- /dev/null
+++ b/graphiti_core/driver/kuzu_driver.py
@@ -0,0 +1,175 @@
+"""
+Copyright 2024, Zep Software, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import logging
+from typing import Any
+
+import kuzu
+
+from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
+
+logger = logging.getLogger(__name__)
+
+# Kuzu requires an explicit schema.
+# As Kuzu currently does not support creating full text indexes on edge properties,
+# we work around this by representing (n:Entity)-[:RELATES_TO]->(m:Entity) as
+# (n)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m).
+SCHEMA_QUERIES = """
+ CREATE NODE TABLE IF NOT EXISTS Episodic (
+ uuid STRING PRIMARY KEY,
+ name STRING,
+ group_id STRING,
+ created_at TIMESTAMP,
+ source STRING,
+ source_description STRING,
+ content STRING,
+ valid_at TIMESTAMP,
+ entity_edges STRING[]
+ );
+ CREATE NODE TABLE IF NOT EXISTS Entity (
+ uuid STRING PRIMARY KEY,
+ name STRING,
+ group_id STRING,
+ labels STRING[],
+ created_at TIMESTAMP,
+ name_embedding FLOAT[],
+ summary STRING,
+ attributes STRING
+ );
+ CREATE NODE TABLE IF NOT EXISTS Community (
+ uuid STRING PRIMARY KEY,
+ name STRING,
+ group_id STRING,
+ created_at TIMESTAMP,
+ name_embedding FLOAT[],
+ summary STRING
+ );
+ CREATE NODE TABLE IF NOT EXISTS RelatesToNode_ (
+ uuid STRING PRIMARY KEY,
+ group_id STRING,
+ created_at TIMESTAMP,
+ name STRING,
+ fact STRING,
+ fact_embedding FLOAT[],
+ episodes STRING[],
+ expired_at TIMESTAMP,
+ valid_at TIMESTAMP,
+ invalid_at TIMESTAMP,
+ attributes STRING
+ );
+ CREATE REL TABLE IF NOT EXISTS RELATES_TO(
+ FROM Entity TO RelatesToNode_,
+ FROM RelatesToNode_ TO Entity
+ );
+ CREATE REL TABLE IF NOT EXISTS MENTIONS(
+ FROM Episodic TO Entity,
+ uuid STRING PRIMARY KEY,
+ group_id STRING,
+ created_at TIMESTAMP
+ );
+ CREATE REL TABLE IF NOT EXISTS HAS_MEMBER(
+ FROM Community TO Entity,
+ FROM Community TO Community,
+ uuid STRING,
+ group_id STRING,
+ created_at TIMESTAMP
+ );
+"""
+
+
+class KuzuDriver(GraphDriver):
+ provider: GraphProvider = GraphProvider.KUZU
+
+ def __init__(
+ self,
+ db: str = ':memory:',
+ max_concurrent_queries: int = 1,
+ ):
+ super().__init__()
+ self.db = kuzu.Database(db)
+
+ self.setup_schema()
+
+ self.client = kuzu.AsyncConnection(self.db, max_concurrent_queries=max_concurrent_queries)
+
+ async def execute_query(
+ self, cypher_query_: str, **kwargs: Any
+ ) -> tuple[list[dict[str, Any]] | list[list[dict[str, Any]]], None, None]:
+ params = {k: v for k, v in kwargs.items() if v is not None}
+ # Kuzu does not support these parameters.
+ params.pop('database_', None)
+ params.pop('routing_', None)
+
+ try:
+ results = await self.client.execute(cypher_query_, parameters=params)
+ except Exception as e:
+ params = {k: (v[:5] if isinstance(v, list) else v) for k, v in params.items()}
+ logger.error(f'Error executing Kuzu query: {e}\n{cypher_query_}\n{params}')
+ raise
+
+ if not results:
+ return [], None, None
+
+ if isinstance(results, list):
+ dict_results = [list(result.rows_as_dict()) for result in results]
+ else:
+ dict_results = list(results.rows_as_dict())
+ return dict_results, None, None # type: ignore
+
+ def session(self, _database: str | None = None) -> GraphDriverSession:
+ return KuzuDriverSession(self)
+
+ async def close(self):
+ # Do not explicity close the connection, instead rely on GC.
+ pass
+
+ def delete_all_indexes(self, database_: str):
+ pass
+
+ def setup_schema(self):
+ conn = kuzu.Connection(self.db)
+ conn.execute(SCHEMA_QUERIES)
+ conn.close()
+
+
+class KuzuDriverSession(GraphDriverSession):
+ provider = GraphProvider.KUZU
+
+ def __init__(self, driver: KuzuDriver):
+ self.driver = driver
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ # No cleanup needed for Kuzu, but method must exist.
+ pass
+
+ async def close(self):
+ # Do not close the session here, as we're reusing the driver connection.
+ pass
+
+ async def execute_write(self, func, *args, **kwargs):
+ # Directly await the provided async function with `self` as the transaction/session
+ return await func(self, *args, **kwargs)
+
+ async def run(self, query: str | list, **kwargs: Any) -> Any:
+ if isinstance(query, list):
+ for cypher, params in query:
+ await self.driver.execute_query(cypher, **params)
+ else:
+ await self.driver.execute_query(query, **kwargs)
+ return None
diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py
index 65ca59a5..25aa12c3 100644
--- a/graphiti_core/driver/neptune_driver.py
+++ b/graphiti_core/driver/neptune_driver.py
@@ -271,6 +271,8 @@ class NeptuneDriver(GraphDriver):
class NeptuneDriverSession(GraphDriverSession):
+ provider = GraphProvider.NEPTUNE
+
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
self.driver = driver
diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py
index fa2eefd4..a427d65e 100644
--- a/graphiti_core/edges.py
+++ b/graphiti_core/edges.py
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
+import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
@@ -30,11 +31,10 @@ from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_RETURN,
- ENTITY_EDGE_RETURN,
- ENTITY_EDGE_RETURN_NEPTUNE,
EPISODIC_EDGE_RETURN,
EPISODIC_EDGE_SAVE,
get_community_edge_save_query,
+ get_entity_edge_return_query,
get_entity_edge_save_query,
)
from graphiti_core.nodes import Node
@@ -53,33 +53,63 @@ class Edge(BaseModel, ABC):
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
- result = await driver.execute_query(
- """
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
- DELETE e
- """,
- uuid=self.uuid,
- )
+ if driver.provider == GraphProvider.KUZU:
+ await driver.execute_query(
+ """
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
+ DELETE e
+ """,
+ uuid=self.uuid,
+ )
+ await driver.execute_query(
+ """
+ MATCH (e:RelatesToNode_ {uuid: $uuid})
+ DETACH DELETE e
+ """,
+ uuid=self.uuid,
+ )
+ else:
+ await driver.execute_query(
+ """
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
+ DELETE e
+ """,
+ uuid=self.uuid,
+ )
logger.debug(f'Deleted Edge: {self.uuid}')
- return result
-
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
- result = await driver.execute_query(
- """
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
- WHERE e.uuid IN $uuids
- DELETE e
- """,
- uuids=uuids,
- )
+ if driver.provider == GraphProvider.KUZU:
+ await driver.execute_query(
+ """
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
+ WHERE e.uuid IN $uuids
+ DELETE e
+ """,
+ uuids=uuids,
+ )
+ await driver.execute_query(
+ """
+ MATCH (e:RelatesToNode_)
+ WHERE e.uuid IN $uuids
+ DETACH DELETE e
+ """,
+ uuids=uuids,
+ )
+ else:
+ await driver.execute_query(
+ """
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
+ WHERE e.uuid IN $uuids
+ DELETE e
+ """,
+ uuids=uuids,
+ )
logger.debug(f'Deleted Edges: {uuids}')
- return result
-
def __hash__(self):
return hash(self.uuid)
@@ -166,7 +196,7 @@ class EpisodicEdge(Edge):
"""
+ EPISODIC_EDGE_RETURN
+ """
- ORDER BY e.uuid DESC
+ ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
@@ -215,15 +245,21 @@ class EntityEdge(Edge):
return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver):
- if driver.provider == GraphProvider.NEPTUNE:
- query: LiteralString = """
- MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
- RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
- """
- else:
- query: LiteralString = """
+ query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
+ """
+
+ if driver.provider == GraphProvider.NEPTUNE:
+ query = """
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
+ RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
+ """
+
+ if driver.provider == GraphProvider.KUZU:
+ query = """
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
+ RETURN e.fact_embedding AS fact_embedding
"""
records, _, _ = await driver.execute_query(
@@ -253,15 +289,22 @@ class EntityEdge(Edge):
'invalid_at': self.invalid_at,
}
- edge_data.update(self.attributes or {})
+ if driver.provider == GraphProvider.KUZU:
+ edge_data['attributes'] = json.dumps(self.attributes)
+ result = await driver.execute_query(
+ get_entity_edge_save_query(driver.provider),
+ **edge_data,
+ )
+ else:
+ edge_data.update(self.attributes or {})
- if driver.provider == GraphProvider.NEPTUNE:
- driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
+ if driver.provider == GraphProvider.NEPTUNE:
+ driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
- result = await driver.execute_query(
- get_entity_edge_save_query(driver.provider),
- edge_data=edge_data,
- )
+ result = await driver.execute_query(
+ get_entity_edge_save_query(driver.provider),
+ edge_data=edge_data,
+ )
logger.debug(f'Saved edge to Graph: {self.uuid}')
@@ -269,21 +312,25 @@ class EntityEdge(Edge):
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
- records, _, _ = await driver.execute_query(
- """
+ match_query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
+ """
+
+ records, _, _ = await driver.execute_query(
+ match_query
+ + """
RETURN
"""
- + (
- ENTITY_EDGE_RETURN_NEPTUNE
- if driver.provider == GraphProvider.NEPTUNE
- else ENTITY_EDGE_RETURN
- ),
+ + get_entity_edge_return_query(driver.provider),
uuid=uuid,
routing_='r',
)
- edges = [get_entity_edge_from_record(record) for record in records]
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0:
raise EdgeNotFoundError(uuid)
@@ -294,22 +341,26 @@ class EntityEdge(Edge):
if len(uuids) == 0:
return []
- records, _, _ = await driver.execute_query(
- """
+ match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
+ """
+
+ records, _, _ = await driver.execute_query(
+ match_query
+ + """
WHERE e.uuid IN $uuids
RETURN
"""
- + (
- ENTITY_EDGE_RETURN_NEPTUNE
- if driver.provider == GraphProvider.NEPTUNE
- else ENTITY_EDGE_RETURN
- ),
+ + get_entity_edge_return_query(driver.provider),
uuids=uuids,
routing_='r',
)
- edges = [get_entity_edge_from_record(record) for record in records]
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@@ -332,23 +383,27 @@ class EntityEdge(Edge):
else ''
)
- records, _, _ = await driver.execute_query(
- """
+ match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
+ """
+
+ records, _, _ = await driver.execute_query(
+ match_query
+ + """
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
- + (
- ENTITY_EDGE_RETURN_NEPTUNE
- if driver.provider == GraphProvider.NEPTUNE
- else ENTITY_EDGE_RETURN
- )
+ + get_entity_edge_return_query(driver.provider)
+ with_embeddings_query
+ """
- ORDER BY e.uuid DESC
+ ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
@@ -357,7 +412,7 @@ class EntityEdge(Edge):
routing_='r',
)
- edges = [get_entity_edge_from_record(record) for record in records]
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0:
raise GroupsEdgesNotFoundError(group_ids)
@@ -365,21 +420,25 @@ class EntityEdge(Edge):
@classmethod
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
- records, _, _ = await driver.execute_query(
- """
+ match_query = """
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
+ """
+
+ records, _, _ = await driver.execute_query(
+ match_query
+ + """
RETURN
"""
- + (
- ENTITY_EDGE_RETURN_NEPTUNE
- if driver.provider == GraphProvider.NEPTUNE
- else ENTITY_EDGE_RETURN
- ),
+ + get_entity_edge_return_query(driver.provider),
node_uuid=node_uuid,
routing_='r',
)
- edges = [get_entity_edge_from_record(record) for record in records]
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@@ -479,7 +538,25 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
)
-def get_entity_edge_from_record(record: Any) -> EntityEdge:
+def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
+ episodes = record['episodes']
+ if provider == GraphProvider.KUZU:
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
+ else:
+ attributes = record['attributes']
+ attributes.pop('uuid', None)
+ attributes.pop('source_node_uuid', None)
+ attributes.pop('target_node_uuid', None)
+ attributes.pop('fact', None)
+ attributes.pop('fact_embedding', None)
+ attributes.pop('name', None)
+ attributes.pop('group_id', None)
+ attributes.pop('episodes', None)
+ attributes.pop('created_at', None)
+ attributes.pop('expired_at', None)
+ attributes.pop('valid_at', None)
+ attributes.pop('invalid_at', None)
+
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
@@ -488,26 +565,14 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
fact_embedding=record.get('fact_embedding'),
name=record['name'],
group_id=record['group_id'],
- episodes=record['episodes'],
+ episodes=episodes,
created_at=parse_db_date(record['created_at']), # type: ignore
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
- attributes=record['attributes'],
+ attributes=attributes,
)
- edge.attributes.pop('uuid', None)
- edge.attributes.pop('source_node_uuid', None)
- edge.attributes.pop('target_node_uuid', None)
- edge.attributes.pop('fact', None)
- edge.attributes.pop('name', None)
- edge.attributes.pop('group_id', None)
- edge.attributes.pop('episodes', None)
- edge.attributes.pop('created_at', None)
- edge.attributes.pop('expired_at', None)
- edge.attributes.pop('valid_at', None)
- edge.attributes.pop('invalid_at', None)
-
return edge
diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py
index ec1872fe..71fa0547 100644
--- a/graphiti_core/graph_queries.py
+++ b/graphiti_core/graph_queries.py
@@ -16,6 +16,13 @@ NEO4J_TO_FALKORDB_MAPPING = {
'episode_content': 'Episodic',
'edge_name_and_fact': 'RELATES_TO',
}
+# Mapping from fulltext index names to Kuzu node labels
+INDEX_TO_LABEL_KUZU_MAPPING = {
+ 'node_name_and_summary': 'Entity',
+ 'community_name': 'Community',
+ 'episode_content': 'Episodic',
+ 'edge_name_and_fact': 'RelatesToNode_',
+}
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
@@ -35,6 +42,9 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
]
+ if provider == GraphProvider.KUZU:
+ return []
+
return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
@@ -68,6 +78,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
]
+ if provider == GraphProvider.KUZU:
+ return [
+ "CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
+ "CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
+ "CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
+ "CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
+ ]
+
return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
@@ -80,11 +98,15 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
]
-def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
+def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
+ if provider == GraphProvider.KUZU:
+ label = INDEX_TO_LABEL_KUZU_MAPPING[name]
+ return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
+
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
@@ -93,12 +115,19 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
+ if provider == GraphProvider.KUZU:
+ return f'array_cosine_similarity({vec1}, {vec2})'
+
return f'vector.similarity.cosine({vec1}, {vec2})'
-def get_relationships_query(name: str, provider: GraphProvider) -> str:
+def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
+ if provider == GraphProvider.KUZU:
+ label = INDEX_TO_LABEL_KUZU_MAPPING[name]
+ return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
+
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index 75d64f19..3dae32a0 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -1070,7 +1070,7 @@ class Graphiti:
if record['episode_count'] == 1:
nodes_to_delete.append(node)
+ await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
- await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
await episode.delete(self.driver)
diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py
index 9feb3073..b1de852a 100644
--- a/graphiti_core/helpers.py
+++ b/graphiti_core/helpers.py
@@ -43,14 +43,14 @@ RUNTIME_QUERY: LiteralString = (
)
-def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
- return (
- neo_date.to_native()
- if isinstance(neo_date, neo4j_time.DateTime)
- else datetime.fromisoformat(neo_date)
- if neo_date
- else None
- )
+def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
+ if isinstance(input_date, neo4j_time.DateTime):
+ return input_date.to_native()
+
+ if isinstance(input_date, str):
+ return datetime.fromisoformat(input_date)
+
+ return input_date
def get_default_group_id(provider: GraphProvider) -> str:
diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py
index d9b68405..5b8d5402 100644
--- a/graphiti_core/models/edges/edge_db_queries.py
+++ b/graphiti_core/models/edges/edge_db_queries.py
@@ -20,18 +20,36 @@ EPISODIC_EDGE_SAVE = """
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
- SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
+ SET
+ e.group_id = $group_id,
+ e.created_at = $created_at
RETURN e.uuid AS uuid
"""
-EPISODIC_EDGE_SAVE_BULK = """
- UNWIND $episodic_edges AS edge
- MATCH (episode:Episodic {uuid: edge.source_node_uuid})
- MATCH (node:Entity {uuid: edge.target_node_uuid})
- MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
- SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
- RETURN e.uuid AS uuid
-"""
+
+def get_episodic_edge_save_bulk_query(provider: GraphProvider) -> str:
+ if provider == GraphProvider.KUZU:
+ return """
+ MATCH (episode:Episodic {uuid: $source_node_uuid})
+ MATCH (node:Entity {uuid: $target_node_uuid})
+ MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
+ SET
+ e.group_id = $group_id,
+ e.created_at = $created_at
+ RETURN e.uuid AS uuid
+ """
+
+ return """
+ UNWIND $episodic_edges AS edge
+ MATCH (episode:Episodic {uuid: edge.source_node_uuid})
+ MATCH (node:Entity {uuid: edge.target_node_uuid})
+ MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
+ SET
+ e.group_id = edge.group_id,
+ e.created_at = edge.created_at
+ RETURN e.uuid AS uuid
+ """
+
EPISODIC_EDGE_RETURN = """
e.uuid AS uuid,
@@ -54,14 +72,32 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
"""
case GraphProvider.NEPTUNE:
return """
- MATCH (source:Entity {uuid: $edge_data.source_uuid})
- MATCH (target:Entity {uuid: $edge_data.target_uuid})
+ MATCH (source:Entity {uuid: $edge_data.source_uuid})
+ MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
SET e.episodes = join($edge_data.episodes, ",")
RETURN $edge_data.uuid AS uuid
"""
+ case GraphProvider.KUZU:
+ return """
+ MATCH (source:Entity {uuid: $source_uuid})
+ MATCH (target:Entity {uuid: $target_uuid})
+ MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
+ SET
+ e.group_id = $group_id,
+ e.created_at = $created_at,
+ e.name = $name,
+ e.fact = $fact,
+ e.fact_embedding = $fact_embedding,
+ e.episodes = $episodes,
+ e.expired_at = $expired_at,
+ e.valid_at = $valid_at,
+ e.invalid_at = $invalid_at,
+ e.attributes = $attributes
+ RETURN e.uuid AS uuid
+ """
case _: # Neo4j
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
@@ -89,14 +125,32 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE:
return """
UNWIND $entity_edges AS edge
- MATCH (source:Entity {uuid: edge.source_node_uuid})
- MATCH (target:Entity {uuid: edge.target_node_uuid})
+ MATCH (source:Entity {uuid: edge.source_node_uuid})
+ MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
SET r.episodes = join(edge.episodes, ",")
RETURN edge.uuid AS uuid
"""
+ case GraphProvider.KUZU:
+ return """
+ MATCH (source:Entity {uuid: $source_node_uuid})
+ MATCH (target:Entity {uuid: $target_node_uuid})
+ MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
+ SET
+ e.group_id = $group_id,
+ e.created_at = $created_at,
+ e.name = $name,
+ e.fact = $fact,
+ e.fact_embedding = $fact_embedding,
+ e.episodes = $episodes,
+ e.expired_at = $expired_at,
+ e.valid_at = $valid_at,
+ e.invalid_at = $invalid_at,
+ e.attributes = $attributes
+ RETURN e.uuid AS uuid
+ """
case _:
return """
UNWIND $entity_edges AS edge
@@ -109,35 +163,42 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
"""
-ENTITY_EDGE_RETURN = """
- e.uuid AS uuid,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
- e.group_id AS group_id,
- e.name AS name,
- e.fact AS fact,
- e.episodes AS episodes,
- e.created_at AS created_at,
- e.expired_at AS expired_at,
- e.valid_at AS valid_at,
- e.invalid_at AS invalid_at,
- properties(e) AS attributes
-"""
+def get_entity_edge_return_query(provider: GraphProvider) -> str:
+ # `fact_embedding` is not returned by default and must be manually loaded using `load_fact_embedding()`.
-ENTITY_EDGE_RETURN_NEPTUNE = """
- e.uuid AS uuid,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
- e.group_id AS group_id,
- e.name AS name,
- e.fact AS fact,
- split(e.episodes, ',') AS episodes,
- e.created_at AS created_at,
- e.expired_at AS expired_at,
- e.valid_at AS valid_at,
- e.invalid_at AS invalid_at,
- properties(e) AS attributes
-"""
+ if provider == GraphProvider.NEPTUNE:
+ return """
+ e.uuid AS uuid,
+ n.uuid AS source_node_uuid,
+ m.uuid AS target_node_uuid,
+ e.group_id AS group_id,
+ e.name AS name,
+ e.fact AS fact,
+ split(e.episodes, ',') AS episodes,
+ e.created_at AS created_at,
+ e.expired_at AS expired_at,
+ e.valid_at AS valid_at,
+ e.invalid_at AS invalid_at,
+ properties(e) AS attributes
+ """
+
+ return """
+ e.uuid AS uuid,
+ n.uuid AS source_node_uuid,
+ m.uuid AS target_node_uuid,
+ e.group_id AS group_id,
+ e.created_at AS created_at,
+ e.name AS name,
+ e.fact AS fact,
+ e.episodes AS episodes,
+ e.expired_at AS expired_at,
+ e.valid_at AS valid_at,
+ e.invalid_at AS invalid_at,
+ """ + (
+ 'e.attributes AS attributes'
+ if provider == GraphProvider.KUZU
+ else 'properties(e) AS attributes'
+ )
def get_community_edge_save_query(provider: GraphProvider) -> str:
@@ -152,7 +213,7 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
"""
case GraphProvider.NEPTUNE:
return """
- MATCH (community:Community {uuid: $community_uuid})
+ MATCH (community:Community {uuid: $community_uuid})
MATCH (node {uuid: $entity_uuid})
WHERE node:Entity OR node:Community
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
@@ -161,6 +222,24 @@ def get_community_edge_save_query(provider: GraphProvider) -> str:
SET r.created_at= $created_at
RETURN r.uuid AS uuid
"""
+ case GraphProvider.KUZU:
+ return """
+ MATCH (community:Community {uuid: $community_uuid})
+ MATCH (node:Entity {uuid: $entity_uuid})
+ MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
+ SET
+ e.group_id = $group_id,
+ e.created_at = $created_at
+ RETURN e.uuid AS uuid
+ UNION
+ MATCH (community:Community {uuid: $community_uuid})
+ MATCH (node:Community {uuid: $entity_uuid})
+ MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
+ SET
+ e.group_id = $group_id,
+ e.created_at = $created_at
+ RETURN e.uuid AS uuid
+ """
case _: # Neo4j
return """
MATCH (community:Community {uuid: $community_uuid})
diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py
index 16f4031e..9627e566 100644
--- a/graphiti_core/models/nodes/node_db_queries.py
+++ b/graphiti_core/models/nodes/node_db_queries.py
@@ -24,10 +24,24 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE:
return """
MERGE (n:Episodic {uuid: $uuid})
- SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
+ case GraphProvider.KUZU:
+ return """
+ MERGE (n:Episodic {uuid: $uuid})
+ SET
+ n.name = $name,
+ n.group_id = $group_id,
+ n.created_at = $created_at,
+ n.source = $source,
+ n.source_description = $source_description,
+ n.content = $content,
+ n.valid_at = $valid_at,
+ n.entity_edges = $entity_edges
+ RETURN n.uuid AS uuid
+ """
case GraphProvider.FALKORDB:
return """
MERGE (n:Episodic {uuid: $uuid})
@@ -51,11 +65,25 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
- SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
- source: episode.source, content: episode.content,
+ SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
+ source: episode.source, content: episode.content,
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
+ case GraphProvider.KUZU:
+ return """
+ MERGE (n:Episodic {uuid: $uuid})
+ SET
+ n.name = $name,
+ n.group_id = $group_id,
+ n.created_at = $created_at,
+ n.source = $source,
+ n.source_description = $source_description,
+ n.content = $content,
+ n.valid_at = $valid_at,
+ n.entity_edges = $entity_edges
+ RETURN n.uuid AS uuid
+ """
case GraphProvider.FALKORDB:
return """
UNWIND $episodes AS episode
@@ -76,14 +104,14 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
EPISODIC_NODE_RETURN = """
- e.content AS content,
- e.created_at AS created_at,
- e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
- e.source_description AS source_description,
+ e.created_at AS created_at,
e.source AS source,
+ e.source_description AS source_description,
+ e.content AS content,
+ e.valid_at AS valid_at,
e.entity_edges AS entity_edges
"""
@@ -109,6 +137,20 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
SET n = $entity_data
RETURN n.uuid AS uuid
"""
+ case GraphProvider.KUZU:
+ return """
+ MERGE (n:Entity {uuid: $uuid})
+ SET
+ n.name = $name,
+ n.group_id = $group_id,
+ n.labels = $labels,
+ n.created_at = $created_at,
+ n.name_embedding = $name_embedding,
+ n.summary = $summary,
+ n.attributes = $attributes
+ WITH n
+ RETURN n.uuid AS uuid
+ """
case GraphProvider.NEPTUNE:
label_subquery = ''
for label in labels.split(':'):
@@ -168,6 +210,19 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
"""
)
return queries
+ case GraphProvider.KUZU:
+ return """
+ MERGE (n:Entity {uuid: $uuid})
+ SET
+ n.name = $name,
+ n.group_id = $group_id,
+ n.labels = $labels,
+ n.created_at = $created_at,
+ n.name_embedding = $name_embedding,
+ n.summary = $summary,
+ n.attributes = $attributes
+ RETURN n.uuid AS uuid
+ """
case _: # Neo4j
return """
UNWIND $nodes AS node
@@ -179,15 +234,28 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
"""
-ENTITY_NODE_RETURN = """
- n.uuid AS uuid,
- n.name AS name,
- n.group_id AS group_id,
- n.created_at AS created_at,
- n.summary AS summary,
- labels(n) AS labels,
- properties(n) AS attributes
-"""
+def get_entity_node_return_query(provider: GraphProvider) -> str:
+ # `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
+ if provider == GraphProvider.KUZU:
+ return """
+ n.uuid AS uuid,
+ n.name AS name,
+ n.group_id AS group_id,
+ n.labels AS labels,
+ n.created_at AS created_at,
+ n.summary AS summary,
+ n.attributes AS attributes
+ """
+
+ return """
+ n.uuid AS uuid,
+ n.name AS name,
+ n.group_id AS group_id,
+ n.created_at AS created_at,
+ n.summary AS summary,
+ labels(n) AS labels,
+ properties(n) AS attributes
+ """
def get_community_node_save_query(provider: GraphProvider) -> str:
@@ -201,10 +269,21 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
case GraphProvider.NEPTUNE:
return """
MERGE (n:Community {uuid: $uuid})
- SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
+ case GraphProvider.KUZU:
+ return """
+ MERGE (n:Community {uuid: $uuid})
+ SET
+ n.name = $name,
+ n.group_id = $group_id,
+ n.created_at = $created_at,
+ n.name_embedding = $name_embedding,
+ n.summary = $summary
+ RETURN n.uuid AS uuid
+ """
case _: # Neo4j
return """
MERGE (n:Community {uuid: $uuid})
@@ -215,12 +294,12 @@ def get_community_node_save_query(provider: GraphProvider) -> str:
COMMUNITY_NODE_RETURN = """
- n.uuid AS uuid,
- n.name AS name,
- n.name_embedding AS name_embedding,
- n.group_id AS group_id,
- n.summary AS summary,
- n.created_at AS created_at
+ c.uuid AS uuid,
+ c.name AS name,
+ c.group_id AS group_id,
+ c.created_at AS created_at,
+ c.name_embedding AS name_embedding,
+ c.summary AS summary
"""
COMMUNITY_NODE_RETURN_NEPTUNE = """
diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py
index 98bafab6..4080fcc6 100644
--- a/graphiti_core/nodes.py
+++ b/graphiti_core/nodes.py
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
+import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
@@ -32,10 +33,10 @@ from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_RETURN,
COMMUNITY_NODE_RETURN_NEPTUNE,
- ENTITY_NODE_RETURN,
EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE,
get_community_node_save_query,
+ get_entity_node_return_query,
get_entity_node_save_query,
get_episode_node_save_query,
)
@@ -95,12 +96,37 @@ class Node(BaseModel, ABC):
case GraphProvider.NEO4J:
await driver.execute_query(
"""
- MATCH (n:Entity|Episodic|Community {uuid: $uuid})
- DETACH DELETE n
- """,
+ MATCH (n:Entity|Episodic|Community {uuid: $uuid})
+ DETACH DELETE n
+ """,
uuid=self.uuid,
)
- case _: # FalkorDB and Neptune
+ case GraphProvider.KUZU:
+ for label in ['Episodic', 'Community']:
+ await driver.execute_query(
+ f"""
+ MATCH (n:{label} {{uuid: $uuid}})
+ DETACH DELETE n
+ """,
+ uuid=self.uuid,
+ )
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
+ # Explicitly delete the "edge" nodes first, then the entity node.
+ await driver.execute_query(
+ """
+ MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
+ DETACH DELETE e
+ """,
+ uuid=self.uuid,
+ )
+ await driver.execute_query(
+ """
+ MATCH (n:Entity {uuid: $uuid})
+ DETACH DELETE n
+ """,
+ uuid=self.uuid,
+ )
+ case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
@@ -136,8 +162,32 @@ class Node(BaseModel, ABC):
group_id=group_id,
batch_size=batch_size,
)
-
- case _: # FalkorDB and Neptune
+ case GraphProvider.KUZU:
+ for label in ['Episodic', 'Community']:
+ await driver.execute_query(
+ f"""
+ MATCH (n:{label} {{group_id: $group_id}})
+ DETACH DELETE n
+ """,
+ group_id=group_id,
+ )
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
+ # Explicitly delete the "edge" nodes first, then the entity node.
+ await driver.execute_query(
+ """
+ MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
+ DETACH DELETE e
+ """,
+ group_id=group_id,
+ )
+ await driver.execute_query(
+ """
+ MATCH (n:Entity {group_id: $group_id})
+ DETACH DELETE n
+ """,
+ group_id=group_id,
+ )
+ case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
@@ -149,30 +199,59 @@ class Node(BaseModel, ABC):
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
- if driver.provider == GraphProvider.FALKORDB:
- for label in ['Entity', 'Episodic', 'Community']:
- await driver.execute_query(
- f"""
- MATCH (n:{label})
- WHERE n.uuid IN $uuids
- DETACH DELETE n
- """,
- uuids=uuids,
- )
- else:
- async with driver.session() as session:
- await session.run(
- """
- MATCH (n:Entity|Episodic|Community)
- WHERE n.uuid IN $uuids
- CALL {
- WITH n
+ match driver.provider:
+ case GraphProvider.FALKORDB:
+ for label in ['Entity', 'Episodic', 'Community']:
+ await driver.execute_query(
+ f"""
+ MATCH (n:{label})
+ WHERE n.uuid IN $uuids
DETACH DELETE n
- } IN TRANSACTIONS OF $batch_size ROWS
+ """,
+ uuids=uuids,
+ )
+ case GraphProvider.KUZU:
+ for label in ['Episodic', 'Community']:
+ await driver.execute_query(
+ f"""
+ MATCH (n:{label})
+ WHERE n.uuid IN $uuids
+ DETACH DELETE n
+ """,
+ uuids=uuids,
+ )
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
+ # Explicitly delete the "edge" nodes first, then the entity node.
+ await driver.execute_query(
+ """
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
+ WHERE n.uuid IN $uuids
+ DETACH DELETE e
""",
uuids=uuids,
- batch_size=batch_size,
)
+ await driver.execute_query(
+ """
+ MATCH (n:Entity)
+ WHERE n.uuid IN $uuids
+ DETACH DELETE n
+ """,
+ uuids=uuids,
+ )
+ case _: # Neo4J, Neptune
+ async with driver.session() as session:
+ await session.run(
+ """
+ MATCH (n:Entity|Episodic|Community)
+ WHERE n.uuid IN $uuids
+ CALL {
+ WITH n
+ DETACH DELETE n
+ } IN TRANSACTIONS OF $batch_size ROWS
+ """,
+ uuids=uuids,
+ batch_size=batch_size,
+ )
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@@ -376,17 +455,25 @@ class EntityNode(Node):
'summary': self.summary,
'created_at': self.created_at,
}
- entity_data.update(self.attributes or {})
- if driver.provider == GraphProvider.NEPTUNE:
- driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
+ if driver.provider == GraphProvider.KUZU:
+ entity_data['attributes'] = json.dumps(self.attributes)
+ entity_data['labels'] = list(set(self.labels + ['Entity']))
+ result = await driver.execute_query(
+ get_entity_node_save_query(driver.provider, labels=''),
+ **entity_data,
+ )
+ else:
+ entity_data.update(self.attributes or {})
+ labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
- labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
+ if driver.provider == GraphProvider.NEPTUNE:
+ driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
- result = await driver.execute_query(
- get_entity_node_save_query(driver.provider, labels),
- entity_data=entity_data,
- )
+ result = await driver.execute_query(
+ get_entity_node_save_query(driver.provider, labels),
+ entity_data=entity_data,
+ )
logger.debug(f'Saved Node to Graph: {self.uuid}')
@@ -399,12 +486,12 @@ class EntityNode(Node):
MATCH (n:Entity {uuid: $uuid})
RETURN
"""
- + ENTITY_NODE_RETURN,
+ + get_entity_node_return_query(driver.provider),
uuid=uuid,
routing_='r',
)
- nodes = [get_entity_node_from_record(record) for record in records]
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
@@ -419,12 +506,12 @@ class EntityNode(Node):
WHERE n.uuid IN $uuids
RETURN
"""
- + ENTITY_NODE_RETURN,
+ + get_entity_node_return_query(driver.provider),
uuids=uuids,
routing_='r',
)
- nodes = [get_entity_node_from_record(record) for record in records]
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@@ -456,7 +543,7 @@ class EntityNode(Node):
+ """
RETURN
"""
- + ENTITY_NODE_RETURN
+ + get_entity_node_return_query(driver.provider)
+ with_embeddings_query
+ """
ORDER BY n.uuid DESC
@@ -468,7 +555,7 @@ class EntityNode(Node):
routing_='r',
)
- nodes = [get_entity_node_from_record(record) for record in records]
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@@ -533,7 +620,7 @@ class CommunityNode(Node):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
- MATCH (n:Community {uuid: $uuid})
+ MATCH (c:Community {uuid: $uuid})
RETURN
"""
+ (
@@ -556,8 +643,8 @@ class CommunityNode(Node):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
- MATCH (n:Community)
- WHERE n.uuid IN $uuids
+ MATCH (c:Community)
+ WHERE c.uuid IN $uuids
RETURN
"""
+ (
@@ -581,13 +668,13 @@ class CommunityNode(Node):
limit: int | None = None,
uuid_cursor: str | None = None,
):
- cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
+ cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
- MATCH (n:Community)
- WHERE n.group_id IN $group_ids
+ MATCH (c:Community)
+ WHERE c.group_id IN $group_ids
"""
+ cursor_query
+ """
@@ -599,7 +686,7 @@ class CommunityNode(Node):
else COMMUNITY_NODE_RETURN
)
+ """
- ORDER BY n.uuid DESC
+ ORDER BY c.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
@@ -636,7 +723,19 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
)
-def get_entity_node_from_record(record: Any) -> EntityNode:
+def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
+ if provider == GraphProvider.KUZU:
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
+ else:
+ attributes = record['attributes']
+ attributes.pop('uuid', None)
+ attributes.pop('name', None)
+ attributes.pop('group_id', None)
+ attributes.pop('name_embedding', None)
+ attributes.pop('summary', None)
+ attributes.pop('created_at', None)
+ attributes.pop('labels', None)
+
entity_node = EntityNode(
uuid=record['uuid'],
name=record['name'],
@@ -645,16 +744,9 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
labels=record['labels'],
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
- attributes=record['attributes'],
+ attributes=attributes,
)
- entity_node.attributes.pop('uuid', None)
- entity_node.attributes.pop('name', None)
- entity_node.attributes.pop('group_id', None)
- entity_node.attributes.pop('name_embedding', None)
- entity_node.attributes.pop('summary', None)
- entity_node.attributes.pop('created_at', None)
-
return entity_node
diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py
index 2213688b..93cab5ba 100644
--- a/graphiti_core/search/search_filters.py
+++ b/graphiti_core/search/search_filters.py
@@ -20,6 +20,8 @@ from typing import Any
from pydantic import BaseModel, Field
+from graphiti_core.driver.driver import GraphProvider
+
class ComparisonOperator(Enum):
equals = '='
@@ -54,16 +56,21 @@ class SearchFilters(BaseModel):
def node_search_filter_query_constructor(
filters: SearchFilters,
-) -> tuple[str, dict[str, Any]]:
- filter_query: str = ''
+ provider: GraphProvider,
+) -> tuple[list[str], dict[str, Any]]:
+ filter_queries: list[str] = []
filter_params: dict[str, Any] = {}
if filters.node_labels is not None:
- node_labels = '|'.join(filters.node_labels)
- node_label_filter = ' AND n:' + node_labels
- filter_query += node_label_filter
+ if provider == GraphProvider.KUZU:
+ node_label_filter = 'list_has_all(n.labels, $labels)'
+ filter_params['labels'] = filters.node_labels
+ else:
+ node_labels = '|'.join(filters.node_labels)
+ node_label_filter = 'n:' + node_labels
+ filter_queries.append(node_label_filter)
- return filter_query, filter_params
+ return filter_queries, filter_params
def date_filter_query_constructor(
@@ -81,23 +88,29 @@ def date_filter_query_constructor(
def edge_search_filter_query_constructor(
filters: SearchFilters,
-) -> tuple[str, dict[str, Any]]:
- filter_query: str = ''
+ provider: GraphProvider,
+) -> tuple[list[str], dict[str, Any]]:
+ filter_queries: list[str] = []
filter_params: dict[str, Any] = {}
if filters.edge_types is not None:
edge_types = filters.edge_types
- edge_types_filter = '\nAND e.name in $edge_types'
- filter_query += edge_types_filter
+ filter_queries.append('e.name in $edge_types')
filter_params['edge_types'] = edge_types
if filters.node_labels is not None:
- node_labels = '|'.join(filters.node_labels)
- node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels
- filter_query += node_label_filter
+ if provider == GraphProvider.KUZU:
+ node_label_filter = (
+ 'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
+ )
+ filter_params['labels'] = filters.node_labels
+ else:
+ node_labels = '|'.join(filters.node_labels)
+ node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
+ filter_queries.append(node_label_filter)
if filters.valid_at is not None:
- valid_at_filter = '\nAND ('
+ valid_at_filter = '('
for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
@@ -125,10 +138,10 @@ def edge_search_filter_query_constructor(
else:
valid_at_filter += ' OR '
- filter_query += valid_at_filter
+ filter_queries.append(valid_at_filter)
if filters.invalid_at is not None:
- invalid_at_filter = ' AND ('
+ invalid_at_filter = '('
for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
@@ -156,10 +169,10 @@ def edge_search_filter_query_constructor(
else:
invalid_at_filter += ' OR '
- filter_query += invalid_at_filter
+ filter_queries.append(invalid_at_filter)
if filters.created_at is not None:
- created_at_filter = ' AND ('
+ created_at_filter = '('
for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
@@ -187,10 +200,10 @@ def edge_search_filter_query_constructor(
else:
created_at_filter += ' OR '
- filter_query += created_at_filter
+ filter_queries.append(created_at_filter)
if filters.expired_at is not None:
- expired_at_filter = ' AND ('
+ expired_at_filter = '('
for i, or_list in enumerate(filters.expired_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
@@ -218,6 +231,6 @@ def edge_search_filter_query_constructor(
else:
expired_at_filter += ' OR '
- filter_query += expired_at_filter
+ filter_queries.append(expired_at_filter)
- return filter_query, filter_params
+ return filter_queries, filter_params
diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py
index a24d36b7..6c61ab24 100644
--- a/graphiti_core/search/search_utils.py
+++ b/graphiti_core/search/search_utils.py
@@ -37,10 +37,13 @@ from graphiti_core.helpers import (
normalize_l2,
semaphore_gather,
)
-from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
-from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
+from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
+from graphiti_core.models.nodes.node_db_queries import (
+ COMMUNITY_NODE_RETURN,
+ EPISODIC_NODE_RETURN,
+ get_entity_node_return_query,
+)
from graphiti_core.nodes import (
- ENTITY_NODE_RETURN,
CommunityNode,
EntityNode,
EpisodicNode,
@@ -78,9 +81,16 @@ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> f
return dot_product / (norm_vector1 * norm_vector2)
-def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
+def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
+ if driver.provider == GraphProvider.KUZU:
+ # Kuzu only supports simple queries.
+ if len(query.split(' ')) > MAX_QUERY_LENGTH:
+ return ''
+ return query
group_ids_filter_list = (
- [fulltext_syntax + f'group_id:"{g}"' for g in group_ids] if group_ids is not None else []
+ [driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
+ if group_ids is not None
+ else []
)
group_ids_filter = ''
for f in group_ids_filter_list:
@@ -124,12 +134,12 @@ async def get_mentioned_nodes(
WHERE episode.uuid IN $uuids
RETURN DISTINCT
"""
- + ENTITY_NODE_RETURN,
+ + get_entity_node_return_query(driver.provider),
uuids=episode_uuids,
routing_='r',
)
- nodes = [get_entity_node_from_record(record) for record in records]
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@@ -141,7 +151,7 @@ async def get_communities_by_nodes(
records, _, _ = await driver.execute_query(
"""
- MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
WHERE m.uuid IN $uuids
RETURN DISTINCT
"""
@@ -163,11 +173,32 @@ async def edge_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
+ fuzzy_query = fulltext_query(query, group_ids, driver)
+
if fuzzy_query == '':
return []
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
+ match_query = """
+ YIELD relationship AS rel, score
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ YIELD node, score
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
+ """
+
+ filter_queries, filter_params = edge_search_filter_query_constructor(
+ search_filter, driver.provider
+ )
+
+ if group_ids is not None:
+ filter_queries.append('e.group_id IN $group_ids')
+ filter_params['group_ids'] = group_ids
+
+ filter_query = ''
+ if filter_queries:
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
@@ -187,6 +218,7 @@ async def edge_fulltext_search(
"""
+ filter_query
+ """
+ AND id(e)=id
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
RETURN
e.uuid AS uuid,
@@ -208,7 +240,6 @@ async def edge_fulltext_search(
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
- group_ids=group_ids,
ids=input_ids,
limit=limit,
routing_='r',
@@ -218,17 +249,14 @@ async def edge_fulltext_search(
return []
else:
query = (
- get_relationships_query('edge_name_and_fact', provider=driver.provider)
- + """
- YIELD relationship AS rel, score
- MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
- WHERE e.group_id IN $group_ids """
+ get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
+ + match_query
+ filter_query
+ """
WITH e, score, n, m
RETURN
"""
- + ENTITY_EDGE_RETURN
+ + get_entity_edge_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
@@ -238,13 +266,12 @@ async def edge_fulltext_search(
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
- group_ids=group_ids,
limit=limit,
routing_='r',
**filter_params,
)
- edges = [get_entity_edge_from_record(record) for record in records]
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@@ -259,32 +286,44 @@ async def edge_similarity_search(
limit: int = RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
- # vector similarity search over embedded facts
- query_params: dict[str, Any] = {}
+ match_query = """
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
+ """
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
- query_params.update(filter_params)
+ filter_queries, filter_params = edge_search_filter_query_constructor(
+ search_filter, driver.provider
+ )
- group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
if group_ids is not None:
- group_filter_query += '\nAND e.group_id IN $group_ids'
- query_params['group_ids'] = group_ids
+ filter_queries.append('e.group_id IN $group_ids')
+ filter_params['group_ids'] = group_ids
if source_node_uuid is not None:
- query_params['source_uuid'] = source_node_uuid
- group_filter_query += '\nAND (n.uuid = $source_uuid)'
+ filter_params['source_uuid'] = source_node_uuid
+ filter_queries.append('n.uuid = $source_uuid')
if target_node_uuid is not None:
- query_params['target_uuid'] = target_node_uuid
- group_filter_query += '\nAND (m.uuid = $target_uuid)'
+ filter_params['target_uuid'] = target_node_uuid
+ filter_queries.append('m.uuid = $target_uuid')
+
+ filter_query = ''
+ if filter_queries:
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
+
+ search_vector_var = '$search_vector'
+ if driver.provider == GraphProvider.KUZU:
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
if driver.provider == GraphProvider.NEPTUNE:
query = (
RUNTIME_QUERY
+ """
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
- + group_filter_query
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@@ -296,7 +335,7 @@ async def edge_similarity_search(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
if len(resp) > 0:
@@ -338,26 +377,23 @@ async def edge_similarity_search(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
else:
return []
else:
query = (
RUNTIME_QUERY
- + """
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
- """
- + group_filter_query
+ + match_query
+ filter_query
+ """
WITH DISTINCT e, n, m, """
- + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
+ + get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
"""
- + ENTITY_EDGE_RETURN
+ + get_entity_edge_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
@@ -370,10 +406,10 @@ async def edge_similarity_search(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
- edges = [get_entity_edge_from_record(record) for record in records]
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@@ -387,70 +423,116 @@ async def edge_bfs_search(
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
- if bfs_origin_node_uuids is None:
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
return []
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
-
- if driver.provider == GraphProvider.NEPTUNE:
- query = (
- f"""
- UNWIND $bfs_origin_node_uuids AS origin_uuid
- MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
- WHERE origin:Entity OR origin:Episodic
- UNWIND relationships(path) AS rel
- MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
- WHERE e.uuid = rel.uuid
- """
- + filter_query
- + """
- RETURN DISTINCT
- e.uuid AS uuid,
- e.group_id AS group_id,
- startNode(e).uuid AS source_node_uuid,
- endNode(e).uuid AS target_node_uuid,
- e.created_at AS created_at,
- e.name AS name,
- e.fact AS fact,
- split(e.episodes, ',') AS episodes,
- e.expired_at AS expired_at,
- e.valid_at AS valid_at,
- e.invalid_at AS invalid_at,
- properties(e) AS attributes
- LIMIT $limit
- """
- )
- else:
- query = (
- f"""
- UNWIND $bfs_origin_node_uuids AS origin_uuid
- MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
- UNWIND relationships(path) AS rel
- MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
- WHERE e.uuid = rel.uuid
- AND e.group_id IN $group_ids
- """
- + filter_query
- + """
- RETURN DISTINCT
- """
- + ENTITY_EDGE_RETURN
- + """
- LIMIT $limit
- """
- )
-
- records, _, _ = await driver.execute_query(
- query,
- bfs_origin_node_uuids=bfs_origin_node_uuids,
- depth=bfs_max_depth,
- group_ids=group_ids,
- limit=limit,
- routing_='r',
- **filter_params,
+ filter_queries, filter_params = edge_search_filter_query_constructor(
+ search_filter, driver.provider
)
- edges = [get_entity_edge_from_record(record) for record in records]
+ if group_ids is not None:
+ filter_queries.append('e.group_id IN $group_ids')
+ filter_params['group_ids'] = group_ids
+
+ filter_query = ''
+ if filter_queries:
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
+
+ if driver.provider == GraphProvider.KUZU:
+ # Kuzu stores entity edges twice with an intermediate node, so we need to match them
+ # separately for the correct BFS depth.
+ depth = bfs_max_depth * 2 - 1
+ match_queries = [
+ f"""
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
+ UNWIND nodes(path) AS relNode
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
+ """,
+ ]
+ if bfs_max_depth > 1:
+ depth = (bfs_max_depth - 1) * 2 - 1
+ match_queries.append(f"""
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
+ UNWIND nodes(path) AS relNode
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
+ """)
+
+ records = []
+ for match_query in match_queries:
+ sub_records, _, _ = await driver.execute_query(
+ match_query
+ + filter_query
+ + """
+ RETURN DISTINCT
+ """
+ + get_entity_edge_return_query(driver.provider)
+ + """
+ LIMIT $limit
+ """,
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
+ limit=limit,
+ routing_='r',
+ **filter_params,
+ )
+ records.extend(sub_records)
+ else:
+ if driver.provider == GraphProvider.NEPTUNE:
+ query = (
+ f"""
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
+ WHERE origin:Entity OR origin:Episodic
+ UNWIND relationships(path) AS rel
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
+ """
+ + filter_query
+ + """
+ RETURN DISTINCT
+ e.uuid AS uuid,
+ e.group_id AS group_id,
+ startNode(e).uuid AS source_node_uuid,
+ endNode(e).uuid AS target_node_uuid,
+ e.created_at AS created_at,
+ e.name AS name,
+ e.fact AS fact,
+ split(e.episodes, ',') AS episodes,
+ e.expired_at AS expired_at,
+ e.valid_at AS valid_at,
+ e.invalid_at AS invalid_at,
+ properties(e) AS attributes
+ LIMIT $limit
+ """
+ )
+ else:
+ query = (
+ f"""
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
+ UNWIND relationships(path) AS rel
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
+ """
+ + filter_query
+ + """
+ RETURN DISTINCT
+ """
+ + get_entity_edge_return_query(driver.provider)
+ + """
+ LIMIT $limit
+ """
+ )
+
+ records, _, _ = await driver.execute_query(
+ query,
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
+ depth=bfs_max_depth,
+ limit=limit,
+ routing_='r',
+ **filter_params,
+ )
+
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@@ -463,10 +545,25 @@ async def node_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
# BM25 search to get top nodes
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
+ fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
return []
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
+
+ filter_queries, filter_params = node_search_filter_query_constructor(
+ search_filter, driver.provider
+ )
+
+ if group_ids is not None:
+ filter_queries.append('n.group_id IN $group_ids')
+ filter_params['group_ids'] = group_ids
+
+ filter_query = ''
+ if filter_queries:
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
+
+ yield_query = 'YIELD node AS n, score'
+ if driver.provider == GraphProvider.KUZU:
+ yield_query = 'WITH node AS n, score'
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
@@ -483,8 +580,8 @@ async def node_fulltext_search(
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
- """
- + ENTITY_NODE_RETURN
+ """
+ + get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
LIMIT $limit
@@ -494,7 +591,6 @@ async def node_fulltext_search(
query,
ids=input_ids,
query=fuzzy_query,
- group_ids=group_ids,
limit=limit,
routing_='r',
**filter_params,
@@ -509,31 +605,29 @@ async def node_fulltext_search(
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = (
- get_nodes_query(driver.provider, index_name, '$query')
- + """
- YIELD node AS n, score
- WHERE n:Entity AND n.group_id IN $group_ids
- """
+ get_nodes_query(
+ index_name, '$query', limit=limit, provider=driver.provider
+ )
+ + yield_query
+ filter_query
+ """
- WITH n, score
- ORDER BY score DESC
- LIMIT $limit
- RETURN
- """
- + ENTITY_NODE_RETURN
+ WITH n, score
+ ORDER BY score DESC
+ LIMIT $limit
+ RETURN
+ """
+ + get_entity_node_return_query(driver.provider)
)
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
- group_ids=group_ids,
limit=limit,
routing_='r',
**filter_params,
)
- nodes = [get_entity_node_from_record(record) for record in records]
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@@ -546,16 +640,21 @@ async def node_similarity_search(
limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
- # vector similarity search over entity names
- query_params: dict[str, Any] = {}
+ filter_queries, filter_params = node_search_filter_query_constructor(
+ search_filter, driver.provider
+ )
- group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
if group_ids is not None:
- group_filter_query += ' AND n.group_id IN $group_ids'
- query_params['group_ids'] = group_ids
+ filter_queries.append('n.group_id IN $group_ids')
+ filter_params['group_ids'] = group_ids
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
- query_params.update(filter_params)
+ filter_query = ''
+ if filter_queries:
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
+
+ search_vector_var = '$search_vector'
+ if driver.provider == GraphProvider.KUZU:
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
if driver.provider == GraphProvider.NEPTUNE:
query = (
@@ -563,7 +662,6 @@ async def node_similarity_search(
+ """
MATCH (n:Entity)
"""
- + group_filter_query
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -571,9 +669,8 @@ async def node_similarity_search(
)
resp, header, _ = await driver.execute_query(
query,
- params=query_params,
+ params=filter_params,
search_vector=search_vector,
- group_ids=group_ids,
limit=limit,
min_score=min_score,
routing_='r',
@@ -598,7 +695,7 @@ async def node_similarity_search(
WHERE id(n)=i.id
RETURN
"""
- + ENTITY_NODE_RETURN
+ + get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
LIMIT $limit
@@ -611,7 +708,7 @@ async def node_similarity_search(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
else:
return []
@@ -623,13 +720,12 @@ async def node_similarity_search(
f"""
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
"""
- + group_filter_query
+ filter_query
+ """
AND score > $min_score
RETURN
"""
- + ENTITY_NODE_RETURN
+ + get_entity_node_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
@@ -642,7 +738,7 @@ async def node_similarity_search(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
else:
@@ -651,16 +747,15 @@ async def node_similarity_search(
+ """
MATCH (n:Entity)
"""
- + group_filter_query
+ filter_query
+ """
WITH n, """
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
+ + get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
"""
- + ENTITY_NODE_RETURN
+ + get_entity_node_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
@@ -673,10 +768,10 @@ async def node_similarity_search(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
- nodes = [get_entity_node_from_record(record) for record in records]
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@@ -689,56 +784,82 @@ async def node_bfs_search(
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
- # vector similarity search over entity names
- if bfs_origin_node_uuids is None:
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
return []
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
+ filter_queries, filter_params = node_search_filter_query_constructor(
+ search_filter, driver.provider
+ )
+
+ if group_ids is not None:
+ filter_queries.append('n.group_id IN $group_ids')
+ filter_queries.append('origin.group_id IN $group_ids')
+ filter_params['group_ids'] = group_ids
+
+ filter_query = ''
+ if filter_queries:
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
+
+ match_queries = [
+ f"""
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
+ WHERE n.group_id = origin.group_id
+ """
+ ]
if driver.provider == GraphProvider.NEPTUNE:
- query = (
+ match_queries = [
f"""
- UNWIND $bfs_origin_node_uuids AS origin_uuid
- MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
- WHERE origin:Entity OR origin.Episode
- AND n.group_id = origin.group_id
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
+ WHERE origin:Entity OR origin.Episode
+ AND n.group_id = origin.group_id
"""
- + filter_query
- + """
- RETURN
- """
- + ENTITY_NODE_RETURN
- + """
- LIMIT $limit
- """
- )
- else:
- query = (
- f"""
- UNWIND $bfs_origin_node_uuids AS origin_uuid
- MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
- WHERE n.group_id = origin.group_id
- AND origin.group_id IN $group_ids
- """
- + filter_query
- + """
- RETURN
- """
- + ENTITY_NODE_RETURN
- + """
- LIMIT $limit
- """
- )
+ ]
- records, _, _ = await driver.execute_query(
- query,
- bfs_origin_node_uuids=bfs_origin_node_uuids,
- group_ids=group_ids,
- limit=limit,
- routing_='r',
- **filter_params,
- )
- nodes = [get_entity_node_from_record(record) for record in records]
+ if driver.provider == GraphProvider.KUZU:
+ depth = bfs_max_depth * 2
+ match_queries = [
+ """
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
+ WHERE n.group_id = origin.group_id
+ """,
+ f"""
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
+ WHERE n.group_id = origin.group_id
+ """,
+ ]
+ if bfs_max_depth > 1:
+ depth = (bfs_max_depth - 1) * 2
+ match_queries.append(f"""
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
+ WHERE n.group_id = origin.group_id
+ """)
+
+ records = []
+ for match_query in match_queries:
+ sub_records, _, _ = await driver.execute_query(
+ match_query
+ + filter_query
+ + """
+ RETURN
+ """
+ + get_entity_node_return_query(driver.provider)
+ + """
+ LIMIT $limit
+ """,
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
+ limit=limit,
+ routing_='r',
+ **filter_params,
+ )
+ records.extend(sub_records)
+
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@@ -751,10 +872,16 @@ async def episode_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
# BM25 search to get top episodes
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
+ fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
return []
+ filter_params: dict[str, Any] = {}
+ group_filter_query: LiteralString = ''
+ if group_ids is not None:
+ group_filter_query += '\nAND e.group_id IN $group_ids'
+ filter_params['group_ids'] = group_ids
+
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
@@ -768,7 +895,7 @@ async def episode_fulltext_search(
UNWIND $ids as i
MATCH (e:Episodic)
WHERE e.uuid=i.id
- RETURN
+ RETURN
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
@@ -785,9 +912,9 @@ async def episode_fulltext_search(
query,
ids=input_ids,
query=fuzzy_query,
- group_ids=group_ids,
limit=limit,
routing_='r',
+ **filter_params,
)
else:
return []
@@ -799,12 +926,14 @@ async def episode_fulltext_search(
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = (
- get_nodes_query(driver.provider, index_name, '$query')
+ get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
+ """
YIELD node AS episode, score
MATCH (e:Episodic)
WHERE e.uuid = episode.uuid
- AND e.group_id IN $group_ids
+ """
+ + group_filter_query
+ + """
RETURN
"""
+ EPISODIC_NODE_RETURN
@@ -815,12 +944,9 @@ async def episode_fulltext_search(
)
records, _, _ = await driver.execute_query(
- query,
- query=fuzzy_query,
- group_ids=group_ids,
- limit=limit,
- routing_='r',
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
)
+
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
@@ -833,10 +959,20 @@ async def community_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
# BM25 search to get top communities
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
+ fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
return []
+ filter_params: dict[str, Any] = {}
+ group_filter_query: LiteralString = ''
+ if group_ids is not None:
+ group_filter_query = 'WHERE c.group_id IN $group_ids'
+ filter_params['group_ids'] = group_ids
+
+ yield_query = 'YIELD node AS c, score'
+ if driver.provider == GraphProvider.KUZU:
+ yield_query = 'WITH node AS c, score'
+
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
@@ -852,9 +988,9 @@ async def community_fulltext_search(
WHERE comm.uuid=i.id
RETURN
comm.uuid AS uuid,
- comm.group_id AS group_id,
- comm.name AS name,
- comm.created_at AS created_at,
+ comm.group_id AS group_id,
+ comm.name AS name,
+ comm.created_at AS created_at,
comm.summary AS summary,
[x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
ORDER BY i.score DESC
@@ -864,18 +1000,21 @@ async def community_fulltext_search(
query,
ids=input_ids,
query=fuzzy_query,
- group_ids=group_ids,
limit=limit,
routing_='r',
+ **filter_params,
)
else:
return []
else:
query = (
- get_nodes_query(driver.provider, 'community_name', '$query')
+ get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
+ + yield_query
+ + """
+ WITH c, score
+ """
+ + group_filter_query
+ """
- YIELD node AS n, score
- WHERE n.group_id IN $group_ids
RETURN
"""
+ COMMUNITY_NODE_RETURN
@@ -886,12 +1025,9 @@ async def community_fulltext_search(
)
records, _, _ = await driver.execute_query(
- query,
- query=fuzzy_query,
- group_ids=group_ids,
- limit=limit,
- routing_='r',
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
)
+
communities = [get_community_node_from_record(record) for record in records]
return communities
@@ -909,7 +1045,7 @@ async def community_similarity_search(
group_filter_query: LiteralString = ''
if group_ids is not None:
- group_filter_query += 'WHERE n.group_id IN $group_ids'
+ group_filter_query += ' WHERE c.group_id IN $group_ids'
query_params['group_ids'] = group_ids
if driver.provider == GraphProvider.NEPTUNE:
@@ -951,8 +1087,8 @@ async def community_similarity_search(
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
- comm.name AS name,
- comm.created_at AS created_at,
+ comm.name AS name,
+ comm.created_at AS created_at,
comm.summary AS summary,
comm.name_embedding AS name_embedding
ORDER BY i.score DESC
@@ -970,16 +1106,20 @@ async def community_similarity_search(
else:
return []
else:
+ search_vector_var = '$search_vector'
+ if driver.provider == GraphProvider.KUZU:
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
+
query = (
RUNTIME_QUERY
+ """
- MATCH (n:Community)
+ MATCH (c:Community)
"""
+ group_filter_query
+ """
- WITH n,
+ WITH c,
"""
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
+ + get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
@@ -999,6 +1139,7 @@ async def community_similarity_search(
routing_='r',
**query_params,
)
+
communities = [get_community_node_from_record(record) for record in records]
return communities
@@ -1089,67 +1230,129 @@ async def get_relevant_nodes(
return []
group_id = nodes[0].group_id
-
- # vector similarity search over entity names
- query_params: dict[str, Any] = {}
-
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
- query_params.update(filter_params)
-
- query = (
- RUNTIME_QUERY
- + """
- UNWIND $nodes AS node
- MATCH (n:Entity {group_id: $group_id})
- """
- + filter_query
- + """
- WITH node, n, """
- + get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
- + """ AS score
- WHERE score > $min_score
- WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
- """
- + get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
- + """
- YIELD node AS m
- WHERE m.group_id = $group_id
- WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
-
- WITH node,
- top_vector_nodes,
- [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
-
- WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
-
- UNWIND combined_nodes AS combined_node
- WITH node, collect(DISTINCT combined_node) AS deduped_nodes
-
- RETURN
- node.uuid AS search_node_uuid,
- [x IN deduped_nodes | {
- uuid: x.uuid,
- name: x.name,
- name_embedding: x.name_embedding,
- group_id: x.group_id,
- created_at: x.created_at,
- summary: x.summary,
- labels: labels(x),
- attributes: properties(x)
- }] AS matches
- """
- )
-
query_nodes = [
{
'uuid': node.uuid,
'name': node.name,
'name_embedding': node.name_embedding,
- 'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax),
+ 'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
}
for node in nodes
]
+ filter_queries, filter_params = node_search_filter_query_constructor(
+ search_filter, driver.provider
+ )
+
+ filter_query = ''
+ if filter_queries:
+ filter_query = 'WHERE ' + (' AND '.join(filter_queries))
+
+ if driver.provider == GraphProvider.KUZU:
+ embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
+ if embedding_size == 0:
+ return []
+
+ # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
+ query = (
+ RUNTIME_QUERY
+ + """
+ UNWIND $nodes AS node
+ MATCH (n:Entity {group_id: $group_id})
+ """
+ + filter_query
+ + """
+ WITH node, n, """
+ + get_vector_cosine_func_query(
+ 'n.name_embedding',
+ f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
+ driver.provider,
+ )
+ + """ AS score
+ WHERE score > $min_score
+ WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
+ """
+ + get_nodes_query(
+ 'node_name_and_summary',
+ 'node.fulltext_query',
+ limit=limit,
+ provider=driver.provider,
+ )
+ + """
+ WITH node AS m
+ WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
+ WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
+
+ WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
+
+ UNWIND combined_nodes AS x
+ WITH node, collect(DISTINCT {
+ uuid: x.uuid,
+ name: x.name,
+ name_embedding: x.name_embedding,
+ group_id: x.group_id,
+ created_at: x.created_at,
+ summary: x.summary,
+ labels: x.labels,
+ attributes: x.attributes
+ }) AS matches
+
+ RETURN
+ node.uuid AS search_node_uuid, matches
+ """
+ )
+ else:
+ query = (
+ RUNTIME_QUERY
+ + """
+ UNWIND $nodes AS node
+ MATCH (n:Entity {group_id: $group_id})
+ """
+ + filter_query
+ + """
+ WITH node, n, """
+ + get_vector_cosine_func_query(
+ 'n.name_embedding', 'node.name_embedding', driver.provider
+ )
+ + """ AS score
+ WHERE score > $min_score
+ WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
+ """
+ + get_nodes_query(
+ 'node_name_and_summary',
+ 'node.fulltext_query',
+ limit=limit,
+ provider=driver.provider,
+ )
+ + """
+ YIELD node AS m
+ WHERE m.group_id = $group_id
+ WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
+
+ WITH node,
+ top_vector_nodes,
+ [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
+
+ WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
+
+ UNWIND combined_nodes AS combined_node
+ WITH node, collect(DISTINCT combined_node) AS deduped_nodes
+
+ RETURN
+ node.uuid AS search_node_uuid,
+ [x IN deduped_nodes | {
+ uuid: x.uuid,
+ name: x.name,
+ name_embedding: x.name_embedding,
+ group_id: x.group_id,
+ created_at: x.created_at,
+ summary: x.summary,
+ labels: labels(x),
+ attributes: properties(x)
+ }] AS matches
+ """
+ )
+
results, _, _ = await driver.execute_query(
query,
nodes=query_nodes,
@@ -1157,12 +1360,12 @@ async def get_relevant_nodes(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
relevant_nodes_dict: dict[str, list[EntityNode]] = {
result['search_node_uuid']: [
- get_entity_node_from_record(record) for record in result['matches']
+ get_entity_node_from_record(record, driver.provider) for record in result['matches']
]
for result in results
}
@@ -1182,10 +1385,13 @@ async def get_relevant_edges(
if len(edges) == 0:
return []
- query_params: dict[str, Any] = {}
+ filter_queries, filter_params = edge_search_filter_query_constructor(
+ search_filter, driver.provider
+ )
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
- query_params.update(filter_params)
+ filter_query = ''
+ if filter_queries:
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
if driver.provider == GraphProvider.NEPTUNE:
query = (
@@ -1197,7 +1403,7 @@ async def get_relevant_edges(
+ filter_query
+ """
WITH e, edge
- RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
edge.fact_embedding as target_embedding
"""
)
@@ -1207,7 +1413,7 @@ async def get_relevant_edges(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
# Calculate Cosine similarity then return the edge ids
@@ -1220,7 +1426,7 @@ async def get_relevant_edges(
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
# Match the edge ides and return the values
- query = """
+ query = """
UNWIND $ids AS edge
MATCH ()-[e]->()
WHERE id(e) = edge.id
@@ -1246,49 +1452,95 @@ async def get_relevant_edges(
results, _, _ = await driver.execute_query(
query,
- params=query_params,
ids=input_ids,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
else:
- query = (
- RUNTIME_QUERY
- + """
- UNWIND $edges AS edge
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
- """
- + filter_query
- + """
- WITH e, edge, """
- + get_vector_cosine_func_query(
- 'e.fact_embedding', 'edge.fact_embedding', driver.provider
+ if driver.provider == GraphProvider.KUZU:
+ embedding_size = (
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
+ )
+ if embedding_size == 0:
+ return []
+
+ query = (
+ RUNTIME_QUERY
+ + """
+ UNWIND $edges AS edge
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
+ """
+ + filter_query
+ + """
+ WITH e, edge, n, m, """
+ + get_vector_cosine_func_query(
+ 'e.fact_embedding',
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
+ driver.provider,
+ )
+ + """ AS score
+ WHERE score > $min_score
+ WITH e, edge, n, m, score
+ ORDER BY score DESC
+ LIMIT $limit
+ RETURN
+ edge.uuid AS search_edge_uuid,
+ collect({
+ uuid: e.uuid,
+ source_node_uuid: n.uuid,
+ target_node_uuid: m.uuid,
+ created_at: e.created_at,
+ name: e.name,
+ group_id: e.group_id,
+ fact: e.fact,
+ fact_embedding: e.fact_embedding,
+ episodes: e.episodes,
+ expired_at: e.expired_at,
+ valid_at: e.valid_at,
+ invalid_at: e.invalid_at,
+ attributes: e.attributes
+ }) AS matches
+ """
+ )
+ else:
+ query = (
+ RUNTIME_QUERY
+ + """
+ UNWIND $edges AS edge
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
+ """
+ + filter_query
+ + """
+ WITH e, edge, """
+ + get_vector_cosine_func_query(
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
+ )
+ + """ AS score
+ WHERE score > $min_score
+ WITH edge, e, score
+ ORDER BY score DESC
+ RETURN
+ edge.uuid AS search_edge_uuid,
+ collect({
+ uuid: e.uuid,
+ source_node_uuid: startNode(e).uuid,
+ target_node_uuid: endNode(e).uuid,
+ created_at: e.created_at,
+ name: e.name,
+ group_id: e.group_id,
+ fact: e.fact,
+ fact_embedding: e.fact_embedding,
+ episodes: e.episodes,
+ expired_at: e.expired_at,
+ valid_at: e.valid_at,
+ invalid_at: e.invalid_at,
+ attributes: properties(e)
+ })[..$limit] AS matches
+ """
)
- + """ AS score
- WHERE score > $min_score
- WITH edge, e, score
- ORDER BY score DESC
- RETURN edge.uuid AS search_edge_uuid,
- collect({
- uuid: e.uuid,
- source_node_uuid: startNode(e).uuid,
- target_node_uuid: endNode(e).uuid,
- created_at: e.created_at,
- name: e.name,
- group_id: e.group_id,
- fact: e.fact,
- fact_embedding: e.fact_embedding,
- episodes: e.episodes,
- expired_at: e.expired_at,
- valid_at: e.valid_at,
- invalid_at: e.invalid_at,
- attributes: properties(e)
- })[..$limit] AS matches
- """
- )
results, _, _ = await driver.execute_query(
query,
@@ -1296,12 +1548,12 @@ async def get_relevant_edges(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
relevant_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
- get_entity_edge_from_record(record) for record in result['matches']
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
]
for result in results
}
@@ -1321,10 +1573,13 @@ async def get_edge_invalidation_candidates(
if len(edges) == 0:
return []
- query_params: dict[str, Any] = {}
+ filter_queries, filter_params = edge_search_filter_query_constructor(
+ search_filter, driver.provider
+ )
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
- query_params.update(filter_params)
+ filter_query = ''
+ if filter_queries:
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
if driver.provider == GraphProvider.NEPTUNE:
query = (
@@ -1348,7 +1603,7 @@ async def get_edge_invalidation_candidates(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
# Calculate Cosine similarity then return the edge ids
@@ -1361,7 +1616,7 @@ async def get_edge_invalidation_candidates(
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
# Match the edge ides and return the values
- query = """
+ query = """
UNWIND $ids AS edge
MATCH ()-[e]->()
WHERE id(e) = edge.id
@@ -1391,44 +1646,92 @@ async def get_edge_invalidation_candidates(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
else:
- query = (
- RUNTIME_QUERY
- + """
- UNWIND $edges AS edge
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
- """
- + filter_query
- + """
- WITH edge, e, """
- + get_vector_cosine_func_query(
- 'e.fact_embedding', 'edge.fact_embedding', driver.provider
+ if driver.provider == GraphProvider.KUZU:
+ embedding_size = (
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
+ )
+ if embedding_size == 0:
+ return []
+
+ query = (
+ RUNTIME_QUERY
+ + """
+ UNWIND $edges AS edge
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
+ WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
+ """
+ + filter_query
+ + """
+ WITH edge, e, n, m, """
+ + get_vector_cosine_func_query(
+ 'e.fact_embedding',
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
+ driver.provider,
+ )
+ + """ AS score
+ WHERE score > $min_score
+ WITH edge, e, n, m, score
+ ORDER BY score DESC
+ LIMIT $limit
+ RETURN
+ edge.uuid AS search_edge_uuid,
+ collect({
+ uuid: e.uuid,
+ source_node_uuid: n.uuid,
+ target_node_uuid: m.uuid,
+ created_at: e.created_at,
+ name: e.name,
+ group_id: e.group_id,
+ fact: e.fact,
+ fact_embedding: e.fact_embedding,
+ episodes: e.episodes,
+ expired_at: e.expired_at,
+ valid_at: e.valid_at,
+ invalid_at: e.invalid_at,
+ attributes: e.attributes
+ }) AS matches
+ """
+ )
+ else:
+ query = (
+ RUNTIME_QUERY
+ + """
+ UNWIND $edges AS edge
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
+ """
+ + filter_query
+ + """
+ WITH edge, e, """
+ + get_vector_cosine_func_query(
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
+ )
+ + """ AS score
+ WHERE score > $min_score
+ WITH edge, e, score
+ ORDER BY score DESC
+ RETURN
+ edge.uuid AS search_edge_uuid,
+ collect({
+ uuid: e.uuid,
+ source_node_uuid: startNode(e).uuid,
+ target_node_uuid: endNode(e).uuid,
+ created_at: e.created_at,
+ name: e.name,
+ group_id: e.group_id,
+ fact: e.fact,
+ fact_embedding: e.fact_embedding,
+ episodes: e.episodes,
+ expired_at: e.expired_at,
+ valid_at: e.valid_at,
+ invalid_at: e.invalid_at,
+ attributes: properties(e)
+ })[..$limit] AS matches
+ """
)
- + """ AS score
- WHERE score > $min_score
- WITH edge, e, score
- ORDER BY score DESC
- RETURN edge.uuid AS search_edge_uuid,
- collect({
- uuid: e.uuid,
- source_node_uuid: startNode(e).uuid,
- target_node_uuid: endNode(e).uuid,
- created_at: e.created_at,
- name: e.name,
- group_id: e.group_id,
- fact: e.fact,
- fact_embedding: e.fact_embedding,
- episodes: e.episodes,
- expired_at: e.expired_at,
- valid_at: e.valid_at,
- invalid_at: e.invalid_at,
- attributes: properties(e)
- })[..$limit] AS matches
- """
- )
results, _, _ = await driver.execute_query(
query,
@@ -1436,11 +1739,11 @@ async def get_edge_invalidation_candidates(
limit=limit,
min_score=min_score,
routing_='r',
- **query_params,
+ **filter_params,
)
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
- get_entity_edge_from_record(record) for record in result['matches']
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
]
for result in results
}
@@ -1479,13 +1782,21 @@ async def node_distance_reranker(
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {center_node_uuid: 0.0}
+ query = """
+ UNWIND $node_uuids AS node_uuid
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
+ RETURN 1 AS score, node_uuid AS uuid
+ """
+ if driver.provider == GraphProvider.KUZU:
+ query = """
+ UNWIND $node_uuids AS node_uuid
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
+ RETURN 1 AS score, node_uuid AS uuid
+ """
+
# Find the shortest path to center node
results, header, _ = await driver.execute_query(
- """
- UNWIND $node_uuids AS node_uuid
- MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
- RETURN 1 AS score, node_uuid AS uuid
- """,
+ query,
node_uuids=filtered_uuids,
center_uuid=center_node_uuid,
routing_='r',
@@ -1536,6 +1847,10 @@ async def episode_mentions_reranker(
for result in results:
scores[result['uuid']] = result['score']
+ for uuid in sorted_uuids:
+ if uuid not in scores:
+ scores[uuid] = float('inf')
+
# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
@@ -1667,13 +1982,23 @@ async def get_embeddings_for_edges(
split(e.fact_embedding, ",") AS fact_embedding
"""
else:
- query = """
- MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
+ match_query = """
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
+ """
+
+ query = (
+ match_query
+ + """
WHERE e.uuid IN $edge_uuids
RETURN DISTINCT
e.uuid AS uuid,
e.fact_embedding AS fact_embedding
"""
+ )
results, _, _ = await driver.execute_query(
query,
edge_uuids=[edge.uuid for edge in edges],
diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py
index e20b20b0..14be80a2 100644
--- a/graphiti_core/utils/bulk_utils.py
+++ b/graphiti_core/utils/bulk_utils.py
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
+import json
import logging
import typing
from datetime import datetime
@@ -22,20 +23,21 @@ import numpy as np
from pydantic import BaseModel, Field
from typing_extensions import Any
-from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
+from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
- EPISODIC_EDGE_SAVE_BULK,
get_entity_edge_save_bulk_query,
+ get_episodic_edge_save_bulk_query,
)
from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
+from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
resolve_extracted_edge,
@@ -116,11 +118,15 @@ async def add_nodes_and_edges_bulk_tx(
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
+ episode.pop('labels', None)
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
- nodes: list[dict[str, Any]] = []
+
+ nodes = []
+
for node in entity_nodes:
if node.name_embedding is None:
await node.generate_name_embedding(embedder)
+
entity_data: dict[str, Any] = {
'uuid': node.uuid,
'name': node.name,
@@ -130,13 +136,19 @@ async def add_nodes_and_edges_bulk_tx(
'created_at': node.created_at,
}
- entity_data.update(node.attributes or {})
- entity_data['labels'] = list(
- set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
- )
+ entity_data['labels'] = list(set(node.labels + ['Entity']))
+ if driver.provider == GraphProvider.KUZU:
+ attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
+ entity_data['attributes'] = json.dumps(attributes)
+ else:
+ entity_data.update(node.attributes or {})
+ entity_data['labels'] = list(
+ set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
+ )
+
nodes.append(entity_data)
- edges: list[dict[str, Any]] = []
+ edges = []
for edge in entity_edges:
if edge.fact_embedding is None:
await edge.generate_embedding(embedder)
@@ -155,17 +167,36 @@ async def add_nodes_and_edges_bulk_tx(
'invalid_at': edge.invalid_at,
}
- edge_data.update(edge.attributes or {})
+ if driver.provider == GraphProvider.KUZU:
+ attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
+ edge_data['attributes'] = json.dumps(attributes)
+ else:
+ edge_data.update(edge.attributes or {})
+
edges.append(edge_data)
- await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
- entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
- await tx.run(entity_node_save_bulk, nodes=nodes)
- await tx.run(
- EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
- )
- entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
- await tx.run(entity_edge_save_bulk, entity_edges=edges)
+ if driver.provider == GraphProvider.KUZU:
+ # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
+ episode_query = get_episode_node_save_bulk_query(driver.provider)
+ for episode in episodes:
+ await tx.run(episode_query, **episode)
+ entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
+ for node in nodes:
+ await tx.run(entity_node_query, **node)
+ entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
+ for edge in edges:
+ await tx.run(entity_edge_query, **edge)
+ episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
+ for edge in episodic_edges:
+ await tx.run(episodic_edge_query, **edge.model_dump())
+ else:
+ await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
+ await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
+ await tx.run(
+ get_episodic_edge_save_bulk_query(driver.provider),
+ episodic_edges=[edge.model_dump() for edge in episodic_edges],
+ )
+ await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
async def extract_nodes_and_edges_bulk(
diff --git a/graphiti_core/utils/datetime_utils.py b/graphiti_core/utils/datetime_utils.py
index 71550108..7ef11dc6 100644
--- a/graphiti_core/utils/datetime_utils.py
+++ b/graphiti_core/utils/datetime_utils.py
@@ -40,3 +40,16 @@ def ensure_utc(dt: datetime | None) -> datetime | None:
return dt.astimezone(timezone.utc)
return dt
+
+
+def convert_datetimes_to_strings(obj):
+ if isinstance(obj, dict):
+ return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [convert_datetimes_to_strings(item) for item in obj]
+ elif isinstance(obj, tuple):
+ return tuple(convert_datetimes_to_strings(item) for item in obj)
+ elif isinstance(obj, datetime):
+ return obj.isoformat()
+ else:
+ return obj
diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py
index bdcb0e1d..260870bc 100644
--- a/graphiti_core/utils/maintenance/community_operations.py
+++ b/graphiti_core/utils/maintenance/community_operations.py
@@ -4,11 +4,12 @@ from collections import defaultdict
from pydantic import BaseModel
-from graphiti_core.driver.driver import GraphDriver
+from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient
+from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
@@ -33,11 +34,11 @@ async def get_community_clusters(
if group_ids is None:
group_id_values, _, _ = await driver.execute_query(
"""
- MATCH (n:Entity)
- WHERE n.group_id IS NOT NULL
- RETURN
- collect(DISTINCT n.group_id) AS group_ids
- """,
+ MATCH (n:Entity)
+ WHERE n.group_id IS NOT NULL
+ RETURN
+ collect(DISTINCT n.group_id) AS group_ids
+ """
)
group_ids = group_id_values[0]['group_ids'] if group_id_values else []
@@ -46,14 +47,21 @@ async def get_community_clusters(
projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes:
- records, _, _ = await driver.execute_query(
+ match_query = """
+ MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
"""
- MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
- WITH count(r) AS count, m.uuid AS uuid
- RETURN
- uuid,
- count
- """,
+ records, _, _ = await driver.execute_query(
+ match_query
+ + """
+ WITH count(e) AS count, m.uuid AS uuid
+ RETURN
+ uuid,
+ count
+ """,
uuid=node.uuid,
group_id=group_id,
)
@@ -235,9 +243,9 @@ async def build_communities(
async def remove_communities(driver: GraphDriver):
await driver.execute_query(
"""
- MATCH (c:Community)
- DETACH DELETE c
- """,
+ MATCH (c:Community)
+ DETACH DELETE c
+ """
)
@@ -247,14 +255,10 @@ async def determine_entity_community(
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
"""
- MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
- RETURN
- c.uuid AS uuid,
- c.name AS name,
- c.group_id AS group_id,
- c.created_at AS created_at,
- c.summary AS summary
- """,
+ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
+ RETURN
+ """
+ + COMMUNITY_NODE_RETURN,
entity_uuid=entity.uuid,
)
@@ -262,16 +266,19 @@ async def determine_entity_community(
return get_community_node_from_record(records[0]), False
# If the node has no community, add it to the mode community of surrounding entities
- records, _, _ = await driver.execute_query(
+ match_query = """
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
+ """
+ if driver.provider == GraphProvider.KUZU:
+ match_query = """
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
"""
- MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
- RETURN
- c.uuid AS uuid,
- c.name AS name,
- c.group_id AS group_id,
- c.created_at AS created_at,
- c.summary AS summary
- """,
+ records, _, _ = await driver.execute_query(
+ match_query
+ + """
+ RETURN
+ """
+ + COMMUNITY_NODE_RETURN,
entity_uuid=entity.uuid,
)
diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py
index eb08fa7d..55cea243 100644
--- a/graphiti_core/utils/maintenance/edge_operations.py
+++ b/graphiti_core/utils/maintenance/edge_operations.py
@@ -531,17 +531,28 @@ async def filter_existing_duplicate_of_edges(
routing_='r',
)
else:
- query: LiteralString = """
- UNWIND $duplicate_node_uuids AS duplicate_tuple
- MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
- RETURN DISTINCT
- n.uuid AS source_uuid,
- m.uuid AS target_uuid
- """
+ if driver.provider == GraphProvider.KUZU:
+ query = """
+ UNWIND $duplicate_node_uuids AS duplicate
+ MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
+ RETURN DISTINCT
+ n.uuid AS source_uuid,
+ m.uuid AS target_uuid
+ """
+ duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
+ else:
+ query: LiteralString = """
+ UNWIND $duplicate_node_uuids AS duplicate_tuple
+ MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
+ RETURN DISTINCT
+ n.uuid AS source_uuid,
+ m.uuid AS target_uuid
+ """
+ duplicate_node_uuids = list(duplicate_nodes_map.keys())
records, _, _ = await driver.execute_query(
query,
- duplicate_node_uuids=list(duplicate_nodes_map.keys()),
+ duplicate_node_uuids=duplicate_node_uuids,
routing_='r',
)
diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py
index 4cf97784..66dc55e4 100644
--- a/graphiti_core/utils/maintenance/graph_data_operations.py
+++ b/graphiti_core/utils/maintenance/graph_data_operations.py
@@ -53,10 +53,29 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
for name in index_names
]
)
+
range_indices: list[LiteralString] = get_range_indices(driver.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
+ if driver.provider == GraphProvider.KUZU:
+ # Skip creating fulltext indices if they already exist. Need to do this manually
+ # until Kuzu supports `IF NOT EXISTS` for indices.
+ result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
+ if len(result) > 0:
+ fulltext_indices = []
+
+ # Only load the `fts` extension if it's not already loaded, otherwise throw an error.
+ result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
+ if len(result) == 0:
+ fulltext_indices.insert(
+ 0,
+ """
+ INSTALL fts;
+ LOAD fts;
+ """,
+ )
+
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
@@ -76,10 +95,19 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
await tx.run('MATCH (n) DETACH DELETE n')
async def delete_group_ids(tx):
- await tx.run(
- 'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n',
- group_ids=group_ids,
- )
+ labels = ['Entity', 'Episodic', 'Community']
+ if driver.provider == GraphProvider.KUZU:
+ labels.append('RelatesToNode_')
+
+ for label in labels:
+ await tx.run(
+ f"""
+ MATCH (n:{label})
+ WHERE n.group_id IN $group_ids
+ DETACH DELETE n
+ """,
+ group_ids=group_ids,
+ )
if group_ids is None:
await session.execute_write(delete_all)
@@ -108,18 +136,23 @@ async def retrieve_episodes(
Returns:
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
"""
- group_id_filter: LiteralString = (
- '\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
- )
- source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
+
+ query_params: dict = {}
+ query_filter = ''
+ if group_ids and len(group_ids) > 0:
+ query_filter += '\nAND e.group_id IN $group_ids'
+ query_params['group_ids'] = group_ids
+
+ if source is not None:
+ query_filter += '\nAND e.source = $source'
+ query_params['source'] = source.name
query: LiteralString = (
"""
- MATCH (e:Episodic)
- WHERE e.valid_at <= $reference_time
- """
- + group_id_filter
- + source_filter
+ MATCH (e:Episodic)
+ WHERE e.valid_at <= $reference_time
+ """
+ + query_filter
+ """
RETURN
"""
@@ -136,9 +169,8 @@ async def retrieve_episodes(
result, _, _ = await driver.execute_query(
query,
reference_time=reference_time,
- source=source.name if source is not None else None,
num_episodes=last_n,
- group_ids=group_ids,
+ **query_params,
)
episodes = [get_episodic_node_from_record(record) for record in result]
diff --git a/pyproject.toml b/pyproject.toml
index 0301f20c..4dd8db74 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,7 @@ Repository = "https://github.com/getzep/graphiti"
anthropic = ["anthropic>=0.49.0"]
groq = ["groq>=0.2.0"]
google-genai = ["google-genai>=1.8.0"]
+kuzu = ["kuzu>=0.11.2"]
falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"]
sentence-transformers = ["sentence-transformers>=3.2.1"]
@@ -39,6 +40,7 @@ dev = [
"anthropic>=0.49.0",
"google-genai>=1.8.0",
"falkordb>=1.1.2,<2.0.0",
+ "kuzu>=0.11.2",
"ipykernel>=6.29.5",
"jupyterlab>=4.2.4",
"diskcache-stubs>=5.6.3.6.20240818",
@@ -91,7 +93,3 @@ docstring-code-format = true
include = ["graphiti_core"]
pythonVersion = "3.10"
typeCheckingMode = "basic"
-
-[[tool.pyright.overrides]]
-include = ["**/falkordb*"]
-reportMissingImports = false
diff --git a/pytest.ini b/pytest.ini
index 9d26782b..7699537e 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,4 +1,5 @@
[pytest]
markers =
integration: marks tests as integration tests
-asyncio_default_fixture_loop_scope = function
\ No newline at end of file
+asyncio_default_fixture_loop_scope = function
+asyncio_mode = auto
diff --git a/tests/helpers_test.py b/tests/helpers_test.py
index 3614eea7..58ef0c3e 100644
--- a/tests/helpers_test.py
+++ b/tests/helpers_test.py
@@ -15,42 +15,55 @@ limitations under the License.
"""
import os
+from unittest.mock import Mock
+import numpy as np
import pytest
from dotenv import load_dotenv
-from graphiti_core.driver.driver import GraphDriver
-from graphiti_core.driver.neptune_driver import NeptuneDriver
+from graphiti_core.driver.driver import GraphDriver, GraphProvider
+from graphiti_core.edges import EntityEdge, EpisodicEdge
+from graphiti_core.embedder.client import EmbedderClient
from graphiti_core.helpers import lucene_sanitize
+from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
+from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
-HAS_NEO4J = False
-HAS_FALKORDB = False
-HAS_NEPTUNE = False
+drivers: list[GraphProvider] = []
if os.getenv('DISABLE_NEO4J') is None:
try:
from graphiti_core.driver.neo4j_driver import Neo4jDriver
- HAS_NEO4J = True
+ drivers.append(GraphProvider.NEO4J)
except ImportError:
- pass
+ raise
if os.getenv('DISABLE_FALKORDB') is None:
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver
- HAS_FALKORDB = True
+ drivers.append(GraphProvider.FALKORDB)
except ImportError:
- pass
+ raise
+if os.getenv('DISABLE_KUZU') is None:
+ try:
+ from graphiti_core.driver.kuzu_driver import KuzuDriver
+
+ drivers.append(GraphProvider.KUZU)
+ except ImportError:
+ raise
+
+# Disable Neptune for now
+os.environ['DISABLE_NEPTUNE'] = 'True'
if os.getenv('DISABLE_NEPTUNE') is None:
try:
from graphiti_core.driver.neptune_driver import NeptuneDriver
- HAS_NEPTUNE = False
+ drivers.append(GraphProvider.NEPTUNE)
except ImportError:
- pass
+ raise
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
@@ -65,38 +78,100 @@ NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
AOSS_HOST = os.getenv('AOSS_HOST', None)
+KUZU_DB = os.getenv('KUZU_DB', ':memory:')
-def get_driver(driver_name: str) -> GraphDriver:
- if driver_name == 'neo4j':
+group_id = 'graphiti_test_group'
+group_id_2 = 'graphiti_test_group_2'
+
+
+def get_driver(provider: GraphProvider) -> GraphDriver:
+ if provider == GraphProvider.NEO4J:
return Neo4jDriver(
uri=NEO4J_URI,
user=NEO4J_USER,
password=NEO4J_PASSWORD,
)
- elif driver_name == 'falkordb':
+ elif provider == GraphProvider.FALKORDB:
return FalkorDriver(
host=FALKORDB_HOST,
port=int(FALKORDB_PORT),
username=FALKORDB_USER,
password=FALKORDB_PASSWORD,
)
- elif driver_name == 'neptune':
+ elif provider == GraphProvider.KUZU:
+ driver = KuzuDriver(
+ db=KUZU_DB,
+ )
+ return driver
+ elif provider == GraphProvider.NEPTUNE:
return NeptuneDriver(
host=NEPTUNE_HOST,
port=int(NEPTUNE_PORT),
aoss_host=AOSS_HOST,
)
else:
- raise ValueError(f'Driver {driver_name} not available')
+ raise ValueError(f'Driver {provider} not available')
-drivers: list[str] = []
-if HAS_NEO4J:
- drivers.append('neo4j')
-if HAS_FALKORDB:
- drivers.append('falkordb')
-if HAS_NEPTUNE:
- drivers.append('neptune')
+@pytest.fixture(params=drivers)
+async def graph_driver(request):
+ driver = request.param
+ graph_driver = get_driver(driver)
+ await clear_data(graph_driver, [group_id, group_id_2])
+ try:
+ yield graph_driver # provide driver to the test
+ finally:
+ # always called, even if the test fails or raises
+ # await clean_up(graph_driver)
+ await graph_driver.close()
+
+
+embedding_dim = 384
+embeddings = {
+ key: np.random.uniform(0.0, 0.9, embedding_dim).tolist()
+ for key in [
+ 'Alice',
+ 'Bob',
+ 'Alice likes Bob',
+ 'test_entity_1',
+ 'test_entity_2',
+ 'test_entity_3',
+ 'test_entity_4',
+ 'test_entity_alice',
+ 'test_entity_bob',
+ 'test_entity_1 is a duplicate of test_entity_2',
+ 'test_entity_3 is a duplicate of test_entity_4',
+ 'test_entity_1 relates to test_entity_2',
+ 'test_entity_1 relates to test_entity_3',
+ 'test_entity_2 relates to test_entity_3',
+ 'test_entity_1 relates to test_entity_4',
+ 'test_entity_2 relates to test_entity_4',
+ 'test_entity_3 relates to test_entity_4',
+ 'test_entity_1 relates to test_entity_2',
+ 'test_entity_3 relates to test_entity_4',
+ 'test_entity_2 relates to test_entity_3',
+ 'test_community_1',
+ 'test_community_2',
+ ]
+}
+embeddings['Alice Smith'] = embeddings['Alice']
+
+
+@pytest.fixture
+def mock_embedder():
+ mock_model = Mock(spec=EmbedderClient)
+
+ def mock_embed(input_data):
+ if isinstance(input_data, str):
+ return embeddings[input_data]
+ elif isinstance(input_data, list):
+ combined_input = ' '.join(input_data)
+ return embeddings[combined_input]
+ else:
+ raise ValueError(f'Unsupported input type: {type(input_data)}')
+
+ mock_model.create.side_effect = mock_embed
+ return mock_model
def test_lucene_sanitize():
@@ -114,5 +189,125 @@ def test_lucene_sanitize():
assert assert_result == result
+async def get_node_count(driver: GraphDriver, uuids: list[str]) -> int:
+ results, _, _ = await driver.execute_query(
+ """
+ MATCH (n)
+ WHERE n.uuid IN $uuids
+ RETURN COUNT(n) as count
+ """,
+ uuids=uuids,
+ )
+ return int(results[0]['count'])
+
+
+async def get_edge_count(driver: GraphDriver, uuids: list[str]) -> int:
+ results, _, _ = await driver.execute_query(
+ """
+ MATCH (n)-[e]->(m)
+ WHERE e.uuid IN $uuids
+ RETURN COUNT(e) as count
+ UNION ALL
+ MATCH (e:RelatesToNode_)
+ WHERE e.uuid IN $uuids
+ RETURN COUNT(e) as count
+ """,
+ uuids=uuids,
+ )
+ return sum(int(result['count']) for result in results)
+
+
+async def print_graph(graph_driver: GraphDriver):
+ nodes, _, _ = await graph_driver.execute_query(
+ """
+ MATCH (n)
+ RETURN n.uuid, n.name
+ """,
+ )
+ print('Nodes:')
+ for node in nodes:
+ print(' ', node)
+ edges, _, _ = await graph_driver.execute_query(
+ """
+ MATCH (n)-[e]->(m)
+ RETURN n.name, e.uuid, m.name
+ """,
+ )
+ print('Edges:')
+ for edge in edges:
+ print(' ', edge)
+
+
+async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode):
+ assert retrieved.uuid == sample.uuid
+ assert retrieved.name == sample.name
+ assert retrieved.group_id == group_id
+ assert retrieved.created_at == sample.created_at
+ assert retrieved.source == sample.source
+ assert retrieved.source_description == sample.source_description
+ assert retrieved.content == sample.content
+ assert retrieved.valid_at == sample.valid_at
+ assert set(retrieved.entity_edges) == set(sample.entity_edges)
+
+
+async def assert_entity_node_equals(
+ graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode
+):
+ await retrieved.load_name_embedding(graph_driver)
+ assert retrieved.uuid == sample.uuid
+ assert retrieved.name == sample.name
+ assert retrieved.group_id == sample.group_id
+ assert set(retrieved.labels) == set(sample.labels)
+ assert retrieved.created_at == sample.created_at
+ assert retrieved.name_embedding is not None
+ assert sample.name_embedding is not None
+ assert np.allclose(retrieved.name_embedding, sample.name_embedding)
+ assert retrieved.summary == sample.summary
+ assert retrieved.attributes == sample.attributes
+
+
+async def assert_community_node_equals(
+ graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode
+):
+ await retrieved.load_name_embedding(graph_driver)
+ assert retrieved.uuid == sample.uuid
+ assert retrieved.name == sample.name
+ assert retrieved.group_id == group_id
+ assert retrieved.created_at == sample.created_at
+ assert retrieved.name_embedding is not None
+ assert sample.name_embedding is not None
+ assert np.allclose(retrieved.name_embedding, sample.name_embedding)
+ assert retrieved.summary == sample.summary
+
+
+async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge):
+ assert retrieved.uuid == sample.uuid
+ assert retrieved.group_id == sample.group_id
+ assert retrieved.created_at == sample.created_at
+ assert retrieved.source_node_uuid == sample.source_node_uuid
+ assert retrieved.target_node_uuid == sample.target_node_uuid
+
+
+async def assert_entity_edge_equals(
+ graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge
+):
+ await retrieved.load_fact_embedding(graph_driver)
+ assert retrieved.uuid == sample.uuid
+ assert retrieved.group_id == sample.group_id
+ assert retrieved.created_at == sample.created_at
+ assert retrieved.source_node_uuid == sample.source_node_uuid
+ assert retrieved.target_node_uuid == sample.target_node_uuid
+ assert retrieved.name == sample.name
+ assert retrieved.fact == sample.fact
+ assert retrieved.fact_embedding is not None
+ assert sample.fact_embedding is not None
+ assert np.allclose(retrieved.fact_embedding, sample.fact_embedding)
+ assert retrieved.episodes == sample.episodes
+ assert retrieved.expired_at == sample.expired_at
+ assert retrieved.valid_at == sample.valid_at
+ assert retrieved.invalid_at == sample.invalid_at
+ assert retrieved.attributes == sample.attributes
+
+
if __name__ == '__main__':
pytest.main([__file__])
diff --git a/tests/test_edge_int.py b/tests/test_edge_int.py
index 6eb769a4..15555d72 100644
--- a/tests/test_edge_int.py
+++ b/tests/test_edge_int.py
@@ -17,23 +17,16 @@ limitations under the License.
import logging
import sys
from datetime import datetime
-from uuid import uuid4
import numpy as np
import pytest
-from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
-from graphiti_core.embedder.openai import OpenAIEmbedder
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
-from tests.helpers_test import drivers, get_driver
-
-pytestmark = pytest.mark.integration
+from tests.helpers_test import get_edge_count, get_node_count, group_id
pytest_plugins = ('pytest_asyncio',)
-group_id = f'test_group_{str(uuid4())}'
-
def setup_logging():
# Create a logger
@@ -57,17 +50,10 @@ def setup_logging():
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- 'driver',
- drivers,
- ids=drivers,
-)
-async def test_episodic_edge(driver):
- graph_driver = get_driver(driver)
- embedder = OpenAIEmbedder()
-
+async def test_episodic_edge(graph_driver, mock_embedder):
now = datetime.now()
+ # Create episodic node
episode_node = EpisodicNode(
name='test_episode',
labels=[],
@@ -79,13 +65,13 @@ async def test_episodic_edge(driver):
entity_edges=[],
group_id=group_id,
)
-
- node_count = await get_node_count(graph_driver, episode_node.uuid)
+ node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 0
await episode_node.save(graph_driver)
- node_count = await get_node_count(graph_driver, episode_node.uuid)
+ node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 1
+ # Create entity node
alice_node = EntityNode(
name='Alice',
labels=[],
@@ -93,27 +79,27 @@ async def test_episodic_edge(driver):
summary='Alice summary',
group_id=group_id,
)
- await alice_node.generate_name_embedding(embedder)
-
- node_count = await get_node_count(graph_driver, alice_node.uuid)
+ await alice_node.generate_name_embedding(mock_embedder)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
- node_count = await get_node_count(graph_driver, alice_node.uuid)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
+ # Create episodic to entity edge
episodic_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid,
created_at=now,
group_id=group_id,
)
-
- edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
+ edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
await episodic_edge.save(graph_driver)
- edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
+ edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 1
+ # Get edge by uuid
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
assert retrieved.uuid == episodic_edge.uuid
assert retrieved.source_node_uuid == episode_node.uuid
@@ -121,6 +107,7 @@ async def test_episodic_edge(driver):
assert retrieved.created_at == now
assert retrieved.group_id == group_id
+ # Get edge by uuids
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
@@ -129,6 +116,7 @@ async def test_episodic_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
+ # Get edge by group ids
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
@@ -137,33 +125,41 @@ async def test_episodic_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
+ # Get episodic node by entity node uuid
+ retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
+ assert len(retrieved) == 1
+ assert retrieved[0].uuid == episode_node.uuid
+ assert retrieved[0].name == 'test_episode'
+ assert retrieved[0].created_at == now
+ assert retrieved[0].group_id == group_id
+
+ # Delete edge by uuid
await episodic_edge.delete(graph_driver)
- edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
+ edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
- await episode_node.delete(graph_driver)
- node_count = await get_node_count(graph_driver, episode_node.uuid)
- assert node_count == 0
+ # Delete edge by uuids
+ await episodic_edge.save(graph_driver)
+ await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
+ edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
+ assert edge_count == 0
+ # Cleanup nodes
+ await episode_node.delete(graph_driver)
+ node_count = await get_node_count(graph_driver, [episode_node.uuid])
+ assert node_count == 0
await alice_node.delete(graph_driver)
- node_count = await get_node_count(graph_driver, alice_node.uuid)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- 'driver',
- drivers,
- ids=drivers,
-)
-async def test_entity_edge(driver):
- graph_driver = get_driver(driver)
- embedder = OpenAIEmbedder()
-
+async def test_entity_edge(graph_driver, mock_embedder):
now = datetime.now()
+ # Create entity node
alice_node = EntityNode(
name='Alice',
labels=[],
@@ -171,25 +167,25 @@ async def test_entity_edge(driver):
summary='Alice summary',
group_id=group_id,
)
- await alice_node.generate_name_embedding(embedder)
-
- node_count = await get_node_count(graph_driver, alice_node.uuid)
+ await alice_node.generate_name_embedding(mock_embedder)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
- node_count = await get_node_count(graph_driver, alice_node.uuid)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
+ # Create entity node
bob_node = EntityNode(
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
)
- await bob_node.generate_name_embedding(embedder)
-
- node_count = await get_node_count(graph_driver, bob_node.uuid)
+ await bob_node.generate_name_embedding(mock_embedder)
+ node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0
await bob_node.save(graph_driver)
- node_count = await get_node_count(graph_driver, bob_node.uuid)
+ node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 1
+ # Create entity to entity edge
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
@@ -202,14 +198,14 @@ async def test_entity_edge(driver):
invalid_at=now,
group_id=group_id,
)
- edge_embedding = await entity_edge.generate_embedding(embedder)
-
- edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
+ edge_embedding = await entity_edge.generate_embedding(mock_embedder)
+ edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
await entity_edge.save(graph_driver)
- edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
+ edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 1
+ # Get edge by uuid
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
assert retrieved.uuid == entity_edge.uuid
assert retrieved.source_node_uuid == alice_node.uuid
@@ -217,6 +213,7 @@ async def test_entity_edge(driver):
assert retrieved.created_at == now
assert retrieved.group_id == group_id
+ # Get edge by uuids
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
@@ -225,6 +222,7 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
+ # Get edge by group ids
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
@@ -233,6 +231,7 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
+ # Get edge by node uuid
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
@@ -241,82 +240,113 @@ async def test_entity_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
+ # Get fact embedding
await entity_edge.load_fact_embedding(graph_driver)
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
+ # Delete edge by uuid
await entity_edge.delete(graph_driver)
- edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
+ edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
- await alice_node.delete(graph_driver)
- node_count = await get_node_count(graph_driver, alice_node.uuid)
- assert node_count == 0
+ # Delete edge by uuids
+ await entity_edge.save(graph_driver)
+ await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
+ edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
+ assert edge_count == 0
+ # Deleting node should delete the edge
+ await entity_edge.save(graph_driver)
+ await alice_node.delete(graph_driver)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
+ assert node_count == 0
+ edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
+ assert edge_count == 0
+
+ # Deleting node by uuids should delete the edge
+ await alice_node.save(graph_driver)
+ await entity_edge.save(graph_driver)
+ await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
+ assert node_count == 0
+ edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
+ assert edge_count == 0
+
+ # Deleting node by group id should delete the edge
+ await alice_node.save(graph_driver)
+ await entity_edge.save(graph_driver)
+ await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
+ assert node_count == 0
+ edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
+ assert edge_count == 0
+
+ # Cleanup nodes
+ await alice_node.delete(graph_driver)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
+ assert node_count == 0
await bob_node.delete(graph_driver)
- node_count = await get_node_count(graph_driver, bob_node.uuid)
+ node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- 'driver',
- drivers,
- ids=drivers,
-)
-async def test_community_edge(driver):
- graph_driver = get_driver(driver)
- embedder = OpenAIEmbedder()
-
+async def test_community_edge(graph_driver, mock_embedder):
now = datetime.now()
+ # Create community node
community_node_1 = CommunityNode(
- name='Community A',
+ name='test_community_1',
group_id=group_id,
summary='Community A summary',
)
- await community_node_1.generate_name_embedding(embedder)
- node_count = await get_node_count(graph_driver, community_node_1.uuid)
+ await community_node_1.generate_name_embedding(mock_embedder)
+ node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0
await community_node_1.save(graph_driver)
- node_count = await get_node_count(graph_driver, community_node_1.uuid)
+ node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 1
+ # Create community node
community_node_2 = CommunityNode(
- name='Community B',
+ name='test_community_2',
group_id=group_id,
summary='Community B summary',
)
- await community_node_2.generate_name_embedding(embedder)
- node_count = await get_node_count(graph_driver, community_node_2.uuid)
+ await community_node_2.generate_name_embedding(mock_embedder)
+ node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0
await community_node_2.save(graph_driver)
- node_count = await get_node_count(graph_driver, community_node_2.uuid)
+ node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 1
+ # Create entity node
alice_node = EntityNode(
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
)
- await alice_node.generate_name_embedding(embedder)
- node_count = await get_node_count(graph_driver, alice_node.uuid)
+ await alice_node.generate_name_embedding(mock_embedder)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
- node_count = await get_node_count(graph_driver, alice_node.uuid)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
+ # Create community to community edge
community_edge = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=community_node_2.uuid,
created_at=now,
group_id=group_id,
)
- edge_count = await get_edge_count(graph_driver, community_edge.uuid)
+ edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
await community_edge.save(graph_driver)
- edge_count = await get_edge_count(graph_driver, community_edge.uuid)
+ edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 1
+ # Get edge by uuid
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
assert retrieved.uuid == community_edge.uuid
assert retrieved.source_node_uuid == community_node_1.uuid
@@ -324,6 +354,7 @@ async def test_community_edge(driver):
assert retrieved.created_at == now
assert retrieved.group_id == group_id
+ # Get edge by uuids
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
@@ -332,6 +363,7 @@ async def test_community_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
+ # Get edge by group ids
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
@@ -340,45 +372,26 @@ async def test_community_edge(driver):
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
+ # Delete edge by uuid
await community_edge.delete(graph_driver)
- edge_count = await get_edge_count(graph_driver, community_edge.uuid)
+ edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
+ # Delete edge by uuids
+ await community_edge.save(graph_driver)
+ await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
+ edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
+ assert edge_count == 0
+
+ # Cleanup nodes
await alice_node.delete(graph_driver)
- node_count = await get_node_count(graph_driver, alice_node.uuid)
+ node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
-
await community_node_1.delete(graph_driver)
- node_count = await get_node_count(graph_driver, community_node_1.uuid)
+ node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0
-
await community_node_2.delete(graph_driver)
- node_count = await get_node_count(graph_driver, community_node_2.uuid)
+ node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0
await graph_driver.close()
-
-
-async def get_node_count(driver: GraphDriver, uuid: str):
- results, _, _ = await driver.execute_query(
- """
- MATCH (n {uuid: $uuid})
- RETURN COUNT(n) as count
- """,
- uuid=uuid,
- )
- return int(results[0]['count'])
-
-
-async def get_edge_count(driver: GraphDriver, uuid: str):
- results, _, _ = await driver.execute_query(
- """
- MATCH (n)-[e {uuid: $uuid}]->(m)
- RETURN COUNT(e) as count
- UNION ALL
- MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2)
- RETURN COUNT(m) as count
- """,
- uuid=uuid,
- )
- return sum(int(result['count']) for result in results)
diff --git a/tests/test_entity_exclusion_int.py b/tests/test_entity_exclusion_int.py
index 473177b1..0ac9897c 100644
--- a/tests/test_entity_exclusion_int.py
+++ b/tests/test_entity_exclusion_int.py
@@ -60,7 +60,6 @@ class Location(BaseModel):
@pytest.mark.parametrize(
'driver',
drivers,
- ids=drivers,
)
async def test_exclude_default_entity_type(driver):
"""Test excluding the default 'Entity' type while keeping custom types."""
@@ -118,7 +117,6 @@ async def test_exclude_default_entity_type(driver):
@pytest.mark.parametrize(
'driver',
drivers,
- ids=drivers,
)
async def test_exclude_specific_custom_types(driver):
"""Test excluding specific custom entity types while keeping others."""
@@ -182,7 +180,6 @@ async def test_exclude_specific_custom_types(driver):
@pytest.mark.parametrize(
'driver',
drivers,
- ids=drivers,
)
async def test_exclude_all_types(driver):
"""Test excluding all entity types (edge case)."""
@@ -231,7 +228,6 @@ async def test_exclude_all_types(driver):
@pytest.mark.parametrize(
'driver',
drivers,
- ids=drivers,
)
async def test_exclude_no_types(driver):
"""Test normal behavior when no types are excluded (baseline test)."""
@@ -314,7 +310,6 @@ def test_validation_invalid_excluded_types():
@pytest.mark.parametrize(
'driver',
drivers,
- ids=drivers,
)
async def test_excluded_types_parameter_validation_in_add_episode(driver):
"""Test that add_episode validates excluded_entity_types parameter."""
diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py
index 276191d0..90033f35 100644
--- a/tests/test_graphiti_int.py
+++ b/tests/test_graphiti_int.py
@@ -23,7 +23,7 @@ from graphiti_core.graphiti import Graphiti
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
from graphiti_core.search.search_helpers import search_results_to_context_string
from graphiti_core.utils.datetime_utils import utc_now
-from tests.helpers_test import drivers, get_driver
+from tests.helpers_test import GraphProvider
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
@@ -51,15 +51,12 @@ def setup_logging():
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- 'driver',
- drivers,
- ids=drivers,
-)
-async def test_graphiti_init(driver):
+async def test_graphiti_init(graph_driver):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
logger = setup_logging()
- driver = get_driver(driver)
- graphiti = Graphiti(graph_driver=driver)
+ graphiti = Graphiti(graph_driver=graph_driver)
await graphiti.build_indices_and_constraints()
diff --git a/tests/test_graphiti_mock.py b/tests/test_graphiti_mock.py
new file mode 100644
index 00000000..9426dc9f
--- /dev/null
+++ b/tests/test_graphiti_mock.py
@@ -0,0 +1,2056 @@
+"""
+Copyright 2024, Zep Software, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+from datetime import datetime, timedelta
+from unittest.mock import Mock
+
+import numpy as np
+import pytest
+
+from graphiti_core.cross_encoder.client import CrossEncoderClient
+from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
+from graphiti_core.graphiti import Graphiti
+from graphiti_core.llm_client import LLMClient
+from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
+from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
+from graphiti_core.search.search_utils import (
+ community_fulltext_search,
+ community_similarity_search,
+ edge_bfs_search,
+ edge_fulltext_search,
+ edge_similarity_search,
+ episode_fulltext_search,
+ episode_mentions_reranker,
+ get_communities_by_nodes,
+ get_edge_invalidation_candidates,
+ get_embeddings_for_communities,
+ get_embeddings_for_edges,
+ get_embeddings_for_nodes,
+ get_mentioned_nodes,
+ get_relevant_edges,
+ get_relevant_nodes,
+ node_bfs_search,
+ node_distance_reranker,
+ node_fulltext_search,
+ node_similarity_search,
+)
+from graphiti_core.utils.bulk_utils import add_nodes_and_edges_bulk
+from graphiti_core.utils.maintenance.community_operations import (
+ determine_entity_community,
+ get_community_clusters,
+ remove_communities,
+)
+from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
+from tests.helpers_test import (
+ GraphProvider,
+ assert_entity_edge_equals,
+ assert_entity_node_equals,
+ assert_episodic_edge_equals,
+ assert_episodic_node_equals,
+ get_edge_count,
+ get_node_count,
+ group_id,
+ group_id_2,
+)
+
+pytest_plugins = ('pytest_asyncio',)
+
+
+@pytest.fixture
+def mock_llm_client():
+ """Create a mock LLM"""
+ mock_llm = Mock(spec=LLMClient)
+ mock_llm.config = Mock()
+ mock_llm.model = 'test-model'
+ mock_llm.small_model = 'test-small-model'
+ mock_llm.temperature = 0.0
+ mock_llm.max_tokens = 1000
+ mock_llm.cache_enabled = False
+ mock_llm.cache_dir = None
+
+ # Mock the public method that's actually called
+ mock_llm.generate_response = Mock()
+ mock_llm.generate_response.return_value = {
+ 'tool_calls': [
+ {
+ 'name': 'extract_entities',
+ 'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
+ }
+ ]
+ }
+
+ return mock_llm
+
+
+@pytest.fixture
+def mock_cross_encoder_client():
+ """Create a mock LLM"""
+ mock_llm = Mock(spec=CrossEncoderClient)
+ mock_llm.config = Mock()
+
+ # Mock the public method that's actually called
+ mock_llm.rerank = Mock()
+ mock_llm.rerank.return_value = {
+ 'tool_calls': [
+ {
+ 'name': 'extract_entities',
+ 'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
+ }
+ ]
+ }
+
+ return mock_llm
+
+
+@pytest.mark.asyncio
+async def test_add_bulk(
+ graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as test fails on FalkorDB')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+
+ await graphiti.build_indices_and_constraints()
+
+ now = datetime.now()
+
+ # Create episodic nodes
+ episode_node_1 = EpisodicNode(
+ name='test_episode',
+ group_id=group_id,
+ labels=[],
+ created_at=now,
+ source=EpisodeType.message,
+ source_description='conversation message',
+ content='Alice likes Bob',
+ valid_at=now,
+ entity_edges=[], # Filled in later
+ )
+ episode_node_2 = EpisodicNode(
+ name='test_episode_2',
+ group_id=group_id,
+ labels=[],
+ created_at=now,
+ source=EpisodeType.message,
+ source_description='conversation message',
+ content='Bob adores Alice',
+ valid_at=now,
+ entity_edges=[], # Filled in later
+ )
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ group_id=group_id,
+ labels=['Entity', 'Person'],
+ created_at=now,
+ summary='test_entity_1 summary',
+ attributes={'age': 30, 'location': 'New York'},
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ group_id=group_id,
+ labels=['Entity', 'Person2'],
+ created_at=now,
+ summary='test_entity_2 summary',
+ attributes={'age': 25, 'location': 'Los Angeles'},
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+
+ entity_node_3 = EntityNode(
+ name='test_entity_3',
+ group_id=group_id,
+ labels=['Entity', 'City', 'Location'],
+ created_at=now,
+ summary='test_entity_3 summary',
+ attributes={'age': 25, 'location': 'Los Angeles'},
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+
+ entity_node_4 = EntityNode(
+ name='test_entity_4',
+ group_id=group_id,
+ labels=['Entity'],
+ created_at=now,
+ summary='test_entity_4 summary',
+ attributes={'age': 25, 'location': 'Los Angeles'},
+ )
+ await entity_node_4.generate_name_embedding(mock_embedder)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ created_at=now,
+ name='likes',
+ fact='test_entity_1 relates to test_entity_2',
+ episodes=[],
+ expired_at=now,
+ valid_at=now,
+ invalid_at=now,
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+
+ entity_edge_2 = EntityEdge(
+ source_node_uuid=entity_node_3.uuid,
+ target_node_uuid=entity_node_4.uuid,
+ created_at=now,
+ name='relates_to',
+ fact='test_entity_3 relates to test_entity_4',
+ episodes=[],
+ expired_at=now,
+ valid_at=now,
+ invalid_at=now,
+ group_id=group_id,
+ )
+ await entity_edge_2.generate_embedding(mock_embedder)
+
+ # Create episodic to entity edges
+ episodic_edge_1 = EpisodicEdge(
+ source_node_uuid=episode_node_1.uuid,
+ target_node_uuid=entity_node_1.uuid,
+ created_at=now,
+ group_id=group_id,
+ )
+ episodic_edge_2 = EpisodicEdge(
+ source_node_uuid=episode_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ created_at=now,
+ group_id=group_id,
+ )
+ episodic_edge_3 = EpisodicEdge(
+ source_node_uuid=episode_node_2.uuid,
+ target_node_uuid=entity_node_3.uuid,
+ created_at=now,
+ group_id=group_id,
+ )
+ episodic_edge_4 = EpisodicEdge(
+ source_node_uuid=episode_node_2.uuid,
+ target_node_uuid=entity_node_4.uuid,
+ created_at=now,
+ group_id=group_id,
+ )
+
+ # Cross reference the ids
+ episode_node_1.entity_edges = [entity_edge_1.uuid]
+ episode_node_2.entity_edges = [entity_edge_2.uuid]
+ entity_edge_1.episodes = [episode_node_1.uuid, episode_node_2.uuid]
+ entity_edge_2.episodes = [episode_node_2.uuid]
+
+ # Test add bulk
+ await add_nodes_and_edges_bulk(
+ graph_driver,
+ [episode_node_1, episode_node_2],
+ [episodic_edge_1, episodic_edge_2, episodic_edge_3, episodic_edge_4],
+ [entity_node_1, entity_node_2, entity_node_3, entity_node_4],
+ [entity_edge_1, entity_edge_2],
+ mock_embedder,
+ )
+
+ node_ids = [episode_node_1.uuid, episode_node_2.uuid, entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
+ edge_ids = [episodic_edge_1.uuid, episodic_edge_2.uuid, episodic_edge_3.uuid, episodic_edge_4.uuid, entity_edge_1.uuid, entity_edge_2.uuid]
+ node_count = await get_node_count(graph_driver, node_ids)
+ assert node_count == len(node_ids)
+ edge_count = await get_edge_count(graph_driver, edge_ids)
+ assert edge_count == len(edge_ids)
+
+ # Test episodic nodes
+ retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_1.uuid)
+ await assert_episodic_node_equals(retrieved_episode, episode_node_1)
+
+ retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_2.uuid)
+ await assert_episodic_node_equals(retrieved_episode, episode_node_2)
+
+ # Test entity nodes
+ retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_1.uuid)
+ await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_1)
+
+ retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid)
+ await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2)
+
+
+ retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid)
+ await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3)
+
+ retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_4.uuid)
+ await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_4)
+
+ # Test episodic edges
+ retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_1.uuid)
+ await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_1)
+
+ retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_2.uuid)
+ await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_2)
+
+ retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_3.uuid)
+ await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_3)
+
+ retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_4.uuid)
+ await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_4)
+
+ # Test entity edges
+ retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_1.uuid)
+ await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_1)
+
+ retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid)
+ await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2)
+
+@pytest.mark.asyncio
+async def test_remove_episode(
+ graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
+):
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+
+ await graphiti.build_indices_and_constraints()
+
+ now = datetime.now()
+
+ # Create episodic nodes
+ episode_node = EpisodicNode(
+ name='test_episode',
+ group_id=group_id,
+ labels=[],
+ created_at=now,
+ source=EpisodeType.message,
+ source_description='conversation message',
+ content='Alice likes Bob',
+ valid_at=now,
+ entity_edges=[], # Filled in later
+ )
+
+ # Create entity nodes
+ alice_node = EntityNode(
+ name='Alice',
+ group_id=group_id,
+ labels=['Entity', 'Person'],
+ created_at=now,
+ summary='Alice summary',
+ attributes={'age': 30, 'location': 'New York'},
+ )
+ await alice_node.generate_name_embedding(mock_embedder)
+
+ bob_node = EntityNode(
+ name='Bob',
+ group_id=group_id,
+ labels=['Entity', 'Person2'],
+ created_at=now,
+ summary='Bob summary',
+ attributes={'age': 25, 'location': 'Los Angeles'},
+ )
+ await bob_node.generate_name_embedding(mock_embedder)
+
+ # Create entity to entity edge
+ entity_edge = EntityEdge(
+ source_node_uuid=alice_node.uuid,
+ target_node_uuid=bob_node.uuid,
+ created_at=now,
+ name='likes',
+ fact='Alice likes Bob',
+ episodes=[],
+ expired_at=now,
+ valid_at=now,
+ invalid_at=now,
+ group_id=group_id,
+ )
+ await entity_edge.generate_embedding(mock_embedder)
+
+ # Create episodic to entity edges
+ episodic_alice_edge = EpisodicEdge(
+ source_node_uuid=episode_node.uuid,
+ target_node_uuid=alice_node.uuid,
+ created_at=now,
+ group_id=group_id,
+ )
+ episodic_bob_edge = EpisodicEdge(
+ source_node_uuid=episode_node.uuid,
+ target_node_uuid=bob_node.uuid,
+ created_at=now,
+ group_id=group_id,
+ )
+
+ # Cross reference the ids
+ episode_node.entity_edges = [entity_edge.uuid]
+ entity_edge.episodes = [episode_node.uuid]
+
+ # Test add bulk
+ await add_nodes_and_edges_bulk(
+ graph_driver,
+ [episode_node],
+ [episodic_alice_edge, episodic_bob_edge],
+ [alice_node, bob_node],
+ [entity_edge],
+ mock_embedder,
+ )
+
+ node_ids = [episode_node.uuid, alice_node.uuid, bob_node.uuid]
+ edge_ids = [episodic_alice_edge.uuid, episodic_bob_edge.uuid, entity_edge.uuid]
+ node_count = await get_node_count(graph_driver, node_ids)
+ assert node_count == 3
+ edge_count = await get_edge_count(graph_driver, edge_ids)
+ assert edge_count == 3
+
+ # Test remove episode
+ await graphiti.remove_episode(episode_node.uuid)
+ node_count = await get_node_count(graph_driver, node_ids)
+ assert node_count == 0
+ edge_count = await get_edge_count(graph_driver, edge_ids)
+ assert edge_count == 0
+
+ # Test add bulk again
+ await add_nodes_and_edges_bulk(
+ graph_driver,
+ [episode_node],
+ [episodic_alice_edge, episodic_bob_edge],
+ [alice_node, bob_node],
+ [entity_edge],
+ mock_embedder,
+ )
+ node_count = await get_node_count(graph_driver, node_ids)
+ assert node_count == 3
+ edge_count = await get_edge_count(graph_driver, edge_ids)
+ assert edge_count == 3
+
+
+@pytest.mark.asyncio
+async def test_graphiti_retrieve_episodes(
+ graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as test fails on FalkorDB')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+
+ await graphiti.build_indices_and_constraints()
+
+ now = datetime.now()
+ valid_at_1 = now - timedelta(days=2)
+ valid_at_2 = now - timedelta(days=4)
+ valid_at_3 = now - timedelta(days=6)
+
+ # Create episodic nodes
+ episode_node_1 = EpisodicNode(
+ name='test_episode_1',
+ labels=[],
+ created_at=now,
+ valid_at=valid_at_1,
+ source=EpisodeType.message,
+ source_description='conversation message',
+ content='Test message 1',
+ entity_edges=[],
+ group_id=group_id,
+ )
+ episode_node_2 = EpisodicNode(
+ name='test_episode_2',
+ labels=[],
+ created_at=now,
+ valid_at=valid_at_2,
+ source=EpisodeType.message,
+ source_description='conversation message',
+ content='Test message 2',
+ entity_edges=[],
+ group_id=group_id,
+ )
+ episode_node_3 = EpisodicNode(
+ name='test_episode_3',
+ labels=[],
+ created_at=now,
+ valid_at=valid_at_3,
+ source=EpisodeType.message,
+ source_description='conversation message',
+ content='Test message 3',
+ entity_edges=[],
+ group_id=group_id,
+ )
+
+ # Save the nodes
+ await episode_node_1.save(graph_driver)
+ await episode_node_2.save(graph_driver)
+ await episode_node_3.save(graph_driver)
+
+ node_ids = [episode_node_1.uuid, episode_node_2.uuid, episode_node_3.uuid]
+ node_count = await get_node_count(graph_driver, node_ids)
+ assert node_count == 3
+
+ # Retrieve episodes
+ query_time = now - timedelta(days=3)
+ episodes = await graphiti.retrieve_episodes(
+ query_time, last_n=5, group_ids=[group_id], source=EpisodeType.message
+ )
+ assert len(episodes) == 2
+ assert episodes[0].name == episode_node_3.name
+ assert episodes[1].name == episode_node_2.name
+
+
+@pytest.mark.asyncio
+async def test_filter_existing_duplicate_of_edges(graph_driver, mock_embedder):
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+ entity_node_3 = EntityNode(
+ name='test_entity_3',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+ entity_node_4 = EntityNode(
+ name='test_entity_4',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_4.generate_name_embedding(mock_embedder)
+
+ # Save the nodes
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_node_3.save(graph_driver)
+ await entity_node_4.save(graph_driver)
+
+ node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
+ node_count = await get_node_count(graph_driver, node_ids)
+ assert node_count == 4
+
+ # Create duplicate entity edge
+ entity_edge = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='IS_DUPLICATE_OF',
+ fact='test_entity_1 is a duplicate of test_entity_2',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge.generate_embedding(mock_embedder)
+ await entity_edge.save(graph_driver)
+
+ # Filter duplicate entity edges
+ duplicate_node_tuples = [
+ (entity_node_1, entity_node_2),
+ (entity_node_3, entity_node_4),
+ ]
+ node_tuples = await filter_existing_duplicate_of_edges(graph_driver, duplicate_node_tuples)
+ assert len(node_tuples) == 1
+ assert [node.name for node in node_tuples[0]] == [entity_node_3.name, entity_node_4.name]
+
+
+@pytest.mark.asyncio
+async def test_determine_entity_community(graph_driver, mock_embedder):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as test fails on FalkorDB')
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+ entity_node_3 = EntityNode(
+ name='test_entity_3',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+ entity_node_4 = EntityNode(
+ name='test_entity_4',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_4.generate_name_embedding(mock_embedder)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_4.uuid,
+ name='RELATES_TO',
+ fact='test_entity_1 relates to test_entity_4',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+ entity_edge_2 = EntityEdge(
+ source_node_uuid=entity_node_2.uuid,
+ target_node_uuid=entity_node_4.uuid,
+ name='RELATES_TO',
+ fact='test_entity_2 relates to test_entity_4',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge_2.generate_embedding(mock_embedder)
+ entity_edge_3 = EntityEdge(
+ source_node_uuid=entity_node_3.uuid,
+ target_node_uuid=entity_node_4.uuid,
+ name='RELATES_TO',
+ fact='test_entity_3 relates to test_entity_4',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge_3.generate_embedding(mock_embedder)
+
+ # Create community nodes
+ community_node_1 = CommunityNode(
+ name='test_community_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_node_1.generate_name_embedding(mock_embedder)
+ community_node_2 = CommunityNode(
+ name='test_community_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_node_2.generate_name_embedding(mock_embedder)
+
+ # Create community to entity edges
+ community_edge_1 = CommunityEdge(
+ source_node_uuid=community_node_1.uuid,
+ target_node_uuid=entity_node_1.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ community_edge_2 = CommunityEdge(
+ source_node_uuid=community_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ community_edge_3 = CommunityEdge(
+ source_node_uuid=community_node_2.uuid,
+ target_node_uuid=entity_node_3.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_node_3.save(graph_driver)
+ await entity_node_4.save(graph_driver)
+ await community_node_1.save(graph_driver)
+ await community_node_2.save(graph_driver)
+
+ await entity_edge_1.save(graph_driver)
+ await entity_edge_2.save(graph_driver)
+ await entity_edge_3.save(graph_driver)
+ await community_edge_1.save(graph_driver)
+ await community_edge_2.save(graph_driver)
+ await community_edge_3.save(graph_driver)
+
+ node_ids = [
+ entity_node_1.uuid,
+ entity_node_2.uuid,
+ entity_node_3.uuid,
+ entity_node_4.uuid,
+ community_node_1.uuid,
+ community_node_2.uuid,
+ ]
+ edge_ids = [
+ entity_edge_1.uuid,
+ entity_edge_2.uuid,
+ entity_edge_3.uuid,
+ community_edge_1.uuid,
+ community_edge_2.uuid,
+ community_edge_3.uuid,
+ ]
+ node_count = await get_node_count(graph_driver, node_ids)
+ assert node_count == 6
+ edge_count = await get_edge_count(graph_driver, edge_ids)
+ assert edge_count == 6
+
+ # Determine entity community
+ community, is_new = await determine_entity_community(graph_driver, entity_node_4)
+ assert community.name == community_node_1.name
+ assert is_new
+
+ # Add entity to community edge
+ community_edge_4 = CommunityEdge(
+ source_node_uuid=community_node_1.uuid,
+ target_node_uuid=entity_node_4.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_edge_4.save(graph_driver)
+
+ # Determine entity community again
+ community, is_new = await determine_entity_community(graph_driver, entity_node_4)
+ assert community.name == community_node_1.name
+ assert not is_new
+
+ await remove_communities(graph_driver)
+ node_count = await get_node_count(graph_driver, [community_node_1.uuid, community_node_2.uuid])
+ assert node_count == 0
+
+
+@pytest.mark.asyncio
+async def test_get_community_clusters(graph_driver, mock_embedder):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as test fails on FalkorDB')
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+ entity_node_3 = EntityNode(
+ name='test_entity_3',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id_2,
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+ entity_node_4 = EntityNode(
+ name='test_entity_4',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id_2,
+ )
+ await entity_node_4.generate_name_embedding(mock_embedder)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='RELATES_TO',
+ fact='test_entity_1 relates to test_entity_2',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+ entity_edge_2 = EntityEdge(
+ source_node_uuid=entity_node_3.uuid,
+ target_node_uuid=entity_node_4.uuid,
+ name='RELATES_TO',
+ fact='test_entity_3 relates to test_entity_4',
+ created_at=datetime.now(),
+ group_id=group_id_2,
+ )
+ await entity_edge_2.generate_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_node_3.save(graph_driver)
+ await entity_node_4.save(graph_driver)
+ await entity_edge_1.save(graph_driver)
+ await entity_edge_2.save(graph_driver)
+
+ node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
+ edge_ids = [entity_edge_1.uuid, entity_edge_2.uuid]
+ node_count = await get_node_count(graph_driver, node_ids)
+ assert node_count == 4
+ edge_count = await get_edge_count(graph_driver, edge_ids)
+ assert edge_count == 2
+
+ # Get community clusters
+ clusters = await get_community_clusters(graph_driver, group_ids=None)
+ assert len(clusters) == 2
+ assert len(clusters[0]) == 2
+ assert len(clusters[1]) == 2
+ entities_1 = set([node.name for node in clusters[0]])
+ entities_2 = set([node.name for node in clusters[1]])
+ assert entities_1 == set(['test_entity_1', 'test_entity_2']) or entities_2 == set(
+ ['test_entity_1', 'test_entity_2']
+ )
+ assert entities_1 == set(['test_entity_3', 'test_entity_4']) or entities_2 == set(
+ ['test_entity_3', 'test_entity_4']
+ )
+
+
+@pytest.mark.asyncio
+async def test_get_mentioned_nodes(graph_driver, mock_embedder):
+ # Create episodic nodes
+ episodic_node_1 = EpisodicNode(
+ name='test_episodic_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ source=EpisodeType.message,
+ source_description='test_source_description',
+ content='test_content',
+ valid_at=datetime.now(),
+ )
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+
+ # Create episodic to entity edges
+ episodic_edge_1 = EpisodicEdge(
+ source_node_uuid=episodic_node_1.uuid,
+ target_node_uuid=entity_node_1.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+
+ # Save the graph
+ await episodic_node_1.save(graph_driver)
+ await entity_node_1.save(graph_driver)
+ await episodic_edge_1.save(graph_driver)
+
+ # Get mentioned nodes
+ mentioned_nodes = await get_mentioned_nodes(graph_driver, [episodic_node_1])
+ assert len(mentioned_nodes) == 1
+ assert mentioned_nodes[0].name == entity_node_1.name
+
+
+@pytest.mark.asyncio
+async def test_get_communities_by_nodes(graph_driver, mock_embedder):
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+
+ # Create community nodes
+ community_node_1 = CommunityNode(
+ name='test_community_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_node_1.generate_name_embedding(mock_embedder)
+
+ # Create community to entity edges
+ community_edge_1 = CommunityEdge(
+ source_node_uuid=community_node_1.uuid,
+ target_node_uuid=entity_node_1.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await community_node_1.save(graph_driver)
+ await community_edge_1.save(graph_driver)
+
+ # Get communities by nodes
+ communities = await get_communities_by_nodes(graph_driver, [entity_node_1])
+ assert len(communities) == 1
+ assert communities[0].name == community_node_1.name
+
+
+@pytest.mark.asyncio
+async def test_edge_fulltext_search(
+ graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+ await graphiti.build_indices_and_constraints()
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+
+ now = datetime.now()
+ created_at = now
+ expired_at = now + timedelta(days=6)
+ valid_at = now + timedelta(days=2)
+ invalid_at = now + timedelta(days=4)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='RELATES_TO',
+ fact='test_entity_1 relates to test_entity_2',
+ created_at=created_at,
+ valid_at=valid_at,
+ invalid_at=invalid_at,
+ expired_at=expired_at,
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_edge_1.save(graph_driver)
+
+ # Search for entity edges
+ search_filters = SearchFilters(
+ node_labels=['Entity'],
+ edge_types=['RELATES_TO'],
+ created_at=[
+ [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
+ ],
+ expired_at=[
+ [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
+ ],
+ valid_at=[
+ [
+ DateFilter(
+ date=now + timedelta(days=1),
+ comparison_operator=ComparisonOperator.greater_than_equal,
+ )
+ ],
+ [
+ DateFilter(
+ date=now + timedelta(days=3),
+ comparison_operator=ComparisonOperator.less_than_equal,
+ )
+ ],
+ ],
+ invalid_at=[
+ [
+ DateFilter(
+ date=now + timedelta(days=3),
+ comparison_operator=ComparisonOperator.greater_than,
+ )
+ ],
+ [
+ DateFilter(
+ date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
+ )
+ ],
+ ],
+ )
+ edges = await edge_fulltext_search(
+ graph_driver, 'test_entity_1 relates to test_entity_2', search_filters, group_ids=[group_id]
+ )
+ assert len(edges) == 1
+ assert edges[0].name == entity_edge_1.name
+
+
+@pytest.mark.asyncio
+async def test_edge_similarity_search(graph_driver, mock_embedder):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+
+ now = datetime.now()
+ created_at = now
+ expired_at = now + timedelta(days=6)
+ valid_at = now + timedelta(days=2)
+ invalid_at = now + timedelta(days=4)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='RELATES_TO',
+ fact='test_entity_1 relates to test_entity_2',
+ created_at=created_at,
+ valid_at=valid_at,
+ invalid_at=invalid_at,
+ expired_at=expired_at,
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_edge_1.save(graph_driver)
+
+ # Search for entity edges
+ search_filters = SearchFilters(
+ node_labels=['Entity'],
+ edge_types=['RELATES_TO'],
+ created_at=[
+ [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
+ ],
+ expired_at=[
+ [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
+ ],
+ valid_at=[
+ [
+ DateFilter(
+ date=now + timedelta(days=1),
+ comparison_operator=ComparisonOperator.greater_than_equal,
+ )
+ ],
+ [
+ DateFilter(
+ date=now + timedelta(days=3),
+ comparison_operator=ComparisonOperator.less_than_equal,
+ )
+ ],
+ ],
+ invalid_at=[
+ [
+ DateFilter(
+ date=now + timedelta(days=3),
+ comparison_operator=ComparisonOperator.greater_than,
+ )
+ ],
+ [
+ DateFilter(
+ date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
+ )
+ ],
+ ],
+ )
+ edges = await edge_similarity_search(
+ graph_driver,
+ entity_edge_1.fact_embedding,
+ entity_node_1.uuid,
+ entity_node_2.uuid,
+ search_filters,
+ group_ids=[group_id],
+ )
+ assert len(edges) == 1
+ assert edges[0].name == entity_edge_1.name
+
+
+@pytest.mark.asyncio
+async def test_edge_bfs_search(graph_driver, mock_embedder):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ # Create episodic nodes
+ episodic_node_1 = EpisodicNode(
+ name='test_episodic_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ source=EpisodeType.message,
+ source_description='test_source_description',
+ content='test_content',
+ valid_at=datetime.now(),
+ )
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+ entity_node_3 = EntityNode(
+ name='test_entity_3',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+
+ now = datetime.now()
+ created_at = now
+ expired_at = now + timedelta(days=6)
+ valid_at = now + timedelta(days=2)
+ invalid_at = now + timedelta(days=4)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='RELATES_TO',
+ fact='test_entity_1 relates to test_entity_2',
+ created_at=created_at,
+ valid_at=valid_at,
+ invalid_at=invalid_at,
+ expired_at=expired_at,
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+ entity_edge_2 = EntityEdge(
+ source_node_uuid=entity_node_2.uuid,
+ target_node_uuid=entity_node_3.uuid,
+ name='RELATES_TO',
+ fact='test_entity_2 relates to test_entity_3',
+ created_at=created_at,
+ valid_at=valid_at,
+ invalid_at=invalid_at,
+ expired_at=expired_at,
+ group_id=group_id,
+ )
+ await entity_edge_2.generate_embedding(mock_embedder)
+
+ # Create episodic to entity edges
+ episodic_edge_1 = EpisodicEdge(
+ source_node_uuid=episodic_node_1.uuid,
+ target_node_uuid=entity_node_1.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+
+ # Save the graph
+ await episodic_node_1.save(graph_driver)
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_node_3.save(graph_driver)
+ await entity_edge_1.save(graph_driver)
+ await entity_edge_2.save(graph_driver)
+ await episodic_edge_1.save(graph_driver)
+
+ # Search for entity edges
+ search_filters = SearchFilters(
+ node_labels=['Entity'],
+ edge_types=['RELATES_TO'],
+ created_at=[
+ [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
+ ],
+ expired_at=[
+ [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
+ ],
+ valid_at=[
+ [
+ DateFilter(
+ date=now + timedelta(days=1),
+ comparison_operator=ComparisonOperator.greater_than_equal,
+ )
+ ],
+ [
+ DateFilter(
+ date=now + timedelta(days=3),
+ comparison_operator=ComparisonOperator.less_than_equal,
+ )
+ ],
+ ],
+ invalid_at=[
+ [
+ DateFilter(
+ date=now + timedelta(days=3),
+ comparison_operator=ComparisonOperator.greater_than,
+ )
+ ],
+ [
+ DateFilter(
+ date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
+ )
+ ],
+ ],
+ )
+
+ # Test bfs from episodic node
+
+ edges = await edge_bfs_search(
+ graph_driver,
+ [episodic_node_1.uuid],
+ 1,
+ search_filters,
+ group_ids=[group_id],
+ )
+ assert len(edges) == 0
+
+ edges = await edge_bfs_search(
+ graph_driver,
+ [episodic_node_1.uuid],
+ 2,
+ search_filters,
+ group_ids=[group_id],
+ )
+ edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
+ assert len(edges_deduplicated) == 1
+ assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
+
+ edges = await edge_bfs_search(
+ graph_driver,
+ [episodic_node_1.uuid],
+ 3,
+ search_filters,
+ group_ids=[group_id],
+ )
+ edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
+ assert len(edges_deduplicated) == 2
+ assert edges_deduplicated == {
+ 'test_entity_1 relates to test_entity_2',
+ 'test_entity_2 relates to test_entity_3',
+ }
+
+ # Test bfs from entity node
+
+ edges = await edge_bfs_search(
+ graph_driver,
+ [entity_node_1.uuid],
+ 1,
+ search_filters,
+ group_ids=[group_id],
+ )
+ edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
+ assert len(edges_deduplicated) == 1
+ assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
+
+ edges = await edge_bfs_search(
+ graph_driver,
+ [entity_node_1.uuid],
+ 2,
+ search_filters,
+ group_ids=[group_id],
+ )
+ edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
+ assert len(edges_deduplicated) == 2
+ assert edges_deduplicated == {
+ 'test_entity_1 relates to test_entity_2',
+ 'test_entity_2 relates to test_entity_3',
+ }
+
+
+@pytest.mark.asyncio
+async def test_node_fulltext_search(
+ graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+ await graphiti.build_indices_and_constraints()
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ summary='Summary about Alice',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ summary='Summary about Bob',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+
+ # Search for entity edges
+ search_filters = SearchFilters(node_labels=['Entity'])
+ nodes = await node_fulltext_search(
+ graph_driver,
+ 'Alice',
+ search_filters,
+ group_ids=[group_id],
+ )
+ assert len(nodes) == 1
+ assert nodes[0].name == entity_node_1.name
+
+
+@pytest.mark.asyncio
+async def test_node_similarity_search(graph_driver, mock_embedder):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_alice',
+ summary='Summary about Alice',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_bob',
+ summary='Summary about Bob',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+
+ # Search for entity edges
+ search_filters = SearchFilters(node_labels=['Entity'])
+ nodes = await node_similarity_search(
+ graph_driver,
+ entity_node_1.name_embedding,
+ search_filters,
+ group_ids=[group_id],
+ min_score=0.9,
+ )
+ assert len(nodes) == 1
+ assert nodes[0].name == entity_node_1.name
+
+
+@pytest.mark.asyncio
+async def test_node_bfs_search(graph_driver, mock_embedder):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ # Create episodic nodes
+ episodic_node_1 = EpisodicNode(
+ name='test_episodic_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ source=EpisodeType.message,
+ source_description='test_source_description',
+ content='test_content',
+ valid_at=datetime.now(),
+ )
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+ entity_node_3 = EntityNode(
+ name='test_entity_3',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='RELATES_TO',
+ fact='test_entity_1 relates to test_entity_2',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+ entity_edge_2 = EntityEdge(
+ source_node_uuid=entity_node_2.uuid,
+ target_node_uuid=entity_node_3.uuid,
+ name='RELATES_TO',
+ fact='test_entity_2 relates to test_entity_3',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge_2.generate_embedding(mock_embedder)
+
+ # Create episodic to entity edges
+ episodic_edge_1 = EpisodicEdge(
+ source_node_uuid=episodic_node_1.uuid,
+ target_node_uuid=entity_node_1.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+
+ # Save the graph
+ await episodic_node_1.save(graph_driver)
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_node_3.save(graph_driver)
+ await entity_edge_1.save(graph_driver)
+ await entity_edge_2.save(graph_driver)
+ await episodic_edge_1.save(graph_driver)
+
+ # Search for entity nodes
+ search_filters = SearchFilters(
+ node_labels=['Entity'],
+ )
+
+ # Test bfs from episodic node
+
+ nodes = await node_bfs_search(
+ graph_driver,
+ [episodic_node_1.uuid],
+ search_filters,
+ 1,
+ group_ids=[group_id],
+ )
+ nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
+ assert len(nodes_deduplicated) == 1
+ assert nodes_deduplicated == {'test_entity_1'}
+
+ nodes = await node_bfs_search(
+ graph_driver,
+ [episodic_node_1.uuid],
+ search_filters,
+ 2,
+ group_ids=[group_id],
+ )
+ nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
+ assert len(nodes_deduplicated) == 2
+ assert nodes_deduplicated == {'test_entity_1', 'test_entity_2'}
+
+ # Test bfs from entity node
+
+ nodes = await node_bfs_search(
+ graph_driver,
+ [entity_node_1.uuid],
+ search_filters,
+ 1,
+ group_ids=[group_id],
+ )
+ nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
+ assert len(nodes_deduplicated) == 1
+ assert nodes_deduplicated == {'test_entity_2'}
+
+
+@pytest.mark.asyncio
+async def test_episode_fulltext_search(
+ graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+ await graphiti.build_indices_and_constraints()
+
+ # Create episodic nodes
+ episodic_node_1 = EpisodicNode(
+ name='test_episodic_1',
+ content='test_content',
+ created_at=datetime.now(),
+ valid_at=datetime.now(),
+ group_id=group_id,
+ source=EpisodeType.message,
+ source_description='Description about Alice',
+ )
+ episodic_node_2 = EpisodicNode(
+ name='test_episodic_2',
+ content='test_content_2',
+ created_at=datetime.now(),
+ valid_at=datetime.now(),
+ group_id=group_id,
+ source=EpisodeType.message,
+ source_description='Description about Bob',
+ )
+
+ # Save the graph
+ await episodic_node_1.save(graph_driver)
+ await episodic_node_2.save(graph_driver)
+
+ # Search for episodic nodes
+ search_filters = SearchFilters(node_labels=['Episodic'])
+ nodes = await episode_fulltext_search(
+ graph_driver,
+ 'Alice',
+ search_filters,
+ group_ids=[group_id],
+ )
+ assert len(nodes) == 1
+ assert nodes[0].name == episodic_node_1.name
+
+
+@pytest.mark.asyncio
+async def test_community_fulltext_search(
+ graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+ await graphiti.build_indices_and_constraints()
+
+ # Create community nodes
+ community_node_1 = CommunityNode(
+ name='Alice',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_node_1.generate_name_embedding(mock_embedder)
+ community_node_2 = CommunityNode(
+ name='Bob',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_node_2.generate_name_embedding(mock_embedder)
+
+ # Save the graph
+ await community_node_1.save(graph_driver)
+ await community_node_2.save(graph_driver)
+
+ # Search for community nodes
+ nodes = await community_fulltext_search(
+ graph_driver,
+ 'Alice',
+ group_ids=[group_id],
+ )
+ assert len(nodes) == 1
+ assert nodes[0].name == community_node_1.name
+
+
+@pytest.mark.asyncio
+async def test_community_similarity_search(
+ graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+ await graphiti.build_indices_and_constraints()
+
+ # Create community nodes
+ community_node_1 = CommunityNode(
+ name='Alice',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_node_1.generate_name_embedding(mock_embedder)
+ community_node_2 = CommunityNode(
+ name='Bob',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_node_2.generate_name_embedding(mock_embedder)
+
+ # Save the graph
+ await community_node_1.save(graph_driver)
+ await community_node_2.save(graph_driver)
+
+ # Search for community nodes
+ nodes = await community_similarity_search(
+ graph_driver,
+ community_node_1.name_embedding,
+ group_ids=[group_id],
+ min_score=0.9,
+ )
+ assert len(nodes) == 1
+ assert nodes[0].name == community_node_1.name
+
+
+@pytest.mark.asyncio
+async def test_get_relevant_nodes(
+ graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ if graph_driver.provider == GraphProvider.KUZU:
+ pytest.skip('Skipping as tests fail on Kuzu')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+ await graphiti.build_indices_and_constraints()
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='Alice',
+ summary='Alice',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='Bob',
+ summary='Bob',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+ entity_node_3 = EntityNode(
+ name='Alice Smith',
+ summary='Alice Smith',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_node_3.save(graph_driver)
+
+ # Search for entity nodes
+ search_filters = SearchFilters(node_labels=['Entity'])
+ nodes = (
+ await get_relevant_nodes(
+ graph_driver,
+ [entity_node_1],
+ search_filters,
+ min_score=0.9,
+ )
+ )[0]
+ assert len(nodes) == 2
+ assert set({node.name for node in nodes}) == {entity_node_1.name, entity_node_3.name}
+
+
+@pytest.mark.asyncio
+async def test_get_relevant_edges_and_invalidation_candidates(
+ graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
+):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ graphiti = Graphiti(
+ graph_driver=graph_driver,
+ llm_client=mock_llm_client,
+ embedder=mock_embedder,
+ cross_encoder=mock_cross_encoder_client,
+ )
+ await graphiti.build_indices_and_constraints()
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ summary='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ summary='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+ entity_node_3 = EntityNode(
+ name='test_entity_3',
+ summary='test_entity_3',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+
+ now = datetime.now()
+ created_at = now
+ expired_at = now + timedelta(days=6)
+ valid_at = now + timedelta(days=2)
+ invalid_at = now + timedelta(days=4)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='RELATES_TO',
+ fact='Alice',
+ created_at=created_at,
+ expired_at=expired_at,
+ valid_at=valid_at,
+ invalid_at=invalid_at,
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+ entity_edge_2 = EntityEdge(
+ source_node_uuid=entity_node_2.uuid,
+ target_node_uuid=entity_node_3.uuid,
+ name='RELATES_TO',
+ fact='Bob',
+ created_at=created_at,
+ expired_at=expired_at,
+ valid_at=valid_at,
+ invalid_at=invalid_at,
+ group_id=group_id,
+ )
+ await entity_edge_2.generate_embedding(mock_embedder)
+ entity_edge_3 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_3.uuid,
+ name='RELATES_TO',
+ fact='Alice',
+ created_at=created_at,
+ expired_at=expired_at,
+ valid_at=valid_at,
+ invalid_at=invalid_at,
+ group_id=group_id,
+ )
+ await entity_edge_3.generate_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_node_3.save(graph_driver)
+ await entity_edge_1.save(graph_driver)
+ await entity_edge_2.save(graph_driver)
+ await entity_edge_3.save(graph_driver)
+
+ # Search for entity nodes
+ search_filters = SearchFilters(
+ node_labels=['Entity'],
+ edge_types=['RELATES_TO'],
+ created_at=[
+ [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
+ ],
+ expired_at=[
+ [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
+ ],
+ valid_at=[
+ [
+ DateFilter(
+ date=now + timedelta(days=1),
+ comparison_operator=ComparisonOperator.greater_than_equal,
+ )
+ ],
+ [
+ DateFilter(
+ date=now + timedelta(days=3),
+ comparison_operator=ComparisonOperator.less_than_equal,
+ )
+ ],
+ ],
+ invalid_at=[
+ [
+ DateFilter(
+ date=now + timedelta(days=3),
+ comparison_operator=ComparisonOperator.greater_than,
+ )
+ ],
+ [
+ DateFilter(
+ date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
+ )
+ ],
+ ],
+ )
+ edges = (
+ await get_relevant_edges(
+ graph_driver,
+ [entity_edge_1],
+ search_filters,
+ min_score=0.9,
+ )
+ )[0]
+ assert len(edges) == 1
+ assert set({edge.name for edge in edges}) == {entity_edge_1.name}
+
+ edges = (
+ await get_edge_invalidation_candidates(
+ graph_driver,
+ [entity_edge_1],
+ search_filters,
+ min_score=0.9,
+ )
+ )[0]
+ assert len(edges) == 2
+ assert set({edge.name for edge in edges}) == {entity_edge_1.name, entity_edge_3.name}
+
+
+@pytest.mark.asyncio
+async def test_node_distance_reranker(graph_driver, mock_embedder):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+ entity_node_3 = EntityNode(
+ name='test_entity_3',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_3.generate_name_embedding(mock_embedder)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='RELATES_TO',
+ fact='test_entity_1 relates to test_entity_2',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_node_3.save(graph_driver)
+ await entity_edge_1.save(graph_driver)
+
+ # Test reranker
+ reranked_uuids, reranked_scores = await node_distance_reranker(
+ graph_driver,
+ [entity_node_2.uuid, entity_node_3.uuid],
+ entity_node_1.uuid,
+ )
+ uuid_to_name = {
+ entity_node_1.uuid: entity_node_1.name,
+ entity_node_2.uuid: entity_node_2.name,
+ entity_node_3.uuid: entity_node_3.name,
+ }
+ names = [uuid_to_name[uuid] for uuid in reranked_uuids]
+ assert names == [entity_node_2.name, entity_node_3.name]
+ assert np.allclose(reranked_scores, [1.0, 0.0])
+
+
+@pytest.mark.asyncio
+async def test_episode_mentions_reranker(graph_driver, mock_embedder):
+ if graph_driver.provider == GraphProvider.FALKORDB:
+ pytest.skip('Skipping as tests fail on Falkordb')
+
+ # Create episodic nodes
+ episodic_node_1 = EpisodicNode(
+ name='test_episodic_1',
+ content='test_content',
+ created_at=datetime.now(),
+ valid_at=datetime.now(),
+ group_id=group_id,
+ source=EpisodeType.message,
+ source_description='Description about Alice',
+ )
+
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+
+ # Create entity edges
+ episodic_edge_1 = EpisodicEdge(
+ source_node_uuid=episodic_node_1.uuid,
+ target_node_uuid=entity_node_1.uuid,
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await episodic_node_1.save(graph_driver)
+ await episodic_edge_1.save(graph_driver)
+
+ # Test reranker
+ reranked_uuids, reranked_scores = await episode_mentions_reranker(
+ graph_driver,
+ [[entity_node_1.uuid, entity_node_2.uuid]],
+ )
+ uuid_to_name = {entity_node_1.uuid: entity_node_1.name, entity_node_2.uuid: entity_node_2.name}
+ names = [uuid_to_name[uuid] for uuid in reranked_uuids]
+ assert names == [entity_node_1.name, entity_node_2.name]
+ assert np.allclose(reranked_scores, [1.0, float('inf')])
+
+
+@pytest.mark.asyncio
+async def test_get_embeddings_for_edges(graph_driver, mock_embedder):
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+ entity_node_2 = EntityNode(
+ name='test_entity_2',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_2.generate_name_embedding(mock_embedder)
+
+ # Create entity edges
+ entity_edge_1 = EntityEdge(
+ source_node_uuid=entity_node_1.uuid,
+ target_node_uuid=entity_node_2.uuid,
+ name='RELATES_TO',
+ fact='test_entity_1 relates to test_entity_2',
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_edge_1.generate_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+ await entity_node_2.save(graph_driver)
+ await entity_edge_1.save(graph_driver)
+
+ # Get embeddings for edges
+ embeddings = await get_embeddings_for_edges(graph_driver, [entity_edge_1])
+ assert len(embeddings) == 1
+ assert entity_edge_1.uuid in embeddings
+ assert np.allclose(embeddings[entity_edge_1.uuid], entity_edge_1.fact_embedding)
+
+
+@pytest.mark.asyncio
+async def test_get_embeddings_for_nodes(graph_driver, mock_embedder):
+ # Create entity nodes
+ entity_node_1 = EntityNode(
+ name='test_entity_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await entity_node_1.generate_name_embedding(mock_embedder)
+
+ # Save the graph
+ await entity_node_1.save(graph_driver)
+
+ # Get embeddings for edges
+ embeddings = await get_embeddings_for_nodes(graph_driver, [entity_node_1])
+ assert len(embeddings) == 1
+ assert entity_node_1.uuid in embeddings
+ assert np.allclose(embeddings[entity_node_1.uuid], entity_node_1.name_embedding)
+
+
+@pytest.mark.asyncio
+async def test_get_embeddings_for_communities(graph_driver, mock_embedder):
+ # Create community nodes
+ community_node_1 = CommunityNode(
+ name='test_community_1',
+ labels=[],
+ created_at=datetime.now(),
+ group_id=group_id,
+ )
+ await community_node_1.generate_name_embedding(mock_embedder)
+
+ # Save the graph
+ await community_node_1.save(graph_driver)
+
+ # Get embeddings for communities
+ embeddings = await get_embeddings_for_communities(graph_driver, [community_node_1])
+ assert len(embeddings) == 1
+ assert community_node_1.uuid in embeddings
+ assert np.allclose(embeddings[community_node_1.uuid], community_node_1.name_embedding)
diff --git a/tests/test_node_int.py b/tests/test_node_int.py
index b4aa0709..edaa017a 100644
--- a/tests/test_node_int.py
+++ b/tests/test_node_int.py
@@ -14,22 +14,29 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
-from datetime import datetime
+from datetime import datetime, timedelta
from uuid import uuid4
-import numpy as np
import pytest
-from graphiti_core.driver.driver import GraphDriver
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
-from tests.helpers_test import drivers, get_driver
+from tests.helpers_test import (
+ assert_community_node_equals,
+ assert_entity_node_equals,
+ assert_episodic_node_equals,
+ get_node_count,
+ group_id,
+)
-group_id = f'test_group_{str(uuid4())}'
+created_at = datetime.now()
+deleted_at = created_at + timedelta(days=3)
+valid_at = created_at + timedelta(days=1)
+invalid_at = created_at + timedelta(days=2)
@pytest.fixture
@@ -38,9 +45,14 @@ def sample_entity_node():
uuid=str(uuid4()),
name='Test Entity',
group_id=group_id,
- labels=[],
+ labels=['Entity', 'Person'],
+ created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Entity Summary',
+ attributes={
+ 'age': 30,
+ 'location': 'New York',
+ },
)
@@ -50,10 +62,12 @@ def sample_episodic_node():
uuid=str(uuid4()),
name='Episode 1',
group_id=group_id,
+ created_at=created_at,
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
- valid_at=datetime.now(),
+ valid_at=valid_at,
+ entity_edges=[],
)
@@ -62,182 +76,152 @@ def sample_community_node():
return CommunityNode(
uuid=str(uuid4()),
name='Community A',
- name_embedding=[0.5] * 1024,
group_id=group_id,
+ created_at=created_at,
+ name_embedding=[0.5] * 1024,
summary='Community summary',
)
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- 'driver',
- drivers,
- ids=drivers,
-)
-async def test_entity_node(sample_entity_node, driver):
- driver = get_driver(driver)
+async def test_entity_node(sample_entity_node, graph_driver):
uuid = sample_entity_node.uuid
# Create node
- node_count = await get_node_count(driver, uuid)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
- await sample_entity_node.save(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_entity_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
- retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid)
- assert retrieved.uuid == sample_entity_node.uuid
- assert retrieved.name == 'Test Entity'
- assert retrieved.group_id == group_id
+ # Get node by uuid
+ retrieved = await EntityNode.get_by_uuid(graph_driver, sample_entity_node.uuid)
+ await assert_entity_node_equals(graph_driver, retrieved, sample_entity_node)
- retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid])
- assert retrieved[0].uuid == sample_entity_node.uuid
- assert retrieved[0].name == 'Test Entity'
- assert retrieved[0].group_id == group_id
+ # Get node by uuids
+ retrieved = await EntityNode.get_by_uuids(graph_driver, [sample_entity_node.uuid])
+ await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
- retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2)
+ # Get node by group ids
+ retrieved = await EntityNode.get_by_group_ids(graph_driver, [group_id], limit=2, with_embeddings=True)
assert len(retrieved) == 1
- assert retrieved[0].uuid == sample_entity_node.uuid
- assert retrieved[0].name == 'Test Entity'
- assert retrieved[0].group_id == group_id
-
- await sample_entity_node.load_name_embedding(driver)
- assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024)
+ await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
# Delete node by uuid
- await sample_entity_node.delete(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_entity_node.delete(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
+ assert node_count == 0
+
+ # Delete node by uuids
+ await sample_entity_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
+ assert node_count == 1
+ await sample_entity_node.delete_by_uuids(graph_driver, [uuid])
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
- await sample_entity_node.save(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_entity_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
- await sample_entity_node.delete_by_group_id(driver, group_id)
- node_count = await get_node_count(driver, uuid)
+ await sample_entity_node.delete_by_group_id(graph_driver, group_id)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
- await driver.close()
+ await graph_driver.close()
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- 'driver',
- drivers,
- ids=drivers,
-)
-async def test_community_node(sample_community_node, driver):
- driver = get_driver(driver)
+async def test_community_node(sample_community_node, graph_driver):
uuid = sample_community_node.uuid
# Create node
- node_count = await get_node_count(driver, uuid)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
- await sample_community_node.save(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_community_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
- retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid)
- assert retrieved.uuid == sample_community_node.uuid
- assert retrieved.name == 'Community A'
- assert retrieved.group_id == group_id
- assert retrieved.summary == 'Community summary'
+ # Get node by uuid
+ retrieved = await CommunityNode.get_by_uuid(graph_driver, sample_community_node.uuid)
+ await assert_community_node_equals(graph_driver, retrieved, sample_community_node)
- retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid])
- assert retrieved[0].uuid == sample_community_node.uuid
- assert retrieved[0].name == 'Community A'
- assert retrieved[0].group_id == group_id
- assert retrieved[0].summary == 'Community summary'
+ # Get node by uuids
+ retrieved = await CommunityNode.get_by_uuids(graph_driver, [sample_community_node.uuid])
+ await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
- retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2)
+ # Get node by group ids
+ retrieved = await CommunityNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
- assert retrieved[0].uuid == sample_community_node.uuid
- assert retrieved[0].name == 'Community A'
- assert retrieved[0].group_id == group_id
+ await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
# Delete node by uuid
- await sample_community_node.delete(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_community_node.delete(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
+ assert node_count == 0
+
+ # Delete node by uuids
+ await sample_community_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
+ assert node_count == 1
+ await sample_community_node.delete_by_uuids(graph_driver, [uuid])
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
- await sample_community_node.save(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_community_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
- await sample_community_node.delete_by_group_id(driver, group_id)
- node_count = await get_node_count(driver, uuid)
+ await sample_community_node.delete_by_group_id(graph_driver, group_id)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
- await driver.close()
+ await graph_driver.close()
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- 'driver',
- drivers,
- ids=drivers,
-)
-async def test_episodic_node(sample_episodic_node, driver):
- driver = get_driver(driver)
+async def test_episodic_node(sample_episodic_node, graph_driver):
uuid = sample_episodic_node.uuid
# Create node
- node_count = await get_node_count(driver, uuid)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
- await sample_episodic_node.save(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_episodic_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
- retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid)
- assert retrieved.uuid == sample_episodic_node.uuid
- assert retrieved.name == 'Episode 1'
- assert retrieved.group_id == group_id
- assert retrieved.source == EpisodeType.text
- assert retrieved.source_description == 'Test source'
- assert retrieved.content == 'Some content here'
- assert retrieved.valid_at == sample_episodic_node.valid_at
+ # Get node by uuid
+ retrieved = await EpisodicNode.get_by_uuid(graph_driver, sample_episodic_node.uuid)
+ await assert_episodic_node_equals(retrieved, sample_episodic_node)
- retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid])
- assert retrieved[0].uuid == sample_episodic_node.uuid
- assert retrieved[0].name == 'Episode 1'
- assert retrieved[0].group_id == group_id
- assert retrieved[0].source == EpisodeType.text
- assert retrieved[0].source_description == 'Test source'
- assert retrieved[0].content == 'Some content here'
- assert retrieved[0].valid_at == sample_episodic_node.valid_at
+ # Get node by uuids
+ retrieved = await EpisodicNode.get_by_uuids(graph_driver, [sample_episodic_node.uuid])
+ await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
- retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2)
+ # Get node by group ids
+ retrieved = await EpisodicNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
- assert retrieved[0].uuid == sample_episodic_node.uuid
- assert retrieved[0].name == 'Episode 1'
- assert retrieved[0].group_id == group_id
- assert retrieved[0].source == EpisodeType.text
- assert retrieved[0].source_description == 'Test source'
- assert retrieved[0].content == 'Some content here'
- assert retrieved[0].valid_at == sample_episodic_node.valid_at
+ await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
# Delete node by uuid
- await sample_episodic_node.delete(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_episodic_node.delete(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
+ assert node_count == 0
+
+ # Delete node by uuids
+ await sample_episodic_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
+ assert node_count == 1
+ await sample_episodic_node.delete_by_uuids(graph_driver, [uuid])
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
- await sample_episodic_node.save(driver)
- node_count = await get_node_count(driver, uuid)
+ await sample_episodic_node.save(graph_driver)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
- await sample_episodic_node.delete_by_group_id(driver, group_id)
- node_count = await get_node_count(driver, uuid)
+ await sample_episodic_node.delete_by_group_id(graph_driver, group_id)
+ node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
- await driver.close()
-
-
-async def get_node_count(driver: GraphDriver, uuid: str):
- result, _, _ = await driver.execute_query(
- """
- MATCH (n {uuid: $uuid})
- RETURN COUNT(n) as count
- """,
- uuid=uuid,
- )
- return int(result[0]['count'])
+ await graph_driver.close()
diff --git a/uv.lock b/uv.lock
index fecce3bc..6a731d48 100644
--- a/uv.lock
+++ b/uv.lock
@@ -809,6 +809,7 @@ dev = [
{ name = "groq" },
{ name = "ipykernel" },
{ name = "jupyterlab" },
+ { name = "kuzu" },
{ name = "langchain-anthropic" },
{ name = "langchain-openai" },
{ name = "langgraph" },
@@ -831,6 +832,9 @@ google-genai = [
groq = [
{ name = "groq" },
]
+kuzu = [
+ { name = "kuzu" },
+]
neptune = [
{ name = "boto3" },
{ name = "langchain-aws" },
@@ -858,6 +862,8 @@ requires-dist = [
{ name = "groq", marker = "extra == 'groq'", specifier = ">=0.2.0" },
{ name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5" },
{ name = "jupyterlab", marker = "extra == 'dev'", specifier = ">=4.2.4" },
+ { name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
+ { name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
@@ -882,7 +888,7 @@ requires-dist = [
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
]
-provides-extras = ["anthropic", "groq", "google-genai", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
+provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
[[package]]
name = "groq"
@@ -1387,6 +1393,40 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/09/2032e7d15c544a0e3cd831c51d77a8ca57f7555b2e1b2922142eddb02a84/jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4", size = 59700, upload-time = "2024-07-16T17:02:01.115Z" },
]
+[[package]]
+name = "kuzu"
+version = "0.11.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/66/fd/adbd05ccf81e6ad2674fcd3849d5d6ffeaf2141a9b8d1c1c4e282e923e1f/kuzu-0.11.2.tar.gz", hash = "sha256:9f224ec218ab165a18acaea903695779780d70335baf402d9b7f59ba389db0bd", size = 4902887, upload-time = "2025-08-21T05:17:00.152Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/0e/91/bed837f5f49220a9f869da8a078b34a3484f210f7b57b267177821545c03/kuzu-0.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b25174cdb721aae47896ed62842d3859679607b493a9a6bbbcd9fb7fb3707", size = 3702618, upload-time = "2025-08-21T05:15:53.726Z" },
+ { url = "https://files.pythonhosted.org/packages/72/8a/fd5e053b0055718afe00b6a99393a835c6254354128fbb7f66a35fd76089/kuzu-0.11.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:9a8567c53bfe282f4727782471ff718842ffead8c48c1762c1df9197408fc986", size = 4101371, upload-time = "2025-08-21T05:15:55.889Z" },
+ { url = "https://files.pythonhosted.org/packages/ad/4b/e45cadc85bdc5079f432675bbe8d557600f0d4ab46fe24ef218374419902/kuzu-0.11.2-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d793bb5a0a14ada730a697eccac2a4c68b434b82692d985942900ef2003e099e", size = 6211974, upload-time = "2025-08-21T05:15:57.505Z" },
+ { url = "https://files.pythonhosted.org/packages/10/ca/92d6a1e6452fcf06bfc423ce2cde819ace6b6e47921921cc8fae87c27780/kuzu-0.11.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1be4e9b6c93ca8591b1fb165f9b9a27d70a56af061831afcdfe7aebb89ee6ff", size = 6992196, upload-time = "2025-08-21T05:15:59.006Z" },
+ { url = "https://files.pythonhosted.org/packages/49/6c/983fc6265dfc1169c87c4a0722f36ee665c5688e1166faeb4cd85e6af078/kuzu-0.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0ec7a304c746a2a98ecfd7e7c3f6fe92c4dfee2e2565c0b7cb4cffd0c2e374a", size = 4303517, upload-time = "2025-08-21T05:16:00.814Z" },
+ { url = "https://files.pythonhosted.org/packages/b5/14/8ae2c52657b93715052ecf47d70232f2c8d9ffe2d1ec3527c8e9c3cb2df5/kuzu-0.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf53b4f321a4c05882b14cef96d39a1e90fa993bab88a1554fb1565367553b8c", size = 3704177, upload-time = "2025-08-21T05:16:02.354Z" },
+ { url = "https://files.pythonhosted.org/packages/2d/7a/bce7bb755e16f9ca855f76a3acc6cfa9fae88c4d6af9df3784c50b2120a5/kuzu-0.11.2-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:2d749883b74f5da5ff4a4b0635a98f6cc5165743995828924321d2ca797317cb", size = 4102372, upload-time = "2025-08-21T05:16:04.249Z" },
+ { url = "https://files.pythonhosted.org/packages/c8/12/f5b1d51fcb78a86c078fb85cc53184ce962a3e86852d47d30e287a932e3f/kuzu-0.11.2-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:632507e5982928ed24fbb5e70ad143d7970bc4059046e77e0522707efbad303b", size = 6212492, upload-time = "2025-08-21T05:16:05.99Z" },
+ { url = "https://files.pythonhosted.org/packages/81/96/d6e57af6ccf9e0697812ad3039c80b87b768cf2674833b0b23d317ea3427/kuzu-0.11.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5211884601f8f08ae97ba25006d0facde24077c5333411d944282b8a2068ab4", size = 6992888, upload-time = "2025-08-21T05:16:07.896Z" },
+ { url = "https://files.pythonhosted.org/packages/40/ee/1f275ac5679a3f615ce0d9cf8c79001fdb535ccc8bc344e49b14484c7cd7/kuzu-0.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:82a6c8bfe1278dc1010790e398bf772683797ef5c16052fa0f6f78bacbc59aa3", size = 4304064, upload-time = "2025-08-21T05:16:10.163Z" },
+ { url = "https://files.pythonhosted.org/packages/73/ba/9f20d9e83681a0ddae8ec13046b116c34745fa0e66862d4e2d8414734ce0/kuzu-0.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aed88ffa695d07289a3d8557bd8f9e743298a4f4349208a60bbb06f4ebf15c26", size = 3703781, upload-time = "2025-08-21T05:16:12.232Z" },
+ { url = "https://files.pythonhosted.org/packages/53/a0/bb815c0490f3d4d30389156369b9fe641e154f0d4b1e8340f09a76021922/kuzu-0.11.2-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:595824b03248af928e3faee57f6825d3a46920f2d3b9bd0c0bb7fc3fa097fce9", size = 4103990, upload-time = "2025-08-21T05:16:14.139Z" },
+ { url = "https://files.pythonhosted.org/packages/a5/6f/97b647c0547a634a669055ff4cfd21a92ea3999aedc6a7fe9004f03f25e3/kuzu-0.11.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5674c6d9d26f5caa0c7ce6f34c02e4411894879aa5b2ce174fad576fa898523", size = 6211947, upload-time = "2025-08-21T05:16:16.48Z" },
+ { url = "https://files.pythonhosted.org/packages/42/74/c7f1a1cfb08c05c91c5a94483be387e80fafab8923c4243a22e9cced5c1b/kuzu-0.11.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c61daf02da35b671f4c6f3c17105725c399a5e14b7349b00eafbcd24ac90034a", size = 6991879, upload-time = "2025-08-21T05:16:18.402Z" },
+ { url = "https://files.pythonhosted.org/packages/54/9e/50d67d7bc08faed95ede6de1a6aa0d81079c98028ca99e32d09c2ab1aead/kuzu-0.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:682096cd87dcbb8257f933ea4172d9dc5617a8d0a5bdd19cd66cf05b68881afd", size = 4305706, upload-time = "2025-08-21T05:16:20.244Z" },
+ { url = "https://files.pythonhosted.org/packages/65/f0/5649a01af37def50293cd7c194afc19f09b343fd2b7f2b28e021a207f8ce/kuzu-0.11.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:17a11b67652e8b331c85cd1a1a30b32ee6783750084473abbab2aa1963ee2a3b", size = 3703740, upload-time = "2025-08-21T05:16:21.896Z" },
+ { url = "https://files.pythonhosted.org/packages/24/e2/e0beb9080911fc1689899a42da0f83534949f43169fb80197def3ec1223f/kuzu-0.11.2-cp313-cp313-macosx_11_0_x86_64.whl", hash = "sha256:bdded35426210faeca8da11e8b4a54e60ccc0c1a832660d76587b5be133b0f55", size = 4104073, upload-time = "2025-08-21T05:16:23.819Z" },
+ { url = "https://files.pythonhosted.org/packages/f2/4c/7a831c9c6e609692953db677f54788bd1dde4c9d34e6ba91f1e153d2e7fe/kuzu-0.11.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6116b609aac153f3523130b31295643d34a6c9509914c0fa9d804b26b23eee73", size = 6212263, upload-time = "2025-08-21T05:16:25.351Z" },
+ { url = "https://files.pythonhosted.org/packages/47/95/615ef10b46b22ec1d33fdbba795e6e79733d9a244aabdeeb910f267ab36c/kuzu-0.11.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:09da5b8cb24dc6b281a6e4ac0f7f24226eb9909803b187e02d014da13ba57bcf", size = 6992492, upload-time = "2025-08-21T05:16:27.518Z" },
+ { url = "https://files.pythonhosted.org/packages/a7/dd/2c905575913c743e6c67a5ca89a6b4ea9d9737238966d85d7e710f0d3e60/kuzu-0.11.2-cp313-cp313-win_amd64.whl", hash = "sha256:c663fb84682f8ebffbe7447a4e552a0e03bd29097d319084a2c53c2e032a780e", size = 4305267, upload-time = "2025-08-21T05:16:29.307Z" },
+ { url = "https://files.pythonhosted.org/packages/89/05/44fbfc9055dba3f472ea4aaa8110635864d3441eede987526ef401680765/kuzu-0.11.2-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5c03fb95ffb9185c1519333f8ee92b7a9695aa7aa9a179e868a7d7bd13d10a16", size = 6216795, upload-time = "2025-08-21T05:16:30.944Z" },
+ { url = "https://files.pythonhosted.org/packages/4f/ca/16c81dc68cc1e8918f8481e7ee89c28aa665c5cb36be7ad0fc1d0d295760/kuzu-0.11.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d857f0efddf26d5e2dc189facb84bf04a997e395972486669b418a470cc76034", size = 6996333, upload-time = "2025-08-21T05:16:32.568Z" },
+ { url = "https://files.pythonhosted.org/packages/48/d8/9275c7e6312bd76dc670e8e2da68639757c22cf2c366e96527595a1d881c/kuzu-0.11.2-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb9e4641867c35b98ceaa604aa79832c0eeed41f5fd1b6da22b1c217b2f1b8ea", size = 6212202, upload-time = "2025-08-21T05:16:34.571Z" },
+ { url = "https://files.pythonhosted.org/packages/88/89/67a977493c60bca3610845df13020711f357a5d80bf91549e4b48d877c2f/kuzu-0.11.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:553408d76a0b4fdecc1338b69b71d7bde42f6936d3b99d9852b30d33bda15978", size = 6992264, upload-time = "2025-08-21T05:16:36.316Z" },
+ { url = "https://files.pythonhosted.org/packages/b6/49/869ceceb1d8a5ea032a35c734e55cfee919340889973623096da7eb94f6b/kuzu-0.11.2-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:989a87fa13ffa39ab7773d968fe739ac4f8faf9ddb5dad72ced2eeef12180293", size = 6216814, upload-time = "2025-08-21T05:16:38.348Z" },
+ { url = "https://files.pythonhosted.org/packages/bc/cd/933b34a246edb882a042eb402747167719222c05149b73b48ba7d310d554/kuzu-0.11.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e67420d04a9643fd6376a23b17b398a3e32bb0c2bd8abbf8d1e4697056596c7e", size = 6996343, upload-time = "2025-08-21T05:16:39.973Z" },
+]
+
[[package]]
name = "langchain-anthropic"
version = "0.3.17"