This commit is contained in:
prestonrasmussen 2025-09-08 09:53:20 -04:00
parent 8e442d4634
commit b036d45329
4 changed files with 75 additions and 83 deletions

View file

@ -195,6 +195,9 @@ class GraphDriver(ABC):
async def create_aoss_indices(self): async def create_aoss_indices(self):
client = self.aoss_client client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
for index in aoss_indices: for index in aoss_indices:
alias_name = index['index_name'] alias_name = index['index_name']
@ -220,10 +223,20 @@ class GraphDriver(ABC):
for index in aoss_indices: for index in aoss_indices:
index_name = index['index_name'] index_name = index['index_name']
client = self.aoss_client client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
if client.indices.exists(index=index_name): if client.indices.exists(index=index_name):
client.indices.delete(index=index_name) client.indices.delete(index=index_name)
def save_to_aoss(self, name: str, data: list[dict]) -> int: def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return 0
for index in aoss_indices: for index in aoss_indices:
if name.lower() == index['index_name']: if name.lower() == index['index_name']:
to_index = [] to_index = []
@ -237,7 +250,7 @@ class GraphDriver(ABC):
item[p] = d[p] item[p] = d[p]
to_index.append(item) to_index.append(item)
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True) success, failed = helpers.bulk(client, to_index, stats_only=True)
return success if failed == 0 else success return success if failed == 0 else success

View file

@ -39,6 +39,7 @@ logger = logging.getLogger(__name__)
class FalkorDriverSession(GraphDriverSession): class FalkorDriverSession(GraphDriverSession):
provider = GraphProvider.FALKORDB provider = GraphProvider.FALKORDB
aoss_client: None
def __init__(self, graph: FalkorGraph): def __init__(self, graph: FalkorGraph):
self.graph = graph self.graph = graph

View file

@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
class KuzuDriver(GraphDriver): class KuzuDriver(GraphDriver):
provider: GraphProvider = GraphProvider.KUZU provider: GraphProvider = GraphProvider.KUZU
aoss_client: None
def __init__( def __init__(
self, self,

View file

@ -989,7 +989,7 @@ async def episode_fulltext_search(
_source=['uuid'], _source=['uuid'],
query={ query={
'bool': { 'bool': {
'filter': [{'terms': {'group_id': group_ids}}], 'filter': {'terms': group_ids},
'must': [ 'must': [
{ {
'multi_match': { 'multi_match': {
@ -1005,37 +1005,14 @@ async def episode_fulltext_search(
) )
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids input_uuids = {}
input_uuids = []
for r in res['hits']['hits']: for r in res['hits']['hits']:
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) input_uuids[r['_source']['uuid']] = r['_score']
# Match the edge ides and return the values # Get nodes
query = """ episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
UNWIND $uuids as i episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
MATCH (e:Episodic) return episodes
WHERE e.uuid=i.uuid
RETURN
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
ORDER BY i.score DESC
LIMIT $limit
"""
records, _, _ = await driver.execute_query(
query,
uuids=input_uuids,
query=fuzzy_query,
limit=limit,
routing_='r',
**filter_params,
)
else: else:
return [] return []
else: else: