adds triplet embedding task to the pipeline

This commit is contained in:
hajdul88 2025-08-03 12:53:33 +02:00
parent ad07e681b8
commit 0bbb7e3781
2 changed files with 53 additions and 7 deletions

View file

@ -162,7 +162,7 @@ Negotiation and Relationship Building
async def main():
pre_graph_creation = False
pre_graph_creation = True
if pre_graph_creation:
await cognee.prune.prune_data()

View file

@ -1,15 +1,48 @@
from typing import Any
from typing import Any, List
import uuid
import json
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.pipelines.tasks.task import Task
import json
from cognee.infrastructure.engine import DataPoint
logger = get_logger("triplet_embedding_poc")
def create_triplet_data_point(triplet: dict) -> "TripletDataPoint":
start_node = triplet.get("start_node", None)
if start_node:
start_node_string = start_node.get("content", None)
else:
start_node_string = ""
relationship = triplet.get("relationship", "")
end_node = triplet.get("end_node", None)
if end_node:
end_node_string = end_node.get("content", None)
else:
end_node_string = ""
triplet_str = start_node_string + " " + relationship + " " + end_node_string
triplet_uuid = uuid.uuid5(uuid.NAMESPACE_OID, name=triplet_str)
return TripletDataPoint(id=triplet_uuid, payload=json.dumps(triplet), text=triplet_str)
class TripletDataPoint(DataPoint):
"""DataPoint for storing graph triplets with embedded text representation."""
payload: str
text: str
metadata: dict = {"index_fields": ["text"]}
def extract_node_data(node_dict):
"""Extract relevant data from a node dictionary."""
result = {"id": node_dict["id"]}
@ -50,13 +83,26 @@ async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any:
offset += triplets_batch_size
async def add_triplets_to_collection(data) -> None:
print(data)
async def add_triplets_to_collection(
triplets_batch: List[dict], collection_name: str = "Triplets"
) -> None:
vector_adapter = get_vector_engine()
for triplet_batch in triplets_batch:
data_points = []
for triplet in triplet_batch:
try:
data_point = create_triplet_data_point(triplet)
data_points.append(data_point)
except Exception as e:
raise ValueError(f"Malformed triplet: {triplet}. Error: {e}")
await vector_adapter.create_data_points(collection_name, data_points)
async def get_triplet_embedding_tasks() -> list[Task]:
triplet_embedding_tasks = [
Task(get_triplets_from_graph_store, triplets_batch_size=100),
Task(get_triplets_from_graph_store, triplets_batch_size=10),
Task(add_triplets_to_collection),
]