don't save vectors to neo4j with aoss
This commit is contained in:
parent
67812fe3d1
commit
af2a736002
2 changed files with 53 additions and 17 deletions
|
|
@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
||||||
match provider:
|
match provider:
|
||||||
case GraphProvider.FALKORDB:
|
case GraphProvider.FALKORDB:
|
||||||
return """
|
return """
|
||||||
|
|
@ -99,17 +99,27 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
||||||
RETURN e.uuid AS uuid
|
RETURN e.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case _: # Neo4j
|
case _: # Neo4j
|
||||||
return """
|
save_embedding_query = (
|
||||||
|
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
|
||||||
|
if not has_aoss
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
"""
|
||||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||||
SET e = $edge_data
|
SET e = $edge_data"""
|
||||||
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
+ save_embedding_query
|
||||||
|
)
|
||||||
|
+ """
|
||||||
RETURN e.uuid AS uuid
|
RETURN e.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
||||||
match provider:
|
match provider:
|
||||||
case GraphProvider.FALKORDB:
|
case GraphProvider.FALKORDB:
|
||||||
return """
|
return """
|
||||||
|
|
@ -152,15 +162,23 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||||
RETURN e.uuid AS uuid
|
RETURN e.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case _:
|
case _:
|
||||||
return """
|
save_embedding_query = (
|
||||||
|
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
|
||||||
|
if not has_aoss
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"""
|
||||||
UNWIND $entity_edges AS edge
|
UNWIND $entity_edges AS edge
|
||||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
SET e = edge
|
SET e = edge"""
|
||||||
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
+ save_embedding_query
|
||||||
|
+ """
|
||||||
RETURN edge.uuid AS uuid
|
RETURN edge.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_entity_edge_return_query(provider: GraphProvider) -> str:
|
def get_entity_edge_return_query(provider: GraphProvider) -> str:
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
|
||||||
match provider:
|
match provider:
|
||||||
case GraphProvider.FALKORDB:
|
case GraphProvider.FALKORDB:
|
||||||
return f"""
|
return f"""
|
||||||
|
|
@ -161,16 +161,26 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case _:
|
case _:
|
||||||
return f"""
|
save_embedding_query = (
|
||||||
|
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
|
||||||
|
if not has_aoss
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"""
|
||||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||||
SET n:{labels}
|
SET n:{labels}
|
||||||
SET n = $entity_data
|
SET n = $entity_data"""
|
||||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
+ save_embedding_query
|
||||||
|
+ """
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
|
def get_entity_node_save_bulk_query(
|
||||||
|
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
|
||||||
|
) -> str | Any:
|
||||||
match provider:
|
match provider:
|
||||||
case GraphProvider.FALKORDB:
|
case GraphProvider.FALKORDB:
|
||||||
queries = []
|
queries = []
|
||||||
|
|
@ -222,14 +232,22 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case _: # Neo4j
|
case _: # Neo4j
|
||||||
return """
|
save_embedding_query = (
|
||||||
|
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
|
||||||
|
if not has_aoss
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MERGE (n:Entity {uuid: node.uuid})
|
MERGE (n:Entity {uuid: node.uuid})
|
||||||
SET n:$(node.labels)
|
SET n:$(node.labels)
|
||||||
SET n = node
|
SET n = node"""
|
||||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
+ save_embedding_query
|
||||||
|
+ """
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_entity_node_return_query(provider: GraphProvider) -> str:
|
def get_entity_node_return_query(provider: GraphProvider) -> str:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue