feat: adds input checks for extract graph from data
This commit is contained in:
parent
df3a3df117
commit
9f965c44b4
3 changed files with 73 additions and 1 deletions
12
cognee/tasks/graph/exceptions/__init__.py
Normal file
12
cognee/tasks/graph/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""
|
||||
Custom exceptions for the Cognee API.
|
||||
|
||||
This module defines a set of exceptions for handling various data errors
|
||||
"""
|
||||
|
||||
from .exceptions import (
|
||||
InvalidDataChunksError,
|
||||
InvalidGraphModelError,
|
||||
InvalidOntologyAdapterError,
|
||||
InvalidChunkGraphInputError
|
||||
)
|
||||
37
cognee/tasks/graph/exceptions/exceptions.py
Normal file
37
cognee/tasks/graph/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from cognee.exceptions import (
|
||||
CogneeValidationError,
|
||||
CogneeConfigurationError,
|
||||
)
|
||||
from fastapi import status
|
||||
|
||||
|
||||
class InvalidDataChunksError(CogneeValidationError):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(
|
||||
message=f"Invalid data_chunks: {detail}",
|
||||
name="InvalidDataChunksError",
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
class InvalidGraphModelError(CogneeValidationError):
|
||||
def __init__(self, got):
|
||||
super().__init__(
|
||||
message=f"graph_model must be a subclass of BaseModel (got {got}).",
|
||||
name="InvalidGraphModelError",
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
class InvalidOntologyAdapterError(CogneeConfigurationError):
|
||||
def __init__(self, got):
|
||||
super().__init__(
|
||||
message=f"ontology_adapter lacks required interface (got {got}).",
|
||||
name="InvalidOntologyAdapterError",
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
class InvalidChunkGraphInputError(CogneeValidationError):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(
|
||||
message=f"Invalid chunk inputs or LLM Chunkgraphs: {detail}",
|
||||
name="InvalidChunkGraphInputError",
|
||||
status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
|
@ -12,7 +12,12 @@ from cognee.modules.graph.utils import (
|
|||
)
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
from cognee.tasks.graph.exceptions import (
|
||||
InvalidGraphModelError,
|
||||
InvalidDataChunksError,
|
||||
InvalidChunkGraphInputError,
|
||||
InvalidOntologyAdapterError,
|
||||
)
|
||||
|
||||
async def integrate_chunk_graphs(
|
||||
data_chunks: list[DocumentChunk],
|
||||
|
|
@ -21,6 +26,16 @@ async def integrate_chunk_graphs(
|
|||
ontology_adapter: OntologyResolver,
|
||||
) -> List[DocumentChunk]:
|
||||
"""Updates DocumentChunk objects, integrates data points and edges into databases."""
|
||||
|
||||
if not isinstance(data_chunks, list) or not isinstance(chunk_graphs, list):
|
||||
raise InvalidChunkGraphInputError("data_chunks and chunk_graphs must be lists.")
|
||||
if len(data_chunks) != len(chunk_graphs):
|
||||
raise InvalidChunkGraphInputError(f"length mismatch: {len(data_chunks)} chunks vs {len(chunk_graphs)} graphs.")
|
||||
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
|
||||
raise InvalidGraphModelError(graph_model)
|
||||
if ontology_adapter is None or not hasattr(ontology_adapter, "get_subgraph"):
|
||||
raise InvalidOntologyAdapterError(type(ontology_adapter).__name__ if ontology_adapter else "None")
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
if graph_model is not KnowledgeGraph:
|
||||
|
|
@ -55,6 +70,14 @@ async def extract_graph_from_data(
|
|||
"""
|
||||
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
||||
"""
|
||||
|
||||
if not isinstance(data_chunks, list) or not data_chunks:
|
||||
raise InvalidDataChunksError("must be a non-empty list of DocumentChunk.")
|
||||
if not all(hasattr(c, "text") for c in data_chunks):
|
||||
raise InvalidDataChunksError("each chunk must have a 'text' attribute")
|
||||
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
|
||||
raise InvalidGraphModelError(graph_model)
|
||||
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue