graphiti/core/search/search_utils.py
Preston Rasmussen 94873f1083
search updates (#19)
* search updates

* add helper function

* make format

* updates
2024-08-22 17:24:59 -04:00

336 lines
9 KiB
Python

import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from neo4j import time as neo4j_time
from core.edges import EntityEdge
from core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
""",
uuids=episode_uuids,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=['Entity'],
created_at=datetime.now(),
summary=record['summary'],
)
)
return nodes
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN DISTINCT
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
m.uuid AS target_node_uuid,
m.name AS target_name,
m.summary AS target_summary,
r.uuid AS uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context = {}
for record in records:
n_uuid = record['source_node_uuid']
if n_uuid in context:
context[n_uuid]['facts'].append(record['fact'])
else:
context[n_uuid] = {
'name': record['source_name'],
'summary': record['source_summary'],
'facts': [record['fact']],
}
m_uuid = record['target_node_uuid']
if m_uuid not in context:
context[m_uuid] = {
'name': record['target_name'],
'summary': record['target_summary'],
'facts': [],
}
logger.info(f'bfs search returned context: {context}')
return context
async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
search_vector=search_vector,
limit=limit,
)
edges: list[EntityEdge] = []
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
edges.append(edge)
return edges
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=['Entity'],
created_at=datetime.now(),
summary=record['summary'],
)
)
return nodes
async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = query + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
node.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=['Entity'],
created_at=datetime.now(),
summary=record['summary'],
)
)
return nodes
async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = query + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
edges: list[EntityEdge] = []
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
edges.append(edge)
return edges
async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
start = time()
relevant_nodes: list[EntityNode] = []
relevant_node_uuids = set()
results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes],
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
)
for result in results:
for node in result:
if node.uuid in relevant_node_uuids:
continue
relevant_node_uuids.add(node.uuid)
relevant_nodes.append(node)
end = time()
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
return relevant_nodes
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()
results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
)
for result in results:
for edge in result:
if edge.uuid in relevant_edge_uuids:
continue
relevant_edge_uuids.add(edge.uuid)
relevant_edges.append(edge)
end = time()
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
return relevant_edges
# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int)
for result in results:
for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const)
scored_uuids = [term for term in scores.items()]
scored_uuids.sort(reverse=True, key=lambda term: term[1])
sorted_uuids = [term[0] for term in scored_uuids]
return sorted_uuids