adds id generation and logging to triplet embedding postprocessing
This commit is contained in:
parent
e0294b38ff
commit
0163a7b6f6
1 changed files with 21 additions and 1 deletions
|
|
@ -29,7 +29,21 @@ def create_triplet_data_point(triplet: dict) -> "TripletDataPoint":
|
|||
else:
|
||||
end_node_string = ""
|
||||
|
||||
triplet_str = start_node_string + " " + relationship + " " + end_node_string
|
||||
start_node_type = triplet.get("start_node_type", "")
|
||||
end_node_type = triplet.get("end_node_type", "")
|
||||
|
||||
triplet_str = (
|
||||
start_node_string
|
||||
+ "-"
|
||||
+ start_node_type
|
||||
+ "-"
|
||||
+ relationship
|
||||
+ "-"
|
||||
+ end_node_string
|
||||
+ "-"
|
||||
+ end_node_type
|
||||
)
|
||||
|
||||
triplet_uuid = uuid.uuid5(uuid.NAMESPACE_OID, name=triplet_str)
|
||||
|
||||
return TripletDataPoint(id=triplet_uuid, payload=json.dumps(triplet), text=triplet_str)
|
||||
|
|
@ -61,6 +75,7 @@ def extract_node_data(node_dict):
|
|||
|
||||
async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any:
|
||||
graph_engine = await get_graph_engine()
|
||||
counter = 0
|
||||
offset = 0
|
||||
while True:
|
||||
query = f"""
|
||||
|
|
@ -74,11 +89,16 @@ async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any:
|
|||
payload = [
|
||||
{
|
||||
"start_node": extract_node_data(result["start_node"]),
|
||||
"start_node_type": result["start_node"]["type"],
|
||||
"relationship": result["relationship"][1],
|
||||
"end_node_type": result["end_node"]["type"],
|
||||
"end_node": extract_node_data(result["end_node"]),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
|
||||
counter += len(payload)
|
||||
logger.info("Processed %d triplets", counter)
|
||||
yield payload
|
||||
offset += triplets_batch_size
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue