don't save vectors to neo4j with aoss
This commit is contained in:
parent
67812fe3d1
commit
af2a736002
2 changed files with 53 additions and 17 deletions
|
|
@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
|
|||
"""
|
||||
|
||||
|
||||
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
||||
def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
|
|
@ -99,17 +99,27 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
|||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
save_embedding_query = (
|
||||
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
|
||||
if not has_aoss
|
||||
else ''
|
||||
)
|
||||
return (
|
||||
(
|
||||
"""
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
||||
SET e = $edge_data"""
|
||||
+ save_embedding_query
|
||||
)
|
||||
+ """
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||
def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
|
|
@ -152,15 +162,23 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
|||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return """
|
||||
save_embedding_query = (
|
||||
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
|
||||
if not has_aoss
|
||||
else ''
|
||||
)
|
||||
return (
|
||||
"""
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET e = edge
|
||||
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
||||
SET e = edge"""
|
||||
+ save_embedding_query
|
||||
+ """
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_entity_edge_return_query(provider: GraphProvider) -> str:
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
|
|||
"""
|
||||
|
||||
|
||||
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
||||
def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return f"""
|
||||
|
|
@ -161,16 +161,26 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
|||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return f"""
|
||||
save_embedding_query = (
|
||||
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
|
||||
if not has_aoss
|
||||
else ''
|
||||
)
|
||||
return (
|
||||
f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
||||
SET n = $entity_data"""
|
||||
+ save_embedding_query
|
||||
+ """
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
|
||||
def get_entity_node_save_bulk_query(
|
||||
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
|
||||
) -> str | Any:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
queries = []
|
||||
|
|
@ -222,14 +232,22 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
|||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
save_embedding_query = (
|
||||
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
|
||||
if not has_aoss
|
||||
else ''
|
||||
)
|
||||
return (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||
SET n = node"""
|
||||
+ save_embedding_query
|
||||
+ """
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_entity_node_return_query(provider: GraphProvider) -> str:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue