Edge types (#501)

* update entity edge attributes

* Adding prompts

* extract fact attributes

* edge types

* edge types no regressions

* mypy

* mypy update

* Update graphiti_core/prompts/dedupe_edges.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* Update graphiti_core/prompts/dedupe_edges.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* mypy

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-05-19 13:30:56 -04:00 committed by GitHub
parent 619c84e98b
commit db7595fe63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 252 additions and 38 deletions

View file

@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at"""
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
class Edge(BaseModel, ABC):
@ -209,6 +211,9 @@ class EntityEdge(Edge):
invalid_at: datetime | None = Field(
default=None, description='datetime of when the fact stopped being true'
)
attributes: dict[str, Any] = Field(
default={}, description='Additional attributes of the edge. Dependent on edge name'
)
async def generate_embedding(self, embedder: EmbedderClient):
start = time()
@ -236,20 +241,26 @@ class EntityEdge(Edge):
self.fact_embedding = records[0]['fact_embedding']
async def save(self, driver: AsyncDriver):
edge_data: dict[str, Any] = {
'source_uuid': self.source_node_uuid,
'target_uuid': self.target_node_uuid,
'uuid': self.uuid,
'name': self.name,
'group_id': self.group_id,
'fact': self.fact,
'fact_embedding': self.fact_embedding,
'episodes': self.episodes,
'created_at': self.created_at,
'expired_at': self.expired_at,
'valid_at': self.valid_at,
'invalid_at': self.invalid_at,
}
edge_data.update(self.attributes or {})
result = await driver.execute_query(
ENTITY_EDGE_SAVE,
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
fact=self.fact,
fact_embedding=self.fact_embedding,
episodes=self.episodes,
created_at=self.created_at,
expired_at=self.expired_at,
valid_at=self.valid_at,
invalid_at=self.invalid_at,
edge_data=edge_data,
database_=DEFAULT_DATABASE,
)
@ -334,8 +345,8 @@ class EntityEdge(Edge):
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
query: LiteralString = (
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(
@ -457,7 +468,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
def get_entity_edge_from_record(record: Any) -> EntityEdge:
return EntityEdge(
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
@ -469,8 +480,23 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
attributes=record['attributes'],
)
edge.attributes.pop('uuid', None)
edge.attributes.pop('source_node_uuid', None)
edge.attributes.pop('target_node_uuid', None)
edge.attributes.pop('fact', None)
edge.attributes.pop('name', None)
edge.attributes.pop('group_id', None)
edge.attributes.pop('episodes', None)
edge.attributes.pop('created_at', None)
edge.attributes.pop('expired_at', None)
edge.attributes.pop('valid_at', None)
edge.attributes.pop('invalid_at', None)
return edge
def get_community_edge_from_record(record: Any):
return CommunityEdge(

View file

@ -273,6 +273,8 @@ class Graphiti:
update_communities: bool = False,
entity_types: dict[str, BaseModel] | None = None,
previous_episode_uuids: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
) -> AddEpisodeResults:
"""
Process an episode and update the graph.
@ -355,6 +357,13 @@ class Graphiti:
)
)
# Create default edge type map
edge_type_map_default = (
{('Entity', 'Entity'): list(edge_types.keys())}
if edge_types is not None
else {('Entity', 'Entity'): []}
)
# Extract entities as nodes
extracted_nodes = await extract_nodes(
@ -370,7 +379,9 @@ class Graphiti:
previous_episodes,
entity_types,
),
extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
extract_edges(
self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
),
)
edges = resolve_edge_pointers(extracted_edges, uuid_map)
@ -380,6 +391,9 @@ class Graphiti:
self.clients,
edges,
episode,
nodes,
edge_types or {},
edge_type_map or edge_type_map_default,
),
extract_attributes_from_nodes(
self.clients, nodes, episode, previous_episodes, entity_types
@ -686,7 +700,19 @@ class Graphiti:
)[0]
resolved_edge, invalidated_edges = await resolve_extracted_edge(
self.llm_client, updated_edge, related_edges, existing_edges
self.llm_client,
updated_edge,
related_edges,
existing_edges,
EpisodicNode(
name='',
source=EpisodeType.text,
source_description='',
content='',
valid_at=edge.valid_at or utc_now(),
entity_edges=[],
group_id=edge.group_id,
),
)
await add_nodes_and_edges_bulk(

View file

@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
SET r = $edge_data
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid"""
@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
SET r = edge
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
RETURN edge.uuid AS uuid
"""

View file

@ -31,6 +31,7 @@ class EdgeDuplicate(BaseModel):
...,
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
)
fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
class UniqueFact(BaseModel):
@ -133,11 +134,18 @@ def resolve_edge(context: dict[str, Any]) -> list[Message]:
{context['edge_invalidation_candidates']}
</FACT INVALIDATION CANDIDATES>
<FACT TYPES>
{context['edge_types']}
</FACT TYPES>
Task:
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
If there are no contradicted facts, return an empty list.

View file

@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
class Prompt(Protocol):
edge: PromptVersion
reflexion: PromptVersion
extract_attributes: PromptVersion
class Versions(TypedDict):
edge: PromptFunction
reflexion: PromptFunction
extract_attributes: PromptFunction
def edge(context: dict[str, Any]) -> list[Message]:
@ -82,12 +84,18 @@ def edge(context: dict[str, Any]) -> list[Message]:
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
</REFERENCE_TIME>
<FACT TYPES>
{context['edge_types']}
</FACT TYPES>
# TASK
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
Only extract facts that:
- involve two DISTINCT ENTITIES from the ENTITIES list,
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
- and can be represented as edges in a knowledge graph.
and can be represented as edges in a knowledge graph.
- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
could be classified into one of the provided fact types
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
@ -145,4 +153,40 @@ determine if any facts haven't been extracted.
]
versions: Versions = {'edge': edge, 'reflexion': reflexion}
def extract_attributes(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that extracts fact properties from the provided text.',
),
Message(
role='user',
content=f"""
<MESSAGE>
{json.dumps(context['episode_content'], indent=2)}
</MESSAGE>
<REFERENCE TIME>
{context['reference_time']}
</REFERENCE TIME>
Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
Guidelines:
1. Do not hallucinate entity property values if they cannot be found in the current context.
2. Only use the provided MESSAGES and FACT to set attribute values.
<FACT>
{context['fact']}
</FACT>
""",
),
]
versions: Versions = {
'edge': edge,
'reflexion': reflexion,
'extract_attributes': extract_attributes,
}

View file

@ -174,7 +174,8 @@ async def edge_fulltext_search(
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
r.invalid_at AS invalid_at,
properties(r) AS attributes
ORDER BY score DESC LIMIT $limit
"""
)
@ -243,7 +244,8 @@ async def edge_similarity_search(
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
r.invalid_at AS invalid_at,
properties(r) AS attributes
ORDER BY score DESC
LIMIT $limit
"""
@ -301,7 +303,8 @@ async def edge_bfs_search(
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
r.invalid_at AS invalid_at,
properties(r) AS attributes
LIMIT $limit
"""
)
@ -337,10 +340,10 @@ async def node_fulltext_search(
query = (
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
@ -771,7 +774,8 @@ async def get_relevant_edges(
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
"""
)
@ -837,7 +841,8 @@ async def get_edge_invalidation_candidates(
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
"""
)

View file

@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx(
entity_data['labels'] = list(set(node.labels + ['Entity']))
nodes.append(entity_data)
edges: list[dict[str, Any]] = []
for edge in entity_edges:
if edge.fact_embedding is None:
await edge.generate_embedding(embedder)
edge_data: dict[str, Any] = {
'uuid': edge.uuid,
'source_node_uuid': edge.source_node_uuid,
'target_node_uuid': edge.target_node_uuid,
'name': edge.name,
'fact': edge.fact,
'fact_embedding': edge.fact_embedding,
'group_id': edge.group_id,
'episodes': edge.episodes,
'created_at': edge.created_at,
'expired_at': edge.expired_at,
'valid_at': edge.valid_at,
'invalid_at': edge.invalid_at,
}
edge_data.update(edge.attributes or {})
edges.append(edge_data)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
await tx.run(
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
)
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
async def extract_nodes_and_edges_bulk(

View file

@ -18,6 +18,8 @@ import logging
from datetime import datetime
from time import time
from pydantic import BaseModel
from graphiti_core.edges import (
CommunityEdge,
EntityEdge,
@ -83,6 +85,7 @@ async def extract_edges(
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
group_id: str = '',
edge_types: dict[str, BaseModel] | None = None,
) -> list[EntityEdge]:
start = time()
@ -91,12 +94,25 @@ async def extract_edges(
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
edge_types_context = (
[
{
'fact_type_name': type_name,
'fact_type_description': type_model.__doc__,
}
for type_name, type_model in edge_types.items()
]
if edge_types is not None
else []
)
# Prepare context for LLM
context = {
'episode_content': episode.content,
'nodes': [node.name for node in nodes],
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_time': episode.valid_at,
'edge_types': edge_types_context,
'custom_prompt': '',
}
@ -233,6 +249,9 @@ async def resolve_extracted_edges(
clients: GraphitiClients,
extracted_edges: list[EntityEdge],
episode: EpisodicNode,
entities: list[EntityNode],
edge_types: dict[str, BaseModel],
edge_type_map: dict[tuple[str, str], list[str]],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
driver = clients.driver
llm_client = clients.llm_client
@ -251,15 +270,50 @@ async def resolve_extracted_edges(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
)
# Build entity hash table
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
# Determine which edge types are relevant for each edge
edge_types_lst: list[dict[str, BaseModel]] = []
for extracted_edge in extracted_edges:
source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
label_tuples = [
(source_label, target_label)
for source_label in source_node_labels
for target_label in target_node_labels
]
extracted_edge_types = {}
for label_tuple in label_tuples:
type_names = edge_type_map.get(label_tuple, [])
for type_name in type_names:
type_model = edge_types.get(type_name)
if type_model is None:
continue
extracted_edge_types[type_name] = type_model
edge_types_lst.append(extracted_edge_types)
# resolve edges with related edges in the graph and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await semaphore_gather(
*[
resolve_extracted_edge(
llm_client, extracted_edge, related_edges, existing_edges, episode
llm_client,
extracted_edge,
related_edges,
existing_edges,
episode,
extracted_edge_types,
)
for extracted_edge, related_edges, existing_edges in zip(
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
extracted_edges,
related_edges_lists,
edge_invalidation_candidates,
edge_types_lst,
strict=True,
)
]
)
@ -322,7 +376,8 @@ async def resolve_extracted_edge(
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
episode: EpisodicNode | None = None,
episode: EpisodicNode,
edge_types: dict[str, BaseModel] | None = None,
) -> tuple[EntityEdge, list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, []
@ -338,10 +393,24 @@ async def resolve_extracted_edge(
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
]
edge_types_context = (
[
{
'fact_type_id': i,
'fact_type_name': type_name,
'fact_type_description': type_model.__doc__,
}
for i, (type_name, type_model) in enumerate(edge_types.items())
]
if edge_types is not None
else []
)
context = {
'existing_edges': related_edges_context,
'new_edge': extracted_edge.fact,
'edge_invalidation_candidates': invalidation_edge_candidates_context,
'edge_types': edge_types_context,
}
llm_response = await llm_client.generate_response(
@ -365,6 +434,26 @@ async def resolve_extracted_edge(
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
fact_type: str = str(llm_response.get('fact_type'))
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
resolved_edge.name = fact_type
edge_attributes_context = {
'message': episode.content,
'reference_time': episode.valid_at,
'fact': resolved_edge.fact,
}
edge_model = edge_types.get(fact_type)
edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
response_model=edge_model, # type: ignore
model_size=ModelSize.small,
)
resolved_edge.attributes = edge_attributes_response
end = time()
logger.debug(
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'