move summary out of attribute extraction (#792)

* move summary out of attribute extraction

* linter

* linter

* fix db query
This commit is contained in:
Preston Rasmussen 2025-07-31 12:15:21 -04:00 committed by GitHub
parent e5112244e5
commit ab8106cb4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 93 additions and 57 deletions

View file

@ -359,10 +359,10 @@ class Graphiti:
group_id: str | None = None, group_id: str | None = None,
uuid: str | None = None, uuid: str | None = None,
update_communities: bool = False, update_communities: bool = False,
entity_types: dict[str, BaseModel] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None, excluded_entity_types: list[str] | None = None,
previous_episode_uuids: list[str] | None = None, previous_episode_uuids: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None, edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None,
) -> AddEpisodeResults: ) -> AddEpisodeResults:
""" """
@ -555,9 +555,9 @@ class Graphiti:
self, self,
bulk_episodes: list[RawEpisode], bulk_episodes: list[RawEpisode],
group_id: str | None = None, group_id: str | None = None,
entity_types: dict[str, BaseModel] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None, excluded_entity_types: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None, edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None,
): ):
""" """

View file

@ -148,7 +148,7 @@ def validate_group_id(group_id: str) -> bool:
def validate_excluded_entity_types( def validate_excluded_entity_types(
excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
) -> bool: ) -> bool:
""" """
Validate that excluded entity types are valid type names. Validate that excluded entity types are valid type names.

View file

@ -52,6 +52,13 @@ class EntityClassification(BaseModel):
) )
class EntitySummary(BaseModel):
summary: str = Field(
...,
description='Summary containing the important information about the entity. Under 250 words',
)
class Prompt(Protocol): class Prompt(Protocol):
extract_message: PromptVersion extract_message: PromptVersion
extract_json: PromptVersion extract_json: PromptVersion
@ -59,6 +66,7 @@ class Prompt(Protocol):
reflexion: PromptVersion reflexion: PromptVersion
classify_nodes: PromptVersion classify_nodes: PromptVersion
extract_attributes: PromptVersion extract_attributes: PromptVersion
extract_summary: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
@ -68,6 +76,7 @@ class Versions(TypedDict):
reflexion: PromptFunction reflexion: PromptFunction
classify_nodes: PromptFunction classify_nodes: PromptFunction
extract_attributes: PromptFunction extract_attributes: PromptFunction
extract_summary: PromptFunction
def extract_message(context: dict[str, Any]) -> list[Message]: def extract_message(context: dict[str, Any]) -> list[Message]:
@ -259,9 +268,39 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
Guidelines: Guidelines:
1. Do not hallucinate entity property values if they cannot be found in the current context. 1. Do not hallucinate entity property values if they cannot be found in the current context.
2. Only use the provided MESSAGES and ENTITY to set attribute values. 2. Only use the provided MESSAGES and ENTITY to set attribute values.
<ENTITY>
{context['node']}
</ENTITY>
""",
),
]
def extract_summary(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that extracts entity summaries from the provided text.',
),
Message(
role='user',
content=f"""
<MESSAGES>
{json.dumps(context['previous_episodes'], indent=2)}
{json.dumps(context['episode_content'], indent=2)}
</MESSAGES>
Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity
from the messages and relevant information from the existing summary.
Guidelines:
1. Do not hallucinate entity summary information if they cannot be found in the current context.
2. Only use the provided MESSAGES and ENTITY to set attribute values.
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES. 3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
Summaries must be no longer than 250 words. Summaries must be no longer than 250 words.
<ENTITY> <ENTITY>
{context['node']} {context['node']}
</ENTITY> </ENTITY>
@ -275,6 +314,7 @@ versions: Versions = {
'extract_json': extract_json, 'extract_json': extract_json,
'extract_text': extract_text, 'extract_text': extract_text,
'reflexion': reflexion, 'reflexion': reflexion,
'extract_summary': extract_summary,
'classify_nodes': classify_nodes, 'classify_nodes': classify_nodes,
'extract_attributes': extract_attributes, 'extract_attributes': extract_attributes,
} }

View file

@ -314,17 +314,15 @@ async def node_fulltext_search(
+ """ + """
YIELD node AS n, score YIELD node AS n, score
WHERE n:Entity AND n.group_id IN $group_ids WHERE n:Entity AND n.group_id IN $group_ids
WITH n, score
LIMIT $limit
""" """
+ filter_query + filter_query
+ """ + """
WITH n, score
ORDER BY score DESC
LIMIT $limit
RETURN RETURN
""" """
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
"""
) )
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(

View file

@ -169,9 +169,9 @@ async def extract_nodes_and_edges_bulk(
clients: GraphitiClients, clients: GraphitiClients,
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
edge_type_map: dict[tuple[str, str], list[str]], edge_type_map: dict[tuple[str, str], list[str]],
entity_types: dict[str, BaseModel] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None, excluded_entity_types: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None, edge_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]: ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather( extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
*[ *[
@ -202,7 +202,7 @@ async def dedupe_nodes_bulk(
clients: GraphitiClients, clients: GraphitiClients,
extracted_nodes: list[list[EntityNode]], extracted_nodes: list[list[EntityNode]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, BaseModel] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]: ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
embedder = clients.embedder embedder = clients.embedder
min_score = 0.8 min_score = 0.8
@ -290,7 +290,7 @@ async def dedupe_edges_bulk(
extracted_edges: list[list[EntityEdge]], extracted_edges: list[list[EntityEdge]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
_entities: list[EntityNode], _entities: list[EntityNode],
edge_types: dict[str, BaseModel], edge_types: dict[str, type[BaseModel]],
_edge_type_map: dict[tuple[str, str], list[str]], _edge_type_map: dict[tuple[str, str], list[str]],
) -> dict[str, list[EntityEdge]]: ) -> dict[str, list[EntityEdge]]:
embedder = clients.embedder embedder = clients.embedder

View file

@ -114,7 +114,7 @@ async def extract_edges(
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
edge_type_map: dict[tuple[str, str], list[str]], edge_type_map: dict[tuple[str, str], list[str]],
group_id: str = '', group_id: str = '',
edge_types: dict[str, BaseModel] | None = None, edge_types: dict[str, type[BaseModel]] | None = None,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time() start = time()
@ -249,7 +249,7 @@ async def resolve_extracted_edges(
extracted_edges: list[EntityEdge], extracted_edges: list[EntityEdge],
episode: EpisodicNode, episode: EpisodicNode,
entities: list[EntityNode], entities: list[EntityNode],
edge_types: dict[str, BaseModel], edge_types: dict[str, type[BaseModel]],
edge_type_map: dict[tuple[str, str], list[str]], edge_type_map: dict[tuple[str, str], list[str]],
) -> tuple[list[EntityEdge], list[EntityEdge]]: ) -> tuple[list[EntityEdge], list[EntityEdge]]:
driver = clients.driver driver = clients.driver
@ -272,7 +272,7 @@ async def resolve_extracted_edges(
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities} uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
# Determine which edge types are relevant for each edge # Determine which edge types are relevant for each edge
edge_types_lst: list[dict[str, BaseModel]] = [] edge_types_lst: list[dict[str, type[BaseModel]]] = []
for extracted_edge in extracted_edges: for extracted_edge in extracted_edges:
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid) source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid) target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
@ -381,7 +381,7 @@ async def resolve_extracted_edge(
related_edges: list[EntityEdge], related_edges: list[EntityEdge],
existing_edges: list[EntityEdge], existing_edges: list[EntityEdge],
episode: EpisodicNode, episode: EpisodicNode,
edge_types: dict[str, BaseModel] | None = None, edge_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]: ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0: if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, [], [] return extracted_edge, [], []

View file

@ -15,13 +15,10 @@ limitations under the License.
""" """
import logging import logging
from contextlib import suppress
from time import time from time import time
from typing import Any from typing import Any
from uuid import uuid4
import pydantic from pydantic import BaseModel
from pydantic import BaseModel, Field
from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
@ -31,6 +28,7 @@ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_en
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import ( from graphiti_core.prompts.extract_nodes import (
EntitySummary,
ExtractedEntities, ExtractedEntities,
ExtractedEntity, ExtractedEntity,
MissedEntities, MissedEntities,
@ -70,7 +68,7 @@ async def extract_nodes(
clients: GraphitiClients, clients: GraphitiClients,
episode: EpisodicNode, episode: EpisodicNode,
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
entity_types: dict[str, BaseModel] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None, excluded_entity_types: list[str] | None = None,
) -> list[EntityNode]: ) -> list[EntityNode]:
start = time() start = time()
@ -180,7 +178,7 @@ async def resolve_extracted_nodes(
extracted_nodes: list[EntityNode], extracted_nodes: list[EntityNode],
episode: EpisodicNode | None = None, episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None, previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
existing_nodes_override: list[EntityNode] | None = None, existing_nodes_override: list[EntityNode] | None = None,
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
llm_client = clients.llm_client llm_client = clients.llm_client
@ -223,7 +221,7 @@ async def resolve_extracted_nodes(
], ],
) )
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {} entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
# Prepare context for LLM # Prepare context for LLM
extracted_nodes_context = [ extracted_nodes_context = [
@ -297,7 +295,7 @@ async def extract_attributes_from_nodes(
nodes: list[EntityNode], nodes: list[EntityNode],
episode: EpisodicNode | None = None, episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None, previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
) -> list[EntityNode]: ) -> list[EntityNode]:
llm_client = clients.llm_client llm_client = clients.llm_client
embedder = clients.embedder embedder = clients.embedder
@ -326,7 +324,7 @@ async def extract_attributes_from_node(
node: EntityNode, node: EntityNode,
episode: EpisodicNode | None = None, episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None, previous_episodes: list[EpisodicNode] | None = None,
entity_type: BaseModel | None = None, entity_type: type[BaseModel] | None = None,
) -> EntityNode: ) -> EntityNode:
node_context: dict[str, Any] = { node_context: dict[str, Any] = {
'name': node.name, 'name': node.name,
@ -335,25 +333,14 @@ async def extract_attributes_from_node(
'attributes': node.attributes, 'attributes': node.attributes,
} }
attributes_definitions: dict[str, Any] = { attributes_context: dict[str, Any] = {
'summary': ( 'node': node_context,
str, 'episode_content': episode.content if episode is not None else '',
Field( 'previous_episodes': [ep.content for ep in previous_episodes]
description='Summary containing the important information about the entity. Under 250 words', if previous_episodes is not None
), else [],
)
} }
if entity_type is not None:
for field_name, field_info in entity_type.model_fields.items():
attributes_definitions[field_name] = (
field_info.annotation,
Field(description=field_info.description),
)
unique_model_name = f'EntityAttributes_{uuid4().hex}'
entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
summary_context: dict[str, Any] = { summary_context: dict[str, Any] = {
'node': node_context, 'node': node_context,
'episode_content': episode.content if episode is not None else '', 'episode_content': episode.content if episode is not None else '',
@ -362,20 +349,30 @@ async def extract_attributes_from_node(
else [], else [],
} }
llm_response = await llm_client.generate_response( llm_response = (
prompt_library.extract_nodes.extract_attributes(summary_context), (
response_model=entity_attributes_model, await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(attributes_context),
response_model=entity_type,
model_size=ModelSize.small,
)
)
if entity_type is not None
else {}
)
summary_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small, model_size=ModelSize.small,
) )
entity_attributes_model(**llm_response) if entity_type is not None:
entity_type(**llm_response)
node.summary = llm_response.get('summary', '') node.summary = summary_response.get('summary', '')
node_attributes = {key: value for key, value in llm_response.items()} node_attributes = {key: value for key, value in llm_response.items()}
with suppress(KeyError):
del node_attributes['summary']
node.attributes.update(node_attributes) node.attributes.update(node_attributes)
return node return node

View file

@ -21,7 +21,7 @@ from graphiti_core.nodes import EntityNode
def validate_entity_types( def validate_entity_types(
entity_types: dict[str, BaseModel] | None, entity_types: dict[str, type[BaseModel]] | None,
) -> bool: ) -> bool:
if entity_types is None: if entity_types is None:
return True return True

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.18.1" version = "0.18.2"
authors = [ authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" },

View file

@ -64,7 +64,8 @@ async def test_graphiti_init(driver):
await graphiti.build_indices_and_constraints() await graphiti.build_indices_and_constraints()
search_filter = SearchFilters( search_filter = SearchFilters(
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]] node_labels=['Person'],
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]],
) )
results = await graphiti.search_( results = await graphiti.search_(

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.18.1" version = "0.18.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },