cognee/cognee/tasks/graph/extract_graph_from_data.py
2025-10-12 22:23:07 +02:00

151 lines
5.6 KiB
Python

import asyncio
from typing import Dict, Type, List, Optional
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.modules.ontology.ontology_config import Config
from cognee.modules.ontology.get_default_ontology_resolver import (
get_default_ontology_resolver,
get_ontology_resolver_from_env,
)
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.graph.utils import (
expand_with_nodes_and_edges,
retrieve_existing_edges,
)
from cognee.shared.data_models import KnowledgeGraph
from cognee.infrastructure.llm.extraction import extract_content_graph
from cognee.tasks.graph.exceptions import (
InvalidGraphModelError,
InvalidDataChunksError,
InvalidChunkGraphInputError,
InvalidOntologyAdapterError,
)
async def integrate_chunk_graphs(
data_chunks: list[DocumentChunk],
chunk_graphs: list,
graph_model: Type[BaseModel],
ontology_resolver: BaseOntologyResolver,
context: Dict,
) -> List[DocumentChunk]:
"""Integrate chunk graphs with ontology validation and store in databases.
This function processes document chunks and their associated knowledge graphs,
validates entities against an ontology resolver, and stores the integrated
data points and edges in the configured databases.
Args:
data_chunks: List of document chunks containing source data
chunk_graphs: List of knowledge graphs corresponding to each chunk
graph_model: Pydantic model class for graph data validation
ontology_resolver: Resolver for validating entities against ontology
Returns:
List of updated DocumentChunk objects with integrated data
Raises:
InvalidChunkGraphInputError: If input validation fails
InvalidGraphModelError: If graph model validation fails
InvalidOntologyAdapterError: If ontology resolver validation fails
"""
if not isinstance(data_chunks, list) or not isinstance(chunk_graphs, list):
raise InvalidChunkGraphInputError("data_chunks and chunk_graphs must be lists.")
if len(data_chunks) != len(chunk_graphs):
raise InvalidChunkGraphInputError(
f"length mismatch: {len(data_chunks)} chunks vs {len(chunk_graphs)} graphs."
)
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
raise InvalidGraphModelError(graph_model)
if ontology_resolver is None or not hasattr(ontology_resolver, "get_subgraph"):
raise InvalidOntologyAdapterError(
type(ontology_resolver).__name__ if ontology_resolver else "None"
)
graph_engine = await get_graph_engine()
if graph_model is not KnowledgeGraph:
for chunk_index, chunk_graph in enumerate(chunk_graphs):
data_chunks[chunk_index].contains = chunk_graph
return data_chunks
existing_edges_map = await retrieve_existing_edges(
data_chunks,
chunk_graphs,
)
graph_nodes, graph_edges = expand_with_nodes_and_edges(
data_chunks, chunk_graphs, ontology_resolver, existing_edges_map
)
if len(graph_nodes) > 0:
await add_data_points(graph_nodes, context)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks
async def extract_graph_from_data(
data_chunks: List[DocumentChunk],
context: Dict,
graph_model: Type[BaseModel],
config: Optional[Config] = None,
custom_prompt: Optional[str] = None,
) -> List[DocumentChunk]:
"""
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
"""
if not isinstance(data_chunks, list) or not data_chunks:
raise InvalidDataChunksError("must be a non-empty list of DocumentChunk.")
if not all(hasattr(c, "text") for c in data_chunks):
raise InvalidDataChunksError("each chunk must have a 'text' attribute")
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
raise InvalidGraphModelError(graph_model)
chunk_graphs = await asyncio.gather(
*[
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
for chunk in data_chunks
]
)
# Note: Filter edges with missing source or target nodes
if graph_model == KnowledgeGraph:
for graph in chunk_graphs:
valid_node_ids = {node.id for node in graph.nodes}
graph.edges = [
edge
for edge in graph.edges
if edge.source_node_id in valid_node_ids and edge.target_node_id in valid_node_ids
]
# Extract resolver from config if provided, otherwise get default
if config is None:
ontology_config = get_ontology_env_config()
if (
ontology_config.ontology_file_path
and ontology_config.ontology_resolver
and ontology_config.matching_strategy
):
config = {
"ontology_config": {
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
}
}
else:
config = {"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}}
ontology_resolver = config["ontology_config"]["ontology_resolver"]
return await integrate_chunk_graphs(
data_chunks, chunk_graphs, graph_model, ontology_resolver, context
)