From db7595fe63fda51d3e42169340ed30cd00b45fed Mon Sep 17 00:00:00 2001
From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
Date: Mon, 19 May 2025 13:30:56 -0400
Subject: [PATCH] Edge types (#501)
* update entity edge attributes
* Adding prompts
* extract fact attributes
* edge types
* edge types no regressions
* mypy
* mypy update
* Update graphiti_core/prompts/dedupe_edges.py
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
* Update graphiti_core/prompts/dedupe_edges.py
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
* mypy
---------
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
---
graphiti_core/edges.py | 58 ++++++++---
graphiti_core/graphiti.py | 30 +++++-
graphiti_core/models/edges/edge_db_queries.py | 6 +-
graphiti_core/prompts/dedupe_edges.py | 8 ++
graphiti_core/prompts/extract_edges.py | 48 ++++++++-
graphiti_core/search/search_utils.py | 23 +++--
graphiti_core/utils/bulk_utils.py | 20 +++-
.../utils/maintenance/edge_operations.py | 97 ++++++++++++++++++-
8 files changed, 252 insertions(+), 38 deletions(-)
diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py
index 491afa2b..700775f3 100644
--- a/graphiti_core/edges.py
+++ b/graphiti_core/edges.py
@@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
- e.invalid_at AS invalid_at"""
+ e.invalid_at AS invalid_at,
+ properties(e) AS attributes
+ """
class Edge(BaseModel, ABC):
@@ -209,6 +211,9 @@ class EntityEdge(Edge):
invalid_at: datetime | None = Field(
default=None, description='datetime of when the fact stopped being true'
)
+ attributes: dict[str, Any] = Field(
+ default={}, description='Additional attributes of the edge. Dependent on edge name'
+ )
async def generate_embedding(self, embedder: EmbedderClient):
start = time()
@@ -236,20 +241,26 @@ class EntityEdge(Edge):
self.fact_embedding = records[0]['fact_embedding']
async def save(self, driver: AsyncDriver):
+ edge_data: dict[str, Any] = {
+ 'source_uuid': self.source_node_uuid,
+ 'target_uuid': self.target_node_uuid,
+ 'uuid': self.uuid,
+ 'name': self.name,
+ 'group_id': self.group_id,
+ 'fact': self.fact,
+ 'fact_embedding': self.fact_embedding,
+ 'episodes': self.episodes,
+ 'created_at': self.created_at,
+ 'expired_at': self.expired_at,
+ 'valid_at': self.valid_at,
+ 'invalid_at': self.invalid_at,
+ }
+
+ edge_data.update(self.attributes or {})
+
result = await driver.execute_query(
ENTITY_EDGE_SAVE,
- source_uuid=self.source_node_uuid,
- target_uuid=self.target_node_uuid,
- uuid=self.uuid,
- name=self.name,
- group_id=self.group_id,
- fact=self.fact,
- fact_embedding=self.fact_embedding,
- episodes=self.episodes,
- created_at=self.created_at,
- expired_at=self.expired_at,
- valid_at=self.valid_at,
- invalid_at=self.invalid_at,
+ edge_data=edge_data,
database_=DEFAULT_DATABASE,
)
@@ -334,8 +345,8 @@ class EntityEdge(Edge):
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
query: LiteralString = (
"""
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
- """
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
+ """
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(
@@ -457,7 +468,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
def get_entity_edge_from_record(record: Any) -> EntityEdge:
- return EntityEdge(
+ edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
@@ -469,8 +480,23 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
+ attributes=record['attributes'],
)
+ edge.attributes.pop('uuid', None)
+ edge.attributes.pop('source_node_uuid', None)
+ edge.attributes.pop('target_node_uuid', None)
+ edge.attributes.pop('fact', None)
+ edge.attributes.pop('name', None)
+ edge.attributes.pop('group_id', None)
+ edge.attributes.pop('episodes', None)
+ edge.attributes.pop('created_at', None)
+ edge.attributes.pop('expired_at', None)
+ edge.attributes.pop('valid_at', None)
+ edge.attributes.pop('invalid_at', None)
+
+ return edge
+
def get_community_edge_from_record(record: Any):
return CommunityEdge(
diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index 5417728f..930f66d4 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -273,6 +273,8 @@ class Graphiti:
update_communities: bool = False,
entity_types: dict[str, BaseModel] | None = None,
previous_episode_uuids: list[str] | None = None,
+ edge_types: dict[str, BaseModel] | None = None,
+ edge_type_map: dict[tuple[str, str], list[str]] | None = None,
) -> AddEpisodeResults:
"""
Process an episode and update the graph.
@@ -355,6 +357,13 @@ class Graphiti:
)
)
+ # Create default edge type map
+ edge_type_map_default = (
+ {('Entity', 'Entity'): list(edge_types.keys())}
+ if edge_types is not None
+ else {('Entity', 'Entity'): []}
+ )
+
# Extract entities as nodes
extracted_nodes = await extract_nodes(
@@ -370,7 +379,9 @@ class Graphiti:
previous_episodes,
entity_types,
),
- extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
+ extract_edges(
+ self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
+ ),
)
edges = resolve_edge_pointers(extracted_edges, uuid_map)
@@ -380,6 +391,9 @@ class Graphiti:
self.clients,
edges,
episode,
+ nodes,
+ edge_types or {},
+ edge_type_map or edge_type_map_default,
),
extract_attributes_from_nodes(
self.clients, nodes, episode, previous_episodes, entity_types
@@ -686,7 +700,19 @@ class Graphiti:
)[0]
resolved_edge, invalidated_edges = await resolve_extracted_edge(
- self.llm_client, updated_edge, related_edges, existing_edges
+ self.llm_client,
+ updated_edge,
+ related_edges,
+ existing_edges,
+ EpisodicNode(
+ name='',
+ source=EpisodeType.text,
+ source_description='',
+ content='',
+ valid_at=edge.valid_at or utc_now(),
+ entity_edges=[],
+ group_id=edge.group_id,
+ ),
)
await add_nodes_and_edges_bulk(
diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py
index ba687219..a1cf547f 100644
--- a/graphiti_core/models/edges/edge_db_queries.py
+++ b/graphiti_core/models/edges/edge_db_queries.py
@@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
- SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
- created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
+ SET r = $edge_data
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid"""
@@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
- SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
- created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
+ SET r = edge
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
RETURN edge.uuid AS uuid
"""
diff --git a/graphiti_core/prompts/dedupe_edges.py b/graphiti_core/prompts/dedupe_edges.py
index f63011d4..6ccbbf26 100644
--- a/graphiti_core/prompts/dedupe_edges.py
+++ b/graphiti_core/prompts/dedupe_edges.py
@@ -31,6 +31,7 @@ class EdgeDuplicate(BaseModel):
...,
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
)
+ fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
class UniqueFact(BaseModel):
@@ -133,11 +134,18 @@ def resolve_edge(context: dict[str, Any]) -> list[Message]:
{context['edge_invalidation_candidates']}
+
+ {context['edge_types']}
+
+
Task:
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
+ Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
+ Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
+
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
If there are no contradicted facts, return an empty list.
diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py
index e7f41cdb..37db4699 100644
--- a/graphiti_core/prompts/extract_edges.py
+++ b/graphiti_core/prompts/extract_edges.py
@@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
class Prompt(Protocol):
edge: PromptVersion
reflexion: PromptVersion
+ extract_attributes: PromptVersion
class Versions(TypedDict):
edge: PromptFunction
reflexion: PromptFunction
+ extract_attributes: PromptFunction
def edge(context: dict[str, Any]) -> list[Message]:
@@ -82,12 +84,18 @@ def edge(context: dict[str, Any]) -> list[Message]:
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
+
+{context['edge_types']}
+
+
# TASK
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
Only extract facts that:
- involve two DISTINCT ENTITIES from the ENTITIES list,
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
-- and can be represented as edges in a knowledge graph.
+ and can be represented as edges in a knowledge graph.
+- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
+ could be classified into one of the provided fact types
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
@@ -145,4 +153,40 @@ determine if any facts haven't been extracted.
]
-versions: Versions = {'edge': edge, 'reflexion': reflexion}
+def extract_attributes(context: dict[str, Any]) -> list[Message]:
+ return [
+ Message(
+ role='system',
+ content='You are a helpful assistant that extracts fact properties from the provided text.',
+ ),
+ Message(
+ role='user',
+ content=f"""
+
+
+ {json.dumps(context['episode_content'], indent=2)}
+
+
+ {context['reference_time']}
+
+
+ Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
+ in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
+
+ Guidelines:
+ 1. Do not hallucinate entity property values if they cannot be found in the current context.
+ 2. Only use the provided MESSAGES and FACT to set attribute values.
+
+
+ {context['fact']}
+
+ """,
+ ),
+ ]
+
+
+versions: Versions = {
+ 'edge': edge,
+ 'reflexion': reflexion,
+ 'extract_attributes': extract_attributes,
+}
diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py
index ca24c903..86f26ef3 100644
--- a/graphiti_core/search/search_utils.py
+++ b/graphiti_core/search/search_utils.py
@@ -174,7 +174,8 @@ async def edge_fulltext_search(
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
+ r.invalid_at AS invalid_at,
+ properties(r) AS attributes
ORDER BY score DESC LIMIT $limit
"""
)
@@ -243,7 +244,8 @@ async def edge_similarity_search(
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
+ r.invalid_at AS invalid_at,
+ properties(r) AS attributes
ORDER BY score DESC
LIMIT $limit
"""
@@ -301,7 +303,8 @@ async def edge_bfs_search(
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
+ r.invalid_at AS invalid_at,
+ properties(r) AS attributes
LIMIT $limit
"""
)
@@ -337,10 +340,10 @@ async def node_fulltext_search(
query = (
"""
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
- YIELD node AS n, score
- WHERE n:Entity
- """
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
+ YIELD node AS n, score
+ WHERE n:Entity
+ """
+ filter_query
+ ENTITY_NODE_RETURN
+ """
@@ -771,7 +774,8 @@ async def get_relevant_edges(
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
- invalid_at: e.invalid_at
+ invalid_at: e.invalid_at,
+ attributes: properties(e)
})[..$limit] AS matches
"""
)
@@ -837,7 +841,8 @@ async def get_edge_invalidation_candidates(
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
- invalid_at: e.invalid_at
+ invalid_at: e.invalid_at,
+ attributes: properties(e)
})[..$limit] AS matches
"""
)
diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py
index ed3fd00e..a4fe0651 100644
--- a/graphiti_core/utils/bulk_utils.py
+++ b/graphiti_core/utils/bulk_utils.py
@@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx(
entity_data['labels'] = list(set(node.labels + ['Entity']))
nodes.append(entity_data)
+ edges: list[dict[str, Any]] = []
for edge in entity_edges:
if edge.fact_embedding is None:
await edge.generate_embedding(embedder)
+ edge_data: dict[str, Any] = {
+ 'uuid': edge.uuid,
+ 'source_node_uuid': edge.source_node_uuid,
+ 'target_node_uuid': edge.target_node_uuid,
+ 'name': edge.name,
+ 'fact': edge.fact,
+ 'fact_embedding': edge.fact_embedding,
+ 'group_id': edge.group_id,
+ 'episodes': edge.episodes,
+ 'created_at': edge.created_at,
+ 'expired_at': edge.expired_at,
+ 'valid_at': edge.valid_at,
+ 'invalid_at': edge.invalid_at,
+ }
+
+ edge_data.update(edge.attributes or {})
+ edges.append(edge_data)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
await tx.run(
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
)
- await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
+ await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
async def extract_nodes_and_edges_bulk(
diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py
index d90fba52..06ffdd85 100644
--- a/graphiti_core/utils/maintenance/edge_operations.py
+++ b/graphiti_core/utils/maintenance/edge_operations.py
@@ -18,6 +18,8 @@ import logging
from datetime import datetime
from time import time
+from pydantic import BaseModel
+
from graphiti_core.edges import (
CommunityEdge,
EntityEdge,
@@ -83,6 +85,7 @@ async def extract_edges(
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
group_id: str = '',
+ edge_types: dict[str, BaseModel] | None = None,
) -> list[EntityEdge]:
start = time()
@@ -91,12 +94,25 @@ async def extract_edges(
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
+ edge_types_context = (
+ [
+ {
+ 'fact_type_name': type_name,
+ 'fact_type_description': type_model.__doc__,
+ }
+ for type_name, type_model in edge_types.items()
+ ]
+ if edge_types is not None
+ else []
+ )
+
# Prepare context for LLM
context = {
'episode_content': episode.content,
'nodes': [node.name for node in nodes],
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_time': episode.valid_at,
+ 'edge_types': edge_types_context,
'custom_prompt': '',
}
@@ -233,6 +249,9 @@ async def resolve_extracted_edges(
clients: GraphitiClients,
extracted_edges: list[EntityEdge],
episode: EpisodicNode,
+ entities: list[EntityNode],
+ edge_types: dict[str, BaseModel],
+ edge_type_map: dict[tuple[str, str], list[str]],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
driver = clients.driver
llm_client = clients.llm_client
@@ -251,15 +270,50 @@ async def resolve_extracted_edges(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
)
+ # Build entity hash table
+ uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
+
+ # Determine which edge types are relevant for each edge
+ edge_types_lst: list[dict[str, BaseModel]] = []
+ for extracted_edge in extracted_edges:
+ source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
+ target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
+ label_tuples = [
+ (source_label, target_label)
+ for source_label in source_node_labels
+ for target_label in target_node_labels
+ ]
+
+ extracted_edge_types = {}
+ for label_tuple in label_tuples:
+ type_names = edge_type_map.get(label_tuple, [])
+ for type_name in type_names:
+ type_model = edge_types.get(type_name)
+ if type_model is None:
+ continue
+
+ extracted_edge_types[type_name] = type_model
+
+ edge_types_lst.append(extracted_edge_types)
+
# resolve edges with related edges in the graph and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await semaphore_gather(
*[
resolve_extracted_edge(
- llm_client, extracted_edge, related_edges, existing_edges, episode
+ llm_client,
+ extracted_edge,
+ related_edges,
+ existing_edges,
+ episode,
+ extracted_edge_types,
)
- for extracted_edge, related_edges, existing_edges in zip(
- extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
+ for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
+ extracted_edges,
+ related_edges_lists,
+ edge_invalidation_candidates,
+ edge_types_lst,
+ strict=True,
)
]
)
@@ -322,7 +376,8 @@ async def resolve_extracted_edge(
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
- episode: EpisodicNode | None = None,
+ episode: EpisodicNode,
+ edge_types: dict[str, BaseModel] | None = None,
) -> tuple[EntityEdge, list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, []
@@ -338,10 +393,24 @@ async def resolve_extracted_edge(
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
]
+ edge_types_context = (
+ [
+ {
+ 'fact_type_id': i,
+ 'fact_type_name': type_name,
+ 'fact_type_description': type_model.__doc__,
+ }
+ for i, (type_name, type_model) in enumerate(edge_types.items())
+ ]
+ if edge_types is not None
+ else []
+ )
+
context = {
'existing_edges': related_edges_context,
'new_edge': extracted_edge.fact,
'edge_invalidation_candidates': invalidation_edge_candidates_context,
+ 'edge_types': edge_types_context,
}
llm_response = await llm_client.generate_response(
@@ -365,6 +434,26 @@ async def resolve_extracted_edge(
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
+ fact_type: str = str(llm_response.get('fact_type'))
+ if fact_type.upper() != 'DEFAULT' and edge_types is not None:
+ resolved_edge.name = fact_type
+
+ edge_attributes_context = {
+ 'message': episode.content,
+ 'reference_time': episode.valid_at,
+ 'fact': resolved_edge.fact,
+ }
+
+ edge_model = edge_types.get(fact_type)
+
+ edge_attributes_response = await llm_client.generate_response(
+ prompt_library.extract_edges.extract_attributes(edge_attributes_context),
+ response_model=edge_model, # type: ignore
+ model_size=ModelSize.small,
+ )
+
+ resolved_edge.attributes = edge_attributes_response
+
end = time()
logger.debug(
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'