feat: adds input checks for extract graph from data
This commit is contained in:
parent
df3a3df117
commit
9f965c44b4
3 changed files with 73 additions and 1 deletions
12
cognee/tasks/graph/exceptions/__init__.py
Normal file
12
cognee/tasks/graph/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""
|
||||||
|
Custom exceptions for the Cognee API.
|
||||||
|
|
||||||
|
This module defines a set of exceptions for handling various data errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
InvalidDataChunksError,
|
||||||
|
InvalidGraphModelError,
|
||||||
|
InvalidOntologyAdapterError,
|
||||||
|
InvalidChunkGraphInputError
|
||||||
|
)
|
||||||
37
cognee/tasks/graph/exceptions/exceptions.py
Normal file
37
cognee/tasks/graph/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
from cognee.exceptions import (
|
||||||
|
CogneeValidationError,
|
||||||
|
CogneeConfigurationError,
|
||||||
|
)
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidDataChunksError(CogneeValidationError):
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Invalid data_chunks: {detail}",
|
||||||
|
name="InvalidDataChunksError",
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvalidGraphModelError(CogneeValidationError):
|
||||||
|
def __init__(self, got):
|
||||||
|
super().__init__(
|
||||||
|
message=f"graph_model must be a subclass of BaseModel (got {got}).",
|
||||||
|
name="InvalidGraphModelError",
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvalidOntologyAdapterError(CogneeConfigurationError):
|
||||||
|
def __init__(self, got):
|
||||||
|
super().__init__(
|
||||||
|
message=f"ontology_adapter lacks required interface (got {got}).",
|
||||||
|
name="InvalidOntologyAdapterError",
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidChunkGraphInputError(CogneeValidationError):
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Invalid chunk inputs or LLM Chunkgraphs: {detail}",
|
||||||
|
name="InvalidChunkGraphInputError",
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
@ -12,7 +12,12 @@ from cognee.modules.graph.utils import (
|
||||||
)
|
)
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
from cognee.tasks.graph.exceptions import (
|
||||||
|
InvalidGraphModelError,
|
||||||
|
InvalidDataChunksError,
|
||||||
|
InvalidChunkGraphInputError,
|
||||||
|
InvalidOntologyAdapterError,
|
||||||
|
)
|
||||||
|
|
||||||
async def integrate_chunk_graphs(
|
async def integrate_chunk_graphs(
|
||||||
data_chunks: list[DocumentChunk],
|
data_chunks: list[DocumentChunk],
|
||||||
|
|
@ -21,6 +26,16 @@ async def integrate_chunk_graphs(
|
||||||
ontology_adapter: OntologyResolver,
|
ontology_adapter: OntologyResolver,
|
||||||
) -> List[DocumentChunk]:
|
) -> List[DocumentChunk]:
|
||||||
"""Updates DocumentChunk objects, integrates data points and edges into databases."""
|
"""Updates DocumentChunk objects, integrates data points and edges into databases."""
|
||||||
|
|
||||||
|
if not isinstance(data_chunks, list) or not isinstance(chunk_graphs, list):
|
||||||
|
raise InvalidChunkGraphInputError("data_chunks and chunk_graphs must be lists.")
|
||||||
|
if len(data_chunks) != len(chunk_graphs):
|
||||||
|
raise InvalidChunkGraphInputError(f"length mismatch: {len(data_chunks)} chunks vs {len(chunk_graphs)} graphs.")
|
||||||
|
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
|
||||||
|
raise InvalidGraphModelError(graph_model)
|
||||||
|
if ontology_adapter is None or not hasattr(ontology_adapter, "get_subgraph"):
|
||||||
|
raise InvalidOntologyAdapterError(type(ontology_adapter).__name__ if ontology_adapter else "None")
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
if graph_model is not KnowledgeGraph:
|
if graph_model is not KnowledgeGraph:
|
||||||
|
|
@ -55,6 +70,14 @@ async def extract_graph_from_data(
|
||||||
"""
|
"""
|
||||||
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not isinstance(data_chunks, list) or not data_chunks:
|
||||||
|
raise InvalidDataChunksError("must be a non-empty list of DocumentChunk.")
|
||||||
|
if not all(hasattr(c, "text") for c in data_chunks):
|
||||||
|
raise InvalidDataChunksError("each chunk must have a 'text' attribute")
|
||||||
|
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
|
||||||
|
raise InvalidGraphModelError(graph_model)
|
||||||
|
|
||||||
chunk_graphs = await asyncio.gather(
|
chunk_graphs = await asyncio.gather(
|
||||||
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue