graphiti/tests/test_edge_int.py
Daniel Chalef dcc9da3f68
chore/prepare kuzu integration (#762)
* Prepare code

* Fix tests

* As -> AS, remove trailing spaces

* Enable more tests for FalkorDB

* Fix more cypher queries

* Return all created nodes and edges

* Add Neo4j service to unit tests workflow

- Introduced Neo4j as a service in the GitHub Actions workflow for unit tests.
- Configured Neo4j with appropriate ports, authentication, and health checks.
- Updated test steps to include waiting for Neo4j and running integration tests against it.
- Set environment variables for Neo4j connection in both non-integration and integration test steps.

* Update Neo4j authentication in unit tests workflow

- Changed Neo4j authentication password from 'test' to 'testpass' in the GitHub Actions workflow.
- Updated health check command to reflect the new password.
- Ensured consistency across all test steps that utilize Neo4j credentials.

* fix health check

* Fix Neo4j integration tests in CI workflow

Remove reference to non-existent test_neo4j_driver.py file from test command.
Integration tests now run via parametrized tests using the drivers list.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Add OPENAI_API_KEY to Neo4j integration tests

Neo4j integration tests require OpenAI API access for LLM functionality.
Add the secret environment variable to enable these tests to run properly.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix Neo4j Cypher syntax error in BFS search queries

Replace parameter substitution in relationship pattern ranges (*1..$depth)
with direct string interpolation (*1..{bfs_max_depth}). Neo4j doesn't allow
parameter maps in MATCH pattern ranges - they must be literal values.

Fixed in both node_bfs_search and edge_bfs_search functions.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix variable name mismatch in edge_bfs_search query

Change relationship variable from 'r' to 'e' to match ENTITY_EDGE_RETURN
constant expectations. The ENTITY_EDGE_RETURN constant references variable
'e' for relationships, but the query was using 'r', causing "Variable e
not defined" errors.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Isolate database tests in CI workflow

- FalkorDB tests: Add DISABLE_NEO4J=1 and remove Neo4j env vars
- Neo4j tests: Keep current setup without DISABLE_NEO4J flag

This ensures proper test isolation where each test suite only runs
against its intended database backend.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Siddhartha Sahu <sid@kuzudb.com>
Co-authored-by: Claude <noreply@anthropic.com>
2025-07-29 09:07:34 -04:00

384 lines
13 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import sys
from datetime import datetime
from uuid import uuid4
import numpy as np
import pytest
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.embedder.openai import OpenAIEmbedder
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from tests.helpers_test import drivers, get_driver
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
group_id = f'test_group_{str(uuid4())}'
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_episodic_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now()
episode_node = EpisodicNode(
name='test_episode',
labels=[],
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Alice likes Bob',
entity_edges=[],
group_id=group_id,
)
node_count = await get_node_count(graph_driver, episode_node.uuid)
assert node_count == 0
await episode_node.save(graph_driver)
node_count = await get_node_count(graph_driver, episode_node.uuid)
assert node_count == 1
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 1
episodic_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
assert edge_count == 0
await episodic_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
assert edge_count == 1
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
assert retrieved.uuid == episodic_edge.uuid
assert retrieved.source_node_uuid == episode_node.uuid
assert retrieved.target_node_uuid == alice_node.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
assert retrieved[0].source_node_uuid == episode_node.uuid
assert retrieved[0].target_node_uuid == alice_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
assert retrieved[0].source_node_uuid == episode_node.uuid
assert retrieved[0].target_node_uuid == alice_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
await episodic_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
assert edge_count == 0
await episode_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, episode_node.uuid)
assert node_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_entity_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now()
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 1
bob_node = EntityNode(
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
)
await bob_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, bob_node.uuid)
assert node_count == 0
await bob_node.save(graph_driver)
node_count = await get_node_count(graph_driver, bob_node.uuid)
assert node_count == 1
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
edge_embedding = await entity_edge.generate_embedding(embedder)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
assert edge_count == 0
await entity_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
assert edge_count == 1
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
assert retrieved.uuid == entity_edge.uuid
assert retrieved.source_node_uuid == alice_node.uuid
assert retrieved.target_node_uuid == bob_node.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
await entity_edge.load_fact_embedding(graph_driver)
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
await entity_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
assert edge_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await bob_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, bob_node.uuid)
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'driver',
drivers,
ids=drivers,
)
async def test_community_edge(driver):
graph_driver = get_driver(driver)
embedder = OpenAIEmbedder()
now = datetime.now()
community_node_1 = CommunityNode(
name='Community A',
group_id=group_id,
summary='Community A summary',
)
await community_node_1.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
assert node_count == 0
await community_node_1.save(graph_driver)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
assert node_count == 1
community_node_2 = CommunityNode(
name='Community B',
group_id=group_id,
summary='Community B summary',
)
await community_node_2.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
assert node_count == 0
await community_node_2.save(graph_driver)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
assert node_count == 1
alice_node = EntityNode(
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
)
await alice_node.generate_name_embedding(embedder)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 1
community_edge = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=community_node_2.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
assert edge_count == 0
await community_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
assert edge_count == 1
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
assert retrieved.uuid == community_edge.uuid
assert retrieved.source_node_uuid == community_node_1.uuid
assert retrieved.target_node_uuid == community_node_2.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
assert retrieved[0].source_node_uuid == community_node_1.uuid
assert retrieved[0].target_node_uuid == community_node_2.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
assert retrieved[0].source_node_uuid == community_node_1.uuid
assert retrieved[0].target_node_uuid == community_node_2.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
await community_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
assert edge_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, alice_node.uuid)
assert node_count == 0
await community_node_1.delete(graph_driver)
node_count = await get_node_count(graph_driver, community_node_1.uuid)
assert node_count == 0
await community_node_2.delete(graph_driver)
node_count = await get_node_count(graph_driver, community_node_2.uuid)
assert node_count == 0
await graph_driver.close()
async def get_node_count(driver: GraphDriver, uuid: str):
results, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
RETURN COUNT(n) as count
""",
uuid=uuid,
)
return int(results[0]['count'])
async def get_edge_count(driver: GraphDriver, uuid: str):
results, _, _ = await driver.execute_query(
"""
MATCH (n)-[e {uuid: $uuid}]->(m)
RETURN COUNT(e) as count
UNION ALL
MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2)
RETURN COUNT(m) as count
""",
uuid=uuid,
)
return sum(int(result['count']) for result in results)