chore: Update the context for date extraction + bug fixes (#31)

* chore: Update the context for date extraction + bug fixes

* chore: Remove logs
This commit is contained in:
Pavlo Paliychuk 2024-08-23 16:45:59 -04:00 committed by GitHub
parent c2aaf94be4
commit 427a67b8f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 16 additions and 21 deletions

View file

@ -55,7 +55,6 @@ from core.utils.maintenance.graph_data_operations import (
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
extract_edge_dates,
extract_node_edge_node_triplet,
invalidate_edges,
prepare_edges_for_invalidation,
)
@ -183,22 +182,27 @@ class Graphiti:
)
for edge in invalidated_edges:
for existing_edge in existing_edges:
if existing_edge.uuid == edge.uuid:
existing_edge.expired_at = edge.expired_at
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
edges_to_save = invalidated_edges
edges_to_save = existing_edges + deduped_edges
# There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one
for deduped_edge in deduped_edges:
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
edges_to_save.append(deduped_edge)
for deduped_edge in deduped_edges:
triplet = extract_node_edge_node_triplet(deduped_edge, nodes)
for edge_to_extract_dates_from in edges_to_save:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client, triplet, episode.valid_at, episode, previous_episodes
self.llm_client,
edge_to_extract_dates_from,
episode.valid_at,
episode,
previous_episodes,
)
deduped_edge.valid_at = valid_at
deduped_edge.invalid_at = invalid_at
edge_to_extract_dates_from.valid_at = valid_at
edge_to_extract_dates_from.invalid_at = invalid_at
entity_edges.extend(edges_to_save)
edge_touched_node_uuids = list(set(edge_touched_node_uuids))

View file

@ -21,9 +21,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
role='user',
content=f"""
Edge:
Source Node: {context['source_node']}
Edge Name: {context['edge_name']}
Target Node: {context['target_node']}
Fact: {context['edge_fact']}
Current Episode: {context['current_episode']}

View file

@ -81,9 +81,7 @@ async def invalidate_edges(
current_episode,
previous_episodes,
)
logger.info(prompt_library.invalidate_edges.v1(context))
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
logger.info(f'invalidate_edges LLM response: {llm_response}')
edges_to_invalidate = llm_response.get('invalidated_edges', [])
invalidated_edges = process_edge_invalidation_llm_response(
@ -139,17 +137,13 @@ def process_edge_invalidation_llm_response(
async def extract_edge_dates(
llm_client: LLMClient,
edge_triplet: NodeEdgeNodeTriplet,
edge: EntityEdge,
reference_time: datetime,
current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode],
) -> tuple[datetime | None, datetime | None, str]:
source_node, edge, target_node = edge_triplet
context = {
'source_node': source_node.name,
'edge_name': edge.name,
'target_node': target_node.name,
'edge_fact': edge.fact,
'current_episode': current_episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],

View file

@ -95,7 +95,6 @@ async def main():
await clear_data(client.driver)
await client.build_indices_and_constraints()
# await client.build_indices()
for i, message in enumerate(bmw_sales):
await client.add_episode(
name=f'Message {i}',