diff --git a/cognee/triplet_embedding_poc/triplet_embedding_poc_example.py b/cognee/triplet_embedding_poc/triplet_embedding_poc_example.py index c2e889101..999347c90 100644 --- a/cognee/triplet_embedding_poc/triplet_embedding_poc_example.py +++ b/cognee/triplet_embedding_poc/triplet_embedding_poc_example.py @@ -162,7 +162,7 @@ Negotiation and Relationship Building async def main(): - pre_graph_creation = False + pre_graph_creation = True if pre_graph_creation: await cognee.prune.prune_data() diff --git a/cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py b/cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py index 89d065e74..bbf6461c2 100644 --- a/cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py +++ b/cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py @@ -1,15 +1,48 @@ -from typing import Any - +from typing import Any, List +import uuid +import json from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base from cognee.modules.users.methods import get_default_user from cognee.shared.logging_utils import get_logger from cognee.modules.pipelines.tasks.task import Task -import json +from cognee.infrastructure.engine import DataPoint + logger = get_logger("triplet_embedding_poc") +def create_triplet_data_point(triplet: dict) -> "TripletDataPoint": + start_node = triplet.get("start_node", None) + if start_node: + start_node_string = start_node.get("content", None) + else: + start_node_string = "" + + relationship = triplet.get("relationship", "") + + end_node = triplet.get("end_node", None) + if end_node: + end_node_string = end_node.get("content", None) + else: + end_node_string = "" + + triplet_str = start_node_string + " " + relationship + " " + end_node_string + triplet_uuid = uuid.uuid5(uuid.NAMESPACE_OID, name=triplet_str) + + return TripletDataPoint(id=triplet_uuid, payload=json.dumps(triplet), text=triplet_str) + + +class TripletDataPoint(DataPoint): + """DataPoint for storing graph triplets with embedded text representation.""" + + payload: str + text: str + metadata: dict = {"index_fields": ["text"]} + + def extract_node_data(node_dict): """Extract relevant data from a node dictionary.""" result = {"id": node_dict["id"]} @@ -50,13 +83,26 @@ async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any: offset += triplets_batch_size -async def add_triplets_to_collection(data) -> None: - print(data) +async def add_triplets_to_collection( + triplets_batch: List[dict], collection_name: str = "Triplets" +) -> None: + vector_adapter = get_vector_engine() + + for triplet_batch in triplets_batch: + data_points = [] + for triplet in triplet_batch: + try: + data_point = create_triplet_data_point(triplet) + data_points.append(data_point) + except Exception as e: + raise ValueError(f"Malformed triplet: {triplet}. Error: {e}") + + await vector_adapter.create_data_points(collection_name, data_points) async def get_triplet_embedding_tasks() -> list[Task]: triplet_embedding_tasks = [ - Task(get_triplets_from_graph_store, triplets_batch_size=100), + Task(get_triplets_from_graph_store, triplets_batch_size=10), Task(add_triplets_to_collection), ]