Invalidation updates && improvements (#20)

* wip

* wip

* wip

* fix: Linter errors

* fix formatting

* chore: fix ruff

* fix: Duplication

---------

Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
Pavlo Paliychuk 2024-08-22 18:09:44 -04:00 committed by GitHub
parent 94873f1083
commit 1f1652f56c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 507 additions and 54 deletions

View file

@ -28,7 +28,10 @@ from core.utils.bulk_utils import (
resolve_edge_pointers,
retrieve_previous_episodes_bulk,
)
from core.utils.maintenance.edge_operations import dedupe_extracted_edges, extract_edges
from core.utils.maintenance.edge_operations import (
dedupe_extracted_edges,
extract_edges,
)
from core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
@ -116,6 +119,7 @@ class Graphiti:
)
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
# Calculate Embeddings
@ -124,14 +128,14 @@ class Graphiti:
)
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
new_nodes, _ = await dedupe_extracted_nodes(
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
logger.info(f'Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}')
nodes.extend(new_nodes)
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
nodes.extend(touched_nodes)
extracted_edges = await extract_edges(
self.llm_client, episode, new_nodes, previous_episodes
self.llm_client, episode, touched_nodes, previous_episodes
)
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
@ -140,10 +144,23 @@ class Graphiti:
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
# deduped_edges = await dedupe_extracted_edges_v2(
# self.llm_client,
# extract_node_and_edge_triplets(extracted_edges, nodes),
# extract_node_and_edge_triplets(existing_edges, nodes),
# )
deduped_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
self.llm_client,
extracted_edges,
existing_edges,
)
edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
for edge in deduped_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
@ -155,26 +172,36 @@ class Graphiti:
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
episode,
previous_episodes,
)
entity_edges.extend(invalidated_edges)
for edge in invalidated_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
edges_to_save = invalidated_edges
# There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one
for deduped_edge in deduped_edges:
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
edges_to_save.append(deduped_edge)
entity_edges.extend(edges_to_save)
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
entity_edges.extend(deduped_edges)
new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in new_edges]}')
entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
nodes,
involved_nodes,
episode,
now,
)
@ -182,12 +209,6 @@ class Graphiti:
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
logger.info(f'Built episodic edges: {episodic_edges}')
# invalidated_edges = await self.invalidate_edges(
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
# )
# edges.extend(invalidated_edges)
# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])

View file

@ -6,11 +6,12 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
v1: PromptVersion
edge_list: PromptVersion
v2: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
edge_list: PromptFunction
@ -54,6 +55,48 @@ def v1(context: dict[str, any]) -> list[Message]:
]
def v2(context: dict[str, any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates relationship from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Edges:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges
Guidelines:
1. Use both the triplet name and fact of edges to determine if they are duplicates,
duplicate edges may have different names meaning the same thing and slight variations in the facts.
2. If you encounter facts that are semantically equivalent or very similar, keep the original edge
Respond with a JSON object in the following format:
{{
"new_edges": [
{{
"triplet": "source_node_name-edge_name-target_node_name",
"fact": "one sentence description of the fact"
}}
]
}}
""",
),
]
def edge_list(context: dict[str, any]) -> list[Message]:
return [
Message(
@ -90,4 +133,4 @@ def edge_list(context: dict[str, any]) -> list[Message]:
]
versions: Versions = {'v1': v1, 'edge_list': edge_list}
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}

View file

@ -15,14 +15,20 @@ def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role='system',
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.',
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
),
Message(
role='user',
content=f"""
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
Do not invalidate relationships merely because they weren't mentioned in new edges.
Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.
Previous Episodes:
{context['previous_episodes']}
Current Episode:
{context['current_episode']}
Existing Edges (sorted by timestamp, newest first):
{context['existing_edges']}
@ -30,19 +36,19 @@ def v1(context: dict[str, any]) -> list[Message]:
New Edges:
{context['new_edges']}
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)"
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), TIMESTAMP)"
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
"reason": "Brief explanation of why this edge is being invalidated"
"fact": "Updated fact of the edge"
}}
]
}}
If no relationships need to be invalidated, return an empty list for "invalidated_edges".
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
""",
),
]

View file

@ -98,7 +98,7 @@ async def dedupe_nodes_bulk(
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
nodes, partial_uuid_map = await dedupe_extracted_nodes(
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes
)

View file

@ -8,6 +8,7 @@ from core.edges import EntityEdge, EpisodicEdge
from core.llm_client import LLMClient
from core.nodes import EntityNode, EpisodicNode
from core.prompts import prompt_library
from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet
logger = logging.getLogger(__name__)
@ -179,6 +180,51 @@ async def extract_edges(
return edges
def create_edge_identifier(
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
) -> str:
return f'{source_node.name}-{edge.name}-{target_node.name}'
async def dedupe_extracted_edges_v2(
llm_client: LLMClient,
extracted_edges: list[NodeEdgeNodeTriplet],
existing_edges: list[NodeEdgeNodeTriplet],
) -> list[NodeEdgeNodeTriplet]:
# Create edge map
edge_map = {}
for n1, edge, n2 in existing_edges:
edge_map[create_edge_identifier(n1, edge, n2)] = edge
for n1, edge, n2 in extracted_edges:
if create_edge_identifier(n1, edge, n2) in edge_map:
continue
edge_map[create_edge_identifier(n1, edge, n2)] = edge
# Prepare context for LLM
context = {
'extracted_edges': [
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
for n1, edge, n2 in extracted_edges
],
'existing_edges': [
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
for n1, edge, n2 in extracted_edges
],
}
logger.info(prompt_library.dedupe_edges.v2(context))
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v2(context))
new_edges_data = llm_response.get('new_edges', [])
logger.info(f'Extracted new edges: {new_edges_data}')
# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data['triplet']]
edges.append(edge)
return edges
async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],

View file

@ -108,6 +108,11 @@ async def dedupe_extracted_nodes(
for node in existing_nodes:
node_map[node.name] = node
# Temp hack
new_nodes_map = {}
for node in extracted_nodes:
new_nodes_map[node.name] = node
# Prepare context for LLM
existing_nodes_context = [
{'name': node.name, 'summary': node.summary} for node in existing_nodes
@ -131,20 +136,25 @@ async def dedupe_extracted_nodes(
uuid_map = {}
for duplicate in duplicate_data:
uuid = node_map[duplicate['name']].uuid
uuid = new_nodes_map[duplicate['name']].uuid
uuid_value = node_map[duplicate['duplicate_of']].uuid
uuid_map[uuid] = uuid_value
nodes = []
brand_new_nodes = []
for node in extracted_nodes:
if node.uuid in uuid_map:
existing_name = uuid_map[node.name]
existing_node = node_map[existing_name]
existing_uuid = uuid_map[node.uuid]
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
nodes.append(existing_node)
continue
brand_new_nodes.append(node)
nodes.append(node)
return nodes, uuid_map
return nodes, uuid_map, brand_new_nodes
async def dedupe_node_list(

View file

@ -4,7 +4,7 @@ from typing import List
from core.edges import EntityEdge
from core.llm_client import LLMClient
from core.nodes import EntityNode
from core.nodes import EntityNode, EpisodicNode
from core.prompts import prompt_library
logger = logging.getLogger(__name__)
@ -12,6 +12,20 @@ logger = logging.getLogger(__name__)
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
def extract_node_and_edge_triplets(
edges: list[EntityEdge], nodes: list[EntityNode]
) -> list[NodeEdgeNodeTriplet]:
return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
def extract_node_edge_node_triplet(
edge: EntityEdge, nodes: list[EntityNode]
) -> NodeEdgeNodeTriplet:
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
return (source_node, edge, target_node)
def prepare_edges_for_invalidation(
existing_edges: list[EntityEdge],
new_edges: list[EntityEdge],
@ -39,13 +53,22 @@ def prepare_edges_for_invalidation(
async def invalidate_edges(
llm_client: LLMClient,
existing_edges_pending_invalidation: List[NodeEdgeNodeTriplet],
new_edges: List[NodeEdgeNodeTriplet],
) -> List[EntityEdge]:
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
new_edges: list[NodeEdgeNodeTriplet],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
invalidated_edges = [] # TODO: this is not yet used?
context = prepare_invalidation_context(existing_edges_pending_invalidation, new_edges)
context = prepare_invalidation_context(
existing_edges_pending_invalidation,
new_edges,
current_episode,
previous_episodes,
)
logger.info(prompt_library.invalidate_edges.v1(context))
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
logger.info(f'invalidate_edges LLM response: {llm_response}')
edges_to_invalidate = llm_response.get('invalidated_edges', [])
invalidated_edges = process_edge_invalidation_llm_response(
@ -56,21 +79,26 @@ async def invalidate_edges(
def prepare_invalidation_context(
existing_edges: List[NodeEdgeNodeTriplet], new_edges: List[NodeEdgeNodeTriplet]
existing_edges: list[NodeEdgeNodeTriplet],
new_edges: list[NodeEdgeNodeTriplet],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> dict:
return {
'existing_edges': [
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})'
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})'
for source_node, edge, target_node in sorted(
existing_edges, key=lambda x: x[1].created_at, reverse=True
)
],
'new_edges': [
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})'
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})'
for source_node, edge, target_node in sorted(
new_edges, key=lambda x: x[1].created_at, reverse=True
)
],
'current_episode': current_episode.content,
'previous_episodes': [episode.content for episode in previous_episodes],
}
@ -86,8 +114,9 @@ def process_edge_invalidation_llm_response(
)
if edge_to_update:
edge_to_update.expired_at = datetime.now()
edge_to_update.fact = edge_to_invalidate['fact']
invalidated_edges.append(edge_to_update)
logger.info(
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Reason: {edge_to_invalidate['reason']}"
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
)
return invalidated_edges

View file

@ -0,0 +1,292 @@
import asyncio
import logging
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from neo4j import time as neo4j_time
from core.edges import EntityEdge
from core.nodes import EntityNode
logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
m.uuid AS target_node_uuid,
m.name AS target_name,
m.summary AS target_summary,
r.uuid AS uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context = {}
for record in records:
n_uuid = record['source_node_uuid']
if n_uuid in context:
context[n_uuid]['facts'].append(record['fact'])
else:
context[n_uuid] = {
'name': record['source_name'],
'summary': record['source_summary'],
'facts': [record['fact']],
}
m_uuid = record['target_node_uuid']
if m_uuid not in context:
context[m_uuid] = {
'name': record['target_name'],
'summary': record['target_summary'],
'facts': [],
}
logger.info(f'bfs search returned context: {context}')
return context
async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
search_vector=search_vector,
limit=limit,
)
edges: list[EntityEdge] = []
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=safely_parse_db_date(record['created_at']),
expired_at=safely_parse_db_date(record['expired_at']),
valid_at=safely_parse_db_date(record['valid_at']),
invalid_At=safely_parse_db_date(record['invalid_at']),
)
edges.append(edge)
return edges
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=[],
created_at=safely_parse_db_date(record['created_at']),
summary=record['summary'],
)
)
return nodes
async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = query + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
node.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=[],
created_at=safely_parse_db_date(record['created_at']),
summary=record['summary'],
)
)
return nodes
async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = query + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
edges: list[EntityEdge] = []
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=safely_parse_db_date(record['created_at']),
expired_at=safely_parse_db_date(record['expired_at']),
valid_at=safely_parse_db_date(record['valid_at']),
invalid_At=safely_parse_db_date(record['invalid_at']),
)
edges.append(edge)
return edges
def safely_parse_db_date(date_str: neo4j_time.Date) -> datetime:
if date_str:
return datetime.fromisoformat(date_str.iso_format())
return None
async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
start = time()
relevant_nodes: list[EntityNode] = []
relevant_node_uuids = set()
results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes],
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
)
for result in results:
for node in result:
if node.uuid in relevant_node_uuids:
continue
relevant_node_uuids.add(node.uuid)
relevant_nodes.append(node)
end = time()
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
return relevant_nodes
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()
results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
)
for result in results:
for edge in result:
if edge.uuid in relevant_edge_uuids:
continue
relevant_edge_uuids.add(edge.uuid)
relevant_edges.append(edge)
end = time()
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
return relevant_edges

View file

@ -2,6 +2,7 @@ import asyncio
import logging
import os
import sys
from datetime import datetime
from dotenv import load_dotenv
@ -42,26 +43,31 @@ async def main():
await clear_data(client.driver)
# await client.build_indices()
await client.add_episode(
name='Message 1',
episode_body='Paul: I love apples',
source_description='WhatsApp Message',
)
await client.add_episode(
name='Message 2',
episode_body='Paul: I hate apples now',
source_description='WhatsApp Message',
)
await client.add_episode(
name='Message 3',
episode_body='Jane: I am married to Paul',
source_description='WhatsApp Message',
reference_time=datetime.now(),
)
await client.add_episode(
name='Message 4',
episode_body='Paul: I have divorced Jane',
source_description='WhatsApp Message',
reference_time=datetime.now(),
)
await client.add_episode(
name='Message 5',
episode_body='Jane: I miss Paul',
source_description='WhatsApp Message',
reference_time=datetime.now(),
)
await client.add_episode(
name='Message 6',
episode_body='Jane: I dont miss Paul anymore, I hate him',
source_description='WhatsApp Message',
reference_time=datetime.now(),
)
# await client.add_episode(
# name="Message 3",
# episode_body="Assistant: The best type of apples available are Fuji apples",