load embeddings from aoss
This commit is contained in:
parent
af2a736002
commit
58c1f7e395
4 changed files with 46 additions and 23 deletions
|
|
@ -38,7 +38,7 @@ class GraphProvider(Enum):
|
|||
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': 'node_name_and_summary',
|
||||
'index_name': 'entities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -68,7 +68,7 @@ aoss_indices = [
|
|||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'community_name',
|
||||
'index_name': 'communities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -84,7 +84,7 @@ aoss_indices = [
|
|||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'episode_content',
|
||||
'index_name': 'episodes',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -109,7 +109,7 @@ aoss_indices = [
|
|||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'edge_name_and_fact',
|
||||
'index_name': 'entity_edges',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
|
|||
|
|
@ -255,6 +255,20 @@ class EntityEdge(Edge):
|
|||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||
"""
|
||||
elif driver.aoss_client:
|
||||
resp = driver.aoss_client.search(
|
||||
body={
|
||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||
'size': 1,
|
||||
},
|
||||
index='entity_edges',
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
|
||||
return
|
||||
else:
|
||||
raise EdgeNotFoundError(self.uuid)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
query = """
|
||||
|
|
|
|||
|
|
@ -273,20 +273,6 @@ class EpisodicNode(Node):
|
|||
)
|
||||
|
||||
async def save(self, driver: GraphDriver):
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episode_content',
|
||||
[
|
||||
{
|
||||
'uuid': self.uuid,
|
||||
'group_id': self.group_id,
|
||||
'source': self.source.value,
|
||||
'content': self.content,
|
||||
'source_description': self.source_description,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
episode_args = {
|
||||
'uuid': self.uuid,
|
||||
'name': self.name,
|
||||
|
|
@ -299,6 +285,12 @@ class EpisodicNode(Node):
|
|||
'source': self.source.value,
|
||||
}
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episode_content',
|
||||
[episode_args],
|
||||
)
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_episode_node_save_query(driver.provider), **episode_args
|
||||
)
|
||||
|
|
@ -433,6 +425,21 @@ class EntityNode(Node):
|
|||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||
"""
|
||||
elif driver.aoss_client:
|
||||
resp = driver.aoss_client.search(
|
||||
body={
|
||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||
'size': 1,
|
||||
},
|
||||
index='entities',
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
|
||||
return
|
||||
else:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
|
||||
else:
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
if driver.aoss_client:
|
||||
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
||||
return
|
||||
if delete_existing:
|
||||
|
|
@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|||
|
||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
||||
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
# Don't create fulltext indices if OpenSearch is being used
|
||||
if not driver.aoss_client:
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
# Skip creating fulltext indices if they already exist. Need to do this manually
|
||||
|
|
@ -149,9 +151,9 @@ async def retrieve_episodes(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ query_filter
|
||||
+ """
|
||||
RETURN
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue