From 9f6b2dca51a936a9de482fc9f3c64934502240b6 Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Sat, 13 Sep 2025 07:57:43 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20Add=20docstrings=20to=20`auto-tr?= =?UTF-8?q?anslate-task`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Docstrings generation was requested by @subhash-0000. * https://github.com/topoteretes/cognee/pull/1353#issuecomment-3287760071 The following files were modified: * `cognee/api/v1/cognify/cognify.py` * `cognee/tasks/translation/test_translation.py` * `cognee/tasks/translation/translate_content.py` * `examples/python/translation_example.py` --- cognee/api/v1/cognify/cognify.py | 385 +++++----- cognee/tasks/translation/test_translation.py | 496 +++++++++++++ cognee/tasks/translation/translate_content.py | 660 ++++++++++++++++++ examples/python/translation_example.py | 86 +++ 4 files changed, 1432 insertions(+), 195 deletions(-) create mode 100644 cognee/tasks/translation/test_translation.py create mode 100644 cognee/tasks/translation/translate_content.py create mode 100644 examples/python/translation_example.py diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index e4f91b44c..bc3a3b1fd 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -1,18 +1,21 @@ -import asyncio from pydantic import BaseModel -from typing import Union, Optional +from typing import Union, Optional, Type from uuid import UUID +import os + -from cognee.shared.logging_utils import get_logger from cognee.shared.data_models import KnowledgeGraph -from cognee.infrastructure.llm import get_max_chunk_tokens +from cognee.infrastructure.llm.utils import get_max_chunk_tokens +from cognee.shared.logging_utils import get_logger -from cognee.modules.pipelines import run_pipeline +from cognee.modules.pipelines.operations.pipeline import run_pipeline from cognee.modules.pipelines.tasks.task import Task from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver from cognee.modules.users.models import User +logger = get_logger() + from cognee.tasks.documents import ( check_permissions_on_dataset, classify_documents, @@ -21,179 +24,101 @@ from cognee.tasks.documents import ( from cognee.tasks.graph import extract_graph_from_data from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_text +from cognee.tasks.translation import translate_content, get_available_providers, validate_provider from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor -from cognee.tasks.temporal_graph.extract_events_and_entities import extract_events_and_timestamps -from cognee.tasks.temporal_graph.extract_knowledge_graph_from_events import ( - extract_knowledge_graph_from_events, -) -logger = get_logger("cognify") +class TranslationProviderError(ValueError): + """Error related to translation provider initialization.""" + pass -update_status_lock = asyncio.Lock() +class UnknownTranslationProviderError(TranslationProviderError): + """Unknown translation provider name.""" + +class ProviderInitializationError(TranslationProviderError): + """Provider failed to initialize (likely missing dependency or bad config).""" -async def cognify( - datasets: Union[str, list[str], list[UUID]] = None, - user: User = None, - graph_model: BaseModel = KnowledgeGraph, +_WARNED_ENV_VARS: set[str] = set() + +def _parse_batch_env(var: str, default: int = 10) -> int: + """ + Parse an environment variable as a positive integer (minimum 1), falling back to a default. + + If the environment variable named `var` is unset, the provided `default` is returned. + If the variable is set but cannot be parsed as an integer, `default` is returned and a + one-time warning is logged for that variable (the variable name is recorded in + `_WARNED_ENV_VARS` to avoid repeated warnings). + + Parameters: + var: Name of the environment variable to read. + default: Fallback integer value returned when the variable is missing or invalid. + + Returns: + An integer >= 1 representing the parsed value or the fallback `default`. + """ + raw = os.getenv(var) + if raw is None: + return default + try: + return max(1, int(raw)) + except (TypeError, ValueError): + if var not in _WARNED_ENV_VARS: + logger.warning("Invalid int for %s=%r; using default=%d", var, raw, default) + _WARNED_ENV_VARS.add(var) + return default + +# Constants for batch processing +DEFAULT_BATCH_SIZE = _parse_batch_env("COGNEE_DEFAULT_BATCH_SIZE", 10) + +async def cognify( # pylint: disable=too-many-arguments,too-many-positional-arguments + datasets: Optional[Union[str, UUID, list[str], list[UUID]]] = None, + user: Optional[User] = None, + graph_model: Type[BaseModel] = KnowledgeGraph, chunker=TextChunker, - chunk_size: int = None, + chunk_size: Optional[int] = None, ontology_file_path: Optional[str] = None, - vector_db_config: dict = None, - graph_db_config: dict = None, + vector_db_config: Optional[dict] = None, + graph_db_config: Optional[dict] = None, run_in_background: bool = False, incremental_loading: bool = True, custom_prompt: Optional[str] = None, - temporal_cognify: bool = False, ): """ - Transform ingested data into a structured knowledge graph. - - This is the core processing step in Cognee that converts raw text and documents - into an intelligent knowledge graph. It analyzes content, extracts entities and - relationships, and creates semantic connections for enhanced search and reasoning. - - Prerequisites: - - **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation) - - **Data Added**: Must have data previously added via `cognee.add()` - - **Vector Database**: Must be accessible for embeddings storage - - **Graph Database**: Must be accessible for relationship storage - - Input Requirements: - - **Datasets**: Must contain data previously added via `cognee.add()` - - **Content Types**: Works with any text-extractable content including: - * Natural language documents - * Structured data (CSV, JSON) - * Code repositories - * Academic papers and technical documentation - * Mixed multimedia content (with text extraction) - - Processing Pipeline: - 1. **Document Classification**: Identifies document types and structures - 2. **Permission Validation**: Ensures user has processing rights - 3. **Text Chunking**: Breaks content into semantically meaningful segments - 4. **Entity Extraction**: Identifies key concepts, people, places, organizations - 5. **Relationship Detection**: Discovers connections between entities - 6. **Graph Construction**: Builds semantic knowledge graph with embeddings - 7. **Content Summarization**: Creates hierarchical summaries for navigation - - Graph Model Customization: - The `graph_model` parameter allows custom knowledge structures: - - **Default**: General-purpose KnowledgeGraph for any domain - - **Custom Models**: Domain-specific schemas (e.g., scientific papers, code analysis) - - **Ontology Integration**: Use `ontology_file_path` for predefined vocabularies - - Args: - datasets: Dataset name(s) or dataset uuid to process. Processes all available data if None. - - Single dataset: "my_dataset" - - Multiple datasets: ["docs", "research", "reports"] - - None: Process all datasets for the user - user: User context for authentication and data access. Uses default if None. - graph_model: Pydantic model defining the knowledge graph structure. - Defaults to KnowledgeGraph for general-purpose processing. - chunker: Text chunking strategy (TextChunker, LangchainChunker). - - TextChunker: Paragraph-based chunking (default, most reliable) - - LangchainChunker: Recursive character splitting with overlap - Determines how documents are segmented for processing. - chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None. - Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2) - Default limits: ~512-8192 tokens depending on models. - Smaller chunks = more granular but potentially fragmented knowledge. - ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types. - Useful for specialized fields like medical or legal documents. - vector_db_config: Custom vector database configuration for embeddings storage. - graph_db_config: Custom graph database configuration for relationship storage. - run_in_background: If True, starts processing asynchronously and returns immediately. - If False, waits for completion before returning. - Background mode recommended for large datasets (>100MB). - Use pipeline_run_id from return value to monitor progress. - custom_prompt: Optional custom prompt string to use for entity extraction and graph generation. - If provided, this prompt will be used instead of the default prompts for - knowledge graph extraction. The prompt should guide the LLM on how to - extract entities and relationships from the text content. - + Orchestrate processing of datasets into a knowledge graph. + + Builds the default Cognify task sequence (classification, permission check, chunking, + graph extraction, summarization, indexing) and executes it via the pipeline + executor. Use get_default_tasks_with_translation(...) to include an automatic + translation step before graph extraction. + + Parameters: + datasets: Optional dataset id or list of ids to process. If None, processes all + datasets available to the user. + user: Optional user context used for permission checks; defaults to the current + runtime user if omitted. + graph_model: Pydantic model type that defines the structure of produced graph + DataPoints (default: KnowledgeGraph). + chunker: Chunking strategy/class used to split documents (default: TextChunker). + chunk_size: Optional max tokens per chunk; when None a sensible default is used. + ontology_file_path: Optional path to an ontology (RDF/OWL) used by the extractor. + vector_db_config: Optional mapping of vector DB configuration (overrides defaults). + graph_db_config: Optional mapping of graph DB configuration (overrides defaults). + run_in_background: If True, starts the pipeline asynchronously and returns + background run info; if False, waits for completion and returns results. + incremental_loading: If True, performs incremental loading to avoid reprocessing + unchanged content. + custom_prompt: Optional prompt to override the default prompt used for graph + extraction. + Returns: - Union[dict, list[PipelineRunInfo]]: - - **Blocking mode**: Dictionary mapping dataset_id -> PipelineRunInfo with: - * Processing status (completed/failed/in_progress) - * Extracted entity and relationship counts - * Processing duration and resource usage - * Error details if any failures occurred - - **Background mode**: List of PipelineRunInfo objects for tracking progress - * Use pipeline_run_id to monitor status - * Check completion via pipeline monitoring APIs - - Next Steps: - After successful cognify processing, use search functions to query the knowledge: - - ```python - import cognee - from cognee import SearchType - - # Process your data into knowledge graph - await cognee.cognify() - - # Query for insights using different search types: - - # 1. Natural language completion with graph context - insights = await cognee.search( - "What are the main themes?", - query_type=SearchType.GRAPH_COMPLETION - ) - - # 2. Get entity relationships and connections - relationships = await cognee.search( - "connections between concepts", - query_type=SearchType.INSIGHTS - ) - - # 3. Find relevant document chunks - chunks = await cognee.search( - "specific topic", - query_type=SearchType.CHUNKS - ) - ``` - - Advanced Usage: - ```python - # Custom domain model for scientific papers - class ScientificPaper(DataPoint): - title: str - authors: List[str] - methodology: str - findings: List[str] - - await cognee.cognify( - datasets=["research_papers"], - graph_model=ScientificPaper, - ontology_file_path="scientific_ontology.owl" - ) - - # Background processing for large datasets - run_info = await cognee.cognify( - datasets=["large_corpus"], - run_in_background=True - ) - # Check status later with run_info.pipeline_run_id - ``` - - - Environment Variables: - Required: - - LLM_API_KEY: API key for your LLM provider - - Optional (same as add function): - - LLM_PROVIDER, LLM_MODEL, VECTOR_DB_PROVIDER, GRAPH_DATABASE_PROVIDER - - LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False) - - LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60) + The pipeline executor result. In blocking mode this is the pipeline run result + (per-dataset run info and status). In background mode this returns information + required to track the background run (e.g., pipeline_run_id and submission status). """ - if temporal_cognify: - tasks = await get_temporal_tasks(user, chunker, chunk_size) - else: - tasks = await get_default_tasks( - user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt - ) + tasks = get_default_tasks( + user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt + ) # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background) @@ -211,20 +136,49 @@ async def cognify( ) -async def get_default_tasks( # TODO: Find out a better way to do this (Boris's comment) - user: User = None, - graph_model: BaseModel = KnowledgeGraph, +def get_default_tasks( # pylint: disable=too-many-arguments,too-many-positional-arguments + user: Optional[User] = None, + graph_model: Type[BaseModel] = KnowledgeGraph, chunker=TextChunker, - chunk_size: int = None, + chunk_size: Optional[int] = None, ontology_file_path: Optional[str] = None, custom_prompt: Optional[str] = None, ) -> list[Task]: + """ + Return the standard, non-translation Task list used by the cognify pipeline. + + This builds the default processing pipeline (no automatic translation) and returns + a list of Task objects in execution order: + 1. classify_documents + 2. check_permissions_on_dataset (enforces write permission for `user`) + 3. extract_chunks_from_documents (uses `chunker` and `chunk_size`) + 4. extract_graph_from_data (uses `graph_model`, optional `ontology_file_path`, and `custom_prompt`) + 5. summarize_text + 6. add_data_points + + Notes: + - Batch sizes for downstream tasks use the module-level DEFAULT_BATCH_SIZE. + - If `chunk_size` is not provided, the token limit from get_max_chunk_tokens() is used. + + Parameters: + user: Optional user context used for the permission check. + graph_model: Model class used to construct knowledge graph instances. + chunker: Chunking strategy or class used to split documents into chunks. + chunk_size: Optional max tokens per chunk; if omitted, defaults to get_max_chunk_tokens(). + ontology_file_path: Optional path to an ontology file passed to the extractor. + custom_prompt: Optional custom prompt applied during graph extraction. + + Returns: + List[Task]: Ordered list of Task objects for the cognify pipeline (no translation). + """ + # Precompute max_chunk_size for stability + max_chunk = chunk_size or get_max_chunk_tokens() default_tasks = [ Task(classify_documents), Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task( extract_chunks_from_documents, - max_chunk_size=chunk_size or get_max_chunk_tokens(), + max_chunk_size=max_chunk, chunker=chunker, ), # Extract text chunks based on the document type. Task( @@ -232,51 +186,92 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's graph_model=graph_model, ontology_adapter=OntologyResolver(ontology_file=ontology_file_path), custom_prompt=custom_prompt, - task_config={"batch_size": 10}, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, ), # Generate knowledge graphs from the document chunks. Task( summarize_text, - task_config={"batch_size": 10}, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, ), - Task(add_data_points, task_config={"batch_size": 10}), + Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}), ] return default_tasks -async def get_temporal_tasks( - user: User = None, chunker=TextChunker, chunk_size: int = None +def get_default_tasks_with_translation( # pylint: disable=too-many-arguments,too-many-positional-arguments + user: Optional[User] = None, + graph_model: Type[BaseModel] = KnowledgeGraph, + chunker=TextChunker, + chunk_size: Optional[int] = None, + ontology_file_path: Optional[str] = None, + custom_prompt: Optional[str] = None, + translation_provider: str = "noop", ) -> list[Task]: """ - Builds and returns a list of temporal processing tasks to be executed in sequence. - - The pipeline includes: - 1. Document classification. - 2. Dataset permission checks (requires "write" access). - 3. Document chunking with a specified or default chunk size. - 4. Event and timestamp extraction from chunks. - 5. Knowledge graph extraction from events. - 6. Batched insertion of data points. - - Args: - user (User, optional): The user requesting task execution, used for permission checks. - chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker. - chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default. - + Return the default Cognify pipeline task list with an added translation step. + + Constructs the standard processing pipeline (classify -> permission check -> chunk extraction -> translate -> graph extraction -> summarize -> add data points), + validates and initializes the named translation provider, and applies module DEFAULT_BATCH_SIZE to downstream batchable tasks. + + Parameters: + translation_provider (str): Name of a registered translation provider (case-insensitive). Defaults to `"noop"` which is a no-op provider. + Returns: - list[Task]: A list of Task objects representing the temporal processing pipeline. + list[Task]: Ordered Task objects ready to be executed by the pipeline executor. + + Raises: + UnknownTranslationProviderError: If the given provider name is not in get_available_providers(). + ProviderInitializationError: If the provider fails to initialize or validate via validate_provider(). """ - temporal_tasks = [ + # Fail fast on unknown providers (keeps errors close to the API surface) + translation_provider = (translation_provider or "noop").strip().lower() + # Validate provider using public API + if translation_provider not in get_available_providers(): + available = ", ".join(get_available_providers()) + logger.error("Unknown provider '%s'. Available: %s", translation_provider, available) + raise UnknownTranslationProviderError(f"Unknown provider '{translation_provider}'") + # Instantiate to validate dependencies; include provider-specific config errors + try: + validate_provider(translation_provider) + except Exception as e: # we want to convert provider init errors + available = ", ".join(get_available_providers()) + logger.error( + "Provider '%s' failed to initialize (available: %s).", + translation_provider, + available, + exc_info=True, + ) + raise ProviderInitializationError() from e + + # Precompute max_chunk_size for stability + max_chunk = chunk_size or get_max_chunk_tokens() + + default_tasks = [ Task(classify_documents), Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task( extract_chunks_from_documents, - max_chunk_size=chunk_size or get_max_chunk_tokens(), + max_chunk_size=max_chunk, chunker=chunker, + ), # Extract text chunks based on the document type. + Task( + translate_content, + target_language="en", + translation_provider=translation_provider, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, + ), # Auto-translate non-English content and attach metadata + Task( + extract_graph_from_data, + graph_model=graph_model, + ontology_adapter=OntologyResolver(ontology_file=ontology_file_path), + custom_prompt=custom_prompt, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, + ), # Generate knowledge graphs from the document chunks. + Task( + summarize_text, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, ), - Task(extract_events_and_timestamps, task_config={"chunk_size": 10}), - Task(extract_knowledge_graph_from_events), - Task(add_data_points, task_config={"batch_size": 10}), + Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}), ] - return temporal_tasks + return default_tasks diff --git a/cognee/tasks/translation/test_translation.py b/cognee/tasks/translation/test_translation.py new file mode 100644 index 000000000..39526bd92 --- /dev/null +++ b/cognee/tasks/translation/test_translation.py @@ -0,0 +1,496 @@ +""" +Unit tests for translation functionality. + +Tests cover: +- Translation provider registry and discovery +- Language detection across providers +- Translation functionality +- Error handling and fallbacks +- Model validation and serialization +""" + +import pytest # type: ignore[import-untyped] +from typing import Tuple, Optional, Dict +from pydantic import ValidationError +import cognee.tasks.translation.translate_content as tr + +from cognee.tasks.translation.translate_content import ( + translate_content, + register_translation_provider, + get_available_providers, + TranslationProvider, + NoOpProvider, + _get_provider, +) +from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata + + +class TestDetectionError(Exception): # pylint: disable=too-few-public-methods + """Test exception for detection failures.""" + + +class TestTranslationError(Exception): # pylint: disable=too-few-public-methods + """Test exception for translation failures.""" + + +# Ensure registry isolation across tests using public helpers +@pytest.fixture(autouse=True) +def _restore_registry(): + """ + Pytest fixture that snapshots the translation provider registry and restores it after the test. + + Use to isolate tests that register or modify providers: the current registry state is captured before the test runs, and always restored when the fixture completes (including on exceptions). + """ + snapshot = tr.snapshot_registry() + try: + yield + finally: + tr.restore_registry(snapshot) + + +class MockDocumentChunk: # pylint: disable=too-few-public-methods + """Mock document chunk for testing.""" + + def __init__(self, text: str, chunk_id: str = "test_chunk", metadata: Optional[Dict] = None): + """ + Initialize a mock document chunk used in tests. + + Parameters: + text (str): Chunk text content. + chunk_id (str): Identifier for the chunk; also used as chunk_index for tests. Defaults to "test_chunk". + metadata (Optional[Dict]): Optional mapping of metadata values; defaults to an empty dict. + """ + self.text = text + self.id = chunk_id + self.chunk_index = chunk_id + self.metadata = metadata or {} + + +class MockTranslationProvider: + """Mock provider for testing custom provider registration.""" + + async def detect_language(self, text: str) -> Tuple[str, float]: + """ + Detect the language of the given text and return an ISO 639-1 language code with a confidence score. + + This mock implementation uses simple keyword heuristics: returns ("es", 0.95) if the text contains "hola", + ("fr", 0.90) if it contains "bonjour", and ("en", 0.85) otherwise. + + Parameters: + text (str): Input text to analyze. + + Returns: + Tuple[str, float]: A tuple of (language_code, confidence) where language_code is an ISO 639-1 code and + confidence is a float between 0.0 and 1.0 indicating detection confidence. + """ + if "hola" in text.lower(): + return "es", 0.95 + if "bonjour" in text.lower(): + return "fr", 0.90 + return "en", 0.85 + + async def translate(self, text: str, target_language: str) -> Tuple[str, float]: + """ + Simulate translating `text` into `target_language` and return a mock translated string with a confidence score. + + If `target_language` is "en", returns the input prefixed with "[MOCK TRANSLATED]" and a confidence of 0.88. For any other target language, returns the original `text` and a confidence of 0.0. + + Parameters: + text (str): The text to translate. + target_language (str): The target language code (e.g., "en"). + + Returns: + Tuple[str, float]: A pair of (translated_text, confidence) where confidence is in [0.0, 1.0]. + """ + if target_language == "en": + return f"[MOCK TRANSLATED] {text}", 0.88 + return text, 0.0 + + +class TestProviderRegistry: + """Test translation provider registration and discovery.""" + + def test_get_available_providers_includes_builtin(self): + """Test that built-in providers are included in available list.""" + providers = get_available_providers() + assert "noop" in providers + assert "langdetect" in providers + + def test_register_custom_provider(self): + """Test custom provider registration.""" + register_translation_provider("mock", MockTranslationProvider) + providers = get_available_providers() + assert "mock" in providers + + # Test provider can be retrieved + provider = _get_provider("mock") + assert isinstance(provider, MockTranslationProvider) + + def test_provider_name_normalization(self): + """Test provider names are normalized to lowercase.""" + register_translation_provider("CUSTOM_PROVIDER", MockTranslationProvider) + providers = get_available_providers() + assert "custom_provider" in providers + + # Should be retrievable with different casing + provider1 = _get_provider("CUSTOM_PROVIDER") + provider2 = _get_provider("custom_provider") + assert provider1.__class__ is provider2.__class__ + + def test_unknown_provider_raises(self): + """Test unknown providers raise ValueError.""" + with pytest.raises(ValueError): + _get_provider("nonexistent_provider") + + +class TestNoOpProvider: + """Test NoOp provider functionality.""" + + @pytest.mark.asyncio + async def test_detect_language_ascii(self): + """Test language detection for ASCII text.""" + provider = NoOpProvider() + lang, conf = await provider.detect_language("Hello world") + assert lang is None + assert conf == 0.0 + + @pytest.mark.asyncio + async def test_detect_language_unicode(self): + """Test language detection for Unicode text.""" + provider = NoOpProvider() + lang, conf = await provider.detect_language("Hëllo wörld") + assert lang is None + assert conf == 0.0 + + @pytest.mark.asyncio + async def test_translate_returns_original(self): + """Test translation returns original text with zero confidence.""" + provider = NoOpProvider() + text = "Test text" + translated, conf = await provider.translate(text, "es") + assert translated == text + assert conf == 0.0 + + +class TestTranslationModels: + """Test Pydantic models for translation data.""" + + def test_translated_content_validation(self): + """Test TranslatedContent model validation.""" + content = TranslatedContent( + original_chunk_id="chunk_1", + original_text="Hello", + translated_text="Hola", + source_language="en", + target_language="es", + translation_provider="test", + confidence_score=0.9 + ) + assert content.original_chunk_id == "chunk_1" + assert content.confidence_score == 0.9 + + def test_translated_content_confidence_validation(self): + """Test confidence score validation bounds.""" + # Valid confidence scores + TranslatedContent( + original_chunk_id="test", + original_text="test", + translated_text="test", + source_language="en", + confidence_score=0.0 + ) + TranslatedContent( + original_chunk_id="test", + original_text="test", + translated_text="test", + source_language="en", + confidence_score=1.0 + ) + + # Invalid confidence scores should raise validation error + with pytest.raises(ValidationError): + TranslatedContent( + original_chunk_id="test", + original_text="test", + translated_text="test", + source_language="en", + confidence_score=-0.1 + ) + + with pytest.raises(ValidationError): + TranslatedContent( + original_chunk_id="test", + original_text="test", + translated_text="test", + source_language="en", + confidence_score=1.1 + ) + + def test_language_metadata_validation(self): + """Test LanguageMetadata model validation.""" + metadata = LanguageMetadata( + content_id="chunk_1", + detected_language="es", + language_confidence=0.95, + requires_translation=True, + character_count=100 + ) + assert metadata.content_id == "chunk_1" + assert metadata.requires_translation is True + assert metadata.character_count == 100 + + def test_language_metadata_character_count_validation(self): + """Test character count cannot be negative.""" + with pytest.raises(ValidationError): + LanguageMetadata( + content_id="test", + detected_language="en", + character_count=-1 + ) + + +class TestTranslateContentFunction: + """Test main translate_content function.""" + + @pytest.mark.asyncio + async def test_noop_provider_processing(self): + """Test processing with noop provider.""" + chunks = [ + MockDocumentChunk("Hello world", "chunk_1"), + MockDocumentChunk("Test content", "chunk_2") + ] + + result = await translate_content( + chunks, + target_language="en", + translation_provider="noop", + confidence_threshold=0.8 + ) + + assert len(result) == 2 + for chunk in result: + assert "language" in chunk.metadata + assert chunk.metadata["language"]["detected_language"] == "unknown" + # No translation should occur with noop provider + assert "translation" not in chunk.metadata + + @pytest.mark.asyncio + async def test_translation_with_custom_provider(self): + """Test translation with custom registered provider.""" + # Register mock provider + register_translation_provider("test_provider", MockTranslationProvider) + + chunks = [MockDocumentChunk("Hola mundo", "chunk_1")] + + result = await translate_content( + chunks, + target_language="en", + translation_provider="test_provider", + confidence_threshold=0.8 + ) + + chunk = result[0] + assert "language" in chunk.metadata + assert "translation" in chunk.metadata + + # Check language metadata + lang_meta = chunk.metadata["language"] + assert lang_meta["detected_language"] == "es" + assert lang_meta["requires_translation"] is True + + # Check translation metadata + trans_meta = chunk.metadata["translation"] + assert trans_meta["original_text"] == "Hola mundo" + assert "[MOCK TRANSLATED]" in trans_meta["translated_text"] + assert trans_meta["source_language"] == "es" + assert trans_meta["target_language"] == "en" + assert trans_meta["translation_provider"] == "test_provider" + + # Check chunk text was updated + assert "[MOCK TRANSLATED]" in chunk.text + + @pytest.mark.asyncio + async def test_low_confidence_no_translation(self): + """Test that low confidence detection doesn't trigger translation.""" + register_translation_provider("low_conf", MockTranslationProvider) + + chunks = [MockDocumentChunk("Hello world", "chunk_1")] # English text + + result = await translate_content( + chunks, + target_language="en", + translation_provider="low_conf", + confidence_threshold=0.9 # High threshold + ) + + chunk = result[0] + assert "language" in chunk.metadata + # Should not translate due to high threshold and English detection + assert "translation" not in chunk.metadata + + @pytest.mark.asyncio + async def test_error_handling_in_detection(self): + """Test graceful error handling in language detection.""" + class FailingProvider: + async def detect_language(self, _text: str) -> Tuple[str, float]: + """ + Simulate a language detection failure by always raising TestDetectionError. + + This async method is used in tests to emulate a provider that fails during language detection. It accepts a text string but does not return; it always raises TestDetectionError. + """ + raise TestDetectionError() + + async def translate(self, text: str, _target_language: str) -> Tuple[str, float]: + """ + Return the input text unchanged and a translation confidence of 0.0. + + This no-op translator performs no translation; the supplied target language is ignored. + + Parameters: + text (str): Source text to "translate". + _target_language (str): Target language (ignored). + + Returns: + Tuple[str, float]: A tuple containing the original text and a confidence score (always 0.0). + """ + return text, 0.0 + + register_translation_provider("failing", FailingProvider) + + chunks = [MockDocumentChunk("Test text", "chunk_1")] + + # Disable 'langdetect' fallback to force unknown + ld = tr._provider_registry.pop("langdetect", None) + try: + result = await translate_content(chunks, translation_provider="failing") + finally: + if ld is not None: + tr._provider_registry["langdetect"] = ld + + chunk = result[0] + assert "language" in chunk.metadata + # Should have unknown language due to detection failure + lang_meta = chunk.metadata["language"] + assert lang_meta["detected_language"] == "unknown" + assert lang_meta["language_confidence"] == 0.0 + + @pytest.mark.asyncio + async def test_error_handling_in_translation(self): + """Test graceful error handling in translation.""" + class PartialProvider: + async def detect_language(self, _text: str) -> Tuple[str, float]: + """ + Mock language detection used in tests. + + Parameters: + _text (str): Input text (ignored by this mock). + + Returns: + Tuple[str, float]: A fixed detected language code ("es") and confidence (0.9). + """ + return "es", 0.9 + + async def translate(self, _text: str, _target_language: str) -> Tuple[str, float]: + """ + Simulate a failing translation by always raising TestTranslationError. + + This async method ignores its inputs and is used in tests to emulate a provider-side failure during translation. + + Parameters: + _text (str): Unused input text. + _target_language (str): Unused target language code. + + Raises: + TestTranslationError: Always raised to simulate a translation failure. + """ + raise TestTranslationError() + + register_translation_provider("partial", PartialProvider) + + chunks = [MockDocumentChunk("Hola", "chunk_1")] + + result = await translate_content( + chunks, + translation_provider="partial", + confidence_threshold=0.8 + ) + + chunk = result[0] + # Should have detected Spanish but failed translation + assert chunk.metadata["language"]["detected_language"] == "es" + # Should still create translation metadata with original text + assert "translation" in chunk.metadata + trans_meta = chunk.metadata["translation"] + assert trans_meta["translated_text"] == "Hola" # Original text due to failure + assert trans_meta["confidence_score"] == 0.0 + + @pytest.mark.asyncio + async def test_no_translation_when_same_language(self): + """Test no translation occurs when source equals target language.""" + register_translation_provider("same_lang", MockTranslationProvider) + + chunks = [MockDocumentChunk("Hello world", "chunk_1")] + + result = await translate_content( + chunks, + target_language="en", # Same as detected language + translation_provider="same_lang" + ) + + chunk = result[0] + assert "language" in chunk.metadata + # No translation should occur for same language + assert "translation" not in chunk.metadata + + @pytest.mark.asyncio + async def test_metadata_serialization(self): + """Test that metadata is properly serialized to dicts.""" + register_translation_provider("serialize_test", MockTranslationProvider) + + chunks = [MockDocumentChunk("Hola", "chunk_1")] + + result = await translate_content( + chunks, + translation_provider="serialize_test", + confidence_threshold=0.8 + ) + + chunk = result[0] + + # Metadata should be plain dicts, not Pydantic models + assert isinstance(chunk.metadata["language"], dict) + if "translation" in chunk.metadata: + assert isinstance(chunk.metadata["translation"], dict) + + def test_model_serialization_compatibility(self): + """ + Verify that a TranslatedContent instance can be dumped to a JSON-serializable dict. + + Creates a TranslatedContent with sample fields, calls model_dump(), and asserts: + - the result is a dict, + - required fields like `original_chunk_id`, `translation_timestamp`, and `metadata` are present and preserved, + - the dict can be round-tripped through json.dumps/json.loads without losing `original_chunk_id`. + """ + content = TranslatedContent( + original_chunk_id="test", + original_text="Hello", + translated_text="Hola", + source_language="en", + target_language="es" + ) + + # Should serialize to dict + data = content.model_dump() + assert isinstance(data, dict) + assert data["original_chunk_id"] == "test" + assert "translation_timestamp" in data + assert "metadata" in data + + # Should be JSON serializable + import json + json_str = json.dumps(data) + parsed = json.loads(json_str) + assert parsed["original_chunk_id"] == "test" + + + diff --git a/cognee/tasks/translation/translate_content.py b/cognee/tasks/translation/translate_content.py new file mode 100644 index 000000000..da77aa15e --- /dev/null +++ b/cognee/tasks/translation/translate_content.py @@ -0,0 +1,660 @@ +# pylint: disable=R0903, W0221 +"""This module provides content translation capabilities for the Cognee framework.""" +import asyncio +import math +import os +from dataclasses import dataclass, field +from typing import Any, Dict, Type, Protocol, Tuple, Optional + +from cognee.shared.logging_utils import get_logger +from .models import TranslatedContent, LanguageMetadata + +logger = get_logger() + +# Custom exceptions for better error handling +class TranslationDependencyError(ImportError): + """Raised when a required translation dependency is missing.""" + +class LangDetectError(TranslationDependencyError): + """LangDetect library required.""" + +class OpenAIError(TranslationDependencyError): + """OpenAI library required.""" + +class GoogleTranslateError(TranslationDependencyError): + """GoogleTrans library required.""" + +class AzureTranslateError(TranslationDependencyError): + """Azure AI Translation library required.""" + +class AzureConfigError(ValueError): + """Azure configuration error.""" + +# Environment variables for configuration +TARGET_LANGUAGE = os.getenv("COGNEE_TRANSLATION_TARGET_LANGUAGE", "en") +try: + CONFIDENCE_THRESHOLD = float(os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD", "0.80")) +except (TypeError, ValueError): + logger.warning( + "Invalid float for COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD=%r; defaulting to 0.80", + os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD"), + ) + CONFIDENCE_THRESHOLD = 0.80 + + +@dataclass +class TranslationContext: + """A context object to hold data for a single translation operation.""" + provider: "TranslationProvider" + chunk: Any + text: str + target_language: str + confidence_threshold: float + provider_name: str + content_id: str = field(init=False) + detected_language: str = "unknown" + detection_confidence: float = 0.0 + requires_translation: bool = False + + def __post_init__(self): + """ + Initialize derived fields after dataclass construction. + + Sets self.content_id to the first available identifier on self.chunk in this order: + - self.chunk.id + - self.chunk.chunk_index + If neither attribute exists, content_id is set to the string "unknown". + """ + self.content_id = getattr(self.chunk, "id", getattr(self.chunk, "chunk_index", "unknown")) + + +class TranslationProvider(Protocol): + """Protocol for translation providers.""" + async def detect_language(self, text: str) -> Optional[Tuple[str, float]]: + """ + Detect the language of the provided text. + + Uses the langdetect library to determine the most likely language and its probability. + Returns a tuple (language_code, confidence) where `language_code` is a normalized short code (e.g., "en", "fr" or "unknown") and `confidence` is a float in [0.0, 1.0]. Returns None when detection fails (empty input, an error, or no reliable result). + """ + + async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]: + """ + Translate the given text into the specified target language asynchronously. + + Parameters: + text: The source text to translate. + target_language: Target language code (e.g., "en", "es", "fr-CA"). + + Returns: + A tuple (translated_text, confidence) on success, where `confidence` is a float in [0.0, 1.0] (may be 0.0 if the provider does not supply a score), or None if translation failed or was unavailable. + """ + +# Registry for translation providers +_provider_registry: Dict[str, Type[TranslationProvider]] = {} + +def register_translation_provider(name: str, provider: Type[TranslationProvider]): + """ + Register a translation provider under a canonical lowercase key. + + The provided class will be stored in the internal provider registry and looked up by its lowercased `name`. If an entry with the same key already exists it will be replaced. + + Parameters: + name (str): Human-readable provider name (case-insensitive); stored as lower-case. + provider (Type[TranslationProvider]): Provider class implementing the TranslationProvider protocol; instances are constructed when the provider is resolved. + """ + _provider_registry[name.lower()] = provider + +def get_available_providers(): + """Returns a list of available translation providers.""" + return sorted(_provider_registry.keys()) + +def _get_provider(translation_provider: str) -> TranslationProvider: + """ + Resolve and instantiate a registered translation provider by name. + + The lookup is case-insensitive: `translation_provider` should be the provider key (e.g., "openai", "google", "noop"). + Returns an instance of the provider implementing the TranslationProvider protocol. + + Raises: + ValueError: if no provider is registered under the given name; the error message lists available providers. + """ + provider_class = _provider_registry.get(translation_provider.lower()) + if not provider_class: + available = ', '.join(get_available_providers()) + msg = f"Unknown translation provider: {translation_provider}. Available providers: {available}" + raise ValueError(msg) + return provider_class() +# Helpers +def _normalize_lang_code(code: Optional[str]) -> str: + """ + Normalize a language code to a canonical form or return "unknown". + + Normalizes common language code formats: + - Two-letter codes (e.g., "en", "EN", " en ") -> "en" + - Locale codes with region (e.g., "en-us", "en_US", "EN-us") -> "en-US" + - Returns "unknown" for empty, non-string, or unrecognized inputs. + + Parameters: + code (Optional[str]): Language code or locale string to normalize. + + Returns: + str: Normalized language code in either "xx" or "xx-YY" form, or "unknown" if input is invalid. + """ + if not isinstance(code, str) or not code.strip(): + return "unknown" + c = code.strip().replace("_", "-") + parts = c.split("-") + if len(parts) == 1 and len(parts[0]) == 2 and parts[0].isalpha(): + return parts[0].lower() + if len(parts) >= 2 and len(parts[0]) == 2 and parts[1]: + return f"{parts[0].lower()}-{parts[1][:2].upper()}" + return "unknown" + +def _provider_name(provider: TranslationProvider) -> str: + """Return the canonical registry key for a provider instance, or a best-effort name.""" + return next( + (name for name, cls in _provider_registry.items() if isinstance(provider, cls)), + provider.__class__.__name__.replace("Provider", "").lower(), + ) + +async def _detect_language_with_fallback(provider: TranslationProvider, text: str, content_id: str) -> Tuple[str, float]: + """ + Detect the language of `text`, falling back to the registered "langdetect" provider if the primary provider fails. + + Attempts to call the primary provider's `detect_language`. If that call returns None or raises, and a different "langdetect" provider is registered, it will try the fallback. Detection failures are logged; exceptions are not propagated. + + Parameters: + text (str): The text to detect language for. + content_id (str): Identifier used in logs to correlate errors to the input content. + + Returns: + Tuple[str, float]: A normalized language code (e.g., "en" or "pt-BR") and a confidence score in [0.0, 1.0]. + On detection failure returns ("unknown", 0.0). Confidence values are coerced to float, NaNs converted to 0.0, and clamped to the [0.0, 1.0] range. + """ + try: + detection = await provider.detect_language(text) + except Exception: + logger.exception("Language detection failed for content_id=%s", content_id) + detection = None + + if detection is None: + fallback_cls = _provider_registry.get("langdetect") + if fallback_cls is not None and not isinstance(provider, fallback_cls): + try: + detection = await fallback_cls().detect_language(text) + except Exception: + logger.exception("Fallback language detection failed for content_id=%s", content_id) + detection = None + + if detection is None: + return "unknown", 0.0 + + lang_code, conf = detection + detected_language = _normalize_lang_code(lang_code) + try: + conf = float(conf) + except (TypeError, ValueError): + conf = 0.0 + if math.isnan(conf): + conf = 0.0 + conf = max(0.0, min(1.0, conf)) + return detected_language, conf + +def _decide_if_translation_is_required(ctx: TranslationContext) -> None: + """ + Decide whether a translation should be performed and update ctx.requires_translation. + + Normalizes the configured target language and marks translation as required only when: + - The provider can perform translations (not "noop" or "langdetect"), and + - Either the detected language is "unknown" and the text is non-empty, or + - The detected language (normalized) differs from the target language and the detection confidence meets or exceeds ctx.confidence_threshold. + + The function mutates the provided TranslationContext in-place and does not return a value. + """ + # Normalize to align with detected_language normalization and model regex. + target_language = _normalize_lang_code(ctx.target_language) + can_translate = ctx.provider_name not in ("noop", "langdetect") + + if ctx.detected_language == "unknown": + ctx.requires_translation = can_translate and bool(ctx.text.strip()) + else: + ctx.requires_translation = ( + ctx.detected_language != target_language + and ctx.detection_confidence >= ctx.confidence_threshold + ) + +def _attach_language_metadata(ctx: TranslationContext) -> None: + """ + Attach language detection and translation decision metadata to the context's chunk. + + Ensures the chunk has a metadata mapping, builds a LanguageMetadata record from + the context (content_id, detected language and confidence, whether translation is + required, and character count of the text), serializes it, and stores it under + the "language" key in chunk.metadata. + + Parameters: + ctx (TranslationContext): Context containing the chunk and detection/decision values. + """ + ctx.chunk.metadata = getattr(ctx.chunk, "metadata", {}) or {} + lang_meta = LanguageMetadata( + content_id=str(ctx.content_id), + detected_language=ctx.detected_language, + language_confidence=ctx.detection_confidence, + requires_translation=ctx.requires_translation, + character_count=len(ctx.text), + ) + ctx.chunk.metadata["language"] = lang_meta.model_dump() + +async def _translate_and_update(ctx: TranslationContext) -> None: + """ + Translate the text in the provided TranslationContext and update the chunk and its metadata. + + Performs an async translation via ctx.provider.translate, and when a non-empty, changed translation is returned: + - replaces ctx.chunk.text with the translated text, + - attempts to update ctx.chunk.chunk_size (if present), + - attaches a `translation` entry in ctx.chunk.metadata containing a TranslatedContent dict (original/translated text, source/target languages, provider, and confidence). + + If translation fails (exception or None) the original text is preserved and a TranslatedContent record is still attached with confidence 0.0. If the provider returns the same text unchanged, no metadata is attached and the function returns without modifying the chunk. + + Parameters: + ctx (TranslationContext): context carrying provider, chunk, original text, target language, detected language, and content_id. + + Returns: + None + """ + try: + tr = await ctx.provider.translate(ctx.text, ctx.target_language) + except Exception: + logger.exception("Translation failed for content_id=%s", ctx.content_id) + tr = None + + translated_text = None + translation_confidence = 0.0 + provider_used = _provider_name(ctx.provider) + target_for_meta = _normalize_lang_code(ctx.target_language) + + if tr and isinstance(tr[0], str) and tr[0].strip() and tr[0] != ctx.text: + translated_text, translation_confidence = tr + ctx.chunk.text = translated_text + if hasattr(ctx.chunk, "chunk_size"): + try: + ctx.chunk.chunk_size = len(translated_text.split()) + except (AttributeError, ValueError, TypeError): + logger.debug( + "Could not update chunk_size for content_id=%s", + ctx.content_id, + exc_info=True, + ) + elif tr is None: + # Translation failed, keep original text + translated_text = ctx.text + else: + # Provider returned unchanged text + logger.info("Provider returned unchanged text; skipping translation metadata (content_id=%s)", ctx.content_id) + return + + trans = TranslatedContent( + original_chunk_id=str(ctx.content_id), + original_text=ctx.text, + translated_text=translated_text, + source_language=ctx.detected_language, + target_language=target_for_meta, + translation_provider=provider_used, + confidence_score=translation_confidence or 0.0, + ) + ctx.chunk.metadata["translation"] = trans.model_dump() + + +# Test helpers for registry isolation +def snapshot_registry() -> Dict[str, Type[TranslationProvider]]: + """Return a shallow copy snapshot of the provider registry (for tests).""" + return dict(_provider_registry) + +def restore_registry(snapshot: Dict[str, Type[TranslationProvider]]) -> None: + """ + Restore the global translation provider registry from a previously captured snapshot. + + This replaces the current internal provider registry with the given snapshot (clears then updates), + typically used by tests to restore provider registration state. + + Parameters: + snapshot (Dict[str, Type[TranslationProvider]]): Mapping of provider name keys to provider classes. + """ + _provider_registry.clear() + _provider_registry.update(snapshot) + +def validate_provider(name: str) -> None: + """Ensure a provider can be resolved and instantiated or raise.""" + _get_provider(name) + +# Built-in Providers +class NoOpProvider: + """A provider that does nothing, used for testing or disabling translation.""" + async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]: + """ + No-op language detection: intentionally performs no detection and always returns None. + + The `_text` parameter is ignored. Returns None to indicate that this provider does not provide a language detection result. + """ + return None + + async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]: + """ + Return the input text unchanged and a confidence score of 0.0. + + This provider does not perform any translation; it mirrors the source text back to the caller. + Parameters: + text (str): Source text to "translate". + _target_language (str): Unused target language parameter. + Returns: + Optional[Tuple[str, float]]: A tuple of (text, 0.0). + """ + return text, 0.0 + +class LangDetectProvider: + """ + A provider that uses the 'langdetect' library for offline language detection. + This provider only detects the language and does not perform translation. + """ + def __init__(self): + """ + Initialize the LangDetectProvider by loading the `langdetect.detect_langs` function. + + Attempts to import `detect_langs` from the `langdetect` package and stores it on the instance as `_detect_langs`. Raises `LangDetectError` if the `langdetect` dependency is not available. + """ + try: + from langdetect import detect_langs # type: ignore[import-untyped] + self._detect_langs = detect_langs + except ImportError as e: + raise LangDetectError() from e + + async def detect_language(self, text: str) -> Optional[Tuple[str, float]]: + """ + Detect the language of `text` using the provider's langdetect backend. + + Returns a tuple of (language_code, confidence) where `language_code` is the top + detected language (e.g., "en") and `confidence` is the detection probability + in [0.0, 1.0]. Returns None if detection fails or no result is available. + """ + try: + detections = await asyncio.to_thread(self._detect_langs, text) + except Exception: + logger.exception("Error during language detection") + return None + + if not detections: + return None + best_detection = detections[0] + return best_detection.lang, best_detection.prob + + async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]: + # This provider only detects language, does not translate. + """ + No-op translation: returns the input text unchanged with a 0.0 confidence. + + This provider only performs language detection; translate is a passthrough that returns the original `text` + and a confidence of 0.0 to indicate no translated content was produced. + + Returns: + A tuple of (text, confidence) where `text` is the original input and `confidence` is 0.0. + """ + return text, 0.0 + +class OpenAIProvider: + """A provider that uses OpenAI's API for translation.""" + def __init__(self): + """ + Initialize the OpenAIProvider by creating an AsyncOpenAI client and loading configuration. + + Reads the following environment variables: + - OPENAI_API_KEY: API key passed to AsyncOpenAI for authentication. + - OPENAI_TRANSLATE_MODEL: model name to use for translations (default: "gpt-4o-mini"). + - OPENAI_TIMEOUT: request timeout in seconds (default: "30", parsed as float). + + Raises: + OpenAIError: if the OpenAI SDK (AsyncOpenAI) cannot be imported. + """ + try: + from openai import AsyncOpenAI # type: ignore[import-untyped] + self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) + self.model = os.getenv("OPENAI_TRANSLATE_MODEL", "gpt-4o-mini") + self.timeout = float(os.getenv("OPENAI_TIMEOUT", "30")) + except ImportError as e: + raise OpenAIError() from e + + async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]: + # OpenAI's API does not have a separate language detection endpoint. + # This can be implemented as part of the translation prompt if needed. + """ + Indicates that this provider does not perform standalone language detection. + + The OpenAI-based provider does not expose a separate detection endpoint and therefore + always returns None. Language detection can be achieved by using another provider + (e.g., the registered langdetect provider) or by incorporating detection into a + translation prompt if needed. + """ + return None + + async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]: + """ + Translate the given text to the specified target language using the OpenAI chat completions client. + + Parameters: + text (str): Source text to translate. + target_language (str): Target language name or code (used verbatim in the translation prompt). + + Returns: + Optional[Tuple[str, float]]: A tuple of (translated_text, confidence). Confidence is 0.0 because no calibrated confidence is available. + Returns None if translation failed or an error occurred. + """ + try: + response = await self.client.with_options(timeout=self.timeout).chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": f"Translate the following text to {target_language}."}, + {"role": "user", "content": text}, + ], + temperature=0.0, + ) + except Exception: + logger.exception("Error during OpenAI translation (model=%s)", self.model) + return None + + translated_text = response.choices[0].message.content.strip() + return translated_text, 0.0 # No calibrated confidence available. + +class GoogleTranslateProvider: + """A provider that uses the 'googletrans' library for translation.""" + def __init__(self): + """ + Initialize the GoogleTranslateProvider by importing and instantiating googletrans.Translator. + + Raises: + GoogleTranslateError: If the `googletrans` library is not installed or cannot be imported. + """ + try: + from googletrans import Translator # type: ignore[import-untyped] + self.translator = Translator() + except ImportError as e: + raise GoogleTranslateError() from e + + async def detect_language(self, text: str) -> Optional[Tuple[str, float]]: + """ + Detect the language of the given text using the configured googletrans Translator. + + Uses a thread to call the synchronous translator.detect method; on failure returns None. + + Parameters: + text: The text to detect the language for. + + Returns: + A tuple (language_code, confidence) where `language_code` is the detected language string from the translator (e.g. "en") and `confidence` is a float in [0.0, 1.0]. Returns None if detection fails. + """ + try: + detection = await asyncio.to_thread(self.translator.detect, text) + except Exception: + logger.exception("Error during Google Translate language detection") + return None + + try: + conf = float(detection.confidence) if detection.confidence is not None else 0.0 + except (TypeError, ValueError): + conf = 0.0 + return detection.lang, conf + + async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]: + """ + Translate `text` to `target_language` using the configured googletrans Translator. + + Returns a tuple (translated_text, confidence) on success — confidence is always 0.0 because googletrans does not provide a confidence score — or None if translation fails. + """ + try: + translation = await asyncio.to_thread(self.translator.translate, text, dest=target_language) + except Exception: + logger.exception("Error during Google Translate translation") + return None + + return translation.text, 0.0 # Confidence not provided. + +class AzureTranslatorProvider: + """A provider that uses Azure's Translator service.""" + def __init__(self): + """ + Initialize the AzureTranslatorProvider. + + Attempts to import Azure SDK classes, reads AZURE_TRANSLATOR_KEY, AZURE_TRANSLATOR_ENDPOINT, + and AZURE_TRANSLATOR_REGION from the environment, verifies the key is present, and constructs + a TextTranslationClient using an AzureKeyCredential. + + Raises: + AzureConfigError: if AZURE_TRANSLATOR_KEY is not set. + AzureTranslateError: if required Azure SDK imports are unavailable. + """ + try: + from azure.core.credentials import AzureKeyCredential # type: ignore[import-untyped] + from azure.ai.translation.text import TextTranslationClient # type: ignore[import-untyped] + + self.key = os.getenv("AZURE_TRANSLATOR_KEY") + self.endpoint = os.getenv("AZURE_TRANSLATOR_ENDPOINT", "https://api.cognitive.microsofttranslator.com/") + self.region = os.getenv("AZURE_TRANSLATOR_REGION", "global") + + if not self.key: + raise AzureConfigError() + + self.client = TextTranslationClient( + endpoint=self.endpoint, + credential=AzureKeyCredential(self.key), + ) + except ImportError as e: + raise AzureTranslateError() from e + + async def detect_language(self, text: str) -> Optional[Tuple[str, float]]: + """ + Detect the language of the given text using the Azure Translator client's detect API. + + Attempts to call the Azure client's detect method (using a two-letter region as a country hint when available) + and returns a tuple of (language_code, confidence_score). Returns None if detection fails or an exception occurs. + + Parameters: + text (str): The text to detect language for. + + Returns: + Optional[Tuple[str, float]]: (ISO language code, confidence between 0.0 and 1.0), or None on error. + """ + try: + # Use a valid country hint only when it looks like ISO 3166-1 alpha-2; otherwise omit. + hint = self.region.lower() if isinstance(self.region, str) and len(self.region) == 2 else None + response = await asyncio.to_thread(self.client.detect, content=[text], country_hint=hint) + except Exception: + logger.exception("Error during Azure language detection") + return None + + detection = response[0].primary_language + return detection.language, detection.score + + async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]: + """ + Translate the given text to the target language using the Azure Translator client. + + Parameters: + text (str): Plain text to translate. + target_language (str): BCP-47 or ISO language code to translate the text into. + + Returns: + Optional[Tuple[str, float]]: A tuple of (translated_text, confidence). Returns None on error. + The provider does not surface a numeric confidence score, so the returned confidence is always 0.0. + """ + try: + response = await asyncio.to_thread(self.client.translate, content=[text], to=[target_language]) + except Exception: + logger.exception("Error during Azure translation") + return None + + translation = response[0].translations[0] + return translation.text, 0.0 # Confidence not provided. + +# Register built-in providers +register_translation_provider("noop", NoOpProvider) +register_translation_provider("langdetect", LangDetectProvider) +register_translation_provider("openai", OpenAIProvider) +register_translation_provider("google", GoogleTranslateProvider) +register_translation_provider("azure", AzureTranslatorProvider) + +async def translate_content( # pylint: disable=too-many-locals,too-many-branches + *data_chunks, + target_language: str = TARGET_LANGUAGE, + translation_provider: str = "noop", + confidence_threshold: float = CONFIDENCE_THRESHOLD, +): + """ + Translate content chunks to a target language and attach language and translation metadata. + + This function accepts either multiple chunk objects as varargs or a single list of chunks. + For each chunk it: + - Resolves the named translation provider. + - Detects the chunk's language (with a fallback detector when available). + - Decides whether translation is required based on detected language, confidence threshold, and provider. + - Attaches language metadata (LanguageMetadata) to chunk.metadata. + - If required, performs translation and updates the chunk text and metadata (TranslatedContent). + + Parameters: + *data_chunks: One or more chunk objects, or a single list of chunk objects. Each chunk must expose a `text` attribute and a `metadata` mapping (the function will create `metadata` if missing). + target_language (str): Language code to translate into (defaults to TARGET_LANGUAGE). + translation_provider (str): Registered provider name to use for detection/translation (defaults to "noop"). + confidence_threshold (float): Minimum detection confidence required to skip translation (defaults to CONFIDENCE_THRESHOLD). + + Returns: + list: The list of processed chunk objects (same objects, possibly modified). Metadata keys added include language detection results and, when a translation occurs, translation details. + """ + provider = _get_provider(translation_provider) + results = [] + + if len(data_chunks) == 1 and isinstance(data_chunks[0], list): + _chunks = data_chunks[0] + else: + _chunks = list(data_chunks) + + for chunk in _chunks: + ctx = TranslationContext( + provider=provider, + chunk=chunk, + text=getattr(chunk, "text", "") or "", + target_language=target_language, + confidence_threshold=confidence_threshold, + provider_name=translation_provider.lower(), + ) + + ctx.detected_language, ctx.detection_confidence = await _detect_language_with_fallback( + ctx.provider, ctx.text, str(ctx.content_id) + ) + + _decide_if_translation_is_required(ctx) + _attach_language_metadata(ctx) + + if ctx.requires_translation: + await _translate_and_update(ctx) + + results.append(ctx.chunk) + + return results diff --git a/examples/python/translation_example.py b/examples/python/translation_example.py new file mode 100644 index 000000000..bb18a30ef --- /dev/null +++ b/examples/python/translation_example.py @@ -0,0 +1,86 @@ +import asyncio +import os +import cognee +from cognee.api.v1.search import SearchType +from cognee.api.v1.cognify.cognify import get_default_tasks_with_translation +from cognee.modules.pipelines.operations.pipeline import run_pipeline + +# Prerequisites: +# 1. Set up your environment with API keys for your chosen translation provider. +# - For OpenAI: OPENAI_API_KEY +# - For Azure: AZURE_TRANSLATOR_KEY, AZURE_TRANSLATOR_ENDPOINT, AZURE_TRANSLATOR_REGION +# 2. Specify the translation provider via an environment variable (optional, defaults to "noop"): +# COGNEE_TRANSLATION_PROVIDER="openai" # Or "google", "azure", "langdetect" +# 3. Install any required libraries for your provider: +# - pip install langdetect googletrans==4.0.0rc1 azure-ai-translation-text + +async def main(): + """ + Demonstrates an end-to-end translation-enabled Cognify workflow using the Cognee SDK. + + Performs three main steps: + 1. Resets the demo workspace by pruning stored data and system metadata. + 2. Seeds three multilingual documents, builds translation-enabled Cognify tasks using the + provider specified by the COGNEE_TRANSLATION_PROVIDER environment variable (defaults to "noop"), + and executes the pipeline to translate and process the documents. + - If the selected provider is missing or invalid, the function prints the error and returns early. + 3. Issues an English search query (using SearchType.INSIGHTS) against the processed index and + prints any returned result texts. + + Side effects: + - Mutates persistent Cognee state (prune, add, cognify pipeline execution). + - Prints status and result messages to stdout. + + Notes: + - No return value. + - Exceptions ValueError and ImportError are caught and handled by printing an error and exiting the function. + """ + # 1. Set up cognee and add multilingual content + print("Setting up demo environment...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + multilingual_texts = [ + "El procesamiento de lenguaje natural (PLN) es un subcampo de la IA.", + "Le traitement automatique du langage naturel (TALN) est un sous-domaine de l'IA.", + "Natural language processing (NLP) is a subfield of AI.", + ] + + print("Adding multilingual texts...") + for text in multilingual_texts: + await cognee.add(text) + print("Texts added successfully.\n") + + # 2. Run the cognify pipeline with translation enabled + provider = os.getenv('COGNEE_TRANSLATION_PROVIDER', 'noop').lower() + print(f"Running cognify with translation provider: {provider}") + + try: + # Build translation-enabled tasks and execute the pipeline + translation_enabled_tasks = get_default_tasks_with_translation( + translation_provider=provider + ) + async for _ in run_pipeline(tasks=translation_enabled_tasks): + pass + print("Cognify pipeline with translation completed successfully.") + except (ValueError, ImportError) as e: + print(f"Error during cognify: {e}") + print("Please ensure the selected provider is installed and configured correctly.") + return + + # 3. Search for content in English + query_text = "Tell me about NLP" + print(f"\nSearching for: '{query_text}'") + + # The search should now return results from all documents, as they have been translated. + search_results = await cognee.search(query_text, query_type=SearchType.INSIGHTS) + + print("\nSearch Results:") + if search_results: + for result in search_results: + print(f"- {result.text}") + else: + print("No results found.") + +if __name__ == "__main__": + asyncio.run(main())