search updates (#19)

* search updates

* add helper function

* make format

* updates
This commit is contained in:
Preston Rasmussen 2024-08-22 17:24:59 -04:00 committed by GitHub
parent 6ae9c4e262
commit 94873f1083
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 13 deletions

View file

@ -5,6 +5,7 @@ from datetime import datetime
from time import time
from neo4j import AsyncDriver
from neo4j import time as neo4j_time
from core.edges import EntityEdge
from core.nodes import EntityNode, EpisodicNode
@ -14,6 +15,10 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
@ -122,8 +127,6 @@ async def edge_similarity_search(
edges: list[EntityEdge] = []
now = datetime.now()
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
@ -133,10 +136,10 @@ async def edge_similarity_search(
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
edges.append(edge)
@ -244,8 +247,6 @@ async def edge_fulltext_search(
edges: list[EntityEdge] = []
now = datetime.now()
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
@ -255,10 +256,10 @@ async def edge_fulltext_search(
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
edges.append(edge)

View file

@ -232,7 +232,7 @@ async def dedupe_edge_list(
unique_edges_data = llm_response.get('unique_edges', [])
end = time()
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms ')
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
# Get full edge data
unique_edges = []