Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
coderabbitai[bot]
9f6b2dca51
📝 Add docstrings to auto-translate-task
Docstrings generation was requested by @subhash-0000.

* https://github.com/topoteretes/cognee/pull/1353#issuecomment-3287760071

The following files were modified:

* `cognee/api/v1/cognify/cognify.py`
* `cognee/tasks/translation/test_translation.py`
* `cognee/tasks/translation/translate_content.py`
* `examples/python/translation_example.py`
2025-09-13 07:57:43 +00:00
4 changed files with 1432 additions and 195 deletions

View file

@ -1,18 +1,21 @@
import asyncio
from pydantic import BaseModel
from typing import Union, Optional
from typing import Union, Optional, Type
from uuid import UUID
import os
from cognee.shared.logging_utils import get_logger
from cognee.shared.data_models import KnowledgeGraph
from cognee.infrastructure.llm import get_max_chunk_tokens
from cognee.infrastructure.llm.utils import get_max_chunk_tokens
from cognee.shared.logging_utils import get_logger
from cognee.modules.pipelines import run_pipeline
from cognee.modules.pipelines.operations.pipeline import run_pipeline
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
from cognee.modules.users.models import User
logger = get_logger()
from cognee.tasks.documents import (
check_permissions_on_dataset,
classify_documents,
@ -21,179 +24,101 @@ from cognee.tasks.documents import (
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text
from cognee.tasks.translation import translate_content, get_available_providers, validate_provider
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
from cognee.tasks.temporal_graph.extract_events_and_entities import extract_events_and_timestamps
from cognee.tasks.temporal_graph.extract_knowledge_graph_from_events import (
extract_knowledge_graph_from_events,
)
logger = get_logger("cognify")
class TranslationProviderError(ValueError):
"""Error related to translation provider initialization."""
pass
update_status_lock = asyncio.Lock()
class UnknownTranslationProviderError(TranslationProviderError):
"""Unknown translation provider name."""
class ProviderInitializationError(TranslationProviderError):
"""Provider failed to initialize (likely missing dependency or bad config)."""
async def cognify(
datasets: Union[str, list[str], list[UUID]] = None,
user: User = None,
graph_model: BaseModel = KnowledgeGraph,
_WARNED_ENV_VARS: set[str] = set()
def _parse_batch_env(var: str, default: int = 10) -> int:
"""
Parse an environment variable as a positive integer (minimum 1), falling back to a default.
If the environment variable named `var` is unset, the provided `default` is returned.
If the variable is set but cannot be parsed as an integer, `default` is returned and a
one-time warning is logged for that variable (the variable name is recorded in
`_WARNED_ENV_VARS` to avoid repeated warnings).
Parameters:
var: Name of the environment variable to read.
default: Fallback integer value returned when the variable is missing or invalid.
Returns:
An integer >= 1 representing the parsed value or the fallback `default`.
"""
raw = os.getenv(var)
if raw is None:
return default
try:
return max(1, int(raw))
except (TypeError, ValueError):
if var not in _WARNED_ENV_VARS:
logger.warning("Invalid int for %s=%r; using default=%d", var, raw, default)
_WARNED_ENV_VARS.add(var)
return default
# Constants for batch processing
DEFAULT_BATCH_SIZE = _parse_batch_env("COGNEE_DEFAULT_BATCH_SIZE", 10)
async def cognify( # pylint: disable=too-many-arguments,too-many-positional-arguments
datasets: Optional[Union[str, UUID, list[str], list[UUID]]] = None,
user: Optional[User] = None,
graph_model: Type[BaseModel] = KnowledgeGraph,
chunker=TextChunker,
chunk_size: int = None,
chunk_size: Optional[int] = None,
ontology_file_path: Optional[str] = None,
vector_db_config: dict = None,
graph_db_config: dict = None,
vector_db_config: Optional[dict] = None,
graph_db_config: Optional[dict] = None,
run_in_background: bool = False,
incremental_loading: bool = True,
custom_prompt: Optional[str] = None,
temporal_cognify: bool = False,
):
"""
Transform ingested data into a structured knowledge graph.
This is the core processing step in Cognee that converts raw text and documents
into an intelligent knowledge graph. It analyzes content, extracts entities and
relationships, and creates semantic connections for enhanced search and reasoning.
Prerequisites:
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
- **Data Added**: Must have data previously added via `cognee.add()`
- **Vector Database**: Must be accessible for embeddings storage
- **Graph Database**: Must be accessible for relationship storage
Input Requirements:
- **Datasets**: Must contain data previously added via `cognee.add()`
- **Content Types**: Works with any text-extractable content including:
* Natural language documents
* Structured data (CSV, JSON)
* Code repositories
* Academic papers and technical documentation
* Mixed multimedia content (with text extraction)
Processing Pipeline:
1. **Document Classification**: Identifies document types and structures
2. **Permission Validation**: Ensures user has processing rights
3. **Text Chunking**: Breaks content into semantically meaningful segments
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
5. **Relationship Detection**: Discovers connections between entities
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
7. **Content Summarization**: Creates hierarchical summaries for navigation
Graph Model Customization:
The `graph_model` parameter allows custom knowledge structures:
- **Default**: General-purpose KnowledgeGraph for any domain
- **Custom Models**: Domain-specific schemas (e.g., scientific papers, code analysis)
- **Ontology Integration**: Use `ontology_file_path` for predefined vocabularies
Args:
datasets: Dataset name(s) or dataset uuid to process. Processes all available data if None.
- Single dataset: "my_dataset"
- Multiple datasets: ["docs", "research", "reports"]
- None: Process all datasets for the user
user: User context for authentication and data access. Uses default if None.
graph_model: Pydantic model defining the knowledge graph structure.
Defaults to KnowledgeGraph for general-purpose processing.
chunker: Text chunking strategy (TextChunker, LangchainChunker).
- TextChunker: Paragraph-based chunking (default, most reliable)
- LangchainChunker: Recursive character splitting with overlap
Determines how documents are segmented for processing.
chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None.
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
Default limits: ~512-8192 tokens depending on models.
Smaller chunks = more granular but potentially fragmented knowledge.
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.
Useful for specialized fields like medical or legal documents.
vector_db_config: Custom vector database configuration for embeddings storage.
graph_db_config: Custom graph database configuration for relationship storage.
run_in_background: If True, starts processing asynchronously and returns immediately.
If False, waits for completion before returning.
Background mode recommended for large datasets (>100MB).
Use pipeline_run_id from return value to monitor progress.
custom_prompt: Optional custom prompt string to use for entity extraction and graph generation.
If provided, this prompt will be used instead of the default prompts for
knowledge graph extraction. The prompt should guide the LLM on how to
extract entities and relationships from the text content.
Orchestrate processing of datasets into a knowledge graph.
Builds the default Cognify task sequence (classification, permission check, chunking,
graph extraction, summarization, indexing) and executes it via the pipeline
executor. Use get_default_tasks_with_translation(...) to include an automatic
translation step before graph extraction.
Parameters:
datasets: Optional dataset id or list of ids to process. If None, processes all
datasets available to the user.
user: Optional user context used for permission checks; defaults to the current
runtime user if omitted.
graph_model: Pydantic model type that defines the structure of produced graph
DataPoints (default: KnowledgeGraph).
chunker: Chunking strategy/class used to split documents (default: TextChunker).
chunk_size: Optional max tokens per chunk; when None a sensible default is used.
ontology_file_path: Optional path to an ontology (RDF/OWL) used by the extractor.
vector_db_config: Optional mapping of vector DB configuration (overrides defaults).
graph_db_config: Optional mapping of graph DB configuration (overrides defaults).
run_in_background: If True, starts the pipeline asynchronously and returns
background run info; if False, waits for completion and returns results.
incremental_loading: If True, performs incremental loading to avoid reprocessing
unchanged content.
custom_prompt: Optional prompt to override the default prompt used for graph
extraction.
Returns:
Union[dict, list[PipelineRunInfo]]:
- **Blocking mode**: Dictionary mapping dataset_id -> PipelineRunInfo with:
* Processing status (completed/failed/in_progress)
* Extracted entity and relationship counts
* Processing duration and resource usage
* Error details if any failures occurred
- **Background mode**: List of PipelineRunInfo objects for tracking progress
* Use pipeline_run_id to monitor status
* Check completion via pipeline monitoring APIs
Next Steps:
After successful cognify processing, use search functions to query the knowledge:
```python
import cognee
from cognee import SearchType
# Process your data into knowledge graph
await cognee.cognify()
# Query for insights using different search types:
# 1. Natural language completion with graph context
insights = await cognee.search(
"What are the main themes?",
query_type=SearchType.GRAPH_COMPLETION
)
# 2. Get entity relationships and connections
relationships = await cognee.search(
"connections between concepts",
query_type=SearchType.INSIGHTS
)
# 3. Find relevant document chunks
chunks = await cognee.search(
"specific topic",
query_type=SearchType.CHUNKS
)
```
Advanced Usage:
```python
# Custom domain model for scientific papers
class ScientificPaper(DataPoint):
title: str
authors: List[str]
methodology: str
findings: List[str]
await cognee.cognify(
datasets=["research_papers"],
graph_model=ScientificPaper,
ontology_file_path="scientific_ontology.owl"
)
# Background processing for large datasets
run_info = await cognee.cognify(
datasets=["large_corpus"],
run_in_background=True
)
# Check status later with run_info.pipeline_run_id
```
Environment Variables:
Required:
- LLM_API_KEY: API key for your LLM provider
Optional (same as add function):
- LLM_PROVIDER, LLM_MODEL, VECTOR_DB_PROVIDER, GRAPH_DATABASE_PROVIDER
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
The pipeline executor result. In blocking mode this is the pipeline run result
(per-dataset run info and status). In background mode this returns information
required to track the background run (e.g., pipeline_run_id and submission status).
"""
if temporal_cognify:
tasks = await get_temporal_tasks(user, chunker, chunk_size)
else:
tasks = await get_default_tasks(
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
)
tasks = get_default_tasks(
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
)
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
@ -211,20 +136,49 @@ async def cognify(
)
async def get_default_tasks( # TODO: Find out a better way to do this (Boris's comment)
user: User = None,
graph_model: BaseModel = KnowledgeGraph,
def get_default_tasks( # pylint: disable=too-many-arguments,too-many-positional-arguments
user: Optional[User] = None,
graph_model: Type[BaseModel] = KnowledgeGraph,
chunker=TextChunker,
chunk_size: int = None,
chunk_size: Optional[int] = None,
ontology_file_path: Optional[str] = None,
custom_prompt: Optional[str] = None,
) -> list[Task]:
"""
Return the standard, non-translation Task list used by the cognify pipeline.
This builds the default processing pipeline (no automatic translation) and returns
a list of Task objects in execution order:
1. classify_documents
2. check_permissions_on_dataset (enforces write permission for `user`)
3. extract_chunks_from_documents (uses `chunker` and `chunk_size`)
4. extract_graph_from_data (uses `graph_model`, optional `ontology_file_path`, and `custom_prompt`)
5. summarize_text
6. add_data_points
Notes:
- Batch sizes for downstream tasks use the module-level DEFAULT_BATCH_SIZE.
- If `chunk_size` is not provided, the token limit from get_max_chunk_tokens() is used.
Parameters:
user: Optional user context used for the permission check.
graph_model: Model class used to construct knowledge graph instances.
chunker: Chunking strategy or class used to split documents into chunks.
chunk_size: Optional max tokens per chunk; if omitted, defaults to get_max_chunk_tokens().
ontology_file_path: Optional path to an ontology file passed to the extractor.
custom_prompt: Optional custom prompt applied during graph extraction.
Returns:
List[Task]: Ordered list of Task objects for the cognify pipeline (no translation).
"""
# Precompute max_chunk_size for stability
max_chunk = chunk_size or get_max_chunk_tokens()
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents,
max_chunk_size=chunk_size or get_max_chunk_tokens(),
max_chunk_size=max_chunk,
chunker=chunker,
), # Extract text chunks based on the document type.
Task(
@ -232,51 +186,92 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
graph_model=graph_model,
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
custom_prompt=custom_prompt,
task_config={"batch_size": 10},
task_config={"batch_size": DEFAULT_BATCH_SIZE},
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
task_config={"batch_size": 10},
task_config={"batch_size": DEFAULT_BATCH_SIZE},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}),
]
return default_tasks
async def get_temporal_tasks(
user: User = None, chunker=TextChunker, chunk_size: int = None
def get_default_tasks_with_translation( # pylint: disable=too-many-arguments,too-many-positional-arguments
user: Optional[User] = None,
graph_model: Type[BaseModel] = KnowledgeGraph,
chunker=TextChunker,
chunk_size: Optional[int] = None,
ontology_file_path: Optional[str] = None,
custom_prompt: Optional[str] = None,
translation_provider: str = "noop",
) -> list[Task]:
"""
Builds and returns a list of temporal processing tasks to be executed in sequence.
The pipeline includes:
1. Document classification.
2. Dataset permission checks (requires "write" access).
3. Document chunking with a specified or default chunk size.
4. Event and timestamp extraction from chunks.
5. Knowledge graph extraction from events.
6. Batched insertion of data points.
Args:
user (User, optional): The user requesting task execution, used for permission checks.
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
Return the default Cognify pipeline task list with an added translation step.
Constructs the standard processing pipeline (classify -> permission check -> chunk extraction -> translate -> graph extraction -> summarize -> add data points),
validates and initializes the named translation provider, and applies module DEFAULT_BATCH_SIZE to downstream batchable tasks.
Parameters:
translation_provider (str): Name of a registered translation provider (case-insensitive). Defaults to `"noop"` which is a no-op provider.
Returns:
list[Task]: A list of Task objects representing the temporal processing pipeline.
list[Task]: Ordered Task objects ready to be executed by the pipeline executor.
Raises:
UnknownTranslationProviderError: If the given provider name is not in get_available_providers().
ProviderInitializationError: If the provider fails to initialize or validate via validate_provider().
"""
temporal_tasks = [
# Fail fast on unknown providers (keeps errors close to the API surface)
translation_provider = (translation_provider or "noop").strip().lower()
# Validate provider using public API
if translation_provider not in get_available_providers():
available = ", ".join(get_available_providers())
logger.error("Unknown provider '%s'. Available: %s", translation_provider, available)
raise UnknownTranslationProviderError(f"Unknown provider '{translation_provider}'")
# Instantiate to validate dependencies; include provider-specific config errors
try:
validate_provider(translation_provider)
except Exception as e: # we want to convert provider init errors
available = ", ".join(get_available_providers())
logger.error(
"Provider '%s' failed to initialize (available: %s).",
translation_provider,
available,
exc_info=True,
)
raise ProviderInitializationError() from e
# Precompute max_chunk_size for stability
max_chunk = chunk_size or get_max_chunk_tokens()
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents,
max_chunk_size=chunk_size or get_max_chunk_tokens(),
max_chunk_size=max_chunk,
chunker=chunker,
), # Extract text chunks based on the document type.
Task(
translate_content,
target_language="en",
translation_provider=translation_provider,
task_config={"batch_size": DEFAULT_BATCH_SIZE},
), # Auto-translate non-English content and attach metadata
Task(
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
custom_prompt=custom_prompt,
task_config={"batch_size": DEFAULT_BATCH_SIZE},
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
task_config={"batch_size": DEFAULT_BATCH_SIZE},
),
Task(extract_events_and_timestamps, task_config={"chunk_size": 10}),
Task(extract_knowledge_graph_from_events),
Task(add_data_points, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}),
]
return temporal_tasks
return default_tasks

View file

@ -0,0 +1,496 @@
"""
Unit tests for translation functionality.
Tests cover:
- Translation provider registry and discovery
- Language detection across providers
- Translation functionality
- Error handling and fallbacks
- Model validation and serialization
"""
import pytest # type: ignore[import-untyped]
from typing import Tuple, Optional, Dict
from pydantic import ValidationError
import cognee.tasks.translation.translate_content as tr
from cognee.tasks.translation.translate_content import (
translate_content,
register_translation_provider,
get_available_providers,
TranslationProvider,
NoOpProvider,
_get_provider,
)
from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata
class TestDetectionError(Exception): # pylint: disable=too-few-public-methods
"""Test exception for detection failures."""
class TestTranslationError(Exception): # pylint: disable=too-few-public-methods
"""Test exception for translation failures."""
# Ensure registry isolation across tests using public helpers
@pytest.fixture(autouse=True)
def _restore_registry():
"""
Pytest fixture that snapshots the translation provider registry and restores it after the test.
Use to isolate tests that register or modify providers: the current registry state is captured before the test runs, and always restored when the fixture completes (including on exceptions).
"""
snapshot = tr.snapshot_registry()
try:
yield
finally:
tr.restore_registry(snapshot)
class MockDocumentChunk: # pylint: disable=too-few-public-methods
"""Mock document chunk for testing."""
def __init__(self, text: str, chunk_id: str = "test_chunk", metadata: Optional[Dict] = None):
"""
Initialize a mock document chunk used in tests.
Parameters:
text (str): Chunk text content.
chunk_id (str): Identifier for the chunk; also used as chunk_index for tests. Defaults to "test_chunk".
metadata (Optional[Dict]): Optional mapping of metadata values; defaults to an empty dict.
"""
self.text = text
self.id = chunk_id
self.chunk_index = chunk_id
self.metadata = metadata or {}
class MockTranslationProvider:
"""Mock provider for testing custom provider registration."""
async def detect_language(self, text: str) -> Tuple[str, float]:
"""
Detect the language of the given text and return an ISO 639-1 language code with a confidence score.
This mock implementation uses simple keyword heuristics: returns ("es", 0.95) if the text contains "hola",
("fr", 0.90) if it contains "bonjour", and ("en", 0.85) otherwise.
Parameters:
text (str): Input text to analyze.
Returns:
Tuple[str, float]: A tuple of (language_code, confidence) where language_code is an ISO 639-1 code and
confidence is a float between 0.0 and 1.0 indicating detection confidence.
"""
if "hola" in text.lower():
return "es", 0.95
if "bonjour" in text.lower():
return "fr", 0.90
return "en", 0.85
async def translate(self, text: str, target_language: str) -> Tuple[str, float]:
"""
Simulate translating `text` into `target_language` and return a mock translated string with a confidence score.
If `target_language` is "en", returns the input prefixed with "[MOCK TRANSLATED]" and a confidence of 0.88. For any other target language, returns the original `text` and a confidence of 0.0.
Parameters:
text (str): The text to translate.
target_language (str): The target language code (e.g., "en").
Returns:
Tuple[str, float]: A pair of (translated_text, confidence) where confidence is in [0.0, 1.0].
"""
if target_language == "en":
return f"[MOCK TRANSLATED] {text}", 0.88
return text, 0.0
class TestProviderRegistry:
"""Test translation provider registration and discovery."""
def test_get_available_providers_includes_builtin(self):
"""Test that built-in providers are included in available list."""
providers = get_available_providers()
assert "noop" in providers
assert "langdetect" in providers
def test_register_custom_provider(self):
"""Test custom provider registration."""
register_translation_provider("mock", MockTranslationProvider)
providers = get_available_providers()
assert "mock" in providers
# Test provider can be retrieved
provider = _get_provider("mock")
assert isinstance(provider, MockTranslationProvider)
def test_provider_name_normalization(self):
"""Test provider names are normalized to lowercase."""
register_translation_provider("CUSTOM_PROVIDER", MockTranslationProvider)
providers = get_available_providers()
assert "custom_provider" in providers
# Should be retrievable with different casing
provider1 = _get_provider("CUSTOM_PROVIDER")
provider2 = _get_provider("custom_provider")
assert provider1.__class__ is provider2.__class__
def test_unknown_provider_raises(self):
"""Test unknown providers raise ValueError."""
with pytest.raises(ValueError):
_get_provider("nonexistent_provider")
class TestNoOpProvider:
"""Test NoOp provider functionality."""
@pytest.mark.asyncio
async def test_detect_language_ascii(self):
"""Test language detection for ASCII text."""
provider = NoOpProvider()
lang, conf = await provider.detect_language("Hello world")
assert lang is None
assert conf == 0.0
@pytest.mark.asyncio
async def test_detect_language_unicode(self):
"""Test language detection for Unicode text."""
provider = NoOpProvider()
lang, conf = await provider.detect_language("Hëllo wörld")
assert lang is None
assert conf == 0.0
@pytest.mark.asyncio
async def test_translate_returns_original(self):
"""Test translation returns original text with zero confidence."""
provider = NoOpProvider()
text = "Test text"
translated, conf = await provider.translate(text, "es")
assert translated == text
assert conf == 0.0
class TestTranslationModels:
"""Test Pydantic models for translation data."""
def test_translated_content_validation(self):
"""Test TranslatedContent model validation."""
content = TranslatedContent(
original_chunk_id="chunk_1",
original_text="Hello",
translated_text="Hola",
source_language="en",
target_language="es",
translation_provider="test",
confidence_score=0.9
)
assert content.original_chunk_id == "chunk_1"
assert content.confidence_score == 0.9
def test_translated_content_confidence_validation(self):
"""Test confidence score validation bounds."""
# Valid confidence scores
TranslatedContent(
original_chunk_id="test",
original_text="test",
translated_text="test",
source_language="en",
confidence_score=0.0
)
TranslatedContent(
original_chunk_id="test",
original_text="test",
translated_text="test",
source_language="en",
confidence_score=1.0
)
# Invalid confidence scores should raise validation error
with pytest.raises(ValidationError):
TranslatedContent(
original_chunk_id="test",
original_text="test",
translated_text="test",
source_language="en",
confidence_score=-0.1
)
with pytest.raises(ValidationError):
TranslatedContent(
original_chunk_id="test",
original_text="test",
translated_text="test",
source_language="en",
confidence_score=1.1
)
def test_language_metadata_validation(self):
"""Test LanguageMetadata model validation."""
metadata = LanguageMetadata(
content_id="chunk_1",
detected_language="es",
language_confidence=0.95,
requires_translation=True,
character_count=100
)
assert metadata.content_id == "chunk_1"
assert metadata.requires_translation is True
assert metadata.character_count == 100
def test_language_metadata_character_count_validation(self):
"""Test character count cannot be negative."""
with pytest.raises(ValidationError):
LanguageMetadata(
content_id="test",
detected_language="en",
character_count=-1
)
class TestTranslateContentFunction:
"""Test main translate_content function."""
@pytest.mark.asyncio
async def test_noop_provider_processing(self):
"""Test processing with noop provider."""
chunks = [
MockDocumentChunk("Hello world", "chunk_1"),
MockDocumentChunk("Test content", "chunk_2")
]
result = await translate_content(
chunks,
target_language="en",
translation_provider="noop",
confidence_threshold=0.8
)
assert len(result) == 2
for chunk in result:
assert "language" in chunk.metadata
assert chunk.metadata["language"]["detected_language"] == "unknown"
# No translation should occur with noop provider
assert "translation" not in chunk.metadata
@pytest.mark.asyncio
async def test_translation_with_custom_provider(self):
"""Test translation with custom registered provider."""
# Register mock provider
register_translation_provider("test_provider", MockTranslationProvider)
chunks = [MockDocumentChunk("Hola mundo", "chunk_1")]
result = await translate_content(
chunks,
target_language="en",
translation_provider="test_provider",
confidence_threshold=0.8
)
chunk = result[0]
assert "language" in chunk.metadata
assert "translation" in chunk.metadata
# Check language metadata
lang_meta = chunk.metadata["language"]
assert lang_meta["detected_language"] == "es"
assert lang_meta["requires_translation"] is True
# Check translation metadata
trans_meta = chunk.metadata["translation"]
assert trans_meta["original_text"] == "Hola mundo"
assert "[MOCK TRANSLATED]" in trans_meta["translated_text"]
assert trans_meta["source_language"] == "es"
assert trans_meta["target_language"] == "en"
assert trans_meta["translation_provider"] == "test_provider"
# Check chunk text was updated
assert "[MOCK TRANSLATED]" in chunk.text
@pytest.mark.asyncio
async def test_low_confidence_no_translation(self):
"""Test that low confidence detection doesn't trigger translation."""
register_translation_provider("low_conf", MockTranslationProvider)
chunks = [MockDocumentChunk("Hello world", "chunk_1")] # English text
result = await translate_content(
chunks,
target_language="en",
translation_provider="low_conf",
confidence_threshold=0.9 # High threshold
)
chunk = result[0]
assert "language" in chunk.metadata
# Should not translate due to high threshold and English detection
assert "translation" not in chunk.metadata
@pytest.mark.asyncio
async def test_error_handling_in_detection(self):
"""Test graceful error handling in language detection."""
class FailingProvider:
async def detect_language(self, _text: str) -> Tuple[str, float]:
"""
Simulate a language detection failure by always raising TestDetectionError.
This async method is used in tests to emulate a provider that fails during language detection. It accepts a text string but does not return; it always raises TestDetectionError.
"""
raise TestDetectionError()
async def translate(self, text: str, _target_language: str) -> Tuple[str, float]:
"""
Return the input text unchanged and a translation confidence of 0.0.
This no-op translator performs no translation; the supplied target language is ignored.
Parameters:
text (str): Source text to "translate".
_target_language (str): Target language (ignored).
Returns:
Tuple[str, float]: A tuple containing the original text and a confidence score (always 0.0).
"""
return text, 0.0
register_translation_provider("failing", FailingProvider)
chunks = [MockDocumentChunk("Test text", "chunk_1")]
# Disable 'langdetect' fallback to force unknown
ld = tr._provider_registry.pop("langdetect", None)
try:
result = await translate_content(chunks, translation_provider="failing")
finally:
if ld is not None:
tr._provider_registry["langdetect"] = ld
chunk = result[0]
assert "language" in chunk.metadata
# Should have unknown language due to detection failure
lang_meta = chunk.metadata["language"]
assert lang_meta["detected_language"] == "unknown"
assert lang_meta["language_confidence"] == 0.0
@pytest.mark.asyncio
async def test_error_handling_in_translation(self):
"""Test graceful error handling in translation."""
class PartialProvider:
async def detect_language(self, _text: str) -> Tuple[str, float]:
"""
Mock language detection used in tests.
Parameters:
_text (str): Input text (ignored by this mock).
Returns:
Tuple[str, float]: A fixed detected language code ("es") and confidence (0.9).
"""
return "es", 0.9
async def translate(self, _text: str, _target_language: str) -> Tuple[str, float]:
"""
Simulate a failing translation by always raising TestTranslationError.
This async method ignores its inputs and is used in tests to emulate a provider-side failure during translation.
Parameters:
_text (str): Unused input text.
_target_language (str): Unused target language code.
Raises:
TestTranslationError: Always raised to simulate a translation failure.
"""
raise TestTranslationError()
register_translation_provider("partial", PartialProvider)
chunks = [MockDocumentChunk("Hola", "chunk_1")]
result = await translate_content(
chunks,
translation_provider="partial",
confidence_threshold=0.8
)
chunk = result[0]
# Should have detected Spanish but failed translation
assert chunk.metadata["language"]["detected_language"] == "es"
# Should still create translation metadata with original text
assert "translation" in chunk.metadata
trans_meta = chunk.metadata["translation"]
assert trans_meta["translated_text"] == "Hola" # Original text due to failure
assert trans_meta["confidence_score"] == 0.0
@pytest.mark.asyncio
async def test_no_translation_when_same_language(self):
"""Test no translation occurs when source equals target language."""
register_translation_provider("same_lang", MockTranslationProvider)
chunks = [MockDocumentChunk("Hello world", "chunk_1")]
result = await translate_content(
chunks,
target_language="en", # Same as detected language
translation_provider="same_lang"
)
chunk = result[0]
assert "language" in chunk.metadata
# No translation should occur for same language
assert "translation" not in chunk.metadata
@pytest.mark.asyncio
async def test_metadata_serialization(self):
"""Test that metadata is properly serialized to dicts."""
register_translation_provider("serialize_test", MockTranslationProvider)
chunks = [MockDocumentChunk("Hola", "chunk_1")]
result = await translate_content(
chunks,
translation_provider="serialize_test",
confidence_threshold=0.8
)
chunk = result[0]
# Metadata should be plain dicts, not Pydantic models
assert isinstance(chunk.metadata["language"], dict)
if "translation" in chunk.metadata:
assert isinstance(chunk.metadata["translation"], dict)
def test_model_serialization_compatibility(self):
"""
Verify that a TranslatedContent instance can be dumped to a JSON-serializable dict.
Creates a TranslatedContent with sample fields, calls model_dump(), and asserts:
- the result is a dict,
- required fields like `original_chunk_id`, `translation_timestamp`, and `metadata` are present and preserved,
- the dict can be round-tripped through json.dumps/json.loads without losing `original_chunk_id`.
"""
content = TranslatedContent(
original_chunk_id="test",
original_text="Hello",
translated_text="Hola",
source_language="en",
target_language="es"
)
# Should serialize to dict
data = content.model_dump()
assert isinstance(data, dict)
assert data["original_chunk_id"] == "test"
assert "translation_timestamp" in data
assert "metadata" in data
# Should be JSON serializable
import json
json_str = json.dumps(data)
parsed = json.loads(json_str)
assert parsed["original_chunk_id"] == "test"

View file

@ -0,0 +1,660 @@
# pylint: disable=R0903, W0221
"""This module provides content translation capabilities for the Cognee framework."""
import asyncio
import math
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Type, Protocol, Tuple, Optional
from cognee.shared.logging_utils import get_logger
from .models import TranslatedContent, LanguageMetadata
logger = get_logger()
# Custom exceptions for better error handling
class TranslationDependencyError(ImportError):
"""Raised when a required translation dependency is missing."""
class LangDetectError(TranslationDependencyError):
"""LangDetect library required."""
class OpenAIError(TranslationDependencyError):
"""OpenAI library required."""
class GoogleTranslateError(TranslationDependencyError):
"""GoogleTrans library required."""
class AzureTranslateError(TranslationDependencyError):
"""Azure AI Translation library required."""
class AzureConfigError(ValueError):
"""Azure configuration error."""
# Environment variables for configuration
TARGET_LANGUAGE = os.getenv("COGNEE_TRANSLATION_TARGET_LANGUAGE", "en")
try:
CONFIDENCE_THRESHOLD = float(os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD", "0.80"))
except (TypeError, ValueError):
logger.warning(
"Invalid float for COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD=%r; defaulting to 0.80",
os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD"),
)
CONFIDENCE_THRESHOLD = 0.80
@dataclass
class TranslationContext:
"""A context object to hold data for a single translation operation."""
provider: "TranslationProvider"
chunk: Any
text: str
target_language: str
confidence_threshold: float
provider_name: str
content_id: str = field(init=False)
detected_language: str = "unknown"
detection_confidence: float = 0.0
requires_translation: bool = False
def __post_init__(self):
"""
Initialize derived fields after dataclass construction.
Sets self.content_id to the first available identifier on self.chunk in this order:
- self.chunk.id
- self.chunk.chunk_index
If neither attribute exists, content_id is set to the string "unknown".
"""
self.content_id = getattr(self.chunk, "id", getattr(self.chunk, "chunk_index", "unknown"))
class TranslationProvider(Protocol):
"""Protocol for translation providers."""
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
"""
Detect the language of the provided text.
Uses the langdetect library to determine the most likely language and its probability.
Returns a tuple (language_code, confidence) where `language_code` is a normalized short code (e.g., "en", "fr" or "unknown") and `confidence` is a float in [0.0, 1.0]. Returns None when detection fails (empty input, an error, or no reliable result).
"""
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
"""
Translate the given text into the specified target language asynchronously.
Parameters:
text: The source text to translate.
target_language: Target language code (e.g., "en", "es", "fr-CA").
Returns:
A tuple (translated_text, confidence) on success, where `confidence` is a float in [0.0, 1.0] (may be 0.0 if the provider does not supply a score), or None if translation failed or was unavailable.
"""
# Registry for translation providers
_provider_registry: Dict[str, Type[TranslationProvider]] = {}
def register_translation_provider(name: str, provider: Type[TranslationProvider]):
"""
Register a translation provider under a canonical lowercase key.
The provided class will be stored in the internal provider registry and looked up by its lowercased `name`. If an entry with the same key already exists it will be replaced.
Parameters:
name (str): Human-readable provider name (case-insensitive); stored as lower-case.
provider (Type[TranslationProvider]): Provider class implementing the TranslationProvider protocol; instances are constructed when the provider is resolved.
"""
_provider_registry[name.lower()] = provider
def get_available_providers():
"""Returns a list of available translation providers."""
return sorted(_provider_registry.keys())
def _get_provider(translation_provider: str) -> TranslationProvider:
"""
Resolve and instantiate a registered translation provider by name.
The lookup is case-insensitive: `translation_provider` should be the provider key (e.g., "openai", "google", "noop").
Returns an instance of the provider implementing the TranslationProvider protocol.
Raises:
ValueError: if no provider is registered under the given name; the error message lists available providers.
"""
provider_class = _provider_registry.get(translation_provider.lower())
if not provider_class:
available = ', '.join(get_available_providers())
msg = f"Unknown translation provider: {translation_provider}. Available providers: {available}"
raise ValueError(msg)
return provider_class()
# Helpers
def _normalize_lang_code(code: Optional[str]) -> str:
"""
Normalize a language code to a canonical form or return "unknown".
Normalizes common language code formats:
- Two-letter codes (e.g., "en", "EN", " en ") -> "en"
- Locale codes with region (e.g., "en-us", "en_US", "EN-us") -> "en-US"
- Returns "unknown" for empty, non-string, or unrecognized inputs.
Parameters:
code (Optional[str]): Language code or locale string to normalize.
Returns:
str: Normalized language code in either "xx" or "xx-YY" form, or "unknown" if input is invalid.
"""
if not isinstance(code, str) or not code.strip():
return "unknown"
c = code.strip().replace("_", "-")
parts = c.split("-")
if len(parts) == 1 and len(parts[0]) == 2 and parts[0].isalpha():
return parts[0].lower()
if len(parts) >= 2 and len(parts[0]) == 2 and parts[1]:
return f"{parts[0].lower()}-{parts[1][:2].upper()}"
return "unknown"
def _provider_name(provider: TranslationProvider) -> str:
"""Return the canonical registry key for a provider instance, or a best-effort name."""
return next(
(name for name, cls in _provider_registry.items() if isinstance(provider, cls)),
provider.__class__.__name__.replace("Provider", "").lower(),
)
async def _detect_language_with_fallback(provider: TranslationProvider, text: str, content_id: str) -> Tuple[str, float]:
"""
Detect the language of `text`, falling back to the registered "langdetect" provider if the primary provider fails.
Attempts to call the primary provider's `detect_language`. If that call returns None or raises, and a different "langdetect" provider is registered, it will try the fallback. Detection failures are logged; exceptions are not propagated.
Parameters:
text (str): The text to detect language for.
content_id (str): Identifier used in logs to correlate errors to the input content.
Returns:
Tuple[str, float]: A normalized language code (e.g., "en" or "pt-BR") and a confidence score in [0.0, 1.0].
On detection failure returns ("unknown", 0.0). Confidence values are coerced to float, NaNs converted to 0.0, and clamped to the [0.0, 1.0] range.
"""
try:
detection = await provider.detect_language(text)
except Exception:
logger.exception("Language detection failed for content_id=%s", content_id)
detection = None
if detection is None:
fallback_cls = _provider_registry.get("langdetect")
if fallback_cls is not None and not isinstance(provider, fallback_cls):
try:
detection = await fallback_cls().detect_language(text)
except Exception:
logger.exception("Fallback language detection failed for content_id=%s", content_id)
detection = None
if detection is None:
return "unknown", 0.0
lang_code, conf = detection
detected_language = _normalize_lang_code(lang_code)
try:
conf = float(conf)
except (TypeError, ValueError):
conf = 0.0
if math.isnan(conf):
conf = 0.0
conf = max(0.0, min(1.0, conf))
return detected_language, conf
def _decide_if_translation_is_required(ctx: TranslationContext) -> None:
"""
Decide whether a translation should be performed and update ctx.requires_translation.
Normalizes the configured target language and marks translation as required only when:
- The provider can perform translations (not "noop" or "langdetect"), and
- Either the detected language is "unknown" and the text is non-empty, or
- The detected language (normalized) differs from the target language and the detection confidence meets or exceeds ctx.confidence_threshold.
The function mutates the provided TranslationContext in-place and does not return a value.
"""
# Normalize to align with detected_language normalization and model regex.
target_language = _normalize_lang_code(ctx.target_language)
can_translate = ctx.provider_name not in ("noop", "langdetect")
if ctx.detected_language == "unknown":
ctx.requires_translation = can_translate and bool(ctx.text.strip())
else:
ctx.requires_translation = (
ctx.detected_language != target_language
and ctx.detection_confidence >= ctx.confidence_threshold
)
def _attach_language_metadata(ctx: TranslationContext) -> None:
"""
Attach language detection and translation decision metadata to the context's chunk.
Ensures the chunk has a metadata mapping, builds a LanguageMetadata record from
the context (content_id, detected language and confidence, whether translation is
required, and character count of the text), serializes it, and stores it under
the "language" key in chunk.metadata.
Parameters:
ctx (TranslationContext): Context containing the chunk and detection/decision values.
"""
ctx.chunk.metadata = getattr(ctx.chunk, "metadata", {}) or {}
lang_meta = LanguageMetadata(
content_id=str(ctx.content_id),
detected_language=ctx.detected_language,
language_confidence=ctx.detection_confidence,
requires_translation=ctx.requires_translation,
character_count=len(ctx.text),
)
ctx.chunk.metadata["language"] = lang_meta.model_dump()
async def _translate_and_update(ctx: TranslationContext) -> None:
"""
Translate the text in the provided TranslationContext and update the chunk and its metadata.
Performs an async translation via ctx.provider.translate, and when a non-empty, changed translation is returned:
- replaces ctx.chunk.text with the translated text,
- attempts to update ctx.chunk.chunk_size (if present),
- attaches a `translation` entry in ctx.chunk.metadata containing a TranslatedContent dict (original/translated text, source/target languages, provider, and confidence).
If translation fails (exception or None) the original text is preserved and a TranslatedContent record is still attached with confidence 0.0. If the provider returns the same text unchanged, no metadata is attached and the function returns without modifying the chunk.
Parameters:
ctx (TranslationContext): context carrying provider, chunk, original text, target language, detected language, and content_id.
Returns:
None
"""
try:
tr = await ctx.provider.translate(ctx.text, ctx.target_language)
except Exception:
logger.exception("Translation failed for content_id=%s", ctx.content_id)
tr = None
translated_text = None
translation_confidence = 0.0
provider_used = _provider_name(ctx.provider)
target_for_meta = _normalize_lang_code(ctx.target_language)
if tr and isinstance(tr[0], str) and tr[0].strip() and tr[0] != ctx.text:
translated_text, translation_confidence = tr
ctx.chunk.text = translated_text
if hasattr(ctx.chunk, "chunk_size"):
try:
ctx.chunk.chunk_size = len(translated_text.split())
except (AttributeError, ValueError, TypeError):
logger.debug(
"Could not update chunk_size for content_id=%s",
ctx.content_id,
exc_info=True,
)
elif tr is None:
# Translation failed, keep original text
translated_text = ctx.text
else:
# Provider returned unchanged text
logger.info("Provider returned unchanged text; skipping translation metadata (content_id=%s)", ctx.content_id)
return
trans = TranslatedContent(
original_chunk_id=str(ctx.content_id),
original_text=ctx.text,
translated_text=translated_text,
source_language=ctx.detected_language,
target_language=target_for_meta,
translation_provider=provider_used,
confidence_score=translation_confidence or 0.0,
)
ctx.chunk.metadata["translation"] = trans.model_dump()
# Test helpers for registry isolation
def snapshot_registry() -> Dict[str, Type[TranslationProvider]]:
"""Return a shallow copy snapshot of the provider registry (for tests)."""
return dict(_provider_registry)
def restore_registry(snapshot: Dict[str, Type[TranslationProvider]]) -> None:
"""
Restore the global translation provider registry from a previously captured snapshot.
This replaces the current internal provider registry with the given snapshot (clears then updates),
typically used by tests to restore provider registration state.
Parameters:
snapshot (Dict[str, Type[TranslationProvider]]): Mapping of provider name keys to provider classes.
"""
_provider_registry.clear()
_provider_registry.update(snapshot)
def validate_provider(name: str) -> None:
"""Ensure a provider can be resolved and instantiated or raise."""
_get_provider(name)
# Built-in Providers
class NoOpProvider:
"""A provider that does nothing, used for testing or disabling translation."""
async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]:
"""
No-op language detection: intentionally performs no detection and always returns None.
The `_text` parameter is ignored. Returns None to indicate that this provider does not provide a language detection result.
"""
return None
async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]:
"""
Return the input text unchanged and a confidence score of 0.0.
This provider does not perform any translation; it mirrors the source text back to the caller.
Parameters:
text (str): Source text to "translate".
_target_language (str): Unused target language parameter.
Returns:
Optional[Tuple[str, float]]: A tuple of (text, 0.0).
"""
return text, 0.0
class LangDetectProvider:
"""
A provider that uses the 'langdetect' library for offline language detection.
This provider only detects the language and does not perform translation.
"""
def __init__(self):
"""
Initialize the LangDetectProvider by loading the `langdetect.detect_langs` function.
Attempts to import `detect_langs` from the `langdetect` package and stores it on the instance as `_detect_langs`. Raises `LangDetectError` if the `langdetect` dependency is not available.
"""
try:
from langdetect import detect_langs # type: ignore[import-untyped]
self._detect_langs = detect_langs
except ImportError as e:
raise LangDetectError() from e
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
"""
Detect the language of `text` using the provider's langdetect backend.
Returns a tuple of (language_code, confidence) where `language_code` is the top
detected language (e.g., "en") and `confidence` is the detection probability
in [0.0, 1.0]. Returns None if detection fails or no result is available.
"""
try:
detections = await asyncio.to_thread(self._detect_langs, text)
except Exception:
logger.exception("Error during language detection")
return None
if not detections:
return None
best_detection = detections[0]
return best_detection.lang, best_detection.prob
async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]:
# This provider only detects language, does not translate.
"""
No-op translation: returns the input text unchanged with a 0.0 confidence.
This provider only performs language detection; translate is a passthrough that returns the original `text`
and a confidence of 0.0 to indicate no translated content was produced.
Returns:
A tuple of (text, confidence) where `text` is the original input and `confidence` is 0.0.
"""
return text, 0.0
class OpenAIProvider:
"""A provider that uses OpenAI's API for translation."""
def __init__(self):
"""
Initialize the OpenAIProvider by creating an AsyncOpenAI client and loading configuration.
Reads the following environment variables:
- OPENAI_API_KEY: API key passed to AsyncOpenAI for authentication.
- OPENAI_TRANSLATE_MODEL: model name to use for translations (default: "gpt-4o-mini").
- OPENAI_TIMEOUT: request timeout in seconds (default: "30", parsed as float).
Raises:
OpenAIError: if the OpenAI SDK (AsyncOpenAI) cannot be imported.
"""
try:
from openai import AsyncOpenAI # type: ignore[import-untyped]
self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.model = os.getenv("OPENAI_TRANSLATE_MODEL", "gpt-4o-mini")
self.timeout = float(os.getenv("OPENAI_TIMEOUT", "30"))
except ImportError as e:
raise OpenAIError() from e
async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]:
# OpenAI's API does not have a separate language detection endpoint.
# This can be implemented as part of the translation prompt if needed.
"""
Indicates that this provider does not perform standalone language detection.
The OpenAI-based provider does not expose a separate detection endpoint and therefore
always returns None. Language detection can be achieved by using another provider
(e.g., the registered langdetect provider) or by incorporating detection into a
translation prompt if needed.
"""
return None
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
"""
Translate the given text to the specified target language using the OpenAI chat completions client.
Parameters:
text (str): Source text to translate.
target_language (str): Target language name or code (used verbatim in the translation prompt).
Returns:
Optional[Tuple[str, float]]: A tuple of (translated_text, confidence). Confidence is 0.0 because no calibrated confidence is available.
Returns None if translation failed or an error occurred.
"""
try:
response = await self.client.with_options(timeout=self.timeout).chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": f"Translate the following text to {target_language}."},
{"role": "user", "content": text},
],
temperature=0.0,
)
except Exception:
logger.exception("Error during OpenAI translation (model=%s)", self.model)
return None
translated_text = response.choices[0].message.content.strip()
return translated_text, 0.0 # No calibrated confidence available.
class GoogleTranslateProvider:
"""A provider that uses the 'googletrans' library for translation."""
def __init__(self):
"""
Initialize the GoogleTranslateProvider by importing and instantiating googletrans.Translator.
Raises:
GoogleTranslateError: If the `googletrans` library is not installed or cannot be imported.
"""
try:
from googletrans import Translator # type: ignore[import-untyped]
self.translator = Translator()
except ImportError as e:
raise GoogleTranslateError() from e
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
"""
Detect the language of the given text using the configured googletrans Translator.
Uses a thread to call the synchronous translator.detect method; on failure returns None.
Parameters:
text: The text to detect the language for.
Returns:
A tuple (language_code, confidence) where `language_code` is the detected language string from the translator (e.g. "en") and `confidence` is a float in [0.0, 1.0]. Returns None if detection fails.
"""
try:
detection = await asyncio.to_thread(self.translator.detect, text)
except Exception:
logger.exception("Error during Google Translate language detection")
return None
try:
conf = float(detection.confidence) if detection.confidence is not None else 0.0
except (TypeError, ValueError):
conf = 0.0
return detection.lang, conf
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
"""
Translate `text` to `target_language` using the configured googletrans Translator.
Returns a tuple (translated_text, confidence) on success confidence is always 0.0 because googletrans does not provide a confidence score or None if translation fails.
"""
try:
translation = await asyncio.to_thread(self.translator.translate, text, dest=target_language)
except Exception:
logger.exception("Error during Google Translate translation")
return None
return translation.text, 0.0 # Confidence not provided.
class AzureTranslatorProvider:
"""A provider that uses Azure's Translator service."""
def __init__(self):
"""
Initialize the AzureTranslatorProvider.
Attempts to import Azure SDK classes, reads AZURE_TRANSLATOR_KEY, AZURE_TRANSLATOR_ENDPOINT,
and AZURE_TRANSLATOR_REGION from the environment, verifies the key is present, and constructs
a TextTranslationClient using an AzureKeyCredential.
Raises:
AzureConfigError: if AZURE_TRANSLATOR_KEY is not set.
AzureTranslateError: if required Azure SDK imports are unavailable.
"""
try:
from azure.core.credentials import AzureKeyCredential # type: ignore[import-untyped]
from azure.ai.translation.text import TextTranslationClient # type: ignore[import-untyped]
self.key = os.getenv("AZURE_TRANSLATOR_KEY")
self.endpoint = os.getenv("AZURE_TRANSLATOR_ENDPOINT", "https://api.cognitive.microsofttranslator.com/")
self.region = os.getenv("AZURE_TRANSLATOR_REGION", "global")
if not self.key:
raise AzureConfigError()
self.client = TextTranslationClient(
endpoint=self.endpoint,
credential=AzureKeyCredential(self.key),
)
except ImportError as e:
raise AzureTranslateError() from e
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
"""
Detect the language of the given text using the Azure Translator client's detect API.
Attempts to call the Azure client's detect method (using a two-letter region as a country hint when available)
and returns a tuple of (language_code, confidence_score). Returns None if detection fails or an exception occurs.
Parameters:
text (str): The text to detect language for.
Returns:
Optional[Tuple[str, float]]: (ISO language code, confidence between 0.0 and 1.0), or None on error.
"""
try:
# Use a valid country hint only when it looks like ISO 3166-1 alpha-2; otherwise omit.
hint = self.region.lower() if isinstance(self.region, str) and len(self.region) == 2 else None
response = await asyncio.to_thread(self.client.detect, content=[text], country_hint=hint)
except Exception:
logger.exception("Error during Azure language detection")
return None
detection = response[0].primary_language
return detection.language, detection.score
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
"""
Translate the given text to the target language using the Azure Translator client.
Parameters:
text (str): Plain text to translate.
target_language (str): BCP-47 or ISO language code to translate the text into.
Returns:
Optional[Tuple[str, float]]: A tuple of (translated_text, confidence). Returns None on error.
The provider does not surface a numeric confidence score, so the returned confidence is always 0.0.
"""
try:
response = await asyncio.to_thread(self.client.translate, content=[text], to=[target_language])
except Exception:
logger.exception("Error during Azure translation")
return None
translation = response[0].translations[0]
return translation.text, 0.0 # Confidence not provided.
# Register built-in providers
register_translation_provider("noop", NoOpProvider)
register_translation_provider("langdetect", LangDetectProvider)
register_translation_provider("openai", OpenAIProvider)
register_translation_provider("google", GoogleTranslateProvider)
register_translation_provider("azure", AzureTranslatorProvider)
async def translate_content( # pylint: disable=too-many-locals,too-many-branches
*data_chunks,
target_language: str = TARGET_LANGUAGE,
translation_provider: str = "noop",
confidence_threshold: float = CONFIDENCE_THRESHOLD,
):
"""
Translate content chunks to a target language and attach language and translation metadata.
This function accepts either multiple chunk objects as varargs or a single list of chunks.
For each chunk it:
- Resolves the named translation provider.
- Detects the chunk's language (with a fallback detector when available).
- Decides whether translation is required based on detected language, confidence threshold, and provider.
- Attaches language metadata (LanguageMetadata) to chunk.metadata.
- If required, performs translation and updates the chunk text and metadata (TranslatedContent).
Parameters:
*data_chunks: One or more chunk objects, or a single list of chunk objects. Each chunk must expose a `text` attribute and a `metadata` mapping (the function will create `metadata` if missing).
target_language (str): Language code to translate into (defaults to TARGET_LANGUAGE).
translation_provider (str): Registered provider name to use for detection/translation (defaults to "noop").
confidence_threshold (float): Minimum detection confidence required to skip translation (defaults to CONFIDENCE_THRESHOLD).
Returns:
list: The list of processed chunk objects (same objects, possibly modified). Metadata keys added include language detection results and, when a translation occurs, translation details.
"""
provider = _get_provider(translation_provider)
results = []
if len(data_chunks) == 1 and isinstance(data_chunks[0], list):
_chunks = data_chunks[0]
else:
_chunks = list(data_chunks)
for chunk in _chunks:
ctx = TranslationContext(
provider=provider,
chunk=chunk,
text=getattr(chunk, "text", "") or "",
target_language=target_language,
confidence_threshold=confidence_threshold,
provider_name=translation_provider.lower(),
)
ctx.detected_language, ctx.detection_confidence = await _detect_language_with_fallback(
ctx.provider, ctx.text, str(ctx.content_id)
)
_decide_if_translation_is_required(ctx)
_attach_language_metadata(ctx)
if ctx.requires_translation:
await _translate_and_update(ctx)
results.append(ctx.chunk)
return results

View file

@ -0,0 +1,86 @@
import asyncio
import os
import cognee
from cognee.api.v1.search import SearchType
from cognee.api.v1.cognify.cognify import get_default_tasks_with_translation
from cognee.modules.pipelines.operations.pipeline import run_pipeline
# Prerequisites:
# 1. Set up your environment with API keys for your chosen translation provider.
# - For OpenAI: OPENAI_API_KEY
# - For Azure: AZURE_TRANSLATOR_KEY, AZURE_TRANSLATOR_ENDPOINT, AZURE_TRANSLATOR_REGION
# 2. Specify the translation provider via an environment variable (optional, defaults to "noop"):
# COGNEE_TRANSLATION_PROVIDER="openai" # Or "google", "azure", "langdetect"
# 3. Install any required libraries for your provider:
# - pip install langdetect googletrans==4.0.0rc1 azure-ai-translation-text
async def main():
"""
Demonstrates an end-to-end translation-enabled Cognify workflow using the Cognee SDK.
Performs three main steps:
1. Resets the demo workspace by pruning stored data and system metadata.
2. Seeds three multilingual documents, builds translation-enabled Cognify tasks using the
provider specified by the COGNEE_TRANSLATION_PROVIDER environment variable (defaults to "noop"),
and executes the pipeline to translate and process the documents.
- If the selected provider is missing or invalid, the function prints the error and returns early.
3. Issues an English search query (using SearchType.INSIGHTS) against the processed index and
prints any returned result texts.
Side effects:
- Mutates persistent Cognee state (prune, add, cognify pipeline execution).
- Prints status and result messages to stdout.
Notes:
- No return value.
- Exceptions ValueError and ImportError are caught and handled by printing an error and exiting the function.
"""
# 1. Set up cognee and add multilingual content
print("Setting up demo environment...")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
multilingual_texts = [
"El procesamiento de lenguaje natural (PLN) es un subcampo de la IA.",
"Le traitement automatique du langage naturel (TALN) est un sous-domaine de l'IA.",
"Natural language processing (NLP) is a subfield of AI.",
]
print("Adding multilingual texts...")
for text in multilingual_texts:
await cognee.add(text)
print("Texts added successfully.\n")
# 2. Run the cognify pipeline with translation enabled
provider = os.getenv('COGNEE_TRANSLATION_PROVIDER', 'noop').lower()
print(f"Running cognify with translation provider: {provider}")
try:
# Build translation-enabled tasks and execute the pipeline
translation_enabled_tasks = get_default_tasks_with_translation(
translation_provider=provider
)
async for _ in run_pipeline(tasks=translation_enabled_tasks):
pass
print("Cognify pipeline with translation completed successfully.")
except (ValueError, ImportError) as e:
print(f"Error during cognify: {e}")
print("Please ensure the selected provider is installed and configured correctly.")
return
# 3. Search for content in English
query_text = "Tell me about NLP"
print(f"\nSearching for: '{query_text}'")
# The search should now return results from all documents, as they have been translated.
search_results = await cognee.search(query_text, query_type=SearchType.INSIGHTS)
print("\nSearch Results:")
if search_results:
for result in search_results:
print(f"- {result.text}")
else:
print("No results found.")
if __name__ == "__main__":
asyncio.run(main())