From 299021173bf6aa26290d0cc301e01358c3e849bb Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Thu, 5 Sep 2024 12:05:44 -0400 Subject: [PATCH] Add episode refactor (#85) * temp commit while moving * fix name embedding bug * invalidation * format * tests on runner examples * format * ellipsis * ruff * fix * format * minor prompt change --- examples/ecommerce/runner.py | 11 +- examples/podcast/podcast_runner.py | 2 +- graphiti_core/graphiti.py | 120 ++++++++---------- graphiti_core/prompts/dedupe_edges.py | 2 +- graphiti_core/prompts/invalidate_edges.py | 38 +++++- graphiti_core/search/search_utils.py | 52 ++++---- .../utils/maintenance/edge_operations.py | 108 ++++++++++++++-- .../utils/maintenance/temporal_operations.py | 34 +++++ 8 files changed, 261 insertions(+), 106 deletions(-) diff --git a/examples/ecommerce/runner.py b/examples/ecommerce/runner.py index 0653b503..bb4317e0 100644 --- a/examples/ecommerce/runner.py +++ b/examples/ecommerce/runner.py @@ -94,7 +94,7 @@ async def main(): async def ingest_products_data(client: Graphiti): script_dir = Path(__file__).parent - json_file_path = script_dir / 'allbirds_products.json' + json_file_path = script_dir / '../data/manybirds_products.json' with open(json_file_path) as file: products = json.load(file)['products'] @@ -110,7 +110,14 @@ async def ingest_products_data(client: Graphiti): for i, product in enumerate(products) ] - await client.add_episode_bulk(episodes) + for episode in episodes: + await client.add_episode( + episode.name, + episode.content, + episode.source_description, + episode.reference_time, + episode.source, + ) asyncio.run(main()) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 1de7eeed..f100926e 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -86,4 +86,4 @@ async def main(use_bulk: bool = True): await client.add_episode_bulk(episodes) -asyncio.run(main(True)) +asyncio.run(main(False)) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 3f4dcfcb..6684fcc0 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -59,11 +59,6 @@ from graphiti_core.utils.maintenance.node_operations import ( extract_nodes, resolve_extracted_nodes, ) -from graphiti_core.utils.maintenance.temporal_operations import ( - extract_edge_dates, - invalidate_edges, - prepare_edges_for_invalidation, -) logger = logging.getLogger(__name__) @@ -293,7 +288,7 @@ class Graphiti: *[node.generate_name_embedding(embedder) for node in extracted_nodes] ) - # Resolve extracted nodes with nodes already in the graph + # Resolve extracted nodes with nodes already in the graph and extract facts existing_nodes_lists: list[list[EntityNode]] = list( await asyncio.gather( *[get_relevant_nodes([node], self.driver) for node in extracted_nodes] @@ -302,22 +297,27 @@ class Graphiti: logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') - mentioned_nodes, _ = await resolve_extracted_nodes( - self.llm_client, extracted_nodes, existing_nodes_lists + (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather( + resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists), + extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes), ) logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}') nodes.extend(mentioned_nodes) - # Extract facts as edges given entity nodes - extracted_edges = await extract_edges( - self.llm_client, episode, mentioned_nodes, previous_episodes + extracted_edges_with_resolved_pointers = resolve_edge_pointers( + extracted_edges, uuid_map ) # calculate embeddings - await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) + await asyncio.gather( + *[ + edge.generate_embedding(embedder) + for edge in extracted_edges_with_resolved_pointers + ] + ) - # Resolve extracted edges with edges already in the graph - existing_edges_list: list[list[EntityEdge]] = list( + # Resolve extracted edges with related edges already in the graph + related_edges_list: list[list[EntityEdge]] = list( await asyncio.gather( *[ get_relevant_edges( @@ -327,74 +327,66 @@ class Graphiti: edge.target_node_uuid, RELEVANT_SCHEMA_LIMIT, ) - for edge in extracted_edges + for edge in extracted_edges_with_resolved_pointers ] ) ) logger.info( - f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}' + f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}' ) - logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') - - deduped_edges: list[EntityEdge] = await resolve_extracted_edges( - self.llm_client, extracted_edges, existing_edges_list + logger.info( + f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}' ) - # Extract dates for the newly extracted edges - edge_dates = await asyncio.gather( - *[ - extract_edge_dates( - self.llm_client, - edge, - episode, - previous_episodes, - ) - for edge in deduped_edges - ] + existing_source_edges_list: list[list[EntityEdge]] = list( + await asyncio.gather( + *[ + get_relevant_edges( + self.driver, + [edge], + edge.source_node_uuid, + None, + RELEVANT_SCHEMA_LIMIT, + ) + for edge in extracted_edges_with_resolved_pointers + ] + ) ) - for i, edge in enumerate(deduped_edges): - valid_at = edge_dates[i][0] - invalid_at = edge_dates[i][1] + existing_target_edges_list: list[list[EntityEdge]] = list( + await asyncio.gather( + *[ + get_relevant_edges( + self.driver, + [edge], + None, + edge.target_node_uuid, + RELEVANT_SCHEMA_LIMIT, + ) + for edge in extracted_edges_with_resolved_pointers + ] + ) + ) - edge.valid_at = valid_at - edge.invalid_at = invalid_at - if edge.invalid_at is not None: - edge.expired_at = now - - entity_edges.extend(deduped_edges) - - existing_edges: list[EntityEdge] = [ - e for edge_lst in existing_edges_list for e in edge_lst + existing_edges_list: list[list[EntityEdge]] = [ + source_lst + target_lst + for source_lst, target_lst in zip( + existing_source_edges_list, existing_target_edges_list + ) ] - ( - old_edges_with_nodes_pending_invalidation, - new_edges_with_nodes, - ) = prepare_edges_for_invalidation( - existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes - ) - - invalidated_edges = await invalidate_edges( + resolved_edges, invalidated_edges = await resolve_extracted_edges( self.llm_client, - old_edges_with_nodes_pending_invalidation, - new_edges_with_nodes, + extracted_edges_with_resolved_pointers, + related_edges_list, + existing_edges_list, episode, previous_episodes, ) - for edge in invalidated_edges: - for existing_edge in existing_edges: - if existing_edge.uuid == edge.uuid: - existing_edge.expired_at = edge.expired_at - for deduped_edge in deduped_edges: - if deduped_edge.uuid == edge.uuid: - deduped_edge.expired_at = edge.expired_at - logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') + entity_edges.extend(resolved_edges + invalidated_edges) - entity_edges.extend(existing_edges) - - logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') + logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}') episodic_edges: list[EpisodicEdge] = build_episodic_edges( mentioned_nodes, diff --git a/graphiti_core/prompts/dedupe_edges.py b/graphiti_core/prompts/dedupe_edges.py index 9902bf75..29895192 100644 --- a/graphiti_core/prompts/dedupe_edges.py +++ b/graphiti_core/prompts/dedupe_edges.py @@ -129,7 +129,7 @@ def v3(context: dict[str, Any]) -> list[Message]: Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges. Existing Edges: - {json.dumps(context['existing_edges'], indent=2)} + {json.dumps(context['related_edges'], indent=2)} New Edge: {json.dumps(context['extracted_edges'], indent=2)} diff --git a/graphiti_core/prompts/invalidate_edges.py b/graphiti_core/prompts/invalidate_edges.py index 693169a8..5724390b 100644 --- a/graphiti_core/prompts/invalidate_edges.py +++ b/graphiti_core/prompts/invalidate_edges.py @@ -21,10 +21,12 @@ from .models import Message, PromptFunction, PromptVersion class Prompt(Protocol): v1: PromptVersion + v2: PromptVersion class Versions(TypedDict): v1: PromptFunction + v2: PromptFunction def v1(context: dict[str, Any]) -> list[Message]: @@ -71,4 +73,38 @@ def v1(context: dict[str, Any]) -> list[Message]: ] -versions: Versions = {'v1': v1} +def v2(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.', + ), + Message( + role='user', + content=f""" + Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge. + + Existing Edges: + {context['existing_edges']} + + New Edge: + {context['new_edge']} + + + For each existing edge that should be invalidated, respond with a JSON object in the following format: + {{ + "invalidated_edges": [ + {{ + "uuid": "The UUID of the edge to be invalidated", + "fact": "Updated fact of the edge" + }} + ] + }} + + If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges". + """, + ), + ] + + +versions: Versions = {'v1': v1, 'v2': v2} diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 22235cbc..75432a72 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -96,11 +96,11 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): async def edge_similarity_search( - driver: AsyncDriver, - search_vector: list[float], - source_node_uuid: str | None, - target_node_uuid: str | None, - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + search_vector: list[float], + source_node_uuid: str | None, + target_node_uuid: str | None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # vector similarity search over embedded facts query = Query(""" @@ -211,7 +211,7 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( @@ -247,7 +247,7 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -284,11 +284,11 @@ async def entity_fulltext_search( async def edge_fulltext_search( - driver: AsyncDriver, - query: str, - source_node_uuid: str | None, - target_node_uuid: str | None, - limit=RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + query: str, + source_node_uuid: str | None, + target_node_uuid: str | None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # fulltext search over facts cypher_query = Query(""" @@ -401,10 +401,10 @@ async def edge_fulltext_search( async def hybrid_node_search( - queries: list[str], - embeddings: list[list[float]], - driver: AsyncDriver, - limit: int = RELEVANT_SCHEMA_LIMIT, + queries: list[str], + embeddings: list[list[float]], + driver: AsyncDriver, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Perform a hybrid search for nodes using both text queries and embeddings. @@ -466,8 +466,8 @@ async def hybrid_node_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: """ Retrieve relevant nodes based on the provided list of EntityNodes. @@ -503,11 +503,11 @@ async def get_relevant_nodes( async def get_relevant_edges( - driver: AsyncDriver, - edges: list[EntityEdge], - source_node_uuid: str | None, - target_node_uuid: str | None, - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + edges: list[EntityEdge], + source_node_uuid: str | None, + target_node_uuid: str | None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -557,7 +557,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: # use rrf as a preliminary ranker sorted_uuids = rrf(results) @@ -579,8 +579,8 @@ async def node_distance_reranker( for record in records: if ( - record['source_uuid'] == center_node_uuid - or record['target_uuid'] == center_node_uuid + record['source_uuid'] == center_node_uuid + or record['target_uuid'] == center_node_uuid ): continue distance = record['score'] diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 00cb852f..0d6aa9eb 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -24,6 +24,10 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.prompts import prompt_library +from graphiti_core.utils.maintenance.temporal_operations import ( + extract_edge_dates, + get_edge_contradictions, +) logger = logging.getLogger(__name__) @@ -149,28 +153,110 @@ async def dedupe_extracted_edges( async def resolve_extracted_edges( llm_client: LLMClient, extracted_edges: list[EntityEdge], + related_edges_lists: list[list[EntityEdge]], existing_edges_lists: list[list[EntityEdge]], -) -> list[EntityEdge]: - resolved_edges: list[EntityEdge] = list( + current_episode: EpisodicNode, + previous_episodes: list[EpisodicNode], +) -> tuple[list[EntityEdge], list[EntityEdge]]: + # resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates + results: list[tuple[EntityEdge, list[EntityEdge]]] = list( await asyncio.gather( *[ - resolve_extracted_edge(llm_client, extracted_edge, existing_edges) - for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists) + resolve_extracted_edge( + llm_client, + extracted_edge, + related_edges, + existing_edges, + current_episode, + previous_episodes, + ) + for extracted_edge, related_edges, existing_edges in zip( + extracted_edges, related_edges_lists, existing_edges_lists + ) ] ) ) - return resolved_edges + resolved_edges: list[EntityEdge] = [] + invalidated_edges: list[EntityEdge] = [] + for result in results: + resolved_edge = result[0] + invalidated_edge_chunk = result[1] + + resolved_edges.append(resolved_edge) + invalidated_edges.extend(invalidated_edge_chunk) + + return resolved_edges, invalidated_edges async def resolve_extracted_edge( - llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge] + llm_client: LLMClient, + extracted_edge: EntityEdge, + related_edges: list[EntityEdge], + existing_edges: list[EntityEdge], + current_episode: EpisodicNode, + previous_episodes: list[EpisodicNode], +) -> tuple[EntityEdge, list[EntityEdge]]: + resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather( + dedupe_extracted_edge(llm_client, extracted_edge, related_edges), + extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes), + get_edge_contradictions(llm_client, extracted_edge, existing_edges), + ) + + now = datetime.now() + + resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at + resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at + if invalid_at is not None and resolved_edge.expired_at is None: + resolved_edge.expired_at = now + + # Determine if the new_edge needs to be expired + if resolved_edge.expired_at is None: + invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at)) + for candidate in invalidation_candidates: + if ( + candidate.valid_at is not None and resolved_edge.valid_at is not None + ) and candidate.valid_at > resolved_edge.valid_at: + # Expire new edge since we have information about more recent events + resolved_edge.invalid_at = candidate.valid_at + resolved_edge.expired_at = now + break + + # Determine which contradictory edges need to be expired + invalidated_edges: list[EntityEdge] = [] + for edge in invalidation_candidates: + # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid) + if ( + edge.invalid_at is not None + and resolved_edge.valid_at is not None + and edge.invalid_at < resolved_edge.valid_at + ) or ( + edge.valid_at is not None + and resolved_edge.invalid_at is not None + and resolved_edge.invalid_at < edge.valid_at + ): + continue + # New edge invalidates edge + elif ( + edge.valid_at is not None + and resolved_edge.valid_at is not None + and edge.valid_at < resolved_edge.valid_at + ): + edge.invalid_at = resolved_edge.valid_at + edge.expired_at = edge.expired_at if edge.expired_at is not None else now + invalidated_edges.append(edge) + + return resolved_edge, invalidated_edges + + +async def dedupe_extracted_edge( + llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge] ) -> EntityEdge: start = time() # Prepare context for LLM - existing_edges_context = [ - {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges + related_edges_context = [ + {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges ] extracted_edge_context = { @@ -180,7 +266,7 @@ async def resolve_extracted_edge( } context = { - 'existing_edges': existing_edges_context, + 'related_edges': related_edges_context, 'extracted_edges': extracted_edge_context, } @@ -191,14 +277,14 @@ async def resolve_extracted_edge( edge = extracted_edge if is_duplicate: - for existing_edge in existing_edges: + for existing_edge in related_edges: if existing_edge.uuid != uuid: continue edge = existing_edge end = time() logger.info( - f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms' + f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms' ) return edge diff --git a/graphiti_core/utils/maintenance/temporal_operations.py b/graphiti_core/utils/maintenance/temporal_operations.py index ef168209..6eaa22b4 100644 --- a/graphiti_core/utils/maintenance/temporal_operations.py +++ b/graphiti_core/utils/maintenance/temporal_operations.py @@ -16,6 +16,7 @@ limitations under the License. import logging from datetime import datetime +from time import time from typing import List from graphiti_core.edges import EntityEdge @@ -181,3 +182,36 @@ async def extract_edge_dates( logger.info(f'Edge date extraction explanation: {explanation}') return valid_at_datetime, invalid_at_datetime + + +async def get_edge_contradictions( + llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge] +) -> list[EntityEdge]: + start = time() + existing_edge_map = {edge.uuid: edge for edge in existing_edges} + + new_edge_context = {'uuid': new_edge.uuid, 'name': new_edge.name, 'fact': new_edge.fact} + existing_edge_context = [ + {'uuid': existing_edge.uuid, 'name': existing_edge.name, 'fact': existing_edge.fact} + for existing_edge in existing_edges + ] + + context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context} + + llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v2(context)) + + contradicted_edge_data = llm_response.get('invalidated_edges', []) + + contradicted_edges: list[EntityEdge] = [] + for edge_data in contradicted_edge_data: + if edge_data['uuid'] in existing_edge_map: + contradicted_edge = existing_edge_map[edge_data['uuid']] + contradicted_edge.fact = edge_data['fact'] + contradicted_edges.append(contradicted_edge) + + end = time() + logger.info( + f'Found invalidated edge candidates from {new_edge.fact}, in {(end - start) * 1000} ms' + ) + + return contradicted_edges