adds first get_triplets_from_graph_store_task to the postprocessing pipeline
This commit is contained in:
parent
023f98b33e
commit
ad07e681b8
1 changed files with 42 additions and 6 deletions
|
|
@ -1,17 +1,53 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.pipelines.tasks.task import Task
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
import json
|
||||||
|
|
||||||
logger = get_logger("triplet_embedding_poc")
|
logger = get_logger("triplet_embedding_poc")
|
||||||
|
|
||||||
|
|
||||||
async def get_triplets_from_graph_store(data) -> Any:
|
def extract_node_data(node_dict):
|
||||||
for i in range(0, 5):
|
"""Extract relevant data from a node dictionary."""
|
||||||
yield i
|
result = {"id": node_dict["id"]}
|
||||||
|
if "metadata" not in node_dict:
|
||||||
|
return result
|
||||||
|
metadata = json.loads(node_dict["metadata"])
|
||||||
|
if "index_fields" not in metadata or not metadata["index_fields"]:
|
||||||
|
return result
|
||||||
|
index_field_name = metadata["index_fields"][0] # Always one entry
|
||||||
|
if index_field_name not in node_dict:
|
||||||
|
return result
|
||||||
|
result["content"] = node_dict[index_field_name]
|
||||||
|
result["index_field_name"] = index_field_name
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any:
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
offset = 0
|
||||||
|
while True:
|
||||||
|
query = f"""
|
||||||
|
MATCH (start_node)-[relationship]->(end_node)
|
||||||
|
RETURN start_node, relationship, end_node
|
||||||
|
SKIP {offset} LIMIT {triplets_batch_size}
|
||||||
|
"""
|
||||||
|
results = await graph_engine.query(query=query)
|
||||||
|
if not results:
|
||||||
|
break
|
||||||
|
payload = [
|
||||||
|
{
|
||||||
|
"start_node": extract_node_data(result["start_node"]),
|
||||||
|
"relationship": result["relationship"][1],
|
||||||
|
"end_node": extract_node_data(result["end_node"]),
|
||||||
|
}
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
yield payload
|
||||||
|
offset += triplets_batch_size
|
||||||
|
|
||||||
|
|
||||||
async def add_triplets_to_collection(data) -> None:
|
async def add_triplets_to_collection(data) -> None:
|
||||||
|
|
@ -20,7 +56,7 @@ async def add_triplets_to_collection(data) -> None:
|
||||||
|
|
||||||
async def get_triplet_embedding_tasks() -> list[Task]:
|
async def get_triplet_embedding_tasks() -> list[Task]:
|
||||||
triplet_embedding_tasks = [
|
triplet_embedding_tasks = [
|
||||||
Task(get_triplets_from_graph_store),
|
Task(get_triplets_from_graph_store, triplets_batch_size=100),
|
||||||
Task(add_triplets_to_collection),
|
Task(add_triplets_to_collection),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -31,4 +67,4 @@ async def triplet_embedding_postprocessing():
|
||||||
tasks = await get_triplet_embedding_tasks()
|
tasks = await get_triplet_embedding_tasks()
|
||||||
|
|
||||||
async for result in run_tasks_base(tasks, user=await get_default_user(), data=[]):
|
async for result in run_tasks_base(tasks, user=await get_default_user(), data=[]):
|
||||||
print(result)
|
pass
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue