diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index b357f3b1..03b1e402 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -359,10 +359,10 @@ class Graphiti: group_id: str | None = None, uuid: str | None = None, update_communities: bool = False, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, excluded_entity_types: list[str] | None = None, previous_episode_uuids: list[str] | None = None, - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None, ) -> AddEpisodeResults: """ @@ -555,9 +555,9 @@ class Graphiti: self, bulk_episodes: list[RawEpisode], group_id: str | None = None, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, excluded_entity_types: list[str] | None = None, - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None, ): """ diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 2bf68b81..9feb3073 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -148,7 +148,7 @@ def validate_group_id(group_id: str) -> bool: def validate_excluded_entity_types( - excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None + excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None ) -> bool: """ Validate that excluded entity types are valid type names. diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index 59d07c88..57a2a577 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -52,6 +52,13 @@ class EntityClassification(BaseModel): ) +class EntitySummary(BaseModel): + summary: str = Field( + ..., + description='Summary containing the important information about the entity. Under 250 words', + ) + + class Prompt(Protocol): extract_message: PromptVersion extract_json: PromptVersion @@ -59,6 +66,7 @@ class Prompt(Protocol): reflexion: PromptVersion classify_nodes: PromptVersion extract_attributes: PromptVersion + extract_summary: PromptVersion class Versions(TypedDict): @@ -68,6 +76,7 @@ class Versions(TypedDict): reflexion: PromptFunction classify_nodes: PromptFunction extract_attributes: PromptFunction + extract_summary: PromptFunction def extract_message(context: dict[str, Any]) -> list[Message]: @@ -259,9 +268,39 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]: Guidelines: 1. Do not hallucinate entity property values if they cannot be found in the current context. 2. Only use the provided MESSAGES and ENTITY to set attribute values. + + + {context['node']} + + """, + ), + ] + + +def extract_summary(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that extracts entity summaries from the provided text.', + ), + Message( + role='user', + content=f""" + + + {json.dumps(context['previous_episodes'], indent=2)} + {json.dumps(context['episode_content'], indent=2)} + + + Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity + from the messages and relevant information from the existing summary. + + Guidelines: + 1. Do not hallucinate entity summary information if they cannot be found in the current context. + 2. Only use the provided MESSAGES and ENTITY to set attribute values. 3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES. Summaries must be no longer than 250 words. - + {context['node']} @@ -275,6 +314,7 @@ versions: Versions = { 'extract_json': extract_json, 'extract_text': extract_text, 'reflexion': reflexion, + 'extract_summary': extract_summary, 'classify_nodes': classify_nodes, 'extract_attributes': extract_attributes, } diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index af9dce67..7c266c94 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -314,17 +314,15 @@ async def node_fulltext_search( + """ YIELD node AS n, score WHERE n:Entity AND n.group_id IN $group_ids - WITH n, score - LIMIT $limit """ + filter_query + """ + WITH n, score + ORDER BY score DESC + LIMIT $limit RETURN """ + ENTITY_NODE_RETURN - + """ - ORDER BY score DESC - """ ) records, _, _ = await driver.execute_query( diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 9263edf3..b80c4f3b 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -169,9 +169,9 @@ async def extract_nodes_and_edges_bulk( clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], edge_type_map: dict[tuple[str, str], list[str]], - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, excluded_entity_types: list[str] | None = None, - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]: extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather( *[ @@ -202,7 +202,7 @@ async def dedupe_nodes_bulk( clients: GraphitiClients, extracted_nodes: list[list[EntityNode]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]: embedder = clients.embedder min_score = 0.8 @@ -290,7 +290,7 @@ async def dedupe_edges_bulk( extracted_edges: list[list[EntityEdge]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], _entities: list[EntityNode], - edge_types: dict[str, BaseModel], + edge_types: dict[str, type[BaseModel]], _edge_type_map: dict[tuple[str, str], list[str]], ) -> dict[str, list[EntityEdge]]: embedder = clients.embedder diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 1455653a..ef78db43 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -114,7 +114,7 @@ async def extract_edges( previous_episodes: list[EpisodicNode], edge_type_map: dict[tuple[str, str], list[str]], group_id: str = '', - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, ) -> list[EntityEdge]: start = time() @@ -249,7 +249,7 @@ async def resolve_extracted_edges( extracted_edges: list[EntityEdge], episode: EpisodicNode, entities: list[EntityNode], - edge_types: dict[str, BaseModel], + edge_types: dict[str, type[BaseModel]], edge_type_map: dict[tuple[str, str], list[str]], ) -> tuple[list[EntityEdge], list[EntityEdge]]: driver = clients.driver @@ -272,7 +272,7 @@ async def resolve_extracted_edges( uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities} # Determine which edge types are relevant for each edge - edge_types_lst: list[dict[str, BaseModel]] = [] + edge_types_lst: list[dict[str, type[BaseModel]]] = [] for extracted_edge in extracted_edges: source_node = uuid_entity_map.get(extracted_edge.source_node_uuid) target_node = uuid_entity_map.get(extracted_edge.target_node_uuid) @@ -381,7 +381,7 @@ async def resolve_extracted_edge( related_edges: list[EntityEdge], existing_edges: list[EntityEdge], episode: EpisodicNode, - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]: if len(related_edges) == 0 and len(existing_edges) == 0: return extracted_edge, [], [] diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 9c3256bf..51a7be4b 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -15,13 +15,10 @@ limitations under the License. """ import logging -from contextlib import suppress from time import time from typing import Any -from uuid import uuid4 -import pydantic -from pydantic import BaseModel, Field +from pydantic import BaseModel from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather @@ -31,6 +28,7 @@ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_en from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions from graphiti_core.prompts.extract_nodes import ( + EntitySummary, ExtractedEntities, ExtractedEntity, MissedEntities, @@ -70,7 +68,7 @@ async def extract_nodes( clients: GraphitiClients, episode: EpisodicNode, previous_episodes: list[EpisodicNode], - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, excluded_entity_types: list[str] | None = None, ) -> list[EntityNode]: start = time() @@ -180,7 +178,7 @@ async def resolve_extracted_nodes( extracted_nodes: list[EntityNode], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, existing_nodes_override: list[EntityNode] | None = None, ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: llm_client = clients.llm_client @@ -223,7 +221,7 @@ async def resolve_extracted_nodes( ], ) - entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {} + entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {} # Prepare context for LLM extracted_nodes_context = [ @@ -297,7 +295,7 @@ async def extract_attributes_from_nodes( nodes: list[EntityNode], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, ) -> list[EntityNode]: llm_client = clients.llm_client embedder = clients.embedder @@ -326,7 +324,7 @@ async def extract_attributes_from_node( node: EntityNode, episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_type: BaseModel | None = None, + entity_type: type[BaseModel] | None = None, ) -> EntityNode: node_context: dict[str, Any] = { 'name': node.name, @@ -335,25 +333,14 @@ async def extract_attributes_from_node( 'attributes': node.attributes, } - attributes_definitions: dict[str, Any] = { - 'summary': ( - str, - Field( - description='Summary containing the important information about the entity. Under 250 words', - ), - ) + attributes_context: dict[str, Any] = { + 'node': node_context, + 'episode_content': episode.content if episode is not None else '', + 'previous_episodes': [ep.content for ep in previous_episodes] + if previous_episodes is not None + else [], } - if entity_type is not None: - for field_name, field_info in entity_type.model_fields.items(): - attributes_definitions[field_name] = ( - field_info.annotation, - Field(description=field_info.description), - ) - - unique_model_name = f'EntityAttributes_{uuid4().hex}' - entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions) - summary_context: dict[str, Any] = { 'node': node_context, 'episode_content': episode.content if episode is not None else '', @@ -362,20 +349,30 @@ async def extract_attributes_from_node( else [], } - llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_attributes(summary_context), - response_model=entity_attributes_model, + llm_response = ( + ( + await llm_client.generate_response( + prompt_library.extract_nodes.extract_attributes(attributes_context), + response_model=entity_type, + model_size=ModelSize.small, + ) + ) + if entity_type is not None + else {} + ) + + summary_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_summary(summary_context), + response_model=EntitySummary, model_size=ModelSize.small, ) - entity_attributes_model(**llm_response) + if entity_type is not None: + entity_type(**llm_response) - node.summary = llm_response.get('summary', '') + node.summary = summary_response.get('summary', '') node_attributes = {key: value for key, value in llm_response.items()} - with suppress(KeyError): - del node_attributes['summary'] - node.attributes.update(node_attributes) return node diff --git a/graphiti_core/utils/ontology_utils/entity_types_utils.py b/graphiti_core/utils/ontology_utils/entity_types_utils.py index f6cb08fb..bbc07af7 100644 --- a/graphiti_core/utils/ontology_utils/entity_types_utils.py +++ b/graphiti_core/utils/ontology_utils/entity_types_utils.py @@ -21,7 +21,7 @@ from graphiti_core.nodes import EntityNode def validate_entity_types( - entity_types: dict[str, BaseModel] | None, + entity_types: dict[str, type[BaseModel]] | None, ) -> bool: if entity_types is None: return True diff --git a/pyproject.toml b/pyproject.toml index 6d33a866..41639b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.18.1" +version = "0.18.2" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 9d98f4de..237e2bfc 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -64,7 +64,8 @@ async def test_graphiti_init(driver): await graphiti.build_indices_and_constraints() search_filter = SearchFilters( - created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]] + node_labels=['Person'], + created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]], ) results = await graphiti.search_( diff --git a/uv.lock b/uv.lock index 6fcf17f6..cc771368 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.18.1" +version = "0.18.2" source = { editable = "." } dependencies = [ { name = "diskcache" },