Compare commits
1 commit
main
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f6b2dca51 |
4 changed files with 1432 additions and 195 deletions
|
|
@ -1,18 +1,21 @@
|
||||||
import asyncio
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional, Type
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
from cognee.infrastructure.llm.utils import get_max_chunk_tokens
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.modules.pipelines import run_pipeline
|
from cognee.modules.pipelines.operations.pipeline import run_pipeline
|
||||||
from cognee.modules.pipelines.tasks.task import Task
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
from cognee.tasks.documents import (
|
from cognee.tasks.documents import (
|
||||||
check_permissions_on_dataset,
|
check_permissions_on_dataset,
|
||||||
classify_documents,
|
classify_documents,
|
||||||
|
|
@ -21,179 +24,101 @@ from cognee.tasks.documents import (
|
||||||
from cognee.tasks.graph import extract_graph_from_data
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.tasks.summarization import summarize_text
|
from cognee.tasks.summarization import summarize_text
|
||||||
|
from cognee.tasks.translation import translate_content, get_available_providers, validate_provider
|
||||||
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
||||||
from cognee.tasks.temporal_graph.extract_events_and_entities import extract_events_and_timestamps
|
|
||||||
from cognee.tasks.temporal_graph.extract_knowledge_graph_from_events import (
|
|
||||||
extract_knowledge_graph_from_events,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("cognify")
|
class TranslationProviderError(ValueError):
|
||||||
|
"""Error related to translation provider initialization."""
|
||||||
|
pass
|
||||||
|
|
||||||
update_status_lock = asyncio.Lock()
|
class UnknownTranslationProviderError(TranslationProviderError):
|
||||||
|
"""Unknown translation provider name."""
|
||||||
|
|
||||||
|
class ProviderInitializationError(TranslationProviderError):
|
||||||
|
"""Provider failed to initialize (likely missing dependency or bad config)."""
|
||||||
|
|
||||||
|
|
||||||
async def cognify(
|
_WARNED_ENV_VARS: set[str] = set()
|
||||||
datasets: Union[str, list[str], list[UUID]] = None,
|
|
||||||
user: User = None,
|
def _parse_batch_env(var: str, default: int = 10) -> int:
|
||||||
graph_model: BaseModel = KnowledgeGraph,
|
"""
|
||||||
|
Parse an environment variable as a positive integer (minimum 1), falling back to a default.
|
||||||
|
|
||||||
|
If the environment variable named `var` is unset, the provided `default` is returned.
|
||||||
|
If the variable is set but cannot be parsed as an integer, `default` is returned and a
|
||||||
|
one-time warning is logged for that variable (the variable name is recorded in
|
||||||
|
`_WARNED_ENV_VARS` to avoid repeated warnings).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
var: Name of the environment variable to read.
|
||||||
|
default: Fallback integer value returned when the variable is missing or invalid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An integer >= 1 representing the parsed value or the fallback `default`.
|
||||||
|
"""
|
||||||
|
raw = os.getenv(var)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return max(1, int(raw))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
if var not in _WARNED_ENV_VARS:
|
||||||
|
logger.warning("Invalid int for %s=%r; using default=%d", var, raw, default)
|
||||||
|
_WARNED_ENV_VARS.add(var)
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Constants for batch processing
|
||||||
|
DEFAULT_BATCH_SIZE = _parse_batch_env("COGNEE_DEFAULT_BATCH_SIZE", 10)
|
||||||
|
|
||||||
|
async def cognify( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||||
|
datasets: Optional[Union[str, UUID, list[str], list[UUID]]] = None,
|
||||||
|
user: Optional[User] = None,
|
||||||
|
graph_model: Type[BaseModel] = KnowledgeGraph,
|
||||||
chunker=TextChunker,
|
chunker=TextChunker,
|
||||||
chunk_size: int = None,
|
chunk_size: Optional[int] = None,
|
||||||
ontology_file_path: Optional[str] = None,
|
ontology_file_path: Optional[str] = None,
|
||||||
vector_db_config: dict = None,
|
vector_db_config: Optional[dict] = None,
|
||||||
graph_db_config: dict = None,
|
graph_db_config: Optional[dict] = None,
|
||||||
run_in_background: bool = False,
|
run_in_background: bool = False,
|
||||||
incremental_loading: bool = True,
|
incremental_loading: bool = True,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
temporal_cognify: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Transform ingested data into a structured knowledge graph.
|
Orchestrate processing of datasets into a knowledge graph.
|
||||||
|
|
||||||
This is the core processing step in Cognee that converts raw text and documents
|
Builds the default Cognify task sequence (classification, permission check, chunking,
|
||||||
into an intelligent knowledge graph. It analyzes content, extracts entities and
|
graph extraction, summarization, indexing) and executes it via the pipeline
|
||||||
relationships, and creates semantic connections for enhanced search and reasoning.
|
executor. Use get_default_tasks_with_translation(...) to include an automatic
|
||||||
|
translation step before graph extraction.
|
||||||
|
|
||||||
Prerequisites:
|
Parameters:
|
||||||
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
|
datasets: Optional dataset id or list of ids to process. If None, processes all
|
||||||
- **Data Added**: Must have data previously added via `cognee.add()`
|
datasets available to the user.
|
||||||
- **Vector Database**: Must be accessible for embeddings storage
|
user: Optional user context used for permission checks; defaults to the current
|
||||||
- **Graph Database**: Must be accessible for relationship storage
|
runtime user if omitted.
|
||||||
|
graph_model: Pydantic model type that defines the structure of produced graph
|
||||||
Input Requirements:
|
DataPoints (default: KnowledgeGraph).
|
||||||
- **Datasets**: Must contain data previously added via `cognee.add()`
|
chunker: Chunking strategy/class used to split documents (default: TextChunker).
|
||||||
- **Content Types**: Works with any text-extractable content including:
|
chunk_size: Optional max tokens per chunk; when None a sensible default is used.
|
||||||
* Natural language documents
|
ontology_file_path: Optional path to an ontology (RDF/OWL) used by the extractor.
|
||||||
* Structured data (CSV, JSON)
|
vector_db_config: Optional mapping of vector DB configuration (overrides defaults).
|
||||||
* Code repositories
|
graph_db_config: Optional mapping of graph DB configuration (overrides defaults).
|
||||||
* Academic papers and technical documentation
|
run_in_background: If True, starts the pipeline asynchronously and returns
|
||||||
* Mixed multimedia content (with text extraction)
|
background run info; if False, waits for completion and returns results.
|
||||||
|
incremental_loading: If True, performs incremental loading to avoid reprocessing
|
||||||
Processing Pipeline:
|
unchanged content.
|
||||||
1. **Document Classification**: Identifies document types and structures
|
custom_prompt: Optional prompt to override the default prompt used for graph
|
||||||
2. **Permission Validation**: Ensures user has processing rights
|
extraction.
|
||||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
|
||||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
|
||||||
5. **Relationship Detection**: Discovers connections between entities
|
|
||||||
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
|
||||||
7. **Content Summarization**: Creates hierarchical summaries for navigation
|
|
||||||
|
|
||||||
Graph Model Customization:
|
|
||||||
The `graph_model` parameter allows custom knowledge structures:
|
|
||||||
- **Default**: General-purpose KnowledgeGraph for any domain
|
|
||||||
- **Custom Models**: Domain-specific schemas (e.g., scientific papers, code analysis)
|
|
||||||
- **Ontology Integration**: Use `ontology_file_path` for predefined vocabularies
|
|
||||||
|
|
||||||
Args:
|
|
||||||
datasets: Dataset name(s) or dataset uuid to process. Processes all available data if None.
|
|
||||||
- Single dataset: "my_dataset"
|
|
||||||
- Multiple datasets: ["docs", "research", "reports"]
|
|
||||||
- None: Process all datasets for the user
|
|
||||||
user: User context for authentication and data access. Uses default if None.
|
|
||||||
graph_model: Pydantic model defining the knowledge graph structure.
|
|
||||||
Defaults to KnowledgeGraph for general-purpose processing.
|
|
||||||
chunker: Text chunking strategy (TextChunker, LangchainChunker).
|
|
||||||
- TextChunker: Paragraph-based chunking (default, most reliable)
|
|
||||||
- LangchainChunker: Recursive character splitting with overlap
|
|
||||||
Determines how documents are segmented for processing.
|
|
||||||
chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None.
|
|
||||||
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
|
|
||||||
Default limits: ~512-8192 tokens depending on models.
|
|
||||||
Smaller chunks = more granular but potentially fragmented knowledge.
|
|
||||||
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.
|
|
||||||
Useful for specialized fields like medical or legal documents.
|
|
||||||
vector_db_config: Custom vector database configuration for embeddings storage.
|
|
||||||
graph_db_config: Custom graph database configuration for relationship storage.
|
|
||||||
run_in_background: If True, starts processing asynchronously and returns immediately.
|
|
||||||
If False, waits for completion before returning.
|
|
||||||
Background mode recommended for large datasets (>100MB).
|
|
||||||
Use pipeline_run_id from return value to monitor progress.
|
|
||||||
custom_prompt: Optional custom prompt string to use for entity extraction and graph generation.
|
|
||||||
If provided, this prompt will be used instead of the default prompts for
|
|
||||||
knowledge graph extraction. The prompt should guide the LLM on how to
|
|
||||||
extract entities and relationships from the text content.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[dict, list[PipelineRunInfo]]:
|
The pipeline executor result. In blocking mode this is the pipeline run result
|
||||||
- **Blocking mode**: Dictionary mapping dataset_id -> PipelineRunInfo with:
|
(per-dataset run info and status). In background mode this returns information
|
||||||
* Processing status (completed/failed/in_progress)
|
required to track the background run (e.g., pipeline_run_id and submission status).
|
||||||
* Extracted entity and relationship counts
|
|
||||||
* Processing duration and resource usage
|
|
||||||
* Error details if any failures occurred
|
|
||||||
- **Background mode**: List of PipelineRunInfo objects for tracking progress
|
|
||||||
* Use pipeline_run_id to monitor status
|
|
||||||
* Check completion via pipeline monitoring APIs
|
|
||||||
|
|
||||||
Next Steps:
|
|
||||||
After successful cognify processing, use search functions to query the knowledge:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import cognee
|
|
||||||
from cognee import SearchType
|
|
||||||
|
|
||||||
# Process your data into knowledge graph
|
|
||||||
await cognee.cognify()
|
|
||||||
|
|
||||||
# Query for insights using different search types:
|
|
||||||
|
|
||||||
# 1. Natural language completion with graph context
|
|
||||||
insights = await cognee.search(
|
|
||||||
"What are the main themes?",
|
|
||||||
query_type=SearchType.GRAPH_COMPLETION
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Get entity relationships and connections
|
|
||||||
relationships = await cognee.search(
|
|
||||||
"connections between concepts",
|
|
||||||
query_type=SearchType.INSIGHTS
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Find relevant document chunks
|
|
||||||
chunks = await cognee.search(
|
|
||||||
"specific topic",
|
|
||||||
query_type=SearchType.CHUNKS
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Advanced Usage:
|
|
||||||
```python
|
|
||||||
# Custom domain model for scientific papers
|
|
||||||
class ScientificPaper(DataPoint):
|
|
||||||
title: str
|
|
||||||
authors: List[str]
|
|
||||||
methodology: str
|
|
||||||
findings: List[str]
|
|
||||||
|
|
||||||
await cognee.cognify(
|
|
||||||
datasets=["research_papers"],
|
|
||||||
graph_model=ScientificPaper,
|
|
||||||
ontology_file_path="scientific_ontology.owl"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Background processing for large datasets
|
|
||||||
run_info = await cognee.cognify(
|
|
||||||
datasets=["large_corpus"],
|
|
||||||
run_in_background=True
|
|
||||||
)
|
|
||||||
# Check status later with run_info.pipeline_run_id
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
Environment Variables:
|
|
||||||
Required:
|
|
||||||
- LLM_API_KEY: API key for your LLM provider
|
|
||||||
|
|
||||||
Optional (same as add function):
|
|
||||||
- LLM_PROVIDER, LLM_MODEL, VECTOR_DB_PROVIDER, GRAPH_DATABASE_PROVIDER
|
|
||||||
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
|
|
||||||
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
|
|
||||||
"""
|
"""
|
||||||
if temporal_cognify:
|
tasks = get_default_tasks(
|
||||||
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
|
||||||
else:
|
)
|
||||||
tasks = await get_default_tasks(
|
|
||||||
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||||
|
|
@ -211,20 +136,49 @@ async def cognify(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_default_tasks( # TODO: Find out a better way to do this (Boris's comment)
|
def get_default_tasks( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||||
user: User = None,
|
user: Optional[User] = None,
|
||||||
graph_model: BaseModel = KnowledgeGraph,
|
graph_model: Type[BaseModel] = KnowledgeGraph,
|
||||||
chunker=TextChunker,
|
chunker=TextChunker,
|
||||||
chunk_size: int = None,
|
chunk_size: Optional[int] = None,
|
||||||
ontology_file_path: Optional[str] = None,
|
ontology_file_path: Optional[str] = None,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
) -> list[Task]:
|
) -> list[Task]:
|
||||||
|
"""
|
||||||
|
Return the standard, non-translation Task list used by the cognify pipeline.
|
||||||
|
|
||||||
|
This builds the default processing pipeline (no automatic translation) and returns
|
||||||
|
a list of Task objects in execution order:
|
||||||
|
1. classify_documents
|
||||||
|
2. check_permissions_on_dataset (enforces write permission for `user`)
|
||||||
|
3. extract_chunks_from_documents (uses `chunker` and `chunk_size`)
|
||||||
|
4. extract_graph_from_data (uses `graph_model`, optional `ontology_file_path`, and `custom_prompt`)
|
||||||
|
5. summarize_text
|
||||||
|
6. add_data_points
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Batch sizes for downstream tasks use the module-level DEFAULT_BATCH_SIZE.
|
||||||
|
- If `chunk_size` is not provided, the token limit from get_max_chunk_tokens() is used.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
user: Optional user context used for the permission check.
|
||||||
|
graph_model: Model class used to construct knowledge graph instances.
|
||||||
|
chunker: Chunking strategy or class used to split documents into chunks.
|
||||||
|
chunk_size: Optional max tokens per chunk; if omitted, defaults to get_max_chunk_tokens().
|
||||||
|
ontology_file_path: Optional path to an ontology file passed to the extractor.
|
||||||
|
custom_prompt: Optional custom prompt applied during graph extraction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Task]: Ordered list of Task objects for the cognify pipeline (no translation).
|
||||||
|
"""
|
||||||
|
# Precompute max_chunk_size for stability
|
||||||
|
max_chunk = chunk_size or get_max_chunk_tokens()
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||||
Task(
|
Task(
|
||||||
extract_chunks_from_documents,
|
extract_chunks_from_documents,
|
||||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
max_chunk_size=max_chunk,
|
||||||
chunker=chunker,
|
chunker=chunker,
|
||||||
), # Extract text chunks based on the document type.
|
), # Extract text chunks based on the document type.
|
||||||
Task(
|
Task(
|
||||||
|
|
@ -232,51 +186,92 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
graph_model=graph_model,
|
graph_model=graph_model,
|
||||||
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
||||||
custom_prompt=custom_prompt,
|
custom_prompt=custom_prompt,
|
||||||
task_config={"batch_size": 10},
|
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||||
), # Generate knowledge graphs from the document chunks.
|
), # Generate knowledge graphs from the document chunks.
|
||||||
Task(
|
Task(
|
||||||
summarize_text,
|
summarize_text,
|
||||||
task_config={"batch_size": 10},
|
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||||
),
|
),
|
||||||
Task(add_data_points, task_config={"batch_size": 10}),
|
Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}),
|
||||||
]
|
]
|
||||||
|
|
||||||
return default_tasks
|
return default_tasks
|
||||||
|
|
||||||
|
|
||||||
async def get_temporal_tasks(
|
def get_default_tasks_with_translation( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||||
user: User = None, chunker=TextChunker, chunk_size: int = None
|
user: Optional[User] = None,
|
||||||
|
graph_model: Type[BaseModel] = KnowledgeGraph,
|
||||||
|
chunker=TextChunker,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
ontology_file_path: Optional[str] = None,
|
||||||
|
custom_prompt: Optional[str] = None,
|
||||||
|
translation_provider: str = "noop",
|
||||||
) -> list[Task]:
|
) -> list[Task]:
|
||||||
"""
|
"""
|
||||||
Builds and returns a list of temporal processing tasks to be executed in sequence.
|
Return the default Cognify pipeline task list with an added translation step.
|
||||||
|
|
||||||
The pipeline includes:
|
Constructs the standard processing pipeline (classify -> permission check -> chunk extraction -> translate -> graph extraction -> summarize -> add data points),
|
||||||
1. Document classification.
|
validates and initializes the named translation provider, and applies module DEFAULT_BATCH_SIZE to downstream batchable tasks.
|
||||||
2. Dataset permission checks (requires "write" access).
|
|
||||||
3. Document chunking with a specified or default chunk size.
|
|
||||||
4. Event and timestamp extraction from chunks.
|
|
||||||
5. Knowledge graph extraction from events.
|
|
||||||
6. Batched insertion of data points.
|
|
||||||
|
|
||||||
Args:
|
Parameters:
|
||||||
user (User, optional): The user requesting task execution, used for permission checks.
|
translation_provider (str): Name of a registered translation provider (case-insensitive). Defaults to `"noop"` which is a no-op provider.
|
||||||
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
|
|
||||||
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Task]: A list of Task objects representing the temporal processing pipeline.
|
list[Task]: Ordered Task objects ready to be executed by the pipeline executor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UnknownTranslationProviderError: If the given provider name is not in get_available_providers().
|
||||||
|
ProviderInitializationError: If the provider fails to initialize or validate via validate_provider().
|
||||||
"""
|
"""
|
||||||
temporal_tasks = [
|
# Fail fast on unknown providers (keeps errors close to the API surface)
|
||||||
|
translation_provider = (translation_provider or "noop").strip().lower()
|
||||||
|
# Validate provider using public API
|
||||||
|
if translation_provider not in get_available_providers():
|
||||||
|
available = ", ".join(get_available_providers())
|
||||||
|
logger.error("Unknown provider '%s'. Available: %s", translation_provider, available)
|
||||||
|
raise UnknownTranslationProviderError(f"Unknown provider '{translation_provider}'")
|
||||||
|
# Instantiate to validate dependencies; include provider-specific config errors
|
||||||
|
try:
|
||||||
|
validate_provider(translation_provider)
|
||||||
|
except Exception as e: # we want to convert provider init errors
|
||||||
|
available = ", ".join(get_available_providers())
|
||||||
|
logger.error(
|
||||||
|
"Provider '%s' failed to initialize (available: %s).",
|
||||||
|
translation_provider,
|
||||||
|
available,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise ProviderInitializationError() from e
|
||||||
|
|
||||||
|
# Precompute max_chunk_size for stability
|
||||||
|
max_chunk = chunk_size or get_max_chunk_tokens()
|
||||||
|
|
||||||
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||||
Task(
|
Task(
|
||||||
extract_chunks_from_documents,
|
extract_chunks_from_documents,
|
||||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
max_chunk_size=max_chunk,
|
||||||
chunker=chunker,
|
chunker=chunker,
|
||||||
|
), # Extract text chunks based on the document type.
|
||||||
|
Task(
|
||||||
|
translate_content,
|
||||||
|
target_language="en",
|
||||||
|
translation_provider=translation_provider,
|
||||||
|
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||||
|
), # Auto-translate non-English content and attach metadata
|
||||||
|
Task(
|
||||||
|
extract_graph_from_data,
|
||||||
|
graph_model=graph_model,
|
||||||
|
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
||||||
|
custom_prompt=custom_prompt,
|
||||||
|
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||||
|
), # Generate knowledge graphs from the document chunks.
|
||||||
|
Task(
|
||||||
|
summarize_text,
|
||||||
|
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||||
),
|
),
|
||||||
Task(extract_events_and_timestamps, task_config={"chunk_size": 10}),
|
Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}),
|
||||||
Task(extract_knowledge_graph_from_events),
|
|
||||||
Task(add_data_points, task_config={"batch_size": 10}),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return temporal_tasks
|
return default_tasks
|
||||||
|
|
|
||||||
496
cognee/tasks/translation/test_translation.py
Normal file
496
cognee/tasks/translation/test_translation.py
Normal file
|
|
@ -0,0 +1,496 @@
|
||||||
|
"""
|
||||||
|
Unit tests for translation functionality.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Translation provider registry and discovery
|
||||||
|
- Language detection across providers
|
||||||
|
- Translation functionality
|
||||||
|
- Error handling and fallbacks
|
||||||
|
- Model validation and serialization
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest # type: ignore[import-untyped]
|
||||||
|
from typing import Tuple, Optional, Dict
|
||||||
|
from pydantic import ValidationError
|
||||||
|
import cognee.tasks.translation.translate_content as tr
|
||||||
|
|
||||||
|
from cognee.tasks.translation.translate_content import (
|
||||||
|
translate_content,
|
||||||
|
register_translation_provider,
|
||||||
|
get_available_providers,
|
||||||
|
TranslationProvider,
|
||||||
|
NoOpProvider,
|
||||||
|
_get_provider,
|
||||||
|
)
|
||||||
|
from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectionError(Exception): # pylint: disable=too-few-public-methods
|
||||||
|
"""Test exception for detection failures."""
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranslationError(Exception): # pylint: disable=too-few-public-methods
|
||||||
|
"""Test exception for translation failures."""
|
||||||
|
|
||||||
|
|
||||||
|
# Ensure registry isolation across tests using public helpers
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _restore_registry():
|
||||||
|
"""
|
||||||
|
Pytest fixture that snapshots the translation provider registry and restores it after the test.
|
||||||
|
|
||||||
|
Use to isolate tests that register or modify providers: the current registry state is captured before the test runs, and always restored when the fixture completes (including on exceptions).
|
||||||
|
"""
|
||||||
|
snapshot = tr.snapshot_registry()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
tr.restore_registry(snapshot)
|
||||||
|
|
||||||
|
|
||||||
|
class MockDocumentChunk: # pylint: disable=too-few-public-methods
|
||||||
|
"""Mock document chunk for testing."""
|
||||||
|
|
||||||
|
def __init__(self, text: str, chunk_id: str = "test_chunk", metadata: Optional[Dict] = None):
|
||||||
|
"""
|
||||||
|
Initialize a mock document chunk used in tests.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): Chunk text content.
|
||||||
|
chunk_id (str): Identifier for the chunk; also used as chunk_index for tests. Defaults to "test_chunk".
|
||||||
|
metadata (Optional[Dict]): Optional mapping of metadata values; defaults to an empty dict.
|
||||||
|
"""
|
||||||
|
self.text = text
|
||||||
|
self.id = chunk_id
|
||||||
|
self.chunk_index = chunk_id
|
||||||
|
self.metadata = metadata or {}
|
||||||
|
|
||||||
|
|
||||||
|
class MockTranslationProvider:
|
||||||
|
"""Mock provider for testing custom provider registration."""
|
||||||
|
|
||||||
|
async def detect_language(self, text: str) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Detect the language of the given text and return an ISO 639-1 language code with a confidence score.
|
||||||
|
|
||||||
|
This mock implementation uses simple keyword heuristics: returns ("es", 0.95) if the text contains "hola",
|
||||||
|
("fr", 0.90) if it contains "bonjour", and ("en", 0.85) otherwise.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): Input text to analyze.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, float]: A tuple of (language_code, confidence) where language_code is an ISO 639-1 code and
|
||||||
|
confidence is a float between 0.0 and 1.0 indicating detection confidence.
|
||||||
|
"""
|
||||||
|
if "hola" in text.lower():
|
||||||
|
return "es", 0.95
|
||||||
|
if "bonjour" in text.lower():
|
||||||
|
return "fr", 0.90
|
||||||
|
return "en", 0.85
|
||||||
|
|
||||||
|
async def translate(self, text: str, target_language: str) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Simulate translating `text` into `target_language` and return a mock translated string with a confidence score.
|
||||||
|
|
||||||
|
If `target_language` is "en", returns the input prefixed with "[MOCK TRANSLATED]" and a confidence of 0.88. For any other target language, returns the original `text` and a confidence of 0.0.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): The text to translate.
|
||||||
|
target_language (str): The target language code (e.g., "en").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, float]: A pair of (translated_text, confidence) where confidence is in [0.0, 1.0].
|
||||||
|
"""
|
||||||
|
if target_language == "en":
|
||||||
|
return f"[MOCK TRANSLATED] {text}", 0.88
|
||||||
|
return text, 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderRegistry:
|
||||||
|
"""Test translation provider registration and discovery."""
|
||||||
|
|
||||||
|
def test_get_available_providers_includes_builtin(self):
|
||||||
|
"""Test that built-in providers are included in available list."""
|
||||||
|
providers = get_available_providers()
|
||||||
|
assert "noop" in providers
|
||||||
|
assert "langdetect" in providers
|
||||||
|
|
||||||
|
def test_register_custom_provider(self):
|
||||||
|
"""Test custom provider registration."""
|
||||||
|
register_translation_provider("mock", MockTranslationProvider)
|
||||||
|
providers = get_available_providers()
|
||||||
|
assert "mock" in providers
|
||||||
|
|
||||||
|
# Test provider can be retrieved
|
||||||
|
provider = _get_provider("mock")
|
||||||
|
assert isinstance(provider, MockTranslationProvider)
|
||||||
|
|
||||||
|
def test_provider_name_normalization(self):
|
||||||
|
"""Test provider names are normalized to lowercase."""
|
||||||
|
register_translation_provider("CUSTOM_PROVIDER", MockTranslationProvider)
|
||||||
|
providers = get_available_providers()
|
||||||
|
assert "custom_provider" in providers
|
||||||
|
|
||||||
|
# Should be retrievable with different casing
|
||||||
|
provider1 = _get_provider("CUSTOM_PROVIDER")
|
||||||
|
provider2 = _get_provider("custom_provider")
|
||||||
|
assert provider1.__class__ is provider2.__class__
|
||||||
|
|
||||||
|
def test_unknown_provider_raises(self):
|
||||||
|
"""Test unknown providers raise ValueError."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_get_provider("nonexistent_provider")
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoOpProvider:
|
||||||
|
"""Test NoOp provider functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detect_language_ascii(self):
|
||||||
|
"""Test language detection for ASCII text."""
|
||||||
|
provider = NoOpProvider()
|
||||||
|
lang, conf = await provider.detect_language("Hello world")
|
||||||
|
assert lang is None
|
||||||
|
assert conf == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detect_language_unicode(self):
|
||||||
|
"""Test language detection for Unicode text."""
|
||||||
|
provider = NoOpProvider()
|
||||||
|
lang, conf = await provider.detect_language("Hëllo wörld")
|
||||||
|
assert lang is None
|
||||||
|
assert conf == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_translate_returns_original(self):
|
||||||
|
"""Test translation returns original text with zero confidence."""
|
||||||
|
provider = NoOpProvider()
|
||||||
|
text = "Test text"
|
||||||
|
translated, conf = await provider.translate(text, "es")
|
||||||
|
assert translated == text
|
||||||
|
assert conf == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranslationModels:
|
||||||
|
"""Test Pydantic models for translation data."""
|
||||||
|
|
||||||
|
def test_translated_content_validation(self):
|
||||||
|
"""Test TranslatedContent model validation."""
|
||||||
|
content = TranslatedContent(
|
||||||
|
original_chunk_id="chunk_1",
|
||||||
|
original_text="Hello",
|
||||||
|
translated_text="Hola",
|
||||||
|
source_language="en",
|
||||||
|
target_language="es",
|
||||||
|
translation_provider="test",
|
||||||
|
confidence_score=0.9
|
||||||
|
)
|
||||||
|
assert content.original_chunk_id == "chunk_1"
|
||||||
|
assert content.confidence_score == 0.9
|
||||||
|
|
||||||
|
def test_translated_content_confidence_validation(self):
|
||||||
|
"""Test confidence score validation bounds."""
|
||||||
|
# Valid confidence scores
|
||||||
|
TranslatedContent(
|
||||||
|
original_chunk_id="test",
|
||||||
|
original_text="test",
|
||||||
|
translated_text="test",
|
||||||
|
source_language="en",
|
||||||
|
confidence_score=0.0
|
||||||
|
)
|
||||||
|
TranslatedContent(
|
||||||
|
original_chunk_id="test",
|
||||||
|
original_text="test",
|
||||||
|
translated_text="test",
|
||||||
|
source_language="en",
|
||||||
|
confidence_score=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalid confidence scores should raise validation error
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TranslatedContent(
|
||||||
|
original_chunk_id="test",
|
||||||
|
original_text="test",
|
||||||
|
translated_text="test",
|
||||||
|
source_language="en",
|
||||||
|
confidence_score=-0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TranslatedContent(
|
||||||
|
original_chunk_id="test",
|
||||||
|
original_text="test",
|
||||||
|
translated_text="test",
|
||||||
|
source_language="en",
|
||||||
|
confidence_score=1.1
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_language_metadata_validation(self):
|
||||||
|
"""Test LanguageMetadata model validation."""
|
||||||
|
metadata = LanguageMetadata(
|
||||||
|
content_id="chunk_1",
|
||||||
|
detected_language="es",
|
||||||
|
language_confidence=0.95,
|
||||||
|
requires_translation=True,
|
||||||
|
character_count=100
|
||||||
|
)
|
||||||
|
assert metadata.content_id == "chunk_1"
|
||||||
|
assert metadata.requires_translation is True
|
||||||
|
assert metadata.character_count == 100
|
||||||
|
|
||||||
|
def test_language_metadata_character_count_validation(self):
|
||||||
|
"""Test character count cannot be negative."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
LanguageMetadata(
|
||||||
|
content_id="test",
|
||||||
|
detected_language="en",
|
||||||
|
character_count=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranslateContentFunction:
|
||||||
|
"""Test main translate_content function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_noop_provider_processing(self):
|
||||||
|
"""Test processing with noop provider."""
|
||||||
|
chunks = [
|
||||||
|
MockDocumentChunk("Hello world", "chunk_1"),
|
||||||
|
MockDocumentChunk("Test content", "chunk_2")
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await translate_content(
|
||||||
|
chunks,
|
||||||
|
target_language="en",
|
||||||
|
translation_provider="noop",
|
||||||
|
confidence_threshold=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
for chunk in result:
|
||||||
|
assert "language" in chunk.metadata
|
||||||
|
assert chunk.metadata["language"]["detected_language"] == "unknown"
|
||||||
|
# No translation should occur with noop provider
|
||||||
|
assert "translation" not in chunk.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_translation_with_custom_provider(self):
|
||||||
|
"""Test translation with custom registered provider."""
|
||||||
|
# Register mock provider
|
||||||
|
register_translation_provider("test_provider", MockTranslationProvider)
|
||||||
|
|
||||||
|
chunks = [MockDocumentChunk("Hola mundo", "chunk_1")]
|
||||||
|
|
||||||
|
result = await translate_content(
|
||||||
|
chunks,
|
||||||
|
target_language="en",
|
||||||
|
translation_provider="test_provider",
|
||||||
|
confidence_threshold=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = result[0]
|
||||||
|
assert "language" in chunk.metadata
|
||||||
|
assert "translation" in chunk.metadata
|
||||||
|
|
||||||
|
# Check language metadata
|
||||||
|
lang_meta = chunk.metadata["language"]
|
||||||
|
assert lang_meta["detected_language"] == "es"
|
||||||
|
assert lang_meta["requires_translation"] is True
|
||||||
|
|
||||||
|
# Check translation metadata
|
||||||
|
trans_meta = chunk.metadata["translation"]
|
||||||
|
assert trans_meta["original_text"] == "Hola mundo"
|
||||||
|
assert "[MOCK TRANSLATED]" in trans_meta["translated_text"]
|
||||||
|
assert trans_meta["source_language"] == "es"
|
||||||
|
assert trans_meta["target_language"] == "en"
|
||||||
|
assert trans_meta["translation_provider"] == "test_provider"
|
||||||
|
|
||||||
|
# Check chunk text was updated
|
||||||
|
assert "[MOCK TRANSLATED]" in chunk.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_low_confidence_no_translation(self):
|
||||||
|
"""Test that low confidence detection doesn't trigger translation."""
|
||||||
|
register_translation_provider("low_conf", MockTranslationProvider)
|
||||||
|
|
||||||
|
chunks = [MockDocumentChunk("Hello world", "chunk_1")] # English text
|
||||||
|
|
||||||
|
result = await translate_content(
|
||||||
|
chunks,
|
||||||
|
target_language="en",
|
||||||
|
translation_provider="low_conf",
|
||||||
|
confidence_threshold=0.9 # High threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = result[0]
|
||||||
|
assert "language" in chunk.metadata
|
||||||
|
# Should not translate due to high threshold and English detection
|
||||||
|
assert "translation" not in chunk.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_handling_in_detection(self):
|
||||||
|
"""Test graceful error handling in language detection."""
|
||||||
|
class FailingProvider:
|
||||||
|
async def detect_language(self, _text: str) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Simulate a language detection failure by always raising TestDetectionError.
|
||||||
|
|
||||||
|
This async method is used in tests to emulate a provider that fails during language detection. It accepts a text string but does not return; it always raises TestDetectionError.
|
||||||
|
"""
|
||||||
|
raise TestDetectionError()
|
||||||
|
|
||||||
|
async def translate(self, text: str, _target_language: str) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Return the input text unchanged and a translation confidence of 0.0.
|
||||||
|
|
||||||
|
This no-op translator performs no translation; the supplied target language is ignored.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): Source text to "translate".
|
||||||
|
_target_language (str): Target language (ignored).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, float]: A tuple containing the original text and a confidence score (always 0.0).
|
||||||
|
"""
|
||||||
|
return text, 0.0
|
||||||
|
|
||||||
|
register_translation_provider("failing", FailingProvider)
|
||||||
|
|
||||||
|
chunks = [MockDocumentChunk("Test text", "chunk_1")]
|
||||||
|
|
||||||
|
# Disable 'langdetect' fallback to force unknown
|
||||||
|
ld = tr._provider_registry.pop("langdetect", None)
|
||||||
|
try:
|
||||||
|
result = await translate_content(chunks, translation_provider="failing")
|
||||||
|
finally:
|
||||||
|
if ld is not None:
|
||||||
|
tr._provider_registry["langdetect"] = ld
|
||||||
|
|
||||||
|
chunk = result[0]
|
||||||
|
assert "language" in chunk.metadata
|
||||||
|
# Should have unknown language due to detection failure
|
||||||
|
lang_meta = chunk.metadata["language"]
|
||||||
|
assert lang_meta["detected_language"] == "unknown"
|
||||||
|
assert lang_meta["language_confidence"] == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_handling_in_translation(self):
|
||||||
|
"""Test graceful error handling in translation."""
|
||||||
|
class PartialProvider:
|
||||||
|
async def detect_language(self, _text: str) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Mock language detection used in tests.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
_text (str): Input text (ignored by this mock).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, float]: A fixed detected language code ("es") and confidence (0.9).
|
||||||
|
"""
|
||||||
|
return "es", 0.9
|
||||||
|
|
||||||
|
async def translate(self, _text: str, _target_language: str) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Simulate a failing translation by always raising TestTranslationError.
|
||||||
|
|
||||||
|
This async method ignores its inputs and is used in tests to emulate a provider-side failure during translation.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
_text (str): Unused input text.
|
||||||
|
_target_language (str): Unused target language code.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TestTranslationError: Always raised to simulate a translation failure.
|
||||||
|
"""
|
||||||
|
raise TestTranslationError()
|
||||||
|
|
||||||
|
register_translation_provider("partial", PartialProvider)
|
||||||
|
|
||||||
|
chunks = [MockDocumentChunk("Hola", "chunk_1")]
|
||||||
|
|
||||||
|
result = await translate_content(
|
||||||
|
chunks,
|
||||||
|
translation_provider="partial",
|
||||||
|
confidence_threshold=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = result[0]
|
||||||
|
# Should have detected Spanish but failed translation
|
||||||
|
assert chunk.metadata["language"]["detected_language"] == "es"
|
||||||
|
# Should still create translation metadata with original text
|
||||||
|
assert "translation" in chunk.metadata
|
||||||
|
trans_meta = chunk.metadata["translation"]
|
||||||
|
assert trans_meta["translated_text"] == "Hola" # Original text due to failure
|
||||||
|
assert trans_meta["confidence_score"] == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_translation_when_same_language(self):
|
||||||
|
"""Test no translation occurs when source equals target language."""
|
||||||
|
register_translation_provider("same_lang", MockTranslationProvider)
|
||||||
|
|
||||||
|
chunks = [MockDocumentChunk("Hello world", "chunk_1")]
|
||||||
|
|
||||||
|
result = await translate_content(
|
||||||
|
chunks,
|
||||||
|
target_language="en", # Same as detected language
|
||||||
|
translation_provider="same_lang"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = result[0]
|
||||||
|
assert "language" in chunk.metadata
|
||||||
|
# No translation should occur for same language
|
||||||
|
assert "translation" not in chunk.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metadata_serialization(self):
|
||||||
|
"""Test that metadata is properly serialized to dicts."""
|
||||||
|
register_translation_provider("serialize_test", MockTranslationProvider)
|
||||||
|
|
||||||
|
chunks = [MockDocumentChunk("Hola", "chunk_1")]
|
||||||
|
|
||||||
|
result = await translate_content(
|
||||||
|
chunks,
|
||||||
|
translation_provider="serialize_test",
|
||||||
|
confidence_threshold=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = result[0]
|
||||||
|
|
||||||
|
# Metadata should be plain dicts, not Pydantic models
|
||||||
|
assert isinstance(chunk.metadata["language"], dict)
|
||||||
|
if "translation" in chunk.metadata:
|
||||||
|
assert isinstance(chunk.metadata["translation"], dict)
|
||||||
|
|
||||||
|
def test_model_serialization_compatibility(self):
|
||||||
|
"""
|
||||||
|
Verify that a TranslatedContent instance can be dumped to a JSON-serializable dict.
|
||||||
|
|
||||||
|
Creates a TranslatedContent with sample fields, calls model_dump(), and asserts:
|
||||||
|
- the result is a dict,
|
||||||
|
- required fields like `original_chunk_id`, `translation_timestamp`, and `metadata` are present and preserved,
|
||||||
|
- the dict can be round-tripped through json.dumps/json.loads without losing `original_chunk_id`.
|
||||||
|
"""
|
||||||
|
content = TranslatedContent(
|
||||||
|
original_chunk_id="test",
|
||||||
|
original_text="Hello",
|
||||||
|
translated_text="Hola",
|
||||||
|
source_language="en",
|
||||||
|
target_language="es"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should serialize to dict
|
||||||
|
data = content.model_dump()
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert data["original_chunk_id"] == "test"
|
||||||
|
assert "translation_timestamp" in data
|
||||||
|
assert "metadata" in data
|
||||||
|
|
||||||
|
# Should be JSON serializable
|
||||||
|
import json
|
||||||
|
json_str = json.dumps(data)
|
||||||
|
parsed = json.loads(json_str)
|
||||||
|
assert parsed["original_chunk_id"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
660
cognee/tasks/translation/translate_content.py
Normal file
660
cognee/tasks/translation/translate_content.py
Normal file
|
|
@ -0,0 +1,660 @@
|
||||||
|
# pylint: disable=R0903, W0221
|
||||||
|
"""This module provides content translation capabilities for the Cognee framework."""
|
||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, Type, Protocol, Tuple, Optional
|
||||||
|
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from .models import TranslatedContent, LanguageMetadata
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
# Custom exceptions for better error handling
|
||||||
|
class TranslationDependencyError(ImportError):
|
||||||
|
"""Raised when a required translation dependency is missing."""
|
||||||
|
|
||||||
|
class LangDetectError(TranslationDependencyError):
|
||||||
|
"""LangDetect library required."""
|
||||||
|
|
||||||
|
class OpenAIError(TranslationDependencyError):
|
||||||
|
"""OpenAI library required."""
|
||||||
|
|
||||||
|
class GoogleTranslateError(TranslationDependencyError):
|
||||||
|
"""GoogleTrans library required."""
|
||||||
|
|
||||||
|
class AzureTranslateError(TranslationDependencyError):
|
||||||
|
"""Azure AI Translation library required."""
|
||||||
|
|
||||||
|
class AzureConfigError(ValueError):
|
||||||
|
"""Azure configuration error."""
|
||||||
|
|
||||||
|
# Environment variables for configuration
|
||||||
|
TARGET_LANGUAGE = os.getenv("COGNEE_TRANSLATION_TARGET_LANGUAGE", "en")
|
||||||
|
try:
|
||||||
|
CONFIDENCE_THRESHOLD = float(os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD", "0.80"))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.warning(
|
||||||
|
"Invalid float for COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD=%r; defaulting to 0.80",
|
||||||
|
os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD"),
|
||||||
|
)
|
||||||
|
CONFIDENCE_THRESHOLD = 0.80
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranslationContext:
|
||||||
|
"""A context object to hold data for a single translation operation."""
|
||||||
|
provider: "TranslationProvider"
|
||||||
|
chunk: Any
|
||||||
|
text: str
|
||||||
|
target_language: str
|
||||||
|
confidence_threshold: float
|
||||||
|
provider_name: str
|
||||||
|
content_id: str = field(init=False)
|
||||||
|
detected_language: str = "unknown"
|
||||||
|
detection_confidence: float = 0.0
|
||||||
|
requires_translation: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""
|
||||||
|
Initialize derived fields after dataclass construction.
|
||||||
|
|
||||||
|
Sets self.content_id to the first available identifier on self.chunk in this order:
|
||||||
|
- self.chunk.id
|
||||||
|
- self.chunk.chunk_index
|
||||||
|
If neither attribute exists, content_id is set to the string "unknown".
|
||||||
|
"""
|
||||||
|
self.content_id = getattr(self.chunk, "id", getattr(self.chunk, "chunk_index", "unknown"))
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationProvider(Protocol):
|
||||||
|
"""Protocol for translation providers."""
|
||||||
|
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Detect the language of the provided text.
|
||||||
|
|
||||||
|
Uses the langdetect library to determine the most likely language and its probability.
|
||||||
|
Returns a tuple (language_code, confidence) where `language_code` is a normalized short code (e.g., "en", "fr" or "unknown") and `confidence` is a float in [0.0, 1.0]. Returns None when detection fails (empty input, an error, or no reliable result).
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Translate the given text into the specified target language asynchronously.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text: The source text to translate.
|
||||||
|
target_language: Target language code (e.g., "en", "es", "fr-CA").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (translated_text, confidence) on success, where `confidence` is a float in [0.0, 1.0] (may be 0.0 if the provider does not supply a score), or None if translation failed or was unavailable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Registry for translation providers
|
||||||
|
_provider_registry: Dict[str, Type[TranslationProvider]] = {}
|
||||||
|
|
||||||
|
def register_translation_provider(name: str, provider: Type[TranslationProvider]):
|
||||||
|
"""
|
||||||
|
Register a translation provider under a canonical lowercase key.
|
||||||
|
|
||||||
|
The provided class will be stored in the internal provider registry and looked up by its lowercased `name`. If an entry with the same key already exists it will be replaced.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
name (str): Human-readable provider name (case-insensitive); stored as lower-case.
|
||||||
|
provider (Type[TranslationProvider]): Provider class implementing the TranslationProvider protocol; instances are constructed when the provider is resolved.
|
||||||
|
"""
|
||||||
|
_provider_registry[name.lower()] = provider
|
||||||
|
|
||||||
|
def get_available_providers():
|
||||||
|
"""Returns a list of available translation providers."""
|
||||||
|
return sorted(_provider_registry.keys())
|
||||||
|
|
||||||
|
def _get_provider(translation_provider: str) -> TranslationProvider:
|
||||||
|
"""
|
||||||
|
Resolve and instantiate a registered translation provider by name.
|
||||||
|
|
||||||
|
The lookup is case-insensitive: `translation_provider` should be the provider key (e.g., "openai", "google", "noop").
|
||||||
|
Returns an instance of the provider implementing the TranslationProvider protocol.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if no provider is registered under the given name; the error message lists available providers.
|
||||||
|
"""
|
||||||
|
provider_class = _provider_registry.get(translation_provider.lower())
|
||||||
|
if not provider_class:
|
||||||
|
available = ', '.join(get_available_providers())
|
||||||
|
msg = f"Unknown translation provider: {translation_provider}. Available providers: {available}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return provider_class()
|
||||||
|
# Helpers
|
||||||
|
def _normalize_lang_code(code: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
Normalize a language code to a canonical form or return "unknown".
|
||||||
|
|
||||||
|
Normalizes common language code formats:
|
||||||
|
- Two-letter codes (e.g., "en", "EN", " en ") -> "en"
|
||||||
|
- Locale codes with region (e.g., "en-us", "en_US", "EN-us") -> "en-US"
|
||||||
|
- Returns "unknown" for empty, non-string, or unrecognized inputs.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
code (Optional[str]): Language code or locale string to normalize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Normalized language code in either "xx" or "xx-YY" form, or "unknown" if input is invalid.
|
||||||
|
"""
|
||||||
|
if not isinstance(code, str) or not code.strip():
|
||||||
|
return "unknown"
|
||||||
|
c = code.strip().replace("_", "-")
|
||||||
|
parts = c.split("-")
|
||||||
|
if len(parts) == 1 and len(parts[0]) == 2 and parts[0].isalpha():
|
||||||
|
return parts[0].lower()
|
||||||
|
if len(parts) >= 2 and len(parts[0]) == 2 and parts[1]:
|
||||||
|
return f"{parts[0].lower()}-{parts[1][:2].upper()}"
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _provider_name(provider: TranslationProvider) -> str:
|
||||||
|
"""Return the canonical registry key for a provider instance, or a best-effort name."""
|
||||||
|
return next(
|
||||||
|
(name for name, cls in _provider_registry.items() if isinstance(provider, cls)),
|
||||||
|
provider.__class__.__name__.replace("Provider", "").lower(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _detect_language_with_fallback(provider: TranslationProvider, text: str, content_id: str) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Detect the language of `text`, falling back to the registered "langdetect" provider if the primary provider fails.
|
||||||
|
|
||||||
|
Attempts to call the primary provider's `detect_language`. If that call returns None or raises, and a different "langdetect" provider is registered, it will try the fallback. Detection failures are logged; exceptions are not propagated.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): The text to detect language for.
|
||||||
|
content_id (str): Identifier used in logs to correlate errors to the input content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, float]: A normalized language code (e.g., "en" or "pt-BR") and a confidence score in [0.0, 1.0].
|
||||||
|
On detection failure returns ("unknown", 0.0). Confidence values are coerced to float, NaNs converted to 0.0, and clamped to the [0.0, 1.0] range.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
detection = await provider.detect_language(text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Language detection failed for content_id=%s", content_id)
|
||||||
|
detection = None
|
||||||
|
|
||||||
|
if detection is None:
|
||||||
|
fallback_cls = _provider_registry.get("langdetect")
|
||||||
|
if fallback_cls is not None and not isinstance(provider, fallback_cls):
|
||||||
|
try:
|
||||||
|
detection = await fallback_cls().detect_language(text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Fallback language detection failed for content_id=%s", content_id)
|
||||||
|
detection = None
|
||||||
|
|
||||||
|
if detection is None:
|
||||||
|
return "unknown", 0.0
|
||||||
|
|
||||||
|
lang_code, conf = detection
|
||||||
|
detected_language = _normalize_lang_code(lang_code)
|
||||||
|
try:
|
||||||
|
conf = float(conf)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
conf = 0.0
|
||||||
|
if math.isnan(conf):
|
||||||
|
conf = 0.0
|
||||||
|
conf = max(0.0, min(1.0, conf))
|
||||||
|
return detected_language, conf
|
||||||
|
|
||||||
|
def _decide_if_translation_is_required(ctx: TranslationContext) -> None:
|
||||||
|
"""
|
||||||
|
Decide whether a translation should be performed and update ctx.requires_translation.
|
||||||
|
|
||||||
|
Normalizes the configured target language and marks translation as required only when:
|
||||||
|
- The provider can perform translations (not "noop" or "langdetect"), and
|
||||||
|
- Either the detected language is "unknown" and the text is non-empty, or
|
||||||
|
- The detected language (normalized) differs from the target language and the detection confidence meets or exceeds ctx.confidence_threshold.
|
||||||
|
|
||||||
|
The function mutates the provided TranslationContext in-place and does not return a value.
|
||||||
|
"""
|
||||||
|
# Normalize to align with detected_language normalization and model regex.
|
||||||
|
target_language = _normalize_lang_code(ctx.target_language)
|
||||||
|
can_translate = ctx.provider_name not in ("noop", "langdetect")
|
||||||
|
|
||||||
|
if ctx.detected_language == "unknown":
|
||||||
|
ctx.requires_translation = can_translate and bool(ctx.text.strip())
|
||||||
|
else:
|
||||||
|
ctx.requires_translation = (
|
||||||
|
ctx.detected_language != target_language
|
||||||
|
and ctx.detection_confidence >= ctx.confidence_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
def _attach_language_metadata(ctx: TranslationContext) -> None:
|
||||||
|
"""
|
||||||
|
Attach language detection and translation decision metadata to the context's chunk.
|
||||||
|
|
||||||
|
Ensures the chunk has a metadata mapping, builds a LanguageMetadata record from
|
||||||
|
the context (content_id, detected language and confidence, whether translation is
|
||||||
|
required, and character count of the text), serializes it, and stores it under
|
||||||
|
the "language" key in chunk.metadata.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
ctx (TranslationContext): Context containing the chunk and detection/decision values.
|
||||||
|
"""
|
||||||
|
ctx.chunk.metadata = getattr(ctx.chunk, "metadata", {}) or {}
|
||||||
|
lang_meta = LanguageMetadata(
|
||||||
|
content_id=str(ctx.content_id),
|
||||||
|
detected_language=ctx.detected_language,
|
||||||
|
language_confidence=ctx.detection_confidence,
|
||||||
|
requires_translation=ctx.requires_translation,
|
||||||
|
character_count=len(ctx.text),
|
||||||
|
)
|
||||||
|
ctx.chunk.metadata["language"] = lang_meta.model_dump()
|
||||||
|
|
||||||
|
async def _translate_and_update(ctx: TranslationContext) -> None:
|
||||||
|
"""
|
||||||
|
Translate the text in the provided TranslationContext and update the chunk and its metadata.
|
||||||
|
|
||||||
|
Performs an async translation via ctx.provider.translate, and when a non-empty, changed translation is returned:
|
||||||
|
- replaces ctx.chunk.text with the translated text,
|
||||||
|
- attempts to update ctx.chunk.chunk_size (if present),
|
||||||
|
- attaches a `translation` entry in ctx.chunk.metadata containing a TranslatedContent dict (original/translated text, source/target languages, provider, and confidence).
|
||||||
|
|
||||||
|
If translation fails (exception or None) the original text is preserved and a TranslatedContent record is still attached with confidence 0.0. If the provider returns the same text unchanged, no metadata is attached and the function returns without modifying the chunk.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
ctx (TranslationContext): context carrying provider, chunk, original text, target language, detected language, and content_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tr = await ctx.provider.translate(ctx.text, ctx.target_language)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Translation failed for content_id=%s", ctx.content_id)
|
||||||
|
tr = None
|
||||||
|
|
||||||
|
translated_text = None
|
||||||
|
translation_confidence = 0.0
|
||||||
|
provider_used = _provider_name(ctx.provider)
|
||||||
|
target_for_meta = _normalize_lang_code(ctx.target_language)
|
||||||
|
|
||||||
|
if tr and isinstance(tr[0], str) and tr[0].strip() and tr[0] != ctx.text:
|
||||||
|
translated_text, translation_confidence = tr
|
||||||
|
ctx.chunk.text = translated_text
|
||||||
|
if hasattr(ctx.chunk, "chunk_size"):
|
||||||
|
try:
|
||||||
|
ctx.chunk.chunk_size = len(translated_text.split())
|
||||||
|
except (AttributeError, ValueError, TypeError):
|
||||||
|
logger.debug(
|
||||||
|
"Could not update chunk_size for content_id=%s",
|
||||||
|
ctx.content_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
elif tr is None:
|
||||||
|
# Translation failed, keep original text
|
||||||
|
translated_text = ctx.text
|
||||||
|
else:
|
||||||
|
# Provider returned unchanged text
|
||||||
|
logger.info("Provider returned unchanged text; skipping translation metadata (content_id=%s)", ctx.content_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
trans = TranslatedContent(
|
||||||
|
original_chunk_id=str(ctx.content_id),
|
||||||
|
original_text=ctx.text,
|
||||||
|
translated_text=translated_text,
|
||||||
|
source_language=ctx.detected_language,
|
||||||
|
target_language=target_for_meta,
|
||||||
|
translation_provider=provider_used,
|
||||||
|
confidence_score=translation_confidence or 0.0,
|
||||||
|
)
|
||||||
|
ctx.chunk.metadata["translation"] = trans.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
# Test helpers for registry isolation
|
||||||
|
def snapshot_registry() -> Dict[str, Type[TranslationProvider]]:
|
||||||
|
"""Return a shallow copy snapshot of the provider registry (for tests)."""
|
||||||
|
return dict(_provider_registry)
|
||||||
|
|
||||||
|
def restore_registry(snapshot: Dict[str, Type[TranslationProvider]]) -> None:
|
||||||
|
"""
|
||||||
|
Restore the global translation provider registry from a previously captured snapshot.
|
||||||
|
|
||||||
|
This replaces the current internal provider registry with the given snapshot (clears then updates),
|
||||||
|
typically used by tests to restore provider registration state.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
snapshot (Dict[str, Type[TranslationProvider]]): Mapping of provider name keys to provider classes.
|
||||||
|
"""
|
||||||
|
_provider_registry.clear()
|
||||||
|
_provider_registry.update(snapshot)
|
||||||
|
|
||||||
|
def validate_provider(name: str) -> None:
|
||||||
|
"""Ensure a provider can be resolved and instantiated or raise."""
|
||||||
|
_get_provider(name)
|
||||||
|
|
||||||
|
# Built-in Providers
|
||||||
|
class NoOpProvider:
|
||||||
|
"""A provider that does nothing, used for testing or disabling translation."""
|
||||||
|
async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
No-op language detection: intentionally performs no detection and always returns None.
|
||||||
|
|
||||||
|
The `_text` parameter is ignored. Returns None to indicate that this provider does not provide a language detection result.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Return the input text unchanged and a confidence score of 0.0.
|
||||||
|
|
||||||
|
This provider does not perform any translation; it mirrors the source text back to the caller.
|
||||||
|
Parameters:
|
||||||
|
text (str): Source text to "translate".
|
||||||
|
_target_language (str): Unused target language parameter.
|
||||||
|
Returns:
|
||||||
|
Optional[Tuple[str, float]]: A tuple of (text, 0.0).
|
||||||
|
"""
|
||||||
|
return text, 0.0
|
||||||
|
|
||||||
|
class LangDetectProvider:
|
||||||
|
"""
|
||||||
|
A provider that uses the 'langdetect' library for offline language detection.
|
||||||
|
This provider only detects the language and does not perform translation.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the LangDetectProvider by loading the `langdetect.detect_langs` function.
|
||||||
|
|
||||||
|
Attempts to import `detect_langs` from the `langdetect` package and stores it on the instance as `_detect_langs`. Raises `LangDetectError` if the `langdetect` dependency is not available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from langdetect import detect_langs # type: ignore[import-untyped]
|
||||||
|
self._detect_langs = detect_langs
|
||||||
|
except ImportError as e:
|
||||||
|
raise LangDetectError() from e
|
||||||
|
|
||||||
|
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Detect the language of `text` using the provider's langdetect backend.
|
||||||
|
|
||||||
|
Returns a tuple of (language_code, confidence) where `language_code` is the top
|
||||||
|
detected language (e.g., "en") and `confidence` is the detection probability
|
||||||
|
in [0.0, 1.0]. Returns None if detection fails or no result is available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
detections = await asyncio.to_thread(self._detect_langs, text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during language detection")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not detections:
|
||||||
|
return None
|
||||||
|
best_detection = detections[0]
|
||||||
|
return best_detection.lang, best_detection.prob
|
||||||
|
|
||||||
|
async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]:
|
||||||
|
# This provider only detects language, does not translate.
|
||||||
|
"""
|
||||||
|
No-op translation: returns the input text unchanged with a 0.0 confidence.
|
||||||
|
|
||||||
|
This provider only performs language detection; translate is a passthrough that returns the original `text`
|
||||||
|
and a confidence of 0.0 to indicate no translated content was produced.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (text, confidence) where `text` is the original input and `confidence` is 0.0.
|
||||||
|
"""
|
||||||
|
return text, 0.0
|
||||||
|
|
||||||
|
class OpenAIProvider:
|
||||||
|
"""A provider that uses OpenAI's API for translation."""
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the OpenAIProvider by creating an AsyncOpenAI client and loading configuration.
|
||||||
|
|
||||||
|
Reads the following environment variables:
|
||||||
|
- OPENAI_API_KEY: API key passed to AsyncOpenAI for authentication.
|
||||||
|
- OPENAI_TRANSLATE_MODEL: model name to use for translations (default: "gpt-4o-mini").
|
||||||
|
- OPENAI_TIMEOUT: request timeout in seconds (default: "30", parsed as float).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OpenAIError: if the OpenAI SDK (AsyncOpenAI) cannot be imported.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from openai import AsyncOpenAI # type: ignore[import-untyped]
|
||||||
|
self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
self.model = os.getenv("OPENAI_TRANSLATE_MODEL", "gpt-4o-mini")
|
||||||
|
self.timeout = float(os.getenv("OPENAI_TIMEOUT", "30"))
|
||||||
|
except ImportError as e:
|
||||||
|
raise OpenAIError() from e
|
||||||
|
|
||||||
|
async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]:
|
||||||
|
# OpenAI's API does not have a separate language detection endpoint.
|
||||||
|
# This can be implemented as part of the translation prompt if needed.
|
||||||
|
"""
|
||||||
|
Indicates that this provider does not perform standalone language detection.
|
||||||
|
|
||||||
|
The OpenAI-based provider does not expose a separate detection endpoint and therefore
|
||||||
|
always returns None. Language detection can be achieved by using another provider
|
||||||
|
(e.g., the registered langdetect provider) or by incorporating detection into a
|
||||||
|
translation prompt if needed.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Translate the given text to the specified target language using the OpenAI chat completions client.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): Source text to translate.
|
||||||
|
target_language (str): Target language name or code (used verbatim in the translation prompt).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Tuple[str, float]]: A tuple of (translated_text, confidence). Confidence is 0.0 because no calibrated confidence is available.
|
||||||
|
Returns None if translation failed or an error occurred.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self.client.with_options(timeout=self.timeout).chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": f"Translate the following text to {target_language}."},
|
||||||
|
{"role": "user", "content": text},
|
||||||
|
],
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during OpenAI translation (model=%s)", self.model)
|
||||||
|
return None
|
||||||
|
|
||||||
|
translated_text = response.choices[0].message.content.strip()
|
||||||
|
return translated_text, 0.0 # No calibrated confidence available.
|
||||||
|
|
||||||
|
class GoogleTranslateProvider:
|
||||||
|
"""A provider that uses the 'googletrans' library for translation."""
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the GoogleTranslateProvider by importing and instantiating googletrans.Translator.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GoogleTranslateError: If the `googletrans` library is not installed or cannot be imported.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from googletrans import Translator # type: ignore[import-untyped]
|
||||||
|
self.translator = Translator()
|
||||||
|
except ImportError as e:
|
||||||
|
raise GoogleTranslateError() from e
|
||||||
|
|
||||||
|
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Detect the language of the given text using the configured googletrans Translator.
|
||||||
|
|
||||||
|
Uses a thread to call the synchronous translator.detect method; on failure returns None.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text: The text to detect the language for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (language_code, confidence) where `language_code` is the detected language string from the translator (e.g. "en") and `confidence` is a float in [0.0, 1.0]. Returns None if detection fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
detection = await asyncio.to_thread(self.translator.detect, text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during Google Translate language detection")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
conf = float(detection.confidence) if detection.confidence is not None else 0.0
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
conf = 0.0
|
||||||
|
return detection.lang, conf
|
||||||
|
|
||||||
|
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Translate `text` to `target_language` using the configured googletrans Translator.
|
||||||
|
|
||||||
|
Returns a tuple (translated_text, confidence) on success — confidence is always 0.0 because googletrans does not provide a confidence score — or None if translation fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
translation = await asyncio.to_thread(self.translator.translate, text, dest=target_language)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during Google Translate translation")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return translation.text, 0.0 # Confidence not provided.
|
||||||
|
|
||||||
|
class AzureTranslatorProvider:
|
||||||
|
"""A provider that uses Azure's Translator service."""
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the AzureTranslatorProvider.
|
||||||
|
|
||||||
|
Attempts to import Azure SDK classes, reads AZURE_TRANSLATOR_KEY, AZURE_TRANSLATOR_ENDPOINT,
|
||||||
|
and AZURE_TRANSLATOR_REGION from the environment, verifies the key is present, and constructs
|
||||||
|
a TextTranslationClient using an AzureKeyCredential.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AzureConfigError: if AZURE_TRANSLATOR_KEY is not set.
|
||||||
|
AzureTranslateError: if required Azure SDK imports are unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from azure.core.credentials import AzureKeyCredential # type: ignore[import-untyped]
|
||||||
|
from azure.ai.translation.text import TextTranslationClient # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
self.key = os.getenv("AZURE_TRANSLATOR_KEY")
|
||||||
|
self.endpoint = os.getenv("AZURE_TRANSLATOR_ENDPOINT", "https://api.cognitive.microsofttranslator.com/")
|
||||||
|
self.region = os.getenv("AZURE_TRANSLATOR_REGION", "global")
|
||||||
|
|
||||||
|
if not self.key:
|
||||||
|
raise AzureConfigError()
|
||||||
|
|
||||||
|
self.client = TextTranslationClient(
|
||||||
|
endpoint=self.endpoint,
|
||||||
|
credential=AzureKeyCredential(self.key),
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise AzureTranslateError() from e
|
||||||
|
|
||||||
|
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Detect the language of the given text using the Azure Translator client's detect API.
|
||||||
|
|
||||||
|
Attempts to call the Azure client's detect method (using a two-letter region as a country hint when available)
|
||||||
|
and returns a tuple of (language_code, confidence_score). Returns None if detection fails or an exception occurs.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): The text to detect language for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Tuple[str, float]]: (ISO language code, confidence between 0.0 and 1.0), or None on error.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use a valid country hint only when it looks like ISO 3166-1 alpha-2; otherwise omit.
|
||||||
|
hint = self.region.lower() if isinstance(self.region, str) and len(self.region) == 2 else None
|
||||||
|
response = await asyncio.to_thread(self.client.detect, content=[text], country_hint=hint)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during Azure language detection")
|
||||||
|
return None
|
||||||
|
|
||||||
|
detection = response[0].primary_language
|
||||||
|
return detection.language, detection.score
|
||||||
|
|
||||||
|
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Translate the given text to the target language using the Azure Translator client.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): Plain text to translate.
|
||||||
|
target_language (str): BCP-47 or ISO language code to translate the text into.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Tuple[str, float]]: A tuple of (translated_text, confidence). Returns None on error.
|
||||||
|
The provider does not surface a numeric confidence score, so the returned confidence is always 0.0.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await asyncio.to_thread(self.client.translate, content=[text], to=[target_language])
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during Azure translation")
|
||||||
|
return None
|
||||||
|
|
||||||
|
translation = response[0].translations[0]
|
||||||
|
return translation.text, 0.0 # Confidence not provided.
|
||||||
|
|
||||||
|
# Register built-in providers
|
||||||
|
register_translation_provider("noop", NoOpProvider)
|
||||||
|
register_translation_provider("langdetect", LangDetectProvider)
|
||||||
|
register_translation_provider("openai", OpenAIProvider)
|
||||||
|
register_translation_provider("google", GoogleTranslateProvider)
|
||||||
|
register_translation_provider("azure", AzureTranslatorProvider)
|
||||||
|
|
||||||
|
async def translate_content( # pylint: disable=too-many-locals,too-many-branches
|
||||||
|
*data_chunks,
|
||||||
|
target_language: str = TARGET_LANGUAGE,
|
||||||
|
translation_provider: str = "noop",
|
||||||
|
confidence_threshold: float = CONFIDENCE_THRESHOLD,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Translate content chunks to a target language and attach language and translation metadata.
|
||||||
|
|
||||||
|
This function accepts either multiple chunk objects as varargs or a single list of chunks.
|
||||||
|
For each chunk it:
|
||||||
|
- Resolves the named translation provider.
|
||||||
|
- Detects the chunk's language (with a fallback detector when available).
|
||||||
|
- Decides whether translation is required based on detected language, confidence threshold, and provider.
|
||||||
|
- Attaches language metadata (LanguageMetadata) to chunk.metadata.
|
||||||
|
- If required, performs translation and updates the chunk text and metadata (TranslatedContent).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*data_chunks: One or more chunk objects, or a single list of chunk objects. Each chunk must expose a `text` attribute and a `metadata` mapping (the function will create `metadata` if missing).
|
||||||
|
target_language (str): Language code to translate into (defaults to TARGET_LANGUAGE).
|
||||||
|
translation_provider (str): Registered provider name to use for detection/translation (defaults to "noop").
|
||||||
|
confidence_threshold (float): Minimum detection confidence required to skip translation (defaults to CONFIDENCE_THRESHOLD).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The list of processed chunk objects (same objects, possibly modified). Metadata keys added include language detection results and, when a translation occurs, translation details.
|
||||||
|
"""
|
||||||
|
provider = _get_provider(translation_provider)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if len(data_chunks) == 1 and isinstance(data_chunks[0], list):
|
||||||
|
_chunks = data_chunks[0]
|
||||||
|
else:
|
||||||
|
_chunks = list(data_chunks)
|
||||||
|
|
||||||
|
for chunk in _chunks:
|
||||||
|
ctx = TranslationContext(
|
||||||
|
provider=provider,
|
||||||
|
chunk=chunk,
|
||||||
|
text=getattr(chunk, "text", "") or "",
|
||||||
|
target_language=target_language,
|
||||||
|
confidence_threshold=confidence_threshold,
|
||||||
|
provider_name=translation_provider.lower(),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.detected_language, ctx.detection_confidence = await _detect_language_with_fallback(
|
||||||
|
ctx.provider, ctx.text, str(ctx.content_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
_decide_if_translation_is_required(ctx)
|
||||||
|
_attach_language_metadata(ctx)
|
||||||
|
|
||||||
|
if ctx.requires_translation:
|
||||||
|
await _translate_and_update(ctx)
|
||||||
|
|
||||||
|
results.append(ctx.chunk)
|
||||||
|
|
||||||
|
return results
|
||||||
86
examples/python/translation_example.py
Normal file
86
examples/python/translation_example.py
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import cognee
|
||||||
|
from cognee.api.v1.search import SearchType
|
||||||
|
from cognee.api.v1.cognify.cognify import get_default_tasks_with_translation
|
||||||
|
from cognee.modules.pipelines.operations.pipeline import run_pipeline
|
||||||
|
|
||||||
|
# Prerequisites:
|
||||||
|
# 1. Set up your environment with API keys for your chosen translation provider.
|
||||||
|
# - For OpenAI: OPENAI_API_KEY
|
||||||
|
# - For Azure: AZURE_TRANSLATOR_KEY, AZURE_TRANSLATOR_ENDPOINT, AZURE_TRANSLATOR_REGION
|
||||||
|
# 2. Specify the translation provider via an environment variable (optional, defaults to "noop"):
|
||||||
|
# COGNEE_TRANSLATION_PROVIDER="openai" # Or "google", "azure", "langdetect"
|
||||||
|
# 3. Install any required libraries for your provider:
|
||||||
|
# - pip install langdetect googletrans==4.0.0rc1 azure-ai-translation-text
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""
|
||||||
|
Demonstrates an end-to-end translation-enabled Cognify workflow using the Cognee SDK.
|
||||||
|
|
||||||
|
Performs three main steps:
|
||||||
|
1. Resets the demo workspace by pruning stored data and system metadata.
|
||||||
|
2. Seeds three multilingual documents, builds translation-enabled Cognify tasks using the
|
||||||
|
provider specified by the COGNEE_TRANSLATION_PROVIDER environment variable (defaults to "noop"),
|
||||||
|
and executes the pipeline to translate and process the documents.
|
||||||
|
- If the selected provider is missing or invalid, the function prints the error and returns early.
|
||||||
|
3. Issues an English search query (using SearchType.INSIGHTS) against the processed index and
|
||||||
|
prints any returned result texts.
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
- Mutates persistent Cognee state (prune, add, cognify pipeline execution).
|
||||||
|
- Prints status and result messages to stdout.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- No return value.
|
||||||
|
- Exceptions ValueError and ImportError are caught and handled by printing an error and exiting the function.
|
||||||
|
"""
|
||||||
|
# 1. Set up cognee and add multilingual content
|
||||||
|
print("Setting up demo environment...")
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
multilingual_texts = [
|
||||||
|
"El procesamiento de lenguaje natural (PLN) es un subcampo de la IA.",
|
||||||
|
"Le traitement automatique du langage naturel (TALN) est un sous-domaine de l'IA.",
|
||||||
|
"Natural language processing (NLP) is a subfield of AI.",
|
||||||
|
]
|
||||||
|
|
||||||
|
print("Adding multilingual texts...")
|
||||||
|
for text in multilingual_texts:
|
||||||
|
await cognee.add(text)
|
||||||
|
print("Texts added successfully.\n")
|
||||||
|
|
||||||
|
# 2. Run the cognify pipeline with translation enabled
|
||||||
|
provider = os.getenv('COGNEE_TRANSLATION_PROVIDER', 'noop').lower()
|
||||||
|
print(f"Running cognify with translation provider: {provider}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build translation-enabled tasks and execute the pipeline
|
||||||
|
translation_enabled_tasks = get_default_tasks_with_translation(
|
||||||
|
translation_provider=provider
|
||||||
|
)
|
||||||
|
async for _ in run_pipeline(tasks=translation_enabled_tasks):
|
||||||
|
pass
|
||||||
|
print("Cognify pipeline with translation completed successfully.")
|
||||||
|
except (ValueError, ImportError) as e:
|
||||||
|
print(f"Error during cognify: {e}")
|
||||||
|
print("Please ensure the selected provider is installed and configured correctly.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. Search for content in English
|
||||||
|
query_text = "Tell me about NLP"
|
||||||
|
print(f"\nSearching for: '{query_text}'")
|
||||||
|
|
||||||
|
# The search should now return results from all documents, as they have been translated.
|
||||||
|
search_results = await cognee.search(query_text, query_type=SearchType.INSIGHTS)
|
||||||
|
|
||||||
|
print("\nSearch Results:")
|
||||||
|
if search_results:
|
||||||
|
for result in search_results:
|
||||||
|
print(f"- {result.text}")
|
||||||
|
else:
|
||||||
|
print("No results found.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Reference in a new issue