diff --git a/cognee/tasks/translation/config.py b/cognee/tasks/translation/config.py index 99ed560de..db8a23870 100644 --- a/cognee/tasks/translation/config.py +++ b/cognee/tasks/translation/config.py @@ -1,6 +1,7 @@ from functools import lru_cache from typing import Literal, Optional +from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -23,7 +24,7 @@ class TranslationConfig(BaseSettings): # Translation provider settings translation_provider: TranslationProviderType = "openai" target_language: str = "en" - confidence_threshold: float = 0.8 + confidence_threshold: float = Field(default=0.8, ge=0.0, le=1.0) # Google Translate settings google_translate_api_key: Optional[str] = None @@ -57,7 +58,7 @@ class TranslationConfig(BaseSettings): } -@lru_cache +@lru_cache() def get_translation_config() -> TranslationConfig: """Get the translation configuration singleton.""" return TranslationConfig() diff --git a/cognee/tasks/translation/detect_language.py b/cognee/tasks/translation/detect_language.py index e223083c0..00b0bf012 100644 --- a/cognee/tasks/translation/detect_language.py +++ b/cognee/tasks/translation/detect_language.py @@ -88,7 +88,7 @@ def get_language_name(language_code: str) -> str: def detect_language( text: str, target_language: str = "en", - confidence_threshold: float = None, + confidence_threshold: Optional[float] = None, ) -> LanguageDetectionResult: """ Detect the language of the given text. @@ -184,7 +184,7 @@ async def detect_language_async( """ import asyncio - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() return await loop.run_in_executor( None, detect_language, text, target_language, confidence_threshold ) diff --git a/cognee/tasks/translation/exceptions.py b/cognee/tasks/translation/exceptions.py index ba5e74510..d5db128de 100644 --- a/cognee/tasks/translation/exceptions.py +++ b/cognee/tasks/translation/exceptions.py @@ -38,6 +38,7 @@ class UnsupportedLanguageError(TranslationError): language: str, provider: str = None, message: str = None, + original_error: Exception = None, ): self.language = language self.provider = provider @@ -45,11 +46,15 @@ class UnsupportedLanguageError(TranslationError): message = f"Language '{language}' is not supported" if provider: message += f" by {provider}" - super().__init__(message) + super().__init__(message, original_error) class TranslationConfigError(TranslationError): """Exception raised when translation configuration is invalid.""" - def __init__(self, message: str = "Invalid translation configuration"): - super().__init__(message) + def __init__( + self, + message: str = "Invalid translation configuration", + original_error: Exception = None, + ): + super().__init__(message, original_error) diff --git a/cognee/tasks/translation/providers/__init__.py b/cognee/tasks/translation/providers/__init__.py index 79a28a586..2fb8480ef 100644 --- a/cognee/tasks/translation/providers/__init__.py +++ b/cognee/tasks/translation/providers/__init__.py @@ -9,6 +9,7 @@ __all__ = [ "OpenAITranslationProvider", "GoogleTranslationProvider", "AzureTranslationProvider", + "get_translation_provider", ] diff --git a/cognee/tasks/translation/providers/azure_provider.py b/cognee/tasks/translation/providers/azure_provider.py index 2ee1f45d7..349445ca1 100644 --- a/cognee/tasks/translation/providers/azure_provider.py +++ b/cognee/tasks/translation/providers/azure_provider.py @@ -1,4 +1,3 @@ -import asyncio from typing import Optional import aiohttp @@ -142,12 +141,12 @@ class AzureTranslationProvider(TranslationProvider): batch_size = min(100, self._config.batch_size) all_results = [] - for i in range(0, len(texts), batch_size): - batch = texts[i : i + batch_size] - body = [{"text": text} for text in batch] + try: + async with aiohttp.ClientSession() as session: + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + body = [{"text": text} for text in batch] - try: - async with aiohttp.ClientSession() as session: async with session.post( endpoint, params=params, @@ -158,24 +157,24 @@ class AzureTranslationProvider(TranslationProvider): response.raise_for_status() results = await response.json() - for result in results: - translation = result["translations"][0] - detected_language = result.get("detectedLanguage", {}) + for result in results: + translation = result["translations"][0] + detected_language = result.get("detectedLanguage", {}) - all_results.append( - TranslationResult( - translated_text=translation["text"], - source_language=source_language - or detected_language.get("language", "unknown"), - target_language=target_language, - confidence_score=detected_language.get("score", 0.9), - provider=self.provider_name, - raw_response=result, + all_results.append( + TranslationResult( + translated_text=translation["text"], + source_language=source_language + or detected_language.get("language", "unknown"), + target_language=target_language, + confidence_score=detected_language.get("score", 0.9), + provider=self.provider_name, + raw_response=result, + ) ) - ) - except Exception as e: - logger.error(f"Azure batch translation failed: {e}") - raise + except Exception as e: + logger.error(f"Azure batch translation failed: {e}") + raise return all_results diff --git a/cognee/tasks/translation/providers/base.py b/cognee/tasks/translation/providers/base.py index d8e5e981e..c92f2f552 100644 --- a/cognee/tasks/translation/providers/base.py +++ b/cognee/tasks/translation/providers/base.py @@ -64,6 +64,13 @@ class TranslationProvider(ABC): """ pass + @abstractmethod def is_available(self) -> bool: - """Check if this provider is available (has required credentials).""" - return True + """Check if this provider is available (has required credentials). + + All providers must implement this method to validate their credentials. + + Returns: + True if the provider has valid credentials and is ready to use. + """ + pass diff --git a/cognee/tasks/translation/providers/google_provider.py b/cognee/tasks/translation/providers/google_provider.py index 0a7373b54..f007575cd 100644 --- a/cognee/tasks/translation/providers/google_provider.py +++ b/cognee/tasks/translation/providers/google_provider.py @@ -48,7 +48,8 @@ class GoogleTranslationProvider(TranslationProvider): try: self._get_client() return True - except Exception: + except Exception as e: + logger.debug(f"Google Translate not available: {e}") return False async def translate( @@ -72,7 +73,7 @@ class GoogleTranslationProvider(TranslationProvider): client = self._get_client() # Run in thread pool since google-cloud-translate is synchronous - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() if source_language: result = await loop.run_in_executor( @@ -122,7 +123,7 @@ class GoogleTranslationProvider(TranslationProvider): """ try: client = self._get_client() - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() if source_language: results = await loop.run_in_executor( diff --git a/cognee/tasks/translation/providers/openai_provider.py b/cognee/tasks/translation/providers/openai_provider.py index a888d688e..95597e368 100644 --- a/cognee/tasks/translation/providers/openai_provider.py +++ b/cognee/tasks/translation/providers/openai_provider.py @@ -52,6 +52,15 @@ class OpenAITranslationProvider(TranslationProvider): try: system_prompt = read_query_prompt("translate_content.txt") + # Validate system prompt was loaded successfully + if system_prompt is None: + logger.warning("translate_content.txt prompt file not found, using default prompt") + system_prompt = ( + "You are a professional translator. Translate the given text accurately " + "while preserving the original meaning, tone, and style. " + "Detect the source language if not provided." + ) + # Build the input with context if source_language: input_text = ( @@ -75,6 +84,8 @@ class OpenAITranslationProvider(TranslationProvider): translated_text=result.translated_text, source_language=source_language or result.detected_source_language, target_language=target_language, + # TODO: Consider deriving confidence from LLM response metadata + # or making configurable via TranslationConfig confidence_score=0.95, # LLM translations are generally high quality provider=self.provider_name, raw_response={"notes": result.translation_notes}, @@ -103,3 +114,9 @@ class OpenAITranslationProvider(TranslationProvider): """ tasks = [self.translate(text, target_language, source_language) for text in texts] return await asyncio.gather(*tasks) + + def is_available(self) -> bool: + """Check if OpenAI provider is available (has required credentials).""" + import os + + return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY")) diff --git a/cognee/tasks/translation/translate_content.py b/cognee/tasks/translation/translate_content.py index 1c869b132..fcf6ae430 100644 --- a/cognee/tasks/translation/translate_content.py +++ b/cognee/tasks/translation/translate_content.py @@ -44,6 +44,13 @@ async def translate_content( Chunks that required translation will have TranslatedContent objects in their 'contains' list. + Note: + This function mutates the input chunks in-place. Specifically: + - chunk.text is replaced with the translated text + - chunk.contains is updated with LanguageMetadata and TranslatedContent + The original text is preserved in TranslatedContent.original_text + if preserve_original=True. + Example: ```python from cognee.tasks.translation import translate_content diff --git a/cognee/tests/tasks/translation/config_test.py b/cognee/tests/tasks/translation/config_test.py index ee8d6019c..80f76a5f0 100644 --- a/cognee/tests/tasks/translation/config_test.py +++ b/cognee/tests/tasks/translation/config_test.py @@ -2,7 +2,6 @@ Unit tests for translation configuration """ -import os from typing import get_args from cognee.tasks.translation.config import ( get_translation_config, @@ -15,9 +14,15 @@ def test_default_translation_config(): """Test default translation configuration""" config = get_translation_config() - assert isinstance(config, TranslationConfig) - assert config.translation_provider in ["openai", "google", "azure"] - assert 0.0 <= config.confidence_threshold <= 1.0 + assert isinstance(config, TranslationConfig), "Config should be TranslationConfig instance" + assert config.translation_provider in [ + "openai", + "google", + "azure", + ], f"Invalid provider: {config.translation_provider}" + assert 0.0 <= config.confidence_threshold <= 1.0, ( + f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]" + ) def test_translation_provider_type_literal(): @@ -25,17 +30,52 @@ def test_translation_provider_type_literal(): # Get the allowed values from the Literal type allowed_values = get_args(TranslationProviderType) - assert "openai" in allowed_values - assert "google" in allowed_values - assert "azure" in allowed_values - assert len(allowed_values) == 3 + assert "openai" in allowed_values, "openai should be an allowed provider" + assert "google" in allowed_values, "google should be an allowed provider" + assert "azure" in allowed_values, "azure should be an allowed provider" + assert len(allowed_values) == 3, f"Expected 3 providers, got {len(allowed_values)}" def test_confidence_threshold_bounds(): """Test confidence threshold validation""" config = TranslationConfig(translation_provider="openai", confidence_threshold=0.9) - assert 0.0 <= config.confidence_threshold <= 1.0 + assert 0.0 <= config.confidence_threshold <= 1.0, ( + f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]" + ) + + +def test_confidence_threshold_validation(): + """Test that invalid confidence thresholds are rejected or clamped""" + # Test boundary values - these should work + config_min = TranslationConfig(translation_provider="openai", confidence_threshold=0.0) + assert config_min.confidence_threshold == 0.0, "Minimum bound (0.0) should be valid" + + config_max = TranslationConfig(translation_provider="openai", confidence_threshold=1.0) + assert config_max.confidence_threshold == 1.0, "Maximum bound (1.0) should be valid" + + # Test invalid values - these should either raise ValidationError or be clamped + try: + config_invalid_low = TranslationConfig( + translation_provider="openai", confidence_threshold=-0.1 + ) + # If no error, verify it was clamped to valid range + assert 0.0 <= config_invalid_low.confidence_threshold <= 1.0, ( + f"Invalid low value should be clamped, got {config_invalid_low.confidence_threshold}" + ) + except Exception: + pass # Expected validation error + + try: + config_invalid_high = TranslationConfig( + translation_provider="openai", confidence_threshold=1.5 + ) + # If no error, verify it was clamped to valid range + assert 0.0 <= config_invalid_high.confidence_threshold <= 1.0, ( + f"Invalid high value should be clamped, got {config_invalid_high.confidence_threshold}" + ) + except Exception: + pass # Expected validation error def test_multiple_provider_keys(): @@ -46,21 +86,5 @@ def test_multiple_provider_keys(): azure_translator_key="azure_key", ) - assert config.google_translate_api_key == "google_key" - assert config.azure_translator_key == "azure_key" - - -if __name__ == "__main__": - test_default_translation_config() - print("✓ test_default_translation_config passed") - - test_translation_provider_type_literal() - print("✓ test_translation_provider_type_literal passed") - - test_confidence_threshold_bounds() - print("✓ test_confidence_threshold_bounds passed") - - test_multiple_provider_keys() - print("✓ test_multiple_provider_keys passed") - - print("\nAll config tests passed!") + assert config.google_translate_api_key == "google_key", "Google API key not set correctly" + assert config.azure_translator_key == "azure_key", "Azure API key not set correctly" diff --git a/cognee/tests/tasks/translation/detect_language_test.py b/cognee/tests/tasks/translation/detect_language_test.py index 907c94df8..3845777ba 100644 --- a/cognee/tests/tasks/translation/detect_language_test.py +++ b/cognee/tests/tasks/translation/detect_language_test.py @@ -2,7 +2,7 @@ Unit tests for language detection functionality """ -import asyncio +import pytest from cognee.tasks.translation.detect_language import ( detect_language_async, LanguageDetectionResult, @@ -10,6 +10,7 @@ from cognee.tasks.translation.detect_language import ( from cognee.tasks.translation.exceptions import LanguageDetectionError +@pytest.mark.asyncio async def test_detect_english(): """Test detection of English text""" result = await detect_language_async("Hello world, this is a test.", target_language="en") @@ -20,6 +21,7 @@ async def test_detect_english(): assert result.language_name == "English" +@pytest.mark.asyncio async def test_detect_spanish(): """Test detection of Spanish text""" result = await detect_language_async("Hola mundo, esta es una prueba.", target_language="en") @@ -30,6 +32,7 @@ async def test_detect_spanish(): assert result.language_name == "Spanish" +@pytest.mark.asyncio async def test_detect_french(): """Test detection of French text""" result = await detect_language_async( @@ -42,6 +45,7 @@ async def test_detect_french(): assert result.language_name == "French" +@pytest.mark.asyncio async def test_detect_german(): """Test detection of German text""" result = await detect_language_async("Hallo Welt, das ist ein Test.", target_language="en") @@ -51,15 +55,17 @@ async def test_detect_german(): assert result.confidence > 0.9 +@pytest.mark.asyncio async def test_detect_chinese(): """Test detection of Chinese text""" result = await detect_language_async("你好世界,这是一个测试。", target_language="en") - assert result.language_code == "zh-cn" + assert result.language_code.startswith("zh"), f"Expected Chinese, got {result.language_code}" assert result.requires_translation is True assert result.confidence > 0.9 +@pytest.mark.asyncio async def test_already_target_language(): """Test when text is already in target language""" result = await detect_language_async("This text is already in English.", target_language="en") @@ -67,6 +73,7 @@ async def test_already_target_language(): assert result.requires_translation is False +@pytest.mark.asyncio async def test_short_text(): """Test detection with very short text""" result = await detect_language_async("Hi", target_language="es") @@ -76,6 +83,7 @@ async def test_short_text(): assert result.character_count == 2 +@pytest.mark.asyncio async def test_empty_text(): """Test detection with empty text - returns unknown by default""" result = await detect_language_async("", target_language="en") @@ -88,6 +96,7 @@ async def test_empty_text(): assert result.character_count == 0 +@pytest.mark.asyncio async def test_confidence_threshold(): """Test detection respects confidence threshold""" result = await detect_language_async( @@ -97,6 +106,7 @@ async def test_confidence_threshold(): assert result.confidence >= 0.5 +@pytest.mark.asyncio async def test_mixed_language_text(): """Test detection with mixed language text (predominantly one language)""" # Predominantly Spanish with English word @@ -106,42 +116,3 @@ async def test_mixed_language_text(): assert result.language_code == "es" # Should detect as Spanish assert result.requires_translation is True - - -async def main(): - """Run all language detection tests""" - await test_detect_english() - print("✓ test_detect_english passed") - - await test_detect_spanish() - print("✓ test_detect_spanish passed") - - await test_detect_french() - print("✓ test_detect_french passed") - - await test_detect_german() - print("✓ test_detect_german passed") - - await test_detect_chinese() - print("✓ test_detect_chinese passed") - - await test_already_target_language() - print("✓ test_already_target_language passed") - - await test_short_text() - print("✓ test_short_text passed") - - await test_empty_text() - print("✓ test_empty_text passed") - - await test_confidence_threshold() - print("✓ test_confidence_threshold passed") - - await test_mixed_language_text() - print("✓ test_mixed_language_text passed") - - print("\nAll language detection tests passed!") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tests/tasks/translation/integration_test.py b/cognee/tests/tasks/translation/integration_test.py index 98fcae5d5..ff2877959 100644 --- a/cognee/tests/tasks/translation/integration_test.py +++ b/cognee/tests/tasks/translation/integration_test.py @@ -4,9 +4,10 @@ Integration tests for multilingual content translation feature. Tests the full cognify pipeline with translation enabled. """ -import asyncio import os +import pytest + from cognee import add, cognify, prune, search, SearchType from cognee.tasks.translation import translate_text from cognee.tasks.translation.detect_language import detect_language_async @@ -17,12 +18,10 @@ def has_openai_key(): return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY")) +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_quick_translation(): """Quick smoke test for translation feature""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - await prune.prune_data() await prune.prune_system(metadata=True) @@ -39,12 +38,10 @@ async def test_quick_translation(): assert result is not None +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translation_basic(): """Test basic translation functionality with English text""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - await prune.prune_data() await prune.prune_system(metadata=True) @@ -67,12 +64,10 @@ async def test_translation_basic(): assert search_results is not None +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translation_spanish(): """Test translation with Spanish text""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - await prune.prune_data() await prune.prune_system(metadata=True) @@ -100,12 +95,10 @@ async def test_translation_spanish(): assert search_results is not None +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translation_french(): """Test translation with French text""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - await prune.prune_data() await prune.prune_system(metadata=True) @@ -133,12 +126,10 @@ async def test_translation_french(): assert search_results is not None +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translation_disabled(): """Test that cognify works without translation""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - await prune.prune_data() await prune.prune_system(metadata=True) @@ -153,12 +144,10 @@ async def test_translation_disabled(): assert result is not None +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translation_mixed_languages(): """Test with multiple documents in different languages""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - await prune.prune_data() await prune.prune_system(metadata=True) @@ -186,12 +175,10 @@ async def test_translation_mixed_languages(): assert search_results is not None +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_direct_translation_function(): """Test the translate_text convenience function directly""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - result = await translate_text( text="Hola, ¿cómo estás? Espero que tengas un buen día.", target_language="en", @@ -204,6 +191,7 @@ async def test_direct_translation_function(): assert result.provider == "openai" +@pytest.mark.asyncio async def test_language_detection(): """Test language detection directly""" test_texts = [ @@ -220,36 +208,3 @@ async def test_language_detection(): # Only check requires_translation for high-confidence detections if result.confidence > 0.8: assert result.requires_translation == should_translate - - -async def main(): - """Run all translation integration tests""" - await test_quick_translation() - print("✓ test_quick_translation passed") - - await test_language_detection() - print("✓ test_language_detection passed") - - await test_direct_translation_function() - print("✓ test_direct_translation_function passed") - - await test_translation_basic() - print("✓ test_translation_basic passed") - - await test_translation_spanish() - print("✓ test_translation_spanish passed") - - await test_translation_french() - print("✓ test_translation_french passed") - - await test_translation_disabled() - print("✓ test_translation_disabled passed") - - await test_translation_mixed_languages() - print("✓ test_translation_mixed_languages passed") - - print("\nAll translation integration tests passed!") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tests/tasks/translation/providers_test.py b/cognee/tests/tasks/translation/providers_test.py index 5be88a5ed..243a66fe8 100644 --- a/cognee/tests/tasks/translation/providers_test.py +++ b/cognee/tests/tasks/translation/providers_test.py @@ -2,8 +2,10 @@ Unit tests for translation providers """ -import asyncio import os + +import pytest + from cognee.tasks.translation.providers import ( get_translation_provider, OpenAITranslationProvider, @@ -17,12 +19,10 @@ def has_openai_key(): return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY")) +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_openai_provider_basic_translation(): """Test basic translation with OpenAI provider""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - provider = OpenAITranslationProvider() result = await provider.translate(text="Hola mundo", target_language="en", source_language="es") @@ -35,12 +35,10 @@ async def test_openai_provider_basic_translation(): assert result.provider == "openai" +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_openai_provider_auto_detect_source(): """Test translation with automatic source language detection""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - provider = OpenAITranslationProvider() result = await provider.translate( @@ -53,12 +51,10 @@ async def test_openai_provider_auto_detect_source(): assert result.target_language == "en" +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_openai_provider_long_text(): """Test translation of longer text""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - provider = OpenAITranslationProvider() long_text = """ @@ -88,12 +84,10 @@ def test_get_translation_provider_invalid(): pass +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_openai_batch_translation(): """Test batch translation with OpenAI provider""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - provider = OpenAITranslationProvider() texts = ["Hola", "¿Cómo estás?", "Adiós"] @@ -110,12 +104,10 @@ async def test_openai_batch_translation(): assert result.target_language == "en" +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translation_preserves_formatting(): """Test that translation preserves basic formatting""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - provider = OpenAITranslationProvider() text_with_newlines = "Primera línea.\nSegunda línea." @@ -129,12 +121,10 @@ async def test_translation_preserves_formatting(): assert len(result.translated_text) > 0 +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translation_special_characters(): """Test translation with special characters""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - provider = OpenAITranslationProvider() text = "¡Hola! ¿Cómo estás? Está bien." @@ -145,12 +135,10 @@ async def test_translation_special_characters(): assert len(result.translated_text) > 0 +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_empty_text_translation(): """Test translation with empty text - should return empty or handle gracefully""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - provider = OpenAITranslationProvider() # Empty text may either raise an error or return an empty result @@ -161,41 +149,3 @@ async def test_empty_text_translation(): except TranslationError: # This is also acceptable behavior pass - - -async def main(): - """Run all provider tests""" - # Sync tests - test_get_translation_provider_factory() - print("✓ test_get_translation_provider_factory passed") - - test_get_translation_provider_invalid() - print("✓ test_get_translation_provider_invalid passed") - - # Async tests - await test_openai_provider_basic_translation() - print("✓ test_openai_provider_basic_translation passed") - - await test_openai_provider_auto_detect_source() - print("✓ test_openai_provider_auto_detect_source passed") - - await test_openai_provider_long_text() - print("✓ test_openai_provider_long_text passed") - - await test_openai_batch_translation() - print("✓ test_openai_batch_translation passed") - - await test_translation_preserves_formatting() - print("✓ test_translation_preserves_formatting passed") - - await test_translation_special_characters() - print("✓ test_translation_special_characters passed") - - await test_empty_text_translation() - print("✓ test_empty_text_translation passed") - - print("\nAll provider tests passed!") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tests/tasks/translation/translate_content_test.py b/cognee/tests/tasks/translation/translate_content_test.py index 0d92e339e..35b5e60b3 100644 --- a/cognee/tests/tasks/translation/translate_content_test.py +++ b/cognee/tests/tasks/translation/translate_content_test.py @@ -2,9 +2,11 @@ Unit tests for translate_content task """ -import asyncio import os from uuid import uuid4 + +import pytest + from cognee.modules.chunking.models import DocumentChunk from cognee.modules.data.processing.document_types import TextDocument from cognee.tasks.translation import translate_content @@ -37,12 +39,10 @@ def create_test_chunk(text: str, chunk_index: int = 0): ) +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translate_content_basic(): """Test basic content translation""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - # Create test chunk with Spanish text original_text = "Hola mundo, esta es una prueba." chunk = create_test_chunk(original_text) @@ -61,12 +61,10 @@ async def test_translate_content_basic(): assert has_translated_content +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translate_content_preserves_original(): """Test that original text is preserved""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - original_text = "Bonjour le monde" chunk = create_test_chunk(original_text) @@ -86,6 +84,7 @@ async def test_translate_content_preserves_original(): assert translated_content.translated_text != original_text +@pytest.mark.asyncio async def test_translate_content_skip_english(): """Test skipping translation for English text""" # This test doesn't require API call since English text is skipped @@ -110,12 +109,10 @@ async def test_translate_content_skip_english(): assert not has_translated_content +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translate_content_multiple_chunks(): """Test translation of multiple chunks""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - # Use longer texts to ensure reliable language detection original_texts = [ "Hola mundo, esta es una prueba de traducción.", @@ -136,6 +133,7 @@ async def test_translate_content_multiple_chunks(): assert translated_count >= 2 # At least 2 chunks should be translated +@pytest.mark.asyncio async def test_translate_content_empty_list(): """Test with empty chunk list""" result = await translate_content(data_chunks=[], target_language="en") @@ -143,6 +141,7 @@ async def test_translate_content_empty_list(): assert result == [] +@pytest.mark.asyncio async def test_translate_content_empty_text(): """Test with chunk containing empty text""" chunk = create_test_chunk("") @@ -153,12 +152,10 @@ async def test_translate_content_empty_text(): assert result[0].text == "" +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translate_content_language_metadata(): """Test that LanguageMetadata is created correctly""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - # Use a longer, distinctly Spanish text to ensure reliable detection chunk = create_test_chunk( "La inteligencia artificial está cambiando el mundo de manera significativa" @@ -180,12 +177,10 @@ async def test_translate_content_language_metadata(): assert language_metadata.language_confidence > 0.0 +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translate_content_confidence_threshold(): """Test with custom confidence threshold""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - # Use longer text for more reliable detection chunk = create_test_chunk("Hola mundo, esta es una frase más larga para mejor detección") @@ -196,12 +191,10 @@ async def test_translate_content_confidence_threshold(): assert len(result) == 1 +@pytest.mark.asyncio +@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available") async def test_translate_content_no_preserve_original(): """Test translation without preserving original""" - if not has_openai_key(): - print(" (skipped - no API key)") - return - # Use longer text for more reliable detection chunk = create_test_chunk("Bonjour le monde, comment allez-vous aujourd'hui") @@ -218,39 +211,3 @@ async def test_translate_content_no_preserve_original(): assert translated_content is not None assert translated_content.original_text == "" # Should be empty - - -async def main(): - """Run all translate_content tests""" - await test_translate_content_basic() - print("✓ test_translate_content_basic passed") - - await test_translate_content_preserves_original() - print("✓ test_translate_content_preserves_original passed") - - await test_translate_content_skip_english() - print("✓ test_translate_content_skip_english passed") - - await test_translate_content_multiple_chunks() - print("✓ test_translate_content_multiple_chunks passed") - - await test_translate_content_empty_list() - print("✓ test_translate_content_empty_list passed") - - await test_translate_content_empty_text() - print("✓ test_translate_content_empty_text passed") - - await test_translate_content_language_metadata() - print("✓ test_translate_content_language_metadata passed") - - await test_translate_content_confidence_threshold() - print("✓ test_translate_content_confidence_threshold passed") - - await test_translate_content_no_preserve_original() - print("✓ test_translate_content_no_preserve_original passed") - - print("\nAll translate_content tests passed!") - - -if __name__ == "__main__": - asyncio.run(main())