feat: adds input checks for extract graph from data

This commit is contained in:
hajdul88 2025-08-14 13:53:39 +02:00
parent df3a3df117
commit 9f965c44b4
3 changed files with 73 additions and 1 deletions

View file

@ -0,0 +1,12 @@
"""
Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various data errors
"""
from .exceptions import (
InvalidDataChunksError,
InvalidGraphModelError,
InvalidOntologyAdapterError,
InvalidChunkGraphInputError
)

View file

@ -0,0 +1,37 @@
from cognee.exceptions import (
CogneeValidationError,
CogneeConfigurationError,
)
from fastapi import status
class InvalidDataChunksError(CogneeValidationError):
def __init__(self, detail: str):
super().__init__(
message=f"Invalid data_chunks: {detail}",
name="InvalidDataChunksError",
status_code=status.HTTP_400_BAD_REQUEST,
)
class InvalidGraphModelError(CogneeValidationError):
def __init__(self, got):
super().__init__(
message=f"graph_model must be a subclass of BaseModel (got {got}).",
name="InvalidGraphModelError",
status_code=status.HTTP_400_BAD_REQUEST,
)
class InvalidOntologyAdapterError(CogneeConfigurationError):
def __init__(self, got):
super().__init__(
message=f"ontology_adapter lacks required interface (got {got}).",
name="InvalidOntologyAdapterError",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
class InvalidChunkGraphInputError(CogneeValidationError):
def __init__(self, detail: str):
super().__init__(
message=f"Invalid chunk inputs or LLM Chunkgraphs: {detail}",
name="InvalidChunkGraphInputError",
status_code=status.HTTP_400_BAD_REQUEST)

View file

@ -12,7 +12,12 @@ from cognee.modules.graph.utils import (
)
from cognee.shared.data_models import KnowledgeGraph
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.tasks.graph.exceptions import (
InvalidGraphModelError,
InvalidDataChunksError,
InvalidChunkGraphInputError,
InvalidOntologyAdapterError,
)
async def integrate_chunk_graphs(
data_chunks: list[DocumentChunk],
@ -21,6 +26,16 @@ async def integrate_chunk_graphs(
ontology_adapter: OntologyResolver,
) -> List[DocumentChunk]:
"""Updates DocumentChunk objects, integrates data points and edges into databases."""
if not isinstance(data_chunks, list) or not isinstance(chunk_graphs, list):
raise InvalidChunkGraphInputError("data_chunks and chunk_graphs must be lists.")
if len(data_chunks) != len(chunk_graphs):
raise InvalidChunkGraphInputError(f"length mismatch: {len(data_chunks)} chunks vs {len(chunk_graphs)} graphs.")
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
raise InvalidGraphModelError(graph_model)
if ontology_adapter is None or not hasattr(ontology_adapter, "get_subgraph"):
raise InvalidOntologyAdapterError(type(ontology_adapter).__name__ if ontology_adapter else "None")
graph_engine = await get_graph_engine()
if graph_model is not KnowledgeGraph:
@ -55,6 +70,14 @@ async def extract_graph_from_data(
"""
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
"""
if not isinstance(data_chunks, list) or not data_chunks:
raise InvalidDataChunksError("must be a non-empty list of DocumentChunk.")
if not all(hasattr(c, "text") for c in data_chunks):
raise InvalidDataChunksError("each chunk must have a 'text' attribute")
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
raise InvalidGraphModelError(graph_model)
chunk_graphs = await asyncio.gather(
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)