adds triplet embedding task to the pipeline
This commit is contained in:
parent
ad07e681b8
commit
0bbb7e3781
2 changed files with 53 additions and 7 deletions
|
|
@ -162,7 +162,7 @@ Negotiation and Relationship Building
|
|||
|
||||
|
||||
async def main():
|
||||
pre_graph_creation = False
|
||||
pre_graph_creation = True
|
||||
|
||||
if pre_graph_creation:
|
||||
await cognee.prune.prune_data()
|
||||
|
|
|
|||
|
|
@ -1,15 +1,48 @@
|
|||
from typing import Any
|
||||
|
||||
from typing import Any, List
|
||||
import uuid
|
||||
import json
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
import json
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
logger = get_logger("triplet_embedding_poc")
|
||||
|
||||
|
||||
def create_triplet_data_point(triplet: dict) -> "TripletDataPoint":
|
||||
start_node = triplet.get("start_node", None)
|
||||
if start_node:
|
||||
start_node_string = start_node.get("content", None)
|
||||
else:
|
||||
start_node_string = ""
|
||||
|
||||
relationship = triplet.get("relationship", "")
|
||||
|
||||
end_node = triplet.get("end_node", None)
|
||||
if end_node:
|
||||
end_node_string = end_node.get("content", None)
|
||||
else:
|
||||
end_node_string = ""
|
||||
|
||||
triplet_str = start_node_string + " " + relationship + " " + end_node_string
|
||||
triplet_uuid = uuid.uuid5(uuid.NAMESPACE_OID, name=triplet_str)
|
||||
|
||||
return TripletDataPoint(id=triplet_uuid, payload=json.dumps(triplet), text=triplet_str)
|
||||
|
||||
|
||||
class TripletDataPoint(DataPoint):
|
||||
"""DataPoint for storing graph triplets with embedded text representation."""
|
||||
|
||||
payload: str
|
||||
text: str
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
def extract_node_data(node_dict):
|
||||
"""Extract relevant data from a node dictionary."""
|
||||
result = {"id": node_dict["id"]}
|
||||
|
|
@ -50,13 +83,26 @@ async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any:
|
|||
offset += triplets_batch_size
|
||||
|
||||
|
||||
async def add_triplets_to_collection(data) -> None:
|
||||
print(data)
|
||||
async def add_triplets_to_collection(
|
||||
triplets_batch: List[dict], collection_name: str = "Triplets"
|
||||
) -> None:
|
||||
vector_adapter = get_vector_engine()
|
||||
|
||||
for triplet_batch in triplets_batch:
|
||||
data_points = []
|
||||
for triplet in triplet_batch:
|
||||
try:
|
||||
data_point = create_triplet_data_point(triplet)
|
||||
data_points.append(data_point)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Malformed triplet: {triplet}. Error: {e}")
|
||||
|
||||
await vector_adapter.create_data_points(collection_name, data_points)
|
||||
|
||||
|
||||
async def get_triplet_embedding_tasks() -> list[Task]:
|
||||
triplet_embedding_tasks = [
|
||||
Task(get_triplets_from_graph_store, triplets_batch_size=100),
|
||||
Task(get_triplets_from_graph_store, triplets_batch_size=10),
|
||||
Task(add_triplets_to_collection),
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue