feat: adds feedback weights to edges
This commit is contained in:
parent
c6ec22a5a0
commit
4a5d5f70d0
4 changed files with 74 additions and 0 deletions
|
|
@ -1654,3 +1654,41 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
id_list = [row[0] for row in rows]
|
||||
return id_list
|
||||
|
||||
async def apply_feedback_weight(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
weight: float,
|
||||
) -> None:
|
||||
"""
|
||||
Increment `feedback_weight` inside r.properties JSON for edges where
|
||||
relationship_name = 'used_graph_element_to_answer'.
|
||||
|
||||
"""
|
||||
# Step 1: fetch matching edges
|
||||
query = """
|
||||
MATCH (n:Node)-[r:EDGE]->()
|
||||
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
|
||||
RETURN r.properties, n.id
|
||||
"""
|
||||
results = await self.query(query, {"node_ids": node_ids})
|
||||
|
||||
# Step 2: update JSON client-side
|
||||
updates = []
|
||||
for props_json, source_id in results:
|
||||
try:
|
||||
props = json.loads(props_json) if props_json else {}
|
||||
except json.JSONDecodeError:
|
||||
props = {}
|
||||
|
||||
props["feedback_weight"] = props.get("feedback_weight", 0) + weight
|
||||
updates.append((source_id, json.dumps(props)))
|
||||
|
||||
# Step 3: write back
|
||||
for node_id, new_props in updates:
|
||||
update_query = """
|
||||
MATCH (n:Node)-[r:EDGE]->()
|
||||
WHERE n.id = $node_id AND r.relationship_name = 'used_graph_element_to_answer'
|
||||
SET r.properties = $props
|
||||
"""
|
||||
await self.query(update_query, {"node_id": node_id, "props": new_props})
|
||||
|
|
|
|||
|
|
@ -1345,3 +1345,29 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
id_list = [row["id"] for row in rows if "id" in row]
|
||||
return id_list
|
||||
|
||||
async def apply_feedback_weight(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
weight: float,
|
||||
) -> None:
|
||||
"""
|
||||
Increment `feedback_weight` on relationships `:used_graph_element_to_answer`
|
||||
outgoing from nodes whose `id` is in `node_ids`.
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs to match.
|
||||
weight: Amount to add to `r.feedback_weight` (can be negative).
|
||||
|
||||
Side effects:
|
||||
Updates relationship property `feedback_weight`, defaulting missing values to 0.
|
||||
"""
|
||||
query = """
|
||||
MATCH (n)-[r]->()
|
||||
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
|
||||
SET r.feedback_weight = coalesce(r.feedback_weight, 0) + $weight
|
||||
"""
|
||||
await self.query(
|
||||
query,
|
||||
params={"weight": float(weight), "node_ids": list(node_ids)},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -248,6 +248,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
"source_node_id": source_id,
|
||||
"target_node_id": target_id_1,
|
||||
"ontology_valid": False,
|
||||
"feedback_weight": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -262,6 +263,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
"source_node_id": source_id,
|
||||
"target_node_id": target_id_2,
|
||||
"ontology_valid": False,
|
||||
"feedback_weight": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ class UserQAFeedback(BaseFeedback):
|
|||
|
||||
relationships = []
|
||||
relationship_name = "gives_feedback_to"
|
||||
to_node_ids = []
|
||||
|
||||
for interaction_id in last_interaction_ids:
|
||||
target_id_1 = feedback_id
|
||||
|
|
@ -70,9 +71,16 @@ class UserQAFeedback(BaseFeedback):
|
|||
},
|
||||
)
|
||||
)
|
||||
to_node_ids.append(str(target_id_2))
|
||||
|
||||
|
||||
if len(relationships) > 0:
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.add_edges(relationships)
|
||||
await graph_engine.apply_feedback_weight(
|
||||
node_ids=to_node_ids,
|
||||
weight=feedback_sentiment.score
|
||||
)
|
||||
|
||||
|
||||
return [feedback_text]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue