cognee/cognee/tasks/graph/extract_graph_from_data.py
2024-11-12 09:01:01 +01:00

121 lines
4.5 KiB
Python

import asyncio
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.engine.models import EntityType, Entity
from cognee.modules.engine.utils import generate_node_id, generate_node_name
from cognee.tasks.storage import add_data_points
async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
for node in chunk_graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if str(type_node_id) not in processed_nodes:
type_node_edges.append((str(chunk.id), str(type_node_id), "exists_in"))
processed_nodes[str(type_node_id)] = True
if str(entity_node_id) not in processed_nodes:
entity_node_edges.append((str(chunk.id), entity_node_id, "mentioned_in"))
type_entity_edges.append((str(entity_node_id), str(type_node_id), "is_a"))
processed_nodes[str(entity_node_id)] = True
graph_node_edges = [
(edge.target_node_id, edge.source_node_id, edge.relationship_name) \
for edge in chunk_graph.edges
]
graph_engine = await get_graph_engine()
existing_edges = await graph_engine.has_edges([
*type_node_edges,
*entity_node_edges,
*type_entity_edges,
*graph_node_edges,
])
existing_edges_map = {}
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
added_nodes_map = {}
graph_edges = []
data_points = []
for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index]
if graph is None:
continue
for node in graph.nodes:
node_id = generate_node_id(node.id)
node_name = generate_node_name(node.name)
type_node_id = generate_node_id(node.type)
type_node_name = generate_node_name(node.type)
if f"{str(type_node_id)}_type" not in added_nodes_map:
type_node = EntityType(
id = type_node_id,
name = type_node_name,
type = type_node_name,
description = type_node_name,
exists_in = chunk,
)
added_nodes_map[f"{str(type_node_id)}_type"] = type_node
else:
type_node = added_nodes_map[f"{str(type_node_id)}_type"]
if f"{str(node_id)}_entity" not in added_nodes_map:
entity_node = Entity(
id = node_id,
name = node_name,
is_a = type_node,
description = node.description,
mentioned_in = chunk,
)
data_points.append(entity_node)
added_nodes_map[f"{str(node_id)}_entity"] = entity_node
# Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_node_name(edge.relationship_name)
edge_key = str(source_node_id) + str(target_node_id) + relationship_name
if edge_key not in existing_edges_map:
graph_edges.append((
source_node_id,
target_node_id,
edge.relationship_name,
dict(
relationship_name = generate_node_name(edge.relationship_name),
source_node_id = source_node_id,
target_node_id = target_node_id,
),
))
existing_edges_map[edge_key] = True
if len(data_points) > 0:
await add_data_points(data_points)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks