<!-- .github/pull_request_template.md --> ## Description This PR fixes distributed pipeline + updates core changes in distr logic. ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Changes Made Fixes distributed pipeline: -Changed spawning logic + adds incremental loading to run_tasks_diistributed -Adds batching to consumer nodes -Fixes consumer stopping criteria by adding stop signal + handling -Changed edge embedding solution to avoid huge network load in a case of a multicontainer environment ## Testing Tested it by running 1GB on modal + manually ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues None ## Additional Notes None ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Boris <boris@topoteretes.com> Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
151 lines
5.7 KiB
Python
151 lines
5.7 KiB
Python
import asyncio
|
|
from typing import Type, List, Optional
|
|
from pydantic import BaseModel
|
|
|
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
|
from cognee.tasks.storage import index_graph_edges
|
|
from cognee.tasks.storage.add_data_points import add_data_points
|
|
from cognee.modules.ontology.ontology_config import Config
|
|
from cognee.modules.ontology.get_default_ontology_resolver import (
|
|
get_default_ontology_resolver,
|
|
get_ontology_resolver_from_env,
|
|
)
|
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
|
from cognee.modules.graph.utils import (
|
|
expand_with_nodes_and_edges,
|
|
retrieve_existing_edges,
|
|
)
|
|
from cognee.shared.data_models import KnowledgeGraph
|
|
from cognee.infrastructure.llm.extraction import extract_content_graph
|
|
from cognee.tasks.graph.exceptions import (
|
|
InvalidGraphModelError,
|
|
InvalidDataChunksError,
|
|
InvalidChunkGraphInputError,
|
|
InvalidOntologyAdapterError,
|
|
)
|
|
|
|
|
|
async def integrate_chunk_graphs(
|
|
data_chunks: list[DocumentChunk],
|
|
chunk_graphs: list,
|
|
graph_model: Type[BaseModel],
|
|
ontology_resolver: BaseOntologyResolver,
|
|
) -> List[DocumentChunk]:
|
|
"""Integrate chunk graphs with ontology validation and store in databases.
|
|
|
|
This function processes document chunks and their associated knowledge graphs,
|
|
validates entities against an ontology resolver, and stores the integrated
|
|
data points and edges in the configured databases.
|
|
|
|
Args:
|
|
data_chunks: List of document chunks containing source data
|
|
chunk_graphs: List of knowledge graphs corresponding to each chunk
|
|
graph_model: Pydantic model class for graph data validation
|
|
ontology_resolver: Resolver for validating entities against ontology
|
|
|
|
Returns:
|
|
List of updated DocumentChunk objects with integrated data
|
|
|
|
Raises:
|
|
InvalidChunkGraphInputError: If input validation fails
|
|
InvalidGraphModelError: If graph model validation fails
|
|
InvalidOntologyAdapterError: If ontology resolver validation fails
|
|
"""
|
|
|
|
if not isinstance(data_chunks, list) or not isinstance(chunk_graphs, list):
|
|
raise InvalidChunkGraphInputError("data_chunks and chunk_graphs must be lists.")
|
|
if len(data_chunks) != len(chunk_graphs):
|
|
raise InvalidChunkGraphInputError(
|
|
f"length mismatch: {len(data_chunks)} chunks vs {len(chunk_graphs)} graphs."
|
|
)
|
|
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
|
|
raise InvalidGraphModelError(graph_model)
|
|
if ontology_resolver is None or not hasattr(ontology_resolver, "get_subgraph"):
|
|
raise InvalidOntologyAdapterError(
|
|
type(ontology_resolver).__name__ if ontology_resolver else "None"
|
|
)
|
|
|
|
graph_engine = await get_graph_engine()
|
|
|
|
if graph_model is not KnowledgeGraph:
|
|
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
|
data_chunks[chunk_index].contains = chunk_graph
|
|
|
|
return data_chunks
|
|
|
|
existing_edges_map = await retrieve_existing_edges(
|
|
data_chunks,
|
|
chunk_graphs,
|
|
)
|
|
|
|
graph_nodes, graph_edges = expand_with_nodes_and_edges(
|
|
data_chunks, chunk_graphs, ontology_resolver, existing_edges_map
|
|
)
|
|
|
|
if len(graph_nodes) > 0:
|
|
await add_data_points(graph_nodes)
|
|
|
|
if len(graph_edges) > 0:
|
|
await graph_engine.add_edges(graph_edges)
|
|
await index_graph_edges(graph_edges)
|
|
|
|
return data_chunks
|
|
|
|
|
|
async def extract_graph_from_data(
|
|
data_chunks: List[DocumentChunk],
|
|
graph_model: Type[BaseModel],
|
|
config: Config = None,
|
|
custom_prompt: Optional[str] = None,
|
|
) -> List[DocumentChunk]:
|
|
"""
|
|
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
|
"""
|
|
|
|
if not isinstance(data_chunks, list) or not data_chunks:
|
|
raise InvalidDataChunksError("must be a non-empty list of DocumentChunk.")
|
|
if not all(hasattr(c, "text") for c in data_chunks):
|
|
raise InvalidDataChunksError("each chunk must have a 'text' attribute")
|
|
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
|
|
raise InvalidGraphModelError(graph_model)
|
|
|
|
chunk_graphs = await asyncio.gather(
|
|
*[
|
|
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
|
|
for chunk in data_chunks
|
|
]
|
|
)
|
|
|
|
# Note: Filter edges with missing source or target nodes
|
|
if graph_model == KnowledgeGraph:
|
|
for graph in chunk_graphs:
|
|
valid_node_ids = {node.id for node in graph.nodes}
|
|
graph.edges = [
|
|
edge
|
|
for edge in graph.edges
|
|
if edge.source_node_id in valid_node_ids and edge.target_node_id in valid_node_ids
|
|
]
|
|
|
|
# Extract resolver from config if provided, otherwise get default
|
|
if config is None:
|
|
ontology_config = get_ontology_env_config()
|
|
if (
|
|
ontology_config.ontology_file_path
|
|
and ontology_config.ontology_resolver
|
|
and ontology_config.matching_strategy
|
|
):
|
|
config: Config = {
|
|
"ontology_config": {
|
|
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
|
|
}
|
|
}
|
|
else:
|
|
config: Config = {
|
|
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
|
}
|
|
|
|
ontology_resolver = config["ontology_config"]["ontology_resolver"]
|
|
|
|
return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_resolver)
|