Merge branch 'main' into ollama_embedder

This commit is contained in:
kavenGw 2025-08-04 15:59:57 +08:00
commit b7f1716c58
12 changed files with 101 additions and 57 deletions

View file

@ -359,10 +359,10 @@ class Graphiti:
group_id: str | None = None,
uuid: str | None = None,
update_communities: bool = False,
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
previous_episode_uuids: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
) -> AddEpisodeResults:
"""
@ -555,9 +555,9 @@ class Graphiti:
self,
bulk_episodes: list[RawEpisode],
group_id: str | None = None,
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
):
"""

View file

@ -148,7 +148,7 @@ def validate_group_id(group_id: str) -> bool:
def validate_excluded_entity_types(
excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None
excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
) -> bool:
"""
Validate that excluded entity types are valid type names.

View file

@ -52,6 +52,13 @@ class EntityClassification(BaseModel):
)
class EntitySummary(BaseModel):
summary: str = Field(
...,
description='Summary containing the important information about the entity. Under 250 words',
)
class Prompt(Protocol):
extract_message: PromptVersion
extract_json: PromptVersion
@ -59,6 +66,7 @@ class Prompt(Protocol):
reflexion: PromptVersion
classify_nodes: PromptVersion
extract_attributes: PromptVersion
extract_summary: PromptVersion
class Versions(TypedDict):
@ -68,6 +76,7 @@ class Versions(TypedDict):
reflexion: PromptFunction
classify_nodes: PromptFunction
extract_attributes: PromptFunction
extract_summary: PromptFunction
def extract_message(context: dict[str, Any]) -> list[Message]:
@ -259,9 +268,39 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. Do not hallucinate entity property values if they cannot be found in the current context.
2. Only use the provided MESSAGES and ENTITY to set attribute values.
<ENTITY>
{context['node']}
</ENTITY>
""",
),
]
def extract_summary(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that extracts entity summaries from the provided text.',
),
Message(
role='user',
content=f"""
<MESSAGES>
{json.dumps(context['previous_episodes'], indent=2)}
{json.dumps(context['episode_content'], indent=2)}
</MESSAGES>
Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity
from the messages and relevant information from the existing summary.
Guidelines:
1. Do not hallucinate entity summary information if they cannot be found in the current context.
2. Only use the provided MESSAGES and ENTITY to set attribute values.
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
Summaries must be no longer than 250 words.
<ENTITY>
{context['node']}
</ENTITY>
@ -275,6 +314,7 @@ versions: Versions = {
'extract_json': extract_json,
'extract_text': extract_text,
'reflexion': reflexion,
'extract_summary': extract_summary,
'classify_nodes': classify_nodes,
'extract_attributes': extract_attributes,
}

View file

@ -314,17 +314,15 @@ async def node_fulltext_search(
+ """
YIELD node AS n, score
WHERE n:Entity AND n.group_id IN $group_ids
WITH n, score
LIMIT $limit
"""
+ filter_query
+ """
WITH n, score
ORDER BY score DESC
LIMIT $limit
RETURN
"""
+ ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
"""
)
records, _, _ = await driver.execute_query(

View file

@ -169,9 +169,9 @@ async def extract_nodes_and_edges_bulk(
clients: GraphitiClients,
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
edge_type_map: dict[tuple[str, str], list[str]],
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
*[
@ -202,7 +202,7 @@ async def dedupe_nodes_bulk(
clients: GraphitiClients,
extracted_nodes: list[list[EntityNode]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
embedder = clients.embedder
min_score = 0.8
@ -290,7 +290,7 @@ async def dedupe_edges_bulk(
extracted_edges: list[list[EntityEdge]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
_entities: list[EntityNode],
edge_types: dict[str, BaseModel],
edge_types: dict[str, type[BaseModel]],
_edge_type_map: dict[tuple[str, str], list[str]],
) -> dict[str, list[EntityEdge]]:
embedder = clients.embedder

View file

@ -114,7 +114,7 @@ async def extract_edges(
previous_episodes: list[EpisodicNode],
edge_type_map: dict[tuple[str, str], list[str]],
group_id: str = '',
edge_types: dict[str, BaseModel] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None,
) -> list[EntityEdge]:
start = time()
@ -249,7 +249,7 @@ async def resolve_extracted_edges(
extracted_edges: list[EntityEdge],
episode: EpisodicNode,
entities: list[EntityNode],
edge_types: dict[str, BaseModel],
edge_types: dict[str, type[BaseModel]],
edge_type_map: dict[tuple[str, str], list[str]],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
driver = clients.driver
@ -272,7 +272,7 @@ async def resolve_extracted_edges(
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
# Determine which edge types are relevant for each edge
edge_types_lst: list[dict[str, BaseModel]] = []
edge_types_lst: list[dict[str, type[BaseModel]]] = []
for extracted_edge in extracted_edges:
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
@ -381,7 +381,7 @@ async def resolve_extracted_edge(
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
episode: EpisodicNode,
edge_types: dict[str, BaseModel] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, [], []

View file

@ -15,13 +15,10 @@ limitations under the License.
"""
import logging
from contextlib import suppress
from time import time
from typing import Any
from uuid import uuid4
import pydantic
from pydantic import BaseModel, Field
from pydantic import BaseModel
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
@ -31,6 +28,7 @@ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_en
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
EntitySummary,
ExtractedEntities,
ExtractedEntity,
MissedEntities,
@ -70,7 +68,7 @@ async def extract_nodes(
clients: GraphitiClients,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
) -> list[EntityNode]:
start = time()
@ -180,7 +178,7 @@ async def resolve_extracted_nodes(
extracted_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
existing_nodes_override: list[EntityNode] | None = None,
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
llm_client = clients.llm_client
@ -223,7 +221,7 @@ async def resolve_extracted_nodes(
],
)
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
# Prepare context for LLM
extracted_nodes_context = [
@ -297,7 +295,7 @@ async def extract_attributes_from_nodes(
nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
) -> list[EntityNode]:
llm_client = clients.llm_client
embedder = clients.embedder
@ -326,7 +324,7 @@ async def extract_attributes_from_node(
node: EntityNode,
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_type: BaseModel | None = None,
entity_type: type[BaseModel] | None = None,
) -> EntityNode:
node_context: dict[str, Any] = {
'name': node.name,
@ -335,25 +333,14 @@ async def extract_attributes_from_node(
'attributes': node.attributes,
}
attributes_definitions: dict[str, Any] = {
'summary': (
str,
Field(
description='Summary containing the important information about the entity. Under 250 words',
),
)
attributes_context: dict[str, Any] = {
'node': node_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
}
if entity_type is not None:
for field_name, field_info in entity_type.model_fields.items():
attributes_definitions[field_name] = (
field_info.annotation,
Field(description=field_info.description),
)
unique_model_name = f'EntityAttributes_{uuid4().hex}'
entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
summary_context: dict[str, Any] = {
'node': node_context,
'episode_content': episode.content if episode is not None else '',
@ -362,20 +349,30 @@ async def extract_attributes_from_node(
else [],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(summary_context),
response_model=entity_attributes_model,
llm_response = (
(
await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(attributes_context),
response_model=entity_type,
model_size=ModelSize.small,
)
)
if entity_type is not None
else {}
)
summary_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
)
entity_attributes_model(**llm_response)
if entity_type is not None:
entity_type(**llm_response)
node.summary = llm_response.get('summary', '')
node.summary = summary_response.get('summary', '')
node_attributes = {key: value for key, value in llm_response.items()}
with suppress(KeyError):
del node_attributes['summary']
node.attributes.update(node_attributes)
return node

View file

@ -21,7 +21,7 @@ from graphiti_core.nodes import EntityNode
def validate_entity_types(
entity_types: dict[str, BaseModel] | None,
entity_types: dict[str, type[BaseModel]] | None,
) -> bool:
if entity_types is None:
return True

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.18.1"
version = "0.18.2"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

View file

@ -255,6 +255,14 @@
"created_at": "2025-07-29T20:00:27Z",
"repoId": 840056306,
"pullRequestNo": 782
},
{
"name": "bechbd",
"id": 6898505,
"comment_id": 3140501814,
"created_at": "2025-07-31T15:58:08Z",
"repoId": 840056306,
"pullRequestNo": 793
}
]
}

View file

@ -64,7 +64,8 @@ async def test_graphiti_init(driver):
await graphiti.build_indices_and_constraints()
search_filter = SearchFilters(
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]]
node_labels=['Person'],
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]],
)
results = await graphiti.search_(

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.18.1"
version = "0.18.2"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },