diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index b357f3b1..03b1e402 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -359,10 +359,10 @@ class Graphiti:
group_id: str | None = None,
uuid: str | None = None,
update_communities: bool = False,
- entity_types: dict[str, BaseModel] | None = None,
+ entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
previous_episode_uuids: list[str] | None = None,
- edge_types: dict[str, BaseModel] | None = None,
+ edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
) -> AddEpisodeResults:
"""
@@ -555,9 +555,9 @@ class Graphiti:
self,
bulk_episodes: list[RawEpisode],
group_id: str | None = None,
- entity_types: dict[str, BaseModel] | None = None,
+ entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
- edge_types: dict[str, BaseModel] | None = None,
+ edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
):
"""
diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py
index 2bf68b81..9feb3073 100644
--- a/graphiti_core/helpers.py
+++ b/graphiti_core/helpers.py
@@ -148,7 +148,7 @@ def validate_group_id(group_id: str) -> bool:
def validate_excluded_entity_types(
- excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None
+ excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
) -> bool:
"""
Validate that excluded entity types are valid type names.
diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py
index 59d07c88..57a2a577 100644
--- a/graphiti_core/prompts/extract_nodes.py
+++ b/graphiti_core/prompts/extract_nodes.py
@@ -52,6 +52,13 @@ class EntityClassification(BaseModel):
)
+class EntitySummary(BaseModel):
+ summary: str = Field(
+ ...,
+ description='Summary containing the important information about the entity. Under 250 words',
+ )
+
+
class Prompt(Protocol):
extract_message: PromptVersion
extract_json: PromptVersion
@@ -59,6 +66,7 @@ class Prompt(Protocol):
reflexion: PromptVersion
classify_nodes: PromptVersion
extract_attributes: PromptVersion
+ extract_summary: PromptVersion
class Versions(TypedDict):
@@ -68,6 +76,7 @@ class Versions(TypedDict):
reflexion: PromptFunction
classify_nodes: PromptFunction
extract_attributes: PromptFunction
+ extract_summary: PromptFunction
def extract_message(context: dict[str, Any]) -> list[Message]:
@@ -259,9 +268,39 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. Do not hallucinate entity property values if they cannot be found in the current context.
2. Only use the provided MESSAGES and ENTITY to set attribute values.
+
+
+ {context['node']}
+
+ """,
+ ),
+ ]
+
+
+def extract_summary(context: dict[str, Any]) -> list[Message]:
+ return [
+ Message(
+ role='system',
+ content='You are a helpful assistant that extracts entity summaries from the provided text.',
+ ),
+ Message(
+ role='user',
+ content=f"""
+
+
+ {json.dumps(context['previous_episodes'], indent=2)}
+ {json.dumps(context['episode_content'], indent=2)}
+
+
+ Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity
+ from the messages and relevant information from the existing summary.
+
+ Guidelines:
+ 1. Do not hallucinate entity summary information if they cannot be found in the current context.
+ 2. Only use the provided MESSAGES and ENTITY to set attribute values.
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
Summaries must be no longer than 250 words.
-
+
{context['node']}
@@ -275,6 +314,7 @@ versions: Versions = {
'extract_json': extract_json,
'extract_text': extract_text,
'reflexion': reflexion,
+ 'extract_summary': extract_summary,
'classify_nodes': classify_nodes,
'extract_attributes': extract_attributes,
}
diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py
index af9dce67..7c266c94 100644
--- a/graphiti_core/search/search_utils.py
+++ b/graphiti_core/search/search_utils.py
@@ -314,17 +314,15 @@ async def node_fulltext_search(
+ """
YIELD node AS n, score
WHERE n:Entity AND n.group_id IN $group_ids
- WITH n, score
- LIMIT $limit
"""
+ filter_query
+ """
+ WITH n, score
+ ORDER BY score DESC
+ LIMIT $limit
RETURN
"""
+ ENTITY_NODE_RETURN
- + """
- ORDER BY score DESC
- """
)
records, _, _ = await driver.execute_query(
diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py
index 9263edf3..b80c4f3b 100644
--- a/graphiti_core/utils/bulk_utils.py
+++ b/graphiti_core/utils/bulk_utils.py
@@ -169,9 +169,9 @@ async def extract_nodes_and_edges_bulk(
clients: GraphitiClients,
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
edge_type_map: dict[tuple[str, str], list[str]],
- entity_types: dict[str, BaseModel] | None = None,
+ entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
- edge_types: dict[str, BaseModel] | None = None,
+ edge_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
*[
@@ -202,7 +202,7 @@ async def dedupe_nodes_bulk(
clients: GraphitiClients,
extracted_nodes: list[list[EntityNode]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
- entity_types: dict[str, BaseModel] | None = None,
+ entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
embedder = clients.embedder
min_score = 0.8
@@ -290,7 +290,7 @@ async def dedupe_edges_bulk(
extracted_edges: list[list[EntityEdge]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
_entities: list[EntityNode],
- edge_types: dict[str, BaseModel],
+ edge_types: dict[str, type[BaseModel]],
_edge_type_map: dict[tuple[str, str], list[str]],
) -> dict[str, list[EntityEdge]]:
embedder = clients.embedder
diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py
index 1455653a..ef78db43 100644
--- a/graphiti_core/utils/maintenance/edge_operations.py
+++ b/graphiti_core/utils/maintenance/edge_operations.py
@@ -114,7 +114,7 @@ async def extract_edges(
previous_episodes: list[EpisodicNode],
edge_type_map: dict[tuple[str, str], list[str]],
group_id: str = '',
- edge_types: dict[str, BaseModel] | None = None,
+ edge_types: dict[str, type[BaseModel]] | None = None,
) -> list[EntityEdge]:
start = time()
@@ -249,7 +249,7 @@ async def resolve_extracted_edges(
extracted_edges: list[EntityEdge],
episode: EpisodicNode,
entities: list[EntityNode],
- edge_types: dict[str, BaseModel],
+ edge_types: dict[str, type[BaseModel]],
edge_type_map: dict[tuple[str, str], list[str]],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
driver = clients.driver
@@ -272,7 +272,7 @@ async def resolve_extracted_edges(
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
# Determine which edge types are relevant for each edge
- edge_types_lst: list[dict[str, BaseModel]] = []
+ edge_types_lst: list[dict[str, type[BaseModel]]] = []
for extracted_edge in extracted_edges:
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
@@ -381,7 +381,7 @@ async def resolve_extracted_edge(
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
episode: EpisodicNode,
- edge_types: dict[str, BaseModel] | None = None,
+ edge_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, [], []
diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py
index 9c3256bf..51a7be4b 100644
--- a/graphiti_core/utils/maintenance/node_operations.py
+++ b/graphiti_core/utils/maintenance/node_operations.py
@@ -15,13 +15,10 @@ limitations under the License.
"""
import logging
-from contextlib import suppress
from time import time
from typing import Any
-from uuid import uuid4
-import pydantic
-from pydantic import BaseModel, Field
+from pydantic import BaseModel
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
@@ -31,6 +28,7 @@ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_en
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
+ EntitySummary,
ExtractedEntities,
ExtractedEntity,
MissedEntities,
@@ -70,7 +68,7 @@ async def extract_nodes(
clients: GraphitiClients,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
- entity_types: dict[str, BaseModel] | None = None,
+ entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
) -> list[EntityNode]:
start = time()
@@ -180,7 +178,7 @@ async def resolve_extracted_nodes(
extracted_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
- entity_types: dict[str, BaseModel] | None = None,
+ entity_types: dict[str, type[BaseModel]] | None = None,
existing_nodes_override: list[EntityNode] | None = None,
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
llm_client = clients.llm_client
@@ -223,7 +221,7 @@ async def resolve_extracted_nodes(
],
)
- entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
+ entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
# Prepare context for LLM
extracted_nodes_context = [
@@ -297,7 +295,7 @@ async def extract_attributes_from_nodes(
nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
- entity_types: dict[str, BaseModel] | None = None,
+ entity_types: dict[str, type[BaseModel]] | None = None,
) -> list[EntityNode]:
llm_client = clients.llm_client
embedder = clients.embedder
@@ -326,7 +324,7 @@ async def extract_attributes_from_node(
node: EntityNode,
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
- entity_type: BaseModel | None = None,
+ entity_type: type[BaseModel] | None = None,
) -> EntityNode:
node_context: dict[str, Any] = {
'name': node.name,
@@ -335,25 +333,14 @@ async def extract_attributes_from_node(
'attributes': node.attributes,
}
- attributes_definitions: dict[str, Any] = {
- 'summary': (
- str,
- Field(
- description='Summary containing the important information about the entity. Under 250 words',
- ),
- )
+ attributes_context: dict[str, Any] = {
+ 'node': node_context,
+ 'episode_content': episode.content if episode is not None else '',
+ 'previous_episodes': [ep.content for ep in previous_episodes]
+ if previous_episodes is not None
+ else [],
}
- if entity_type is not None:
- for field_name, field_info in entity_type.model_fields.items():
- attributes_definitions[field_name] = (
- field_info.annotation,
- Field(description=field_info.description),
- )
-
- unique_model_name = f'EntityAttributes_{uuid4().hex}'
- entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
-
summary_context: dict[str, Any] = {
'node': node_context,
'episode_content': episode.content if episode is not None else '',
@@ -362,20 +349,30 @@ async def extract_attributes_from_node(
else [],
}
- llm_response = await llm_client.generate_response(
- prompt_library.extract_nodes.extract_attributes(summary_context),
- response_model=entity_attributes_model,
+ llm_response = (
+ (
+ await llm_client.generate_response(
+ prompt_library.extract_nodes.extract_attributes(attributes_context),
+ response_model=entity_type,
+ model_size=ModelSize.small,
+ )
+ )
+ if entity_type is not None
+ else {}
+ )
+
+ summary_response = await llm_client.generate_response(
+ prompt_library.extract_nodes.extract_summary(summary_context),
+ response_model=EntitySummary,
model_size=ModelSize.small,
)
- entity_attributes_model(**llm_response)
+ if entity_type is not None:
+ entity_type(**llm_response)
- node.summary = llm_response.get('summary', '')
+ node.summary = summary_response.get('summary', '')
node_attributes = {key: value for key, value in llm_response.items()}
- with suppress(KeyError):
- del node_attributes['summary']
-
node.attributes.update(node_attributes)
return node
diff --git a/graphiti_core/utils/ontology_utils/entity_types_utils.py b/graphiti_core/utils/ontology_utils/entity_types_utils.py
index f6cb08fb..bbc07af7 100644
--- a/graphiti_core/utils/ontology_utils/entity_types_utils.py
+++ b/graphiti_core/utils/ontology_utils/entity_types_utils.py
@@ -21,7 +21,7 @@ from graphiti_core.nodes import EntityNode
def validate_entity_types(
- entity_types: dict[str, BaseModel] | None,
+ entity_types: dict[str, type[BaseModel]] | None,
) -> bool:
if entity_types is None:
return True
diff --git a/pyproject.toml b/pyproject.toml
index 6d33a866..41639b64 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
-version = "0.18.1"
+version = "0.18.2"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py
index 9d98f4de..237e2bfc 100644
--- a/tests/test_graphiti_int.py
+++ b/tests/test_graphiti_int.py
@@ -64,7 +64,8 @@ async def test_graphiti_init(driver):
await graphiti.build_indices_and_constraints()
search_filter = SearchFilters(
- created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]]
+ node_labels=['Person'],
+ created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]],
)
results = await graphiti.search_(
diff --git a/uv.lock b/uv.lock
index 6fcf17f6..cc771368 100644
--- a/uv.lock
+++ b/uv.lock
@@ -746,7 +746,7 @@ wheels = [
[[package]]
name = "graphiti-core"
-version = "0.18.1"
+version = "0.18.2"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },