refactor: address code review feedback
- Made is_available() abstract in base.py with proper implementation in providers
- Added original_error parameter to UnsupportedLanguageError and TranslationConfigError
- Added Field validation for confidence_threshold bounds (0.0-1.0)
- Changed @lru_cache to @lru_cache() for explicit style
- Added get_translation_provider to __all__ in providers/__init__.py
- Replaced deprecated asyncio.get_event_loop() with get_running_loop()
- Added debug logging to is_available() in GoogleTranslationProvider
- Added TODO comment for confidence score improvement in OpenAIProvider
- Added None check for read_query_prompt() with fallback default prompt
- Moved ClientSession outside batch loop in AzureTranslationProvider
- Fixed Optional[float] type annotation in detect_language()
- Added Note section documenting in-place mutation in translate_content()
- Added test_confidence_threshold_validation() for bounds testing
- Added descriptive assertion messages to config tests
- Converted all async tests to use @pytest.mark.asyncio decorators
- Replaced manual skip checks with @pytest.mark.skipif
- Removed manual main() blocks, tests now pytest-only
- Changed Chinese language assertion to use startswith('zh') for flexibility
This commit is contained in:
parent
05d34c9915
commit
69c25b43d7
14 changed files with 187 additions and 292 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from functools import lru_cache
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
|
|
@ -23,7 +24,7 @@ class TranslationConfig(BaseSettings):
|
|||
# Translation provider settings
|
||||
translation_provider: TranslationProviderType = "openai"
|
||||
target_language: str = "en"
|
||||
confidence_threshold: float = 0.8
|
||||
confidence_threshold: float = Field(default=0.8, ge=0.0, le=1.0)
|
||||
|
||||
# Google Translate settings
|
||||
google_translate_api_key: Optional[str] = None
|
||||
|
|
@ -57,7 +58,7 @@ class TranslationConfig(BaseSettings):
|
|||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
@lru_cache()
|
||||
def get_translation_config() -> TranslationConfig:
|
||||
"""Get the translation configuration singleton."""
|
||||
return TranslationConfig()
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ def get_language_name(language_code: str) -> str:
|
|||
def detect_language(
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
confidence_threshold: float = None,
|
||||
confidence_threshold: Optional[float] = None,
|
||||
) -> LanguageDetectionResult:
|
||||
"""
|
||||
Detect the language of the given text.
|
||||
|
|
@ -184,7 +184,7 @@ async def detect_language_async(
|
|||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, detect_language, text, target_language, confidence_threshold
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class UnsupportedLanguageError(TranslationError):
|
|||
language: str,
|
||||
provider: str = None,
|
||||
message: str = None,
|
||||
original_error: Exception = None,
|
||||
):
|
||||
self.language = language
|
||||
self.provider = provider
|
||||
|
|
@ -45,11 +46,15 @@ class UnsupportedLanguageError(TranslationError):
|
|||
message = f"Language '{language}' is not supported"
|
||||
if provider:
|
||||
message += f" by {provider}"
|
||||
super().__init__(message)
|
||||
super().__init__(message, original_error)
|
||||
|
||||
|
||||
class TranslationConfigError(TranslationError):
|
||||
"""Exception raised when translation configuration is invalid."""
|
||||
|
||||
def __init__(self, message: str = "Invalid translation configuration"):
|
||||
super().__init__(message)
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid translation configuration",
|
||||
original_error: Exception = None,
|
||||
):
|
||||
super().__init__(message, original_error)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ __all__ = [
|
|||
"OpenAITranslationProvider",
|
||||
"GoogleTranslationProvider",
|
||||
"AzureTranslationProvider",
|
||||
"get_translation_provider",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
|
@ -142,12 +141,12 @@ class AzureTranslationProvider(TranslationProvider):
|
|||
batch_size = min(100, self._config.batch_size)
|
||||
all_results = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
body = [{"text": text} for text in batch]
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
body = [{"text": text} for text in batch]
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
endpoint,
|
||||
params=params,
|
||||
|
|
@ -158,24 +157,24 @@ class AzureTranslationProvider(TranslationProvider):
|
|||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
|
||||
for result in results:
|
||||
translation = result["translations"][0]
|
||||
detected_language = result.get("detectedLanguage", {})
|
||||
for result in results:
|
||||
translation = result["translations"][0]
|
||||
detected_language = result.get("detectedLanguage", {})
|
||||
|
||||
all_results.append(
|
||||
TranslationResult(
|
||||
translated_text=translation["text"],
|
||||
source_language=source_language
|
||||
or detected_language.get("language", "unknown"),
|
||||
target_language=target_language,
|
||||
confidence_score=detected_language.get("score", 0.9),
|
||||
provider=self.provider_name,
|
||||
raw_response=result,
|
||||
all_results.append(
|
||||
TranslationResult(
|
||||
translated_text=translation["text"],
|
||||
source_language=source_language
|
||||
or detected_language.get("language", "unknown"),
|
||||
target_language=target_language,
|
||||
confidence_score=detected_language.get("score", 0.9),
|
||||
provider=self.provider_name,
|
||||
raw_response=result,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Azure batch translation failed: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Azure batch translation failed: {e}")
|
||||
raise
|
||||
|
||||
return all_results
|
||||
|
|
|
|||
|
|
@ -64,6 +64,13 @@ class TranslationProvider(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if this provider is available (has required credentials)."""
|
||||
return True
|
||||
"""Check if this provider is available (has required credentials).
|
||||
|
||||
All providers must implement this method to validate their credentials.
|
||||
|
||||
Returns:
|
||||
True if the provider has valid credentials and is ready to use.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ class GoogleTranslationProvider(TranslationProvider):
|
|||
try:
|
||||
self._get_client()
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.debug(f"Google Translate not available: {e}")
|
||||
return False
|
||||
|
||||
async def translate(
|
||||
|
|
@ -72,7 +73,7 @@ class GoogleTranslationProvider(TranslationProvider):
|
|||
client = self._get_client()
|
||||
|
||||
# Run in thread pool since google-cloud-translate is synchronous
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if source_language:
|
||||
result = await loop.run_in_executor(
|
||||
|
|
@ -122,7 +123,7 @@ class GoogleTranslationProvider(TranslationProvider):
|
|||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if source_language:
|
||||
results = await loop.run_in_executor(
|
||||
|
|
|
|||
|
|
@ -52,6 +52,15 @@ class OpenAITranslationProvider(TranslationProvider):
|
|||
try:
|
||||
system_prompt = read_query_prompt("translate_content.txt")
|
||||
|
||||
# Validate system prompt was loaded successfully
|
||||
if system_prompt is None:
|
||||
logger.warning("translate_content.txt prompt file not found, using default prompt")
|
||||
system_prompt = (
|
||||
"You are a professional translator. Translate the given text accurately "
|
||||
"while preserving the original meaning, tone, and style. "
|
||||
"Detect the source language if not provided."
|
||||
)
|
||||
|
||||
# Build the input with context
|
||||
if source_language:
|
||||
input_text = (
|
||||
|
|
@ -75,6 +84,8 @@ class OpenAITranslationProvider(TranslationProvider):
|
|||
translated_text=result.translated_text,
|
||||
source_language=source_language or result.detected_source_language,
|
||||
target_language=target_language,
|
||||
# TODO: Consider deriving confidence from LLM response metadata
|
||||
# or making configurable via TranslationConfig
|
||||
confidence_score=0.95, # LLM translations are generally high quality
|
||||
provider=self.provider_name,
|
||||
raw_response={"notes": result.translation_notes},
|
||||
|
|
@ -103,3 +114,9 @@ class OpenAITranslationProvider(TranslationProvider):
|
|||
"""
|
||||
tasks = [self.translate(text, target_language, source_language) for text in texts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if OpenAI provider is available (has required credentials)."""
|
||||
import os
|
||||
|
||||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
|
|
|
|||
|
|
@ -44,6 +44,13 @@ async def translate_content(
|
|||
Chunks that required translation will have TranslatedContent
|
||||
objects in their 'contains' list.
|
||||
|
||||
Note:
|
||||
This function mutates the input chunks in-place. Specifically:
|
||||
- chunk.text is replaced with the translated text
|
||||
- chunk.contains is updated with LanguageMetadata and TranslatedContent
|
||||
The original text is preserved in TranslatedContent.original_text
|
||||
if preserve_original=True.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from cognee.tasks.translation import translate_content
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
Unit tests for translation configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import get_args
|
||||
from cognee.tasks.translation.config import (
|
||||
get_translation_config,
|
||||
|
|
@ -15,9 +14,15 @@ def test_default_translation_config():
|
|||
"""Test default translation configuration"""
|
||||
config = get_translation_config()
|
||||
|
||||
assert isinstance(config, TranslationConfig)
|
||||
assert config.translation_provider in ["openai", "google", "azure"]
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0
|
||||
assert isinstance(config, TranslationConfig), "Config should be TranslationConfig instance"
|
||||
assert config.translation_provider in [
|
||||
"openai",
|
||||
"google",
|
||||
"azure",
|
||||
], f"Invalid provider: {config.translation_provider}"
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0, (
|
||||
f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]"
|
||||
)
|
||||
|
||||
|
||||
def test_translation_provider_type_literal():
|
||||
|
|
@ -25,17 +30,52 @@ def test_translation_provider_type_literal():
|
|||
# Get the allowed values from the Literal type
|
||||
allowed_values = get_args(TranslationProviderType)
|
||||
|
||||
assert "openai" in allowed_values
|
||||
assert "google" in allowed_values
|
||||
assert "azure" in allowed_values
|
||||
assert len(allowed_values) == 3
|
||||
assert "openai" in allowed_values, "openai should be an allowed provider"
|
||||
assert "google" in allowed_values, "google should be an allowed provider"
|
||||
assert "azure" in allowed_values, "azure should be an allowed provider"
|
||||
assert len(allowed_values) == 3, f"Expected 3 providers, got {len(allowed_values)}"
|
||||
|
||||
|
||||
def test_confidence_threshold_bounds():
|
||||
"""Test confidence threshold validation"""
|
||||
config = TranslationConfig(translation_provider="openai", confidence_threshold=0.9)
|
||||
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0, (
|
||||
f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]"
|
||||
)
|
||||
|
||||
|
||||
def test_confidence_threshold_validation():
|
||||
"""Test that invalid confidence thresholds are rejected or clamped"""
|
||||
# Test boundary values - these should work
|
||||
config_min = TranslationConfig(translation_provider="openai", confidence_threshold=0.0)
|
||||
assert config_min.confidence_threshold == 0.0, "Minimum bound (0.0) should be valid"
|
||||
|
||||
config_max = TranslationConfig(translation_provider="openai", confidence_threshold=1.0)
|
||||
assert config_max.confidence_threshold == 1.0, "Maximum bound (1.0) should be valid"
|
||||
|
||||
# Test invalid values - these should either raise ValidationError or be clamped
|
||||
try:
|
||||
config_invalid_low = TranslationConfig(
|
||||
translation_provider="openai", confidence_threshold=-0.1
|
||||
)
|
||||
# If no error, verify it was clamped to valid range
|
||||
assert 0.0 <= config_invalid_low.confidence_threshold <= 1.0, (
|
||||
f"Invalid low value should be clamped, got {config_invalid_low.confidence_threshold}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Expected validation error
|
||||
|
||||
try:
|
||||
config_invalid_high = TranslationConfig(
|
||||
translation_provider="openai", confidence_threshold=1.5
|
||||
)
|
||||
# If no error, verify it was clamped to valid range
|
||||
assert 0.0 <= config_invalid_high.confidence_threshold <= 1.0, (
|
||||
f"Invalid high value should be clamped, got {config_invalid_high.confidence_threshold}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Expected validation error
|
||||
|
||||
|
||||
def test_multiple_provider_keys():
|
||||
|
|
@ -46,21 +86,5 @@ def test_multiple_provider_keys():
|
|||
azure_translator_key="azure_key",
|
||||
)
|
||||
|
||||
assert config.google_translate_api_key == "google_key"
|
||||
assert config.azure_translator_key == "azure_key"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_default_translation_config()
|
||||
print("✓ test_default_translation_config passed")
|
||||
|
||||
test_translation_provider_type_literal()
|
||||
print("✓ test_translation_provider_type_literal passed")
|
||||
|
||||
test_confidence_threshold_bounds()
|
||||
print("✓ test_confidence_threshold_bounds passed")
|
||||
|
||||
test_multiple_provider_keys()
|
||||
print("✓ test_multiple_provider_keys passed")
|
||||
|
||||
print("\nAll config tests passed!")
|
||||
assert config.google_translate_api_key == "google_key", "Google API key not set correctly"
|
||||
assert config.azure_translator_key == "azure_key", "Azure API key not set correctly"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Unit tests for language detection functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from cognee.tasks.translation.detect_language import (
|
||||
detect_language_async,
|
||||
LanguageDetectionResult,
|
||||
|
|
@ -10,6 +10,7 @@ from cognee.tasks.translation.detect_language import (
|
|||
from cognee.tasks.translation.exceptions import LanguageDetectionError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_english():
|
||||
"""Test detection of English text"""
|
||||
result = await detect_language_async("Hello world, this is a test.", target_language="en")
|
||||
|
|
@ -20,6 +21,7 @@ async def test_detect_english():
|
|||
assert result.language_name == "English"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_spanish():
|
||||
"""Test detection of Spanish text"""
|
||||
result = await detect_language_async("Hola mundo, esta es una prueba.", target_language="en")
|
||||
|
|
@ -30,6 +32,7 @@ async def test_detect_spanish():
|
|||
assert result.language_name == "Spanish"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_french():
|
||||
"""Test detection of French text"""
|
||||
result = await detect_language_async(
|
||||
|
|
@ -42,6 +45,7 @@ async def test_detect_french():
|
|||
assert result.language_name == "French"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_german():
|
||||
"""Test detection of German text"""
|
||||
result = await detect_language_async("Hallo Welt, das ist ein Test.", target_language="en")
|
||||
|
|
@ -51,15 +55,17 @@ async def test_detect_german():
|
|||
assert result.confidence > 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_chinese():
|
||||
"""Test detection of Chinese text"""
|
||||
result = await detect_language_async("你好世界,这是一个测试。", target_language="en")
|
||||
|
||||
assert result.language_code == "zh-cn"
|
||||
assert result.language_code.startswith("zh"), f"Expected Chinese, got {result.language_code}"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_target_language():
|
||||
"""Test when text is already in target language"""
|
||||
result = await detect_language_async("This text is already in English.", target_language="en")
|
||||
|
|
@ -67,6 +73,7 @@ async def test_already_target_language():
|
|||
assert result.requires_translation is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_text():
|
||||
"""Test detection with very short text"""
|
||||
result = await detect_language_async("Hi", target_language="es")
|
||||
|
|
@ -76,6 +83,7 @@ async def test_short_text():
|
|||
assert result.character_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text():
|
||||
"""Test detection with empty text - returns unknown by default"""
|
||||
result = await detect_language_async("", target_language="en")
|
||||
|
|
@ -88,6 +96,7 @@ async def test_empty_text():
|
|||
assert result.character_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_threshold():
|
||||
"""Test detection respects confidence threshold"""
|
||||
result = await detect_language_async(
|
||||
|
|
@ -97,6 +106,7 @@ async def test_confidence_threshold():
|
|||
assert result.confidence >= 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_language_text():
|
||||
"""Test detection with mixed language text (predominantly one language)"""
|
||||
# Predominantly Spanish with English word
|
||||
|
|
@ -106,42 +116,3 @@ async def test_mixed_language_text():
|
|||
|
||||
assert result.language_code == "es" # Should detect as Spanish
|
||||
assert result.requires_translation is True
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all language detection tests"""
|
||||
await test_detect_english()
|
||||
print("✓ test_detect_english passed")
|
||||
|
||||
await test_detect_spanish()
|
||||
print("✓ test_detect_spanish passed")
|
||||
|
||||
await test_detect_french()
|
||||
print("✓ test_detect_french passed")
|
||||
|
||||
await test_detect_german()
|
||||
print("✓ test_detect_german passed")
|
||||
|
||||
await test_detect_chinese()
|
||||
print("✓ test_detect_chinese passed")
|
||||
|
||||
await test_already_target_language()
|
||||
print("✓ test_already_target_language passed")
|
||||
|
||||
await test_short_text()
|
||||
print("✓ test_short_text passed")
|
||||
|
||||
await test_empty_text()
|
||||
print("✓ test_empty_text passed")
|
||||
|
||||
await test_confidence_threshold()
|
||||
print("✓ test_confidence_threshold passed")
|
||||
|
||||
await test_mixed_language_text()
|
||||
print("✓ test_mixed_language_text passed")
|
||||
|
||||
print("\nAll language detection tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ Integration tests for multilingual content translation feature.
|
|||
Tests the full cognify pipeline with translation enabled.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee import add, cognify, prune, search, SearchType
|
||||
from cognee.tasks.translation import translate_text
|
||||
from cognee.tasks.translation.detect_language import detect_language_async
|
||||
|
|
@ -17,12 +18,10 @@ def has_openai_key():
|
|||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_quick_translation():
|
||||
"""Quick smoke test for translation feature"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -39,12 +38,10 @@ async def test_quick_translation():
|
|||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translation_basic():
|
||||
"""Test basic translation functionality with English text"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -67,12 +64,10 @@ async def test_translation_basic():
|
|||
assert search_results is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translation_spanish():
|
||||
"""Test translation with Spanish text"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -100,12 +95,10 @@ async def test_translation_spanish():
|
|||
assert search_results is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translation_french():
|
||||
"""Test translation with French text"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -133,12 +126,10 @@ async def test_translation_french():
|
|||
assert search_results is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translation_disabled():
|
||||
"""Test that cognify works without translation"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -153,12 +144,10 @@ async def test_translation_disabled():
|
|||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translation_mixed_languages():
|
||||
"""Test with multiple documents in different languages"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -186,12 +175,10 @@ async def test_translation_mixed_languages():
|
|||
assert search_results is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_direct_translation_function():
|
||||
"""Test the translate_text convenience function directly"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
result = await translate_text(
|
||||
text="Hola, ¿cómo estás? Espero que tengas un buen día.",
|
||||
target_language="en",
|
||||
|
|
@ -204,6 +191,7 @@ async def test_direct_translation_function():
|
|||
assert result.provider == "openai"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_language_detection():
|
||||
"""Test language detection directly"""
|
||||
test_texts = [
|
||||
|
|
@ -220,36 +208,3 @@ async def test_language_detection():
|
|||
# Only check requires_translation for high-confidence detections
|
||||
if result.confidence > 0.8:
|
||||
assert result.requires_translation == should_translate
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all translation integration tests"""
|
||||
await test_quick_translation()
|
||||
print("✓ test_quick_translation passed")
|
||||
|
||||
await test_language_detection()
|
||||
print("✓ test_language_detection passed")
|
||||
|
||||
await test_direct_translation_function()
|
||||
print("✓ test_direct_translation_function passed")
|
||||
|
||||
await test_translation_basic()
|
||||
print("✓ test_translation_basic passed")
|
||||
|
||||
await test_translation_spanish()
|
||||
print("✓ test_translation_spanish passed")
|
||||
|
||||
await test_translation_french()
|
||||
print("✓ test_translation_french passed")
|
||||
|
||||
await test_translation_disabled()
|
||||
print("✓ test_translation_disabled passed")
|
||||
|
||||
await test_translation_mixed_languages()
|
||||
print("✓ test_translation_mixed_languages passed")
|
||||
|
||||
print("\nAll translation integration tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
Unit tests for translation providers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.tasks.translation.providers import (
|
||||
get_translation_provider,
|
||||
OpenAITranslationProvider,
|
||||
|
|
@ -17,12 +19,10 @@ def has_openai_key():
|
|||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_openai_provider_basic_translation():
|
||||
"""Test basic translation with OpenAI provider"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
result = await provider.translate(text="Hola mundo", target_language="en", source_language="es")
|
||||
|
|
@ -35,12 +35,10 @@ async def test_openai_provider_basic_translation():
|
|||
assert result.provider == "openai"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_openai_provider_auto_detect_source():
|
||||
"""Test translation with automatic source language detection"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
result = await provider.translate(
|
||||
|
|
@ -53,12 +51,10 @@ async def test_openai_provider_auto_detect_source():
|
|||
assert result.target_language == "en"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_openai_provider_long_text():
|
||||
"""Test translation of longer text"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
long_text = """
|
||||
|
|
@ -88,12 +84,10 @@ def test_get_translation_provider_invalid():
|
|||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_openai_batch_translation():
|
||||
"""Test batch translation with OpenAI provider"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
texts = ["Hola", "¿Cómo estás?", "Adiós"]
|
||||
|
|
@ -110,12 +104,10 @@ async def test_openai_batch_translation():
|
|||
assert result.target_language == "en"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translation_preserves_formatting():
|
||||
"""Test that translation preserves basic formatting"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
text_with_newlines = "Primera línea.\nSegunda línea."
|
||||
|
|
@ -129,12 +121,10 @@ async def test_translation_preserves_formatting():
|
|||
assert len(result.translated_text) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translation_special_characters():
|
||||
"""Test translation with special characters"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
text = "¡Hola! ¿Cómo estás? Está bien."
|
||||
|
|
@ -145,12 +135,10 @@ async def test_translation_special_characters():
|
|||
assert len(result.translated_text) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_empty_text_translation():
|
||||
"""Test translation with empty text - should return empty or handle gracefully"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
# Empty text may either raise an error or return an empty result
|
||||
|
|
@ -161,41 +149,3 @@ async def test_empty_text_translation():
|
|||
except TranslationError:
|
||||
# This is also acceptable behavior
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all provider tests"""
|
||||
# Sync tests
|
||||
test_get_translation_provider_factory()
|
||||
print("✓ test_get_translation_provider_factory passed")
|
||||
|
||||
test_get_translation_provider_invalid()
|
||||
print("✓ test_get_translation_provider_invalid passed")
|
||||
|
||||
# Async tests
|
||||
await test_openai_provider_basic_translation()
|
||||
print("✓ test_openai_provider_basic_translation passed")
|
||||
|
||||
await test_openai_provider_auto_detect_source()
|
||||
print("✓ test_openai_provider_auto_detect_source passed")
|
||||
|
||||
await test_openai_provider_long_text()
|
||||
print("✓ test_openai_provider_long_text passed")
|
||||
|
||||
await test_openai_batch_translation()
|
||||
print("✓ test_openai_batch_translation passed")
|
||||
|
||||
await test_translation_preserves_formatting()
|
||||
print("✓ test_translation_preserves_formatting passed")
|
||||
|
||||
await test_translation_special_characters()
|
||||
print("✓ test_translation_special_characters passed")
|
||||
|
||||
await test_empty_text_translation()
|
||||
print("✓ test_empty_text_translation passed")
|
||||
|
||||
print("\nAll provider tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@
|
|||
Unit tests for translate_content task
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.tasks.translation import translate_content
|
||||
|
|
@ -37,12 +39,10 @@ def create_test_chunk(text: str, chunk_index: int = 0):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translate_content_basic():
|
||||
"""Test basic content translation"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Create test chunk with Spanish text
|
||||
original_text = "Hola mundo, esta es una prueba."
|
||||
chunk = create_test_chunk(original_text)
|
||||
|
|
@ -61,12 +61,10 @@ async def test_translate_content_basic():
|
|||
assert has_translated_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translate_content_preserves_original():
|
||||
"""Test that original text is preserved"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
original_text = "Bonjour le monde"
|
||||
chunk = create_test_chunk(original_text)
|
||||
|
||||
|
|
@ -86,6 +84,7 @@ async def test_translate_content_preserves_original():
|
|||
assert translated_content.translated_text != original_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_content_skip_english():
|
||||
"""Test skipping translation for English text"""
|
||||
# This test doesn't require API call since English text is skipped
|
||||
|
|
@ -110,12 +109,10 @@ async def test_translate_content_skip_english():
|
|||
assert not has_translated_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translate_content_multiple_chunks():
|
||||
"""Test translation of multiple chunks"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Use longer texts to ensure reliable language detection
|
||||
original_texts = [
|
||||
"Hola mundo, esta es una prueba de traducción.",
|
||||
|
|
@ -136,6 +133,7 @@ async def test_translate_content_multiple_chunks():
|
|||
assert translated_count >= 2 # At least 2 chunks should be translated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_content_empty_list():
|
||||
"""Test with empty chunk list"""
|
||||
result = await translate_content(data_chunks=[], target_language="en")
|
||||
|
|
@ -143,6 +141,7 @@ async def test_translate_content_empty_list():
|
|||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_content_empty_text():
|
||||
"""Test with chunk containing empty text"""
|
||||
chunk = create_test_chunk("")
|
||||
|
|
@ -153,12 +152,10 @@ async def test_translate_content_empty_text():
|
|||
assert result[0].text == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translate_content_language_metadata():
|
||||
"""Test that LanguageMetadata is created correctly"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Use a longer, distinctly Spanish text to ensure reliable detection
|
||||
chunk = create_test_chunk(
|
||||
"La inteligencia artificial está cambiando el mundo de manera significativa"
|
||||
|
|
@ -180,12 +177,10 @@ async def test_translate_content_language_metadata():
|
|||
assert language_metadata.language_confidence > 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translate_content_confidence_threshold():
|
||||
"""Test with custom confidence threshold"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Use longer text for more reliable detection
|
||||
chunk = create_test_chunk("Hola mundo, esta es una frase más larga para mejor detección")
|
||||
|
||||
|
|
@ -196,12 +191,10 @@ async def test_translate_content_confidence_threshold():
|
|||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_translate_content_no_preserve_original():
|
||||
"""Test translation without preserving original"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Use longer text for more reliable detection
|
||||
chunk = create_test_chunk("Bonjour le monde, comment allez-vous aujourd'hui")
|
||||
|
||||
|
|
@ -218,39 +211,3 @@ async def test_translate_content_no_preserve_original():
|
|||
|
||||
assert translated_content is not None
|
||||
assert translated_content.original_text == "" # Should be empty
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all translate_content tests"""
|
||||
await test_translate_content_basic()
|
||||
print("✓ test_translate_content_basic passed")
|
||||
|
||||
await test_translate_content_preserves_original()
|
||||
print("✓ test_translate_content_preserves_original passed")
|
||||
|
||||
await test_translate_content_skip_english()
|
||||
print("✓ test_translate_content_skip_english passed")
|
||||
|
||||
await test_translate_content_multiple_chunks()
|
||||
print("✓ test_translate_content_multiple_chunks passed")
|
||||
|
||||
await test_translate_content_empty_list()
|
||||
print("✓ test_translate_content_empty_list passed")
|
||||
|
||||
await test_translate_content_empty_text()
|
||||
print("✓ test_translate_content_empty_text passed")
|
||||
|
||||
await test_translate_content_language_metadata()
|
||||
print("✓ test_translate_content_language_metadata passed")
|
||||
|
||||
await test_translate_content_confidence_threshold()
|
||||
print("✓ test_translate_content_confidence_threshold passed")
|
||||
|
||||
await test_translate_content_no_preserve_original()
|
||||
print("✓ test_translate_content_no_preserve_original passed")
|
||||
|
||||
print("\nAll translate_content tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue