feat: adds feedback weights to edges
This commit is contained in:
parent
c6ec22a5a0
commit
4a5d5f70d0
4 changed files with 74 additions and 0 deletions
|
|
@ -1654,3 +1654,41 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
id_list = [row[0] for row in rows]
|
id_list = [row[0] for row in rows]
|
||||||
return id_list
|
return id_list
|
||||||
|
|
||||||
|
async def apply_feedback_weight(
|
||||||
|
self,
|
||||||
|
node_ids: List[str],
|
||||||
|
weight: float,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Increment `feedback_weight` inside r.properties JSON for edges where
|
||||||
|
relationship_name = 'used_graph_element_to_answer'.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Step 1: fetch matching edges
|
||||||
|
query = """
|
||||||
|
MATCH (n:Node)-[r:EDGE]->()
|
||||||
|
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
|
||||||
|
RETURN r.properties, n.id
|
||||||
|
"""
|
||||||
|
results = await self.query(query, {"node_ids": node_ids})
|
||||||
|
|
||||||
|
# Step 2: update JSON client-side
|
||||||
|
updates = []
|
||||||
|
for props_json, source_id in results:
|
||||||
|
try:
|
||||||
|
props = json.loads(props_json) if props_json else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
props = {}
|
||||||
|
|
||||||
|
props["feedback_weight"] = props.get("feedback_weight", 0) + weight
|
||||||
|
updates.append((source_id, json.dumps(props)))
|
||||||
|
|
||||||
|
# Step 3: write back
|
||||||
|
for node_id, new_props in updates:
|
||||||
|
update_query = """
|
||||||
|
MATCH (n:Node)-[r:EDGE]->()
|
||||||
|
WHERE n.id = $node_id AND r.relationship_name = 'used_graph_element_to_answer'
|
||||||
|
SET r.properties = $props
|
||||||
|
"""
|
||||||
|
await self.query(update_query, {"node_id": node_id, "props": new_props})
|
||||||
|
|
|
||||||
|
|
@ -1345,3 +1345,29 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
id_list = [row["id"] for row in rows if "id" in row]
|
id_list = [row["id"] for row in rows if "id" in row]
|
||||||
return id_list
|
return id_list
|
||||||
|
|
||||||
|
async def apply_feedback_weight(
|
||||||
|
self,
|
||||||
|
node_ids: List[str],
|
||||||
|
weight: float,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Increment `feedback_weight` on relationships `:used_graph_element_to_answer`
|
||||||
|
outgoing from nodes whose `id` is in `node_ids`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_ids: List of node IDs to match.
|
||||||
|
weight: Amount to add to `r.feedback_weight` (can be negative).
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
Updates relationship property `feedback_weight`, defaulting missing values to 0.
|
||||||
|
"""
|
||||||
|
query = """
|
||||||
|
MATCH (n)-[r]->()
|
||||||
|
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
|
||||||
|
SET r.feedback_weight = coalesce(r.feedback_weight, 0) + $weight
|
||||||
|
"""
|
||||||
|
await self.query(
|
||||||
|
query,
|
||||||
|
params={"weight": float(weight), "node_ids": list(node_ids)},
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -248,6 +248,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
"source_node_id": source_id,
|
"source_node_id": source_id,
|
||||||
"target_node_id": target_id_1,
|
"target_node_id": target_id_1,
|
||||||
"ontology_valid": False,
|
"ontology_valid": False,
|
||||||
|
"feedback_weight": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -262,6 +263,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
"source_node_id": source_id,
|
"source_node_id": source_id,
|
||||||
"target_node_id": target_id_2,
|
"target_node_id": target_id_2,
|
||||||
"ontology_valid": False,
|
"ontology_valid": False,
|
||||||
|
"feedback_weight": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ class UserQAFeedback(BaseFeedback):
|
||||||
|
|
||||||
relationships = []
|
relationships = []
|
||||||
relationship_name = "gives_feedback_to"
|
relationship_name = "gives_feedback_to"
|
||||||
|
to_node_ids = []
|
||||||
|
|
||||||
for interaction_id in last_interaction_ids:
|
for interaction_id in last_interaction_ids:
|
||||||
target_id_1 = feedback_id
|
target_id_1 = feedback_id
|
||||||
|
|
@ -70,9 +71,16 @@ class UserQAFeedback(BaseFeedback):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
to_node_ids.append(str(target_id_2))
|
||||||
|
|
||||||
|
|
||||||
if len(relationships) > 0:
|
if len(relationships) > 0:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
await graph_engine.add_edges(relationships)
|
await graph_engine.add_edges(relationships)
|
||||||
|
await graph_engine.apply_feedback_weight(
|
||||||
|
node_ids=to_node_ids,
|
||||||
|
weight=feedback_sentiment.score
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
return [feedback_text]
|
return [feedback_text]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue