diff --git a/cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py b/cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py index bbf6461c2..6ab9ba4a4 100644 --- a/cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py +++ b/cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py @@ -29,7 +29,21 @@ def create_triplet_data_point(triplet: dict) -> "TripletDataPoint": else: end_node_string = "" - triplet_str = start_node_string + " " + relationship + " " + end_node_string + start_node_type = triplet.get("start_node_type", "") + end_node_type = triplet.get("end_node_type", "") + + triplet_str = ( + start_node_string + + "-" + + start_node_type + + "-" + + relationship + + "-" + + end_node_string + + "-" + + end_node_type + ) + triplet_uuid = uuid.uuid5(uuid.NAMESPACE_OID, name=triplet_str) return TripletDataPoint(id=triplet_uuid, payload=json.dumps(triplet), text=triplet_str) @@ -61,6 +75,7 @@ def extract_node_data(node_dict): async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any: graph_engine = await get_graph_engine() + counter = 0 offset = 0 while True: query = f""" @@ -74,11 +89,16 @@ async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any: payload = [ { "start_node": extract_node_data(result["start_node"]), + "start_node_type": result["start_node"]["type"], "relationship": result["relationship"][1], + "end_node_type": result["end_node"]["type"], "end_node": extract_node_data(result["end_node"]), } for result in results ] + + counter += len(payload) + logger.info("Processed %d triplets", counter) yield payload offset += triplets_batch_size