cleanup
This commit is contained in:
parent
8e442d4634
commit
b036d45329
4 changed files with 75 additions and 83 deletions
|
|
@ -195,6 +195,9 @@ class GraphDriver(ABC):
|
||||||
|
|
||||||
async def create_aoss_indices(self):
|
async def create_aoss_indices(self):
|
||||||
client = self.aoss_client
|
client = self.aoss_client
|
||||||
|
if not client:
|
||||||
|
logger.warning('No OpenSearch client found')
|
||||||
|
return
|
||||||
|
|
||||||
for index in aoss_indices:
|
for index in aoss_indices:
|
||||||
alias_name = index['index_name']
|
alias_name = index['index_name']
|
||||||
|
|
@ -220,10 +223,20 @@ class GraphDriver(ABC):
|
||||||
for index in aoss_indices:
|
for index in aoss_indices:
|
||||||
index_name = index['index_name']
|
index_name = index['index_name']
|
||||||
client = self.aoss_client
|
client = self.aoss_client
|
||||||
|
|
||||||
|
if not client:
|
||||||
|
logger.warning('No OpenSearch client found')
|
||||||
|
return
|
||||||
|
|
||||||
if client.indices.exists(index=index_name):
|
if client.indices.exists(index=index_name):
|
||||||
client.indices.delete(index=index_name)
|
client.indices.delete(index=index_name)
|
||||||
|
|
||||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
|
client = self.aoss_client
|
||||||
|
if not client:
|
||||||
|
logger.warning('No OpenSearch client found')
|
||||||
|
return 0
|
||||||
|
|
||||||
for index in aoss_indices:
|
for index in aoss_indices:
|
||||||
if name.lower() == index['index_name']:
|
if name.lower() == index['index_name']:
|
||||||
to_index = []
|
to_index = []
|
||||||
|
|
@ -237,7 +250,7 @@ class GraphDriver(ABC):
|
||||||
item[p] = d[p]
|
item[p] = d[p]
|
||||||
to_index.append(item)
|
to_index.append(item)
|
||||||
|
|
||||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
success, failed = helpers.bulk(client, to_index, stats_only=True)
|
||||||
|
|
||||||
return success if failed == 0 else success
|
return success if failed == 0 else success
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class FalkorDriverSession(GraphDriverSession):
|
class FalkorDriverSession(GraphDriverSession):
|
||||||
provider = GraphProvider.FALKORDB
|
provider = GraphProvider.FALKORDB
|
||||||
|
aoss_client: None
|
||||||
|
|
||||||
def __init__(self, graph: FalkorGraph):
|
def __init__(self, graph: FalkorGraph):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
|
||||||
|
|
||||||
class KuzuDriver(GraphDriver):
|
class KuzuDriver(GraphDriver):
|
||||||
provider: GraphProvider = GraphProvider.KUZU
|
provider: GraphProvider = GraphProvider.KUZU
|
||||||
|
aoss_client: None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -989,7 +989,7 @@ async def episode_fulltext_search(
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
query={
|
query={
|
||||||
'bool': {
|
'bool': {
|
||||||
'filter': [{'terms': {'group_id': group_ids}}],
|
'filter': {'terms': group_ids},
|
||||||
'must': [
|
'must': [
|
||||||
{
|
{
|
||||||
'multi_match': {
|
'multi_match': {
|
||||||
|
|
@ -1005,37 +1005,14 @@ async def episode_fulltext_search(
|
||||||
)
|
)
|
||||||
|
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
# Calculate Cosine similarity then return the edge ids
|
input_uuids = {}
|
||||||
input_uuids = []
|
|
||||||
for r in res['hits']['hits']:
|
for r in res['hits']['hits']:
|
||||||
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']})
|
input_uuids[r['_source']['uuid']] = r['_score']
|
||||||
|
|
||||||
# Match the edge ides and return the values
|
# Get nodes
|
||||||
query = """
|
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||||
UNWIND $uuids as i
|
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||||
MATCH (e:Episodic)
|
return episodes
|
||||||
WHERE e.uuid=i.uuid
|
|
||||||
RETURN
|
|
||||||
e.content AS content,
|
|
||||||
e.created_at AS created_at,
|
|
||||||
e.valid_at AS valid_at,
|
|
||||||
e.uuid AS uuid,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.source_description AS source_description,
|
|
||||||
e.source AS source,
|
|
||||||
e.entity_edges AS entity_edges
|
|
||||||
ORDER BY i.score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
uuids=input_uuids,
|
|
||||||
query=fuzzy_query,
|
|
||||||
limit=limit,
|
|
||||||
routing_='r',
|
|
||||||
**filter_params,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue