import asyncio from typing import Type, List, Optional from pydantic import BaseModel from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.ontology.ontology_env_config import get_ontology_env_config from cognee.tasks.storage import index_graph_edges from cognee.tasks.storage.add_data_points import add_data_points from cognee.modules.ontology.ontology_config import Config from cognee.modules.ontology.get_default_ontology_resolver import ( get_default_ontology_resolver, get_ontology_resolver_from_env, ) from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.graph.utils import ( expand_with_nodes_and_edges, retrieve_existing_edges, ) from cognee.shared.data_models import KnowledgeGraph from cognee.infrastructure.llm.extraction import extract_content_graph from cognee.tasks.graph.exceptions import ( InvalidGraphModelError, InvalidDataChunksError, InvalidChunkGraphInputError, InvalidOntologyAdapterError, ) async def integrate_chunk_graphs( data_chunks: list[DocumentChunk], chunk_graphs: list, graph_model: Type[BaseModel], ontology_resolver: BaseOntologyResolver, ) -> List[DocumentChunk]: """Integrate chunk graphs with ontology validation and store in databases. This function processes document chunks and their associated knowledge graphs, validates entities against an ontology resolver, and stores the integrated data points and edges in the configured databases. Args: data_chunks: List of document chunks containing source data chunk_graphs: List of knowledge graphs corresponding to each chunk graph_model: Pydantic model class for graph data validation ontology_resolver: Resolver for validating entities against ontology Returns: List of updated DocumentChunk objects with integrated data Raises: InvalidChunkGraphInputError: If input validation fails InvalidGraphModelError: If graph model validation fails InvalidOntologyAdapterError: If ontology resolver validation fails """ if not isinstance(data_chunks, list) or not isinstance(chunk_graphs, list): raise InvalidChunkGraphInputError("data_chunks and chunk_graphs must be lists.") if len(data_chunks) != len(chunk_graphs): raise InvalidChunkGraphInputError( f"length mismatch: {len(data_chunks)} chunks vs {len(chunk_graphs)} graphs." ) if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel): raise InvalidGraphModelError(graph_model) if ontology_resolver is None or not hasattr(ontology_resolver, "get_subgraph"): raise InvalidOntologyAdapterError( type(ontology_resolver).__name__ if ontology_resolver else "None" ) graph_engine = await get_graph_engine() if graph_model is not KnowledgeGraph: for chunk_index, chunk_graph in enumerate(chunk_graphs): data_chunks[chunk_index].contains = chunk_graph return data_chunks existing_edges_map = await retrieve_existing_edges( data_chunks, chunk_graphs, ) graph_nodes, graph_edges = expand_with_nodes_and_edges( data_chunks, chunk_graphs, ontology_resolver, existing_edges_map ) if len(graph_nodes) > 0: await add_data_points(graph_nodes) if len(graph_edges) > 0: await graph_engine.add_edges(graph_edges) await index_graph_edges(graph_edges) return data_chunks async def extract_graph_from_data( data_chunks: List[DocumentChunk], graph_model: Type[BaseModel], config: Config = None, custom_prompt: Optional[str] = None, ) -> List[DocumentChunk]: """ Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model. """ if not isinstance(data_chunks, list) or not data_chunks: raise InvalidDataChunksError("must be a non-empty list of DocumentChunk.") if not all(hasattr(c, "text") for c in data_chunks): raise InvalidDataChunksError("each chunk must have a 'text' attribute") if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel): raise InvalidGraphModelError(graph_model) chunk_graphs = await asyncio.gather( *[ extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt) for chunk in data_chunks ] ) # Note: Filter edges with missing source or target nodes if graph_model == KnowledgeGraph: for graph in chunk_graphs: valid_node_ids = {node.id for node in graph.nodes} graph.edges = [ edge for edge in graph.edges if edge.source_node_id in valid_node_ids and edge.target_node_id in valid_node_ids ] # Extract resolver from config if provided, otherwise get default if config is None: ontology_config = get_ontology_env_config() if ( ontology_config.ontology_file_path and ontology_config.ontology_resolver and ontology_config.matching_strategy ): config: Config = { "ontology_config": { "ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict()) } } else: config: Config = { "ontology_config": {"ontology_resolver": get_default_ontology_resolver()} } ontology_resolver = config["ontology_config"]["ontology_resolver"] return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_resolver)