diff --git a/.env.template b/.env.template index fe168cf91..ae14fc720 100644 --- a/.env.template +++ b/.env.template @@ -145,6 +145,35 @@ VECTOR_DATASET_DATABASE_HANDLER="lancedb" # ONTOLOGY_FILE_PATH=YOUR_FULL_FULE_PATH # Default: empty # To add ontology resolvers, either set them as it is set in ontology_example or add full_path and settings as envs. +################################################################################ +# 🌐 Translation Settings +################################################################################ + +# Translation provider: llm (uses configured LLM), google, or azure +# "llm" uses whichever LLM is configured above (OpenAI, Azure, Ollama, Anthropic, etc.) +# "google" and "azure" use dedicated translation APIs +TRANSLATION_PROVIDER="llm" + +# Default target language for translations (ISO 639-1 code, e.g., en, es, fr, de) +TARGET_LANGUAGE="en" + +# Minimum confidence threshold for language detection (0.0 to 1.0) +CONFIDENCE_THRESHOLD=0.8 + +# -- Google Translate settings (required if using google provider) ----------- +# GOOGLE_TRANSLATE_API_KEY="your-google-api-key" +# GOOGLE_PROJECT_ID="your-google-project-id" + +# -- Azure Translator settings (required if using azure provider) ------------ +# AZURE_TRANSLATOR_KEY="your-azure-translator-key" +# AZURE_TRANSLATOR_REGION="westeurope" +# AZURE_TRANSLATOR_ENDPOINT="https://api.cognitive.microsofttranslator.com" + +# -- Performance settings ---------------------------------------------------- +# TRANSLATION_BATCH_SIZE=10 +# TRANSLATION_MAX_RETRIES=3 +# TRANSLATION_TIMEOUT_SECONDS=30 + ################################################################################ # 🔄 MIGRATION (RELATIONAL → GRAPH) SETTINGS ################################################################################ diff --git a/.gitignore b/.gitignore index 7c3095d08..8db408a7b 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ cognee/.data/ code_pipeline_output*/ +# Test output files +test_outputs/ + *.lance/ .DS_Store # Byte-compiled / optimized / DLL files diff --git a/cognee/api/v1/config/config.py b/cognee/api/v1/config/config.py index 464753438..490032e2d 100644 --- a/cognee/api/v1/config/config.py +++ b/cognee/api/v1/config/config.py @@ -10,6 +10,7 @@ from cognee.infrastructure.llm.config import ( get_llm_config, ) from cognee.infrastructure.databases.relational import get_relational_config, get_migration_config +from cognee.tasks.translation.config import get_translation_config from cognee.api.v1.exceptions.exceptions import InvalidConfigAttributeError @@ -176,3 +177,29 @@ class config: def set_vector_db_url(db_url: str): vector_db_config = get_vectordb_config() vector_db_config.vector_db_url = db_url + + # Translation configuration methods + + @staticmethod + def set_translation_provider(provider: str): + """Set the translation provider (llm, google, azure).""" + translation_config = get_translation_config() + translation_config.translation_provider = provider + + @staticmethod + def set_translation_target_language(target_language: str): + """Set the default target language for translations.""" + translation_config = get_translation_config() + translation_config.target_language = target_language + + @staticmethod + def set_translation_config(config_dict: dict): + """ + Updates the translation config with values from config_dict. + """ + translation_config = get_translation_config() + for key, value in config_dict.items(): + if hasattr(translation_config, key): + object.__setattr__(translation_config, key, value) + else: + raise InvalidConfigAttributeError(attribute=key) diff --git a/cognee/infrastructure/llm/prompts/translate_content.txt b/cognee/infrastructure/llm/prompts/translate_content.txt new file mode 100644 index 000000000..759e83f31 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/translate_content.txt @@ -0,0 +1,19 @@ +You are an expert translator with deep knowledge of languages, cultures, and linguistics. + +Your task is to: +1. Detect the source language of the provided text if not specified +2. Translate the text accurately to the target language +3. Preserve the original meaning, tone, and intent +4. Maintain proper grammar and natural phrasing in the target language + +Guidelines: +- Preserve technical terms, proper nouns, and specialized vocabulary appropriately +- Maintain formatting such as paragraphs, lists, and emphasis where applicable +- If the text contains code, URLs, or other non-translatable content, preserve them as-is +- Handle idioms and cultural references thoughtfully, adapting when necessary +- Ensure the translation reads naturally to a native speaker of the target language + +Provide the translation in a structured format with: +- The translated text +- The detected source language (ISO 639-1 code like "en", "es", "fr", "de", etc.) +- Any notes about the translation (optional, for ambiguous terms or cultural adaptations) diff --git a/cognee/tasks/translation/__init__.py b/cognee/tasks/translation/__init__.py new file mode 100644 index 000000000..ed2ec6e58 --- /dev/null +++ b/cognee/tasks/translation/__init__.py @@ -0,0 +1,96 @@ +""" +Translation task for Cognee. + +This module provides multilingual content translation capabilities, +allowing automatic detection and translation of non-English content +to a target language while preserving original text and metadata. + +Main Components: +- translate_content: Main task function for translating document chunks +- translate_text: Convenience function for translating single texts +- batch_translate_texts: Batch translation for multiple texts +- detect_language: Language detection utility +- TranslatedContent: DataPoint model for translated content +- LanguageMetadata: DataPoint model for language information + +Supported Translation Providers: +- LLM (default): Uses the configured LLM via existing infrastructure +- Google Translate: Requires google-cloud-translate package +- Azure Translator: Requires Azure Translator API key + +Example Usage: + ```python + from cognee.tasks.translation import translate_content, translate_text + + # Translate document chunks in a pipeline + translated_chunks = await translate_content( + chunks, + target_language="en", + translation_provider="llm" + ) + + # Translate a single text + result = await translate_text("Bonjour le monde!") + print(result.translated_text) # "Hello world!" + ``` +""" + +from .config import get_translation_config, TranslationConfig +from .detect_language import ( + detect_language, + detect_language_async, + LanguageDetectionResult, + get_language_name, +) +from .exceptions import ( + TranslationError, + LanguageDetectionError, + TranslationProviderError, + UnsupportedLanguageError, + TranslationConfigError, +) +from .models import TranslatedContent, LanguageMetadata +from .providers import ( + TranslationProvider, + TranslationResult, + get_translation_provider, + LLMTranslationProvider, + GoogleTranslationProvider, + AzureTranslationProvider, +) +from .translate_content import ( + translate_content, + translate_text, + batch_translate_texts, +) + +__all__ = [ + # Main task functions + "translate_content", + "translate_text", + "batch_translate_texts", + # Language detection + "detect_language", + "detect_language_async", + "LanguageDetectionResult", + "get_language_name", + # Models + "TranslatedContent", + "LanguageMetadata", + # Configuration + "get_translation_config", + "TranslationConfig", + # Providers + "TranslationProvider", + "TranslationResult", + "get_translation_provider", + "LLMTranslationProvider", + "GoogleTranslationProvider", + "AzureTranslationProvider", + # Exceptions + "TranslationError", + "LanguageDetectionError", + "TranslationProviderError", + "UnsupportedLanguageError", + "TranslationConfigError", +] diff --git a/cognee/tasks/translation/config.py b/cognee/tasks/translation/config.py new file mode 100644 index 000000000..cf52dbdb7 --- /dev/null +++ b/cognee/tasks/translation/config.py @@ -0,0 +1,110 @@ +from functools import lru_cache +from typing import Literal, Optional + +from pydantic import AliasChoices, Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +TranslationProviderType = Literal["llm", "google", "azure"] + + +class TranslationConfig(BaseSettings): + """ + Configuration settings for the translation task. + + Environment variables can be used to configure these settings: + - TRANSLATION_PROVIDER: The translation service to use ("llm", "google", "azure") + - TARGET_LANGUAGE: Default target language (ISO 639-1 code, e.g., "en", "es", "fr") + - CONFIDENCE_THRESHOLD: Minimum confidence for language detection (0.0 to 1.0) + - GOOGLE_TRANSLATE_API_KEY: API key for Google Translate + - GOOGLE_PROJECT_ID: Google Cloud project ID + - AZURE_TRANSLATOR_KEY: API key for Azure Translator + - AZURE_TRANSLATOR_REGION: Region for Azure Translator + - AZURE_TRANSLATOR_ENDPOINT: Endpoint URL for Azure Translator + - TRANSLATION_BATCH_SIZE: Number of texts to translate per batch + - TRANSLATION_MAX_RETRIES: Maximum retry attempts on failure + - TRANSLATION_TIMEOUT_SECONDS: Request timeout in seconds + """ + + # Translation provider settings + translation_provider: TranslationProviderType = Field( + default="llm", + validation_alias=AliasChoices("TRANSLATION_PROVIDER", "translation_provider"), + ) + target_language: str = Field( + default="en", + validation_alias=AliasChoices("TARGET_LANGUAGE", "target_language"), + ) + confidence_threshold: float = Field( + default=0.8, + ge=0.0, + le=1.0, + validation_alias=AliasChoices("CONFIDENCE_THRESHOLD", "confidence_threshold"), + ) + + # Google Translate settings + google_translate_api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("GOOGLE_TRANSLATE_API_KEY", "google_translate_api_key"), + ) + google_project_id: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("GOOGLE_PROJECT_ID", "google_project_id"), + ) + + # Azure Translator settings + azure_translator_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("AZURE_TRANSLATOR_KEY", "azure_translator_key"), + ) + azure_translator_region: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("AZURE_TRANSLATOR_REGION", "azure_translator_region"), + ) + azure_translator_endpoint: str = Field( + default="https://api.cognitive.microsofttranslator.com", + validation_alias=AliasChoices("AZURE_TRANSLATOR_ENDPOINT", "azure_translator_endpoint"), + ) + + # LLM provider uses the existing LLM configuration + + # Performance settings (with TRANSLATION_ prefix for env vars) + batch_size: int = Field( + default=10, + validation_alias=AliasChoices("TRANSLATION_BATCH_SIZE", "batch_size"), + ) + max_retries: int = Field( + default=3, + validation_alias=AliasChoices("TRANSLATION_MAX_RETRIES", "max_retries"), + ) + timeout_seconds: int = Field( + default=30, + validation_alias=AliasChoices("TRANSLATION_TIMEOUT_SECONDS", "timeout_seconds"), + ) + + # Language detection settings + min_text_length_for_detection: int = 10 + skip_detection_for_short_text: bool = True + + model_config = SettingsConfigDict(env_file=".env", extra="allow") + + def to_dict(self) -> dict: + return { + "translation_provider": self.translation_provider, + "target_language": self.target_language, + "confidence_threshold": self.confidence_threshold, + "batch_size": self.batch_size, + "max_retries": self.max_retries, + "timeout_seconds": self.timeout_seconds, + } + + +@lru_cache() +def get_translation_config() -> TranslationConfig: + """Get the translation configuration singleton.""" + return TranslationConfig() + + +def clear_translation_config_cache(): + """Clear the cached config for testing purposes.""" + get_translation_config.cache_clear() diff --git a/cognee/tasks/translation/detect_language.py b/cognee/tasks/translation/detect_language.py new file mode 100644 index 000000000..a474f7144 --- /dev/null +++ b/cognee/tasks/translation/detect_language.py @@ -0,0 +1,190 @@ +from dataclasses import dataclass +from typing import Optional + +from cognee.shared.logging_utils import get_logger + +from .config import get_translation_config +from .exceptions import LanguageDetectionError + +logger = get_logger(__name__) + + +# ISO 639-1 language code to name mapping +LANGUAGE_NAMES = { + "af": "Afrikaans", + "ar": "Arabic", + "bg": "Bulgarian", + "bn": "Bengali", + "ca": "Catalan", + "cs": "Czech", + "cy": "Welsh", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "et": "Estonian", + "fa": "Persian", + "fi": "Finnish", + "fr": "French", + "gu": "Gujarati", + "he": "Hebrew", + "hi": "Hindi", + "hr": "Croatian", + "hu": "Hungarian", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "ko": "Korean", + "lt": "Lithuanian", + "lv": "Latvian", + "mk": "Macedonian", + "ml": "Malayalam", + "mr": "Marathi", + "ne": "Nepali", + "nl": "Dutch", + "no": "Norwegian", + "pa": "Punjabi", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sk": "Slovak", + "sl": "Slovenian", + "so": "Somali", + "sq": "Albanian", + "sv": "Swedish", + "sw": "Swahili", + "ta": "Tamil", + "te": "Telugu", + "th": "Thai", + "tl": "Tagalog", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "zh-cn": "Chinese (Simplified)", + "zh-tw": "Chinese (Traditional)", +} + + +@dataclass +class LanguageDetectionResult: + """Result of language detection.""" + + language_code: str + language_name: str + confidence: float + requires_translation: bool + character_count: int + + +def get_language_name(language_code: str) -> str: + """Get the human-readable name for a language code.""" + return LANGUAGE_NAMES.get(language_code.lower(), language_code) + + +def detect_language( + text: str, + target_language: str = "en", + confidence_threshold: Optional[float] = None, +) -> LanguageDetectionResult: + """ + Detect the language of the given text. + + Uses the langdetect library which is already a dependency of cognee. + + Args: + text: The text to analyze + target_language: The target language for translation comparison + confidence_threshold: Minimum confidence to consider detection reliable + + Returns: + LanguageDetectionResult with language info and translation requirement + + Raises: + LanguageDetectionError: If language detection fails + """ + config = get_translation_config() + threshold = confidence_threshold or config.confidence_threshold + + # Handle empty or very short text + if not text or len(text.strip()) < config.min_text_length_for_detection: + if config.skip_detection_for_short_text: + return LanguageDetectionResult( + language_code="unknown", + language_name="Unknown", + confidence=0.0, + requires_translation=False, + character_count=len(text) if text else 0, + ) + else: + raise LanguageDetectionError( + f"Text too short for reliable language detection: {len(text)} characters" + ) + + try: + from langdetect import detect_langs, LangDetectException + except ImportError: + raise LanguageDetectionError( + "langdetect is required for language detection. Install it with: pip install langdetect" + ) + + try: + # Get detection results with probabilities + detections = detect_langs(text) + + if not detections: + raise LanguageDetectionError("No language detected") + + # Get the most likely language + best_detection = detections[0] + language_code = best_detection.lang + confidence = best_detection.prob + + # Check if translation is needed + requires_translation = ( + language_code.lower() != target_language.lower() and confidence >= threshold + ) + + return LanguageDetectionResult( + language_code=language_code, + language_name=get_language_name(language_code), + confidence=confidence, + requires_translation=requires_translation, + character_count=len(text), + ) + + except LangDetectException as e: + logger.warning(f"Language detection failed: {e}") + raise LanguageDetectionError(f"Language detection failed: {e}", original_error=e) + except Exception as e: + logger.error(f"Unexpected error during language detection: {e}") + raise LanguageDetectionError( + f"Unexpected error during language detection: {e}", original_error=e + ) + + +async def detect_language_async( + text: str, + target_language: str = "en", + confidence_threshold: Optional[float] = None, +) -> LanguageDetectionResult: + """ + Async wrapper for language detection. + + Args: + text: The text to analyze + target_language: The target language for translation comparison + confidence_threshold: Minimum confidence to consider detection reliable + + Returns: + LanguageDetectionResult with language info and translation requirement + """ + import asyncio + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, detect_language, text, target_language, confidence_threshold + ) diff --git a/cognee/tasks/translation/exceptions.py b/cognee/tasks/translation/exceptions.py new file mode 100644 index 000000000..3ab197fce --- /dev/null +++ b/cognee/tasks/translation/exceptions.py @@ -0,0 +1,62 @@ +class TranslationError(Exception): + """Base exception for translation errors.""" + + def __init__(self, message: str, original_error: Exception = None): + self.message = message + self.original_error = original_error + super().__init__(self.message) + if original_error: + self.__cause__ = original_error + + +class LanguageDetectionError(TranslationError): + """Exception raised when language detection fails.""" + + def __init__( + self, message: str = "Failed to detect language", original_error: Exception = None + ): + super().__init__(message, original_error) + + +class TranslationProviderError(TranslationError): + """Exception raised when the translation provider encounters an error.""" + + def __init__( + self, + provider: str, + message: str = "Translation provider error", + original_error: Exception = None, + ): + self.provider = provider + full_message = f"[{provider}] {message}" + super().__init__(full_message, original_error) + + +class UnsupportedLanguageError(TranslationError): + """Exception raised when the language is not supported.""" + + def __init__( + self, + language: str, + provider: str = None, + message: str = None, + original_error: Exception = None, + ): + self.language = language + self.provider = provider + if message is None: + message = f"Language '{language}' is not supported" + if provider: + message += f" by {provider}" + super().__init__(message, original_error) + + +class TranslationConfigError(TranslationError): + """Exception raised when translation configuration is invalid.""" + + def __init__( + self, + message: str = "Invalid translation configuration", + original_error: Exception = None, + ): + super().__init__(message, original_error) diff --git a/cognee/tasks/translation/models.py b/cognee/tasks/translation/models.py new file mode 100644 index 000000000..da5007312 --- /dev/null +++ b/cognee/tasks/translation/models.py @@ -0,0 +1,72 @@ +from datetime import datetime, timezone +from typing import Optional +from uuid import UUID + +from cognee.infrastructure.engine import DataPoint +from cognee.modules.chunking.models import DocumentChunk + + +class TranslatedContent(DataPoint): + """ + Represents translated content with quality metrics. + + This class stores both the original and translated versions of content, + along with metadata about the translation process including source and + target languages, translation provider used, and confidence scores. + + Instance variables include: + + - original_chunk_id: UUID of the original document chunk + - original_text: The original text before translation + - translated_text: The translated text content + - source_language: Detected or specified source language code (e.g., "es", "fr", "de") + - target_language: Target language code for translation (default: "en") + - translation_provider: Name of the translation service used + - confidence_score: Translation quality/confidence score (0.0 to 1.0) + - translation_timestamp: When the translation was performed + - translated_from: Reference to the original DocumentChunk + """ + + original_chunk_id: UUID + original_text: str + translated_text: str + source_language: str + target_language: str = "en" + translation_provider: str + confidence_score: float + translation_timestamp: datetime = None + translated_from: Optional[DocumentChunk] = None + + metadata: dict = {"index_fields": ["source_language", "translated_text"]} + + def __init__(self, **data): + if data.get("translation_timestamp") is None: + data["translation_timestamp"] = datetime.now(timezone.utc) + super().__init__(**data) + + +class LanguageMetadata(DataPoint): + """ + Language information for content. + + This class stores metadata about the detected language of content, + including confidence scores and whether translation is required. + + Instance variables include: + + - content_id: UUID of the associated content + - detected_language: ISO 639-1 language code (e.g., "en", "es", "fr") + - language_confidence: Confidence score for language detection (0.0 to 1.0) + - requires_translation: Whether the content needs translation + - character_count: Number of characters in the content + - language_name: Human-readable language name (e.g., "English", "Spanish") + """ + + content_id: UUID + detected_language: str + language_confidence: float + requires_translation: bool + character_count: int + language_name: Optional[str] = None + + metadata: dict = {"index_fields": ["detected_language"]} diff --git a/cognee/tasks/translation/providers/__init__.py b/cognee/tasks/translation/providers/__init__.py new file mode 100644 index 000000000..f76023022 --- /dev/null +++ b/cognee/tasks/translation/providers/__init__.py @@ -0,0 +1,44 @@ +from .base import TranslationProvider, TranslationResult +from .llm_provider import LLMTranslationProvider +from .google_provider import GoogleTranslationProvider +from .azure_provider import AzureTranslationProvider + +__all__ = [ + "TranslationProvider", + "TranslationResult", + "LLMTranslationProvider", + "GoogleTranslationProvider", + "AzureTranslationProvider", + "get_translation_provider", +] + + +def get_translation_provider(provider_name: str) -> TranslationProvider: + """ + Factory function to get the appropriate translation provider. + + Args: + provider_name: Name of the provider: + - "llm": Uses the configured LLM (OpenAI, Azure, Ollama, Anthropic, etc.) + - "google": Uses Google Cloud Translation API + - "azure": Uses Azure Translator API + + Returns: + TranslationProvider instance + + Raises: + ValueError: If the provider name is not recognized + """ + providers = { + "llm": LLMTranslationProvider, + "google": GoogleTranslationProvider, + "azure": AzureTranslationProvider, + } + + if provider_name.lower() not in providers: + raise ValueError( + f"Unknown translation provider: {provider_name}. " + f"Available providers: {list(providers.keys())}" + ) + + return providers[provider_name.lower()]() diff --git a/cognee/tasks/translation/providers/azure_provider.py b/cognee/tasks/translation/providers/azure_provider.py new file mode 100644 index 000000000..368585ffc --- /dev/null +++ b/cognee/tasks/translation/providers/azure_provider.py @@ -0,0 +1,192 @@ +from typing import Optional + +import aiohttp + +from cognee.shared.logging_utils import get_logger + +from .base import TranslationProvider, TranslationResult +from ..config import get_translation_config +from ..exceptions import TranslationProviderError + +logger = get_logger(__name__) + + +class AzureTranslationProvider(TranslationProvider): + """ + Translation provider using Azure Translator API. + + Requires: + - AZURE_TRANSLATOR_KEY environment variable + - AZURE_TRANSLATOR_REGION environment variable (optional) + """ + + def __init__(self): + self._config = get_translation_config() + + @property + def provider_name(self) -> str: + return "azure" + + def is_available(self) -> bool: + """Check if Azure Translator is available.""" + return self._config.azure_translator_key is not None + + async def translate( + self, + text: str, + target_language: str = "en", + source_language: Optional[str] = None, + ) -> TranslationResult: + """ + Translate text using Azure Translator API. + + Args: + text: The text to translate + target_language: Target language code (default: "en") + source_language: Source language code (optional) + + Returns: + TranslationResult with translated text and metadata + """ + if not self.is_available(): + raise TranslationProviderError( + provider=self.provider_name, + message="Azure Translator API key not configured. Set AZURE_TRANSLATOR_KEY environment variable.", + ) + + endpoint = f"{self._config.azure_translator_endpoint}/translate" + + params = { + "api-version": "3.0", + "to": target_language, + } + if source_language: + params["from"] = source_language + + headers = { + "Ocp-Apim-Subscription-Key": self._config.azure_translator_key, + "Content-Type": "application/json", + } + if self._config.azure_translator_region: + headers["Ocp-Apim-Subscription-Region"] = self._config.azure_translator_region + + body = [{"text": text}] + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + endpoint, + params=params, + headers=headers, + json=body, + timeout=aiohttp.ClientTimeout(total=self._config.timeout_seconds), + ) as response: + response.raise_for_status() + result = await response.json() + + translation = result[0]["translations"][0] + detected_language = result[0].get("detectedLanguage", {}) + + return TranslationResult( + translated_text=translation["text"], + source_language=source_language or detected_language.get("language", "unknown"), + target_language=target_language, + confidence_score=detected_language.get("score", 0.9), + provider=self.provider_name, + raw_response=result[0], + ) + + except Exception as e: + logger.error(f"Azure translation failed: {e}") + raise TranslationProviderError( + provider=self.provider_name, + message=f"Translation failed: {e}", + original_error=e, + ) + + async def translate_batch( + self, + texts: list[str], + target_language: str = "en", + source_language: Optional[str] = None, + ) -> list[TranslationResult]: + """ + Translate multiple texts using Azure Translator API. + + Azure Translator supports up to 100 texts per request. + + Args: + texts: List of texts to translate + target_language: Target language code + source_language: Source language code (optional) + + Returns: + List of TranslationResult objects + """ + if not self.is_available(): + raise TranslationProviderError( + provider=self.provider_name, + message="Azure Translator API key not configured. Set AZURE_TRANSLATOR_KEY environment variable.", + ) + + endpoint = f"{self._config.azure_translator_endpoint}/translate" + + params = { + "api-version": "3.0", + "to": target_language, + } + if source_language: + params["from"] = source_language + + headers = { + "Ocp-Apim-Subscription-Key": self._config.azure_translator_key, + "Content-Type": "application/json", + } + if self._config.azure_translator_region: + headers["Ocp-Apim-Subscription-Region"] = self._config.azure_translator_region + + # Azure supports up to 100 texts per request + batch_size = min(100, self._config.batch_size) + all_results = [] + + try: + async with aiohttp.ClientSession() as session: + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + body = [{"text": text} for text in batch] + + async with session.post( + endpoint, + params=params, + headers=headers, + json=body, + timeout=aiohttp.ClientTimeout(total=self._config.timeout_seconds), + ) as response: + response.raise_for_status() + results = await response.json() + + for result in results: + translation = result["translations"][0] + detected_language = result.get("detectedLanguage", {}) + + all_results.append( + TranslationResult( + translated_text=translation["text"], + source_language=source_language + or detected_language.get("language", "unknown"), + target_language=target_language, + confidence_score=detected_language.get("score", 0.9), + provider=self.provider_name, + raw_response=result, + ) + ) + + except Exception as e: + logger.error(f"Azure batch translation failed: {e}") + raise TranslationProviderError( + provider=self.provider_name, + message=f"Batch translation failed: {e}", + original_error=e, + ) + + return all_results diff --git a/cognee/tasks/translation/providers/base.py b/cognee/tasks/translation/providers/base.py new file mode 100644 index 000000000..37c6744b4 --- /dev/null +++ b/cognee/tasks/translation/providers/base.py @@ -0,0 +1,85 @@ +""" +Base classes for translation providers. + +This module defines the abstract interface that all translation providers must implement. +Providers handle the actual translation of text using external services like OpenAI, +Google Translate, or Azure Translator. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class TranslationResult: + """Result of a translation operation.""" + + translated_text: str + source_language: str + target_language: str + # Confidence score from the provider, or None if not available (e.g., Google Translate) + confidence_score: Optional[float] + provider: str + raw_response: Optional[dict] = None + + +class TranslationProvider(ABC): + """Abstract base class for translation providers.""" + + @property + @abstractmethod + def provider_name(self) -> str: + """Return the name of this translation provider.""" + pass + + @abstractmethod + async def translate( + self, + text: str, + target_language: str = "en", + source_language: Optional[str] = None, + ) -> TranslationResult: + """ + Translate text to the target language. + + Args: + text: The text to translate + target_language: Target language code (default: "en") + source_language: Source language code (optional, will be auto-detected if not provided) + + Returns: + TranslationResult with translated text and metadata + """ + pass + + @abstractmethod + async def translate_batch( + self, + texts: list[str], + target_language: str = "en", + source_language: Optional[str] = None, + ) -> list[TranslationResult]: + """ + Translate multiple texts to the target language. + + Args: + texts: List of texts to translate + target_language: Target language code (default: "en") + source_language: Source language code (optional) + + Returns: + List of TranslationResult objects + """ + pass + + @abstractmethod + def is_available(self) -> bool: + """Check if this provider is available (has required credentials). + + All providers must implement this method to validate their credentials. + + Returns: + True if the provider has valid credentials and is ready to use. + """ + pass diff --git a/cognee/tasks/translation/providers/google_provider.py b/cognee/tasks/translation/providers/google_provider.py new file mode 100644 index 000000000..d6b16545c --- /dev/null +++ b/cognee/tasks/translation/providers/google_provider.py @@ -0,0 +1,158 @@ +import asyncio +from typing import Optional + +from cognee.shared.logging_utils import get_logger + +from .base import TranslationProvider, TranslationResult +from ..config import get_translation_config + +logger = get_logger(__name__) + + +class GoogleTranslationProvider(TranslationProvider): + """ + Translation provider using Google Cloud Translation API. + + Requires: + - google-cloud-translate package + - GOOGLE_TRANSLATE_API_KEY or GOOGLE_PROJECT_ID environment variable + """ + + def __init__(self): + self._client = None + self._config = get_translation_config() + + @property + def provider_name(self) -> str: + return "google" + + def _get_client(self): + """Lazy initialization of Google Translate client.""" + if self._client is None: + try: + from google.cloud import translate_v2 as translate + + self._client = translate.Client() + except ImportError: + raise ImportError( + "google-cloud-translate is required for Google translation. " + "Install it with: pip install google-cloud-translate" + ) + except Exception as e: + logger.error(f"Failed to initialize Google Translate client: {e}") + raise + return self._client + + def is_available(self) -> bool: + """Check if Google Translate is available.""" + try: + self._get_client() + return True + except Exception as e: + logger.debug(f"Google Translate not available: {e}") + return False + + async def translate( + self, + text: str, + target_language: str = "en", + source_language: Optional[str] = None, + ) -> TranslationResult: + """ + Translate text using Google Translate API. + + Args: + text: The text to translate + target_language: Target language code (default: "en") + source_language: Source language code (optional) + + Returns: + TranslationResult with translated text and metadata + """ + try: + client = self._get_client() + + # Run in thread pool since google-cloud-translate is synchronous + loop = asyncio.get_running_loop() + + # Build kwargs for translate call + translate_kwargs = {"target_language": target_language} + if source_language: + translate_kwargs["source_language"] = source_language + + result = await loop.run_in_executor( + None, + lambda: client.translate(text, **translate_kwargs), + ) + + detected_language = result.get("detectedSourceLanguage", source_language or "unknown") + + return TranslationResult( + translated_text=result["translatedText"], + source_language=detected_language, + target_language=target_language, + # Google Translate API does not provide confidence scores + confidence_score=None, + provider=self.provider_name, + raw_response=result, + ) + + except Exception as e: + logger.error(f"Google translation failed: {e}") + raise + + async def translate_batch( + self, + texts: list[str], + target_language: str = "en", + source_language: Optional[str] = None, + ) -> list[TranslationResult]: + """ + Translate multiple texts using Google Translate API. + + Google Translate supports batch translation natively. + + Args: + texts: List of texts to translate + target_language: Target language code + source_language: Source language code (optional) + + Returns: + List of TranslationResult objects + """ + try: + client = self._get_client() + loop = asyncio.get_running_loop() + + # Build kwargs for translate call + translate_kwargs = {"target_language": target_language} + if source_language: + translate_kwargs["source_language"] = source_language + + results = await loop.run_in_executor( + None, + lambda: client.translate(texts, **translate_kwargs), + ) + + translation_results = [] + for result in results: + detected_language = result.get( + "detectedSourceLanguage", source_language or "unknown" + ) + translation_results.append( + TranslationResult( + translated_text=result["translatedText"], + source_language=detected_language, + target_language=target_language, + # Google Translate API does not provide confidence scores + confidence_score=None, + provider=self.provider_name, + raw_response=result, + ) + ) + + return translation_results + + except Exception as e: + logger.error(f"Google batch translation failed: {e}") + raise diff --git a/cognee/tasks/translation/providers/llm_provider.py b/cognee/tasks/translation/providers/llm_provider.py new file mode 100644 index 000000000..2e92811ee --- /dev/null +++ b/cognee/tasks/translation/providers/llm_provider.py @@ -0,0 +1,143 @@ +import asyncio +from typing import Optional + +from pydantic import BaseModel + +from cognee.infrastructure.llm.LLMGateway import LLMGateway +from cognee.infrastructure.llm.config import get_llm_config +from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.shared.logging_utils import get_logger + +from .base import TranslationProvider, TranslationResult + +logger = get_logger(__name__) + + +class TranslationOutput(BaseModel): + """Pydantic model for structured translation output from LLM.""" + + translated_text: str + detected_source_language: str + translation_notes: Optional[str] = None + + +class LLMTranslationProvider(TranslationProvider): + """ + Translation provider using the configured LLM for translation. + + This provider leverages the existing LLM infrastructure in Cognee + to perform translations using any LLM configured via LLM_PROVIDER + (OpenAI, Azure, Ollama, Anthropic, etc.). + + The LLM used is determined by the cognee LLM configuration settings: + - LLM_PROVIDER: The LLM provider (openai, azure, ollama, etc.) + - LLM_MODEL: The model to use + - LLM_API_KEY: API key for the provider + """ + + @property + def provider_name(self) -> str: + """Return 'llm' as the provider name.""" + return "llm" + + async def translate( + self, + text: str, + target_language: str = "en", + source_language: Optional[str] = None, + ) -> TranslationResult: + """ + Translate text using the configured LLM. + + Args: + text: The text to translate + target_language: Target language code (default: "en") + source_language: Source language code (optional) + + Returns: + TranslationResult with translated text and metadata + """ + try: + system_prompt = read_query_prompt("translate_content.txt") + + # Validate system prompt was loaded successfully + if system_prompt is None: + logger.warning("translate_content.txt prompt file not found, using default prompt") + system_prompt = ( + "You are a professional translator. Translate the given text accurately " + "while preserving the original meaning, tone, and style. " + "Detect the source language if not provided." + ) + + # Build the input with context + if source_language: + input_text = ( + f"Translate the following text from {source_language} to {target_language}.\n\n" + f"Text to translate:\n{text}" + ) + else: + input_text = ( + f"Translate the following text to {target_language}. " + f"First detect the source language.\n\n" + f"Text to translate:\n{text}" + ) + + result = await LLMGateway.acreate_structured_output( + text_input=input_text, + system_prompt=system_prompt, + response_model=TranslationOutput, + ) + + return TranslationResult( + translated_text=result.translated_text, + source_language=source_language or result.detected_source_language, + target_language=target_language, + # TODO: Consider deriving confidence from LLM response metadata + # or making configurable via TranslationConfig + confidence_score=0.95, # LLM translations are generally high quality + provider=self.provider_name, + raw_response={"notes": result.translation_notes}, + ) + + except Exception as e: + logger.error(f"LLM translation failed: {e}") + raise + + async def translate_batch( + self, + texts: list[str], + target_language: str = "en", + source_language: Optional[str] = None, + max_concurrent: int = 5, + ) -> list[TranslationResult]: + """ + Translate multiple texts using the configured LLM. + + Uses a semaphore to limit concurrent requests and avoid API rate limits. + + Args: + texts: List of texts to translate + target_language: Target language code + source_language: Source language code (optional) + max_concurrent: Maximum concurrent translation requests (default: 5) + + Returns: + List of TranslationResult objects + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def limited_translate(text: str) -> TranslationResult: + async with semaphore: + return await self.translate(text, target_language, source_language) + + tasks = [limited_translate(text) for text in texts] + return await asyncio.gather(*tasks) + + def is_available(self) -> bool: + """Check if LLM provider is available (has required credentials).""" + try: + llm_config = get_llm_config() + # Check if API key is configured (required for most providers) + return bool(llm_config.llm_api_key) + except Exception: + return False diff --git a/cognee/tasks/translation/translate_content.py b/cognee/tasks/translation/translate_content.py new file mode 100644 index 000000000..984082469 --- /dev/null +++ b/cognee/tasks/translation/translate_content.py @@ -0,0 +1,282 @@ +import asyncio +from typing import List, Optional +from uuid import uuid5 + +from cognee.modules.chunking.models import DocumentChunk +from cognee.shared.logging_utils import get_logger + +from .config import get_translation_config, TranslationProviderType +from .detect_language import detect_language_async, LanguageDetectionResult +from .exceptions import TranslationError, LanguageDetectionError +from .models import TranslatedContent, LanguageMetadata +from .providers import get_translation_provider, TranslationResult + +logger = get_logger(__name__) + + +async def translate_content( + data_chunks: List[DocumentChunk], + target_language: str = None, + translation_provider: TranslationProviderType = None, + confidence_threshold: float = None, + skip_if_target_language: bool = True, + preserve_original: bool = True, +) -> List[DocumentChunk]: + """ + Translate non-English content to the target language. + + This task detects the language of each document chunk and translates + non-target-language content using the specified translation provider. + Original text is preserved alongside translated versions. + + Args: + data_chunks: List of DocumentChunk objects to process + target_language: Target language code (default: "en" for English) + If not provided, uses config default + translation_provider: Translation service to use ("llm", "google", "azure") + If not provided, uses config default + confidence_threshold: Minimum confidence for language detection (0.0 to 1.0) + If not provided, uses config default + skip_if_target_language: If True, skip chunks already in target language + preserve_original: If True, store original text in TranslatedContent + + Returns: + List of DocumentChunk objects with translated content. + Chunks that required translation will have TranslatedContent + objects in their 'contains' list. + + Note: + This function mutates the input chunks in-place. Specifically: + - chunk.text is replaced with the translated text + - chunk.contains is updated with LanguageMetadata and TranslatedContent + The original text is preserved in TranslatedContent.original_text + if preserve_original=True. + + Example: + ```python + from cognee.tasks.translation import translate_content + + # Translate chunks using default settings + translated_chunks = await translate_content(chunks) + + # Translate with specific provider + translated_chunks = await translate_content( + chunks, + translation_provider="llm", + confidence_threshold=0.9 + ) + ``` + """ + if not isinstance(data_chunks, list): + raise TranslationError("data_chunks must be a list") + + if len(data_chunks) == 0: + return data_chunks + + # Get configuration + config = get_translation_config() + provider_name = translation_provider or config.translation_provider + target_lang = target_language or config.target_language + threshold = confidence_threshold or config.confidence_threshold + + logger.info( + f"Starting translation task for {len(data_chunks)} chunks " + f"using {provider_name} provider, target language: {target_lang}" + ) + + # Get the translation provider + provider = get_translation_provider(provider_name) + + # Process chunks + processed_chunks = [] + total_chunks = len(data_chunks) + + for chunk_index, chunk in enumerate(data_chunks): + # Log progress for large batches + if chunk_index > 0 and chunk_index % 100 == 0: + logger.info(f"Translation progress: {chunk_index}/{total_chunks} chunks processed") + + if not hasattr(chunk, "text") or not chunk.text: + processed_chunks.append(chunk) + continue + + try: + # Detect language + detection = await detect_language_async(chunk.text, target_lang, threshold) + + # Create language metadata + language_metadata = LanguageMetadata( + id=uuid5(chunk.id, "LanguageMetadata"), + content_id=chunk.id, + detected_language=detection.language_code, + language_confidence=detection.confidence, + requires_translation=detection.requires_translation, + character_count=detection.character_count, + language_name=detection.language_name, + ) + + # Skip if already in target language + if not detection.requires_translation: + if skip_if_target_language: + logger.debug( + f"Skipping chunk {chunk.id}: already in target language " + f"({detection.language_code})" + ) + # Add language metadata to chunk + _add_to_chunk_contains(chunk, language_metadata) + processed_chunks.append(chunk) + continue + + # Translate the content + logger.debug( + f"Translating chunk {chunk.id} from {detection.language_code} to {target_lang}" + ) + + translation_result = await provider.translate( + text=chunk.text, + target_language=target_lang, + source_language=detection.language_code, + ) + + # Create TranslatedContent data point + translated_content = TranslatedContent( + id=uuid5(chunk.id, "TranslatedContent"), + original_chunk_id=chunk.id, + original_text=chunk.text if preserve_original else "", + translated_text=translation_result.translated_text, + source_language=translation_result.source_language, + target_language=translation_result.target_language, + translation_provider=translation_result.provider, + confidence_score=translation_result.confidence_score, + translated_from=chunk, + ) + + # Update chunk text with translated content + chunk.text = translation_result.translated_text + + # Add metadata to chunk's contains list + _add_to_chunk_contains(chunk, language_metadata) + _add_to_chunk_contains(chunk, translated_content) + + processed_chunks.append(chunk) + + logger.debug( + f"Successfully translated chunk {chunk.id}: " + f"{detection.language_code} -> {target_lang}" + ) + + except LanguageDetectionError as e: + logger.warning(f"Language detection failed for chunk {chunk.id}: {e}") + processed_chunks.append(chunk) + except TranslationError as e: + logger.error(f"Translation failed for chunk {chunk.id}: {e}") + processed_chunks.append(chunk) + except Exception as e: + logger.error(f"Unexpected error processing chunk {chunk.id}: {e}") + processed_chunks.append(chunk) + + logger.info(f"Translation task completed for {len(processed_chunks)} chunks") + return processed_chunks + + +def _add_to_chunk_contains(chunk: DocumentChunk, item) -> None: + """Helper to add an item to a chunk's contains list.""" + if chunk.contains is None: + chunk.contains = [] + chunk.contains.append(item) + + +async def translate_text( + text: str, + target_language: str = None, + translation_provider: TranslationProviderType = None, + source_language: Optional[str] = None, +) -> TranslationResult: + """ + Translate a single text string. + + This is a convenience function for translating individual texts + without creating DocumentChunk objects. + + Args: + text: The text to translate + target_language: Target language code (default: uses config, typically "en") + If not provided, uses config default + translation_provider: Translation service to use + If not provided, uses config default + source_language: Source language code (optional, auto-detected if not provided) + + Returns: + TranslationResult with translated text and metadata + + Example: + ```python + from cognee.tasks.translation import translate_text + + result = await translate_text( + "Bonjour le monde!", + target_language="en" + ) + print(result.translated_text) # "Hello world!" + print(result.source_language) # "fr" + ``` + """ + config = get_translation_config() + provider_name = translation_provider or config.translation_provider + target_lang = target_language or config.target_language + + provider = get_translation_provider(provider_name) + + return await provider.translate( + text=text, + target_language=target_lang, + source_language=source_language, + ) + + +async def batch_translate_texts( + texts: List[str], + target_language: str = None, + translation_provider: TranslationProviderType = None, + source_language: Optional[str] = None, +) -> List[TranslationResult]: + """ + Translate multiple text strings in batch. + + This is more efficient than translating texts individually, + especially for providers that support native batch operations. + + Args: + texts: List of texts to translate + target_language: Target language code (default: uses config, typically "en") + If not provided, uses config default + translation_provider: Translation service to use + If not provided, uses config default + source_language: Source language code (optional) + + Returns: + List of TranslationResult objects + + Example: + ```python + from cognee.tasks.translation import batch_translate_texts + + results = await batch_translate_texts( + ["Hola", "¿Cómo estás?", "Adiós"], + target_language="en" + ) + for result in results: + print(f"{result.source_language}: {result.translated_text}") + ``` + """ + config = get_translation_config() + provider_name = translation_provider or config.translation_provider + target_lang = target_language or config.target_language + + provider = get_translation_provider(provider_name) + + return await provider.translate_batch( + texts=texts, + target_language=target_lang, + source_language=source_language, + ) diff --git a/cognee/tests/tasks/translation/README.md b/cognee/tests/tasks/translation/README.md new file mode 100644 index 000000000..075dc71db --- /dev/null +++ b/cognee/tests/tasks/translation/README.md @@ -0,0 +1,147 @@ +# Translation Task Tests + +Unit and integration tests for the multilingual content translation feature. + +## Test Files + +- **config_test.py** - Tests for translation configuration + - Default configuration + - Provider type validation + - Confidence threshold bounds + - Multiple provider API keys + +- **detect_language_test.py** - Tests for language detection functionality + - English, Spanish, French, German, Chinese detection + - Confidence thresholds + - Edge cases (empty text, short text, mixed languages) + +- **providers_test.py** - Tests for translation provider implementations + - LLM provider basic translation + - Auto-detection of source language + - Batch translation + - Special characters and formatting preservation + - Error handling + +- **translate_content_test.py** - Tests for the main translate_content task + - Basic translation workflow + - Original text preservation + - Multiple chunks processing + - Language metadata creation + - Skip translation for target language + - Confidence threshold customization + +## Running Tests + +### Run all translation tests +```bash +uv run pytest cognee/tests/tasks/translation/ -v +``` + +### Run specific test file +```bash +uv run pytest cognee/tests/tasks/translation/detect_language_test.py -v +``` + +### Run tests directly (without pytest) +```bash +uv run python cognee/tests/tasks/translation/config_test.py +uv run python cognee/tests/tasks/translation/detect_language_test.py +uv run python cognee/tests/tasks/translation/providers_test.py +uv run python cognee/tests/tasks/translation/translate_content_test.py +uv run python cognee/tests/tasks/translation/integration_test.py +``` + +### Run all tests at once +```bash +for f in cognee/tests/tasks/translation/*_test.py; do uv run python "$f"; done +``` + +### Run with coverage +```bash +uv run pytest cognee/tests/tasks/translation/ --cov=cognee.tasks.translation --cov-report=html +``` + +## Prerequisites + +- LLM API key set in environment: `LLM_API_KEY=your_key` +- Tests will be skipped if no API key is available + +**Note:** The translation feature uses the same LLM model configured for other cognee tasks (via `LLM_MODEL` and `LLM_PROVIDER` environment variables). This means any LLM provider supported by cognee (OpenAI, Azure, Anthropic, Ollama, etc.) can be used for translation. + +## Usage Example + +```python +import cognee +from cognee.tasks.translation import translate_text + +# Configure translation (optional - defaults to LLM provider) +cognee.config.set_translation_config( + provider="llm", # Uses configured LLM (default) + target_language="en", # Target language code + confidence_threshold=0.7 # Minimum confidence for language detection +) + +# Translate text directly +result = await translate_text( + text="Bonjour le monde", + target_language="en" +) +print(result.translated_text) # "Hello world" +``` + +### Alternative Translation Providers + +```python +# Use Google Cloud Translate (requires GOOGLE_TRANSLATE_API_KEY) +cognee.config.set_translation_provider("google") + +# Use Azure Translator (requires AZURE_TRANSLATOR_KEY and AZURE_TRANSLATOR_REGION) +cognee.config.set_translation_provider("azure") +``` + +## Test Summary + +| Test File | Tests | Description | +|-----------|-------|-------------| +| config_test.py | 4 | Configuration validation | +| detect_language_test.py | 10 | Language detection | +| providers_test.py | 9 | Translation providers | +| translate_content_test.py | 9 | Content translation task | +| integration_test.py | 2 | Standalone translation tests | +| **Total** | **34** | | + +## Test Categories + +### Configuration (4 tests) +- ✅ Default configuration values +- ✅ Provider type literal validation +- ✅ Confidence threshold bounds +- ✅ Multiple provider API keys + +### Language Detection (10 tests) +- ✅ Multiple language detection (EN, ES, FR, DE, ZH) +- ✅ Confidence scoring +- ✅ Target language matching +- ✅ Short and empty text handling +- ✅ Mixed language detection + +### Translation Providers (9 tests) +- ✅ Provider factory function +- ✅ LLM translation +- ✅ Batch operations +- ✅ Auto source language detection +- ✅ Long text handling +- ✅ Special characters preservation +- ✅ Error handling + +### Content Translation (9 tests) +- ✅ DocumentChunk processing +- ✅ Metadata creation (LanguageMetadata, TranslatedContent) +- ✅ Original text preservation +- ✅ Multiple chunk handling +- ✅ Empty text/list handling +- ✅ Confidence threshold customization + +### Integration (2 tests) +- ✅ Direct translate_text function +- ✅ Language detection functionality diff --git a/cognee/tests/tasks/translation/__init__.py b/cognee/tests/tasks/translation/__init__.py new file mode 100644 index 000000000..7284dcfa5 --- /dev/null +++ b/cognee/tests/tasks/translation/__init__.py @@ -0,0 +1 @@ +"""Translation task tests""" diff --git a/cognee/tests/tasks/translation/config_test.py b/cognee/tests/tasks/translation/config_test.py new file mode 100644 index 000000000..248bf70f3 --- /dev/null +++ b/cognee/tests/tasks/translation/config_test.py @@ -0,0 +1,93 @@ +""" +Unit tests for translation configuration +""" + +from typing import get_args + +from pydantic import ValidationError + +from cognee.tasks.translation.config import ( + get_translation_config, + TranslationConfig, + TranslationProviderType, +) + + +def test_default_translation_config(): + """Test default translation configuration""" + config = get_translation_config() + + assert isinstance(config, TranslationConfig), "Config should be TranslationConfig instance" + assert config.translation_provider in [ + "llm", + "google", + "azure", + ], f"Invalid provider: {config.translation_provider}" + assert 0.0 <= config.confidence_threshold <= 1.0, ( + f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]" + ) + + +def test_translation_provider_type_literal(): + """Test TranslationProviderType Literal type values""" + # Get the allowed values from the Literal type + allowed_values = get_args(TranslationProviderType) + + assert "llm" in allowed_values, "llm should be an allowed provider" + assert "google" in allowed_values, "google should be an allowed provider" + assert "azure" in allowed_values, "azure should be an allowed provider" + assert len(allowed_values) == 3, f"Expected 3 providers, got {len(allowed_values)}" + + +def test_confidence_threshold_bounds(): + """Test confidence threshold validation""" + config = TranslationConfig(translation_provider="llm", confidence_threshold=0.9) + + assert 0.0 <= config.confidence_threshold <= 1.0, ( + f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]" + ) + + +def test_confidence_threshold_validation(): + """Test that invalid confidence thresholds are rejected or clamped""" + # Test boundary values - these should work + config_min = TranslationConfig(translation_provider="llm", confidence_threshold=0.0) + assert config_min.confidence_threshold == 0.0, "Minimum bound (0.0) should be valid" + + config_max = TranslationConfig(translation_provider="llm", confidence_threshold=1.0) + assert config_max.confidence_threshold == 1.0, "Maximum bound (1.0) should be valid" + + # Test invalid values - these should either raise ValidationError or be clamped + try: + config_invalid_low = TranslationConfig( + translation_provider="llm", confidence_threshold=-0.1 + ) + # If no error, verify it was clamped to valid range + assert 0.0 <= config_invalid_low.confidence_threshold <= 1.0, ( + f"Invalid low value should be clamped, got {config_invalid_low.confidence_threshold}" + ) + except ValidationError: + pass # Expected validation error + + try: + config_invalid_high = TranslationConfig( + translation_provider="llm", confidence_threshold=1.5 + ) + # If no error, verify it was clamped to valid range + assert 0.0 <= config_invalid_high.confidence_threshold <= 1.0, ( + f"Invalid high value should be clamped, got {config_invalid_high.confidence_threshold}" + ) + except ValidationError: + pass # Expected validation error + + +def test_multiple_provider_keys(): + """Test configuration with multiple provider API keys""" + config = TranslationConfig( + translation_provider="llm", + google_translate_api_key="google_key", + azure_translator_key="azure_key", + ) + + assert config.google_translate_api_key == "google_key", "Google API key not set correctly" + assert config.azure_translator_key == "azure_key", "Azure API key not set correctly" diff --git a/cognee/tests/tasks/translation/detect_language_test.py b/cognee/tests/tasks/translation/detect_language_test.py new file mode 100644 index 000000000..3845777ba --- /dev/null +++ b/cognee/tests/tasks/translation/detect_language_test.py @@ -0,0 +1,118 @@ +""" +Unit tests for language detection functionality +""" + +import pytest +from cognee.tasks.translation.detect_language import ( + detect_language_async, + LanguageDetectionResult, +) +from cognee.tasks.translation.exceptions import LanguageDetectionError + + +@pytest.mark.asyncio +async def test_detect_english(): + """Test detection of English text""" + result = await detect_language_async("Hello world, this is a test.", target_language="en") + + assert result.language_code == "en" + assert result.requires_translation is False + assert result.confidence > 0.9 + assert result.language_name == "English" + + +@pytest.mark.asyncio +async def test_detect_spanish(): + """Test detection of Spanish text""" + result = await detect_language_async("Hola mundo, esta es una prueba.", target_language="en") + + assert result.language_code == "es" + assert result.requires_translation is True + assert result.confidence > 0.9 + assert result.language_name == "Spanish" + + +@pytest.mark.asyncio +async def test_detect_french(): + """Test detection of French text""" + result = await detect_language_async( + "Bonjour le monde, ceci est un test.", target_language="en" + ) + + assert result.language_code == "fr" + assert result.requires_translation is True + assert result.confidence > 0.9 + assert result.language_name == "French" + + +@pytest.mark.asyncio +async def test_detect_german(): + """Test detection of German text""" + result = await detect_language_async("Hallo Welt, das ist ein Test.", target_language="en") + + assert result.language_code == "de" + assert result.requires_translation is True + assert result.confidence > 0.9 + + +@pytest.mark.asyncio +async def test_detect_chinese(): + """Test detection of Chinese text""" + result = await detect_language_async("你好世界,这是一个测试。", target_language="en") + + assert result.language_code.startswith("zh"), f"Expected Chinese, got {result.language_code}" + assert result.requires_translation is True + assert result.confidence > 0.9 + + +@pytest.mark.asyncio +async def test_already_target_language(): + """Test when text is already in target language""" + result = await detect_language_async("This text is already in English.", target_language="en") + + assert result.requires_translation is False + + +@pytest.mark.asyncio +async def test_short_text(): + """Test detection with very short text""" + result = await detect_language_async("Hi", target_language="es") + + # Short text may return 'unknown' if langdetect can't reliably detect + assert result.language_code in ["en", "unknown"] + assert result.character_count == 2 + + +@pytest.mark.asyncio +async def test_empty_text(): + """Test detection with empty text - returns unknown by default""" + result = await detect_language_async("", target_language="en") + + # With skip_detection_for_short_text=True (default), returns unknown + assert result.language_code == "unknown" + assert result.language_name == "Unknown" + assert result.confidence == 0.0 + assert result.requires_translation is False + assert result.character_count == 0 + + +@pytest.mark.asyncio +async def test_confidence_threshold(): + """Test detection respects confidence threshold""" + result = await detect_language_async( + "Hello world", target_language="es", confidence_threshold=0.5 + ) + + assert result.confidence >= 0.5 + + +@pytest.mark.asyncio +async def test_mixed_language_text(): + """Test detection with mixed language text (predominantly one language)""" + # Predominantly Spanish with English word + result = await detect_language_async( + "La inteligencia artificial es muy importante en technology moderna.", target_language="en" + ) + + assert result.language_code == "es" # Should detect as Spanish + assert result.requires_translation is True diff --git a/cognee/tests/tasks/translation/providers_test.py b/cognee/tests/tasks/translation/providers_test.py new file mode 100644 index 000000000..0573a974f --- /dev/null +++ b/cognee/tests/tasks/translation/providers_test.py @@ -0,0 +1,151 @@ +""" +Unit tests for translation providers +""" + +import os + +import pytest + +from cognee.tasks.translation.providers import ( + get_translation_provider, + LLMTranslationProvider, + TranslationResult, +) +from cognee.tasks.translation.exceptions import TranslationError + + +def has_llm_api_key(): + """Check if LLM API key is available""" + return bool(os.environ.get("LLM_API_KEY")) + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_llm_provider_basic_translation(): + """Test basic translation with LLM provider (uses configured LLM)""" + provider = LLMTranslationProvider() + + result = await provider.translate(text="Hola mundo", target_language="en", source_language="es") + + assert isinstance(result, TranslationResult) + assert result.translated_text is not None + assert len(result.translated_text) > 0 + assert result.source_language == "es" + assert result.target_language == "en" + assert result.provider == "llm" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_llm_provider_auto_detect_source(): + """Test translation with automatic source language detection""" + provider = LLMTranslationProvider() + + result = await provider.translate( + text="Bonjour le monde", + target_language="en", + # source_language not provided - should auto-detect + ) + + assert result.translated_text is not None + assert result.target_language == "en" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_llm_provider_long_text(): + """Test translation of longer text""" + provider = LLMTranslationProvider() + + long_text = """ + La inteligencia artificial es una rama de la informática que se centra en + crear sistemas capaces de realizar tareas que normalmente requieren inteligencia humana. + Estos sistemas pueden aprender, razonar y resolver problemas complejos. + """ + + result = await provider.translate(text=long_text, target_language="en", source_language="es") + + assert len(result.translated_text) > 0 + assert result.source_language == "es" + + +def test_get_translation_provider_factory(): + """Test provider factory function""" + provider = get_translation_provider("llm") + assert isinstance(provider, LLMTranslationProvider) + + +def test_get_translation_provider_invalid(): + """Test provider factory with invalid provider name""" + try: + get_translation_provider("invalid_provider") + assert False, "Expected TranslationError or ValueError" + except (TranslationError, ValueError): + pass + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_llm_batch_translation(): + """Test batch translation with LLM provider""" + provider = LLMTranslationProvider() + + texts = ["Hola", "¿Cómo estás?", "Adiós"] + + results = await provider.translate_batch( + texts=texts, target_language="en", source_language="es" + ) + + assert len(results) == len(texts) + for result in results: + assert isinstance(result, TranslationResult) + assert result.translated_text is not None + assert result.source_language == "es" + assert result.target_language == "en" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_translation_preserves_formatting(): + """Test that translation preserves basic formatting""" + provider = LLMTranslationProvider() + + text_with_newlines = "Primera línea.\nSegunda línea." + + result = await provider.translate( + text=text_with_newlines, target_language="en", source_language="es" + ) + + # Should preserve structure (though exact newlines may vary) + assert result.translated_text is not None + assert len(result.translated_text) > 0 + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_translation_special_characters(): + """Test translation with special characters""" + provider = LLMTranslationProvider() + + text = "¡Hola! ¿Cómo estás? Está bien." + + result = await provider.translate(text=text, target_language="en", source_language="es") + + assert result.translated_text is not None + assert len(result.translated_text) > 0 + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_empty_text_translation(): + """Test translation with empty text - should return empty or handle gracefully""" + provider = LLMTranslationProvider() + + # Empty text may either raise an error or return an empty result + try: + result = await provider.translate(text="", target_language="en", source_language="es") + # If no error, should return a TranslationResult (possibly with empty text) + assert isinstance(result, TranslationResult) + except TranslationError: + # This is also acceptable behavior + pass diff --git a/cognee/tests/tasks/translation/translate_content_test.py b/cognee/tests/tasks/translation/translate_content_test.py new file mode 100644 index 000000000..87fa5b67c --- /dev/null +++ b/cognee/tests/tasks/translation/translate_content_test.py @@ -0,0 +1,213 @@ +""" +Unit tests for translate_content task +""" + +import os +from uuid import uuid4 + +import pytest + +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.tasks.translation import translate_content +from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata + + +def has_llm_api_key(): + """Check if LLM API key is available""" + return bool(os.environ.get("LLM_API_KEY")) + + +def create_test_chunk(text: str, chunk_index: int = 0): + """Helper to create a DocumentChunk with all required fields""" + # Create a minimal Document for the is_part_of field + doc = TextDocument( + id=uuid4(), + name="test_doc", + raw_data_location="/tmp/test.txt", + external_metadata=None, + mime_type="text/plain", + ) + + return DocumentChunk( + id=uuid4(), + text=text, + chunk_index=chunk_index, + chunk_size=len(text), + cut_type="sentence", + is_part_of=doc, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_translate_content_basic(): + """Test basic content translation""" + # Create test chunk with Spanish text + original_text = "Hola mundo, esta es una prueba." + chunk = create_test_chunk(original_text) + + result = await translate_content( + data_chunks=[chunk], target_language="en", translation_provider="llm" + ) + + assert len(result) == 1 + # The chunk's text should now be translated (different from original Spanish) + assert result[0].text != original_text # Text should be translated to English + assert result[0].contains is not None + + # Check for TranslatedContent in contains + has_translated_content = any(isinstance(item, TranslatedContent) for item in result[0].contains) + assert has_translated_content + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_translate_content_preserves_original(): + """Test that original text is preserved""" + original_text = "Bonjour le monde" + chunk = create_test_chunk(original_text) + + result = await translate_content( + data_chunks=[chunk], target_language="en", preserve_original=True + ) + + # Find TranslatedContent in contains + translated_content = None + for item in result[0].contains: + if isinstance(item, TranslatedContent): + translated_content = item + break + + assert translated_content is not None + assert translated_content.original_text == original_text + assert translated_content.translated_text != original_text + + +@pytest.mark.asyncio +async def test_translate_content_skip_english(): + """Test skipping translation for English text""" + # This test doesn't require API call since English text is skipped + chunk = create_test_chunk("Hello world, this is a test.") + + result = await translate_content( + data_chunks=[chunk], target_language="en", skip_if_target_language=True + ) + + # Text should remain unchanged + assert result[0].text == chunk.text + + # Should have LanguageMetadata but not TranslatedContent + has_language_metadata = any( + isinstance(item, LanguageMetadata) for item in (result[0].contains or []) + ) + has_translated_content = any( + isinstance(item, TranslatedContent) for item in (result[0].contains or []) + ) + + assert has_language_metadata + assert not has_translated_content + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_translate_content_multiple_chunks(): + """Test translation of multiple chunks""" + # Use longer texts to ensure reliable language detection + original_texts = [ + "Hola mundo, esta es una prueba de traducción.", + "Bonjour le monde, ceci est un test de traduction.", + "Ciao mondo, questo è un test di traduzione.", + ] + chunks = [create_test_chunk(text, i) for i, text in enumerate(original_texts)] + + result = await translate_content(data_chunks=chunks, target_language="en") + + assert len(result) == 3 + # Check that at least some chunks were translated + translated_count = sum( + 1 + for chunk in result + if any(isinstance(item, TranslatedContent) for item in (chunk.contains or [])) + ) + assert translated_count >= 2 # At least 2 chunks should be translated + + +@pytest.mark.asyncio +async def test_translate_content_empty_list(): + """Test with empty chunk list""" + result = await translate_content(data_chunks=[], target_language="en") + + assert result == [] + + +@pytest.mark.asyncio +async def test_translate_content_empty_text(): + """Test with chunk containing empty text""" + chunk = create_test_chunk("") + + result = await translate_content(data_chunks=[chunk], target_language="en") + + assert len(result) == 1 + assert result[0].text == "" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_translate_content_language_metadata(): + """Test that LanguageMetadata is created correctly""" + # Use a longer, distinctly Spanish text to ensure reliable detection + chunk = create_test_chunk( + "La inteligencia artificial está cambiando el mundo de manera significativa" + ) + + result = await translate_content(data_chunks=[chunk], target_language="en") + + # Find LanguageMetadata + language_metadata = None + for item in result[0].contains: + if isinstance(item, LanguageMetadata): + language_metadata = item + break + + assert language_metadata is not None + # Just check that a language was detected (short texts can be ambiguous) + assert language_metadata.detected_language is not None + assert language_metadata.requires_translation is True + assert language_metadata.language_confidence > 0.0 + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_translate_content_confidence_threshold(): + """Test with custom confidence threshold""" + # Use longer text for more reliable detection + chunk = create_test_chunk("Hola mundo, esta es una frase más larga para mejor detección") + + result = await translate_content( + data_chunks=[chunk], target_language="en", confidence_threshold=0.5 + ) + + assert len(result) == 1 + + +@pytest.mark.asyncio +@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available") +async def test_translate_content_no_preserve_original(): + """Test translation without preserving original""" + # Use longer text for more reliable detection + chunk = create_test_chunk("Bonjour le monde, comment allez-vous aujourd'hui") + + result = await translate_content( + data_chunks=[chunk], target_language="en", preserve_original=False + ) + + # Find TranslatedContent + translated_content = None + for item in result[0].contains: + if isinstance(item, TranslatedContent): + translated_content = item + break + + assert translated_content is not None + assert translated_content.original_text == "" # Should be empty diff --git a/pyproject.toml b/pyproject.toml index 1fff69c85..d9f45f651 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,8 +61,8 @@ dependencies = [ "diskcache>=5.6.3", "aiolimiter>=1.2.1", "urllib3>=2.6.0", - "cbor2>=5.8.0" - + "cbor2>=5.8.0", + "langdetect>=1.0.9", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 812315288..fbad993db 100644 --- a/uv.lock +++ b/uv.lock @@ -934,6 +934,7 @@ dependencies = [ { name = "jinja2" }, { name = "kuzu" }, { name = "lancedb" }, + { name = "langdetect" }, { name = "limits" }, { name = "litellm" }, { name = "mistralai" }, @@ -1134,6 +1135,7 @@ requires-dist = [ { name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.22" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = ">=1.2.5" }, { name = "langchain-text-splitters", marker = "extra == 'langchain'", specifier = ">=0.3.2,<1.0.0" }, + { name = "langdetect", specifier = ">=1.0.9" }, { name = "langfuse", marker = "extra == 'monitoring'", specifier = ">=2.32.0,<3" }, { name = "langsmith", marker = "extra == 'langchain'", specifier = ">=0.2.3,<1.0.0" }, { name = "limits", specifier = ">=4.4.1,<5" },