From 94873f10836c0b6adaec58898237d0bb95bab91f Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Thu, 22 Aug 2024 17:24:59 -0400 Subject: [PATCH] search updates (#19) * search updates * add helper function * make format * updates --- core/search/search_utils.py | 25 ++++++++++++----------- core/utils/maintenance/edge_operations.py | 2 +- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/core/search/search_utils.py b/core/search/search_utils.py index 73e06a33..c230cae6 100644 --- a/core/search/search_utils.py +++ b/core/search/search_utils.py @@ -5,6 +5,7 @@ from datetime import datetime from time import time from neo4j import AsyncDriver +from neo4j import time as neo4j_time from core.edges import EntityEdge from core.nodes import EntityNode, EpisodicNode @@ -14,6 +15,10 @@ logger = logging.getLogger(__name__) RELEVANT_SCHEMA_LIMIT = 3 +def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None: + return neo_date.to_native() if neo_date else None + + async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]): episode_uuids = [episode.uuid for episode in episodes] records, _, _ = await driver.execute_query( @@ -122,8 +127,6 @@ async def edge_similarity_search( edges: list[EntityEdge] = [] - now = datetime.now() - for record in records: edge = EntityEdge( uuid=record['uuid'], @@ -133,10 +136,10 @@ async def edge_similarity_search( name=record['name'], episodes=record['episodes'], fact_embedding=record['fact_embedding'], - created_at=now, - expired_at=now, - valid_at=now, - invalid_At=now, + created_at=record['created_at'].to_native(), + expired_at=parse_db_date(record['expired_at']), + valid_at=parse_db_date(record['valid_at']), + invalid_at=parse_db_date(record['invalid_at']), ) edges.append(edge) @@ -244,8 +247,6 @@ async def edge_fulltext_search( edges: list[EntityEdge] = [] - now = datetime.now() - for record in records: edge = EntityEdge( uuid=record['uuid'], @@ -255,10 +256,10 @@ async def edge_fulltext_search( name=record['name'], episodes=record['episodes'], fact_embedding=record['fact_embedding'], - created_at=now, - expired_at=now, - valid_at=now, - invalid_At=now, + created_at=record['created_at'].to_native(), + expired_at=parse_db_date(record['expired_at']), + valid_at=parse_db_date(record['valid_at']), + invalid_at=parse_db_date(record['invalid_at']), ) edges.append(edge) diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index dae62407..7c61caf7 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -232,7 +232,7 @@ async def dedupe_edge_list( unique_edges_data = llm_response.get('unique_edges', []) end = time() - logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms ') + logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ') # Get full edge data unique_edges = []