Excluded entity type filtering (#624)
* excluded entities filtering * Fix variable name casing in test_entity_exclusion_int.py for consistency
This commit is contained in:
parent
9c8a20e16f
commit
c29893d972
6 changed files with 395 additions and 5 deletions
|
|
@ -29,7 +29,12 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather, validate_group_id
|
||||
from graphiti_core.helpers import (
|
||||
DEFAULT_DATABASE,
|
||||
semaphore_gather,
|
||||
validate_excluded_entity_types,
|
||||
validate_group_id,
|
||||
)
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import SearchConfig, search
|
||||
|
|
@ -293,6 +298,7 @@ class Graphiti:
|
|||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
previous_episode_uuids: list[str] | None = None,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
|
|
@ -321,6 +327,12 @@ class Graphiti:
|
|||
Optional uuid of the episode.
|
||||
update_communities : bool
|
||||
Optional. Whether to update communities with new node information
|
||||
entity_types : dict[str, BaseModel] | None
|
||||
Optional. Dictionary mapping entity type names to their Pydantic model definitions.
|
||||
excluded_entity_types : list[str] | None
|
||||
Optional. List of entity type names to exclude from the graph. Entities classified
|
||||
into these types will not be added to the graph. Can include 'Entity' to exclude
|
||||
the default entity type.
|
||||
previous_episode_uuids : list[str] | None
|
||||
Optional. list of episode uuids to use as the previous episodes. If this is not provided,
|
||||
the most recent episodes by created_at date will be used.
|
||||
|
|
@ -351,6 +363,7 @@ class Graphiti:
|
|||
now = utc_now()
|
||||
|
||||
validate_entity_types(entity_types)
|
||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||
validate_group_id(group_id)
|
||||
|
||||
previous_episodes = (
|
||||
|
|
@ -389,7 +402,7 @@ class Graphiti:
|
|||
# Extract entities as nodes
|
||||
|
||||
extracted_nodes = await extract_nodes(
|
||||
self.clients, episode, previous_episodes, entity_types
|
||||
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
||||
)
|
||||
|
||||
# Extract edges and resolve nodes
|
||||
|
|
@ -534,7 +547,7 @@ class Graphiti:
|
|||
extracted_nodes,
|
||||
extracted_edges,
|
||||
episodic_edges,
|
||||
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
|
||||
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None)
|
||||
|
||||
# Generate embeddings
|
||||
await semaphore_gather(
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import numpy as np
|
|||
from dotenv import load_dotenv
|
||||
from neo4j import time as neo4j_time
|
||||
from numpy._typing import NDArray
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.errors import GroupIdValidationError
|
||||
|
|
@ -132,3 +133,37 @@ def validate_group_id(group_id: str) -> bool:
|
|||
raise GroupIdValidationError(group_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_excluded_entity_types(
|
||||
excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that excluded entity types are valid type names.
|
||||
|
||||
Args:
|
||||
excluded_entity_types: List of entity type names to exclude
|
||||
entity_types: Dictionary of available custom entity types
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If any excluded type names are invalid
|
||||
"""
|
||||
if not excluded_entity_types:
|
||||
return True
|
||||
|
||||
# Build set of available type names
|
||||
available_types = {'Entity'} # Default type is always available
|
||||
if entity_types:
|
||||
available_types.update(entity_types.keys())
|
||||
|
||||
# Check for invalid type names
|
||||
invalid_types = set(excluded_entity_types) - available_types
|
||||
if invalid_types:
|
||||
raise ValueError(
|
||||
f'Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}'
|
||||
)
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -586,6 +586,8 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|||
|
||||
|
||||
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
||||
if not nodes: # Handle empty list case
|
||||
return
|
||||
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
||||
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
||||
node.name_embedding = name_embedding
|
||||
|
|
|
|||
|
|
@ -177,11 +177,14 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||
clients: GraphitiClients,
|
||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||
extracted_nodes_bulk = await semaphore_gather(
|
||||
*[
|
||||
extract_nodes(clients, episode, previous_episodes)
|
||||
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
|
||||
for episode, previous_episodes in episode_tuples
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ async def extract_nodes(
|
|||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
llm_client = clients.llm_client
|
||||
|
|
@ -154,6 +155,11 @@ async def extract_nodes(
|
|||
'entity_type_name'
|
||||
)
|
||||
|
||||
# Check if this entity type should be excluded
|
||||
if excluded_entity_types and entity_type_name in excluded_entity_types:
|
||||
logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
|
||||
continue
|
||||
|
||||
labels: list[str] = list({'Entity', str(entity_type_name)})
|
||||
|
||||
new_node = EntityNode(
|
||||
|
|
|
|||
331
tests/test_entity_exclusion_int.py
Normal file
331
tests/test_entity_exclusion_int.py
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.helpers import validate_excluded_entity_types
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
NEO4J_URI = os.getenv('NEO4J_URI')
|
||||
NEO4J_USER = os.getenv('NEO4J_USER')
|
||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
|
||||
|
||||
|
||||
# Test entity type definitions
|
||||
class Person(BaseModel):
|
||||
"""A human person mentioned in the conversation."""
|
||||
|
||||
first_name: str | None = Field(None, description='First name of the person')
|
||||
last_name: str | None = Field(None, description='Last name of the person')
|
||||
occupation: str | None = Field(None, description='Job or profession of the person')
|
||||
|
||||
|
||||
class Organization(BaseModel):
|
||||
"""A company, institution, or organized group."""
|
||||
|
||||
organization_type: str | None = Field(None, description='Type of organization (company, NGO, etc.)')
|
||||
industry: str | None = Field(None, description='Industry or sector the organization operates in')
|
||||
|
||||
|
||||
class Location(BaseModel):
|
||||
"""A geographic location, place, or address."""
|
||||
|
||||
location_type: str | None = Field(None, description='Type of location (city, country, building, etc.)')
|
||||
coordinates: str | None = Field(None, description='Geographic coordinates if available')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_default_entity_type():
|
||||
"""Test excluding the default 'Entity' type while keeping custom types."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
|
||||
try:
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
# Define entity types but exclude the default 'Entity' type
|
||||
entity_types = {
|
||||
'Person': Person,
|
||||
'Organization': Organization,
|
||||
}
|
||||
|
||||
# Add an episode that would normally create both Entity and custom type entities
|
||||
episode_content = "John Smith works at Acme Corporation in New York. The weather is nice today."
|
||||
|
||||
result = await graphiti.add_episode(
|
||||
name="Business Meeting",
|
||||
episode_body=episode_content,
|
||||
source_description="Meeting notes",
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
entity_types=entity_types,
|
||||
excluded_entity_types=['Entity'], # Exclude default type
|
||||
group_id='test_exclude_default'
|
||||
)
|
||||
|
||||
# Verify that nodes were created (custom types should still work)
|
||||
assert result is not None
|
||||
|
||||
# Search for nodes to verify only custom types were created
|
||||
search_results = await graphiti.search_(
|
||||
query="John Smith Acme Corporation",
|
||||
group_ids=['test_exclude_default']
|
||||
)
|
||||
|
||||
# Check that entities were created but with specific types, not default 'Entity'
|
||||
found_nodes = search_results.nodes
|
||||
for node in found_nodes:
|
||||
assert 'Entity' in node.labels # All nodes should have Entity label
|
||||
# But they should also have specific type labels
|
||||
assert any(label in ['Person', 'Organization'] for label in node.labels), \
|
||||
f"Node {node.name} should have a specific type label, got: {node.labels}"
|
||||
|
||||
# Clean up
|
||||
await _cleanup_test_nodes(graphiti, 'test_exclude_default')
|
||||
|
||||
finally:
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_specific_custom_types():
|
||||
"""Test excluding specific custom entity types while keeping others."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
|
||||
try:
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
# Define multiple entity types
|
||||
entity_types = {
|
||||
'Person': Person,
|
||||
'Organization': Organization,
|
||||
'Location': Location,
|
||||
}
|
||||
|
||||
# Add an episode with content that would create all types
|
||||
episode_content = "Sarah Johnson from Google visited the San Francisco office to discuss the new project."
|
||||
|
||||
result = await graphiti.add_episode(
|
||||
name="Office Visit",
|
||||
episode_body=episode_content,
|
||||
source_description="Visit report",
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
entity_types=entity_types,
|
||||
excluded_entity_types=['Organization', 'Location'], # Exclude these types
|
||||
group_id='test_exclude_custom'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Search for nodes to verify only Person and Entity types were created
|
||||
search_results = await graphiti.search_(
|
||||
query="Sarah Johnson Google San Francisco",
|
||||
group_ids=['test_exclude_custom']
|
||||
)
|
||||
|
||||
found_nodes = search_results.nodes
|
||||
|
||||
# Should have Person and Entity type nodes, but no Organization or Location
|
||||
for node in found_nodes:
|
||||
assert 'Entity' in node.labels
|
||||
# Should not have excluded types
|
||||
assert 'Organization' not in node.labels, f"Found excluded Organization in node: {node.name}"
|
||||
assert 'Location' not in node.labels, f"Found excluded Location in node: {node.name}"
|
||||
|
||||
# Should find at least one Person entity (Sarah Johnson)
|
||||
person_nodes = [n for n in found_nodes if 'Person' in n.labels]
|
||||
assert len(person_nodes) > 0, "Should have found at least one Person entity"
|
||||
|
||||
# Clean up
|
||||
await _cleanup_test_nodes(graphiti, 'test_exclude_custom')
|
||||
|
||||
finally:
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_all_types():
|
||||
"""Test excluding all entity types (edge case)."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
|
||||
try:
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
entity_types = {
|
||||
'Person': Person,
|
||||
'Organization': Organization,
|
||||
}
|
||||
|
||||
# Exclude all types
|
||||
result = await graphiti.add_episode(
|
||||
name="No Entities",
|
||||
episode_body="This text mentions John and Microsoft but no entities should be created.",
|
||||
source_description="Test content",
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
entity_types=entity_types,
|
||||
excluded_entity_types=['Entity', 'Person', 'Organization'], # Exclude everything
|
||||
group_id='test_exclude_all'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Search for nodes - should find very few or none from this episode
|
||||
search_results = await graphiti.search_(
|
||||
query="John Microsoft",
|
||||
group_ids=['test_exclude_all']
|
||||
)
|
||||
|
||||
# There should be minimal to no entities created
|
||||
found_nodes = search_results.nodes
|
||||
assert len(found_nodes) == 0, f"Expected no entities, but found: {[n.name for n in found_nodes]}"
|
||||
|
||||
# Clean up
|
||||
await _cleanup_test_nodes(graphiti, 'test_exclude_all')
|
||||
|
||||
finally:
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_no_types():
|
||||
"""Test normal behavior when no types are excluded (baseline test)."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
|
||||
try:
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
entity_types = {
|
||||
'Person': Person,
|
||||
'Organization': Organization,
|
||||
}
|
||||
|
||||
# Don't exclude any types
|
||||
result = await graphiti.add_episode(
|
||||
name="Normal Behavior",
|
||||
episode_body="Alice Smith works at TechCorp.",
|
||||
source_description="Normal test",
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
entity_types=entity_types,
|
||||
excluded_entity_types=None, # No exclusions
|
||||
group_id='test_exclude_none'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Search for nodes - should find entities of all types
|
||||
search_results = await graphiti.search_(
|
||||
query="Alice Smith TechCorp",
|
||||
group_ids=['test_exclude_none']
|
||||
)
|
||||
|
||||
found_nodes = search_results.nodes
|
||||
assert len(found_nodes) > 0, "Should have found some entities"
|
||||
|
||||
# Should have both Person and Organization entities
|
||||
person_nodes = [n for n in found_nodes if 'Person' in n.labels]
|
||||
org_nodes = [n for n in found_nodes if 'Organization' in n.labels]
|
||||
|
||||
assert len(person_nodes) > 0, "Should have found Person entities"
|
||||
assert len(org_nodes) > 0, "Should have found Organization entities"
|
||||
|
||||
# Clean up
|
||||
await _cleanup_test_nodes(graphiti, 'test_exclude_none')
|
||||
|
||||
finally:
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
def test_validation_valid_excluded_types():
|
||||
"""Test validation function with valid excluded types."""
|
||||
entity_types = {
|
||||
'Person': Person,
|
||||
'Organization': Organization,
|
||||
}
|
||||
|
||||
# Valid exclusions
|
||||
assert validate_excluded_entity_types(['Entity'], entity_types) is True
|
||||
assert validate_excluded_entity_types(['Person'], entity_types) is True
|
||||
assert validate_excluded_entity_types(['Entity', 'Person'], entity_types) is True
|
||||
assert validate_excluded_entity_types(None, entity_types) is True
|
||||
assert validate_excluded_entity_types([], entity_types) is True
|
||||
|
||||
|
||||
def test_validation_invalid_excluded_types():
|
||||
"""Test validation function with invalid excluded types."""
|
||||
entity_types = {
|
||||
'Person': Person,
|
||||
'Organization': Organization,
|
||||
}
|
||||
|
||||
# Invalid exclusions should raise ValueError
|
||||
with pytest.raises(ValueError, match="Invalid excluded entity types"):
|
||||
validate_excluded_entity_types(['InvalidType'], entity_types)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid excluded entity types"):
|
||||
validate_excluded_entity_types(['Person', 'NonExistentType'], entity_types)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excluded_types_parameter_validation_in_add_episode():
|
||||
"""Test that add_episode validates excluded_entity_types parameter."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
|
||||
try:
|
||||
entity_types = {
|
||||
'Person': Person,
|
||||
}
|
||||
|
||||
# Should raise ValueError for invalid excluded type
|
||||
with pytest.raises(ValueError, match="Invalid excluded entity types"):
|
||||
await graphiti.add_episode(
|
||||
name="Invalid Test",
|
||||
episode_body="Test content",
|
||||
source_description="Test",
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
entity_types=entity_types,
|
||||
excluded_entity_types=['NonExistentType'],
|
||||
group_id='test_validation'
|
||||
)
|
||||
|
||||
finally:
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
async def _cleanup_test_nodes(graphiti: Graphiti, group_id: str):
|
||||
"""Helper function to clean up test nodes."""
|
||||
try:
|
||||
# Get all nodes for this group
|
||||
search_results = await graphiti.search_(
|
||||
query="*",
|
||||
group_ids=[group_id]
|
||||
)
|
||||
|
||||
# Delete all found nodes
|
||||
for node in search_results.nodes:
|
||||
await node.delete(graphiti.driver)
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail the test if cleanup fails
|
||||
print(f"Warning: Failed to clean up test nodes for group {group_id}: {e}")
|
||||
Loading…
Add table
Reference in a new issue