Refactor node extraction; remove summary from attribute extraction (#977)
* Refactor node extraction for better maintainability - Extract helper functions from extract_attributes_from_node to improve code organization - Add _extract_entity_attributes, _extract_entity_summary, and _build_episode_context helpers - Apply consistent formatting (double quotes per ruff configuration) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Apply consistent single quote style throughout node_operations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * cleanup * cleanup 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Bump version to 0.22.0pre0 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
5a67e660dc
commit
2864786dd9
3 changed files with 97 additions and 59 deletions
|
|
@ -74,7 +74,9 @@ async def extract_nodes_reflexion(
|
|||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id
|
||||
prompt_library.extract_nodes.reflexion(context),
|
||||
MissedEntities,
|
||||
group_id=group_id,
|
||||
)
|
||||
missed_entities = llm_response.get('missed_entities', [])
|
||||
|
||||
|
|
@ -483,65 +485,95 @@ async def extract_attributes_from_node(
|
|||
entity_type: type[BaseModel] | None = None,
|
||||
should_summarize_node: NodeSummaryFilter | None = None,
|
||||
) -> EntityNode:
|
||||
node_context: dict[str, Any] = {
|
||||
'name': node.name,
|
||||
'summary': node.summary,
|
||||
'entity_types': node.labels,
|
||||
'attributes': node.attributes,
|
||||
}
|
||||
|
||||
attributes_context: dict[str, Any] = {
|
||||
'node': node_context,
|
||||
'episode_content': episode.content if episode is not None else '',
|
||||
'previous_episodes': (
|
||||
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
||||
),
|
||||
}
|
||||
|
||||
summary_context: dict[str, Any] = {
|
||||
'node': node_context,
|
||||
'episode_content': episode.content if episode is not None else '',
|
||||
'previous_episodes': (
|
||||
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
||||
),
|
||||
}
|
||||
|
||||
has_entity_attributes: bool = bool(
|
||||
entity_type is not None and len(entity_type.model_fields) != 0
|
||||
# Extract attributes if entity type is defined and has attributes
|
||||
llm_response = await _extract_entity_attributes(
|
||||
llm_client, node, episode, previous_episodes, entity_type
|
||||
)
|
||||
|
||||
llm_response = (
|
||||
(
|
||||
await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
||||
response_model=entity_type,
|
||||
model_size=ModelSize.small,
|
||||
group_id=node.group_id,
|
||||
)
|
||||
)
|
||||
if has_entity_attributes
|
||||
else {}
|
||||
# Extract summary if needed
|
||||
await _extract_entity_summary(
|
||||
llm_client, node, episode, previous_episodes, should_summarize_node
|
||||
)
|
||||
|
||||
# Determine if summary should be generated
|
||||
generate_summary = True
|
||||
if should_summarize_node is not None:
|
||||
generate_summary = await should_summarize_node(node)
|
||||
|
||||
# Conditionally generate summary
|
||||
if generate_summary:
|
||||
summary_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||
response_model=EntitySummary,
|
||||
model_size=ModelSize.small,
|
||||
group_id=node.group_id,
|
||||
)
|
||||
node.summary = summary_response.get('summary', '')
|
||||
|
||||
if has_entity_attributes and entity_type is not None:
|
||||
entity_type(**llm_response)
|
||||
node_attributes = {key: value for key, value in llm_response.items()}
|
||||
|
||||
node.attributes.update(node_attributes)
|
||||
node.attributes.update(llm_response)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def _extract_entity_attributes(
|
||||
llm_client: LLMClient,
|
||||
node: EntityNode,
|
||||
episode: EpisodicNode | None,
|
||||
previous_episodes: list[EpisodicNode] | None,
|
||||
entity_type: type[BaseModel] | None,
|
||||
) -> dict[str, Any]:
|
||||
if entity_type is None or len(entity_type.model_fields) == 0:
|
||||
return {}
|
||||
|
||||
attributes_context = _build_episode_context(
|
||||
# should not include summary
|
||||
node_data={
|
||||
'name': node.name,
|
||||
'entity_types': node.labels,
|
||||
'attributes': node.attributes,
|
||||
},
|
||||
episode=episode,
|
||||
previous_episodes=previous_episodes,
|
||||
)
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
||||
response_model=entity_type,
|
||||
model_size=ModelSize.small,
|
||||
group_id=node.group_id,
|
||||
)
|
||||
|
||||
# validate response
|
||||
entity_type(**llm_response)
|
||||
|
||||
return llm_response
|
||||
|
||||
|
||||
async def _extract_entity_summary(
|
||||
llm_client: LLMClient,
|
||||
node: EntityNode,
|
||||
episode: EpisodicNode | None,
|
||||
previous_episodes: list[EpisodicNode] | None,
|
||||
should_summarize_node: NodeSummaryFilter | None,
|
||||
) -> None:
|
||||
if should_summarize_node is not None and not await should_summarize_node(node):
|
||||
return
|
||||
|
||||
summary_context = _build_episode_context(
|
||||
node_data={
|
||||
'name': node.name,
|
||||
'summary': node.summary,
|
||||
'entity_types': node.labels,
|
||||
'attributes': node.attributes,
|
||||
},
|
||||
episode=episode,
|
||||
previous_episodes=previous_episodes,
|
||||
)
|
||||
|
||||
summary_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||
response_model=EntitySummary,
|
||||
model_size=ModelSize.small,
|
||||
group_id=node.group_id,
|
||||
)
|
||||
|
||||
node.summary = summary_response.get('summary', '')
|
||||
|
||||
|
||||
def _build_episode_context(
|
||||
node_data: dict[str, Any],
|
||||
episode: EpisodicNode | None,
|
||||
previous_episodes: list[EpisodicNode] | None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
'node': node_data,
|
||||
'episode_content': episode.content if episode is not None else '',
|
||||
'previous_episodes': (
|
||||
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.21.0"
|
||||
version = "0.22.0pre0"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
8
uv.lock
generated
8
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.10, <4"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14'",
|
||||
|
|
@ -803,6 +803,7 @@ anthropic = [
|
|||
]
|
||||
dev = [
|
||||
{ name = "anthropic" },
|
||||
{ name = "boto3" },
|
||||
{ name = "diskcache-stubs" },
|
||||
{ name = "falkordb" },
|
||||
{ name = "google-genai" },
|
||||
|
|
@ -811,9 +812,11 @@ dev = [
|
|||
{ name = "jupyterlab" },
|
||||
{ name = "kuzu" },
|
||||
{ name = "langchain-anthropic" },
|
||||
{ name = "langchain-aws" },
|
||||
{ name = "langchain-openai" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "langsmith" },
|
||||
{ name = "opensearch-py" },
|
||||
{ name = "pyright" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
|
|
@ -855,6 +858,7 @@ voyageai = [
|
|||
requires-dist = [
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
||||
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
||||
{ name = "boto3", marker = "extra == 'dev'", specifier = ">=1.39.16" },
|
||||
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
||||
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
||||
{ name = "diskcache", specifier = ">=5.6.3" },
|
||||
|
|
@ -870,6 +874,7 @@ requires-dist = [
|
|||
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
|
||||
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
|
||||
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
|
||||
{ name = "langchain-aws", marker = "extra == 'dev'", specifier = ">=0.2.29" },
|
||||
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
|
||||
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
|
||||
{ name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" },
|
||||
|
|
@ -877,6 +882,7 @@ requires-dist = [
|
|||
{ name = "neo4j", specifier = ">=5.26.0" },
|
||||
{ name = "numpy", specifier = ">=1.0.0" },
|
||||
{ name = "openai", specifier = ">=1.91.0" },
|
||||
{ name = "opensearch-py", marker = "extra == 'dev'", specifier = ">=3.0.0" },
|
||||
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
||||
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
||||
{ name = "posthog", specifier = ">=3.0.0" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue