diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index ebbaec87..845f1377 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -31,9 +31,13 @@ class MissedEntities(BaseModel): class EntityClassification(BaseModel): - entity_classification: str = Field( + entities: list[str] = Field( ..., - description='Dictionary of entity classifications. Key is the entity name and value is the entity type', + description='List of entities', + ) + entity_classifications: list[str | None] = Field( + ..., + description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.', ) @@ -180,7 +184,8 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]: Guidelines: 1. Each entity must have exactly one type - 2. If none of the provided entity types accurately classify an extracted node, the type should be set to None + 2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities. + 3. If none of the provided entity types accurately classify an extracted node, the type should be set to None """ return [ Message(role='system', content=sys_prompt), diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 9bd963ab..425529d7 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import ast import logging from time import time @@ -163,8 +162,9 @@ async def extract_nodes( prompt_library.extract_nodes.classify_nodes(node_classification_context), response_model=EntityClassification, ) - response_string = llm_response.get('entity_classification', '{}') - node_classifications.update(ast.literal_eval(response_string)) + entities = llm_response.get('entities', []) + entity_classifications = llm_response.get('entity_classifications', []) + node_classifications.update(dict(zip(entities, entity_classifications))) end = time() logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms') @@ -173,7 +173,9 @@ async def extract_nodes( for name in extracted_node_names: entity_type = node_classifications.get(name) labels = ( - ['Entity'] if entity_type is None or entity_type == 'None' else ['Entity', entity_type] + ['Entity'] + if entity_type is None or entity_type == 'None' or entity_type == 'null' + else ['Entity', entity_type] ) new_node = EntityNode( diff --git a/pyproject.toml b/pyproject.toml index 14a22012..bbe09393 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.7.4" +version = "0.7.5" description = "A temporal graph building library" authors = [ "Paul Paliychuk ",