diff --git a/core/graphiti.py b/core/graphiti.py index b25e165b..1632cfa2 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -55,7 +55,6 @@ from core.utils.maintenance.graph_data_operations import ( from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.temporal_operations import ( extract_edge_dates, - extract_node_edge_node_triplet, invalidate_edges, prepare_edges_for_invalidation, ) @@ -183,22 +182,27 @@ class Graphiti: ) for edge in invalidated_edges: + for existing_edge in existing_edges: + if existing_edge.uuid == edge.uuid: + existing_edge.expired_at = edge.expired_at + for deduped_edge in deduped_edges: + if deduped_edge.uuid == edge.uuid: + deduped_edge.expired_at = edge.expired_at edge_touched_node_uuids.append(edge.source_node_uuid) edge_touched_node_uuids.append(edge.target_node_uuid) - edges_to_save = invalidated_edges + edges_to_save = existing_edges + deduped_edges - # There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one - for deduped_edge in deduped_edges: - if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]: - edges_to_save.append(deduped_edge) - for deduped_edge in deduped_edges: - triplet = extract_node_edge_node_triplet(deduped_edge, nodes) + for edge_to_extract_dates_from in edges_to_save: valid_at, invalid_at, _ = await extract_edge_dates( - self.llm_client, triplet, episode.valid_at, episode, previous_episodes + self.llm_client, + edge_to_extract_dates_from, + episode.valid_at, + episode, + previous_episodes, ) - deduped_edge.valid_at = valid_at - deduped_edge.invalid_at = invalid_at + edge_to_extract_dates_from.valid_at = valid_at + edge_to_extract_dates_from.invalid_at = invalid_at entity_edges.extend(edges_to_save) edge_touched_node_uuids = list(set(edge_touched_node_uuids)) diff --git a/core/prompts/extract_edge_dates.py b/core/prompts/extract_edge_dates.py index cae639af..6d52becd 100644 --- a/core/prompts/extract_edge_dates.py +++ b/core/prompts/extract_edge_dates.py @@ -21,9 +21,7 @@ def v1(context: dict[str, Any]) -> list[Message]: role='user', content=f""" Edge: - Source Node: {context['source_node']} Edge Name: {context['edge_name']} - Target Node: {context['target_node']} Fact: {context['edge_fact']} Current Episode: {context['current_episode']} diff --git a/core/utils/maintenance/temporal_operations.py b/core/utils/maintenance/temporal_operations.py index 832ad9c0..d8858eba 100644 --- a/core/utils/maintenance/temporal_operations.py +++ b/core/utils/maintenance/temporal_operations.py @@ -81,9 +81,7 @@ async def invalidate_edges( current_episode, previous_episodes, ) - logger.info(prompt_library.invalidate_edges.v1(context)) llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context)) - logger.info(f'invalidate_edges LLM response: {llm_response}') edges_to_invalidate = llm_response.get('invalidated_edges', []) invalidated_edges = process_edge_invalidation_llm_response( @@ -139,17 +137,13 @@ def process_edge_invalidation_llm_response( async def extract_edge_dates( llm_client: LLMClient, - edge_triplet: NodeEdgeNodeTriplet, + edge: EntityEdge, reference_time: datetime, current_episode: EpisodicNode, previous_episodes: List[EpisodicNode], ) -> tuple[datetime | None, datetime | None, str]: - source_node, edge, target_node = edge_triplet - context = { - 'source_node': source_node.name, 'edge_name': edge.name, - 'target_node': target_node.name, 'edge_fact': edge.fact, 'current_episode': current_episode.content, 'previous_episodes': [ep.content for ep in previous_episodes], diff --git a/runner.py b/runner.py index a1338dc6..616dbd51 100644 --- a/runner.py +++ b/runner.py @@ -95,7 +95,6 @@ async def main(): await clear_data(client.driver) await client.build_indices_and_constraints() - # await client.build_indices() for i, message in enumerate(bmw_sales): await client.add_episode( name=f'Message {i}',