diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py
index 2c53889e..146d20e6 100644
--- a/examples/podcast/podcast_runner.py
+++ b/examples/podcast/podcast_runner.py
@@ -63,6 +63,10 @@ class Person(BaseModel):
occupation: str | None = Field(..., description="The person's work occupation")
+class IsPresidentOf(BaseModel):
+ """Relationship between a person and the entity they are a president of"""
+
+
async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
@@ -84,6 +88,8 @@ async def main():
source_description='Podcast Transcript',
group_id=group_id,
entity_types={'Person': Person},
+ edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
+ edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
previous_episode_uuids=episode_uuids,
)
diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py
index 318d4c9f..16fee8d9 100644
--- a/graphiti_core/prompts/dedupe_nodes.py
+++ b/graphiti_core/prompts/dedupe_nodes.py
@@ -137,8 +137,12 @@ def nodes(context: dict[str, Any]) -> list[Message]:
{json.dumps(context['extracted_nodes'], indent=2)}
+
+
+ {json.dumps(context['existing_nodes'], indent=2)}
+
- For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
+ For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
@@ -152,9 +156,9 @@ def nodes(context: dict[str, Any]) -> list[Message]:
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
as an integer.
- - If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
+ - If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a
duplicate of.
- - If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
+ - If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx
""",
),
]
diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py
index 37db4699..cd73edfa 100644
--- a/graphiti_core/prompts/extract_edges.py
+++ b/graphiti_core/prompts/extract_edges.py
@@ -24,8 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
class Edge(BaseModel):
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
- source_entity_name: str = Field(..., description='The name of the source entity of the fact.')
- target_entity_name: str = Field(..., description='The name of the target entity of the fact.')
+ source_entity_id: int = Field(..., description='The id of the source entity of the fact.')
+ target_entity_id: int = Field(..., description='The id of the target entity of the fact.')
fact: str = Field(..., description='')
valid_at: str | None = Field(
None,
@@ -77,7 +77,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
-{context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases
+{context['nodes']}
@@ -94,8 +94,9 @@ Only extract facts that:
- involve two DISTINCT ENTITIES from the ENTITIES list,
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
and can be represented as edges in a knowledge graph.
-- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
- could be classified into one of the provided fact types
+- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
+- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
+ of the FACT TYPES
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py
index 332706b7..9c90a8e9 100644
--- a/graphiti_core/utils/maintenance/edge_operations.py
+++ b/graphiti_core/utils/maintenance/edge_operations.py
@@ -92,8 +92,6 @@ async def extract_edges(
extract_edges_max_tokens = 16384
llm_client = clients.llm_client
- node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
-
edge_types_context = (
[
{
@@ -109,7 +107,7 @@ async def extract_edges(
# Prepare context for LLM
context = {
'episode_content': episode.content,
- 'nodes': [node.name for node in nodes],
+ 'nodes': [{'id': idx, 'name': node.name} for idx, node in enumerate(nodes)],
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_time': episode.valid_at,
'edge_types': edge_types_context,
@@ -160,14 +158,16 @@ async def extract_edges(
invalid_at = edge_data.get('invalid_at', None)
valid_at_datetime = None
invalid_at_datetime = None
- source_node_uuid = node_uuids_by_name_map.get(edge_data.get('source_entity_name', ''), '')
- target_node_uuid = node_uuids_by_name_map.get(edge_data.get('target_entity_name', ''), '')
- if source_node_uuid == '' or target_node_uuid == '':
+ source_node_idx = edge_data.get('source_entity_id', -1)
+ target_node_idx = edge_data.get('target_entity_id', -1)
+ if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
logger.warning(
- f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_uuid} and target_node_uuid: {target_node_uuid} '
+ f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
)
continue
+ source_node_uuid = nodes[source_node_idx].uuid
+ target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
if valid_at:
try:
diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py
index 2b3de99e..ac572765 100644
--- a/graphiti_core/utils/maintenance/node_operations.py
+++ b/graphiti_core/utils/maintenance/node_operations.py
@@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library
-from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
+from graphiti_core.prompts.dedupe_nodes import NodeResolutions
from graphiti_core.prompts.extract_nodes import (
ExtractedEntities,
ExtractedEntity,
@@ -241,7 +241,25 @@ async def resolve_extracted_nodes(
]
)
- existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
+ existing_nodes_dict: dict[str, EntityNode] = {
+ node.uuid: node for result in search_results for node in result.nodes
+ }
+
+ existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
+
+ existing_nodes_context = (
+ [
+ {
+ **{
+ 'idx': i,
+ 'name': candidate.name,
+ 'entity_types': candidate.labels,
+ },
+ **candidate.attributes,
+ }
+ for i, candidate in enumerate(existing_nodes)
+ ],
+ )
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
@@ -255,23 +273,13 @@ async def resolve_extracted_nodes(
next((item for item in node.labels if item != 'Entity'), '')
).__doc__
or 'Default Entity Type',
- 'duplication_candidates': [
- {
- **{
- 'idx': j,
- 'name': candidate.name,
- 'entity_types': candidate.labels,
- },
- **candidate.attributes,
- }
- for j, candidate in enumerate(existing_nodes_lists[i])
- ],
}
for i, node in enumerate(extracted_nodes)
]
context = {
'extracted_nodes': extracted_nodes_context,
+ 'existing_nodes': existing_nodes_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
@@ -294,8 +302,8 @@ async def resolve_extracted_nodes(
extracted_node = extracted_nodes[resolution_id]
resolved_node = (
- existing_nodes_lists[resolution_id][duplicate_idx]
- if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
+ existing_nodes[duplicate_idx]
+ if 0 <= duplicate_idx < len(existing_nodes)
else extracted_node
)
@@ -309,70 +317,6 @@ async def resolve_extracted_nodes(
return resolved_nodes, uuid_map
-async def resolve_extracted_node(
- llm_client: LLMClient,
- extracted_node: EntityNode,
- existing_nodes: list[EntityNode],
- episode: EpisodicNode | None = None,
- previous_episodes: list[EpisodicNode] | None = None,
- entity_type: BaseModel | None = None,
-) -> EntityNode:
- start = time()
- if len(existing_nodes) == 0:
- return extracted_node
-
- # Prepare context for LLM
- existing_nodes_context = [
- {
- **{
- 'id': i,
- 'name': node.name,
- 'entity_types': node.labels,
- },
- **node.attributes,
- }
- for i, node in enumerate(existing_nodes)
- ]
-
- extracted_node_context = {
- 'name': extracted_node.name,
- 'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
- }
-
- context = {
- 'existing_nodes': existing_nodes_context,
- 'extracted_node': extracted_node_context,
- 'entity_type_description': entity_type.__doc__
- if entity_type is not None
- else 'Default Entity Type',
- 'episode_content': episode.content if episode is not None else '',
- 'previous_episodes': [ep.content for ep in previous_episodes]
- if previous_episodes is not None
- else [],
- }
-
- llm_response = await llm_client.generate_response(
- prompt_library.dedupe_nodes.node(context),
- response_model=NodeDuplicate,
- model_size=ModelSize.small,
- )
-
- duplicate_id: int = llm_response.get('duplicate_node_id', -1)
-
- node = (
- existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
- )
-
- node.name = llm_response.get('name', '')
-
- end = time()
- logger.debug(
- f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
- )
-
- return node
-
-
async def extract_attributes_from_nodes(
clients: GraphitiClients,
nodes: list[EntityNode],
diff --git a/pyproject.toml b/pyproject.toml
index e94b0cd0..64e46bef 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
-version = "0.12.0pre4"
+version = "0.12.0"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },