search updates (#19)
* search updates * add helper function * make format * updates
This commit is contained in:
parent
6ae9c4e262
commit
94873f1083
2 changed files with 14 additions and 13 deletions
|
|
@ -5,6 +5,7 @@ from datetime import datetime
|
|||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from neo4j import time as neo4j_time
|
||||
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
|
|
@ -14,6 +15,10 @@ logger = logging.getLogger(__name__)
|
|||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
|
||||
return neo_date.to_native() if neo_date else None
|
||||
|
||||
|
||||
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -122,8 +127,6 @@ async def edge_similarity_search(
|
|||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
|
|
@ -133,10 +136,10 @@ async def edge_similarity_search(
|
|||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=now,
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
invalid_At=now,
|
||||
created_at=record['created_at'].to_native(),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
|
|
@ -244,8 +247,6 @@ async def edge_fulltext_search(
|
|||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
|
|
@ -255,10 +256,10 @@ async def edge_fulltext_search(
|
|||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=now,
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
invalid_At=now,
|
||||
created_at=record['created_at'].to_native(),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ async def dedupe_edge_list(
|
|||
unique_edges_data = llm_response.get('unique_edges', [])
|
||||
|
||||
end = time()
|
||||
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms ')
|
||||
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
||||
|
||||
# Get full edge data
|
||||
unique_edges = []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue