search updates (#19)

* search updates

* add helper function

* make format

* updates
This commit is contained in:
Preston Rasmussen 2024-08-22 17:24:59 -04:00 committed by GitHub
parent 6ae9c4e262
commit 94873f1083
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 13 deletions

View file

@ -5,6 +5,7 @@ from datetime import datetime
from time import time from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
from neo4j import time as neo4j_time
from core.edges import EntityEdge from core.edges import EntityEdge
from core.nodes import EntityNode, EpisodicNode from core.nodes import EntityNode, EpisodicNode
@ -14,6 +15,10 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3 RELEVANT_SCHEMA_LIMIT = 3
def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]): async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes] episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -122,8 +127,6 @@ async def edge_similarity_search(
edges: list[EntityEdge] = [] edges: list[EntityEdge] = []
now = datetime.now()
for record in records: for record in records:
edge = EntityEdge( edge = EntityEdge(
uuid=record['uuid'], uuid=record['uuid'],
@ -133,10 +136,10 @@ async def edge_similarity_search(
name=record['name'], name=record['name'],
episodes=record['episodes'], episodes=record['episodes'],
fact_embedding=record['fact_embedding'], fact_embedding=record['fact_embedding'],
created_at=now, created_at=record['created_at'].to_native(),
expired_at=now, expired_at=parse_db_date(record['expired_at']),
valid_at=now, valid_at=parse_db_date(record['valid_at']),
invalid_At=now, invalid_at=parse_db_date(record['invalid_at']),
) )
edges.append(edge) edges.append(edge)
@ -244,8 +247,6 @@ async def edge_fulltext_search(
edges: list[EntityEdge] = [] edges: list[EntityEdge] = []
now = datetime.now()
for record in records: for record in records:
edge = EntityEdge( edge = EntityEdge(
uuid=record['uuid'], uuid=record['uuid'],
@ -255,10 +256,10 @@ async def edge_fulltext_search(
name=record['name'], name=record['name'],
episodes=record['episodes'], episodes=record['episodes'],
fact_embedding=record['fact_embedding'], fact_embedding=record['fact_embedding'],
created_at=now, created_at=record['created_at'].to_native(),
expired_at=now, expired_at=parse_db_date(record['expired_at']),
valid_at=now, valid_at=parse_db_date(record['valid_at']),
invalid_At=now, invalid_at=parse_db_date(record['invalid_at']),
) )
edges.append(edge) edges.append(edge)

View file

@ -232,7 +232,7 @@ async def dedupe_edge_list(
unique_edges_data = llm_response.get('unique_edges', []) unique_edges_data = llm_response.get('unique_edges', [])
end = time() end = time()
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms ') logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
# Get full edge data # Get full edge data
unique_edges = [] unique_edges = []