search updates (#19)
* search updates * add helper function * make format * updates
This commit is contained in:
parent
6ae9c4e262
commit
94873f1083
2 changed files with 14 additions and 13 deletions
|
|
@ -5,6 +5,7 @@ from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
from neo4j import time as neo4j_time
|
||||||
|
|
||||||
from core.edges import EntityEdge
|
from core.edges import EntityEdge
|
||||||
from core.nodes import EntityNode, EpisodicNode
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
|
|
@ -14,6 +15,10 @@ logger = logging.getLogger(__name__)
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
RELEVANT_SCHEMA_LIMIT = 3
|
||||||
|
|
||||||
|
|
||||||
|
def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
|
||||||
|
return neo_date.to_native() if neo_date else None
|
||||||
|
|
||||||
|
|
||||||
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -122,8 +127,6 @@ async def edge_similarity_search(
|
||||||
|
|
||||||
edges: list[EntityEdge] = []
|
edges: list[EntityEdge] = []
|
||||||
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
edge = EntityEdge(
|
edge = EntityEdge(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
|
|
@ -133,10 +136,10 @@ async def edge_similarity_search(
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
episodes=record['episodes'],
|
episodes=record['episodes'],
|
||||||
fact_embedding=record['fact_embedding'],
|
fact_embedding=record['fact_embedding'],
|
||||||
created_at=now,
|
created_at=record['created_at'].to_native(),
|
||||||
expired_at=now,
|
expired_at=parse_db_date(record['expired_at']),
|
||||||
valid_at=now,
|
valid_at=parse_db_date(record['valid_at']),
|
||||||
invalid_At=now,
|
invalid_at=parse_db_date(record['invalid_at']),
|
||||||
)
|
)
|
||||||
|
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
@ -244,8 +247,6 @@ async def edge_fulltext_search(
|
||||||
|
|
||||||
edges: list[EntityEdge] = []
|
edges: list[EntityEdge] = []
|
||||||
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
edge = EntityEdge(
|
edge = EntityEdge(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
|
|
@ -255,10 +256,10 @@ async def edge_fulltext_search(
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
episodes=record['episodes'],
|
episodes=record['episodes'],
|
||||||
fact_embedding=record['fact_embedding'],
|
fact_embedding=record['fact_embedding'],
|
||||||
created_at=now,
|
created_at=record['created_at'].to_native(),
|
||||||
expired_at=now,
|
expired_at=parse_db_date(record['expired_at']),
|
||||||
valid_at=now,
|
valid_at=parse_db_date(record['valid_at']),
|
||||||
invalid_At=now,
|
invalid_at=parse_db_date(record['invalid_at']),
|
||||||
)
|
)
|
||||||
|
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
|
||||||
|
|
@ -232,7 +232,7 @@ async def dedupe_edge_list(
|
||||||
unique_edges_data = llm_response.get('unique_edges', [])
|
unique_edges_data = llm_response.get('unique_edges', [])
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms ')
|
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
||||||
|
|
||||||
# Get full edge data
|
# Get full edge data
|
||||||
unique_edges = []
|
unique_edges = []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue