From dcc9da3f6887b84830758d7e89974eb4f2af8f92 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Tue, 29 Jul 2025 06:07:34 -0700 Subject: [PATCH] chore/prepare kuzu integration (#762) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Prepare code * Fix tests * As -> AS, remove trailing spaces * Enable more tests for FalkorDB * Fix more cypher queries * Return all created nodes and edges * Add Neo4j service to unit tests workflow - Introduced Neo4j as a service in the GitHub Actions workflow for unit tests. - Configured Neo4j with appropriate ports, authentication, and health checks. - Updated test steps to include waiting for Neo4j and running integration tests against it. - Set environment variables for Neo4j connection in both non-integration and integration test steps. * Update Neo4j authentication in unit tests workflow - Changed Neo4j authentication password from 'test' to 'testpass' in the GitHub Actions workflow. - Updated health check command to reflect the new password. - Ensured consistency across all test steps that utilize Neo4j credentials. * fix health check * Fix Neo4j integration tests in CI workflow Remove reference to non-existent test_neo4j_driver.py file from test command. Integration tests now run via parametrized tests using the drivers list. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Add OPENAI_API_KEY to Neo4j integration tests Neo4j integration tests require OpenAI API access for LLM functionality. Add the secret environment variable to enable these tests to run properly. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Fix Neo4j Cypher syntax error in BFS search queries Replace parameter substitution in relationship pattern ranges (*1..$depth) with direct string interpolation (*1..{bfs_max_depth}). Neo4j doesn't allow parameter maps in MATCH pattern ranges - they must be literal values. Fixed in both node_bfs_search and edge_bfs_search functions. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Fix variable name mismatch in edge_bfs_search query Change relationship variable from 'r' to 'e' to match ENTITY_EDGE_RETURN constant expectations. The ENTITY_EDGE_RETURN constant references variable 'e' for relationships, but the query was using 'r', causing "Variable e not defined" errors. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Isolate database tests in CI workflow - FalkorDB tests: Add DISABLE_NEO4J=1 and remove Neo4j env vars - Neo4j tests: Keep current setup without DISABLE_NEO4J flag This ensures proper test isolation where each test suite only runs against its intended database backend. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --------- Co-authored-by: Siddhartha Sahu Co-authored-by: Claude --- .github/workflows/unit_tests.yml | 27 ++ graphiti_core/driver/driver.py | 8 +- graphiti_core/driver/falkordb_driver.py | 6 +- graphiti_core/driver/neo4j_driver.py | 10 +- graphiti_core/edges.py | 172 ++++---- graphiti_core/graph_queries.py | 144 +++---- graphiti_core/graphiti.py | 25 +- graphiti_core/helpers.py | 5 +- graphiti_core/models/edges/edge_db_queries.py | 138 +++++-- graphiti_core/models/nodes/node_db_queries.py | 125 +++++- graphiti_core/nodes.py | 241 ++++++----- graphiti_core/search/search_filters.py | 10 +- graphiti_core/search/search_utils.py | 325 +++++++-------- graphiti_core/utils/bulk_utils.py | 8 +- .../utils/maintenance/community_operations.py | 18 +- .../maintenance/graph_data_operations.py | 43 +- pyproject.toml | 6 +- tests/driver/test_falkordb_driver.py | 8 +- tests/helpers_test.py | 60 ++- tests/test_edge_int.py | 384 ++++++++++++++++++ tests/test_entity_exclusion_int.py | 55 ++- tests/test_graphiti_falkordb_int.py | 164 -------- tests/test_graphiti_int.py | 95 +---- tests/test_node_falkordb_int.py | 139 ------- tests/test_node_int.py | 191 +++++++-- 25 files changed, 1339 insertions(+), 1068 deletions(-) create mode 100644 tests/test_edge_int.py delete mode 100644 tests/test_graphiti_falkordb_int.py delete mode 100644 tests/test_node_falkordb_int.py diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 9f41d595..3096524a 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -20,6 +20,15 @@ jobs: ports: - 6379:6379 options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5 + neo4j: + image: neo4j:5.26-community + ports: + - 7687:7687 + - 7474:7474 + env: + NEO4J_AUTH: neo4j/testpass + NEO4J_PLUGINS: '["apoc"]' + options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10 steps: - uses: actions/checkout@v4 - name: Set up Python @@ -37,15 +46,33 @@ jobs: - name: Run non-integration tests env: PYTHONPATH: ${{ github.workspace }} + NEO4J_URI: bolt://localhost:7687 + NEO4J_USER: neo4j + NEO4J_PASSWORD: testpass run: | uv run pytest -m "not integration" - name: Wait for FalkorDB run: | timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done' + - name: Wait for Neo4j + run: | + timeout 60 bash -c 'until wget -O /dev/null http://localhost:7474 >/dev/null 2>&1; do sleep 1; done' - name: Run FalkorDB integration tests env: PYTHONPATH: ${{ github.workspace }} FALKORDB_HOST: localhost FALKORDB_PORT: 6379 + DISABLE_NEO4J: 1 run: | uv run pytest tests/driver/test_falkordb_driver.py + - name: Run Neo4j integration tests + env: + PYTHONPATH: ${{ github.workspace }} + NEO4J_URI: bolt://localhost:7687 + NEO4J_USER: neo4j + NEO4J_PASSWORD: testpass + FALKORDB_HOST: localhost + FALKORDB_PORT: 6379 + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + uv run pytest tests/test_*_int.py -k "neo4j" diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 4efe230a..6b85d80e 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -18,11 +18,17 @@ import copy import logging from abc import ABC, abstractmethod from collections.abc import Coroutine +from enum import Enum from typing import Any logger = logging.getLogger(__name__) +class GraphProvider(Enum): + NEO4J = 'neo4j' + FALKORDB = 'falkordb' + + class GraphDriverSession(ABC): async def __aenter__(self): return self @@ -46,7 +52,7 @@ class GraphDriverSession(ABC): class GraphDriver(ABC): - provider: str + provider: GraphProvider fulltext_syntax: str = ( '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index acf2c66f..f121319b 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -32,7 +32,7 @@ else: 'Install it with: pip install graphiti-core[falkordb]' ) from None -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider logger = logging.getLogger(__name__) @@ -71,7 +71,7 @@ class FalkorDriverSession(GraphDriverSession): class FalkorDriver(GraphDriver): - provider: str = 'falkordb' + provider = GraphProvider.FALKORDB def __init__( self, @@ -119,7 +119,7 @@ class FalkorDriver(GraphDriver): # check if index already exists logger.info(f'Index already exists: {e}') return None - logger.error(f'Error executing FalkorDB query: {e}') + logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}') raise # Convert the result header to a list of strings diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index bd82e8d9..7ac9a5a8 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -21,13 +21,13 @@ from typing import Any from neo4j import AsyncGraphDatabase, EagerResult from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider logger = logging.getLogger(__name__) class Neo4jDriver(GraphDriver): - provider: str = 'neo4j' + provider = GraphProvider.NEO4J def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'): super().__init__() @@ -45,7 +45,11 @@ class Neo4jDriver(GraphDriver): params = {} params.setdefault('database_', self._database) - result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs) + try: + result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs) + except Exception as e: + logger.error(f'Error executing Neo4j query: {e}\n{cypher_query_}\n{params}') + raise return result diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index cee5b870..d1c4f1d2 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -29,29 +29,17 @@ from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import parse_db_date from graphiti_core.models.edges.edge_db_queries import ( - COMMUNITY_EDGE_SAVE, - ENTITY_EDGE_SAVE, + COMMUNITY_EDGE_RETURN, + ENTITY_EDGE_RETURN, + EPISODIC_EDGE_RETURN, EPISODIC_EDGE_SAVE, + get_community_edge_save_query, + get_entity_edge_save_query, ) from graphiti_core.nodes import Node logger = logging.getLogger(__name__) -ENTITY_EDGE_RETURN: LiteralString = """ - RETURN - e.uuid AS uuid, - startNode(e).uuid AS source_node_uuid, - endNode(e).uuid AS target_node_uuid, - e.created_at AS created_at, - e.name AS name, - e.group_id AS group_id, - e.fact AS fact, - e.episodes AS episodes, - e.expired_at AS expired_at, - e.valid_at AS valid_at, - e.invalid_at AS invalid_at, - properties(e) AS attributes""" - class Edge(BaseModel, ABC): uuid: str = Field(default_factory=lambda: str(uuid4())) @@ -66,9 +54,9 @@ class Edge(BaseModel, ABC): async def delete(self, driver: GraphDriver): result = await driver.execute_query( """ - MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m) - DELETE e - """, + MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m) + DELETE e + """, uuid=self.uuid, ) @@ -107,14 +95,10 @@ class EpisodicEdge(Edge): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - """, + MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) + RETURN + """ + + EPISODIC_EDGE_RETURN, uuid=uuid, routing_='r', ) @@ -129,15 +113,11 @@ class EpisodicEdge(Edge): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) - WHERE e.uuid IN $uuids - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - """, + MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) + WHERE e.uuid IN $uuids + RETURN + """ + + EPISODIC_EDGE_RETURN, uuids=uuids, routing_='r', ) @@ -161,19 +141,17 @@ class EpisodicEdge(Edge): records, _, _ = await driver.execute_query( """ - MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) - WHERE e.group_id IN $group_ids - """ + MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) + WHERE e.group_id IN $group_ids + """ + cursor_query + """ - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - ORDER BY e.uuid DESC - """ + RETURN + """ + + EPISODIC_EDGE_RETURN + + """ + ORDER BY e.uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, @@ -221,11 +199,14 @@ class EntityEdge(Edge): return self.fact_embedding async def load_fact_embedding(self, driver: GraphDriver): - query: LiteralString = """ + records, _, _ = await driver.execute_query( + """ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) RETURN e.fact_embedding AS fact_embedding - """ - records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r') + """, + uuid=self.uuid, + routing_='r', + ) if len(records) == 0: raise EdgeNotFoundError(self.uuid) @@ -251,7 +232,7 @@ class EntityEdge(Edge): edge_data.update(self.attributes or {}) result = await driver.execute_query( - ENTITY_EDGE_SAVE, + get_entity_edge_save_query(driver.provider), edge_data=edge_data, ) @@ -263,8 +244,9 @@ class EntityEdge(Edge): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) + RETURN + """ + ENTITY_EDGE_RETURN, uuid=uuid, routing_='r', @@ -283,9 +265,10 @@ class EntityEdge(Edge): records, _, _ = await driver.execute_query( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.uuid IN $uuids - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.uuid IN $uuids + RETURN + """ + ENTITY_EDGE_RETURN, uuids=uuids, routing_='r', @@ -314,22 +297,21 @@ class EntityEdge(Edge): else '' ) - query: LiteralString = ( + records, _, _ = await driver.execute_query( """ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) WHERE e.group_id IN $group_ids """ + cursor_query + + """ + RETURN + """ + ENTITY_EDGE_RETURN + with_embeddings_query + """ - ORDER BY e.uuid DESC - """ - + limit_query - ) - - records, _, _ = await driver.execute_query( - query, + ORDER BY e.uuid DESC + """ + + limit_query, group_ids=group_ids, uuid=uuid_cursor, limit=limit, @@ -344,13 +326,15 @@ class EntityEdge(Edge): @classmethod async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str): - query: LiteralString = ( + records, _, _ = await driver.execute_query( """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - """ - + ENTITY_EDGE_RETURN + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + RETURN + """ + + ENTITY_EDGE_RETURN, + node_uuid=node_uuid, + routing_='r', ) - records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r') edges = [get_entity_edge_from_record(record) for record in records] @@ -360,7 +344,7 @@ class EntityEdge(Edge): class CommunityEdge(Edge): async def save(self, driver: GraphDriver): result = await driver.execute_query( - COMMUNITY_EDGE_SAVE, + get_community_edge_save_query(driver.provider), community_uuid=self.source_node_uuid, entity_uuid=self.target_node_uuid, uuid=self.uuid, @@ -376,14 +360,10 @@ class CommunityEdge(Edge): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community) - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - """, + MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m) + RETURN + """ + + COMMUNITY_EDGE_RETURN, uuid=uuid, routing_='r', ) @@ -396,15 +376,11 @@ class CommunityEdge(Edge): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) - WHERE e.uuid IN $uuids - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - """, + MATCH (n:Community)-[e:HAS_MEMBER]->(m) + WHERE e.uuid IN $uuids + RETURN + """ + + COMMUNITY_EDGE_RETURN, uuids=uuids, routing_='r', ) @@ -426,19 +402,17 @@ class CommunityEdge(Edge): records, _, _ = await driver.execute_query( """ - MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) - WHERE e.group_id IN $group_ids - """ + MATCH (n:Community)-[e:HAS_MEMBER]->(m) + WHERE e.group_id IN $group_ids + """ + cursor_query + """ - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - ORDER BY e.uuid DESC - """ + RETURN + """ + + COMMUNITY_EDGE_RETURN + + """ + ORDER BY e.uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 10e396c0..06f4e8ab 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -5,16 +5,9 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB, supporting index creation, fulltext search, and bulk operations. """ -from typing import Any - from typing_extensions import LiteralString -from graphiti_core.models.edges.edge_db_queries import ( - ENTITY_EDGE_SAVE_BULK, -) -from graphiti_core.models.nodes.node_db_queries import ( - ENTITY_NODE_SAVE_BULK, -) +from graphiti_core.driver.driver import GraphProvider # Mapping from Neo4j fulltext index names to FalkorDB node labels NEO4J_TO_FALKORDB_MAPPING = { @@ -25,8 +18,8 @@ NEO4J_TO_FALKORDB_MAPPING = { } -def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]: - if db_type == 'falkordb': +def get_range_indices(provider: GraphProvider) -> list[LiteralString]: + if provider == GraphProvider.FALKORDB: return [ # Entity node 'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)', @@ -41,109 +34,70 @@ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]: # HAS_MEMBER edge 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', ] - else: - return [ - 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', - 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', - 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)', - 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', - 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', - 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', - 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', - 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', - 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', - 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', - 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', - 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', - 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', - 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', - 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', - 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', - 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', - 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', - 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', - ] + + return [ + 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', + 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', + 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)', + 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', + 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', + 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', + 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', + 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', + 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', + 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', + 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', + 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', + 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', + 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', + 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', + 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', + 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', + 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', + 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', + ] -def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]: - if db_type == 'falkordb': +def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]: + if provider == GraphProvider.FALKORDB: return [ """CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""", """CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""", """CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""", """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""", ] - else: - return [ - """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS - FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", - """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS - FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", - """CREATE FULLTEXT INDEX community_name IF NOT EXISTS - FOR (n:Community) ON EACH [n.name, n.group_id]""", - """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS - FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""", - ] + + return [ + """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS + FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", + """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS + FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", + """CREATE FULLTEXT INDEX community_name IF NOT EXISTS + FOR (n:Community) ON EACH [n.name, n.group_id]""", + """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS + FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""", + ] -def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str: - if db_type == 'falkordb': +def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str: + if provider == GraphProvider.FALKORDB: label = NEO4J_TO_FALKORDB_MAPPING[name] return f"CALL db.idx.fulltext.queryNodes('{label}', {query})" - else: - return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' + + return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' -def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str: - if db_type == 'falkordb': +def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2' - else: - return f'vector.similarity.cosine({vec1}, {vec2})' + + return f'vector.similarity.cosine({vec1}, {vec2})' -def get_relationships_query(name: str, db_type: str = 'neo4j') -> str: - if db_type == 'falkordb': +def get_relationships_query(name: str, provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: label = NEO4J_TO_FALKORDB_MAPPING[name] return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)" - else: - return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' - -def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any: - if db_type == 'falkordb': - queries = [] - for node in nodes: - for label in node['labels']: - queries.append( - ( - f""" - UNWIND $nodes AS node - MERGE (n:Entity {{uuid: node.uuid}}) - SET n:{label} - SET n = node - WITH n, node - SET n.name_embedding = vecf32(node.name_embedding) - RETURN n.uuid AS uuid - """, - {'nodes': [node]}, - ) - ) - return queries - else: - return ENTITY_NODE_SAVE_BULK - - -def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str: - if db_type == 'falkordb': - return """ - UNWIND $entity_edges AS edge - MATCH (source:Entity {uuid: edge.source_node_uuid}) - MATCH (target:Entity {uuid: edge.target_node_uuid}) - MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) - SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes, - created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)} - WITH r, edge - RETURN edge.uuid AS uuid""" - else: - return ENTITY_EDGE_SAVE_BULK + return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 3459f8ae..b357f3b1 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -26,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver -from graphiti_core.edges import EntityEdge, EpisodicEdge +from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import ( @@ -93,8 +93,11 @@ load_dotenv() class AddEpisodeResults(BaseModel): episode: EpisodicNode + episodic_edges: list[EpisodicEdge] nodes: list[EntityNode] edges: list[EntityEdge] + communities: list[CommunityNode] + community_edges: list[CommunityEdge] class Graphiti: @@ -520,9 +523,12 @@ class Graphiti: self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder ) + communities = [] + community_edges = [] + # Update any communities if update_communities: - await semaphore_gather( + communities, community_edges = await semaphore_gather( *[ update_community(self.driver, self.llm_client, self.embedder, node) for node in nodes @@ -532,7 +538,14 @@ class Graphiti: end = time() logger.info(f'Completed add_episode in {(end - start) * 1000} ms') - return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges) + return AddEpisodeResults( + episode=episode, + episodic_edges=episodic_edges, + nodes=hydrated_nodes, + edges=entity_edges, + communities=communities, + community_edges=community_edges, + ) except Exception as e: raise e @@ -817,7 +830,9 @@ class Graphiti: except Exception as e: raise e - async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]: + async def build_communities( + self, group_ids: list[str] | None = None + ) -> tuple[list[CommunityNode], list[CommunityEdge]]: """ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising the content of these communities. @@ -846,7 +861,7 @@ class Graphiti: max_coroutines=self.max_coroutines, ) - return community_nodes + return community_nodes, community_edges async def search( self, diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index ae311a08..2bf68b81 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -28,6 +28,7 @@ from numpy._typing import NDArray from pydantic import BaseModel from typing_extensions import LiteralString +from graphiti_core.driver.driver import GraphProvider from graphiti_core.errors import GroupIdValidationError load_dotenv() @@ -52,12 +53,12 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None ) -def get_default_group_id(db_type: str) -> str: +def get_default_group_id(provider: GraphProvider) -> str: """ This function differentiates the default group id based on the database type. For most databases, the default group id is an empty string, while there are database types that require a specific default group id. """ - if db_type == 'falkordb': + if provider == GraphProvider.FALKORDB: return '_' else: return '' diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index 47b6f04f..ed1a19e5 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -14,43 +14,117 @@ See the License for the specific language governing permissions and limitations under the License. """ +from graphiti_core.driver.driver import GraphProvider + EPISODIC_EDGE_SAVE = """ - MATCH (episode:Episodic {uuid: $episode_uuid}) - MATCH (node:Entity {uuid: $entity_uuid}) - MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) - SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at} - RETURN r.uuid AS uuid""" + MATCH (episode:Episodic {uuid: $episode_uuid}) + MATCH (node:Entity {uuid: $entity_uuid}) + MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node) + SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} + RETURN e.uuid AS uuid +""" EPISODIC_EDGE_SAVE_BULK = """ UNWIND $episodic_edges AS edge - MATCH (episode:Episodic {uuid: edge.source_node_uuid}) - MATCH (node:Entity {uuid: edge.target_node_uuid}) - MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node) - SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at} - RETURN r.uuid AS uuid + MATCH (episode:Episodic {uuid: edge.source_node_uuid}) + MATCH (node:Entity {uuid: edge.target_node_uuid}) + MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node) + SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at} + RETURN e.uuid AS uuid """ -ENTITY_EDGE_SAVE = """ - MATCH (source:Entity {uuid: $edge_data.source_uuid}) - MATCH (target:Entity {uuid: $edge_data.target_uuid}) - MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target) - SET r = $edge_data - WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding) - RETURN r.uuid AS uuid""" - -ENTITY_EDGE_SAVE_BULK = """ - UNWIND $entity_edges AS edge - MATCH (source:Entity {uuid: edge.source_node_uuid}) - MATCH (target:Entity {uuid: edge.target_node_uuid}) - MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) - SET r = edge - WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding) - RETURN edge.uuid AS uuid +EPISODIC_EDGE_RETURN = """ + e.uuid AS uuid, + e.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at """ -COMMUNITY_EDGE_SAVE = """ - MATCH (community:Community {uuid: $community_uuid}) - MATCH (node:Entity | Community {uuid: $entity_uuid}) - MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node) - SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at} - RETURN r.uuid AS uuid""" + +def get_entity_edge_save_query(provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + return """ + MATCH (source:Entity {uuid: $edge_data.source_uuid}) + MATCH (target:Entity {uuid: $edge_data.target_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) + SET e = $edge_data + RETURN e.uuid AS uuid + """ + + return """ + MATCH (source:Entity {uuid: $edge_data.source_uuid}) + MATCH (target:Entity {uuid: $edge_data.target_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) + SET e = $edge_data + WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding) + RETURN e.uuid AS uuid + """ + + +def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + return """ + UNWIND $entity_edges AS edge + MATCH (source:Entity {uuid: edge.source_node_uuid}) + MATCH (target:Entity {uuid: edge.target_node_uuid}) + MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) + SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes, + created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)} + WITH r, edge + RETURN edge.uuid AS uuid + """ + + return """ + UNWIND $entity_edges AS edge + MATCH (source:Entity {uuid: edge.source_node_uuid}) + MATCH (target:Entity {uuid: edge.target_node_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) + SET e = edge + WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding) + RETURN edge.uuid AS uuid + """ + + +ENTITY_EDGE_RETURN = """ + e.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.group_id AS group_id, + e.name AS name, + e.fact AS fact, + e.episodes AS episodes, + e.created_at AS created_at, + e.expired_at AS expired_at, + e.valid_at AS valid_at, + e.invalid_at AS invalid_at, + properties(e) AS attributes +""" + + +def get_community_edge_save_query(provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + return """ + MATCH (community:Community {uuid: $community_uuid}) + MATCH (node {uuid: $entity_uuid}) + MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node) + SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} + RETURN e.uuid AS uuid + """ + + return """ + MATCH (community:Community {uuid: $community_uuid}) + MATCH (node:Entity | Community {uuid: $entity_uuid}) + MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node) + SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} + RETURN e.uuid AS uuid + """ + + +COMMUNITY_EDGE_RETURN = """ + e.uuid AS uuid, + e.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at +""" diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 5f361e07..fdcf48f1 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -14,39 +14,120 @@ See the License for the specific language governing permissions and limitations under the License. """ +from typing import Any + +from graphiti_core.driver.driver import GraphProvider + EPISODIC_NODE_SAVE = """ - MERGE (n:Episodic {uuid: $uuid}) - SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, - entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} - RETURN n.uuid AS uuid""" + MERGE (n:Episodic {uuid: $uuid}) + SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, + entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} + RETURN n.uuid AS uuid +""" EPISODIC_NODE_SAVE_BULK = """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) - SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, - source: episode.source, content: episode.content, + SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, + source: episode.source, content: episode.content, entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} RETURN n.uuid AS uuid """ -ENTITY_NODE_SAVE = """ - MERGE (n:Entity {uuid: $entity_data.uuid}) - SET n:$($labels) - SET n = $entity_data - WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding) - RETURN n.uuid AS uuid""" - -ENTITY_NODE_SAVE_BULK = """ - UNWIND $nodes AS node - MERGE (n:Entity {uuid: node.uuid}) - SET n:$(node.labels) - SET n = node - WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) - RETURN n.uuid AS uuid +EPISODIC_NODE_RETURN = """ + e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.name AS name, + e.group_id AS group_id, + e.source_description AS source_description, + e.source AS source, + e.entity_edges AS entity_edges """ -COMMUNITY_NODE_SAVE = """ + +def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: + if provider == GraphProvider.FALKORDB: + return f""" + MERGE (n:Entity {{uuid: $entity_data.uuid}}) + SET n:{labels} + SET n = $entity_data + RETURN n.uuid AS uuid + """ + + return f""" + MERGE (n:Entity {{uuid: $entity_data.uuid}}) + SET n:{labels} + SET n = $entity_data + WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding) + RETURN n.uuid AS uuid + """ + + +def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any: + if provider == GraphProvider.FALKORDB: + queries = [] + for node in nodes: + for label in node['labels']: + queries.append( + ( + f""" + UNWIND $nodes AS node + MERGE (n:Entity {{uuid: node.uuid}}) + SET n:{label} + SET n = node + WITH n, node + SET n.name_embedding = vecf32(node.name_embedding) + RETURN n.uuid AS uuid + """, + {'nodes': [node]}, + ) + ) + return queries + + return """ + UNWIND $nodes AS node + MERGE (n:Entity {uuid: node.uuid}) + SET n:$(node.labels) + SET n = node + WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) + RETURN n.uuid AS uuid + """ + + +ENTITY_NODE_RETURN = """ + n.uuid AS uuid, + n.name AS name, + n.group_id AS group_id, + n.created_at AS created_at, + n.summary AS summary, + labels(n) AS labels, + properties(n) AS attributes +""" + + +def get_community_node_save_query(provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + return """ + MERGE (n:Community {uuid: $uuid}) + SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: $name_embedding} + RETURN n.uuid AS uuid + """ + + return """ MERGE (n:Community {uuid: $uuid}) SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at} WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding) - RETURN n.uuid AS uuid""" + RETURN n.uuid AS uuid + """ + + +COMMUNITY_NODE_RETURN = """ + n.uuid AS uuid, + n.name AS name, + n.name_embedding AS name_embedding, + n.group_id AS group_id, + n.summary AS summary, + n.created_at AS created_at +""" diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index d2fb79ca..bad8fefa 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -25,29 +25,22 @@ from uuid import uuid4 from pydantic import BaseModel, Field from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver +from graphiti_core.driver.driver import GraphDriver, GraphProvider from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import NodeNotFoundError from graphiti_core.helpers import parse_db_date from graphiti_core.models.nodes.node_db_queries import ( - COMMUNITY_NODE_SAVE, - ENTITY_NODE_SAVE, + COMMUNITY_NODE_RETURN, + ENTITY_NODE_RETURN, + EPISODIC_NODE_RETURN, EPISODIC_NODE_SAVE, + get_community_node_save_query, + get_entity_node_save_query, ) from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) -ENTITY_NODE_RETURN: LiteralString = """ - RETURN - n.uuid As uuid, - n.name AS name, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes""" - class EpisodeType(Enum): """ @@ -96,18 +89,26 @@ class Node(BaseModel, ABC): async def save(self, driver: GraphDriver): ... async def delete(self, driver: GraphDriver): - result = await driver.execute_query( - """ - MATCH (n:Entity|Episodic|Community {uuid: $uuid}) - DETACH DELETE n - """, - uuid=self.uuid, - ) + if driver.provider == GraphProvider.FALKORDB: + for label in ['Entity', 'Episodic', 'Community']: + await driver.execute_query( + f""" + MATCH (n:{label} {{uuid: $uuid}}) + DETACH DELETE n + """, + uuid=self.uuid, + ) + else: + await driver.execute_query( + """ + MATCH (n:Entity|Episodic|Community {uuid: $uuid}) + DETACH DELETE n + """, + uuid=self.uuid, + ) logger.debug(f'Deleted Node: {self.uuid}') - return result - def __hash__(self): return hash(self.uuid) @@ -118,15 +119,23 @@ class Node(BaseModel, ABC): @classmethod async def delete_by_group_id(cls, driver: GraphDriver, group_id: str): - await driver.execute_query( - """ - MATCH (n:Entity|Episodic|Community {group_id: $group_id}) - DETACH DELETE n - """, - group_id=group_id, - ) - - return 'SUCCESS' + if driver.provider == GraphProvider.FALKORDB: + for label in ['Entity', 'Episodic', 'Community']: + await driver.execute_query( + f""" + MATCH (n:{label} {{group_id: $group_id}}) + DETACH DELETE n + """, + group_id=group_id, + ) + else: + await driver.execute_query( + """ + MATCH (n:Entity|Episodic|Community {group_id: $group_id}) + DETACH DELETE n + """, + group_id=group_id, + ) @classmethod async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... @@ -169,17 +178,10 @@ class EpisodicNode(Node): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (e:Episodic {uuid: $uuid}) - RETURN e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - """, + MATCH (e:Episodic {uuid: $uuid}) + RETURN + """ + + EPISODIC_NODE_RETURN, uuid=uuid, routing_='r', ) @@ -195,18 +197,11 @@ class EpisodicNode(Node): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (e:Episodic) WHERE e.uuid IN $uuids + MATCH (e:Episodic) + WHERE e.uuid IN $uuids RETURN DISTINCT - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - """, + """ + + EPISODIC_NODE_RETURN, uuids=uuids, routing_='r', ) @@ -228,22 +223,17 @@ class EpisodicNode(Node): records, _, _ = await driver.execute_query( """ - MATCH (e:Episodic) WHERE e.group_id IN $group_ids - """ + MATCH (e:Episodic) + WHERE e.group_id IN $group_ids + """ + cursor_query + """ RETURN DISTINCT - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - ORDER BY e.uuid DESC - """ + """ + + EPISODIC_NODE_RETURN + + """ + ORDER BY uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, @@ -259,18 +249,10 @@ class EpisodicNode(Node): async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid}) + MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid}) RETURN DISTINCT - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - """, + """ + + EPISODIC_NODE_RETURN, entity_node_uuid=entity_node_uuid, routing_='r', ) @@ -297,11 +279,14 @@ class EntityNode(Node): return self.name_embedding async def load_name_embedding(self, driver: GraphDriver): - query: LiteralString = """ + records, _, _ = await driver.execute_query( + """ MATCH (n:Entity {uuid: $uuid}) RETURN n.name_embedding AS name_embedding - """ - records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r') + """, + uuid=self.uuid, + routing_='r', + ) if len(records) == 0: raise NodeNotFoundError(self.uuid) @@ -317,12 +302,12 @@ class EntityNode(Node): 'summary': self.summary, 'created_at': self.created_at, } - entity_data.update(self.attributes or {}) + labels = ':'.join(self.labels + ['Entity']) + result = await driver.execute_query( - ENTITY_NODE_SAVE, - labels=self.labels + ['Entity'], + get_entity_node_save_query(driver.provider, labels), entity_data=entity_data, ) @@ -332,14 +317,12 @@ class EntityNode(Node): @classmethod async def get_by_uuid(cls, driver: GraphDriver, uuid: str): - query = ( - """ - MATCH (n:Entity {uuid: $uuid}) - """ - + ENTITY_NODE_RETURN - ) records, _, _ = await driver.execute_query( - query, + """ + MATCH (n:Entity {uuid: $uuid}) + RETURN + """ + + ENTITY_NODE_RETURN, uuid=uuid, routing_='r', ) @@ -355,8 +338,10 @@ class EntityNode(Node): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (n:Entity) WHERE n.uuid IN $uuids - """ + MATCH (n:Entity) + WHERE n.uuid IN $uuids + RETURN + """ + ENTITY_NODE_RETURN, uuids=uuids, routing_='r', @@ -379,22 +364,26 @@ class EntityNode(Node): limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' with_embeddings_query: LiteralString = ( """, - n.name_embedding AS name_embedding - """ + n.name_embedding AS name_embedding + """ if with_embeddings else '' ) records, _, _ = await driver.execute_query( """ - MATCH (n:Entity) WHERE n.group_id IN $group_ids - """ + MATCH (n:Entity) + WHERE n.group_id IN $group_ids + """ + cursor_query + + """ + RETURN + """ + ENTITY_NODE_RETURN + with_embeddings_query + """ - ORDER BY n.uuid DESC - """ + ORDER BY n.uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, @@ -413,7 +402,7 @@ class CommunityNode(Node): async def save(self, driver: GraphDriver): result = await driver.execute_query( - COMMUNITY_NODE_SAVE, + get_community_node_save_query(driver.provider), uuid=self.uuid, name=self.name, group_id=self.group_id, @@ -436,11 +425,14 @@ class CommunityNode(Node): return self.name_embedding async def load_name_embedding(self, driver: GraphDriver): - query: LiteralString = """ + records, _, _ = await driver.execute_query( + """ MATCH (c:Community {uuid: $uuid}) RETURN c.name_embedding AS name_embedding - """ - records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r') + """, + uuid=self.uuid, + routing_='r', + ) if len(records) == 0: raise NodeNotFoundError(self.uuid) @@ -451,14 +443,10 @@ class CommunityNode(Node): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (n:Community {uuid: $uuid}) - RETURN - n.uuid As uuid, - n.name AS name, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary - """, + MATCH (n:Community {uuid: $uuid}) + RETURN + """ + + COMMUNITY_NODE_RETURN, uuid=uuid, routing_='r', ) @@ -474,14 +462,11 @@ class CommunityNode(Node): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (n:Community) WHERE n.uuid IN $uuids - RETURN - n.uuid As uuid, - n.name AS name, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary - """, + MATCH (n:Community) + WHERE n.uuid IN $uuids + RETURN + """ + + COMMUNITY_NODE_RETURN, uuids=uuids, routing_='r', ) @@ -503,18 +488,17 @@ class CommunityNode(Node): records, _, _ = await driver.execute_query( """ - MATCH (n:Community) WHERE n.group_id IN $group_ids - """ + MATCH (n:Community) + WHERE n.group_id IN $group_ids + """ + cursor_query + """ - RETURN - n.uuid As uuid, - n.name AS name, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary - ORDER BY n.uuid DESC - """ + RETURN + """ + + COMMUNITY_NODE_RETURN + + """ + ORDER BY n.uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, @@ -586,6 +570,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode: async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]): if not nodes: # Handle empty list case return + name_embeddings = await embedder.create_batch([node.name for node in nodes]) for node, name_embedding in zip(nodes, name_embeddings, strict=True): node.name_embedding = name_embedding diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index 40c12576..90abb766 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -72,7 +72,7 @@ def edge_search_filter_query_constructor( if filters.edge_types is not None: edge_types = filters.edge_types - edge_types_filter = '\nAND r.name in $edge_types' + edge_types_filter = '\nAND e.name in $edge_types' filter_query += edge_types_filter filter_params['edge_types'] = edge_types @@ -88,7 +88,7 @@ def edge_search_filter_query_constructor( filter_params['valid_at_' + str(j)] = date_filter.date and_filters = [ - '(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})' + '(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})' for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -113,7 +113,7 @@ def edge_search_filter_query_constructor( filter_params['invalid_at_' + str(j)] = date_filter.date and_filters = [ - '(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})' + '(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})' for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -138,7 +138,7 @@ def edge_search_filter_query_constructor( filter_params['created_at_' + str(j)] = date_filter.date and_filters = [ - '(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})' + '(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})' for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -163,7 +163,7 @@ def edge_search_filter_query_constructor( filter_params['expired_at_' + str(j)] = date_filter.date and_filters = [ - '(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})' + '(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})' for j, date_filter in enumerate(or_list) ] and_filter_query = '' diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 5d6828f7..af9dce67 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -23,7 +23,7 @@ import numpy as np from numpy._typing import NDArray from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver +from graphiti_core.driver.driver import GraphDriver, GraphProvider from graphiti_core.edges import EntityEdge, get_entity_edge_from_record from graphiti_core.graph_queries import ( get_nodes_query, @@ -36,6 +36,8 @@ from graphiti_core.helpers import ( normalize_l2, semaphore_gather, ) +from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN +from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN from graphiti_core.nodes import ( ENTITY_NODE_RETURN, CommunityNode, @@ -100,20 +102,13 @@ async def get_mentioned_nodes( ) -> list[EntityNode]: episode_uuids = [episode.uuid for episode in episodes] - query = """ - MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids - RETURN DISTINCT - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes - """ - records, _, _ = await driver.execute_query( - query, + """ + MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) + WHERE episode.uuid IN $uuids + RETURN DISTINCT + """ + + ENTITY_NODE_RETURN, uuids=episode_uuids, routing_='r', ) @@ -128,18 +123,13 @@ async def get_communities_by_nodes( ) -> list[CommunityNode]: node_uuids = [node.uuid for node in nodes] - query = """ - MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids - RETURN DISTINCT - c.uuid As uuid, - c.group_id AS group_id, - c.name AS name, - c.created_at AS created_at, - c.summary AS summary - """ - records, _, _ = await driver.execute_query( - query, + """ + MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity) + WHERE m.uuid IN $uuids + RETURN DISTINCT + """ + + COMMUNITY_NODE_RETURN, uuids=node_uuids, routing_='r', ) @@ -164,38 +154,30 @@ async def edge_fulltext_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query = ( - get_relationships_query('edge_name_and_fact', db_type=driver.provider) + get_relationships_query('edge_name_and_fact', provider=driver.provider) + """ YIELD relationship AS rel, score - MATCH (n:Entity)-[r:RELATES_TO {uuid: rel.uuid}]->(m:Entity) - WHERE r.group_id IN $group_ids """ + MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity) + WHERE e.group_id IN $group_ids """ + filter_query + """ - WITH r, score, startNode(r) AS n, endNode(r) AS m + WITH e, score, n, m RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at, - properties(r) AS attributes - ORDER BY score DESC LIMIT $limit + """ + + ENTITY_EDGE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit """ ) records, _, _ = await driver.execute_query( query, - params=filter_params, query=fuzzy_query, group_ids=group_ids, limit=limit, routing_='r', + **filter_params, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -219,58 +201,47 @@ async def edge_similarity_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query_params.update(filter_params) - group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL' + group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL' if group_ids is not None: - group_filter_query += '\nAND r.group_id IN $group_ids' + group_filter_query += '\nAND e.group_id IN $group_ids' query_params['group_ids'] = group_ids - query_params['source_node_uuid'] = source_node_uuid - query_params['target_node_uuid'] = target_node_uuid if source_node_uuid is not None: - group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])' + query_params['source_uuid'] = source_node_uuid + group_filter_query += '\nAND (n.uuid = $source_uuid)' if target_node_uuid is not None: - group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])' + query_params['target_uuid'] = target_node_uuid + group_filter_query += '\nAND (m.uuid = $target_uuid)' query = ( RUNTIME_QUERY + """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) """ + group_filter_query + filter_query + """ - WITH DISTINCT r, """ - + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider) + WITH DISTINCT e, n, m, """ + + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider) + """ AS score WHERE score > $min_score RETURN - r.uuid AS uuid, - r.group_id AS group_id, - startNode(r).uuid AS source_node_uuid, - endNode(r).uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at, - properties(r) AS attributes + """ + + ENTITY_EDGE_RETURN + + """ ORDER BY score DESC LIMIT $limit """ ) - records, header, _ = await driver.execute_query( + + records, _, _ = await driver.execute_query( query, - params=query_params, search_vector=search_vector, - source_uuid=source_node_uuid, - target_uuid=target_node_uuid, - group_ids=group_ids, limit=limit, min_score=min_score, routing_='r', + **query_params, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -293,41 +264,31 @@ async def edge_bfs_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query = ( + f""" + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity) + UNWIND relationships(path) AS rel + MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) + WHERE e.uuid = rel.uuid + AND e.group_id IN $group_ids """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - UNWIND relationships(path) AS rel - MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) - WHERE r.uuid = rel.uuid - AND r.group_id IN $group_ids - """ + filter_query - + """ - RETURN DISTINCT - r.uuid AS uuid, - r.group_id AS group_id, - startNode(r).uuid AS source_node_uuid, - endNode(r).uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at, - properties(r) AS attributes - LIMIT $limit + + """ + RETURN DISTINCT + """ + + ENTITY_EDGE_RETURN + + """ + LIMIT $limit """ ) records, _, _ = await driver.execute_query( query, - params=filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, - depth=bfs_max_depth, group_ids=group_ids, limit=limit, routing_='r', + **filter_params, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -352,23 +313,27 @@ async def node_fulltext_search( get_nodes_query(driver.provider, 'node_name_and_summary', '$query') + """ YIELD node AS n, score - WITH n, score - LIMIT $limit - WHERE n:Entity AND n.group_id IN $group_ids + WHERE n:Entity AND n.group_id IN $group_ids + WITH n, score + LIMIT $limit """ + filter_query + + """ + RETURN + """ + ENTITY_NODE_RETURN + """ ORDER BY score DESC """ ) - records, header, _ = await driver.execute_query( + + records, _, _ = await driver.execute_query( query, - params=filter_params, query=fuzzy_query, group_ids=group_ids, limit=limit, routing_='r', + **filter_params, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -406,22 +371,23 @@ async def node_similarity_search( WITH n, """ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider) + """ AS score - WHERE score > $min_score""" + WHERE score > $min_score + RETURN + """ + ENTITY_NODE_RETURN + """ ORDER BY score DESC LIMIT $limit - """ + """ ) - records, header, _ = await driver.execute_query( + records, _, _ = await driver.execute_query( query, - params=query_params, search_vector=search_vector, - group_ids=group_ids, limit=limit, min_score=min_score, routing_='r', + **query_params, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -444,26 +410,29 @@ async def node_bfs_search( filter_query, filter_params = node_search_filter_query_constructor(search_filter) query = ( + f""" + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity) + WHERE n.group_id = origin.group_id + AND origin.group_id IN $group_ids """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - WHERE n.group_id = origin.group_id - AND origin.group_id IN $group_ids - """ + filter_query + + """ + RETURN + """ + ENTITY_NODE_RETURN + """ LIMIT $limit """ ) + records, _, _ = await driver.execute_query( query, - params=filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, - depth=bfs_max_depth, group_ids=group_ids, limit=limit, routing_='r', + **filter_params, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -489,16 +458,10 @@ async def episode_fulltext_search( MATCH (e:Episodic) WHERE e.uuid = episode.uuid AND e.group_id IN $group_ids - RETURN - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges + RETURN + """ + + EPISODIC_NODE_RETURN + + """ ORDER BY score DESC LIMIT $limit """ @@ -530,15 +493,12 @@ async def community_fulltext_search( query = ( get_nodes_query(driver.provider, 'community_name', '$query') + """ - YIELD node AS comm, score - WHERE comm.group_id IN $group_ids + YIELD node AS n, score + WHERE n.group_id IN $group_ids RETURN - comm.uuid AS uuid, - comm.group_id AS group_id, - comm.name AS name, - comm.created_at AS created_at, - comm.summary AS summary, - comm.name_embedding AS name_embedding + """ + + COMMUNITY_NODE_RETURN + + """ ORDER BY score DESC LIMIT $limit """ @@ -568,39 +528,37 @@ async def community_similarity_search( group_filter_query: LiteralString = '' if group_ids is not None: - group_filter_query += 'WHERE comm.group_id IN $group_ids' + group_filter_query += 'WHERE n.group_id IN $group_ids' query_params['group_ids'] = group_ids query = ( RUNTIME_QUERY + """ - MATCH (comm:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ - WITH comm, """ - + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider) + WITH n, + """ + + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider) + """ AS score - WHERE score > $min_score - RETURN - comm.uuid As uuid, - comm.group_id AS group_id, - comm.name AS name, - comm.created_at AS created_at, - comm.summary AS summary, - comm.name_embedding AS name_embedding - ORDER BY score DESC - LIMIT $limit + WHERE score > $min_score + RETURN + """ + + COMMUNITY_NODE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit """ ) records, _, _ = await driver.execute_query( query, search_vector=search_vector, - group_ids=group_ids, limit=limit, min_score=min_score, routing_='r', + **query_params, ) communities = [get_community_node_from_record(record) for record in records] @@ -719,8 +677,8 @@ async def get_relevant_nodes( WHERE m.group_id = $group_id WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes - WITH node, - top_vector_nodes, + WITH node, + top_vector_nodes, [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes @@ -728,10 +686,10 @@ async def get_relevant_nodes( UNWIND combined_nodes AS combined_node WITH node, collect(DISTINCT combined_node) AS deduped_nodes - RETURN + RETURN node.uuid AS search_node_uuid, [x IN deduped_nodes | { - uuid: x.uuid, + uuid: x.uuid, name: x.name, name_embedding: x.name_embedding, group_id: x.group_id, @@ -755,12 +713,12 @@ async def get_relevant_nodes( results, _, _ = await driver.execute_query( query, - params=query_params, nodes=query_nodes, group_id=group_id, limit=limit, min_score=min_score, routing_='r', + **query_params, ) relevant_nodes_dict: dict[str, list[EntityNode]] = { @@ -825,11 +783,11 @@ async def get_relevant_edges( results, _, _ = await driver.execute_query( query, - params=query_params, edges=[edge.model_dump() for edge in edges], limit=limit, min_score=min_score, routing_='r', + **query_params, ) relevant_edges_dict: dict[str, list[EntityEdge]] = { @@ -895,11 +853,11 @@ async def get_edge_invalidation_candidates( results, _, _ = await driver.execute_query( query, - params=query_params, edges=[edge.model_dump() for edge in edges], limit=limit, min_score=min_score, routing_='r', + **query_params, ) invalidation_edges_dict: dict[str, list[EntityEdge]] = { result['search_edge_uuid']: [ @@ -943,18 +901,17 @@ async def node_distance_reranker( scores: dict[str, float] = {center_node_uuid: 0.0} # Find the shortest path to center node - query = """ + results, header, _ = await driver.execute_query( + """ UNWIND $node_uuids AS node_uuid MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid}) RETURN 1 AS score, node_uuid AS uuid - """ - results, header, _ = await driver.execute_query( - query, + """, node_uuids=filtered_uuids, center_uuid=center_node_uuid, routing_='r', ) - if driver.provider == 'falkordb': + if driver.provider == GraphProvider.FALKORDB: results = [dict(zip(header, row, strict=True)) for row in results] for result in results: @@ -987,13 +944,12 @@ async def episode_mentions_reranker( scores: dict[str, float] = {} # Find the shortest path to center node - query = """ - UNWIND $node_uuids AS node_uuid + results, _, _ = await driver.execute_query( + """ + UNWIND $node_uuids AS node_uuid MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid}) RETURN count(*) AS score, n.uuid AS uuid - """ - results, _, _ = await driver.execute_query( - query, + """, node_uuids=sorted_uuids, routing_='r', ) @@ -1053,15 +1009,16 @@ def maximal_marginal_relevance( async def get_embeddings_for_nodes( driver: GraphDriver, nodes: list[EntityNode] ) -> dict[str, list[float]]: - query: LiteralString = """MATCH (n:Entity) - WHERE n.uuid IN $node_uuids - RETURN DISTINCT - n.uuid AS uuid, - n.name_embedding AS name_embedding - """ - results, _, _ = await driver.execute_query( - query, node_uuids=[node.uuid for node in nodes], routing_='r' + """ + MATCH (n:Entity) + WHERE n.uuid IN $node_uuids + RETURN DISTINCT + n.uuid AS uuid, + n.name_embedding AS name_embedding + """, + node_uuids=[node.uuid for node in nodes], + routing_='r', ) embeddings_dict: dict[str, list[float]] = {} @@ -1077,15 +1034,14 @@ async def get_embeddings_for_nodes( async def get_embeddings_for_communities( driver: GraphDriver, communities: list[CommunityNode] ) -> dict[str, list[float]]: - query: LiteralString = """MATCH (c:Community) - WHERE c.uuid IN $community_uuids - RETURN DISTINCT - c.uuid AS uuid, - c.name_embedding AS name_embedding - """ - results, _, _ = await driver.execute_query( - query, + """ + MATCH (c:Community) + WHERE c.uuid IN $community_uuids + RETURN DISTINCT + c.uuid AS uuid, + c.name_embedding AS name_embedding + """, community_uuids=[community.uuid for community in communities], routing_='r', ) @@ -1103,15 +1059,14 @@ async def get_embeddings_for_communities( async def get_embeddings_for_edges( driver: GraphDriver, edges: list[EntityEdge] ) -> dict[str, list[float]]: - query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) - WHERE e.uuid IN $edge_uuids - RETURN DISTINCT - e.uuid AS uuid, - e.fact_embedding AS fact_embedding - """ - results, _, _ = await driver.execute_query( - query, + """ + MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) + WHERE e.uuid IN $edge_uuids + RETURN DISTINCT + e.uuid AS uuid, + e.fact_embedding AS fact_embedding + """, edge_uuids=[edge.uuid for edge in edges], routing_='r', ) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 8d5e8a68..9263edf3 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -25,17 +25,15 @@ from typing_extensions import Any from graphiti_core.driver.driver import GraphDriver, GraphDriverSession from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings from graphiti_core.embedder import EmbedderClient -from graphiti_core.graph_queries import ( - get_entity_edge_save_bulk_query, - get_entity_node_save_bulk_query, -) from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import normalize_l2, semaphore_gather from graphiti_core.models.edges.edge_db_queries import ( EPISODIC_EDGE_SAVE_BULK, + get_entity_edge_save_bulk_query, ) from graphiti_core.models.nodes.node_db_queries import ( EPISODIC_NODE_SAVE_BULK, + get_entity_node_save_bulk_query, ) from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.utils.maintenance.edge_operations import ( @@ -158,7 +156,7 @@ async def add_nodes_and_edges_bulk_tx( edges.append(edge_data) await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) - entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider) + entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes) await tx.run(entity_node_save_bulk, nodes=nodes) await tx.run( EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges] diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index de74d56b..243a5079 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -34,7 +34,7 @@ async def get_community_clusters( group_id_values, _, _ = await driver.execute_query( """ MATCH (n:Entity WHERE n.group_id IS NOT NULL) - RETURN + RETURN collect(DISTINCT n.group_id) AS group_ids """, ) @@ -233,10 +233,10 @@ async def determine_entity_community( """ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid}) RETURN - c.uuid As uuid, + c.uuid AS uuid, c.name AS name, c.group_id AS group_id, - c.created_at AS created_at, + c.created_at AS created_at, c.summary AS summary """, entity_uuid=entity.uuid, @@ -250,10 +250,10 @@ async def determine_entity_community( """ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid}) RETURN - c.uuid As uuid, + c.uuid AS uuid, c.name AS name, c.group_id AS group_id, - c.created_at AS created_at, + c.created_at AS created_at, c.summary AS summary """, entity_uuid=entity.uuid, @@ -286,11 +286,11 @@ async def determine_entity_community( async def update_community( driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode -): +) -> tuple[list[CommunityNode], list[CommunityEdge]]: community, is_new = await determine_entity_community(driver, entity) if community is None: - return + return [], [] new_summary = await summarize_pair(llm_client, (entity.summary, community.summary)) new_name = await generate_summary_description(llm_client, new_summary) @@ -298,10 +298,14 @@ async def update_community( community.summary = new_summary community.name = new_name + community_edges = [] if is_new: community_edge = (build_community_edges([entity], community, utc_now()))[0] await community_edge.save(driver) + community_edges.append(community_edge) await community.generate_name_embedding(embedder) await community.save(driver) + + return [community], community_edges diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index f6975053..b866607f 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -15,14 +15,15 @@ limitations under the License. """ import logging -from datetime import datetime, timezone +from datetime import datetime from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices -from graphiti_core.helpers import parse_db_date, semaphore_gather -from graphiti_core.nodes import EpisodeType, EpisodicNode +from graphiti_core.helpers import semaphore_gather +from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN +from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record EPISODE_WINDOW_LEN = 3 @@ -33,8 +34,8 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo if delete_existing: records, _, _ = await driver.execute_query( """ - SHOW INDEXES YIELD name - """, + SHOW INDEXES YIELD name + """, ) index_names = [record['name'] for record in records] await semaphore_gather( @@ -108,19 +109,16 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) + WHERE e.valid_at <= $reference_time + """ + group_id_filter + source_filter + """ - RETURN e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.group_id AS group_id, - e.name AS name, - e.source_description AS source_description, - e.source AS source + RETURN + """ + + EPISODIC_NODE_RETURN + + """ ORDER BY e.valid_at DESC LIMIT $num_episodes """ @@ -133,18 +131,5 @@ async def retrieve_episodes( group_ids=group_ids, ) - episodes = [ - EpisodicNode( - content=record['content'], - created_at=parse_db_date(record['created_at']) - or datetime.min.replace(tzinfo=timezone.utc), - valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc), - uuid=record['uuid'], - group_id=record['group_id'], - source=EpisodeType.from_str(record['source']), - name=record['name'], - source_description=record['source_description'], - ) - for record in result - ] + episodes = [get_episodic_node_from_record(record) for record in result] return list(reversed(episodes)) # Return in chronological order diff --git a/pyproject.toml b/pyproject.toml index d2e49acc..8c57be09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,9 +3,9 @@ name = "graphiti-core" description = "A temporal graph building library" version = "0.18.0" authors = [ - { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, - { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, - { "name" = "Daniel Chalef", "email" = "daniel@getzep.com" }, + { name = "Paul Paliychuk", email = "paul@getzep.com" }, + { name = "Preston Rasmussen", email = "preston@getzep.com" }, + { name = "Daniel Chalef", email = "daniel@getzep.com" }, ] readme = "README.md" license = "Apache-2.0" diff --git a/tests/driver/test_falkordb_driver.py b/tests/driver/test_falkordb_driver.py index 260e24d2..6cca9e74 100644 --- a/tests/driver/test_falkordb_driver.py +++ b/tests/driver/test_falkordb_driver.py @@ -21,6 +21,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from graphiti_core.driver.driver import GraphProvider + try: from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession @@ -48,7 +50,7 @@ class TestFalkorDriver: driver = FalkorDriver( host='test-host', port='1234', username='test-user', password='test-pass' ) - assert driver.provider == 'falkordb' + assert driver.provider == GraphProvider.FALKORDB mock_falkor_db.assert_called_once_with( host='test-host', port='1234', username='test-user', password='test-pass' ) @@ -59,14 +61,14 @@ class TestFalkorDriver: with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class: mock_falkor_db = MagicMock() driver = FalkorDriver(falkor_db=mock_falkor_db) - assert driver.provider == 'falkordb' + assert driver.provider == GraphProvider.FALKORDB assert driver.client is mock_falkor_db mock_falkor_db_class.assert_not_called() @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') def test_provider(self): """Test driver provider identification.""" - assert self.driver.provider == 'falkordb' + assert self.driver.provider == GraphProvider.FALKORDB @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') def test_get_graph_with_name(self): diff --git a/tests/helpers_test.py b/tests/helpers_test.py index 006a9cd3..95878825 100644 --- a/tests/helpers_test.py +++ b/tests/helpers_test.py @@ -14,10 +14,68 @@ See the License for the specific language governing permissions and limitations under the License. """ -import pytest +import os +import pytest +from dotenv import load_dotenv + +from graphiti_core.driver.driver import GraphDriver from graphiti_core.helpers import lucene_sanitize +load_dotenv() + +HAS_NEO4J = False +HAS_FALKORDB = False +if os.getenv('DISABLE_NEO4J') is None: + try: + from graphiti_core.driver.neo4j_driver import Neo4jDriver + + HAS_NEO4J = True + except ImportError: + pass + +if os.getenv('DISABLE_FALKORDB') is None: + try: + from graphiti_core.driver.falkordb_driver import FalkorDriver + + HAS_FALKORDB = True + except ImportError: + pass + +NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687') +NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j') +NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test') + +FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost') +FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379') +FALKORDB_USER = os.getenv('FALKORDB_USER', None) +FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) + + +def get_driver(driver_name: str) -> GraphDriver: + if driver_name == 'neo4j': + return Neo4jDriver( + uri=NEO4J_URI, + user=NEO4J_USER, + password=NEO4J_PASSWORD, + ) + elif driver_name == 'falkordb': + return FalkorDriver( + host=FALKORDB_HOST, + port=int(FALKORDB_PORT), + username=FALKORDB_USER, + password=FALKORDB_PASSWORD, + ) + else: + raise ValueError(f'Driver {driver_name} not available') + + +drivers: list[str] = [] +if HAS_NEO4J: + drivers.append('neo4j') +if HAS_FALKORDB: + drivers.append('falkordb') + def test_lucene_sanitize(): # Call the function with test data diff --git a/tests/test_edge_int.py b/tests/test_edge_int.py new file mode 100644 index 00000000..6eb769a4 --- /dev/null +++ b/tests/test_edge_int.py @@ -0,0 +1,384 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +import sys +from datetime import datetime +from uuid import uuid4 + +import numpy as np +import pytest + +from graphiti_core.driver.driver import GraphDriver +from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge +from graphiti_core.embedder.openai import OpenAIEmbedder +from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode +from tests.helpers_test import drivers, get_driver + +pytestmark = pytest.mark.integration + +pytest_plugins = ('pytest_asyncio',) + +group_id = f'test_group_{str(uuid4())}' + + +def setup_logging(): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO + + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Add formatter to console handler + console_handler.setFormatter(formatter) + + # Add console handler to logger + logger.addHandler(console_handler) + + return logger + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_episodic_edge(driver): + graph_driver = get_driver(driver) + embedder = OpenAIEmbedder() + + now = datetime.now() + + episode_node = EpisodicNode( + name='test_episode', + labels=[], + created_at=now, + valid_at=now, + source=EpisodeType.message, + source_description='conversation message', + content='Alice likes Bob', + entity_edges=[], + group_id=group_id, + ) + + node_count = await get_node_count(graph_driver, episode_node.uuid) + assert node_count == 0 + await episode_node.save(graph_driver) + node_count = await get_node_count(graph_driver, episode_node.uuid) + assert node_count == 1 + + alice_node = EntityNode( + name='Alice', + labels=[], + created_at=now, + summary='Alice summary', + group_id=group_id, + ) + await alice_node.generate_name_embedding(embedder) + + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + await alice_node.save(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 1 + + episodic_edge = EpisodicEdge( + source_node_uuid=episode_node.uuid, + target_node_uuid=alice_node.uuid, + created_at=now, + group_id=group_id, + ) + + edge_count = await get_edge_count(graph_driver, episodic_edge.uuid) + assert edge_count == 0 + await episodic_edge.save(graph_driver) + edge_count = await get_edge_count(graph_driver, episodic_edge.uuid) + assert edge_count == 1 + + retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid) + assert retrieved.uuid == episodic_edge.uuid + assert retrieved.source_node_uuid == episode_node.uuid + assert retrieved.target_node_uuid == alice_node.uuid + assert retrieved.created_at == now + assert retrieved.group_id == group_id + + retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid]) + assert len(retrieved) == 1 + assert retrieved[0].uuid == episodic_edge.uuid + assert retrieved[0].source_node_uuid == episode_node.uuid + assert retrieved[0].target_node_uuid == alice_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == episodic_edge.uuid + assert retrieved[0].source_node_uuid == episode_node.uuid + assert retrieved[0].target_node_uuid == alice_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + await episodic_edge.delete(graph_driver) + edge_count = await get_edge_count(graph_driver, episodic_edge.uuid) + assert edge_count == 0 + + await episode_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, episode_node.uuid) + assert node_count == 0 + + await alice_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + + await graph_driver.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_entity_edge(driver): + graph_driver = get_driver(driver) + embedder = OpenAIEmbedder() + + now = datetime.now() + + alice_node = EntityNode( + name='Alice', + labels=[], + created_at=now, + summary='Alice summary', + group_id=group_id, + ) + await alice_node.generate_name_embedding(embedder) + + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + await alice_node.save(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 1 + + bob_node = EntityNode( + name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id + ) + await bob_node.generate_name_embedding(embedder) + + node_count = await get_node_count(graph_driver, bob_node.uuid) + assert node_count == 0 + await bob_node.save(graph_driver) + node_count = await get_node_count(graph_driver, bob_node.uuid) + assert node_count == 1 + + entity_edge = EntityEdge( + source_node_uuid=alice_node.uuid, + target_node_uuid=bob_node.uuid, + created_at=now, + name='likes', + fact='Alice likes Bob', + episodes=[], + expired_at=now, + valid_at=now, + invalid_at=now, + group_id=group_id, + ) + edge_embedding = await entity_edge.generate_embedding(embedder) + + edge_count = await get_edge_count(graph_driver, entity_edge.uuid) + assert edge_count == 0 + await entity_edge.save(graph_driver) + edge_count = await get_edge_count(graph_driver, entity_edge.uuid) + assert edge_count == 1 + + retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid) + assert retrieved.uuid == entity_edge.uuid + assert retrieved.source_node_uuid == alice_node.uuid + assert retrieved.target_node_uuid == bob_node.uuid + assert retrieved.created_at == now + assert retrieved.group_id == group_id + + retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid]) + assert len(retrieved) == 1 + assert retrieved[0].uuid == entity_edge.uuid + assert retrieved[0].source_node_uuid == alice_node.uuid + assert retrieved[0].target_node_uuid == bob_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == entity_edge.uuid + assert retrieved[0].source_node_uuid == alice_node.uuid + assert retrieved[0].target_node_uuid == bob_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid) + assert len(retrieved) == 1 + assert retrieved[0].uuid == entity_edge.uuid + assert retrieved[0].source_node_uuid == alice_node.uuid + assert retrieved[0].target_node_uuid == bob_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + await entity_edge.load_fact_embedding(graph_driver) + assert np.allclose(entity_edge.fact_embedding, edge_embedding) + + await entity_edge.delete(graph_driver) + edge_count = await get_edge_count(graph_driver, entity_edge.uuid) + assert edge_count == 0 + + await alice_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + + await bob_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, bob_node.uuid) + assert node_count == 0 + + await graph_driver.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_community_edge(driver): + graph_driver = get_driver(driver) + embedder = OpenAIEmbedder() + + now = datetime.now() + + community_node_1 = CommunityNode( + name='Community A', + group_id=group_id, + summary='Community A summary', + ) + await community_node_1.generate_name_embedding(embedder) + node_count = await get_node_count(graph_driver, community_node_1.uuid) + assert node_count == 0 + await community_node_1.save(graph_driver) + node_count = await get_node_count(graph_driver, community_node_1.uuid) + assert node_count == 1 + + community_node_2 = CommunityNode( + name='Community B', + group_id=group_id, + summary='Community B summary', + ) + await community_node_2.generate_name_embedding(embedder) + node_count = await get_node_count(graph_driver, community_node_2.uuid) + assert node_count == 0 + await community_node_2.save(graph_driver) + node_count = await get_node_count(graph_driver, community_node_2.uuid) + assert node_count == 1 + + alice_node = EntityNode( + name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id + ) + await alice_node.generate_name_embedding(embedder) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + await alice_node.save(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 1 + + community_edge = CommunityEdge( + source_node_uuid=community_node_1.uuid, + target_node_uuid=community_node_2.uuid, + created_at=now, + group_id=group_id, + ) + edge_count = await get_edge_count(graph_driver, community_edge.uuid) + assert edge_count == 0 + await community_edge.save(graph_driver) + edge_count = await get_edge_count(graph_driver, community_edge.uuid) + assert edge_count == 1 + + retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid) + assert retrieved.uuid == community_edge.uuid + assert retrieved.source_node_uuid == community_node_1.uuid + assert retrieved.target_node_uuid == community_node_2.uuid + assert retrieved.created_at == now + assert retrieved.group_id == group_id + + retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid]) + assert len(retrieved) == 1 + assert retrieved[0].uuid == community_edge.uuid + assert retrieved[0].source_node_uuid == community_node_1.uuid + assert retrieved[0].target_node_uuid == community_node_2.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1) + assert len(retrieved) == 1 + assert retrieved[0].uuid == community_edge.uuid + assert retrieved[0].source_node_uuid == community_node_1.uuid + assert retrieved[0].target_node_uuid == community_node_2.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + await community_edge.delete(graph_driver) + edge_count = await get_edge_count(graph_driver, community_edge.uuid) + assert edge_count == 0 + + await alice_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + + await community_node_1.delete(graph_driver) + node_count = await get_node_count(graph_driver, community_node_1.uuid) + assert node_count == 0 + + await community_node_2.delete(graph_driver) + node_count = await get_node_count(graph_driver, community_node_2.uuid) + assert node_count == 0 + + await graph_driver.close() + + +async def get_node_count(driver: GraphDriver, uuid: str): + results, _, _ = await driver.execute_query( + """ + MATCH (n {uuid: $uuid}) + RETURN COUNT(n) as count + """, + uuid=uuid, + ) + return int(results[0]['count']) + + +async def get_edge_count(driver: GraphDriver, uuid: str): + results, _, _ = await driver.execute_query( + """ + MATCH (n)-[e {uuid: $uuid}]->(m) + RETURN COUNT(e) as count + UNION ALL + MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2) + RETURN COUNT(m) as count + """, + uuid=uuid, + ) + return sum(int(result['count']) for result in results) diff --git a/tests/test_entity_exclusion_int.py b/tests/test_entity_exclusion_int.py index 715e5ada..473177b1 100644 --- a/tests/test_entity_exclusion_int.py +++ b/tests/test_entity_exclusion_int.py @@ -14,26 +14,18 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os from datetime import datetime, timezone import pytest -from dotenv import load_dotenv from pydantic import BaseModel, Field from graphiti_core.graphiti import Graphiti from graphiti_core.helpers import validate_excluded_entity_types +from tests.helpers_test import drivers, get_driver pytestmark = pytest.mark.integration - pytest_plugins = ('pytest_asyncio',) -load_dotenv() - -NEO4J_URI = os.getenv('NEO4J_URI') -NEO4J_USER = os.getenv('NEO4J_USER') -NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD') - # Test entity type definitions class Person(BaseModel): @@ -65,9 +57,14 @@ class Location(BaseModel): @pytest.mark.asyncio -async def test_exclude_default_entity_type(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_exclude_default_entity_type(driver): """Test excluding the default 'Entity' type while keeping custom types.""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: await graphiti.build_indices_and_constraints() @@ -118,9 +115,14 @@ async def test_exclude_default_entity_type(): @pytest.mark.asyncio -async def test_exclude_specific_custom_types(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_exclude_specific_custom_types(driver): """Test excluding specific custom entity types while keeping others.""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: await graphiti.build_indices_and_constraints() @@ -177,9 +179,14 @@ async def test_exclude_specific_custom_types(): @pytest.mark.asyncio -async def test_exclude_all_types(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_exclude_all_types(driver): """Test excluding all entity types (edge case).""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: await graphiti.build_indices_and_constraints() @@ -221,9 +228,14 @@ async def test_exclude_all_types(): @pytest.mark.asyncio -async def test_exclude_no_types(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_exclude_no_types(driver): """Test normal behavior when no types are excluded (baseline test).""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: await graphiti.build_indices_and_constraints() @@ -299,9 +311,14 @@ def test_validation_invalid_excluded_types(): @pytest.mark.asyncio -async def test_excluded_types_parameter_validation_in_add_episode(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_excluded_types_parameter_validation_in_add_episode(driver): """Test that add_episode validates excluded_entity_types parameter.""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: entity_types = { diff --git a/tests/test_graphiti_falkordb_int.py b/tests/test_graphiti_falkordb_int.py deleted file mode 100644 index c41ef47e..00000000 --- a/tests/test_graphiti_falkordb_int.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import logging -import os -import sys -import unittest -from datetime import datetime, timezone - -import pytest -from dotenv import load_dotenv - -from graphiti_core.edges import EntityEdge, EpisodicEdge -from graphiti_core.graphiti import Graphiti -from graphiti_core.helpers import semaphore_gather -from graphiti_core.nodes import EntityNode, EpisodicNode -from graphiti_core.search.search_helpers import search_results_to_context_string - -try: - from graphiti_core.driver.falkordb_driver import FalkorDriver - - HAS_FALKORDB = True -except ImportError: - FalkorDriver = None - HAS_FALKORDB = False - -pytestmark = pytest.mark.integration - -pytest_plugins = ('pytest_asyncio',) - -load_dotenv() - -FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost') -FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379') -FALKORDB_USER = os.getenv('FALKORDB_USER', None) -FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) - - -def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO - - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - # Add formatter to console handler - console_handler.setFormatter(formatter) - - # Add console handler to logger - logger.addHandler(console_handler) - - return logger - - -@pytest.mark.asyncio -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_graphiti_falkordb_init(): - logger = setup_logging() - - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - graphiti = Graphiti(graph_driver=falkor_driver) - - results = await graphiti.search_(query='Who is the user?') - - pretty_results = search_results_to_context_string(results) - - logger.info(pretty_results) - - await graphiti.close() - - -@pytest.mark.asyncio -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_graph_falkordb_integration(): - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - client = Graphiti(graph_driver=falkor_driver) - embedder = client.embedder - driver = client.driver - - now = datetime.now(timezone.utc) - episode = EpisodicNode( - name='test_episode', - labels=[], - created_at=now, - valid_at=now, - source='message', - source_description='conversation message', - content='Alice likes Bob', - entity_edges=[], - ) - - alice_node = EntityNode( - name='Alice', - labels=[], - created_at=now, - summary='Alice summary', - ) - - bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary') - - episodic_edge_1 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now - ) - - episodic_edge_2 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now - ) - - entity_edge = EntityEdge( - source_node_uuid=alice_node.uuid, - target_node_uuid=bob_node.uuid, - created_at=now, - name='likes', - fact='Alice likes Bob', - episodes=[], - expired_at=now, - valid_at=now, - invalid_at=now, - ) - - await entity_edge.generate_embedding(embedder) - - nodes = [episode, alice_node, bob_node] - edges = [episodic_edge_1, episodic_edge_2, entity_edge] - - # test save - await semaphore_gather(*[node.save(driver) for node in nodes]) - await semaphore_gather(*[edge.save(driver) for edge in edges]) - - # test get - assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None - assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None - assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None - assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None - - # test delete - await semaphore_gather(*[node.delete(driver) for node in nodes]) - await semaphore_gather(*[edge.delete(driver) for edge in edges]) - - await client.close() diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index acec74cb..9d98f4de 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -15,31 +15,19 @@ limitations under the License. """ import logging -import os import sys -from datetime import datetime, timezone import pytest -from dotenv import load_dotenv -from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.graphiti import Graphiti -from graphiti_core.helpers import semaphore_gather -from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters from graphiti_core.search.search_helpers import search_results_to_context_string from graphiti_core.utils.datetime_utils import utc_now +from tests.helpers_test import drivers, get_driver pytestmark = pytest.mark.integration - pytest_plugins = ('pytest_asyncio',) -load_dotenv() - -NEO4J_URI = os.getenv('NEO4J_URI') -NEO4j_USER = os.getenv('NEO4J_USER') -NEO4j_PASSWORD = os.getenv('NEO4J_PASSWORD') - def setup_logging(): # Create a logger @@ -63,9 +51,18 @@ def setup_logging(): @pytest.mark.asyncio -async def test_graphiti_init(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_graphiti_init(driver): logger = setup_logging() - graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) + driver = get_driver(driver) + graphiti = Graphiti(graph_driver=driver) + + await graphiti.build_indices_and_constraints() + search_filter = SearchFilters( created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]] ) @@ -76,74 +73,6 @@ async def test_graphiti_init(): ) pretty_results = search_results_to_context_string(results) - logger.info(pretty_results) await graphiti.close() - - -@pytest.mark.asyncio -async def test_graph_integration(): - client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - embedder = client.embedder - driver = client.driver - - now = datetime.now(timezone.utc) - episode = EpisodicNode( - name='test_episode', - labels=[], - created_at=now, - valid_at=now, - source='message', - source_description='conversation message', - content='Alice likes Bob', - entity_edges=[], - ) - - alice_node = EntityNode( - name='Alice', - labels=[], - created_at=now, - summary='Alice summary', - ) - - bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary') - - episodic_edge_1 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now - ) - - episodic_edge_2 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now - ) - - entity_edge = EntityEdge( - source_node_uuid=alice_node.uuid, - target_node_uuid=bob_node.uuid, - created_at=now, - name='likes', - fact='Alice likes Bob', - episodes=[], - expired_at=now, - valid_at=now, - invalid_at=now, - ) - - await entity_edge.generate_embedding(embedder) - - nodes = [episode, alice_node, bob_node] - edges = [episodic_edge_1, episodic_edge_2, entity_edge] - - # test save - await semaphore_gather(*[node.save(driver) for node in nodes]) - await semaphore_gather(*[edge.save(driver) for edge in edges]) - - # test get - assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None - assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None - assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None - assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None - - # test delete - await semaphore_gather(*[node.delete(driver) for node in nodes]) - await semaphore_gather(*[edge.delete(driver) for edge in edges]) diff --git a/tests/test_node_falkordb_int.py b/tests/test_node_falkordb_int.py deleted file mode 100644 index 35eaa578..00000000 --- a/tests/test_node_falkordb_int.py +++ /dev/null @@ -1,139 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import os -import unittest -from datetime import datetime, timezone -from uuid import uuid4 - -import pytest - -from graphiti_core.nodes import ( - CommunityNode, - EntityNode, - EpisodeType, - EpisodicNode, -) - -FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost') -FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379') -FALKORDB_USER = os.getenv('FALKORDB_USER', None) -FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) - -try: - from graphiti_core.driver.falkordb_driver import FalkorDriver - - HAS_FALKORDB = True -except ImportError: - FalkorDriver = None - HAS_FALKORDB = False - - -@pytest.fixture -def sample_entity_node(): - return EntityNode( - uuid=str(uuid4()), - name='Test Entity', - group_id='test_group', - labels=['Entity'], - name_embedding=[0.5] * 1024, - summary='Entity Summary', - ) - - -@pytest.fixture -def sample_episodic_node(): - return EpisodicNode( - uuid=str(uuid4()), - name='Episode 1', - group_id='test_group', - source=EpisodeType.text, - source_description='Test source', - content='Some content here', - valid_at=datetime.now(timezone.utc), - ) - - -@pytest.fixture -def sample_community_node(): - return CommunityNode( - uuid=str(uuid4()), - name='Community A', - name_embedding=[0.5] * 1024, - group_id='test_group', - summary='Community summary', - ) - - -@pytest.mark.asyncio -@pytest.mark.integration -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_entity_node_save_get_and_delete(sample_entity_node): - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - await sample_entity_node.save(falkor_driver) - - retrieved = await EntityNode.get_by_uuid(falkor_driver, sample_entity_node.uuid) - assert retrieved.uuid == sample_entity_node.uuid - assert retrieved.name == 'Test Entity' - assert retrieved.group_id == 'test_group' - - await sample_entity_node.delete(falkor_driver) - await falkor_driver.close() - - -@pytest.mark.asyncio -@pytest.mark.integration -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_community_node_save_get_and_delete(sample_community_node): - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - await sample_community_node.save(falkor_driver) - - retrieved = await CommunityNode.get_by_uuid(falkor_driver, sample_community_node.uuid) - assert retrieved.uuid == sample_community_node.uuid - assert retrieved.name == 'Community A' - assert retrieved.group_id == 'test_group' - assert retrieved.summary == 'Community summary' - - await sample_community_node.delete(falkor_driver) - await falkor_driver.close() - - -@pytest.mark.asyncio -@pytest.mark.integration -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_episodic_node_save_get_and_delete(sample_episodic_node): - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - await sample_episodic_node.save(falkor_driver) - - retrieved = await EpisodicNode.get_by_uuid(falkor_driver, sample_episodic_node.uuid) - assert retrieved.uuid == sample_episodic_node.uuid - assert retrieved.name == 'Episode 1' - assert retrieved.group_id == 'test_group' - assert retrieved.source == EpisodeType.text - assert retrieved.source_description == 'Test source' - assert retrieved.content == 'Some content here' - - await sample_episodic_node.delete(falkor_driver) - await falkor_driver.close() diff --git a/tests/test_node_int.py b/tests/test_node_int.py index 9f50f18a..b4aa0709 100644 --- a/tests/test_node_int.py +++ b/tests/test_node_int.py @@ -14,23 +14,22 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os -from datetime import datetime, timezone +from datetime import datetime from uuid import uuid4 +import numpy as np import pytest -from neo4j import AsyncGraphDatabase +from graphiti_core.driver.driver import GraphDriver from graphiti_core.nodes import ( CommunityNode, EntityNode, EpisodeType, EpisodicNode, ) +from tests.helpers_test import drivers, get_driver -NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687') -NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j') -NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test') +group_id = f'test_group_{str(uuid4())}' @pytest.fixture @@ -38,8 +37,8 @@ def sample_entity_node(): return EntityNode( uuid=str(uuid4()), name='Test Entity', - group_id='test_group', - labels=['Entity'], + group_id=group_id, + labels=[], name_embedding=[0.5] * 1024, summary='Entity Summary', ) @@ -50,11 +49,11 @@ def sample_episodic_node(): return EpisodicNode( uuid=str(uuid4()), name='Episode 1', - group_id='test_group', + group_id=group_id, source=EpisodeType.text, source_description='Test source', content='Some content here', - valid_at=datetime.now(timezone.utc), + valid_at=datetime.now(), ) @@ -64,59 +63,181 @@ def sample_community_node(): uuid=str(uuid4()), name='Community A', name_embedding=[0.5] * 1024, - group_id='test_group', + group_id=group_id, summary='Community summary', ) @pytest.mark.asyncio -@pytest.mark.integration -async def test_entity_node_save_get_and_delete(sample_entity_node): - neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) - await sample_entity_node.save(neo4j_driver) - retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid) +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_entity_node(sample_entity_node, driver): + driver = get_driver(driver) + uuid = sample_entity_node.uuid + + # Create node + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + await sample_entity_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 + + retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid) assert retrieved.uuid == sample_entity_node.uuid assert retrieved.name == 'Test Entity' - assert retrieved.group_id == 'test_group' + assert retrieved.group_id == group_id - await sample_entity_node.delete(neo4j_driver) + retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid]) + assert retrieved[0].uuid == sample_entity_node.uuid + assert retrieved[0].name == 'Test Entity' + assert retrieved[0].group_id == group_id - await neo4j_driver.close() + retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == sample_entity_node.uuid + assert retrieved[0].name == 'Test Entity' + assert retrieved[0].group_id == group_id + + await sample_entity_node.load_name_embedding(driver) + assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024) + + # Delete node by uuid + await sample_entity_node.delete(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + # Delete node by group id + await sample_entity_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 + await sample_entity_node.delete_by_group_id(driver, group_id) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + await driver.close() @pytest.mark.asyncio -@pytest.mark.integration -async def test_community_node_save_get_and_delete(sample_community_node): - neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_community_node(sample_community_node, driver): + driver = get_driver(driver) + uuid = sample_community_node.uuid - await sample_community_node.save(neo4j_driver) + # Create node + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + await sample_community_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 - retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid) + retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid) assert retrieved.uuid == sample_community_node.uuid assert retrieved.name == 'Community A' - assert retrieved.group_id == 'test_group' + assert retrieved.group_id == group_id assert retrieved.summary == 'Community summary' - await sample_community_node.delete(neo4j_driver) + retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid]) + assert retrieved[0].uuid == sample_community_node.uuid + assert retrieved[0].name == 'Community A' + assert retrieved[0].group_id == group_id + assert retrieved[0].summary == 'Community summary' - await neo4j_driver.close() + retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == sample_community_node.uuid + assert retrieved[0].name == 'Community A' + assert retrieved[0].group_id == group_id + + # Delete node by uuid + await sample_community_node.delete(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + # Delete node by group id + await sample_community_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 + await sample_community_node.delete_by_group_id(driver, group_id) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + await driver.close() @pytest.mark.asyncio -@pytest.mark.integration -async def test_episodic_node_save_get_and_delete(sample_episodic_node): - neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_episodic_node(sample_episodic_node, driver): + driver = get_driver(driver) + uuid = sample_episodic_node.uuid - await sample_episodic_node.save(neo4j_driver) + # Create node + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + await sample_episodic_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 - retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid) + retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid) assert retrieved.uuid == sample_episodic_node.uuid assert retrieved.name == 'Episode 1' - assert retrieved.group_id == 'test_group' + assert retrieved.group_id == group_id assert retrieved.source == EpisodeType.text assert retrieved.source_description == 'Test source' assert retrieved.content == 'Some content here' + assert retrieved.valid_at == sample_episodic_node.valid_at - await sample_episodic_node.delete(neo4j_driver) + retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid]) + assert retrieved[0].uuid == sample_episodic_node.uuid + assert retrieved[0].name == 'Episode 1' + assert retrieved[0].group_id == group_id + assert retrieved[0].source == EpisodeType.text + assert retrieved[0].source_description == 'Test source' + assert retrieved[0].content == 'Some content here' + assert retrieved[0].valid_at == sample_episodic_node.valid_at - await neo4j_driver.close() + retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == sample_episodic_node.uuid + assert retrieved[0].name == 'Episode 1' + assert retrieved[0].group_id == group_id + assert retrieved[0].source == EpisodeType.text + assert retrieved[0].source_description == 'Test source' + assert retrieved[0].content == 'Some content here' + assert retrieved[0].valid_at == sample_episodic_node.valid_at + + # Delete node by uuid + await sample_episodic_node.delete(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + # Delete node by group id + await sample_episodic_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 + await sample_episodic_node.delete_by_group_id(driver, group_id) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + await driver.close() + + +async def get_node_count(driver: GraphDriver, uuid: str): + result, _, _ = await driver.execute_query( + """ + MATCH (n {uuid: $uuid}) + RETURN COUNT(n) as count + """, + uuid=uuid, + ) + return int(result[0]['count'])