chore: Update the context for date extraction + bug fixes (#31)
* chore: Update the context for date extraction + bug fixes * chore: Remove logs
This commit is contained in:
parent
c2aaf94be4
commit
427a67b8f8
4 changed files with 16 additions and 21 deletions
|
|
@ -55,7 +55,6 @@ from core.utils.maintenance.graph_data_operations import (
|
||||||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||||
from core.utils.maintenance.temporal_operations import (
|
from core.utils.maintenance.temporal_operations import (
|
||||||
extract_edge_dates,
|
extract_edge_dates,
|
||||||
extract_node_edge_node_triplet,
|
|
||||||
invalidate_edges,
|
invalidate_edges,
|
||||||
prepare_edges_for_invalidation,
|
prepare_edges_for_invalidation,
|
||||||
)
|
)
|
||||||
|
|
@ -183,22 +182,27 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
|
|
||||||
for edge in invalidated_edges:
|
for edge in invalidated_edges:
|
||||||
|
for existing_edge in existing_edges:
|
||||||
|
if existing_edge.uuid == edge.uuid:
|
||||||
|
existing_edge.expired_at = edge.expired_at
|
||||||
|
for deduped_edge in deduped_edges:
|
||||||
|
if deduped_edge.uuid == edge.uuid:
|
||||||
|
deduped_edge.expired_at = edge.expired_at
|
||||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
edge_touched_node_uuids.append(edge.source_node_uuid)
|
||||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
edge_touched_node_uuids.append(edge.target_node_uuid)
|
||||||
|
|
||||||
edges_to_save = invalidated_edges
|
edges_to_save = existing_edges + deduped_edges
|
||||||
|
|
||||||
# There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one
|
for edge_to_extract_dates_from in edges_to_save:
|
||||||
for deduped_edge in deduped_edges:
|
|
||||||
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
|
|
||||||
edges_to_save.append(deduped_edge)
|
|
||||||
for deduped_edge in deduped_edges:
|
|
||||||
triplet = extract_node_edge_node_triplet(deduped_edge, nodes)
|
|
||||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||||
self.llm_client, triplet, episode.valid_at, episode, previous_episodes
|
self.llm_client,
|
||||||
|
edge_to_extract_dates_from,
|
||||||
|
episode.valid_at,
|
||||||
|
episode,
|
||||||
|
previous_episodes,
|
||||||
)
|
)
|
||||||
deduped_edge.valid_at = valid_at
|
edge_to_extract_dates_from.valid_at = valid_at
|
||||||
deduped_edge.invalid_at = invalid_at
|
edge_to_extract_dates_from.invalid_at = invalid_at
|
||||||
entity_edges.extend(edges_to_save)
|
entity_edges.extend(edges_to_save)
|
||||||
|
|
||||||
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
|
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
Edge:
|
Edge:
|
||||||
Source Node: {context['source_node']}
|
|
||||||
Edge Name: {context['edge_name']}
|
Edge Name: {context['edge_name']}
|
||||||
Target Node: {context['target_node']}
|
|
||||||
Fact: {context['edge_fact']}
|
Fact: {context['edge_fact']}
|
||||||
|
|
||||||
Current Episode: {context['current_episode']}
|
Current Episode: {context['current_episode']}
|
||||||
|
|
|
||||||
|
|
@ -81,9 +81,7 @@ async def invalidate_edges(
|
||||||
current_episode,
|
current_episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
)
|
)
|
||||||
logger.info(prompt_library.invalidate_edges.v1(context))
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
|
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
|
||||||
logger.info(f'invalidate_edges LLM response: {llm_response}')
|
|
||||||
|
|
||||||
edges_to_invalidate = llm_response.get('invalidated_edges', [])
|
edges_to_invalidate = llm_response.get('invalidated_edges', [])
|
||||||
invalidated_edges = process_edge_invalidation_llm_response(
|
invalidated_edges = process_edge_invalidation_llm_response(
|
||||||
|
|
@ -139,17 +137,13 @@ def process_edge_invalidation_llm_response(
|
||||||
|
|
||||||
async def extract_edge_dates(
|
async def extract_edge_dates(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
edge_triplet: NodeEdgeNodeTriplet,
|
edge: EntityEdge,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
current_episode: EpisodicNode,
|
current_episode: EpisodicNode,
|
||||||
previous_episodes: List[EpisodicNode],
|
previous_episodes: List[EpisodicNode],
|
||||||
) -> tuple[datetime | None, datetime | None, str]:
|
) -> tuple[datetime | None, datetime | None, str]:
|
||||||
source_node, edge, target_node = edge_triplet
|
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
'source_node': source_node.name,
|
|
||||||
'edge_name': edge.name,
|
'edge_name': edge.name,
|
||||||
'target_node': target_node.name,
|
|
||||||
'edge_fact': edge.fact,
|
'edge_fact': edge.fact,
|
||||||
'current_episode': current_episode.content,
|
'current_episode': current_episode.content,
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,6 @@ async def main():
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
await client.build_indices_and_constraints()
|
await client.build_indices_and_constraints()
|
||||||
|
|
||||||
# await client.build_indices()
|
|
||||||
for i, message in enumerate(bmw_sales):
|
for i, message in enumerate(bmw_sales):
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue