move summary out of attribute extraction (#792)
* move summary out of attribute extraction * linter * linter * fix db query
This commit is contained in:
parent
e5112244e5
commit
ab8106cb4f
11 changed files with 93 additions and 57 deletions
|
|
@ -359,10 +359,10 @@ class Graphiti:
|
|||
group_id: str | None = None,
|
||||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
previous_episode_uuids: list[str] | None = None,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
) -> AddEpisodeResults:
|
||||
"""
|
||||
|
|
@ -555,9 +555,9 @@ class Graphiti:
|
|||
self,
|
||||
bulk_episodes: list[RawEpisode],
|
||||
group_id: str | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ def validate_group_id(group_id: str) -> bool:
|
|||
|
||||
|
||||
def validate_excluded_entity_types(
|
||||
excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None
|
||||
excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that excluded entity types are valid type names.
|
||||
|
|
|
|||
|
|
@ -52,6 +52,13 @@ class EntityClassification(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class EntitySummary(BaseModel):
|
||||
summary: str = Field(
|
||||
...,
|
||||
description='Summary containing the important information about the entity. Under 250 words',
|
||||
)
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
extract_message: PromptVersion
|
||||
extract_json: PromptVersion
|
||||
|
|
@ -59,6 +66,7 @@ class Prompt(Protocol):
|
|||
reflexion: PromptVersion
|
||||
classify_nodes: PromptVersion
|
||||
extract_attributes: PromptVersion
|
||||
extract_summary: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
|
|
@ -68,6 +76,7 @@ class Versions(TypedDict):
|
|||
reflexion: PromptFunction
|
||||
classify_nodes: PromptFunction
|
||||
extract_attributes: PromptFunction
|
||||
extract_summary: PromptFunction
|
||||
|
||||
|
||||
def extract_message(context: dict[str, Any]) -> list[Message]:
|
||||
|
|
@ -259,9 +268,39 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
|||
Guidelines:
|
||||
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
||||
2. Only use the provided MESSAGES and ENTITY to set attribute values.
|
||||
|
||||
<ENTITY>
|
||||
{context['node']}
|
||||
</ENTITY>
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def extract_summary(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that extracts entity summaries from the provided text.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
|
||||
<MESSAGES>
|
||||
{json.dumps(context['previous_episodes'], indent=2)}
|
||||
{json.dumps(context['episode_content'], indent=2)}
|
||||
</MESSAGES>
|
||||
|
||||
Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity
|
||||
from the messages and relevant information from the existing summary.
|
||||
|
||||
Guidelines:
|
||||
1. Do not hallucinate entity summary information if they cannot be found in the current context.
|
||||
2. Only use the provided MESSAGES and ENTITY to set attribute values.
|
||||
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
|
||||
Summaries must be no longer than 250 words.
|
||||
|
||||
|
||||
<ENTITY>
|
||||
{context['node']}
|
||||
</ENTITY>
|
||||
|
|
@ -275,6 +314,7 @@ versions: Versions = {
|
|||
'extract_json': extract_json,
|
||||
'extract_text': extract_text,
|
||||
'reflexion': reflexion,
|
||||
'extract_summary': extract_summary,
|
||||
'classify_nodes': classify_nodes,
|
||||
'extract_attributes': extract_attributes,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -314,17 +314,15 @@ async def node_fulltext_search(
|
|||
+ """
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity AND n.group_id IN $group_ids
|
||||
WITH n, score
|
||||
LIMIT $limit
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
|
|||
|
|
@ -169,9 +169,9 @@ async def extract_nodes_and_edges_bulk(
|
|||
clients: GraphitiClients,
|
||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
edge_type_map: dict[tuple[str, str], list[str]],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
||||
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -202,7 +202,7 @@ async def dedupe_nodes_bulk(
|
|||
clients: GraphitiClients,
|
||||
extracted_nodes: list[list[EntityNode]],
|
||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
||||
embedder = clients.embedder
|
||||
min_score = 0.8
|
||||
|
|
@ -290,7 +290,7 @@ async def dedupe_edges_bulk(
|
|||
extracted_edges: list[list[EntityEdge]],
|
||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
_entities: list[EntityNode],
|
||||
edge_types: dict[str, BaseModel],
|
||||
edge_types: dict[str, type[BaseModel]],
|
||||
_edge_type_map: dict[tuple[str, str], list[str]],
|
||||
) -> dict[str, list[EntityEdge]]:
|
||||
embedder = clients.embedder
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ async def extract_edges(
|
|||
previous_episodes: list[EpisodicNode],
|
||||
edge_type_map: dict[tuple[str, str], list[str]],
|
||||
group_id: str = '',
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
|
|
@ -249,7 +249,7 @@ async def resolve_extracted_edges(
|
|||
extracted_edges: list[EntityEdge],
|
||||
episode: EpisodicNode,
|
||||
entities: list[EntityNode],
|
||||
edge_types: dict[str, BaseModel],
|
||||
edge_types: dict[str, type[BaseModel]],
|
||||
edge_type_map: dict[tuple[str, str], list[str]],
|
||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||
driver = clients.driver
|
||||
|
|
@ -272,7 +272,7 @@ async def resolve_extracted_edges(
|
|||
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
||||
|
||||
# Determine which edge types are relevant for each edge
|
||||
edge_types_lst: list[dict[str, BaseModel]] = []
|
||||
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
||||
for extracted_edge in extracted_edges:
|
||||
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
||||
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
||||
|
|
@ -381,7 +381,7 @@ async def resolve_extracted_edge(
|
|||
related_edges: list[EntityEdge],
|
||||
existing_edges: list[EntityEdge],
|
||||
episode: EpisodicNode,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||
return extracted_edge, [], []
|
||||
|
|
|
|||
|
|
@ -15,13 +15,10 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from time import time
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
|
|
@ -31,6 +28,7 @@ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_en
|
|||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
||||
from graphiti_core.prompts.extract_nodes import (
|
||||
EntitySummary,
|
||||
ExtractedEntities,
|
||||
ExtractedEntity,
|
||||
MissedEntities,
|
||||
|
|
@ -70,7 +68,7 @@ async def extract_nodes(
|
|||
clients: GraphitiClients,
|
||||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
|
|
@ -180,7 +178,7 @@ async def resolve_extracted_nodes(
|
|||
extracted_nodes: list[EntityNode],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
existing_nodes_override: list[EntityNode] | None = None,
|
||||
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
||||
llm_client = clients.llm_client
|
||||
|
|
@ -223,7 +221,7 @@ async def resolve_extracted_nodes(
|
|||
],
|
||||
)
|
||||
|
||||
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
||||
entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
|
||||
|
||||
# Prepare context for LLM
|
||||
extracted_nodes_context = [
|
||||
|
|
@ -297,7 +295,7 @@ async def extract_attributes_from_nodes(
|
|||
nodes: list[EntityNode],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
) -> list[EntityNode]:
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
|
@ -326,7 +324,7 @@ async def extract_attributes_from_node(
|
|||
node: EntityNode,
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_type: BaseModel | None = None,
|
||||
entity_type: type[BaseModel] | None = None,
|
||||
) -> EntityNode:
|
||||
node_context: dict[str, Any] = {
|
||||
'name': node.name,
|
||||
|
|
@ -335,25 +333,14 @@ async def extract_attributes_from_node(
|
|||
'attributes': node.attributes,
|
||||
}
|
||||
|
||||
attributes_definitions: dict[str, Any] = {
|
||||
'summary': (
|
||||
str,
|
||||
Field(
|
||||
description='Summary containing the important information about the entity. Under 250 words',
|
||||
),
|
||||
)
|
||||
attributes_context: dict[str, Any] = {
|
||||
'node': node_context,
|
||||
'episode_content': episode.content if episode is not None else '',
|
||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||
if previous_episodes is not None
|
||||
else [],
|
||||
}
|
||||
|
||||
if entity_type is not None:
|
||||
for field_name, field_info in entity_type.model_fields.items():
|
||||
attributes_definitions[field_name] = (
|
||||
field_info.annotation,
|
||||
Field(description=field_info.description),
|
||||
)
|
||||
|
||||
unique_model_name = f'EntityAttributes_{uuid4().hex}'
|
||||
entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
|
||||
|
||||
summary_context: dict[str, Any] = {
|
||||
'node': node_context,
|
||||
'episode_content': episode.content if episode is not None else '',
|
||||
|
|
@ -362,20 +349,30 @@ async def extract_attributes_from_node(
|
|||
else [],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_attributes(summary_context),
|
||||
response_model=entity_attributes_model,
|
||||
llm_response = (
|
||||
(
|
||||
await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
||||
response_model=entity_type,
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
)
|
||||
if entity_type is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
summary_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||
response_model=EntitySummary,
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
|
||||
entity_attributes_model(**llm_response)
|
||||
if entity_type is not None:
|
||||
entity_type(**llm_response)
|
||||
|
||||
node.summary = llm_response.get('summary', '')
|
||||
node.summary = summary_response.get('summary', '')
|
||||
node_attributes = {key: value for key, value in llm_response.items()}
|
||||
|
||||
with suppress(KeyError):
|
||||
del node_attributes['summary']
|
||||
|
||||
node.attributes.update(node_attributes)
|
||||
|
||||
return node
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from graphiti_core.nodes import EntityNode
|
|||
|
||||
|
||||
def validate_entity_types(
|
||||
entity_types: dict[str, BaseModel] | None,
|
||||
entity_types: dict[str, type[BaseModel]] | None,
|
||||
) -> bool:
|
||||
if entity_types is None:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.18.1"
|
||||
version = "0.18.2"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -64,7 +64,8 @@ async def test_graphiti_init(driver):
|
|||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
search_filter = SearchFilters(
|
||||
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]]
|
||||
node_labels=['Person'],
|
||||
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]],
|
||||
)
|
||||
|
||||
results = await graphiti.search_(
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.18.1"
|
||||
version = "0.18.2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue