don't save vectors to neo4j with aoss

This commit is contained in:
prestonrasmussen 2025-09-07 12:03:47 -04:00
parent 67812fe3d1
commit af2a736002
2 changed files with 53 additions and 17 deletions

View file

@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
""" """
def get_entity_edge_save_query(provider: GraphProvider) -> str: def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
match provider: match provider:
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
return """ return """
@ -99,17 +99,27 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
""" """
case _: # Neo4j case _: # Neo4j
return """ save_embedding_query = (
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
if not has_aoss
else ''
)
return (
(
"""
MATCH (source:Entity {uuid: $edge_data.source_uuid}) MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid}) MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data SET e = $edge_data"""
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding) + save_embedding_query
)
+ """
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
""" """
)
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
match provider: match provider:
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
return """ return """
@ -152,15 +162,23 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
""" """
case _: case _:
return """ save_embedding_query = (
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
if not has_aoss
else ''
)
return (
"""
UNWIND $entity_edges AS edge UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid}) MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid}) MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge SET e = edge"""
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding) + save_embedding_query
+ """
RETURN edge.uuid AS uuid RETURN edge.uuid AS uuid
""" """
)
def get_entity_edge_return_query(provider: GraphProvider) -> str: def get_entity_edge_return_query(provider: GraphProvider) -> str:

View file

@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
""" """
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
match provider: match provider:
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
return f""" return f"""
@ -161,16 +161,26 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: case _:
return f""" save_embedding_query = (
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
if not has_aoss
else ''
)
return (
f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}}) MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels} SET n:{labels}
SET n = $entity_data SET n = $entity_data"""
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding) + save_embedding_query
+ """
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
)
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any: def get_entity_node_save_bulk_query(
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
) -> str | Any:
match provider: match provider:
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
queries = [] queries = []
@ -222,14 +232,22 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: # Neo4j case _: # Neo4j
return """ save_embedding_query = (
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
if not has_aoss
else ''
)
return (
"""
UNWIND $nodes AS node UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid}) MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels) SET n:$(node.labels)
SET n = node SET n = node"""
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) + save_embedding_query
+ """
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
)
def get_entity_node_return_query(provider: GraphProvider) -> str: def get_entity_node_return_query(provider: GraphProvider) -> str: