load embeddings from aoss

This commit is contained in:
prestonrasmussen 2025-09-07 12:51:34 -04:00
parent af2a736002
commit 58c1f7e395
4 changed files with 46 additions and 23 deletions

View file

@ -38,7 +38,7 @@ class GraphProvider(Enum):
aoss_indices = [
{
'index_name': 'node_name_and_summary',
'index_name': 'entities',
'body': {
'mappings': {
'properties': {
@ -68,7 +68,7 @@ aoss_indices = [
},
},
{
'index_name': 'community_name',
'index_name': 'communities',
'body': {
'mappings': {
'properties': {
@ -84,7 +84,7 @@ aoss_indices = [
},
},
{
'index_name': 'episode_content',
'index_name': 'episodes',
'body': {
'mappings': {
'properties': {
@ -109,7 +109,7 @@ aoss_indices = [
},
},
{
'index_name': 'edge_name_and_fact',
'index_name': 'entity_edges',
'body': {
'mappings': {
'properties': {

View file

@ -255,6 +255,20 @@ class EntityEdge(Edge):
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
elif driver.aoss_client:
resp = driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index='entity_edges',
)
if resp['hits']['hits']:
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
return
else:
raise EdgeNotFoundError(self.uuid)
if driver.provider == GraphProvider.KUZU:
query = """

View file

@ -273,20 +273,6 @@ class EpisodicNode(Node):
)
async def save(self, driver: GraphDriver):
if driver.aoss_client:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episode_content',
[
{
'uuid': self.uuid,
'group_id': self.group_id,
'source': self.source.value,
'content': self.content,
'source_description': self.source_description,
}
],
)
episode_args = {
'uuid': self.uuid,
'name': self.name,
@ -299,6 +285,12 @@ class EpisodicNode(Node):
'source': self.source.value,
}
if driver.aoss_client:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episode_content',
[episode_args],
)
result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args
)
@ -433,6 +425,21 @@ class EntityNode(Node):
MATCH (n:Entity {uuid: $uuid})
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
"""
elif driver.aoss_client:
resp = driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index='entities',
)
if resp['hits']['hits']:
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
return
else:
raise NodeNotFoundError(self.uuid)
else:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})

View file

@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if driver.provider == GraphProvider.NEPTUNE:
if driver.aoss_client:
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
return
if delete_existing:
@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
range_indices: list[LiteralString] = get_range_indices(driver.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
# Don't create fulltext indices if OpenSearch is being used
if not driver.aoss_client:
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
if driver.provider == GraphProvider.KUZU:
# Skip creating fulltext indices if they already exist. Need to do this manually
@ -149,9 +151,9 @@ async def retrieve_episodes(
query: LiteralString = (
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ query_filter
+ """
RETURN