diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 50071caef..1b50b6d2f 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -335,25 +335,27 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's ) ) - default_tasks.extend([ - Task( - extract_graph_from_data, - graph_model=graph_model, - config=config, - custom_prompt=custom_prompt, - task_config={"batch_size": chunks_per_batch}, - **kwargs, - ), # Generate knowledge graphs from the document chunks. - Task( - summarize_text, - task_config={"batch_size": chunks_per_batch}, - ), - Task( - add_data_points, - embed_triplets=embed_triplets, - task_config={"batch_size": chunks_per_batch}, - ), - ]) + default_tasks.extend( + [ + Task( + extract_graph_from_data, + graph_model=graph_model, + config=config, + custom_prompt=custom_prompt, + task_config={"batch_size": chunks_per_batch}, + **kwargs, + ), # Generate knowledge graphs from the document chunks. + Task( + summarize_text, + task_config={"batch_size": chunks_per_batch}, + ), + Task( + add_data_points, + embed_triplets=embed_triplets, + task_config={"batch_size": chunks_per_batch}, + ), + ] + ) return default_tasks @@ -413,10 +415,12 @@ async def get_temporal_tasks( ) ) - temporal_tasks.extend([ - Task(extract_events_and_timestamps, task_config={"batch_size": chunks_per_batch}), - Task(extract_knowledge_graph_from_events), - Task(add_data_points, task_config={"batch_size": chunks_per_batch}), - ]) + temporal_tasks.extend( + [ + Task(extract_events_and_timestamps, task_config={"batch_size": chunks_per_batch}), + Task(extract_knowledge_graph_from_events), + Task(add_data_points, task_config={"batch_size": chunks_per_batch}), + ] + ) return temporal_tasks diff --git a/cognee/tasks/translation/exceptions.py b/cognee/tasks/translation/exceptions.py index 322e00c7a..ba5e74510 100644 --- a/cognee/tasks/translation/exceptions.py +++ b/cognee/tasks/translation/exceptions.py @@ -10,7 +10,9 @@ class TranslationError(Exception): class LanguageDetectionError(TranslationError): """Exception raised when language detection fails.""" - def __init__(self, message: str = "Failed to detect language", original_error: Exception = None): + def __init__( + self, message: str = "Failed to detect language", original_error: Exception = None + ): super().__init__(message, original_error) diff --git a/cognee/tasks/translation/providers/azure_provider.py b/cognee/tasks/translation/providers/azure_provider.py index 4618834ff..2ee1f45d7 100644 --- a/cognee/tasks/translation/providers/azure_provider.py +++ b/cognee/tasks/translation/providers/azure_provider.py @@ -89,8 +89,7 @@ class AzureTranslationProvider(TranslationProvider): return TranslationResult( translated_text=translation["text"], - source_language=source_language - or detected_language.get("language", "unknown"), + source_language=source_language or detected_language.get("language", "unknown"), target_language=target_language, confidence_score=detected_language.get("score", 0.9), provider=self.provider_name, diff --git a/cognee/tasks/translation/providers/openai_provider.py b/cognee/tasks/translation/providers/openai_provider.py index 2c70c6edb..a888d688e 100644 --- a/cognee/tasks/translation/providers/openai_provider.py +++ b/cognee/tasks/translation/providers/openai_provider.py @@ -101,7 +101,5 @@ class OpenAITranslationProvider(TranslationProvider): Returns: List of TranslationResult objects """ - tasks = [ - self.translate(text, target_language, source_language) for text in texts - ] + tasks = [self.translate(text, target_language, source_language) for text in texts] return await asyncio.gather(*tasks) diff --git a/cognee/tasks/translation/translate_content.py b/cognee/tasks/translation/translate_content.py index e200f659d..1c869b132 100644 --- a/cognee/tasks/translation/translate_content.py +++ b/cognee/tasks/translation/translate_content.py @@ -88,9 +88,7 @@ async def translate_content( try: # Detect language - detection = await detect_language_async( - chunk.text, target_language, threshold - ) + detection = await detect_language_async(chunk.text, target_language, threshold) # Create language metadata language_metadata = LanguageMetadata( @@ -117,8 +115,7 @@ async def translate_content( # Translate the content logger.debug( - f"Translating chunk {chunk.id} from {detection.language_code} " - f"to {target_language}" + f"Translating chunk {chunk.id} from {detection.language_code} to {target_language}" ) translation_result = await provider.translate( diff --git a/cognee/tests/tasks/translation/README.md b/cognee/tests/tasks/translation/README.md new file mode 100644 index 000000000..cb56bf18a --- /dev/null +++ b/cognee/tests/tasks/translation/README.md @@ -0,0 +1,126 @@ +# Translation Task Tests + +Unit and integration tests for the multilingual content translation feature. + +## Test Files + +- **config_test.py** - Tests for translation configuration + - Default configuration + - Provider type validation + - Confidence threshold bounds + - Multiple provider API keys + +- **detect_language_test.py** - Tests for language detection functionality + - English, Spanish, French, German, Chinese detection + - Confidence thresholds + - Edge cases (empty text, short text, mixed languages) + +- **providers_test.py** - Tests for translation provider implementations + - OpenAI provider basic translation + - Auto-detection of source language + - Batch translation + - Special characters and formatting preservation + - Error handling + +- **translate_content_test.py** - Tests for the main translate_content task + - Basic translation workflow + - Original text preservation + - Multiple chunks processing + - Language metadata creation + - Skip translation for target language + - Confidence threshold customization + +- **integration_test.py** - End-to-end integration tests + - Full cognify pipeline with translation + - Spanish/French to English translation + - Mixed language datasets + - Search functionality after translation + - Translation disabled mode + +## Running Tests + +### Run all translation tests +```bash +uv run pytest cognee/tests/tasks/translation/ -v +``` + +### Run specific test file +```bash +uv run pytest cognee/tests/tasks/translation/detect_language_test.py -v +``` + +### Run tests directly (without pytest) +```bash +uv run python cognee/tests/tasks/translation/config_test.py +uv run python cognee/tests/tasks/translation/detect_language_test.py +uv run python cognee/tests/tasks/translation/providers_test.py +uv run python cognee/tests/tasks/translation/translate_content_test.py +uv run python cognee/tests/tasks/translation/integration_test.py +``` + +### Run all tests at once +```bash +for f in cognee/tests/tasks/translation/*_test.py; do uv run python "$f"; done +``` + +### Run with coverage +```bash +uv run pytest cognee/tests/tasks/translation/ --cov=cognee.tasks.translation --cov-report=html +``` + +## Prerequisites + +- LLM API key set in environment: `LLM_API_KEY=your_key` +- Tests will be skipped if no API key is available + +## Test Summary + +| Test File | Tests | Description | +|-----------|-------|-------------| +| config_test.py | 4 | Configuration validation | +| detect_language_test.py | 10 | Language detection | +| providers_test.py | 9 | Translation providers | +| translate_content_test.py | 9 | Content translation task | +| integration_test.py | 8 | End-to-end pipeline | +| **Total** | **40** | | + +## Test Categories + +### Configuration (4 tests) +- ✅ Default configuration values +- ✅ Provider type literal validation +- ✅ Confidence threshold bounds +- ✅ Multiple provider API keys + +### Language Detection (10 tests) +- ✅ Multiple language detection (EN, ES, FR, DE, ZH) +- ✅ Confidence scoring +- ✅ Target language matching +- ✅ Short and empty text handling +- ✅ Mixed language detection + +### Translation Providers (9 tests) +- ✅ Provider factory function +- ✅ OpenAI translation +- ✅ Batch operations +- ✅ Auto source language detection +- ✅ Long text handling +- ✅ Special characters preservation +- ✅ Error handling + +### Content Translation (9 tests) +- ✅ DocumentChunk processing +- ✅ Metadata creation (LanguageMetadata, TranslatedContent) +- ✅ Original text preservation +- ✅ Multiple chunk handling +- ✅ Empty text/list handling +- ✅ Confidence threshold customization + +### Integration (8 tests) +- ✅ Full cognify pipeline with auto_translate=True +- ✅ Spanish to English translation +- ✅ French to English translation +- ✅ Mixed language datasets +- ✅ Translation disabled mode +- ✅ Direct translate_text function +- ✅ Search after translation diff --git a/cognee/tests/tasks/translation/__init__.py b/cognee/tests/tasks/translation/__init__.py new file mode 100644 index 000000000..7284dcfa5 --- /dev/null +++ b/cognee/tests/tasks/translation/__init__.py @@ -0,0 +1 @@ +"""Translation task tests""" diff --git a/cognee/tests/tasks/translation/config_test.py b/cognee/tests/tasks/translation/config_test.py new file mode 100644 index 000000000..ee8d6019c --- /dev/null +++ b/cognee/tests/tasks/translation/config_test.py @@ -0,0 +1,66 @@ +""" +Unit tests for translation configuration +""" + +import os +from typing import get_args +from cognee.tasks.translation.config import ( + get_translation_config, + TranslationConfig, + TranslationProviderType, +) + + +def test_default_translation_config(): + """Test default translation configuration""" + config = get_translation_config() + + assert isinstance(config, TranslationConfig) + assert config.translation_provider in ["openai", "google", "azure"] + assert 0.0 <= config.confidence_threshold <= 1.0 + + +def test_translation_provider_type_literal(): + """Test TranslationProviderType Literal type values""" + # Get the allowed values from the Literal type + allowed_values = get_args(TranslationProviderType) + + assert "openai" in allowed_values + assert "google" in allowed_values + assert "azure" in allowed_values + assert len(allowed_values) == 3 + + +def test_confidence_threshold_bounds(): + """Test confidence threshold validation""" + config = TranslationConfig(translation_provider="openai", confidence_threshold=0.9) + + assert 0.0 <= config.confidence_threshold <= 1.0 + + +def test_multiple_provider_keys(): + """Test configuration with multiple provider API keys""" + config = TranslationConfig( + translation_provider="openai", + google_translate_api_key="google_key", + azure_translator_key="azure_key", + ) + + assert config.google_translate_api_key == "google_key" + assert config.azure_translator_key == "azure_key" + + +if __name__ == "__main__": + test_default_translation_config() + print("✓ test_default_translation_config passed") + + test_translation_provider_type_literal() + print("✓ test_translation_provider_type_literal passed") + + test_confidence_threshold_bounds() + print("✓ test_confidence_threshold_bounds passed") + + test_multiple_provider_keys() + print("✓ test_multiple_provider_keys passed") + + print("\nAll config tests passed!") diff --git a/cognee/tests/tasks/translation/detect_language_test.py b/cognee/tests/tasks/translation/detect_language_test.py new file mode 100644 index 000000000..907c94df8 --- /dev/null +++ b/cognee/tests/tasks/translation/detect_language_test.py @@ -0,0 +1,147 @@ +""" +Unit tests for language detection functionality +""" + +import asyncio +from cognee.tasks.translation.detect_language import ( + detect_language_async, + LanguageDetectionResult, +) +from cognee.tasks.translation.exceptions import LanguageDetectionError + + +async def test_detect_english(): + """Test detection of English text""" + result = await detect_language_async("Hello world, this is a test.", target_language="en") + + assert result.language_code == "en" + assert result.requires_translation is False + assert result.confidence > 0.9 + assert result.language_name == "English" + + +async def test_detect_spanish(): + """Test detection of Spanish text""" + result = await detect_language_async("Hola mundo, esta es una prueba.", target_language="en") + + assert result.language_code == "es" + assert result.requires_translation is True + assert result.confidence > 0.9 + assert result.language_name == "Spanish" + + +async def test_detect_french(): + """Test detection of French text""" + result = await detect_language_async( + "Bonjour le monde, ceci est un test.", target_language="en" + ) + + assert result.language_code == "fr" + assert result.requires_translation is True + assert result.confidence > 0.9 + assert result.language_name == "French" + + +async def test_detect_german(): + """Test detection of German text""" + result = await detect_language_async("Hallo Welt, das ist ein Test.", target_language="en") + + assert result.language_code == "de" + assert result.requires_translation is True + assert result.confidence > 0.9 + + +async def test_detect_chinese(): + """Test detection of Chinese text""" + result = await detect_language_async("你好世界,这是一个测试。", target_language="en") + + assert result.language_code == "zh-cn" + assert result.requires_translation is True + assert result.confidence > 0.9 + + +async def test_already_target_language(): + """Test when text is already in target language""" + result = await detect_language_async("This text is already in English.", target_language="en") + + assert result.requires_translation is False + + +async def test_short_text(): + """Test detection with very short text""" + result = await detect_language_async("Hi", target_language="es") + + # Short text may return 'unknown' if langdetect can't reliably detect + assert result.language_code in ["en", "unknown"] + assert result.character_count == 2 + + +async def test_empty_text(): + """Test detection with empty text - returns unknown by default""" + result = await detect_language_async("", target_language="en") + + # With skip_detection_for_short_text=True (default), returns unknown + assert result.language_code == "unknown" + assert result.language_name == "Unknown" + assert result.confidence == 0.0 + assert result.requires_translation is False + assert result.character_count == 0 + + +async def test_confidence_threshold(): + """Test detection respects confidence threshold""" + result = await detect_language_async( + "Hello world", target_language="es", confidence_threshold=0.5 + ) + + assert result.confidence >= 0.5 + + +async def test_mixed_language_text(): + """Test detection with mixed language text (predominantly one language)""" + # Predominantly Spanish with English word + result = await detect_language_async( + "La inteligencia artificial es muy importante en technology moderna.", target_language="en" + ) + + assert result.language_code == "es" # Should detect as Spanish + assert result.requires_translation is True + + +async def main(): + """Run all language detection tests""" + await test_detect_english() + print("✓ test_detect_english passed") + + await test_detect_spanish() + print("✓ test_detect_spanish passed") + + await test_detect_french() + print("✓ test_detect_french passed") + + await test_detect_german() + print("✓ test_detect_german passed") + + await test_detect_chinese() + print("✓ test_detect_chinese passed") + + await test_already_target_language() + print("✓ test_already_target_language passed") + + await test_short_text() + print("✓ test_short_text passed") + + await test_empty_text() + print("✓ test_empty_text passed") + + await test_confidence_threshold() + print("✓ test_confidence_threshold passed") + + await test_mixed_language_text() + print("✓ test_mixed_language_text passed") + + print("\nAll language detection tests passed!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/tasks/translation/integration_test.py b/cognee/tests/tasks/translation/integration_test.py new file mode 100644 index 000000000..98fcae5d5 --- /dev/null +++ b/cognee/tests/tasks/translation/integration_test.py @@ -0,0 +1,255 @@ +""" +Integration tests for multilingual content translation feature. + +Tests the full cognify pipeline with translation enabled. +""" + +import asyncio +import os + +from cognee import add, cognify, prune, search, SearchType +from cognee.tasks.translation import translate_text +from cognee.tasks.translation.detect_language import detect_language_async + + +def has_openai_key(): + """Check if OpenAI API key is available""" + return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY")) + + +async def test_quick_translation(): + """Quick smoke test for translation feature""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + await prune.prune_data() + await prune.prune_system(metadata=True) + + spanish_text = "La inteligencia artificial está transformando el mundo de la tecnología." + await add(spanish_text, dataset_name="spanish_test") + + result = await cognify( + datasets=["spanish_test"], + auto_translate=True, + target_language="en", + translation_provider="openai", + ) + + assert result is not None + + +async def test_translation_basic(): + """Test basic translation functionality with English text""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + await prune.prune_data() + await prune.prune_system(metadata=True) + + english_text = "Hello, this is a test document about artificial intelligence." + await add(english_text, dataset_name="test_english") + + result = await cognify( + datasets=["test_english"], + auto_translate=True, + target_language="en", + translation_provider="openai", + ) + + assert result is not None + + search_results = await search( + query_text="What is this document about?", + query_type=SearchType.SUMMARIES, + ) + assert search_results is not None + + +async def test_translation_spanish(): + """Test translation with Spanish text""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + await prune.prune_data() + await prune.prune_system(metadata=True) + + spanish_text = """ + La inteligencia artificial es una rama de la informática que se centra en + crear sistemas capaces de realizar tareas que normalmente requieren inteligencia humana. + Estos sistemas pueden aprender, razonar y resolver problemas complejos. + """ + + await add(spanish_text, dataset_name="test_spanish") + + result = await cognify( + datasets=["test_spanish"], + auto_translate=True, + target_language="en", + translation_provider="openai", + ) + + assert result is not None + + search_results = await search( + query_text="What is artificial intelligence?", + query_type=SearchType.SUMMARIES, + ) + assert search_results is not None + + +async def test_translation_french(): + """Test translation with French text""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + await prune.prune_data() + await prune.prune_system(metadata=True) + + french_text = """ + L'apprentissage automatique est une méthode d'analyse de données qui + automatise la construction de modèles analytiques. C'est une branche + de l'intelligence artificielle basée sur l'idée que les systèmes peuvent + apprendre à partir de données, identifier des modèles et prendre des décisions. + """ + + await add(french_text, dataset_name="test_french") + + result = await cognify( + datasets=["test_french"], + auto_translate=True, + target_language="en", + ) + + assert result is not None + + search_results = await search( + query_text="What is machine learning?", + query_type=SearchType.SUMMARIES, + ) + assert search_results is not None + + +async def test_translation_disabled(): + """Test that cognify works without translation""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + await prune.prune_data() + await prune.prune_system(metadata=True) + + text = "This is a baseline test without translation enabled." + await add(text, dataset_name="test_baseline") + + result = await cognify( + datasets=["test_baseline"], + auto_translate=False, + ) + + assert result is not None + + +async def test_translation_mixed_languages(): + """Test with multiple documents in different languages""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + await prune.prune_data() + await prune.prune_system(metadata=True) + + texts = [ + "Artificial intelligence is transforming the world.", + "La tecnología está cambiando nuestras vidas.", + "Les ordinateurs deviennent de plus en plus puissants.", + ] + + for text in texts: + await add(text, dataset_name="test_mixed") + + result = await cognify( + datasets=["test_mixed"], + auto_translate=True, + target_language="en", + ) + + assert result is not None + + search_results = await search( + query_text="What topics are discussed?", + query_type=SearchType.SUMMARIES, + ) + assert search_results is not None + + +async def test_direct_translation_function(): + """Test the translate_text convenience function directly""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + result = await translate_text( + text="Hola, ¿cómo estás? Espero que tengas un buen día.", + target_language="en", + translation_provider="openai", + ) + + assert result.translated_text is not None + assert result.translated_text != "" + assert result.target_language == "en" + assert result.provider == "openai" + + +async def test_language_detection(): + """Test language detection directly""" + test_texts = [ + ("Hello world, how are you doing today?", "en", False), + ("Bonjour le monde, comment allez-vous aujourd'hui?", "en", True), + ("Hola mundo, cómo estás hoy?", "en", True), + ("This is already in English language", "en", False), + ] + + for text, target_lang, should_translate in test_texts: + result = await detect_language_async(text, target_lang) + assert result.language_code is not None + assert result.confidence > 0.0 + # Only check requires_translation for high-confidence detections + if result.confidence > 0.8: + assert result.requires_translation == should_translate + + +async def main(): + """Run all translation integration tests""" + await test_quick_translation() + print("✓ test_quick_translation passed") + + await test_language_detection() + print("✓ test_language_detection passed") + + await test_direct_translation_function() + print("✓ test_direct_translation_function passed") + + await test_translation_basic() + print("✓ test_translation_basic passed") + + await test_translation_spanish() + print("✓ test_translation_spanish passed") + + await test_translation_french() + print("✓ test_translation_french passed") + + await test_translation_disabled() + print("✓ test_translation_disabled passed") + + await test_translation_mixed_languages() + print("✓ test_translation_mixed_languages passed") + + print("\nAll translation integration tests passed!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/tasks/translation/providers_test.py b/cognee/tests/tasks/translation/providers_test.py new file mode 100644 index 000000000..5be88a5ed --- /dev/null +++ b/cognee/tests/tasks/translation/providers_test.py @@ -0,0 +1,201 @@ +""" +Unit tests for translation providers +""" + +import asyncio +import os +from cognee.tasks.translation.providers import ( + get_translation_provider, + OpenAITranslationProvider, + TranslationResult, +) +from cognee.tasks.translation.exceptions import TranslationError + + +def has_openai_key(): + """Check if OpenAI API key is available""" + return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY")) + + +async def test_openai_provider_basic_translation(): + """Test basic translation with OpenAI provider""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + provider = OpenAITranslationProvider() + + result = await provider.translate(text="Hola mundo", target_language="en", source_language="es") + + assert isinstance(result, TranslationResult) + assert result.translated_text is not None + assert len(result.translated_text) > 0 + assert result.source_language == "es" + assert result.target_language == "en" + assert result.provider == "openai" + + +async def test_openai_provider_auto_detect_source(): + """Test translation with automatic source language detection""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + provider = OpenAITranslationProvider() + + result = await provider.translate( + text="Bonjour le monde", + target_language="en", + # source_language not provided - should auto-detect + ) + + assert result.translated_text is not None + assert result.target_language == "en" + + +async def test_openai_provider_long_text(): + """Test translation of longer text""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + provider = OpenAITranslationProvider() + + long_text = """ + La inteligencia artificial es una rama de la informática que se centra en + crear sistemas capaces de realizar tareas que normalmente requieren inteligencia humana. + Estos sistemas pueden aprender, razonar y resolver problemas complejos. + """ + + result = await provider.translate(text=long_text, target_language="en", source_language="es") + + assert len(result.translated_text) > 0 + assert result.source_language == "es" + + +def test_get_translation_provider_factory(): + """Test provider factory function""" + provider = get_translation_provider("openai") + assert isinstance(provider, OpenAITranslationProvider) + + +def test_get_translation_provider_invalid(): + """Test provider factory with invalid provider name""" + try: + get_translation_provider("invalid_provider") + assert False, "Expected TranslationError or ValueError" + except (TranslationError, ValueError): + pass + + +async def test_openai_batch_translation(): + """Test batch translation with OpenAI provider""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + provider = OpenAITranslationProvider() + + texts = ["Hola", "¿Cómo estás?", "Adiós"] + + results = await provider.translate_batch( + texts=texts, target_language="en", source_language="es" + ) + + assert len(results) == len(texts) + for result in results: + assert isinstance(result, TranslationResult) + assert result.translated_text is not None + assert result.source_language == "es" + assert result.target_language == "en" + + +async def test_translation_preserves_formatting(): + """Test that translation preserves basic formatting""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + provider = OpenAITranslationProvider() + + text_with_newlines = "Primera línea.\nSegunda línea." + + result = await provider.translate( + text=text_with_newlines, target_language="en", source_language="es" + ) + + # Should preserve structure (though exact newlines may vary) + assert result.translated_text is not None + assert len(result.translated_text) > 0 + + +async def test_translation_special_characters(): + """Test translation with special characters""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + provider = OpenAITranslationProvider() + + text = "¡Hola! ¿Cómo estás? Está bien." + + result = await provider.translate(text=text, target_language="en", source_language="es") + + assert result.translated_text is not None + assert len(result.translated_text) > 0 + + +async def test_empty_text_translation(): + """Test translation with empty text - should return empty or handle gracefully""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + provider = OpenAITranslationProvider() + + # Empty text may either raise an error or return an empty result + try: + result = await provider.translate(text="", target_language="en", source_language="es") + # If no error, should return a TranslationResult (possibly with empty text) + assert isinstance(result, TranslationResult) + except TranslationError: + # This is also acceptable behavior + pass + + +async def main(): + """Run all provider tests""" + # Sync tests + test_get_translation_provider_factory() + print("✓ test_get_translation_provider_factory passed") + + test_get_translation_provider_invalid() + print("✓ test_get_translation_provider_invalid passed") + + # Async tests + await test_openai_provider_basic_translation() + print("✓ test_openai_provider_basic_translation passed") + + await test_openai_provider_auto_detect_source() + print("✓ test_openai_provider_auto_detect_source passed") + + await test_openai_provider_long_text() + print("✓ test_openai_provider_long_text passed") + + await test_openai_batch_translation() + print("✓ test_openai_batch_translation passed") + + await test_translation_preserves_formatting() + print("✓ test_translation_preserves_formatting passed") + + await test_translation_special_characters() + print("✓ test_translation_special_characters passed") + + await test_empty_text_translation() + print("✓ test_empty_text_translation passed") + + print("\nAll provider tests passed!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/tasks/translation/translate_content_test.py b/cognee/tests/tasks/translation/translate_content_test.py new file mode 100644 index 000000000..0d92e339e --- /dev/null +++ b/cognee/tests/tasks/translation/translate_content_test.py @@ -0,0 +1,256 @@ +""" +Unit tests for translate_content task +""" + +import asyncio +import os +from uuid import uuid4 +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.tasks.translation import translate_content +from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata + + +def has_openai_key(): + """Check if OpenAI API key is available""" + return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY")) + + +def create_test_chunk(text: str, chunk_index: int = 0): + """Helper to create a DocumentChunk with all required fields""" + # Create a minimal Document for the is_part_of field + doc = TextDocument( + id=uuid4(), + name="test_doc", + raw_data_location="/tmp/test.txt", + external_metadata=None, + mime_type="text/plain", + ) + + return DocumentChunk( + id=uuid4(), + text=text, + chunk_index=chunk_index, + chunk_size=len(text), + cut_type="sentence", + is_part_of=doc, + ) + + +async def test_translate_content_basic(): + """Test basic content translation""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + # Create test chunk with Spanish text + original_text = "Hola mundo, esta es una prueba." + chunk = create_test_chunk(original_text) + + result = await translate_content( + data_chunks=[chunk], target_language="en", translation_provider="openai" + ) + + assert len(result) == 1 + # The chunk's text should now be translated (different from original Spanish) + assert result[0].text != original_text # Text should be translated to English + assert result[0].contains is not None + + # Check for TranslatedContent in contains + has_translated_content = any(isinstance(item, TranslatedContent) for item in result[0].contains) + assert has_translated_content + + +async def test_translate_content_preserves_original(): + """Test that original text is preserved""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + original_text = "Bonjour le monde" + chunk = create_test_chunk(original_text) + + result = await translate_content( + data_chunks=[chunk], target_language="en", preserve_original=True + ) + + # Find TranslatedContent in contains + translated_content = None + for item in result[0].contains: + if isinstance(item, TranslatedContent): + translated_content = item + break + + assert translated_content is not None + assert translated_content.original_text == original_text + assert translated_content.translated_text != original_text + + +async def test_translate_content_skip_english(): + """Test skipping translation for English text""" + # This test doesn't require API call since English text is skipped + chunk = create_test_chunk("Hello world, this is a test.") + + result = await translate_content( + data_chunks=[chunk], target_language="en", skip_if_target_language=True + ) + + # Text should remain unchanged + assert result[0].text == chunk.text + + # Should have LanguageMetadata but not TranslatedContent + has_language_metadata = any( + isinstance(item, LanguageMetadata) for item in (result[0].contains or []) + ) + has_translated_content = any( + isinstance(item, TranslatedContent) for item in (result[0].contains or []) + ) + + assert has_language_metadata + assert not has_translated_content + + +async def test_translate_content_multiple_chunks(): + """Test translation of multiple chunks""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + # Use longer texts to ensure reliable language detection + original_texts = [ + "Hola mundo, esta es una prueba de traducción.", + "Bonjour le monde, ceci est un test de traduction.", + "Ciao mondo, questo è un test di traduzione.", + ] + chunks = [create_test_chunk(text, i) for i, text in enumerate(original_texts)] + + result = await translate_content(data_chunks=chunks, target_language="en") + + assert len(result) == 3 + # Check that at least some chunks were translated + translated_count = sum( + 1 + for chunk in result + if any(isinstance(item, TranslatedContent) for item in (chunk.contains or [])) + ) + assert translated_count >= 2 # At least 2 chunks should be translated + + +async def test_translate_content_empty_list(): + """Test with empty chunk list""" + result = await translate_content(data_chunks=[], target_language="en") + + assert result == [] + + +async def test_translate_content_empty_text(): + """Test with chunk containing empty text""" + chunk = create_test_chunk("") + + result = await translate_content(data_chunks=[chunk], target_language="en") + + assert len(result) == 1 + assert result[0].text == "" + + +async def test_translate_content_language_metadata(): + """Test that LanguageMetadata is created correctly""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + # Use a longer, distinctly Spanish text to ensure reliable detection + chunk = create_test_chunk( + "La inteligencia artificial está cambiando el mundo de manera significativa" + ) + + result = await translate_content(data_chunks=[chunk], target_language="en") + + # Find LanguageMetadata + language_metadata = None + for item in result[0].contains: + if isinstance(item, LanguageMetadata): + language_metadata = item + break + + assert language_metadata is not None + # Just check that a language was detected (short texts can be ambiguous) + assert language_metadata.detected_language is not None + assert language_metadata.requires_translation is True + assert language_metadata.language_confidence > 0.0 + + +async def test_translate_content_confidence_threshold(): + """Test with custom confidence threshold""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + # Use longer text for more reliable detection + chunk = create_test_chunk("Hola mundo, esta es una frase más larga para mejor detección") + + result = await translate_content( + data_chunks=[chunk], target_language="en", confidence_threshold=0.5 + ) + + assert len(result) == 1 + + +async def test_translate_content_no_preserve_original(): + """Test translation without preserving original""" + if not has_openai_key(): + print(" (skipped - no API key)") + return + + # Use longer text for more reliable detection + chunk = create_test_chunk("Bonjour le monde, comment allez-vous aujourd'hui") + + result = await translate_content( + data_chunks=[chunk], target_language="en", preserve_original=False + ) + + # Find TranslatedContent + translated_content = None + for item in result[0].contains: + if isinstance(item, TranslatedContent): + translated_content = item + break + + assert translated_content is not None + assert translated_content.original_text == "" # Should be empty + + +async def main(): + """Run all translate_content tests""" + await test_translate_content_basic() + print("✓ test_translate_content_basic passed") + + await test_translate_content_preserves_original() + print("✓ test_translate_content_preserves_original passed") + + await test_translate_content_skip_english() + print("✓ test_translate_content_skip_english passed") + + await test_translate_content_multiple_chunks() + print("✓ test_translate_content_multiple_chunks passed") + + await test_translate_content_empty_list() + print("✓ test_translate_content_empty_list passed") + + await test_translate_content_empty_text() + print("✓ test_translate_content_empty_text passed") + + await test_translate_content_language_metadata() + print("✓ test_translate_content_language_metadata passed") + + await test_translate_content_confidence_threshold() + print("✓ test_translate_content_confidence_threshold passed") + + await test_translate_content_no_preserve_original() + print("✓ test_translate_content_no_preserve_original passed") + + print("\nAll translate_content tests passed!") + + +if __name__ == "__main__": + asyncio.run(main())