diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 18a39657..5618dcd6 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -29,7 +29,12 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients -from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather, validate_group_id +from graphiti_core.helpers import ( + DEFAULT_DATABASE, + semaphore_gather, + validate_excluded_entity_types, + validate_group_id, +) from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search import SearchConfig, search @@ -293,6 +298,7 @@ class Graphiti: uuid: str | None = None, update_communities: bool = False, entity_types: dict[str, BaseModel] | None = None, + excluded_entity_types: list[str] | None = None, previous_episode_uuids: list[str] | None = None, edge_types: dict[str, BaseModel] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None, @@ -321,6 +327,12 @@ class Graphiti: Optional uuid of the episode. update_communities : bool Optional. Whether to update communities with new node information + entity_types : dict[str, BaseModel] | None + Optional. Dictionary mapping entity type names to their Pydantic model definitions. + excluded_entity_types : list[str] | None + Optional. List of entity type names to exclude from the graph. Entities classified + into these types will not be added to the graph. Can include 'Entity' to exclude + the default entity type. previous_episode_uuids : list[str] | None Optional. list of episode uuids to use as the previous episodes. If this is not provided, the most recent episodes by created_at date will be used. @@ -351,6 +363,7 @@ class Graphiti: now = utc_now() validate_entity_types(entity_types) + validate_excluded_entity_types(excluded_entity_types, entity_types) validate_group_id(group_id) previous_episodes = ( @@ -389,7 +402,7 @@ class Graphiti: # Extract entities as nodes extracted_nodes = await extract_nodes( - self.clients, episode, previous_episodes, entity_types + self.clients, episode, previous_episodes, entity_types, excluded_entity_types ) # Extract edges and resolve nodes @@ -534,7 +547,7 @@ class Graphiti: extracted_nodes, extracted_edges, episodic_edges, - ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs) + ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None) # Generate embeddings await semaphore_gather( diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index a115f075..460b4883 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -24,6 +24,7 @@ import numpy as np from dotenv import load_dotenv from neo4j import time as neo4j_time from numpy._typing import NDArray +from pydantic import BaseModel from typing_extensions import LiteralString from graphiti_core.errors import GroupIdValidationError @@ -132,3 +133,37 @@ def validate_group_id(group_id: str) -> bool: raise GroupIdValidationError(group_id) return True + + +def validate_excluded_entity_types( + excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None +) -> bool: + """ + Validate that excluded entity types are valid type names. + + Args: + excluded_entity_types: List of entity type names to exclude + entity_types: Dictionary of available custom entity types + + Returns: + True if valid + + Raises: + ValueError: If any excluded type names are invalid + """ + if not excluded_entity_types: + return True + + # Build set of available type names + available_types = {'Entity'} # Default type is always available + if entity_types: + available_types.update(entity_types.keys()) + + # Check for invalid type names + invalid_types = set(excluded_entity_types) - available_types + if invalid_types: + raise ValueError( + f'Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}' + ) + + return True diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index fd15499c..c1af0a6f 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -586,6 +586,8 @@ def get_community_node_from_record(record: Any) -> CommunityNode: async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]): + if not nodes: # Handle empty list case + return name_embeddings = await embedder.create_batch([node.name for node in nodes]) for node, name_embedding in zip(nodes, name_embeddings, strict=True): node.name_embedding = name_embedding diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 88fd5639..1538515f 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -177,11 +177,14 @@ async def add_nodes_and_edges_bulk_tx( async def extract_nodes_and_edges_bulk( - clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] + clients: GraphitiClients, + episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], + entity_types: dict[str, BaseModel] | None = None, + excluded_entity_types: list[str] | None = None, ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: extracted_nodes_bulk = await semaphore_gather( *[ - extract_nodes(clients, episode, previous_episodes) + extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types) for episode, previous_episodes in episode_tuples ] ) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index de5a1fa0..7c31f295 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -71,6 +71,7 @@ async def extract_nodes( episode: EpisodicNode, previous_episodes: list[EpisodicNode], entity_types: dict[str, BaseModel] | None = None, + excluded_entity_types: list[str] | None = None, ) -> list[EntityNode]: start = time() llm_client = clients.llm_client @@ -154,6 +155,11 @@ async def extract_nodes( 'entity_type_name' ) + # Check if this entity type should be excluded + if excluded_entity_types and entity_type_name in excluded_entity_types: + logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"') + continue + labels: list[str] = list({'Entity', str(entity_type_name)}) new_node = EntityNode( diff --git a/tests/test_entity_exclusion_int.py b/tests/test_entity_exclusion_int.py new file mode 100644 index 00000000..d5dec06e --- /dev/null +++ b/tests/test_entity_exclusion_int.py @@ -0,0 +1,331 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +from datetime import datetime, timezone + +import pytest +from dotenv import load_dotenv +from pydantic import BaseModel, Field + +from graphiti_core.graphiti import Graphiti +from graphiti_core.helpers import validate_excluded_entity_types + +pytestmark = pytest.mark.integration + +pytest_plugins = ('pytest_asyncio',) + +load_dotenv() + +NEO4J_URI = os.getenv('NEO4J_URI') +NEO4J_USER = os.getenv('NEO4J_USER') +NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD') + + +# Test entity type definitions +class Person(BaseModel): + """A human person mentioned in the conversation.""" + + first_name: str | None = Field(None, description='First name of the person') + last_name: str | None = Field(None, description='Last name of the person') + occupation: str | None = Field(None, description='Job or profession of the person') + + +class Organization(BaseModel): + """A company, institution, or organized group.""" + + organization_type: str | None = Field(None, description='Type of organization (company, NGO, etc.)') + industry: str | None = Field(None, description='Industry or sector the organization operates in') + + +class Location(BaseModel): + """A geographic location, place, or address.""" + + location_type: str | None = Field(None, description='Type of location (city, country, building, etc.)') + coordinates: str | None = Field(None, description='Geographic coordinates if available') + + +@pytest.mark.asyncio +async def test_exclude_default_entity_type(): + """Test excluding the default 'Entity' type while keeping custom types.""" + graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + + try: + await graphiti.build_indices_and_constraints() + + # Define entity types but exclude the default 'Entity' type + entity_types = { + 'Person': Person, + 'Organization': Organization, + } + + # Add an episode that would normally create both Entity and custom type entities + episode_content = "John Smith works at Acme Corporation in New York. The weather is nice today." + + result = await graphiti.add_episode( + name="Business Meeting", + episode_body=episode_content, + source_description="Meeting notes", + reference_time=datetime.now(timezone.utc), + entity_types=entity_types, + excluded_entity_types=['Entity'], # Exclude default type + group_id='test_exclude_default' + ) + + # Verify that nodes were created (custom types should still work) + assert result is not None + + # Search for nodes to verify only custom types were created + search_results = await graphiti.search_( + query="John Smith Acme Corporation", + group_ids=['test_exclude_default'] + ) + + # Check that entities were created but with specific types, not default 'Entity' + found_nodes = search_results.nodes + for node in found_nodes: + assert 'Entity' in node.labels # All nodes should have Entity label + # But they should also have specific type labels + assert any(label in ['Person', 'Organization'] for label in node.labels), \ + f"Node {node.name} should have a specific type label, got: {node.labels}" + + # Clean up + await _cleanup_test_nodes(graphiti, 'test_exclude_default') + + finally: + await graphiti.close() + + +@pytest.mark.asyncio +async def test_exclude_specific_custom_types(): + """Test excluding specific custom entity types while keeping others.""" + graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + + try: + await graphiti.build_indices_and_constraints() + + # Define multiple entity types + entity_types = { + 'Person': Person, + 'Organization': Organization, + 'Location': Location, + } + + # Add an episode with content that would create all types + episode_content = "Sarah Johnson from Google visited the San Francisco office to discuss the new project." + + result = await graphiti.add_episode( + name="Office Visit", + episode_body=episode_content, + source_description="Visit report", + reference_time=datetime.now(timezone.utc), + entity_types=entity_types, + excluded_entity_types=['Organization', 'Location'], # Exclude these types + group_id='test_exclude_custom' + ) + + assert result is not None + + # Search for nodes to verify only Person and Entity types were created + search_results = await graphiti.search_( + query="Sarah Johnson Google San Francisco", + group_ids=['test_exclude_custom'] + ) + + found_nodes = search_results.nodes + + # Should have Person and Entity type nodes, but no Organization or Location + for node in found_nodes: + assert 'Entity' in node.labels + # Should not have excluded types + assert 'Organization' not in node.labels, f"Found excluded Organization in node: {node.name}" + assert 'Location' not in node.labels, f"Found excluded Location in node: {node.name}" + + # Should find at least one Person entity (Sarah Johnson) + person_nodes = [n for n in found_nodes if 'Person' in n.labels] + assert len(person_nodes) > 0, "Should have found at least one Person entity" + + # Clean up + await _cleanup_test_nodes(graphiti, 'test_exclude_custom') + + finally: + await graphiti.close() + + +@pytest.mark.asyncio +async def test_exclude_all_types(): + """Test excluding all entity types (edge case).""" + graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + + try: + await graphiti.build_indices_and_constraints() + + entity_types = { + 'Person': Person, + 'Organization': Organization, + } + + # Exclude all types + result = await graphiti.add_episode( + name="No Entities", + episode_body="This text mentions John and Microsoft but no entities should be created.", + source_description="Test content", + reference_time=datetime.now(timezone.utc), + entity_types=entity_types, + excluded_entity_types=['Entity', 'Person', 'Organization'], # Exclude everything + group_id='test_exclude_all' + ) + + assert result is not None + + # Search for nodes - should find very few or none from this episode + search_results = await graphiti.search_( + query="John Microsoft", + group_ids=['test_exclude_all'] + ) + + # There should be minimal to no entities created + found_nodes = search_results.nodes + assert len(found_nodes) == 0, f"Expected no entities, but found: {[n.name for n in found_nodes]}" + + # Clean up + await _cleanup_test_nodes(graphiti, 'test_exclude_all') + + finally: + await graphiti.close() + + +@pytest.mark.asyncio +async def test_exclude_no_types(): + """Test normal behavior when no types are excluded (baseline test).""" + graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + + try: + await graphiti.build_indices_and_constraints() + + entity_types = { + 'Person': Person, + 'Organization': Organization, + } + + # Don't exclude any types + result = await graphiti.add_episode( + name="Normal Behavior", + episode_body="Alice Smith works at TechCorp.", + source_description="Normal test", + reference_time=datetime.now(timezone.utc), + entity_types=entity_types, + excluded_entity_types=None, # No exclusions + group_id='test_exclude_none' + ) + + assert result is not None + + # Search for nodes - should find entities of all types + search_results = await graphiti.search_( + query="Alice Smith TechCorp", + group_ids=['test_exclude_none'] + ) + + found_nodes = search_results.nodes + assert len(found_nodes) > 0, "Should have found some entities" + + # Should have both Person and Organization entities + person_nodes = [n for n in found_nodes if 'Person' in n.labels] + org_nodes = [n for n in found_nodes if 'Organization' in n.labels] + + assert len(person_nodes) > 0, "Should have found Person entities" + assert len(org_nodes) > 0, "Should have found Organization entities" + + # Clean up + await _cleanup_test_nodes(graphiti, 'test_exclude_none') + + finally: + await graphiti.close() + + +def test_validation_valid_excluded_types(): + """Test validation function with valid excluded types.""" + entity_types = { + 'Person': Person, + 'Organization': Organization, + } + + # Valid exclusions + assert validate_excluded_entity_types(['Entity'], entity_types) is True + assert validate_excluded_entity_types(['Person'], entity_types) is True + assert validate_excluded_entity_types(['Entity', 'Person'], entity_types) is True + assert validate_excluded_entity_types(None, entity_types) is True + assert validate_excluded_entity_types([], entity_types) is True + + +def test_validation_invalid_excluded_types(): + """Test validation function with invalid excluded types.""" + entity_types = { + 'Person': Person, + 'Organization': Organization, + } + + # Invalid exclusions should raise ValueError + with pytest.raises(ValueError, match="Invalid excluded entity types"): + validate_excluded_entity_types(['InvalidType'], entity_types) + + with pytest.raises(ValueError, match="Invalid excluded entity types"): + validate_excluded_entity_types(['Person', 'NonExistentType'], entity_types) + + +@pytest.mark.asyncio +async def test_excluded_types_parameter_validation_in_add_episode(): + """Test that add_episode validates excluded_entity_types parameter.""" + graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + + try: + entity_types = { + 'Person': Person, + } + + # Should raise ValueError for invalid excluded type + with pytest.raises(ValueError, match="Invalid excluded entity types"): + await graphiti.add_episode( + name="Invalid Test", + episode_body="Test content", + source_description="Test", + reference_time=datetime.now(timezone.utc), + entity_types=entity_types, + excluded_entity_types=['NonExistentType'], + group_id='test_validation' + ) + + finally: + await graphiti.close() + + +async def _cleanup_test_nodes(graphiti: Graphiti, group_id: str): + """Helper function to clean up test nodes.""" + try: + # Get all nodes for this group + search_results = await graphiti.search_( + query="*", + group_ids=[group_id] + ) + + # Delete all found nodes + for node in search_results.nodes: + await node.delete(graphiti.driver) + + except Exception as e: + # Log but don't fail the test if cleanup fails + print(f"Warning: Failed to clean up test nodes for group {group_id}: {e}") \ No newline at end of file