Refactor node extraction; remove summary from attribute extraction (#977)
* Refactor node extraction for better maintainability - Extract helper functions from extract_attributes_from_node to improve code organization - Add _extract_entity_attributes, _extract_entity_summary, and _build_episode_context helpers - Apply consistent formatting (double quotes per ruff configuration) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Apply consistent single quote style throughout node_operations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * cleanup * cleanup 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Bump version to 0.22.0pre0 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
5a67e660dc
commit
2864786dd9
3 changed files with 97 additions and 59 deletions
|
|
@ -74,7 +74,9 @@ async def extract_nodes_reflexion(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id
|
prompt_library.extract_nodes.reflexion(context),
|
||||||
|
MissedEntities,
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
missed_entities = llm_response.get('missed_entities', [])
|
missed_entities = llm_response.get('missed_entities', [])
|
||||||
|
|
||||||
|
|
@ -483,65 +485,95 @@ async def extract_attributes_from_node(
|
||||||
entity_type: type[BaseModel] | None = None,
|
entity_type: type[BaseModel] | None = None,
|
||||||
should_summarize_node: NodeSummaryFilter | None = None,
|
should_summarize_node: NodeSummaryFilter | None = None,
|
||||||
) -> EntityNode:
|
) -> EntityNode:
|
||||||
node_context: dict[str, Any] = {
|
# Extract attributes if entity type is defined and has attributes
|
||||||
'name': node.name,
|
llm_response = await _extract_entity_attributes(
|
||||||
'summary': node.summary,
|
llm_client, node, episode, previous_episodes, entity_type
|
||||||
'entity_types': node.labels,
|
|
||||||
'attributes': node.attributes,
|
|
||||||
}
|
|
||||||
|
|
||||||
attributes_context: dict[str, Any] = {
|
|
||||||
'node': node_context,
|
|
||||||
'episode_content': episode.content if episode is not None else '',
|
|
||||||
'previous_episodes': (
|
|
||||||
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
summary_context: dict[str, Any] = {
|
|
||||||
'node': node_context,
|
|
||||||
'episode_content': episode.content if episode is not None else '',
|
|
||||||
'previous_episodes': (
|
|
||||||
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
has_entity_attributes: bool = bool(
|
|
||||||
entity_type is not None and len(entity_type.model_fields) != 0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_response = (
|
# Extract summary if needed
|
||||||
(
|
await _extract_entity_summary(
|
||||||
await llm_client.generate_response(
|
llm_client, node, episode, previous_episodes, should_summarize_node
|
||||||
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
|
||||||
response_model=entity_type,
|
|
||||||
model_size=ModelSize.small,
|
|
||||||
group_id=node.group_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if has_entity_attributes
|
|
||||||
else {}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine if summary should be generated
|
node.attributes.update(llm_response)
|
||||||
generate_summary = True
|
|
||||||
if should_summarize_node is not None:
|
|
||||||
generate_summary = await should_summarize_node(node)
|
|
||||||
|
|
||||||
# Conditionally generate summary
|
|
||||||
if generate_summary:
|
|
||||||
summary_response = await llm_client.generate_response(
|
|
||||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
|
||||||
response_model=EntitySummary,
|
|
||||||
model_size=ModelSize.small,
|
|
||||||
group_id=node.group_id,
|
|
||||||
)
|
|
||||||
node.summary = summary_response.get('summary', '')
|
|
||||||
|
|
||||||
if has_entity_attributes and entity_type is not None:
|
|
||||||
entity_type(**llm_response)
|
|
||||||
node_attributes = {key: value for key, value in llm_response.items()}
|
|
||||||
|
|
||||||
node.attributes.update(node_attributes)
|
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_entity_attributes(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
node: EntityNode,
|
||||||
|
episode: EpisodicNode | None,
|
||||||
|
previous_episodes: list[EpisodicNode] | None,
|
||||||
|
entity_type: type[BaseModel] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if entity_type is None or len(entity_type.model_fields) == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
attributes_context = _build_episode_context(
|
||||||
|
# should not include summary
|
||||||
|
node_data={
|
||||||
|
'name': node.name,
|
||||||
|
'entity_types': node.labels,
|
||||||
|
'attributes': node.attributes,
|
||||||
|
},
|
||||||
|
episode=episode,
|
||||||
|
previous_episodes=previous_episodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
||||||
|
response_model=entity_type,
|
||||||
|
model_size=ModelSize.small,
|
||||||
|
group_id=node.group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# validate response
|
||||||
|
entity_type(**llm_response)
|
||||||
|
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_entity_summary(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
node: EntityNode,
|
||||||
|
episode: EpisodicNode | None,
|
||||||
|
previous_episodes: list[EpisodicNode] | None,
|
||||||
|
should_summarize_node: NodeSummaryFilter | None,
|
||||||
|
) -> None:
|
||||||
|
if should_summarize_node is not None and not await should_summarize_node(node):
|
||||||
|
return
|
||||||
|
|
||||||
|
summary_context = _build_episode_context(
|
||||||
|
node_data={
|
||||||
|
'name': node.name,
|
||||||
|
'summary': node.summary,
|
||||||
|
'entity_types': node.labels,
|
||||||
|
'attributes': node.attributes,
|
||||||
|
},
|
||||||
|
episode=episode,
|
||||||
|
previous_episodes=previous_episodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_response = await llm_client.generate_response(
|
||||||
|
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||||
|
response_model=EntitySummary,
|
||||||
|
model_size=ModelSize.small,
|
||||||
|
group_id=node.group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
node.summary = summary_response.get('summary', '')
|
||||||
|
|
||||||
|
|
||||||
|
def _build_episode_context(
|
||||||
|
node_data: dict[str, Any],
|
||||||
|
episode: EpisodicNode | None,
|
||||||
|
previous_episodes: list[EpisodicNode] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
'node': node_data,
|
||||||
|
'episode_content': episode.content if episode is not None else '',
|
||||||
|
'previous_episodes': (
|
||||||
|
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.21.0"
|
version = "0.22.0pre0"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
8
uv.lock
generated
8
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.10, <4"
|
requires-python = ">=3.10, <4"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.14'",
|
"python_full_version >= '3.14'",
|
||||||
|
|
@ -803,6 +803,7 @@ anthropic = [
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
{ name = "anthropic" },
|
{ name = "anthropic" },
|
||||||
|
{ name = "boto3" },
|
||||||
{ name = "diskcache-stubs" },
|
{ name = "diskcache-stubs" },
|
||||||
{ name = "falkordb" },
|
{ name = "falkordb" },
|
||||||
{ name = "google-genai" },
|
{ name = "google-genai" },
|
||||||
|
|
@ -811,9 +812,11 @@ dev = [
|
||||||
{ name = "jupyterlab" },
|
{ name = "jupyterlab" },
|
||||||
{ name = "kuzu" },
|
{ name = "kuzu" },
|
||||||
{ name = "langchain-anthropic" },
|
{ name = "langchain-anthropic" },
|
||||||
|
{ name = "langchain-aws" },
|
||||||
{ name = "langchain-openai" },
|
{ name = "langchain-openai" },
|
||||||
{ name = "langgraph" },
|
{ name = "langgraph" },
|
||||||
{ name = "langsmith" },
|
{ name = "langsmith" },
|
||||||
|
{ name = "opensearch-py" },
|
||||||
{ name = "pyright" },
|
{ name = "pyright" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
|
|
@ -855,6 +858,7 @@ voyageai = [
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
||||||
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
||||||
|
{ name = "boto3", marker = "extra == 'dev'", specifier = ">=1.39.16" },
|
||||||
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
||||||
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
||||||
{ name = "diskcache", specifier = ">=5.6.3" },
|
{ name = "diskcache", specifier = ">=5.6.3" },
|
||||||
|
|
@ -870,6 +874,7 @@ requires-dist = [
|
||||||
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
|
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
|
||||||
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
|
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
|
||||||
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
|
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
|
||||||
|
{ name = "langchain-aws", marker = "extra == 'dev'", specifier = ">=0.2.29" },
|
||||||
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
|
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
|
||||||
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
|
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
|
||||||
{ name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" },
|
{ name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" },
|
||||||
|
|
@ -877,6 +882,7 @@ requires-dist = [
|
||||||
{ name = "neo4j", specifier = ">=5.26.0" },
|
{ name = "neo4j", specifier = ">=5.26.0" },
|
||||||
{ name = "numpy", specifier = ">=1.0.0" },
|
{ name = "numpy", specifier = ">=1.0.0" },
|
||||||
{ name = "openai", specifier = ">=1.91.0" },
|
{ name = "openai", specifier = ">=1.91.0" },
|
||||||
|
{ name = "opensearch-py", marker = "extra == 'dev'", specifier = ">=3.0.0" },
|
||||||
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
||||||
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
||||||
{ name = "posthog", specifier = ">=3.0.0" },
|
{ name = "posthog", specifier = ">=3.0.0" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue