Merge 3dca104bdf into 2ef347f8fa
This commit is contained in:
commit
f4c9ebfbc2
23 changed files with 2239 additions and 2 deletions
|
|
@ -145,6 +145,35 @@ VECTOR_DATASET_DATABASE_HANDLER="lancedb"
|
|||
# ONTOLOGY_FILE_PATH=YOUR_FULL_FULE_PATH # Default: empty
|
||||
# To add ontology resolvers, either set them as it is set in ontology_example or add full_path and settings as envs.
|
||||
|
||||
################################################################################
|
||||
# 🌐 Translation Settings
|
||||
################################################################################
|
||||
|
||||
# Translation provider: llm (uses configured LLM), google, or azure
|
||||
# "llm" uses whichever LLM is configured above (OpenAI, Azure, Ollama, Anthropic, etc.)
|
||||
# "google" and "azure" use dedicated translation APIs
|
||||
TRANSLATION_PROVIDER="llm"
|
||||
|
||||
# Default target language for translations (ISO 639-1 code, e.g., en, es, fr, de)
|
||||
TARGET_LANGUAGE="en"
|
||||
|
||||
# Minimum confidence threshold for language detection (0.0 to 1.0)
|
||||
CONFIDENCE_THRESHOLD=0.8
|
||||
|
||||
# -- Google Translate settings (required if using google provider) -----------
|
||||
# GOOGLE_TRANSLATE_API_KEY="your-google-api-key"
|
||||
# GOOGLE_PROJECT_ID="your-google-project-id"
|
||||
|
||||
# -- Azure Translator settings (required if using azure provider) ------------
|
||||
# AZURE_TRANSLATOR_KEY="your-azure-translator-key"
|
||||
# AZURE_TRANSLATOR_REGION="westeurope"
|
||||
# AZURE_TRANSLATOR_ENDPOINT="https://api.cognitive.microsofttranslator.com"
|
||||
|
||||
# -- Performance settings ----------------------------------------------------
|
||||
# TRANSLATION_BATCH_SIZE=10
|
||||
# TRANSLATION_MAX_RETRIES=3
|
||||
# TRANSLATION_TIMEOUT_SECONDS=30
|
||||
|
||||
################################################################################
|
||||
# 🔄 MIGRATION (RELATIONAL → GRAPH) SETTINGS
|
||||
################################################################################
|
||||
|
|
|
|||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -6,6 +6,9 @@ cognee/.data/
|
|||
|
||||
code_pipeline_output*/
|
||||
|
||||
# Test output files
|
||||
test_outputs/
|
||||
|
||||
*.lance/
|
||||
.DS_Store
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.llm.config import (
|
|||
get_llm_config,
|
||||
)
|
||||
from cognee.infrastructure.databases.relational import get_relational_config, get_migration_config
|
||||
from cognee.tasks.translation.config import get_translation_config
|
||||
from cognee.api.v1.exceptions.exceptions import InvalidConfigAttributeError
|
||||
|
||||
|
||||
|
|
@ -176,3 +177,29 @@ class config:
|
|||
def set_vector_db_url(db_url: str):
|
||||
vector_db_config = get_vectordb_config()
|
||||
vector_db_config.vector_db_url = db_url
|
||||
|
||||
# Translation configuration methods
|
||||
|
||||
@staticmethod
|
||||
def set_translation_provider(provider: str):
|
||||
"""Set the translation provider (llm, google, azure)."""
|
||||
translation_config = get_translation_config()
|
||||
translation_config.translation_provider = provider
|
||||
|
||||
@staticmethod
|
||||
def set_translation_target_language(target_language: str):
|
||||
"""Set the default target language for translations."""
|
||||
translation_config = get_translation_config()
|
||||
translation_config.target_language = target_language
|
||||
|
||||
@staticmethod
|
||||
def set_translation_config(config_dict: dict):
|
||||
"""
|
||||
Updates the translation config with values from config_dict.
|
||||
"""
|
||||
translation_config = get_translation_config()
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(translation_config, key):
|
||||
object.__setattr__(translation_config, key, value)
|
||||
else:
|
||||
raise InvalidConfigAttributeError(attribute=key)
|
||||
|
|
|
|||
19
cognee/infrastructure/llm/prompts/translate_content.txt
Normal file
19
cognee/infrastructure/llm/prompts/translate_content.txt
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
You are an expert translator with deep knowledge of languages, cultures, and linguistics.
|
||||
|
||||
Your task is to:
|
||||
1. Detect the source language of the provided text if not specified
|
||||
2. Translate the text accurately to the target language
|
||||
3. Preserve the original meaning, tone, and intent
|
||||
4. Maintain proper grammar and natural phrasing in the target language
|
||||
|
||||
Guidelines:
|
||||
- Preserve technical terms, proper nouns, and specialized vocabulary appropriately
|
||||
- Maintain formatting such as paragraphs, lists, and emphasis where applicable
|
||||
- If the text contains code, URLs, or other non-translatable content, preserve them as-is
|
||||
- Handle idioms and cultural references thoughtfully, adapting when necessary
|
||||
- Ensure the translation reads naturally to a native speaker of the target language
|
||||
|
||||
Provide the translation in a structured format with:
|
||||
- The translated text
|
||||
- The detected source language (ISO 639-1 code like "en", "es", "fr", "de", etc.)
|
||||
- Any notes about the translation (optional, for ambiguous terms or cultural adaptations)
|
||||
96
cognee/tasks/translation/__init__.py
Normal file
96
cognee/tasks/translation/__init__.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Translation task for Cognee.
|
||||
|
||||
This module provides multilingual content translation capabilities,
|
||||
allowing automatic detection and translation of non-English content
|
||||
to a target language while preserving original text and metadata.
|
||||
|
||||
Main Components:
|
||||
- translate_content: Main task function for translating document chunks
|
||||
- translate_text: Convenience function for translating single texts
|
||||
- batch_translate_texts: Batch translation for multiple texts
|
||||
- detect_language: Language detection utility
|
||||
- TranslatedContent: DataPoint model for translated content
|
||||
- LanguageMetadata: DataPoint model for language information
|
||||
|
||||
Supported Translation Providers:
|
||||
- LLM (default): Uses the configured LLM via existing infrastructure
|
||||
- Google Translate: Requires google-cloud-translate package
|
||||
- Azure Translator: Requires Azure Translator API key
|
||||
|
||||
Example Usage:
|
||||
```python
|
||||
from cognee.tasks.translation import translate_content, translate_text
|
||||
|
||||
# Translate document chunks in a pipeline
|
||||
translated_chunks = await translate_content(
|
||||
chunks,
|
||||
target_language="en",
|
||||
translation_provider="llm"
|
||||
)
|
||||
|
||||
# Translate a single text
|
||||
result = await translate_text("Bonjour le monde!")
|
||||
print(result.translated_text) # "Hello world!"
|
||||
```
|
||||
"""
|
||||
|
||||
from .config import get_translation_config, TranslationConfig
|
||||
from .detect_language import (
|
||||
detect_language,
|
||||
detect_language_async,
|
||||
LanguageDetectionResult,
|
||||
get_language_name,
|
||||
)
|
||||
from .exceptions import (
|
||||
TranslationError,
|
||||
LanguageDetectionError,
|
||||
TranslationProviderError,
|
||||
UnsupportedLanguageError,
|
||||
TranslationConfigError,
|
||||
)
|
||||
from .models import TranslatedContent, LanguageMetadata
|
||||
from .providers import (
|
||||
TranslationProvider,
|
||||
TranslationResult,
|
||||
get_translation_provider,
|
||||
LLMTranslationProvider,
|
||||
GoogleTranslationProvider,
|
||||
AzureTranslationProvider,
|
||||
)
|
||||
from .translate_content import (
|
||||
translate_content,
|
||||
translate_text,
|
||||
batch_translate_texts,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Main task functions
|
||||
"translate_content",
|
||||
"translate_text",
|
||||
"batch_translate_texts",
|
||||
# Language detection
|
||||
"detect_language",
|
||||
"detect_language_async",
|
||||
"LanguageDetectionResult",
|
||||
"get_language_name",
|
||||
# Models
|
||||
"TranslatedContent",
|
||||
"LanguageMetadata",
|
||||
# Configuration
|
||||
"get_translation_config",
|
||||
"TranslationConfig",
|
||||
# Providers
|
||||
"TranslationProvider",
|
||||
"TranslationResult",
|
||||
"get_translation_provider",
|
||||
"LLMTranslationProvider",
|
||||
"GoogleTranslationProvider",
|
||||
"AzureTranslationProvider",
|
||||
# Exceptions
|
||||
"TranslationError",
|
||||
"LanguageDetectionError",
|
||||
"TranslationProviderError",
|
||||
"UnsupportedLanguageError",
|
||||
"TranslationConfigError",
|
||||
]
|
||||
110
cognee/tasks/translation/config.py
Normal file
110
cognee/tasks/translation/config.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
from functools import lru_cache
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
TranslationProviderType = Literal["llm", "google", "azure"]
|
||||
|
||||
|
||||
class TranslationConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for the translation task.
|
||||
|
||||
Environment variables can be used to configure these settings:
|
||||
- TRANSLATION_PROVIDER: The translation service to use ("llm", "google", "azure")
|
||||
- TARGET_LANGUAGE: Default target language (ISO 639-1 code, e.g., "en", "es", "fr")
|
||||
- CONFIDENCE_THRESHOLD: Minimum confidence for language detection (0.0 to 1.0)
|
||||
- GOOGLE_TRANSLATE_API_KEY: API key for Google Translate
|
||||
- GOOGLE_PROJECT_ID: Google Cloud project ID
|
||||
- AZURE_TRANSLATOR_KEY: API key for Azure Translator
|
||||
- AZURE_TRANSLATOR_REGION: Region for Azure Translator
|
||||
- AZURE_TRANSLATOR_ENDPOINT: Endpoint URL for Azure Translator
|
||||
- TRANSLATION_BATCH_SIZE: Number of texts to translate per batch
|
||||
- TRANSLATION_MAX_RETRIES: Maximum retry attempts on failure
|
||||
- TRANSLATION_TIMEOUT_SECONDS: Request timeout in seconds
|
||||
"""
|
||||
|
||||
# Translation provider settings
|
||||
translation_provider: TranslationProviderType = Field(
|
||||
default="llm",
|
||||
validation_alias=AliasChoices("TRANSLATION_PROVIDER", "translation_provider"),
|
||||
)
|
||||
target_language: str = Field(
|
||||
default="en",
|
||||
validation_alias=AliasChoices("TARGET_LANGUAGE", "target_language"),
|
||||
)
|
||||
confidence_threshold: float = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
validation_alias=AliasChoices("CONFIDENCE_THRESHOLD", "confidence_threshold"),
|
||||
)
|
||||
|
||||
# Google Translate settings
|
||||
google_translate_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("GOOGLE_TRANSLATE_API_KEY", "google_translate_api_key"),
|
||||
)
|
||||
google_project_id: Optional[str] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("GOOGLE_PROJECT_ID", "google_project_id"),
|
||||
)
|
||||
|
||||
# Azure Translator settings
|
||||
azure_translator_key: Optional[str] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("AZURE_TRANSLATOR_KEY", "azure_translator_key"),
|
||||
)
|
||||
azure_translator_region: Optional[str] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("AZURE_TRANSLATOR_REGION", "azure_translator_region"),
|
||||
)
|
||||
azure_translator_endpoint: str = Field(
|
||||
default="https://api.cognitive.microsofttranslator.com",
|
||||
validation_alias=AliasChoices("AZURE_TRANSLATOR_ENDPOINT", "azure_translator_endpoint"),
|
||||
)
|
||||
|
||||
# LLM provider uses the existing LLM configuration
|
||||
|
||||
# Performance settings (with TRANSLATION_ prefix for env vars)
|
||||
batch_size: int = Field(
|
||||
default=10,
|
||||
validation_alias=AliasChoices("TRANSLATION_BATCH_SIZE", "batch_size"),
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
validation_alias=AliasChoices("TRANSLATION_MAX_RETRIES", "max_retries"),
|
||||
)
|
||||
timeout_seconds: int = Field(
|
||||
default=30,
|
||||
validation_alias=AliasChoices("TRANSLATION_TIMEOUT_SECONDS", "timeout_seconds"),
|
||||
)
|
||||
|
||||
# Language detection settings
|
||||
min_text_length_for_detection: int = 10
|
||||
skip_detection_for_short_text: bool = True
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"translation_provider": self.translation_provider,
|
||||
"target_language": self.target_language,
|
||||
"confidence_threshold": self.confidence_threshold,
|
||||
"batch_size": self.batch_size,
|
||||
"max_retries": self.max_retries,
|
||||
"timeout_seconds": self.timeout_seconds,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_translation_config() -> TranslationConfig:
|
||||
"""Get the translation configuration singleton."""
|
||||
return TranslationConfig()
|
||||
|
||||
|
||||
def clear_translation_config_cache():
|
||||
"""Clear the cached config for testing purposes."""
|
||||
get_translation_config.cache_clear()
|
||||
190
cognee/tasks/translation/detect_language.py
Normal file
190
cognee/tasks/translation/detect_language.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .config import get_translation_config
|
||||
from .exceptions import LanguageDetectionError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# ISO 639-1 language code to name mapping
|
||||
LANGUAGE_NAMES = {
|
||||
"af": "Afrikaans",
|
||||
"ar": "Arabic",
|
||||
"bg": "Bulgarian",
|
||||
"bn": "Bengali",
|
||||
"ca": "Catalan",
|
||||
"cs": "Czech",
|
||||
"cy": "Welsh",
|
||||
"da": "Danish",
|
||||
"de": "German",
|
||||
"el": "Greek",
|
||||
"en": "English",
|
||||
"es": "Spanish",
|
||||
"et": "Estonian",
|
||||
"fa": "Persian",
|
||||
"fi": "Finnish",
|
||||
"fr": "French",
|
||||
"gu": "Gujarati",
|
||||
"he": "Hebrew",
|
||||
"hi": "Hindi",
|
||||
"hr": "Croatian",
|
||||
"hu": "Hungarian",
|
||||
"id": "Indonesian",
|
||||
"it": "Italian",
|
||||
"ja": "Japanese",
|
||||
"kn": "Kannada",
|
||||
"ko": "Korean",
|
||||
"lt": "Lithuanian",
|
||||
"lv": "Latvian",
|
||||
"mk": "Macedonian",
|
||||
"ml": "Malayalam",
|
||||
"mr": "Marathi",
|
||||
"ne": "Nepali",
|
||||
"nl": "Dutch",
|
||||
"no": "Norwegian",
|
||||
"pa": "Punjabi",
|
||||
"pl": "Polish",
|
||||
"pt": "Portuguese",
|
||||
"ro": "Romanian",
|
||||
"ru": "Russian",
|
||||
"sk": "Slovak",
|
||||
"sl": "Slovenian",
|
||||
"so": "Somali",
|
||||
"sq": "Albanian",
|
||||
"sv": "Swedish",
|
||||
"sw": "Swahili",
|
||||
"ta": "Tamil",
|
||||
"te": "Telugu",
|
||||
"th": "Thai",
|
||||
"tl": "Tagalog",
|
||||
"tr": "Turkish",
|
||||
"uk": "Ukrainian",
|
||||
"ur": "Urdu",
|
||||
"vi": "Vietnamese",
|
||||
"zh-cn": "Chinese (Simplified)",
|
||||
"zh-tw": "Chinese (Traditional)",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageDetectionResult:
|
||||
"""Result of language detection."""
|
||||
|
||||
language_code: str
|
||||
language_name: str
|
||||
confidence: float
|
||||
requires_translation: bool
|
||||
character_count: int
|
||||
|
||||
|
||||
def get_language_name(language_code: str) -> str:
|
||||
"""Get the human-readable name for a language code."""
|
||||
return LANGUAGE_NAMES.get(language_code.lower(), language_code)
|
||||
|
||||
|
||||
def detect_language(
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
confidence_threshold: Optional[float] = None,
|
||||
) -> LanguageDetectionResult:
|
||||
"""
|
||||
Detect the language of the given text.
|
||||
|
||||
Uses the langdetect library which is already a dependency of cognee.
|
||||
|
||||
Args:
|
||||
text: The text to analyze
|
||||
target_language: The target language for translation comparison
|
||||
confidence_threshold: Minimum confidence to consider detection reliable
|
||||
|
||||
Returns:
|
||||
LanguageDetectionResult with language info and translation requirement
|
||||
|
||||
Raises:
|
||||
LanguageDetectionError: If language detection fails
|
||||
"""
|
||||
config = get_translation_config()
|
||||
threshold = confidence_threshold or config.confidence_threshold
|
||||
|
||||
# Handle empty or very short text
|
||||
if not text or len(text.strip()) < config.min_text_length_for_detection:
|
||||
if config.skip_detection_for_short_text:
|
||||
return LanguageDetectionResult(
|
||||
language_code="unknown",
|
||||
language_name="Unknown",
|
||||
confidence=0.0,
|
||||
requires_translation=False,
|
||||
character_count=len(text) if text else 0,
|
||||
)
|
||||
else:
|
||||
raise LanguageDetectionError(
|
||||
f"Text too short for reliable language detection: {len(text)} characters"
|
||||
)
|
||||
|
||||
try:
|
||||
from langdetect import detect_langs, LangDetectException
|
||||
except ImportError:
|
||||
raise LanguageDetectionError(
|
||||
"langdetect is required for language detection. Install it with: pip install langdetect"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get detection results with probabilities
|
||||
detections = detect_langs(text)
|
||||
|
||||
if not detections:
|
||||
raise LanguageDetectionError("No language detected")
|
||||
|
||||
# Get the most likely language
|
||||
best_detection = detections[0]
|
||||
language_code = best_detection.lang
|
||||
confidence = best_detection.prob
|
||||
|
||||
# Check if translation is needed
|
||||
requires_translation = (
|
||||
language_code.lower() != target_language.lower() and confidence >= threshold
|
||||
)
|
||||
|
||||
return LanguageDetectionResult(
|
||||
language_code=language_code,
|
||||
language_name=get_language_name(language_code),
|
||||
confidence=confidence,
|
||||
requires_translation=requires_translation,
|
||||
character_count=len(text),
|
||||
)
|
||||
|
||||
except LangDetectException as e:
|
||||
logger.warning(f"Language detection failed: {e}")
|
||||
raise LanguageDetectionError(f"Language detection failed: {e}", original_error=e)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during language detection: {e}")
|
||||
raise LanguageDetectionError(
|
||||
f"Unexpected error during language detection: {e}", original_error=e
|
||||
)
|
||||
|
||||
|
||||
async def detect_language_async(
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
confidence_threshold: Optional[float] = None,
|
||||
) -> LanguageDetectionResult:
|
||||
"""
|
||||
Async wrapper for language detection.
|
||||
|
||||
Args:
|
||||
text: The text to analyze
|
||||
target_language: The target language for translation comparison
|
||||
confidence_threshold: Minimum confidence to consider detection reliable
|
||||
|
||||
Returns:
|
||||
LanguageDetectionResult with language info and translation requirement
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, detect_language, text, target_language, confidence_threshold
|
||||
)
|
||||
62
cognee/tasks/translation/exceptions.py
Normal file
62
cognee/tasks/translation/exceptions.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
class TranslationError(Exception):
|
||||
"""Base exception for translation errors."""
|
||||
|
||||
def __init__(self, message: str, original_error: Exception = None):
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
super().__init__(self.message)
|
||||
if original_error:
|
||||
self.__cause__ = original_error
|
||||
|
||||
|
||||
class LanguageDetectionError(TranslationError):
|
||||
"""Exception raised when language detection fails."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = "Failed to detect language", original_error: Exception = None
|
||||
):
|
||||
super().__init__(message, original_error)
|
||||
|
||||
|
||||
class TranslationProviderError(TranslationError):
|
||||
"""Exception raised when the translation provider encounters an error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
message: str = "Translation provider error",
|
||||
original_error: Exception = None,
|
||||
):
|
||||
self.provider = provider
|
||||
full_message = f"[{provider}] {message}"
|
||||
super().__init__(full_message, original_error)
|
||||
|
||||
|
||||
class UnsupportedLanguageError(TranslationError):
|
||||
"""Exception raised when the language is not supported."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language: str,
|
||||
provider: str = None,
|
||||
message: str = None,
|
||||
original_error: Exception = None,
|
||||
):
|
||||
self.language = language
|
||||
self.provider = provider
|
||||
if message is None:
|
||||
message = f"Language '{language}' is not supported"
|
||||
if provider:
|
||||
message += f" by {provider}"
|
||||
super().__init__(message, original_error)
|
||||
|
||||
|
||||
class TranslationConfigError(TranslationError):
|
||||
"""Exception raised when translation configuration is invalid."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid translation configuration",
|
||||
original_error: Exception = None,
|
||||
):
|
||||
super().__init__(message, original_error)
|
||||
72
cognee/tasks/translation/models.py
Normal file
72
cognee/tasks/translation/models.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
|
||||
|
||||
class TranslatedContent(DataPoint):
|
||||
"""
|
||||
Represents translated content with quality metrics.
|
||||
|
||||
This class stores both the original and translated versions of content,
|
||||
along with metadata about the translation process including source and
|
||||
target languages, translation provider used, and confidence scores.
|
||||
|
||||
Instance variables include:
|
||||
|
||||
- original_chunk_id: UUID of the original document chunk
|
||||
- original_text: The original text before translation
|
||||
- translated_text: The translated text content
|
||||
- source_language: Detected or specified source language code (e.g., "es", "fr", "de")
|
||||
- target_language: Target language code for translation (default: "en")
|
||||
- translation_provider: Name of the translation service used
|
||||
- confidence_score: Translation quality/confidence score (0.0 to 1.0)
|
||||
- translation_timestamp: When the translation was performed
|
||||
- translated_from: Reference to the original DocumentChunk
|
||||
"""
|
||||
|
||||
original_chunk_id: UUID
|
||||
original_text: str
|
||||
translated_text: str
|
||||
source_language: str
|
||||
target_language: str = "en"
|
||||
translation_provider: str
|
||||
confidence_score: float
|
||||
translation_timestamp: datetime = None
|
||||
translated_from: Optional[DocumentChunk] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["source_language", "translated_text"]}
|
||||
|
||||
def __init__(self, **data):
|
||||
if data.get("translation_timestamp") is None:
|
||||
data["translation_timestamp"] = datetime.now(timezone.utc)
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class LanguageMetadata(DataPoint):
|
||||
"""
|
||||
Language information for content.
|
||||
|
||||
This class stores metadata about the detected language of content,
|
||||
including confidence scores and whether translation is required.
|
||||
|
||||
Instance variables include:
|
||||
|
||||
- content_id: UUID of the associated content
|
||||
- detected_language: ISO 639-1 language code (e.g., "en", "es", "fr")
|
||||
- language_confidence: Confidence score for language detection (0.0 to 1.0)
|
||||
- requires_translation: Whether the content needs translation
|
||||
- character_count: Number of characters in the content
|
||||
- language_name: Human-readable language name (e.g., "English", "Spanish")
|
||||
"""
|
||||
|
||||
content_id: UUID
|
||||
detected_language: str
|
||||
language_confidence: float
|
||||
requires_translation: bool
|
||||
character_count: int
|
||||
language_name: Optional[str] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["detected_language"]}
|
||||
44
cognee/tasks/translation/providers/__init__.py
Normal file
44
cognee/tasks/translation/providers/__init__.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from .base import TranslationProvider, TranslationResult
|
||||
from .llm_provider import LLMTranslationProvider
|
||||
from .google_provider import GoogleTranslationProvider
|
||||
from .azure_provider import AzureTranslationProvider
|
||||
|
||||
__all__ = [
|
||||
"TranslationProvider",
|
||||
"TranslationResult",
|
||||
"LLMTranslationProvider",
|
||||
"GoogleTranslationProvider",
|
||||
"AzureTranslationProvider",
|
||||
"get_translation_provider",
|
||||
]
|
||||
|
||||
|
||||
def get_translation_provider(provider_name: str) -> TranslationProvider:
|
||||
"""
|
||||
Factory function to get the appropriate translation provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider:
|
||||
- "llm": Uses the configured LLM (OpenAI, Azure, Ollama, Anthropic, etc.)
|
||||
- "google": Uses Google Cloud Translation API
|
||||
- "azure": Uses Azure Translator API
|
||||
|
||||
Returns:
|
||||
TranslationProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider name is not recognized
|
||||
"""
|
||||
providers = {
|
||||
"llm": LLMTranslationProvider,
|
||||
"google": GoogleTranslationProvider,
|
||||
"azure": AzureTranslationProvider,
|
||||
}
|
||||
|
||||
if provider_name.lower() not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown translation provider: {provider_name}. "
|
||||
f"Available providers: {list(providers.keys())}"
|
||||
)
|
||||
|
||||
return providers[provider_name.lower()]()
|
||||
192
cognee/tasks/translation/providers/azure_provider.py
Normal file
192
cognee/tasks/translation/providers/azure_provider.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .base import TranslationProvider, TranslationResult
|
||||
from ..config import get_translation_config
|
||||
from ..exceptions import TranslationProviderError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AzureTranslationProvider(TranslationProvider):
|
||||
"""
|
||||
Translation provider using Azure Translator API.
|
||||
|
||||
Requires:
|
||||
- AZURE_TRANSLATOR_KEY environment variable
|
||||
- AZURE_TRANSLATOR_REGION environment variable (optional)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._config = get_translation_config()
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "azure"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Azure Translator is available."""
|
||||
return self._config.azure_translator_key is not None
|
||||
|
||||
async def translate(
|
||||
self,
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
source_language: Optional[str] = None,
|
||||
) -> TranslationResult:
|
||||
"""
|
||||
Translate text using Azure Translator API.
|
||||
|
||||
Args:
|
||||
text: The text to translate
|
||||
target_language: Target language code (default: "en")
|
||||
source_language: Source language code (optional)
|
||||
|
||||
Returns:
|
||||
TranslationResult with translated text and metadata
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise TranslationProviderError(
|
||||
provider=self.provider_name,
|
||||
message="Azure Translator API key not configured. Set AZURE_TRANSLATOR_KEY environment variable.",
|
||||
)
|
||||
|
||||
endpoint = f"{self._config.azure_translator_endpoint}/translate"
|
||||
|
||||
params = {
|
||||
"api-version": "3.0",
|
||||
"to": target_language,
|
||||
}
|
||||
if source_language:
|
||||
params["from"] = source_language
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self._config.azure_translator_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self._config.azure_translator_region:
|
||||
headers["Ocp-Apim-Subscription-Region"] = self._config.azure_translator_region
|
||||
|
||||
body = [{"text": text}]
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
endpoint,
|
||||
params=params,
|
||||
headers=headers,
|
||||
json=body,
|
||||
timeout=aiohttp.ClientTimeout(total=self._config.timeout_seconds),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
|
||||
translation = result[0]["translations"][0]
|
||||
detected_language = result[0].get("detectedLanguage", {})
|
||||
|
||||
return TranslationResult(
|
||||
translated_text=translation["text"],
|
||||
source_language=source_language or detected_language.get("language", "unknown"),
|
||||
target_language=target_language,
|
||||
confidence_score=detected_language.get("score", 0.9),
|
||||
provider=self.provider_name,
|
||||
raw_response=result[0],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Azure translation failed: {e}")
|
||||
raise TranslationProviderError(
|
||||
provider=self.provider_name,
|
||||
message=f"Translation failed: {e}",
|
||||
original_error=e,
|
||||
)
|
||||
|
||||
async def translate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
target_language: str = "en",
|
||||
source_language: Optional[str] = None,
|
||||
) -> list[TranslationResult]:
|
||||
"""
|
||||
Translate multiple texts using Azure Translator API.
|
||||
|
||||
Azure Translator supports up to 100 texts per request.
|
||||
|
||||
Args:
|
||||
texts: List of texts to translate
|
||||
target_language: Target language code
|
||||
source_language: Source language code (optional)
|
||||
|
||||
Returns:
|
||||
List of TranslationResult objects
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise TranslationProviderError(
|
||||
provider=self.provider_name,
|
||||
message="Azure Translator API key not configured. Set AZURE_TRANSLATOR_KEY environment variable.",
|
||||
)
|
||||
|
||||
endpoint = f"{self._config.azure_translator_endpoint}/translate"
|
||||
|
||||
params = {
|
||||
"api-version": "3.0",
|
||||
"to": target_language,
|
||||
}
|
||||
if source_language:
|
||||
params["from"] = source_language
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self._config.azure_translator_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self._config.azure_translator_region:
|
||||
headers["Ocp-Apim-Subscription-Region"] = self._config.azure_translator_region
|
||||
|
||||
# Azure supports up to 100 texts per request
|
||||
batch_size = min(100, self._config.batch_size)
|
||||
all_results = []
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
body = [{"text": text} for text in batch]
|
||||
|
||||
async with session.post(
|
||||
endpoint,
|
||||
params=params,
|
||||
headers=headers,
|
||||
json=body,
|
||||
timeout=aiohttp.ClientTimeout(total=self._config.timeout_seconds),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
|
||||
for result in results:
|
||||
translation = result["translations"][0]
|
||||
detected_language = result.get("detectedLanguage", {})
|
||||
|
||||
all_results.append(
|
||||
TranslationResult(
|
||||
translated_text=translation["text"],
|
||||
source_language=source_language
|
||||
or detected_language.get("language", "unknown"),
|
||||
target_language=target_language,
|
||||
confidence_score=detected_language.get("score", 0.9),
|
||||
provider=self.provider_name,
|
||||
raw_response=result,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Azure batch translation failed: {e}")
|
||||
raise TranslationProviderError(
|
||||
provider=self.provider_name,
|
||||
message=f"Batch translation failed: {e}",
|
||||
original_error=e,
|
||||
)
|
||||
|
||||
return all_results
|
||||
85
cognee/tasks/translation/providers/base.py
Normal file
85
cognee/tasks/translation/providers/base.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""
|
||||
Base classes for translation providers.
|
||||
|
||||
This module defines the abstract interface that all translation providers must implement.
|
||||
Providers handle the actual translation of text using external services like OpenAI,
|
||||
Google Translate, or Azure Translator.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationResult:
|
||||
"""Result of a translation operation."""
|
||||
|
||||
translated_text: str
|
||||
source_language: str
|
||||
target_language: str
|
||||
# Confidence score from the provider, or None if not available (e.g., Google Translate)
|
||||
confidence_score: Optional[float]
|
||||
provider: str
|
||||
raw_response: Optional[dict] = None
|
||||
|
||||
|
||||
class TranslationProvider(ABC):
|
||||
"""Abstract base class for translation providers."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""Return the name of this translation provider."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def translate(
|
||||
self,
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
source_language: Optional[str] = None,
|
||||
) -> TranslationResult:
|
||||
"""
|
||||
Translate text to the target language.
|
||||
|
||||
Args:
|
||||
text: The text to translate
|
||||
target_language: Target language code (default: "en")
|
||||
source_language: Source language code (optional, will be auto-detected if not provided)
|
||||
|
||||
Returns:
|
||||
TranslationResult with translated text and metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def translate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
target_language: str = "en",
|
||||
source_language: Optional[str] = None,
|
||||
) -> list[TranslationResult]:
|
||||
"""
|
||||
Translate multiple texts to the target language.
|
||||
|
||||
Args:
|
||||
texts: List of texts to translate
|
||||
target_language: Target language code (default: "en")
|
||||
source_language: Source language code (optional)
|
||||
|
||||
Returns:
|
||||
List of TranslationResult objects
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if this provider is available (has required credentials).
|
||||
|
||||
All providers must implement this method to validate their credentials.
|
||||
|
||||
Returns:
|
||||
True if the provider has valid credentials and is ready to use.
|
||||
"""
|
||||
pass
|
||||
158
cognee/tasks/translation/providers/google_provider.py
Normal file
158
cognee/tasks/translation/providers/google_provider.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .base import TranslationProvider, TranslationResult
|
||||
from ..config import get_translation_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GoogleTranslationProvider(TranslationProvider):
|
||||
"""
|
||||
Translation provider using Google Cloud Translation API.
|
||||
|
||||
Requires:
|
||||
- google-cloud-translate package
|
||||
- GOOGLE_TRANSLATE_API_KEY or GOOGLE_PROJECT_ID environment variable
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
self._config = get_translation_config()
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "google"
|
||||
|
||||
def _get_client(self):
|
||||
"""Lazy initialization of Google Translate client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from google.cloud import translate_v2 as translate
|
||||
|
||||
self._client = translate.Client()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"google-cloud-translate is required for Google translation. "
|
||||
"Install it with: pip install google-cloud-translate"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Google Translate client: {e}")
|
||||
raise
|
||||
return self._client
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Google Translate is available."""
|
||||
try:
|
||||
self._get_client()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Google Translate not available: {e}")
|
||||
return False
|
||||
|
||||
async def translate(
|
||||
self,
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
source_language: Optional[str] = None,
|
||||
) -> TranslationResult:
|
||||
"""
|
||||
Translate text using Google Translate API.
|
||||
|
||||
Args:
|
||||
text: The text to translate
|
||||
target_language: Target language code (default: "en")
|
||||
source_language: Source language code (optional)
|
||||
|
||||
Returns:
|
||||
TranslationResult with translated text and metadata
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
|
||||
# Run in thread pool since google-cloud-translate is synchronous
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Build kwargs for translate call
|
||||
translate_kwargs = {"target_language": target_language}
|
||||
if source_language:
|
||||
translate_kwargs["source_language"] = source_language
|
||||
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.translate(text, **translate_kwargs),
|
||||
)
|
||||
|
||||
detected_language = result.get("detectedSourceLanguage", source_language or "unknown")
|
||||
|
||||
return TranslationResult(
|
||||
translated_text=result["translatedText"],
|
||||
source_language=detected_language,
|
||||
target_language=target_language,
|
||||
# Google Translate API does not provide confidence scores
|
||||
confidence_score=None,
|
||||
provider=self.provider_name,
|
||||
raw_response=result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google translation failed: {e}")
|
||||
raise
|
||||
|
||||
async def translate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
target_language: str = "en",
|
||||
source_language: Optional[str] = None,
|
||||
) -> list[TranslationResult]:
|
||||
"""
|
||||
Translate multiple texts using Google Translate API.
|
||||
|
||||
Google Translate supports batch translation natively.
|
||||
|
||||
Args:
|
||||
texts: List of texts to translate
|
||||
target_language: Target language code
|
||||
source_language: Source language code (optional)
|
||||
|
||||
Returns:
|
||||
List of TranslationResult objects
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Build kwargs for translate call
|
||||
translate_kwargs = {"target_language": target_language}
|
||||
if source_language:
|
||||
translate_kwargs["source_language"] = source_language
|
||||
|
||||
results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.translate(texts, **translate_kwargs),
|
||||
)
|
||||
|
||||
translation_results = []
|
||||
for result in results:
|
||||
detected_language = result.get(
|
||||
"detectedSourceLanguage", source_language or "unknown"
|
||||
)
|
||||
translation_results.append(
|
||||
TranslationResult(
|
||||
translated_text=result["translatedText"],
|
||||
source_language=detected_language,
|
||||
target_language=target_language,
|
||||
# Google Translate API does not provide confidence scores
|
||||
confidence_score=None,
|
||||
provider=self.provider_name,
|
||||
raw_response=result,
|
||||
)
|
||||
)
|
||||
|
||||
return translation_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google batch translation failed: {e}")
|
||||
raise
|
||||
143
cognee/tasks/translation/providers/llm_provider.py
Normal file
143
cognee/tasks/translation/providers/llm_provider.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .base import TranslationProvider, TranslationResult
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TranslationOutput(BaseModel):
|
||||
"""Pydantic model for structured translation output from LLM."""
|
||||
|
||||
translated_text: str
|
||||
detected_source_language: str
|
||||
translation_notes: Optional[str] = None
|
||||
|
||||
|
||||
class LLMTranslationProvider(TranslationProvider):
|
||||
"""
|
||||
Translation provider using the configured LLM for translation.
|
||||
|
||||
This provider leverages the existing LLM infrastructure in Cognee
|
||||
to perform translations using any LLM configured via LLM_PROVIDER
|
||||
(OpenAI, Azure, Ollama, Anthropic, etc.).
|
||||
|
||||
The LLM used is determined by the cognee LLM configuration settings:
|
||||
- LLM_PROVIDER: The LLM provider (openai, azure, ollama, etc.)
|
||||
- LLM_MODEL: The model to use
|
||||
- LLM_API_KEY: API key for the provider
|
||||
"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return 'llm' as the provider name."""
|
||||
return "llm"
|
||||
|
||||
async def translate(
|
||||
self,
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
source_language: Optional[str] = None,
|
||||
) -> TranslationResult:
|
||||
"""
|
||||
Translate text using the configured LLM.
|
||||
|
||||
Args:
|
||||
text: The text to translate
|
||||
target_language: Target language code (default: "en")
|
||||
source_language: Source language code (optional)
|
||||
|
||||
Returns:
|
||||
TranslationResult with translated text and metadata
|
||||
"""
|
||||
try:
|
||||
system_prompt = read_query_prompt("translate_content.txt")
|
||||
|
||||
# Validate system prompt was loaded successfully
|
||||
if system_prompt is None:
|
||||
logger.warning("translate_content.txt prompt file not found, using default prompt")
|
||||
system_prompt = (
|
||||
"You are a professional translator. Translate the given text accurately "
|
||||
"while preserving the original meaning, tone, and style. "
|
||||
"Detect the source language if not provided."
|
||||
)
|
||||
|
||||
# Build the input with context
|
||||
if source_language:
|
||||
input_text = (
|
||||
f"Translate the following text from {source_language} to {target_language}.\n\n"
|
||||
f"Text to translate:\n{text}"
|
||||
)
|
||||
else:
|
||||
input_text = (
|
||||
f"Translate the following text to {target_language}. "
|
||||
f"First detect the source language.\n\n"
|
||||
f"Text to translate:\n{text}"
|
||||
)
|
||||
|
||||
result = await LLMGateway.acreate_structured_output(
|
||||
text_input=input_text,
|
||||
system_prompt=system_prompt,
|
||||
response_model=TranslationOutput,
|
||||
)
|
||||
|
||||
return TranslationResult(
|
||||
translated_text=result.translated_text,
|
||||
source_language=source_language or result.detected_source_language,
|
||||
target_language=target_language,
|
||||
# TODO: Consider deriving confidence from LLM response metadata
|
||||
# or making configurable via TranslationConfig
|
||||
confidence_score=0.95, # LLM translations are generally high quality
|
||||
provider=self.provider_name,
|
||||
raw_response={"notes": result.translation_notes},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM translation failed: {e}")
|
||||
raise
|
||||
|
||||
async def translate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
target_language: str = "en",
|
||||
source_language: Optional[str] = None,
|
||||
max_concurrent: int = 5,
|
||||
) -> list[TranslationResult]:
|
||||
"""
|
||||
Translate multiple texts using the configured LLM.
|
||||
|
||||
Uses a semaphore to limit concurrent requests and avoid API rate limits.
|
||||
|
||||
Args:
|
||||
texts: List of texts to translate
|
||||
target_language: Target language code
|
||||
source_language: Source language code (optional)
|
||||
max_concurrent: Maximum concurrent translation requests (default: 5)
|
||||
|
||||
Returns:
|
||||
List of TranslationResult objects
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def limited_translate(text: str) -> TranslationResult:
|
||||
async with semaphore:
|
||||
return await self.translate(text, target_language, source_language)
|
||||
|
||||
tasks = [limited_translate(text) for text in texts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if LLM provider is available (has required credentials)."""
|
||||
try:
|
||||
llm_config = get_llm_config()
|
||||
# Check if API key is configured (required for most providers)
|
||||
return bool(llm_config.llm_api_key)
|
||||
except Exception:
|
||||
return False
|
||||
282
cognee/tasks/translation/translate_content.py
Normal file
282
cognee/tasks/translation/translate_content.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
import asyncio
|
||||
from typing import List, Optional
|
||||
from uuid import uuid5
|
||||
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .config import get_translation_config, TranslationProviderType
|
||||
from .detect_language import detect_language_async, LanguageDetectionResult
|
||||
from .exceptions import TranslationError, LanguageDetectionError
|
||||
from .models import TranslatedContent, LanguageMetadata
|
||||
from .providers import get_translation_provider, TranslationResult
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def translate_content(
|
||||
data_chunks: List[DocumentChunk],
|
||||
target_language: str = None,
|
||||
translation_provider: TranslationProviderType = None,
|
||||
confidence_threshold: float = None,
|
||||
skip_if_target_language: bool = True,
|
||||
preserve_original: bool = True,
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Translate non-English content to the target language.
|
||||
|
||||
This task detects the language of each document chunk and translates
|
||||
non-target-language content using the specified translation provider.
|
||||
Original text is preserved alongside translated versions.
|
||||
|
||||
Args:
|
||||
data_chunks: List of DocumentChunk objects to process
|
||||
target_language: Target language code (default: "en" for English)
|
||||
If not provided, uses config default
|
||||
translation_provider: Translation service to use ("llm", "google", "azure")
|
||||
If not provided, uses config default
|
||||
confidence_threshold: Minimum confidence for language detection (0.0 to 1.0)
|
||||
If not provided, uses config default
|
||||
skip_if_target_language: If True, skip chunks already in target language
|
||||
preserve_original: If True, store original text in TranslatedContent
|
||||
|
||||
Returns:
|
||||
List of DocumentChunk objects with translated content.
|
||||
Chunks that required translation will have TranslatedContent
|
||||
objects in their 'contains' list.
|
||||
|
||||
Note:
|
||||
This function mutates the input chunks in-place. Specifically:
|
||||
- chunk.text is replaced with the translated text
|
||||
- chunk.contains is updated with LanguageMetadata and TranslatedContent
|
||||
The original text is preserved in TranslatedContent.original_text
|
||||
if preserve_original=True.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from cognee.tasks.translation import translate_content
|
||||
|
||||
# Translate chunks using default settings
|
||||
translated_chunks = await translate_content(chunks)
|
||||
|
||||
# Translate with specific provider
|
||||
translated_chunks = await translate_content(
|
||||
chunks,
|
||||
translation_provider="llm",
|
||||
confidence_threshold=0.9
|
||||
)
|
||||
```
|
||||
"""
|
||||
if not isinstance(data_chunks, list):
|
||||
raise TranslationError("data_chunks must be a list")
|
||||
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
# Get configuration
|
||||
config = get_translation_config()
|
||||
provider_name = translation_provider or config.translation_provider
|
||||
target_lang = target_language or config.target_language
|
||||
threshold = confidence_threshold or config.confidence_threshold
|
||||
|
||||
logger.info(
|
||||
f"Starting translation task for {len(data_chunks)} chunks "
|
||||
f"using {provider_name} provider, target language: {target_lang}"
|
||||
)
|
||||
|
||||
# Get the translation provider
|
||||
provider = get_translation_provider(provider_name)
|
||||
|
||||
# Process chunks
|
||||
processed_chunks = []
|
||||
total_chunks = len(data_chunks)
|
||||
|
||||
for chunk_index, chunk in enumerate(data_chunks):
|
||||
# Log progress for large batches
|
||||
if chunk_index > 0 and chunk_index % 100 == 0:
|
||||
logger.info(f"Translation progress: {chunk_index}/{total_chunks} chunks processed")
|
||||
|
||||
if not hasattr(chunk, "text") or not chunk.text:
|
||||
processed_chunks.append(chunk)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Detect language
|
||||
detection = await detect_language_async(chunk.text, target_lang, threshold)
|
||||
|
||||
# Create language metadata
|
||||
language_metadata = LanguageMetadata(
|
||||
id=uuid5(chunk.id, "LanguageMetadata"),
|
||||
content_id=chunk.id,
|
||||
detected_language=detection.language_code,
|
||||
language_confidence=detection.confidence,
|
||||
requires_translation=detection.requires_translation,
|
||||
character_count=detection.character_count,
|
||||
language_name=detection.language_name,
|
||||
)
|
||||
|
||||
# Skip if already in target language
|
||||
if not detection.requires_translation:
|
||||
if skip_if_target_language:
|
||||
logger.debug(
|
||||
f"Skipping chunk {chunk.id}: already in target language "
|
||||
f"({detection.language_code})"
|
||||
)
|
||||
# Add language metadata to chunk
|
||||
_add_to_chunk_contains(chunk, language_metadata)
|
||||
processed_chunks.append(chunk)
|
||||
continue
|
||||
|
||||
# Translate the content
|
||||
logger.debug(
|
||||
f"Translating chunk {chunk.id} from {detection.language_code} to {target_lang}"
|
||||
)
|
||||
|
||||
translation_result = await provider.translate(
|
||||
text=chunk.text,
|
||||
target_language=target_lang,
|
||||
source_language=detection.language_code,
|
||||
)
|
||||
|
||||
# Create TranslatedContent data point
|
||||
translated_content = TranslatedContent(
|
||||
id=uuid5(chunk.id, "TranslatedContent"),
|
||||
original_chunk_id=chunk.id,
|
||||
original_text=chunk.text if preserve_original else "",
|
||||
translated_text=translation_result.translated_text,
|
||||
source_language=translation_result.source_language,
|
||||
target_language=translation_result.target_language,
|
||||
translation_provider=translation_result.provider,
|
||||
confidence_score=translation_result.confidence_score,
|
||||
translated_from=chunk,
|
||||
)
|
||||
|
||||
# Update chunk text with translated content
|
||||
chunk.text = translation_result.translated_text
|
||||
|
||||
# Add metadata to chunk's contains list
|
||||
_add_to_chunk_contains(chunk, language_metadata)
|
||||
_add_to_chunk_contains(chunk, translated_content)
|
||||
|
||||
processed_chunks.append(chunk)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully translated chunk {chunk.id}: "
|
||||
f"{detection.language_code} -> {target_lang}"
|
||||
)
|
||||
|
||||
except LanguageDetectionError as e:
|
||||
logger.warning(f"Language detection failed for chunk {chunk.id}: {e}")
|
||||
processed_chunks.append(chunk)
|
||||
except TranslationError as e:
|
||||
logger.error(f"Translation failed for chunk {chunk.id}: {e}")
|
||||
processed_chunks.append(chunk)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error processing chunk {chunk.id}: {e}")
|
||||
processed_chunks.append(chunk)
|
||||
|
||||
logger.info(f"Translation task completed for {len(processed_chunks)} chunks")
|
||||
return processed_chunks
|
||||
|
||||
|
||||
def _add_to_chunk_contains(chunk: DocumentChunk, item) -> None:
|
||||
"""Helper to add an item to a chunk's contains list."""
|
||||
if chunk.contains is None:
|
||||
chunk.contains = []
|
||||
chunk.contains.append(item)
|
||||
|
||||
|
||||
async def translate_text(
|
||||
text: str,
|
||||
target_language: str = None,
|
||||
translation_provider: TranslationProviderType = None,
|
||||
source_language: Optional[str] = None,
|
||||
) -> TranslationResult:
|
||||
"""
|
||||
Translate a single text string.
|
||||
|
||||
This is a convenience function for translating individual texts
|
||||
without creating DocumentChunk objects.
|
||||
|
||||
Args:
|
||||
text: The text to translate
|
||||
target_language: Target language code (default: uses config, typically "en")
|
||||
If not provided, uses config default
|
||||
translation_provider: Translation service to use
|
||||
If not provided, uses config default
|
||||
source_language: Source language code (optional, auto-detected if not provided)
|
||||
|
||||
Returns:
|
||||
TranslationResult with translated text and metadata
|
||||
|
||||
Example:
|
||||
```python
|
||||
from cognee.tasks.translation import translate_text
|
||||
|
||||
result = await translate_text(
|
||||
"Bonjour le monde!",
|
||||
target_language="en"
|
||||
)
|
||||
print(result.translated_text) # "Hello world!"
|
||||
print(result.source_language) # "fr"
|
||||
```
|
||||
"""
|
||||
config = get_translation_config()
|
||||
provider_name = translation_provider or config.translation_provider
|
||||
target_lang = target_language or config.target_language
|
||||
|
||||
provider = get_translation_provider(provider_name)
|
||||
|
||||
return await provider.translate(
|
||||
text=text,
|
||||
target_language=target_lang,
|
||||
source_language=source_language,
|
||||
)
|
||||
|
||||
|
||||
async def batch_translate_texts(
|
||||
texts: List[str],
|
||||
target_language: str = None,
|
||||
translation_provider: TranslationProviderType = None,
|
||||
source_language: Optional[str] = None,
|
||||
) -> List[TranslationResult]:
|
||||
"""
|
||||
Translate multiple text strings in batch.
|
||||
|
||||
This is more efficient than translating texts individually,
|
||||
especially for providers that support native batch operations.
|
||||
|
||||
Args:
|
||||
texts: List of texts to translate
|
||||
target_language: Target language code (default: uses config, typically "en")
|
||||
If not provided, uses config default
|
||||
translation_provider: Translation service to use
|
||||
If not provided, uses config default
|
||||
source_language: Source language code (optional)
|
||||
|
||||
Returns:
|
||||
List of TranslationResult objects
|
||||
|
||||
Example:
|
||||
```python
|
||||
from cognee.tasks.translation import batch_translate_texts
|
||||
|
||||
results = await batch_translate_texts(
|
||||
["Hola", "¿Cómo estás?", "Adiós"],
|
||||
target_language="en"
|
||||
)
|
||||
for result in results:
|
||||
print(f"{result.source_language}: {result.translated_text}")
|
||||
```
|
||||
"""
|
||||
config = get_translation_config()
|
||||
provider_name = translation_provider or config.translation_provider
|
||||
target_lang = target_language or config.target_language
|
||||
|
||||
provider = get_translation_provider(provider_name)
|
||||
|
||||
return await provider.translate_batch(
|
||||
texts=texts,
|
||||
target_language=target_lang,
|
||||
source_language=source_language,
|
||||
)
|
||||
147
cognee/tests/tasks/translation/README.md
Normal file
147
cognee/tests/tasks/translation/README.md
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
# Translation Task Tests
|
||||
|
||||
Unit and integration tests for the multilingual content translation feature.
|
||||
|
||||
## Test Files
|
||||
|
||||
- **config_test.py** - Tests for translation configuration
|
||||
- Default configuration
|
||||
- Provider type validation
|
||||
- Confidence threshold bounds
|
||||
- Multiple provider API keys
|
||||
|
||||
- **detect_language_test.py** - Tests for language detection functionality
|
||||
- English, Spanish, French, German, Chinese detection
|
||||
- Confidence thresholds
|
||||
- Edge cases (empty text, short text, mixed languages)
|
||||
|
||||
- **providers_test.py** - Tests for translation provider implementations
|
||||
- LLM provider basic translation
|
||||
- Auto-detection of source language
|
||||
- Batch translation
|
||||
- Special characters and formatting preservation
|
||||
- Error handling
|
||||
|
||||
- **translate_content_test.py** - Tests for the main translate_content task
|
||||
- Basic translation workflow
|
||||
- Original text preservation
|
||||
- Multiple chunks processing
|
||||
- Language metadata creation
|
||||
- Skip translation for target language
|
||||
- Confidence threshold customization
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Run all translation tests
|
||||
```bash
|
||||
uv run pytest cognee/tests/tasks/translation/ -v
|
||||
```
|
||||
|
||||
### Run specific test file
|
||||
```bash
|
||||
uv run pytest cognee/tests/tasks/translation/detect_language_test.py -v
|
||||
```
|
||||
|
||||
### Run tests directly (without pytest)
|
||||
```bash
|
||||
uv run python cognee/tests/tasks/translation/config_test.py
|
||||
uv run python cognee/tests/tasks/translation/detect_language_test.py
|
||||
uv run python cognee/tests/tasks/translation/providers_test.py
|
||||
uv run python cognee/tests/tasks/translation/translate_content_test.py
|
||||
uv run python cognee/tests/tasks/translation/integration_test.py
|
||||
```
|
||||
|
||||
### Run all tests at once
|
||||
```bash
|
||||
for f in cognee/tests/tasks/translation/*_test.py; do uv run python "$f"; done
|
||||
```
|
||||
|
||||
### Run with coverage
|
||||
```bash
|
||||
uv run pytest cognee/tests/tasks/translation/ --cov=cognee.tasks.translation --cov-report=html
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- LLM API key set in environment: `LLM_API_KEY=your_key`
|
||||
- Tests will be skipped if no API key is available
|
||||
|
||||
**Note:** The translation feature uses the same LLM model configured for other cognee tasks (via `LLM_MODEL` and `LLM_PROVIDER` environment variables). This means any LLM provider supported by cognee (OpenAI, Azure, Anthropic, Ollama, etc.) can be used for translation.
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
import cognee
|
||||
from cognee.tasks.translation import translate_text
|
||||
|
||||
# Configure translation (optional - defaults to LLM provider)
|
||||
cognee.config.set_translation_config(
|
||||
provider="llm", # Uses configured LLM (default)
|
||||
target_language="en", # Target language code
|
||||
confidence_threshold=0.7 # Minimum confidence for language detection
|
||||
)
|
||||
|
||||
# Translate text directly
|
||||
result = await translate_text(
|
||||
text="Bonjour le monde",
|
||||
target_language="en"
|
||||
)
|
||||
print(result.translated_text) # "Hello world"
|
||||
```
|
||||
|
||||
### Alternative Translation Providers
|
||||
|
||||
```python
|
||||
# Use Google Cloud Translate (requires GOOGLE_TRANSLATE_API_KEY)
|
||||
cognee.config.set_translation_provider("google")
|
||||
|
||||
# Use Azure Translator (requires AZURE_TRANSLATOR_KEY and AZURE_TRANSLATOR_REGION)
|
||||
cognee.config.set_translation_provider("azure")
|
||||
```
|
||||
|
||||
## Test Summary
|
||||
|
||||
| Test File | Tests | Description |
|
||||
|-----------|-------|-------------|
|
||||
| config_test.py | 4 | Configuration validation |
|
||||
| detect_language_test.py | 10 | Language detection |
|
||||
| providers_test.py | 9 | Translation providers |
|
||||
| translate_content_test.py | 9 | Content translation task |
|
||||
| integration_test.py | 2 | Standalone translation tests |
|
||||
| **Total** | **34** | |
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Configuration (4 tests)
|
||||
- ✅ Default configuration values
|
||||
- ✅ Provider type literal validation
|
||||
- ✅ Confidence threshold bounds
|
||||
- ✅ Multiple provider API keys
|
||||
|
||||
### Language Detection (10 tests)
|
||||
- ✅ Multiple language detection (EN, ES, FR, DE, ZH)
|
||||
- ✅ Confidence scoring
|
||||
- ✅ Target language matching
|
||||
- ✅ Short and empty text handling
|
||||
- ✅ Mixed language detection
|
||||
|
||||
### Translation Providers (9 tests)
|
||||
- ✅ Provider factory function
|
||||
- ✅ LLM translation
|
||||
- ✅ Batch operations
|
||||
- ✅ Auto source language detection
|
||||
- ✅ Long text handling
|
||||
- ✅ Special characters preservation
|
||||
- ✅ Error handling
|
||||
|
||||
### Content Translation (9 tests)
|
||||
- ✅ DocumentChunk processing
|
||||
- ✅ Metadata creation (LanguageMetadata, TranslatedContent)
|
||||
- ✅ Original text preservation
|
||||
- ✅ Multiple chunk handling
|
||||
- ✅ Empty text/list handling
|
||||
- ✅ Confidence threshold customization
|
||||
|
||||
### Integration (2 tests)
|
||||
- ✅ Direct translate_text function
|
||||
- ✅ Language detection functionality
|
||||
1
cognee/tests/tasks/translation/__init__.py
Normal file
1
cognee/tests/tasks/translation/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Translation task tests"""
|
||||
93
cognee/tests/tasks/translation/config_test.py
Normal file
93
cognee/tests/tasks/translation/config_test.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
Unit tests for translation configuration
|
||||
"""
|
||||
|
||||
from typing import get_args
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from cognee.tasks.translation.config import (
|
||||
get_translation_config,
|
||||
TranslationConfig,
|
||||
TranslationProviderType,
|
||||
)
|
||||
|
||||
|
||||
def test_default_translation_config():
|
||||
"""Test default translation configuration"""
|
||||
config = get_translation_config()
|
||||
|
||||
assert isinstance(config, TranslationConfig), "Config should be TranslationConfig instance"
|
||||
assert config.translation_provider in [
|
||||
"llm",
|
||||
"google",
|
||||
"azure",
|
||||
], f"Invalid provider: {config.translation_provider}"
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0, (
|
||||
f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]"
|
||||
)
|
||||
|
||||
|
||||
def test_translation_provider_type_literal():
|
||||
"""Test TranslationProviderType Literal type values"""
|
||||
# Get the allowed values from the Literal type
|
||||
allowed_values = get_args(TranslationProviderType)
|
||||
|
||||
assert "llm" in allowed_values, "llm should be an allowed provider"
|
||||
assert "google" in allowed_values, "google should be an allowed provider"
|
||||
assert "azure" in allowed_values, "azure should be an allowed provider"
|
||||
assert len(allowed_values) == 3, f"Expected 3 providers, got {len(allowed_values)}"
|
||||
|
||||
|
||||
def test_confidence_threshold_bounds():
|
||||
"""Test confidence threshold validation"""
|
||||
config = TranslationConfig(translation_provider="llm", confidence_threshold=0.9)
|
||||
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0, (
|
||||
f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]"
|
||||
)
|
||||
|
||||
|
||||
def test_confidence_threshold_validation():
|
||||
"""Test that invalid confidence thresholds are rejected or clamped"""
|
||||
# Test boundary values - these should work
|
||||
config_min = TranslationConfig(translation_provider="llm", confidence_threshold=0.0)
|
||||
assert config_min.confidence_threshold == 0.0, "Minimum bound (0.0) should be valid"
|
||||
|
||||
config_max = TranslationConfig(translation_provider="llm", confidence_threshold=1.0)
|
||||
assert config_max.confidence_threshold == 1.0, "Maximum bound (1.0) should be valid"
|
||||
|
||||
# Test invalid values - these should either raise ValidationError or be clamped
|
||||
try:
|
||||
config_invalid_low = TranslationConfig(
|
||||
translation_provider="llm", confidence_threshold=-0.1
|
||||
)
|
||||
# If no error, verify it was clamped to valid range
|
||||
assert 0.0 <= config_invalid_low.confidence_threshold <= 1.0, (
|
||||
f"Invalid low value should be clamped, got {config_invalid_low.confidence_threshold}"
|
||||
)
|
||||
except ValidationError:
|
||||
pass # Expected validation error
|
||||
|
||||
try:
|
||||
config_invalid_high = TranslationConfig(
|
||||
translation_provider="llm", confidence_threshold=1.5
|
||||
)
|
||||
# If no error, verify it was clamped to valid range
|
||||
assert 0.0 <= config_invalid_high.confidence_threshold <= 1.0, (
|
||||
f"Invalid high value should be clamped, got {config_invalid_high.confidence_threshold}"
|
||||
)
|
||||
except ValidationError:
|
||||
pass # Expected validation error
|
||||
|
||||
|
||||
def test_multiple_provider_keys():
|
||||
"""Test configuration with multiple provider API keys"""
|
||||
config = TranslationConfig(
|
||||
translation_provider="llm",
|
||||
google_translate_api_key="google_key",
|
||||
azure_translator_key="azure_key",
|
||||
)
|
||||
|
||||
assert config.google_translate_api_key == "google_key", "Google API key not set correctly"
|
||||
assert config.azure_translator_key == "azure_key", "Azure API key not set correctly"
|
||||
118
cognee/tests/tasks/translation/detect_language_test.py
Normal file
118
cognee/tests/tasks/translation/detect_language_test.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
Unit tests for language detection functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from cognee.tasks.translation.detect_language import (
|
||||
detect_language_async,
|
||||
LanguageDetectionResult,
|
||||
)
|
||||
from cognee.tasks.translation.exceptions import LanguageDetectionError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_english():
|
||||
"""Test detection of English text"""
|
||||
result = await detect_language_async("Hello world, this is a test.", target_language="en")
|
||||
|
||||
assert result.language_code == "en"
|
||||
assert result.requires_translation is False
|
||||
assert result.confidence > 0.9
|
||||
assert result.language_name == "English"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_spanish():
|
||||
"""Test detection of Spanish text"""
|
||||
result = await detect_language_async("Hola mundo, esta es una prueba.", target_language="en")
|
||||
|
||||
assert result.language_code == "es"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
assert result.language_name == "Spanish"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_french():
|
||||
"""Test detection of French text"""
|
||||
result = await detect_language_async(
|
||||
"Bonjour le monde, ceci est un test.", target_language="en"
|
||||
)
|
||||
|
||||
assert result.language_code == "fr"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
assert result.language_name == "French"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_german():
|
||||
"""Test detection of German text"""
|
||||
result = await detect_language_async("Hallo Welt, das ist ein Test.", target_language="en")
|
||||
|
||||
assert result.language_code == "de"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_chinese():
|
||||
"""Test detection of Chinese text"""
|
||||
result = await detect_language_async("你好世界,这是一个测试。", target_language="en")
|
||||
|
||||
assert result.language_code.startswith("zh"), f"Expected Chinese, got {result.language_code}"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_target_language():
|
||||
"""Test when text is already in target language"""
|
||||
result = await detect_language_async("This text is already in English.", target_language="en")
|
||||
|
||||
assert result.requires_translation is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_text():
|
||||
"""Test detection with very short text"""
|
||||
result = await detect_language_async("Hi", target_language="es")
|
||||
|
||||
# Short text may return 'unknown' if langdetect can't reliably detect
|
||||
assert result.language_code in ["en", "unknown"]
|
||||
assert result.character_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text():
|
||||
"""Test detection with empty text - returns unknown by default"""
|
||||
result = await detect_language_async("", target_language="en")
|
||||
|
||||
# With skip_detection_for_short_text=True (default), returns unknown
|
||||
assert result.language_code == "unknown"
|
||||
assert result.language_name == "Unknown"
|
||||
assert result.confidence == 0.0
|
||||
assert result.requires_translation is False
|
||||
assert result.character_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_threshold():
|
||||
"""Test detection respects confidence threshold"""
|
||||
result = await detect_language_async(
|
||||
"Hello world", target_language="es", confidence_threshold=0.5
|
||||
)
|
||||
|
||||
assert result.confidence >= 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_language_text():
|
||||
"""Test detection with mixed language text (predominantly one language)"""
|
||||
# Predominantly Spanish with English word
|
||||
result = await detect_language_async(
|
||||
"La inteligencia artificial es muy importante en technology moderna.", target_language="en"
|
||||
)
|
||||
|
||||
assert result.language_code == "es" # Should detect as Spanish
|
||||
assert result.requires_translation is True
|
||||
151
cognee/tests/tasks/translation/providers_test.py
Normal file
151
cognee/tests/tasks/translation/providers_test.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""
|
||||
Unit tests for translation providers
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.tasks.translation.providers import (
|
||||
get_translation_provider,
|
||||
LLMTranslationProvider,
|
||||
TranslationResult,
|
||||
)
|
||||
from cognee.tasks.translation.exceptions import TranslationError
|
||||
|
||||
|
||||
def has_llm_api_key():
|
||||
"""Check if LLM API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_llm_provider_basic_translation():
|
||||
"""Test basic translation with LLM provider (uses configured LLM)"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
result = await provider.translate(text="Hola mundo", target_language="en", source_language="es")
|
||||
|
||||
assert isinstance(result, TranslationResult)
|
||||
assert result.translated_text is not None
|
||||
assert len(result.translated_text) > 0
|
||||
assert result.source_language == "es"
|
||||
assert result.target_language == "en"
|
||||
assert result.provider == "llm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_llm_provider_auto_detect_source():
|
||||
"""Test translation with automatic source language detection"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
result = await provider.translate(
|
||||
text="Bonjour le monde",
|
||||
target_language="en",
|
||||
# source_language not provided - should auto-detect
|
||||
)
|
||||
|
||||
assert result.translated_text is not None
|
||||
assert result.target_language == "en"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_llm_provider_long_text():
|
||||
"""Test translation of longer text"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
long_text = """
|
||||
La inteligencia artificial es una rama de la informática que se centra en
|
||||
crear sistemas capaces de realizar tareas que normalmente requieren inteligencia humana.
|
||||
Estos sistemas pueden aprender, razonar y resolver problemas complejos.
|
||||
"""
|
||||
|
||||
result = await provider.translate(text=long_text, target_language="en", source_language="es")
|
||||
|
||||
assert len(result.translated_text) > 0
|
||||
assert result.source_language == "es"
|
||||
|
||||
|
||||
def test_get_translation_provider_factory():
|
||||
"""Test provider factory function"""
|
||||
provider = get_translation_provider("llm")
|
||||
assert isinstance(provider, LLMTranslationProvider)
|
||||
|
||||
|
||||
def test_get_translation_provider_invalid():
|
||||
"""Test provider factory with invalid provider name"""
|
||||
try:
|
||||
get_translation_provider("invalid_provider")
|
||||
assert False, "Expected TranslationError or ValueError"
|
||||
except (TranslationError, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_llm_batch_translation():
|
||||
"""Test batch translation with LLM provider"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
texts = ["Hola", "¿Cómo estás?", "Adiós"]
|
||||
|
||||
results = await provider.translate_batch(
|
||||
texts=texts, target_language="en", source_language="es"
|
||||
)
|
||||
|
||||
assert len(results) == len(texts)
|
||||
for result in results:
|
||||
assert isinstance(result, TranslationResult)
|
||||
assert result.translated_text is not None
|
||||
assert result.source_language == "es"
|
||||
assert result.target_language == "en"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_preserves_formatting():
|
||||
"""Test that translation preserves basic formatting"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
text_with_newlines = "Primera línea.\nSegunda línea."
|
||||
|
||||
result = await provider.translate(
|
||||
text=text_with_newlines, target_language="en", source_language="es"
|
||||
)
|
||||
|
||||
# Should preserve structure (though exact newlines may vary)
|
||||
assert result.translated_text is not None
|
||||
assert len(result.translated_text) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_special_characters():
|
||||
"""Test translation with special characters"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
text = "¡Hola! ¿Cómo estás? Está bien."
|
||||
|
||||
result = await provider.translate(text=text, target_language="en", source_language="es")
|
||||
|
||||
assert result.translated_text is not None
|
||||
assert len(result.translated_text) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_empty_text_translation():
|
||||
"""Test translation with empty text - should return empty or handle gracefully"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
# Empty text may either raise an error or return an empty result
|
||||
try:
|
||||
result = await provider.translate(text="", target_language="en", source_language="es")
|
||||
# If no error, should return a TranslationResult (possibly with empty text)
|
||||
assert isinstance(result, TranslationResult)
|
||||
except TranslationError:
|
||||
# This is also acceptable behavior
|
||||
pass
|
||||
213
cognee/tests/tasks/translation/translate_content_test.py
Normal file
213
cognee/tests/tasks/translation/translate_content_test.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
"""
|
||||
Unit tests for translate_content task
|
||||
"""
|
||||
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.tasks.translation import translate_content
|
||||
from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata
|
||||
|
||||
|
||||
def has_llm_api_key():
|
||||
"""Check if LLM API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY"))
|
||||
|
||||
|
||||
def create_test_chunk(text: str, chunk_index: int = 0):
|
||||
"""Helper to create a DocumentChunk with all required fields"""
|
||||
# Create a minimal Document for the is_part_of field
|
||||
doc = TextDocument(
|
||||
id=uuid4(),
|
||||
name="test_doc",
|
||||
raw_data_location="/tmp/test.txt",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
return DocumentChunk(
|
||||
id=uuid4(),
|
||||
text=text,
|
||||
chunk_index=chunk_index,
|
||||
chunk_size=len(text),
|
||||
cut_type="sentence",
|
||||
is_part_of=doc,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_basic():
|
||||
"""Test basic content translation"""
|
||||
# Create test chunk with Spanish text
|
||||
original_text = "Hola mundo, esta es una prueba."
|
||||
chunk = create_test_chunk(original_text)
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", translation_provider="llm"
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
# The chunk's text should now be translated (different from original Spanish)
|
||||
assert result[0].text != original_text # Text should be translated to English
|
||||
assert result[0].contains is not None
|
||||
|
||||
# Check for TranslatedContent in contains
|
||||
has_translated_content = any(isinstance(item, TranslatedContent) for item in result[0].contains)
|
||||
assert has_translated_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_preserves_original():
|
||||
"""Test that original text is preserved"""
|
||||
original_text = "Bonjour le monde"
|
||||
chunk = create_test_chunk(original_text)
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", preserve_original=True
|
||||
)
|
||||
|
||||
# Find TranslatedContent in contains
|
||||
translated_content = None
|
||||
for item in result[0].contains:
|
||||
if isinstance(item, TranslatedContent):
|
||||
translated_content = item
|
||||
break
|
||||
|
||||
assert translated_content is not None
|
||||
assert translated_content.original_text == original_text
|
||||
assert translated_content.translated_text != original_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_content_skip_english():
|
||||
"""Test skipping translation for English text"""
|
||||
# This test doesn't require API call since English text is skipped
|
||||
chunk = create_test_chunk("Hello world, this is a test.")
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", skip_if_target_language=True
|
||||
)
|
||||
|
||||
# Text should remain unchanged
|
||||
assert result[0].text == chunk.text
|
||||
|
||||
# Should have LanguageMetadata but not TranslatedContent
|
||||
has_language_metadata = any(
|
||||
isinstance(item, LanguageMetadata) for item in (result[0].contains or [])
|
||||
)
|
||||
has_translated_content = any(
|
||||
isinstance(item, TranslatedContent) for item in (result[0].contains or [])
|
||||
)
|
||||
|
||||
assert has_language_metadata
|
||||
assert not has_translated_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_multiple_chunks():
|
||||
"""Test translation of multiple chunks"""
|
||||
# Use longer texts to ensure reliable language detection
|
||||
original_texts = [
|
||||
"Hola mundo, esta es una prueba de traducción.",
|
||||
"Bonjour le monde, ceci est un test de traduction.",
|
||||
"Ciao mondo, questo è un test di traduzione.",
|
||||
]
|
||||
chunks = [create_test_chunk(text, i) for i, text in enumerate(original_texts)]
|
||||
|
||||
result = await translate_content(data_chunks=chunks, target_language="en")
|
||||
|
||||
assert len(result) == 3
|
||||
# Check that at least some chunks were translated
|
||||
translated_count = sum(
|
||||
1
|
||||
for chunk in result
|
||||
if any(isinstance(item, TranslatedContent) for item in (chunk.contains or []))
|
||||
)
|
||||
assert translated_count >= 2 # At least 2 chunks should be translated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_content_empty_list():
|
||||
"""Test with empty chunk list"""
|
||||
result = await translate_content(data_chunks=[], target_language="en")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_content_empty_text():
|
||||
"""Test with chunk containing empty text"""
|
||||
chunk = create_test_chunk("")
|
||||
|
||||
result = await translate_content(data_chunks=[chunk], target_language="en")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_language_metadata():
|
||||
"""Test that LanguageMetadata is created correctly"""
|
||||
# Use a longer, distinctly Spanish text to ensure reliable detection
|
||||
chunk = create_test_chunk(
|
||||
"La inteligencia artificial está cambiando el mundo de manera significativa"
|
||||
)
|
||||
|
||||
result = await translate_content(data_chunks=[chunk], target_language="en")
|
||||
|
||||
# Find LanguageMetadata
|
||||
language_metadata = None
|
||||
for item in result[0].contains:
|
||||
if isinstance(item, LanguageMetadata):
|
||||
language_metadata = item
|
||||
break
|
||||
|
||||
assert language_metadata is not None
|
||||
# Just check that a language was detected (short texts can be ambiguous)
|
||||
assert language_metadata.detected_language is not None
|
||||
assert language_metadata.requires_translation is True
|
||||
assert language_metadata.language_confidence > 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_confidence_threshold():
|
||||
"""Test with custom confidence threshold"""
|
||||
# Use longer text for more reliable detection
|
||||
chunk = create_test_chunk("Hola mundo, esta es una frase más larga para mejor detección")
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", confidence_threshold=0.5
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_no_preserve_original():
|
||||
"""Test translation without preserving original"""
|
||||
# Use longer text for more reliable detection
|
||||
chunk = create_test_chunk("Bonjour le monde, comment allez-vous aujourd'hui")
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", preserve_original=False
|
||||
)
|
||||
|
||||
# Find TranslatedContent
|
||||
translated_content = None
|
||||
for item in result[0].contains:
|
||||
if isinstance(item, TranslatedContent):
|
||||
translated_content = item
|
||||
break
|
||||
|
||||
assert translated_content is not None
|
||||
assert translated_content.original_text == "" # Should be empty
|
||||
|
|
@ -61,8 +61,8 @@ dependencies = [
|
|||
"diskcache>=5.6.3",
|
||||
"aiolimiter>=1.2.1",
|
||||
"urllib3>=2.6.0",
|
||||
"cbor2>=5.8.0"
|
||||
|
||||
"cbor2>=5.8.0",
|
||||
"langdetect>=1.0.9",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -934,6 +934,7 @@ dependencies = [
|
|||
{ name = "jinja2" },
|
||||
{ name = "kuzu" },
|
||||
{ name = "lancedb" },
|
||||
{ name = "langdetect" },
|
||||
{ name = "limits" },
|
||||
{ name = "litellm" },
|
||||
{ name = "mistralai" },
|
||||
|
|
@ -1134,6 +1135,7 @@ requires-dist = [
|
|||
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.22" },
|
||||
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = ">=1.2.5" },
|
||||
{ name = "langchain-text-splitters", marker = "extra == 'langchain'", specifier = ">=0.3.2,<1.0.0" },
|
||||
{ name = "langdetect", specifier = ">=1.0.9" },
|
||||
{ name = "langfuse", marker = "extra == 'monitoring'", specifier = ">=2.32.0,<3" },
|
||||
{ name = "langsmith", marker = "extra == 'langchain'", specifier = ">=0.2.3,<1.0.0" },
|
||||
{ name = "limits", specifier = ">=4.4.1,<5" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue