* Prepare code * Fix tests * As -> AS, remove trailing spaces * Enable more tests for FalkorDB * Fix more cypher queries * Return all created nodes and edges * Add Neo4j service to unit tests workflow - Introduced Neo4j as a service in the GitHub Actions workflow for unit tests. - Configured Neo4j with appropriate ports, authentication, and health checks. - Updated test steps to include waiting for Neo4j and running integration tests against it. - Set environment variables for Neo4j connection in both non-integration and integration test steps. * Update Neo4j authentication in unit tests workflow - Changed Neo4j authentication password from 'test' to 'testpass' in the GitHub Actions workflow. - Updated health check command to reflect the new password. - Ensured consistency across all test steps that utilize Neo4j credentials. * fix health check * Fix Neo4j integration tests in CI workflow Remove reference to non-existent test_neo4j_driver.py file from test command. Integration tests now run via parametrized tests using the drivers list. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Add OPENAI_API_KEY to Neo4j integration tests Neo4j integration tests require OpenAI API access for LLM functionality. Add the secret environment variable to enable these tests to run properly. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix Neo4j Cypher syntax error in BFS search queries Replace parameter substitution in relationship pattern ranges (*1..$depth) with direct string interpolation (*1..{bfs_max_depth}). Neo4j doesn't allow parameter maps in MATCH pattern ranges - they must be literal values. Fixed in both node_bfs_search and edge_bfs_search functions. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix variable name mismatch in edge_bfs_search query Change relationship variable from 'r' to 'e' to match ENTITY_EDGE_RETURN constant expectations. The ENTITY_EDGE_RETURN constant references variable 'e' for relationships, but the query was using 'r', causing "Variable e not defined" errors. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Isolate database tests in CI workflow - FalkorDB tests: Add DISABLE_NEO4J=1 and remove Neo4j env vars - Neo4j tests: Keep current setup without DISABLE_NEO4J flag This ensures proper test isolation where each test suite only runs against its intended database backend. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Siddhartha Sahu <sid@kuzudb.com> Co-authored-by: Claude <noreply@anthropic.com>
356 lines
12 KiB
Python
356 lines
12 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
import pytest
|
|
from pydantic import BaseModel, Field
|
|
|
|
from graphiti_core.graphiti import Graphiti
|
|
from graphiti_core.helpers import validate_excluded_entity_types
|
|
from tests.helpers_test import drivers, get_driver
|
|
|
|
pytestmark = pytest.mark.integration
|
|
pytest_plugins = ('pytest_asyncio',)
|
|
|
|
|
|
# Test entity type definitions
|
|
class Person(BaseModel):
|
|
"""A human person mentioned in the conversation."""
|
|
|
|
first_name: str | None = Field(None, description='First name of the person')
|
|
last_name: str | None = Field(None, description='Last name of the person')
|
|
occupation: str | None = Field(None, description='Job or profession of the person')
|
|
|
|
|
|
class Organization(BaseModel):
|
|
"""A company, institution, or organized group."""
|
|
|
|
organization_type: str | None = Field(
|
|
None, description='Type of organization (company, NGO, etc.)'
|
|
)
|
|
industry: str | None = Field(
|
|
None, description='Industry or sector the organization operates in'
|
|
)
|
|
|
|
|
|
class Location(BaseModel):
|
|
"""A geographic location, place, or address."""
|
|
|
|
location_type: str | None = Field(
|
|
None, description='Type of location (city, country, building, etc.)'
|
|
)
|
|
coordinates: str | None = Field(None, description='Geographic coordinates if available')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
'driver',
|
|
drivers,
|
|
ids=drivers,
|
|
)
|
|
async def test_exclude_default_entity_type(driver):
|
|
"""Test excluding the default 'Entity' type while keeping custom types."""
|
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
|
|
|
try:
|
|
await graphiti.build_indices_and_constraints()
|
|
|
|
# Define entity types but exclude the default 'Entity' type
|
|
entity_types = {
|
|
'Person': Person,
|
|
'Organization': Organization,
|
|
}
|
|
|
|
# Add an episode that would normally create both Entity and custom type entities
|
|
episode_content = (
|
|
'John Smith works at Acme Corporation in New York. The weather is nice today.'
|
|
)
|
|
|
|
result = await graphiti.add_episode(
|
|
name='Business Meeting',
|
|
episode_body=episode_content,
|
|
source_description='Meeting notes',
|
|
reference_time=datetime.now(timezone.utc),
|
|
entity_types=entity_types,
|
|
excluded_entity_types=['Entity'], # Exclude default type
|
|
group_id='test_exclude_default',
|
|
)
|
|
|
|
# Verify that nodes were created (custom types should still work)
|
|
assert result is not None
|
|
|
|
# Search for nodes to verify only custom types were created
|
|
search_results = await graphiti.search_(
|
|
query='John Smith Acme Corporation', group_ids=['test_exclude_default']
|
|
)
|
|
|
|
# Check that entities were created but with specific types, not default 'Entity'
|
|
found_nodes = search_results.nodes
|
|
for node in found_nodes:
|
|
assert 'Entity' in node.labels # All nodes should have Entity label
|
|
# But they should also have specific type labels
|
|
assert any(label in ['Person', 'Organization'] for label in node.labels), (
|
|
f'Node {node.name} should have a specific type label, got: {node.labels}'
|
|
)
|
|
|
|
# Clean up
|
|
await _cleanup_test_nodes(graphiti, 'test_exclude_default')
|
|
|
|
finally:
|
|
await graphiti.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
'driver',
|
|
drivers,
|
|
ids=drivers,
|
|
)
|
|
async def test_exclude_specific_custom_types(driver):
|
|
"""Test excluding specific custom entity types while keeping others."""
|
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
|
|
|
try:
|
|
await graphiti.build_indices_and_constraints()
|
|
|
|
# Define multiple entity types
|
|
entity_types = {
|
|
'Person': Person,
|
|
'Organization': Organization,
|
|
'Location': Location,
|
|
}
|
|
|
|
# Add an episode with content that would create all types
|
|
episode_content = (
|
|
'Sarah Johnson from Google visited the San Francisco office to discuss the new project.'
|
|
)
|
|
|
|
result = await graphiti.add_episode(
|
|
name='Office Visit',
|
|
episode_body=episode_content,
|
|
source_description='Visit report',
|
|
reference_time=datetime.now(timezone.utc),
|
|
entity_types=entity_types,
|
|
excluded_entity_types=['Organization', 'Location'], # Exclude these types
|
|
group_id='test_exclude_custom',
|
|
)
|
|
|
|
assert result is not None
|
|
|
|
# Search for nodes to verify only Person and Entity types were created
|
|
search_results = await graphiti.search_(
|
|
query='Sarah Johnson Google San Francisco', group_ids=['test_exclude_custom']
|
|
)
|
|
|
|
found_nodes = search_results.nodes
|
|
|
|
# Should have Person and Entity type nodes, but no Organization or Location
|
|
for node in found_nodes:
|
|
assert 'Entity' in node.labels
|
|
# Should not have excluded types
|
|
assert 'Organization' not in node.labels, (
|
|
f'Found excluded Organization in node: {node.name}'
|
|
)
|
|
assert 'Location' not in node.labels, f'Found excluded Location in node: {node.name}'
|
|
|
|
# Should find at least one Person entity (Sarah Johnson)
|
|
person_nodes = [n for n in found_nodes if 'Person' in n.labels]
|
|
assert len(person_nodes) > 0, 'Should have found at least one Person entity'
|
|
|
|
# Clean up
|
|
await _cleanup_test_nodes(graphiti, 'test_exclude_custom')
|
|
|
|
finally:
|
|
await graphiti.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
'driver',
|
|
drivers,
|
|
ids=drivers,
|
|
)
|
|
async def test_exclude_all_types(driver):
|
|
"""Test excluding all entity types (edge case)."""
|
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
|
|
|
try:
|
|
await graphiti.build_indices_and_constraints()
|
|
|
|
entity_types = {
|
|
'Person': Person,
|
|
'Organization': Organization,
|
|
}
|
|
|
|
# Exclude all types
|
|
result = await graphiti.add_episode(
|
|
name='No Entities',
|
|
episode_body='This text mentions John and Microsoft but no entities should be created.',
|
|
source_description='Test content',
|
|
reference_time=datetime.now(timezone.utc),
|
|
entity_types=entity_types,
|
|
excluded_entity_types=['Entity', 'Person', 'Organization'], # Exclude everything
|
|
group_id='test_exclude_all',
|
|
)
|
|
|
|
assert result is not None
|
|
|
|
# Search for nodes - should find very few or none from this episode
|
|
search_results = await graphiti.search_(
|
|
query='John Microsoft', group_ids=['test_exclude_all']
|
|
)
|
|
|
|
# There should be minimal to no entities created
|
|
found_nodes = search_results.nodes
|
|
assert len(found_nodes) == 0, (
|
|
f'Expected no entities, but found: {[n.name for n in found_nodes]}'
|
|
)
|
|
|
|
# Clean up
|
|
await _cleanup_test_nodes(graphiti, 'test_exclude_all')
|
|
|
|
finally:
|
|
await graphiti.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
'driver',
|
|
drivers,
|
|
ids=drivers,
|
|
)
|
|
async def test_exclude_no_types(driver):
|
|
"""Test normal behavior when no types are excluded (baseline test)."""
|
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
|
|
|
try:
|
|
await graphiti.build_indices_and_constraints()
|
|
|
|
entity_types = {
|
|
'Person': Person,
|
|
'Organization': Organization,
|
|
}
|
|
|
|
# Don't exclude any types
|
|
result = await graphiti.add_episode(
|
|
name='Normal Behavior',
|
|
episode_body='Alice Smith works at TechCorp.',
|
|
source_description='Normal test',
|
|
reference_time=datetime.now(timezone.utc),
|
|
entity_types=entity_types,
|
|
excluded_entity_types=None, # No exclusions
|
|
group_id='test_exclude_none',
|
|
)
|
|
|
|
assert result is not None
|
|
|
|
# Search for nodes - should find entities of all types
|
|
search_results = await graphiti.search_(
|
|
query='Alice Smith TechCorp', group_ids=['test_exclude_none']
|
|
)
|
|
|
|
found_nodes = search_results.nodes
|
|
assert len(found_nodes) > 0, 'Should have found some entities'
|
|
|
|
# Should have both Person and Organization entities
|
|
person_nodes = [n for n in found_nodes if 'Person' in n.labels]
|
|
org_nodes = [n for n in found_nodes if 'Organization' in n.labels]
|
|
|
|
assert len(person_nodes) > 0, 'Should have found Person entities'
|
|
assert len(org_nodes) > 0, 'Should have found Organization entities'
|
|
|
|
# Clean up
|
|
await _cleanup_test_nodes(graphiti, 'test_exclude_none')
|
|
|
|
finally:
|
|
await graphiti.close()
|
|
|
|
|
|
def test_validation_valid_excluded_types():
|
|
"""Test validation function with valid excluded types."""
|
|
entity_types = {
|
|
'Person': Person,
|
|
'Organization': Organization,
|
|
}
|
|
|
|
# Valid exclusions
|
|
assert validate_excluded_entity_types(['Entity'], entity_types) is True
|
|
assert validate_excluded_entity_types(['Person'], entity_types) is True
|
|
assert validate_excluded_entity_types(['Entity', 'Person'], entity_types) is True
|
|
assert validate_excluded_entity_types(None, entity_types) is True
|
|
assert validate_excluded_entity_types([], entity_types) is True
|
|
|
|
|
|
def test_validation_invalid_excluded_types():
|
|
"""Test validation function with invalid excluded types."""
|
|
entity_types = {
|
|
'Person': Person,
|
|
'Organization': Organization,
|
|
}
|
|
|
|
# Invalid exclusions should raise ValueError
|
|
with pytest.raises(ValueError, match='Invalid excluded entity types'):
|
|
validate_excluded_entity_types(['InvalidType'], entity_types)
|
|
|
|
with pytest.raises(ValueError, match='Invalid excluded entity types'):
|
|
validate_excluded_entity_types(['Person', 'NonExistentType'], entity_types)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
'driver',
|
|
drivers,
|
|
ids=drivers,
|
|
)
|
|
async def test_excluded_types_parameter_validation_in_add_episode(driver):
|
|
"""Test that add_episode validates excluded_entity_types parameter."""
|
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
|
|
|
try:
|
|
entity_types = {
|
|
'Person': Person,
|
|
}
|
|
|
|
# Should raise ValueError for invalid excluded type
|
|
with pytest.raises(ValueError, match='Invalid excluded entity types'):
|
|
await graphiti.add_episode(
|
|
name='Invalid Test',
|
|
episode_body='Test content',
|
|
source_description='Test',
|
|
reference_time=datetime.now(timezone.utc),
|
|
entity_types=entity_types,
|
|
excluded_entity_types=['NonExistentType'],
|
|
group_id='test_validation',
|
|
)
|
|
|
|
finally:
|
|
await graphiti.close()
|
|
|
|
|
|
async def _cleanup_test_nodes(graphiti: Graphiti, group_id: str):
|
|
"""Helper function to clean up test nodes."""
|
|
try:
|
|
# Get all nodes for this group
|
|
search_results = await graphiti.search_(query='*', group_ids=[group_id])
|
|
|
|
# Delete all found nodes
|
|
for node in search_results.nodes:
|
|
await node.delete(graphiti.driver)
|
|
|
|
except Exception as e:
|
|
# Log but don't fail the test if cleanup fails
|
|
print(f'Warning: Failed to clean up test nodes for group {group_id}: {e}')
|