From 2864786dd986a10604bf38b3623a9171f947cfbe Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Sat, 4 Oct 2025 13:37:39 -0700 Subject: [PATCH] Refactor node extraction; remove summary from attribute extraction (#977) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor node extraction for better maintainability - Extract helper functions from extract_attributes_from_node to improve code organization - Add _extract_entity_attributes, _extract_entity_summary, and _build_episode_context helpers - Apply consistent formatting (double quotes per ruff configuration) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * Apply consistent single quote style throughout node_operations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * cleanup * cleanup 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * Bump version to 0.22.0pre0 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --------- Co-authored-by: Claude --- .../utils/maintenance/node_operations.py | 146 +++++++++++------- pyproject.toml | 2 +- uv.lock | 8 +- 3 files changed, 97 insertions(+), 59 deletions(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 56f0a1e2..8db44218 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -74,7 +74,9 @@ async def extract_nodes_reflexion( } llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id + prompt_library.extract_nodes.reflexion(context), + MissedEntities, + group_id=group_id, ) missed_entities = llm_response.get('missed_entities', []) @@ -483,65 +485,95 @@ async def extract_attributes_from_node( entity_type: type[BaseModel] | None = None, should_summarize_node: NodeSummaryFilter | None = None, ) -> EntityNode: - node_context: dict[str, Any] = { - 'name': node.name, - 'summary': node.summary, - 'entity_types': node.labels, - 'attributes': node.attributes, - } - - attributes_context: dict[str, Any] = { - 'node': node_context, - 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': ( - [ep.content for ep in previous_episodes] if previous_episodes is not None else [] - ), - } - - summary_context: dict[str, Any] = { - 'node': node_context, - 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': ( - [ep.content for ep in previous_episodes] if previous_episodes is not None else [] - ), - } - - has_entity_attributes: bool = bool( - entity_type is not None and len(entity_type.model_fields) != 0 + # Extract attributes if entity type is defined and has attributes + llm_response = await _extract_entity_attributes( + llm_client, node, episode, previous_episodes, entity_type ) - llm_response = ( - ( - await llm_client.generate_response( - prompt_library.extract_nodes.extract_attributes(attributes_context), - response_model=entity_type, - model_size=ModelSize.small, - group_id=node.group_id, - ) - ) - if has_entity_attributes - else {} + # Extract summary if needed + await _extract_entity_summary( + llm_client, node, episode, previous_episodes, should_summarize_node ) - # Determine if summary should be generated - generate_summary = True - if should_summarize_node is not None: - generate_summary = await should_summarize_node(node) - - # Conditionally generate summary - if generate_summary: - summary_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_summary(summary_context), - response_model=EntitySummary, - model_size=ModelSize.small, - group_id=node.group_id, - ) - node.summary = summary_response.get('summary', '') - - if has_entity_attributes and entity_type is not None: - entity_type(**llm_response) - node_attributes = {key: value for key, value in llm_response.items()} - - node.attributes.update(node_attributes) + node.attributes.update(llm_response) return node + + +async def _extract_entity_attributes( + llm_client: LLMClient, + node: EntityNode, + episode: EpisodicNode | None, + previous_episodes: list[EpisodicNode] | None, + entity_type: type[BaseModel] | None, +) -> dict[str, Any]: + if entity_type is None or len(entity_type.model_fields) == 0: + return {} + + attributes_context = _build_episode_context( + # should not include summary + node_data={ + 'name': node.name, + 'entity_types': node.labels, + 'attributes': node.attributes, + }, + episode=episode, + previous_episodes=previous_episodes, + ) + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_attributes(attributes_context), + response_model=entity_type, + model_size=ModelSize.small, + group_id=node.group_id, + ) + + # validate response + entity_type(**llm_response) + + return llm_response + + +async def _extract_entity_summary( + llm_client: LLMClient, + node: EntityNode, + episode: EpisodicNode | None, + previous_episodes: list[EpisodicNode] | None, + should_summarize_node: NodeSummaryFilter | None, +) -> None: + if should_summarize_node is not None and not await should_summarize_node(node): + return + + summary_context = _build_episode_context( + node_data={ + 'name': node.name, + 'summary': node.summary, + 'entity_types': node.labels, + 'attributes': node.attributes, + }, + episode=episode, + previous_episodes=previous_episodes, + ) + + summary_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_summary(summary_context), + response_model=EntitySummary, + model_size=ModelSize.small, + group_id=node.group_id, + ) + + node.summary = summary_response.get('summary', '') + + +def _build_episode_context( + node_data: dict[str, Any], + episode: EpisodicNode | None, + previous_episodes: list[EpisodicNode] | None, +) -> dict[str, Any]: + return { + 'node': node_data, + 'episode_content': episode.content if episode is not None else '', + 'previous_episodes': ( + [ep.content for ep in previous_episodes] if previous_episodes is not None else [] + ), + } diff --git a/pyproject.toml b/pyproject.toml index f3b9ce91..9dab4e71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.21.0" +version = "0.22.0pre0" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index c5f9e533..10127033 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <4" resolution-markers = [ "python_full_version >= '3.14'", @@ -803,6 +803,7 @@ anthropic = [ ] dev = [ { name = "anthropic" }, + { name = "boto3" }, { name = "diskcache-stubs" }, { name = "falkordb" }, { name = "google-genai" }, @@ -811,9 +812,11 @@ dev = [ { name = "jupyterlab" }, { name = "kuzu" }, { name = "langchain-anthropic" }, + { name = "langchain-aws" }, { name = "langchain-openai" }, { name = "langgraph" }, { name = "langsmith" }, + { name = "opensearch-py" }, { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -855,6 +858,7 @@ voyageai = [ requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, { name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" }, + { name = "boto3", marker = "extra == 'dev'", specifier = ">=1.39.16" }, { name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" }, { name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" }, { name = "diskcache", specifier = ">=5.6.3" }, @@ -870,6 +874,7 @@ requires-dist = [ { name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" }, { name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" }, { name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" }, + { name = "langchain-aws", marker = "extra == 'dev'", specifier = ">=0.2.29" }, { name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" }, { name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" }, { name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" }, @@ -877,6 +882,7 @@ requires-dist = [ { name = "neo4j", specifier = ">=5.26.0" }, { name = "numpy", specifier = ">=1.0.0" }, { name = "openai", specifier = ">=1.91.0" }, + { name = "opensearch-py", marker = "extra == 'dev'", specifier = ">=3.0.0" }, { name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" }, { name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" }, { name = "posthog", specifier = ">=3.0.0" },