chore: Update the context for date extraction + bug fixes (#31)
* chore: Update the context for date extraction + bug fixes * chore: Remove logs
This commit is contained in:
parent
c2aaf94be4
commit
427a67b8f8
4 changed files with 16 additions and 21 deletions
|
|
@ -55,7 +55,6 @@ from core.utils.maintenance.graph_data_operations import (
|
|||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||
from core.utils.maintenance.temporal_operations import (
|
||||
extract_edge_dates,
|
||||
extract_node_edge_node_triplet,
|
||||
invalidate_edges,
|
||||
prepare_edges_for_invalidation,
|
||||
)
|
||||
|
|
@ -183,22 +182,27 @@ class Graphiti:
|
|||
)
|
||||
|
||||
for edge in invalidated_edges:
|
||||
for existing_edge in existing_edges:
|
||||
if existing_edge.uuid == edge.uuid:
|
||||
existing_edge.expired_at = edge.expired_at
|
||||
for deduped_edge in deduped_edges:
|
||||
if deduped_edge.uuid == edge.uuid:
|
||||
deduped_edge.expired_at = edge.expired_at
|
||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
||||
|
||||
edges_to_save = invalidated_edges
|
||||
edges_to_save = existing_edges + deduped_edges
|
||||
|
||||
# There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one
|
||||
for deduped_edge in deduped_edges:
|
||||
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
|
||||
edges_to_save.append(deduped_edge)
|
||||
for deduped_edge in deduped_edges:
|
||||
triplet = extract_node_edge_node_triplet(deduped_edge, nodes)
|
||||
for edge_to_extract_dates_from in edges_to_save:
|
||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||
self.llm_client, triplet, episode.valid_at, episode, previous_episodes
|
||||
self.llm_client,
|
||||
edge_to_extract_dates_from,
|
||||
episode.valid_at,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
deduped_edge.valid_at = valid_at
|
||||
deduped_edge.invalid_at = invalid_at
|
||||
edge_to_extract_dates_from.valid_at = valid_at
|
||||
edge_to_extract_dates_from.invalid_at = invalid_at
|
||||
entity_edges.extend(edges_to_save)
|
||||
|
||||
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
|
||||
|
|
|
|||
|
|
@ -21,9 +21,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|||
role='user',
|
||||
content=f"""
|
||||
Edge:
|
||||
Source Node: {context['source_node']}
|
||||
Edge Name: {context['edge_name']}
|
||||
Target Node: {context['target_node']}
|
||||
Fact: {context['edge_fact']}
|
||||
|
||||
Current Episode: {context['current_episode']}
|
||||
|
|
|
|||
|
|
@ -81,9 +81,7 @@ async def invalidate_edges(
|
|||
current_episode,
|
||||
previous_episodes,
|
||||
)
|
||||
logger.info(prompt_library.invalidate_edges.v1(context))
|
||||
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
|
||||
logger.info(f'invalidate_edges LLM response: {llm_response}')
|
||||
|
||||
edges_to_invalidate = llm_response.get('invalidated_edges', [])
|
||||
invalidated_edges = process_edge_invalidation_llm_response(
|
||||
|
|
@ -139,17 +137,13 @@ def process_edge_invalidation_llm_response(
|
|||
|
||||
async def extract_edge_dates(
|
||||
llm_client: LLMClient,
|
||||
edge_triplet: NodeEdgeNodeTriplet,
|
||||
edge: EntityEdge,
|
||||
reference_time: datetime,
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: List[EpisodicNode],
|
||||
) -> tuple[datetime | None, datetime | None, str]:
|
||||
source_node, edge, target_node = edge_triplet
|
||||
|
||||
context = {
|
||||
'source_node': source_node.name,
|
||||
'edge_name': edge.name,
|
||||
'target_node': target_node.name,
|
||||
'edge_fact': edge.fact,
|
||||
'current_episode': current_episode.content,
|
||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ async def main():
|
|||
await clear_data(client.driver)
|
||||
await client.build_indices_and_constraints()
|
||||
|
||||
# await client.build_indices()
|
||||
for i, message in enumerate(bmw_sales):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue