cognee/cognee/modules/data/extraction/knowledge_graph/expand_knowledge_graph.py
2024-08-04 22:23:28 +02:00

218 lines
7.9 KiB
Python

import json
import asyncio
from uuid import uuid5, NAMESPACE_OID
from datetime import datetime, timezone
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
from ...processing.chunk_types.DocumentChunk import DocumentChunk
from .extract_knowledge_graph import extract_content_graph
class EntityNode(BaseModel):
uuid: str
name: str
type: str
description: str
created_at: datetime
updated_at: datetime
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
has_collection = await vector_engine.has_collection(collection_name)
if not has_collection:
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
for node in chunk_graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if type_node_id not in processed_nodes:
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
processed_nodes[type_node_id] = True
if entity_node_id not in processed_nodes:
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
processed_nodes[entity_node_id] = True
graph_node_edges = [
(edge.source_node_id, edge.target_node_id, edge.relationship_name) \
for edge in chunk_graph.edges
]
existing_edges = await graph_engine.has_edges([
*type_node_edges,
*entity_node_edges,
*type_entity_edges,
*graph_node_edges,
])
existing_edges_map = {}
existing_nodes_map = {}
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
existing_nodes_map[edge[0]] = True
graph_nodes = []
graph_edges = []
data_points = []
for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index]
if graph is None:
continue
for node in graph.nodes:
node_id = generate_node_id(node.id)
node_name = generate_name(node.name)
type_node_id = generate_node_id(node.type)
type_node_name = generate_name(node.type)
if node_id not in existing_nodes_map:
node_data = dict(
uuid = node_id,
name = node_name,
type = node_name,
description = node.description,
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((
node_id,
dict(
**node_data,
properties = json.dumps(node.properties),
)
))
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, node_id)),
payload = node_data,
embed_field = "name",
))
existing_nodes_map[node_id] = True
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
if edge_key not in existing_edges_map:
graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains_entity",
dict(
relationship_name = "contains_entity",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
node_id,
type_node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
existing_edges_map[edge_key] = True
if type_node_id not in existing_nodes_map:
type_node_data = dict(
uuid = type_node_id,
name = type_node_name,
type = type_node_id,
description = type_node_name,
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((type_node_id, dict(
**type_node_data,
properties = json.dumps(node.properties)
)))
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, type_node_id)),
payload = type_node_data,
embed_field = "name",
))
existing_nodes_map[type_node_id] = True
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
if edge_key not in existing_edges_map:
graph_edges.append((
str(chunk.chunk_id),
type_node_id,
"contains_entity_type",
dict(
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
existing_edges_map[edge_key] = True
# Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_name(edge.relationship_name)
edge_key = source_node_id + target_node_id + relationship_name
if edge_key not in existing_edges_map:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
dict(
relationship_name = generate_name(edge.relationship_name),
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
properties = json.dumps(edge.properties),
),
))
existing_edges_map[edge_key] = True
if len(data_points) > 0:
await vector_engine.create_data_points(collection_name, data_points)
if len(graph_nodes) > 0:
await graph_engine.add_nodes(graph_nodes)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks
def generate_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")
def generate_node_id(node_id: str) -> str:
return node_id.lower().replace(" ", "_").replace("'", "")