diff --git a/graphiti_core/utils/ontology_utils/entity_types_utils.py b/graphiti_core/utils/ontology_utils/entity_types_utils.py index f6cb08fb..41a94128 100644 --- a/graphiti_core/utils/ontology_utils/entity_types_utils.py +++ b/graphiti_core/utils/ontology_utils/entity_types_utils.py @@ -26,12 +26,17 @@ def validate_entity_types( if entity_types is None: return True - entity_node_field_names = EntityNode.model_fields.keys() - + # Iterate through the provided entity types for entity_type_name, entity_type_model in entity_types.items(): - entity_type_field_names = entity_type_model.model_fields.keys() - for entity_type_field_name in entity_type_field_names: - if entity_type_field_name in entity_node_field_names: - raise EntityTypeValidationError(entity_type_name, entity_type_field_name) + # Convert model fields to set for fast intersection + entity_type_field_names = set(entity_type_model.model_fields.keys()) + # Intersect to find any clashing field + conflict_fields = _ENTITY_NODE_FIELD_NAMES & entity_type_field_names + if conflict_fields: + # Only raise for the first conflict found, as per original behavior + raise EntityTypeValidationError(entity_type_name, next(iter(conflict_fields))) - return True + return True # Preserve existing comment + + +_ENTITY_NODE_FIELD_NAMES = set(EntityNode.model_fields.keys())