This commit is contained in:
prestonrasmussen 2025-09-12 12:12:23 -04:00
parent 836668e9ee
commit 7d06888de2
4 changed files with 64 additions and 64 deletions

View file

@ -54,7 +54,7 @@ class GraphProvider(Enum):
aoss_indices = [ aoss_indices = [
{ {
'index_name': ENTTITY_INDEX_NAME, 'index_name': ENTITY_INDEX_NAME,
'body': { 'body': {
'settings': {'index': {'knn': True}}, 'settings': {'index': {'knn': True}},
'mappings': { 'mappings': {

View file

@ -29,7 +29,7 @@ from typing_extensions import LiteralString
from graphiti_core.driver.driver import ( from graphiti_core.driver.driver import (
COMMUNITY_INDEX_NAME, COMMUNITY_INDEX_NAME,
ENTITY_EDGE_INDEX_NAME, ENTITY_EDGE_INDEX_NAME,
ENTTITY_INDEX_NAME, ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME, EPISODE_INDEX_NAME,
GraphDriver, GraphDriver,
GraphProvider, GraphProvider,
@ -117,7 +117,7 @@ class Node(BaseModel, ABC):
if driver.aoss_client: if driver.aoss_client:
# Delete the node from OpenSearch indices # Delete the node from OpenSearch indices
for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete( await driver.aoss_client.delete(
index=index, id=self.uuid, routing=self.group_id index=index, id=self.uuid, routing=self.group_id
) )
@ -202,7 +202,7 @@ class Node(BaseModel, ABC):
) )
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index=ENTTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}}, body={'query': {'term': {'group_id': group_id}}},
routing=group_id, routing=group_id,
) )
@ -328,7 +328,7 @@ class Node(BaseModel, ABC):
) )
if driver.aoss_client: if driver.aoss_client:
for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete_by_query( await driver.aoss_client.delete_by_query(
index=index, index=index,
body={'query': {'terms': {'uuid': uuids}}}, body={'query': {'terms': {'uuid': uuids}}},
@ -519,7 +519,7 @@ class EntityNode(Node):
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1, 'size': 1,
}, },
index=ENTTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=self.group_id, routing=self.group_id,
) )
@ -567,7 +567,7 @@ class EntityNode(Node):
labels = ':'.join(self.labels + ['Entity']) labels = ':'.join(self.labels + ['Entity'])
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss(ENTTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),

View file

@ -25,7 +25,7 @@ from typing_extensions import LiteralString
from graphiti_core.driver.driver import ( from graphiti_core.driver.driver import (
ENTITY_EDGE_INDEX_NAME, ENTITY_EDGE_INDEX_NAME,
ENTTITY_INDEX_NAME, ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME, EPISODE_INDEX_NAME,
GraphDriver, GraphDriver,
GraphProvider, GraphProvider,
@ -662,7 +662,7 @@ async def node_fulltext_search(
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
index=ENTTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
size=limit, size=limit,
@ -807,7 +807,7 @@ async def node_similarity_search(
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = driver.aoss_client.search(
index=ENTTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
size=limit, size=limit,

View file

@ -25,7 +25,7 @@ from typing_extensions import Any
from graphiti_core.driver.driver import ( from graphiti_core.driver.driver import (
ENTITY_EDGE_INDEX_NAME, ENTITY_EDGE_INDEX_NAME,
ENTTITY_INDEX_NAME, ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME, EPISODE_INDEX_NAME,
GraphDriver, GraphDriver,
GraphDriverSession, GraphDriverSession,
@ -211,7 +211,7 @@ async def add_nodes_and_edges_bulk_tx(
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
driver.save_to_aoss(ENTTITY_INDEX_NAME, nodes) driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges) driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)