From db7595fe63fda51d3e42169340ed30cd00b45fed Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 19 May 2025 13:30:56 -0400 Subject: [PATCH] Edge types (#501) * update entity edge attributes * Adding prompts * extract fact attributes * edge types * edge types no regressions * mypy * mypy update * Update graphiti_core/prompts/dedupe_edges.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update graphiti_core/prompts/dedupe_edges.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- graphiti_core/edges.py | 58 ++++++++--- graphiti_core/graphiti.py | 30 +++++- graphiti_core/models/edges/edge_db_queries.py | 6 +- graphiti_core/prompts/dedupe_edges.py | 8 ++ graphiti_core/prompts/extract_edges.py | 48 ++++++++- graphiti_core/search/search_utils.py | 23 +++-- graphiti_core/utils/bulk_utils.py | 20 +++- .../utils/maintenance/edge_operations.py | 97 ++++++++++++++++++- 8 files changed, 252 insertions(+), 38 deletions(-) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 491afa2b..700775f3 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """ e.episodes AS episodes, e.expired_at AS expired_at, e.valid_at AS valid_at, - e.invalid_at AS invalid_at""" + e.invalid_at AS invalid_at, + properties(e) AS attributes + """ class Edge(BaseModel, ABC): @@ -209,6 +211,9 @@ class EntityEdge(Edge): invalid_at: datetime | None = Field( default=None, description='datetime of when the fact stopped being true' ) + attributes: dict[str, Any] = Field( + default={}, description='Additional attributes of the edge. Dependent on edge name' + ) async def generate_embedding(self, embedder: EmbedderClient): start = time() @@ -236,20 +241,26 @@ class EntityEdge(Edge): self.fact_embedding = records[0]['fact_embedding'] async def save(self, driver: AsyncDriver): + edge_data: dict[str, Any] = { + 'source_uuid': self.source_node_uuid, + 'target_uuid': self.target_node_uuid, + 'uuid': self.uuid, + 'name': self.name, + 'group_id': self.group_id, + 'fact': self.fact, + 'fact_embedding': self.fact_embedding, + 'episodes': self.episodes, + 'created_at': self.created_at, + 'expired_at': self.expired_at, + 'valid_at': self.valid_at, + 'invalid_at': self.invalid_at, + } + + edge_data.update(self.attributes or {}) + result = await driver.execute_query( ENTITY_EDGE_SAVE, - source_uuid=self.source_node_uuid, - target_uuid=self.target_node_uuid, - uuid=self.uuid, - name=self.name, - group_id=self.group_id, - fact=self.fact, - fact_embedding=self.fact_embedding, - episodes=self.episodes, - created_at=self.created_at, - expired_at=self.expired_at, - valid_at=self.valid_at, - invalid_at=self.invalid_at, + edge_data=edge_data, database_=DEFAULT_DATABASE, ) @@ -334,8 +345,8 @@ class EntityEdge(Edge): async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str): query: LiteralString = ( """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - """ + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + """ + ENTITY_EDGE_RETURN ) records, _, _ = await driver.execute_query( @@ -457,7 +468,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge: def get_entity_edge_from_record(record: Any) -> EntityEdge: - return EntityEdge( + edge = EntityEdge( uuid=record['uuid'], source_node_uuid=record['source_node_uuid'], target_node_uuid=record['target_node_uuid'], @@ -469,8 +480,23 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge: expired_at=parse_db_date(record['expired_at']), valid_at=parse_db_date(record['valid_at']), invalid_at=parse_db_date(record['invalid_at']), + attributes=record['attributes'], ) + edge.attributes.pop('uuid', None) + edge.attributes.pop('source_node_uuid', None) + edge.attributes.pop('target_node_uuid', None) + edge.attributes.pop('fact', None) + edge.attributes.pop('name', None) + edge.attributes.pop('group_id', None) + edge.attributes.pop('episodes', None) + edge.attributes.pop('created_at', None) + edge.attributes.pop('expired_at', None) + edge.attributes.pop('valid_at', None) + edge.attributes.pop('invalid_at', None) + + return edge + def get_community_edge_from_record(record: Any): return CommunityEdge( diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 5417728f..930f66d4 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -273,6 +273,8 @@ class Graphiti: update_communities: bool = False, entity_types: dict[str, BaseModel] | None = None, previous_episode_uuids: list[str] | None = None, + edge_types: dict[str, BaseModel] | None = None, + edge_type_map: dict[tuple[str, str], list[str]] | None = None, ) -> AddEpisodeResults: """ Process an episode and update the graph. @@ -355,6 +357,13 @@ class Graphiti: ) ) + # Create default edge type map + edge_type_map_default = ( + {('Entity', 'Entity'): list(edge_types.keys())} + if edge_types is not None + else {('Entity', 'Entity'): []} + ) + # Extract entities as nodes extracted_nodes = await extract_nodes( @@ -370,7 +379,9 @@ class Graphiti: previous_episodes, entity_types, ), - extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id), + extract_edges( + self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types + ), ) edges = resolve_edge_pointers(extracted_edges, uuid_map) @@ -380,6 +391,9 @@ class Graphiti: self.clients, edges, episode, + nodes, + edge_types or {}, + edge_type_map or edge_type_map_default, ), extract_attributes_from_nodes( self.clients, nodes, episode, previous_episodes, entity_types @@ -686,7 +700,19 @@ class Graphiti: )[0] resolved_edge, invalidated_edges = await resolve_extracted_edge( - self.llm_client, updated_edge, related_edges, existing_edges + self.llm_client, + updated_edge, + related_edges, + existing_edges, + EpisodicNode( + name='', + source=EpisodeType.text, + source_description='', + content='', + valid_at=edge.valid_at or utc_now(), + entity_edges=[], + group_id=edge.group_id, + ), ) await add_nodes_and_edges_bulk( diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index ba687219..a1cf547f 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """ MATCH (source:Entity {uuid: $source_uuid}) MATCH (target:Entity {uuid: $target_uuid}) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) - SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes, - created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at} + SET r = $edge_data WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding) RETURN r.uuid AS uuid""" @@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """ MATCH (source:Entity {uuid: edge.source_node_uuid}) MATCH (target:Entity {uuid: edge.target_node_uuid}) MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) - SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes, - created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at} + SET r = edge WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding) RETURN edge.uuid AS uuid """ diff --git a/graphiti_core/prompts/dedupe_edges.py b/graphiti_core/prompts/dedupe_edges.py index f63011d4..6ccbbf26 100644 --- a/graphiti_core/prompts/dedupe_edges.py +++ b/graphiti_core/prompts/dedupe_edges.py @@ -31,6 +31,7 @@ class EdgeDuplicate(BaseModel): ..., description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.', ) + fact_type: str = Field(..., description='One of the provided fact types or DEFAULT') class UniqueFact(BaseModel): @@ -133,11 +134,18 @@ def resolve_edge(context: dict[str, Any]) -> list[Message]: {context['edge_invalidation_candidates']} + + {context['edge_types']} + + Task: If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact. If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1. + Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types. + Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES. + Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts. Return a list containing all idx's of the facts that are contradicted by the NEW FACT. If there are no contradicted facts, return an empty list. diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py index e7f41cdb..37db4699 100644 --- a/graphiti_core/prompts/extract_edges.py +++ b/graphiti_core/prompts/extract_edges.py @@ -48,11 +48,13 @@ class MissingFacts(BaseModel): class Prompt(Protocol): edge: PromptVersion reflexion: PromptVersion + extract_attributes: PromptVersion class Versions(TypedDict): edge: PromptFunction reflexion: PromptFunction + extract_attributes: PromptFunction def edge(context: dict[str, Any]) -> list[Message]: @@ -82,12 +84,18 @@ def edge(context: dict[str, Any]) -> list[Message]: {context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions + +{context['edge_types']} + + # TASK Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE. Only extract facts that: - involve two DISTINCT ENTITIES from the ENTITIES list, - are clearly stated or unambiguously implied in the CURRENT MESSAGE, -- and can be represented as edges in a knowledge graph. + and can be represented as edges in a knowledge graph. +- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that + could be classified into one of the provided fact types You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity. @@ -145,4 +153,40 @@ determine if any facts haven't been extracted. ] -versions: Versions = {'edge': edge, 'reflexion': reflexion} +def extract_attributes(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that extracts fact properties from the provided text.', + ), + Message( + role='user', + content=f""" + + + {json.dumps(context['episode_content'], indent=2)} + + + {context['reference_time']} + + + Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided + in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined. + + Guidelines: + 1. Do not hallucinate entity property values if they cannot be found in the current context. + 2. Only use the provided MESSAGES and FACT to set attribute values. + + + {context['fact']} + + """, + ), + ] + + +versions: Versions = { + 'edge': edge, + 'reflexion': reflexion, + 'extract_attributes': extract_attributes, +} diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index ca24c903..86f26ef3 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -174,7 +174,8 @@ async def edge_fulltext_search( r.episodes AS episodes, r.expired_at AS expired_at, r.valid_at AS valid_at, - r.invalid_at AS invalid_at + r.invalid_at AS invalid_at, + properties(r) AS attributes ORDER BY score DESC LIMIT $limit """ ) @@ -243,7 +244,8 @@ async def edge_similarity_search( r.episodes AS episodes, r.expired_at AS expired_at, r.valid_at AS valid_at, - r.invalid_at AS invalid_at + r.invalid_at AS invalid_at, + properties(r) AS attributes ORDER BY score DESC LIMIT $limit """ @@ -301,7 +303,8 @@ async def edge_bfs_search( r.episodes AS episodes, r.expired_at AS expired_at, r.valid_at AS valid_at, - r.invalid_at AS invalid_at + r.invalid_at AS invalid_at, + properties(r) AS attributes LIMIT $limit """ ) @@ -337,10 +340,10 @@ async def node_fulltext_search( query = ( """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -771,7 +774,8 @@ async def get_relevant_edges( episodes: e.episodes, expired_at: e.expired_at, valid_at: e.valid_at, - invalid_at: e.invalid_at + invalid_at: e.invalid_at, + attributes: properties(e) })[..$limit] AS matches """ ) @@ -837,7 +841,8 @@ async def get_edge_invalidation_candidates( episodes: e.episodes, expired_at: e.expired_at, valid_at: e.valid_at, - invalid_at: e.invalid_at + invalid_at: e.invalid_at, + attributes: properties(e) })[..$limit] AS matches """ ) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index ed3fd00e..a4fe0651 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx( entity_data['labels'] = list(set(node.labels + ['Entity'])) nodes.append(entity_data) + edges: list[dict[str, Any]] = [] for edge in entity_edges: if edge.fact_embedding is None: await edge.generate_embedding(embedder) + edge_data: dict[str, Any] = { + 'uuid': edge.uuid, + 'source_node_uuid': edge.source_node_uuid, + 'target_node_uuid': edge.target_node_uuid, + 'name': edge.name, + 'fact': edge.fact, + 'fact_embedding': edge.fact_embedding, + 'group_id': edge.group_id, + 'episodes': edge.episodes, + 'created_at': edge.created_at, + 'expired_at': edge.expired_at, + 'valid_at': edge.valid_at, + 'invalid_at': edge.invalid_at, + } + + edge_data.update(edge.attributes or {}) + edges.append(edge_data) await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes) await tx.run( EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges] ) - await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges]) + await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges) async def extract_nodes_and_edges_bulk( diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index d90fba52..06ffdd85 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -18,6 +18,8 @@ import logging from datetime import datetime from time import time +from pydantic import BaseModel + from graphiti_core.edges import ( CommunityEdge, EntityEdge, @@ -83,6 +85,7 @@ async def extract_edges( nodes: list[EntityNode], previous_episodes: list[EpisodicNode], group_id: str = '', + edge_types: dict[str, BaseModel] | None = None, ) -> list[EntityEdge]: start = time() @@ -91,12 +94,25 @@ async def extract_edges( node_uuids_by_name_map = {node.name: node.uuid for node in nodes} + edge_types_context = ( + [ + { + 'fact_type_name': type_name, + 'fact_type_description': type_model.__doc__, + } + for type_name, type_model in edge_types.items() + ] + if edge_types is not None + else [] + ) + # Prepare context for LLM context = { 'episode_content': episode.content, 'nodes': [node.name for node in nodes], 'previous_episodes': [ep.content for ep in previous_episodes], 'reference_time': episode.valid_at, + 'edge_types': edge_types_context, 'custom_prompt': '', } @@ -233,6 +249,9 @@ async def resolve_extracted_edges( clients: GraphitiClients, extracted_edges: list[EntityEdge], episode: EpisodicNode, + entities: list[EntityNode], + edge_types: dict[str, BaseModel], + edge_type_map: dict[tuple[str, str], list[str]], ) -> tuple[list[EntityEdge], list[EntityEdge]]: driver = clients.driver llm_client = clients.llm_client @@ -251,15 +270,50 @@ async def resolve_extracted_edges( f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}' ) + # Build entity hash table + uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities} + + # Determine which edge types are relevant for each edge + edge_types_lst: list[dict[str, BaseModel]] = [] + for extracted_edge in extracted_edges: + source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels + target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels + label_tuples = [ + (source_label, target_label) + for source_label in source_node_labels + for target_label in target_node_labels + ] + + extracted_edge_types = {} + for label_tuple in label_tuples: + type_names = edge_type_map.get(label_tuple, []) + for type_name in type_names: + type_model = edge_types.get(type_name) + if type_model is None: + continue + + extracted_edge_types[type_name] = type_model + + edge_types_lst.append(extracted_edge_types) + # resolve edges with related edges in the graph and find invalidation candidates results: list[tuple[EntityEdge, list[EntityEdge]]] = list( await semaphore_gather( *[ resolve_extracted_edge( - llm_client, extracted_edge, related_edges, existing_edges, episode + llm_client, + extracted_edge, + related_edges, + existing_edges, + episode, + extracted_edge_types, ) - for extracted_edge, related_edges, existing_edges in zip( - extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True + for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip( + extracted_edges, + related_edges_lists, + edge_invalidation_candidates, + edge_types_lst, + strict=True, ) ] ) @@ -322,7 +376,8 @@ async def resolve_extracted_edge( extracted_edge: EntityEdge, related_edges: list[EntityEdge], existing_edges: list[EntityEdge], - episode: EpisodicNode | None = None, + episode: EpisodicNode, + edge_types: dict[str, BaseModel] | None = None, ) -> tuple[EntityEdge, list[EntityEdge]]: if len(related_edges) == 0 and len(existing_edges) == 0: return extracted_edge, [] @@ -338,10 +393,24 @@ async def resolve_extracted_edge( {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges) ] + edge_types_context = ( + [ + { + 'fact_type_id': i, + 'fact_type_name': type_name, + 'fact_type_description': type_model.__doc__, + } + for i, (type_name, type_model) in enumerate(edge_types.items()) + ] + if edge_types is not None + else [] + ) + context = { 'existing_edges': related_edges_context, 'new_edge': extracted_edge.fact, 'edge_invalidation_candidates': invalidation_edge_candidates_context, + 'edge_types': edge_types_context, } llm_response = await llm_client.generate_response( @@ -365,6 +434,26 @@ async def resolve_extracted_edge( invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts] + fact_type: str = str(llm_response.get('fact_type')) + if fact_type.upper() != 'DEFAULT' and edge_types is not None: + resolved_edge.name = fact_type + + edge_attributes_context = { + 'message': episode.content, + 'reference_time': episode.valid_at, + 'fact': resolved_edge.fact, + } + + edge_model = edge_types.get(fact_type) + + edge_attributes_response = await llm_client.generate_response( + prompt_library.extract_edges.extract_attributes(edge_attributes_context), + response_model=edge_model, # type: ignore + model_size=ModelSize.small, + ) + + resolved_edge.attributes = edge_attributes_response + end = time() logger.debug( f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'