test: add comprehensive translation module tests
- Add unit tests for translation configuration, language detection, providers, and translate_content task - Add integration tests for full cognify pipeline with translation - All 40 tests passing (32 unit + 8 integration) - Tests use asyncio.run() pattern matching project style - Tests named with *_test.py suffix per project convention - Update README with test documentation Formatting changes: - Apply ruff format to cognify.py (bracket placement style) Signed-off-by: andikarachman <andika.rachman.y@gmail.com>
This commit is contained in:
parent
d7962bd44a
commit
00e318b3ed
12 changed files with 1087 additions and 35 deletions
|
|
@ -335,25 +335,27 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
)
|
||||
)
|
||||
|
||||
default_tasks.extend([
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
graph_model=graph_model,
|
||||
config=config,
|
||||
custom_prompt=custom_prompt,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
**kwargs,
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
),
|
||||
Task(
|
||||
add_data_points,
|
||||
embed_triplets=embed_triplets,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
),
|
||||
])
|
||||
default_tasks.extend(
|
||||
[
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
graph_model=graph_model,
|
||||
config=config,
|
||||
custom_prompt=custom_prompt,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
**kwargs,
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
),
|
||||
Task(
|
||||
add_data_points,
|
||||
embed_triplets=embed_triplets,
|
||||
task_config={"batch_size": chunks_per_batch},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return default_tasks
|
||||
|
||||
|
|
@ -413,10 +415,12 @@ async def get_temporal_tasks(
|
|||
)
|
||||
)
|
||||
|
||||
temporal_tasks.extend([
|
||||
Task(extract_events_and_timestamps, task_config={"batch_size": chunks_per_batch}),
|
||||
Task(extract_knowledge_graph_from_events),
|
||||
Task(add_data_points, task_config={"batch_size": chunks_per_batch}),
|
||||
])
|
||||
temporal_tasks.extend(
|
||||
[
|
||||
Task(extract_events_and_timestamps, task_config={"batch_size": chunks_per_batch}),
|
||||
Task(extract_knowledge_graph_from_events),
|
||||
Task(add_data_points, task_config={"batch_size": chunks_per_batch}),
|
||||
]
|
||||
)
|
||||
|
||||
return temporal_tasks
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ class TranslationError(Exception):
|
|||
class LanguageDetectionError(TranslationError):
|
||||
"""Exception raised when language detection fails."""
|
||||
|
||||
def __init__(self, message: str = "Failed to detect language", original_error: Exception = None):
|
||||
def __init__(
|
||||
self, message: str = "Failed to detect language", original_error: Exception = None
|
||||
):
|
||||
super().__init__(message, original_error)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -89,8 +89,7 @@ class AzureTranslationProvider(TranslationProvider):
|
|||
|
||||
return TranslationResult(
|
||||
translated_text=translation["text"],
|
||||
source_language=source_language
|
||||
or detected_language.get("language", "unknown"),
|
||||
source_language=source_language or detected_language.get("language", "unknown"),
|
||||
target_language=target_language,
|
||||
confidence_score=detected_language.get("score", 0.9),
|
||||
provider=self.provider_name,
|
||||
|
|
|
|||
|
|
@ -101,7 +101,5 @@ class OpenAITranslationProvider(TranslationProvider):
|
|||
Returns:
|
||||
List of TranslationResult objects
|
||||
"""
|
||||
tasks = [
|
||||
self.translate(text, target_language, source_language) for text in texts
|
||||
]
|
||||
tasks = [self.translate(text, target_language, source_language) for text in texts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
|
|
|||
|
|
@ -88,9 +88,7 @@ async def translate_content(
|
|||
|
||||
try:
|
||||
# Detect language
|
||||
detection = await detect_language_async(
|
||||
chunk.text, target_language, threshold
|
||||
)
|
||||
detection = await detect_language_async(chunk.text, target_language, threshold)
|
||||
|
||||
# Create language metadata
|
||||
language_metadata = LanguageMetadata(
|
||||
|
|
@ -117,8 +115,7 @@ async def translate_content(
|
|||
|
||||
# Translate the content
|
||||
logger.debug(
|
||||
f"Translating chunk {chunk.id} from {detection.language_code} "
|
||||
f"to {target_language}"
|
||||
f"Translating chunk {chunk.id} from {detection.language_code} to {target_language}"
|
||||
)
|
||||
|
||||
translation_result = await provider.translate(
|
||||
|
|
|
|||
126
cognee/tests/tasks/translation/README.md
Normal file
126
cognee/tests/tasks/translation/README.md
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# Translation Task Tests
|
||||
|
||||
Unit and integration tests for the multilingual content translation feature.
|
||||
|
||||
## Test Files
|
||||
|
||||
- **config_test.py** - Tests for translation configuration
|
||||
- Default configuration
|
||||
- Provider type validation
|
||||
- Confidence threshold bounds
|
||||
- Multiple provider API keys
|
||||
|
||||
- **detect_language_test.py** - Tests for language detection functionality
|
||||
- English, Spanish, French, German, Chinese detection
|
||||
- Confidence thresholds
|
||||
- Edge cases (empty text, short text, mixed languages)
|
||||
|
||||
- **providers_test.py** - Tests for translation provider implementations
|
||||
- OpenAI provider basic translation
|
||||
- Auto-detection of source language
|
||||
- Batch translation
|
||||
- Special characters and formatting preservation
|
||||
- Error handling
|
||||
|
||||
- **translate_content_test.py** - Tests for the main translate_content task
|
||||
- Basic translation workflow
|
||||
- Original text preservation
|
||||
- Multiple chunks processing
|
||||
- Language metadata creation
|
||||
- Skip translation for target language
|
||||
- Confidence threshold customization
|
||||
|
||||
- **integration_test.py** - End-to-end integration tests
|
||||
- Full cognify pipeline with translation
|
||||
- Spanish/French to English translation
|
||||
- Mixed language datasets
|
||||
- Search functionality after translation
|
||||
- Translation disabled mode
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Run all translation tests
|
||||
```bash
|
||||
uv run pytest cognee/tests/tasks/translation/ -v
|
||||
```
|
||||
|
||||
### Run specific test file
|
||||
```bash
|
||||
uv run pytest cognee/tests/tasks/translation/detect_language_test.py -v
|
||||
```
|
||||
|
||||
### Run tests directly (without pytest)
|
||||
```bash
|
||||
uv run python cognee/tests/tasks/translation/config_test.py
|
||||
uv run python cognee/tests/tasks/translation/detect_language_test.py
|
||||
uv run python cognee/tests/tasks/translation/providers_test.py
|
||||
uv run python cognee/tests/tasks/translation/translate_content_test.py
|
||||
uv run python cognee/tests/tasks/translation/integration_test.py
|
||||
```
|
||||
|
||||
### Run all tests at once
|
||||
```bash
|
||||
for f in cognee/tests/tasks/translation/*_test.py; do uv run python "$f"; done
|
||||
```
|
||||
|
||||
### Run with coverage
|
||||
```bash
|
||||
uv run pytest cognee/tests/tasks/translation/ --cov=cognee.tasks.translation --cov-report=html
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- LLM API key set in environment: `LLM_API_KEY=your_key`
|
||||
- Tests will be skipped if no API key is available
|
||||
|
||||
## Test Summary
|
||||
|
||||
| Test File | Tests | Description |
|
||||
|-----------|-------|-------------|
|
||||
| config_test.py | 4 | Configuration validation |
|
||||
| detect_language_test.py | 10 | Language detection |
|
||||
| providers_test.py | 9 | Translation providers |
|
||||
| translate_content_test.py | 9 | Content translation task |
|
||||
| integration_test.py | 8 | End-to-end pipeline |
|
||||
| **Total** | **40** | |
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Configuration (4 tests)
|
||||
- ✅ Default configuration values
|
||||
- ✅ Provider type literal validation
|
||||
- ✅ Confidence threshold bounds
|
||||
- ✅ Multiple provider API keys
|
||||
|
||||
### Language Detection (10 tests)
|
||||
- ✅ Multiple language detection (EN, ES, FR, DE, ZH)
|
||||
- ✅ Confidence scoring
|
||||
- ✅ Target language matching
|
||||
- ✅ Short and empty text handling
|
||||
- ✅ Mixed language detection
|
||||
|
||||
### Translation Providers (9 tests)
|
||||
- ✅ Provider factory function
|
||||
- ✅ OpenAI translation
|
||||
- ✅ Batch operations
|
||||
- ✅ Auto source language detection
|
||||
- ✅ Long text handling
|
||||
- ✅ Special characters preservation
|
||||
- ✅ Error handling
|
||||
|
||||
### Content Translation (9 tests)
|
||||
- ✅ DocumentChunk processing
|
||||
- ✅ Metadata creation (LanguageMetadata, TranslatedContent)
|
||||
- ✅ Original text preservation
|
||||
- ✅ Multiple chunk handling
|
||||
- ✅ Empty text/list handling
|
||||
- ✅ Confidence threshold customization
|
||||
|
||||
### Integration (8 tests)
|
||||
- ✅ Full cognify pipeline with auto_translate=True
|
||||
- ✅ Spanish to English translation
|
||||
- ✅ French to English translation
|
||||
- ✅ Mixed language datasets
|
||||
- ✅ Translation disabled mode
|
||||
- ✅ Direct translate_text function
|
||||
- ✅ Search after translation
|
||||
1
cognee/tests/tasks/translation/__init__.py
Normal file
1
cognee/tests/tasks/translation/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Translation task tests"""
|
||||
66
cognee/tests/tasks/translation/config_test.py
Normal file
66
cognee/tests/tasks/translation/config_test.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""
|
||||
Unit tests for translation configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import get_args
|
||||
from cognee.tasks.translation.config import (
|
||||
get_translation_config,
|
||||
TranslationConfig,
|
||||
TranslationProviderType,
|
||||
)
|
||||
|
||||
|
||||
def test_default_translation_config():
|
||||
"""Test default translation configuration"""
|
||||
config = get_translation_config()
|
||||
|
||||
assert isinstance(config, TranslationConfig)
|
||||
assert config.translation_provider in ["openai", "google", "azure"]
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0
|
||||
|
||||
|
||||
def test_translation_provider_type_literal():
|
||||
"""Test TranslationProviderType Literal type values"""
|
||||
# Get the allowed values from the Literal type
|
||||
allowed_values = get_args(TranslationProviderType)
|
||||
|
||||
assert "openai" in allowed_values
|
||||
assert "google" in allowed_values
|
||||
assert "azure" in allowed_values
|
||||
assert len(allowed_values) == 3
|
||||
|
||||
|
||||
def test_confidence_threshold_bounds():
|
||||
"""Test confidence threshold validation"""
|
||||
config = TranslationConfig(translation_provider="openai", confidence_threshold=0.9)
|
||||
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0
|
||||
|
||||
|
||||
def test_multiple_provider_keys():
|
||||
"""Test configuration with multiple provider API keys"""
|
||||
config = TranslationConfig(
|
||||
translation_provider="openai",
|
||||
google_translate_api_key="google_key",
|
||||
azure_translator_key="azure_key",
|
||||
)
|
||||
|
||||
assert config.google_translate_api_key == "google_key"
|
||||
assert config.azure_translator_key == "azure_key"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_default_translation_config()
|
||||
print("✓ test_default_translation_config passed")
|
||||
|
||||
test_translation_provider_type_literal()
|
||||
print("✓ test_translation_provider_type_literal passed")
|
||||
|
||||
test_confidence_threshold_bounds()
|
||||
print("✓ test_confidence_threshold_bounds passed")
|
||||
|
||||
test_multiple_provider_keys()
|
||||
print("✓ test_multiple_provider_keys passed")
|
||||
|
||||
print("\nAll config tests passed!")
|
||||
147
cognee/tests/tasks/translation/detect_language_test.py
Normal file
147
cognee/tests/tasks/translation/detect_language_test.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
"""
|
||||
Unit tests for language detection functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from cognee.tasks.translation.detect_language import (
|
||||
detect_language_async,
|
||||
LanguageDetectionResult,
|
||||
)
|
||||
from cognee.tasks.translation.exceptions import LanguageDetectionError
|
||||
|
||||
|
||||
async def test_detect_english():
|
||||
"""Test detection of English text"""
|
||||
result = await detect_language_async("Hello world, this is a test.", target_language="en")
|
||||
|
||||
assert result.language_code == "en"
|
||||
assert result.requires_translation is False
|
||||
assert result.confidence > 0.9
|
||||
assert result.language_name == "English"
|
||||
|
||||
|
||||
async def test_detect_spanish():
|
||||
"""Test detection of Spanish text"""
|
||||
result = await detect_language_async("Hola mundo, esta es una prueba.", target_language="en")
|
||||
|
||||
assert result.language_code == "es"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
assert result.language_name == "Spanish"
|
||||
|
||||
|
||||
async def test_detect_french():
|
||||
"""Test detection of French text"""
|
||||
result = await detect_language_async(
|
||||
"Bonjour le monde, ceci est un test.", target_language="en"
|
||||
)
|
||||
|
||||
assert result.language_code == "fr"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
assert result.language_name == "French"
|
||||
|
||||
|
||||
async def test_detect_german():
|
||||
"""Test detection of German text"""
|
||||
result = await detect_language_async("Hallo Welt, das ist ein Test.", target_language="en")
|
||||
|
||||
assert result.language_code == "de"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
|
||||
|
||||
async def test_detect_chinese():
|
||||
"""Test detection of Chinese text"""
|
||||
result = await detect_language_async("你好世界,这是一个测试。", target_language="en")
|
||||
|
||||
assert result.language_code == "zh-cn"
|
||||
assert result.requires_translation is True
|
||||
assert result.confidence > 0.9
|
||||
|
||||
|
||||
async def test_already_target_language():
|
||||
"""Test when text is already in target language"""
|
||||
result = await detect_language_async("This text is already in English.", target_language="en")
|
||||
|
||||
assert result.requires_translation is False
|
||||
|
||||
|
||||
async def test_short_text():
|
||||
"""Test detection with very short text"""
|
||||
result = await detect_language_async("Hi", target_language="es")
|
||||
|
||||
# Short text may return 'unknown' if langdetect can't reliably detect
|
||||
assert result.language_code in ["en", "unknown"]
|
||||
assert result.character_count == 2
|
||||
|
||||
|
||||
async def test_empty_text():
|
||||
"""Test detection with empty text - returns unknown by default"""
|
||||
result = await detect_language_async("", target_language="en")
|
||||
|
||||
# With skip_detection_for_short_text=True (default), returns unknown
|
||||
assert result.language_code == "unknown"
|
||||
assert result.language_name == "Unknown"
|
||||
assert result.confidence == 0.0
|
||||
assert result.requires_translation is False
|
||||
assert result.character_count == 0
|
||||
|
||||
|
||||
async def test_confidence_threshold():
|
||||
"""Test detection respects confidence threshold"""
|
||||
result = await detect_language_async(
|
||||
"Hello world", target_language="es", confidence_threshold=0.5
|
||||
)
|
||||
|
||||
assert result.confidence >= 0.5
|
||||
|
||||
|
||||
async def test_mixed_language_text():
|
||||
"""Test detection with mixed language text (predominantly one language)"""
|
||||
# Predominantly Spanish with English word
|
||||
result = await detect_language_async(
|
||||
"La inteligencia artificial es muy importante en technology moderna.", target_language="en"
|
||||
)
|
||||
|
||||
assert result.language_code == "es" # Should detect as Spanish
|
||||
assert result.requires_translation is True
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all language detection tests"""
|
||||
await test_detect_english()
|
||||
print("✓ test_detect_english passed")
|
||||
|
||||
await test_detect_spanish()
|
||||
print("✓ test_detect_spanish passed")
|
||||
|
||||
await test_detect_french()
|
||||
print("✓ test_detect_french passed")
|
||||
|
||||
await test_detect_german()
|
||||
print("✓ test_detect_german passed")
|
||||
|
||||
await test_detect_chinese()
|
||||
print("✓ test_detect_chinese passed")
|
||||
|
||||
await test_already_target_language()
|
||||
print("✓ test_already_target_language passed")
|
||||
|
||||
await test_short_text()
|
||||
print("✓ test_short_text passed")
|
||||
|
||||
await test_empty_text()
|
||||
print("✓ test_empty_text passed")
|
||||
|
||||
await test_confidence_threshold()
|
||||
print("✓ test_confidence_threshold passed")
|
||||
|
||||
await test_mixed_language_text()
|
||||
print("✓ test_mixed_language_text passed")
|
||||
|
||||
print("\nAll language detection tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
255
cognee/tests/tasks/translation/integration_test.py
Normal file
255
cognee/tests/tasks/translation/integration_test.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
"""
|
||||
Integration tests for multilingual content translation feature.
|
||||
|
||||
Tests the full cognify pipeline with translation enabled.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from cognee import add, cognify, prune, search, SearchType
|
||||
from cognee.tasks.translation import translate_text
|
||||
from cognee.tasks.translation.detect_language import detect_language_async
|
||||
|
||||
|
||||
def has_openai_key():
|
||||
"""Check if OpenAI API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
|
||||
async def test_quick_translation():
|
||||
"""Quick smoke test for translation feature"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
spanish_text = "La inteligencia artificial está transformando el mundo de la tecnología."
|
||||
await add(spanish_text, dataset_name="spanish_test")
|
||||
|
||||
result = await cognify(
|
||||
datasets=["spanish_test"],
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
translation_provider="openai",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
async def test_translation_basic():
|
||||
"""Test basic translation functionality with English text"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
english_text = "Hello, this is a test document about artificial intelligence."
|
||||
await add(english_text, dataset_name="test_english")
|
||||
|
||||
result = await cognify(
|
||||
datasets=["test_english"],
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
translation_provider="openai",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
search_results = await search(
|
||||
query_text="What is this document about?",
|
||||
query_type=SearchType.SUMMARIES,
|
||||
)
|
||||
assert search_results is not None
|
||||
|
||||
|
||||
async def test_translation_spanish():
|
||||
"""Test translation with Spanish text"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
spanish_text = """
|
||||
La inteligencia artificial es una rama de la informática que se centra en
|
||||
crear sistemas capaces de realizar tareas que normalmente requieren inteligencia humana.
|
||||
Estos sistemas pueden aprender, razonar y resolver problemas complejos.
|
||||
"""
|
||||
|
||||
await add(spanish_text, dataset_name="test_spanish")
|
||||
|
||||
result = await cognify(
|
||||
datasets=["test_spanish"],
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
translation_provider="openai",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
search_results = await search(
|
||||
query_text="What is artificial intelligence?",
|
||||
query_type=SearchType.SUMMARIES,
|
||||
)
|
||||
assert search_results is not None
|
||||
|
||||
|
||||
async def test_translation_french():
|
||||
"""Test translation with French text"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
french_text = """
|
||||
L'apprentissage automatique est une méthode d'analyse de données qui
|
||||
automatise la construction de modèles analytiques. C'est une branche
|
||||
de l'intelligence artificielle basée sur l'idée que les systèmes peuvent
|
||||
apprendre à partir de données, identifier des modèles et prendre des décisions.
|
||||
"""
|
||||
|
||||
await add(french_text, dataset_name="test_french")
|
||||
|
||||
result = await cognify(
|
||||
datasets=["test_french"],
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
search_results = await search(
|
||||
query_text="What is machine learning?",
|
||||
query_type=SearchType.SUMMARIES,
|
||||
)
|
||||
assert search_results is not None
|
||||
|
||||
|
||||
async def test_translation_disabled():
|
||||
"""Test that cognify works without translation"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
text = "This is a baseline test without translation enabled."
|
||||
await add(text, dataset_name="test_baseline")
|
||||
|
||||
result = await cognify(
|
||||
datasets=["test_baseline"],
|
||||
auto_translate=False,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
async def test_translation_mixed_languages():
|
||||
"""Test with multiple documents in different languages"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
texts = [
|
||||
"Artificial intelligence is transforming the world.",
|
||||
"La tecnología está cambiando nuestras vidas.",
|
||||
"Les ordinateurs deviennent de plus en plus puissants.",
|
||||
]
|
||||
|
||||
for text in texts:
|
||||
await add(text, dataset_name="test_mixed")
|
||||
|
||||
result = await cognify(
|
||||
datasets=["test_mixed"],
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
search_results = await search(
|
||||
query_text="What topics are discussed?",
|
||||
query_type=SearchType.SUMMARIES,
|
||||
)
|
||||
assert search_results is not None
|
||||
|
||||
|
||||
async def test_direct_translation_function():
|
||||
"""Test the translate_text convenience function directly"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
result = await translate_text(
|
||||
text="Hola, ¿cómo estás? Espero que tengas un buen día.",
|
||||
target_language="en",
|
||||
translation_provider="openai",
|
||||
)
|
||||
|
||||
assert result.translated_text is not None
|
||||
assert result.translated_text != ""
|
||||
assert result.target_language == "en"
|
||||
assert result.provider == "openai"
|
||||
|
||||
|
||||
async def test_language_detection():
|
||||
"""Test language detection directly"""
|
||||
test_texts = [
|
||||
("Hello world, how are you doing today?", "en", False),
|
||||
("Bonjour le monde, comment allez-vous aujourd'hui?", "en", True),
|
||||
("Hola mundo, cómo estás hoy?", "en", True),
|
||||
("This is already in English language", "en", False),
|
||||
]
|
||||
|
||||
for text, target_lang, should_translate in test_texts:
|
||||
result = await detect_language_async(text, target_lang)
|
||||
assert result.language_code is not None
|
||||
assert result.confidence > 0.0
|
||||
# Only check requires_translation for high-confidence detections
|
||||
if result.confidence > 0.8:
|
||||
assert result.requires_translation == should_translate
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all translation integration tests"""
|
||||
await test_quick_translation()
|
||||
print("✓ test_quick_translation passed")
|
||||
|
||||
await test_language_detection()
|
||||
print("✓ test_language_detection passed")
|
||||
|
||||
await test_direct_translation_function()
|
||||
print("✓ test_direct_translation_function passed")
|
||||
|
||||
await test_translation_basic()
|
||||
print("✓ test_translation_basic passed")
|
||||
|
||||
await test_translation_spanish()
|
||||
print("✓ test_translation_spanish passed")
|
||||
|
||||
await test_translation_french()
|
||||
print("✓ test_translation_french passed")
|
||||
|
||||
await test_translation_disabled()
|
||||
print("✓ test_translation_disabled passed")
|
||||
|
||||
await test_translation_mixed_languages()
|
||||
print("✓ test_translation_mixed_languages passed")
|
||||
|
||||
print("\nAll translation integration tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
201
cognee/tests/tasks/translation/providers_test.py
Normal file
201
cognee/tests/tasks/translation/providers_test.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
"""
|
||||
Unit tests for translation providers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from cognee.tasks.translation.providers import (
|
||||
get_translation_provider,
|
||||
OpenAITranslationProvider,
|
||||
TranslationResult,
|
||||
)
|
||||
from cognee.tasks.translation.exceptions import TranslationError
|
||||
|
||||
|
||||
def has_openai_key():
|
||||
"""Check if OpenAI API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
|
||||
async def test_openai_provider_basic_translation():
|
||||
"""Test basic translation with OpenAI provider"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
result = await provider.translate(text="Hola mundo", target_language="en", source_language="es")
|
||||
|
||||
assert isinstance(result, TranslationResult)
|
||||
assert result.translated_text is not None
|
||||
assert len(result.translated_text) > 0
|
||||
assert result.source_language == "es"
|
||||
assert result.target_language == "en"
|
||||
assert result.provider == "openai"
|
||||
|
||||
|
||||
async def test_openai_provider_auto_detect_source():
|
||||
"""Test translation with automatic source language detection"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
result = await provider.translate(
|
||||
text="Bonjour le monde",
|
||||
target_language="en",
|
||||
# source_language not provided - should auto-detect
|
||||
)
|
||||
|
||||
assert result.translated_text is not None
|
||||
assert result.target_language == "en"
|
||||
|
||||
|
||||
async def test_openai_provider_long_text():
|
||||
"""Test translation of longer text"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
long_text = """
|
||||
La inteligencia artificial es una rama de la informática que se centra en
|
||||
crear sistemas capaces de realizar tareas que normalmente requieren inteligencia humana.
|
||||
Estos sistemas pueden aprender, razonar y resolver problemas complejos.
|
||||
"""
|
||||
|
||||
result = await provider.translate(text=long_text, target_language="en", source_language="es")
|
||||
|
||||
assert len(result.translated_text) > 0
|
||||
assert result.source_language == "es"
|
||||
|
||||
|
||||
def test_get_translation_provider_factory():
|
||||
"""Test provider factory function"""
|
||||
provider = get_translation_provider("openai")
|
||||
assert isinstance(provider, OpenAITranslationProvider)
|
||||
|
||||
|
||||
def test_get_translation_provider_invalid():
|
||||
"""Test provider factory with invalid provider name"""
|
||||
try:
|
||||
get_translation_provider("invalid_provider")
|
||||
assert False, "Expected TranslationError or ValueError"
|
||||
except (TranslationError, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
async def test_openai_batch_translation():
|
||||
"""Test batch translation with OpenAI provider"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
texts = ["Hola", "¿Cómo estás?", "Adiós"]
|
||||
|
||||
results = await provider.translate_batch(
|
||||
texts=texts, target_language="en", source_language="es"
|
||||
)
|
||||
|
||||
assert len(results) == len(texts)
|
||||
for result in results:
|
||||
assert isinstance(result, TranslationResult)
|
||||
assert result.translated_text is not None
|
||||
assert result.source_language == "es"
|
||||
assert result.target_language == "en"
|
||||
|
||||
|
||||
async def test_translation_preserves_formatting():
|
||||
"""Test that translation preserves basic formatting"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
text_with_newlines = "Primera línea.\nSegunda línea."
|
||||
|
||||
result = await provider.translate(
|
||||
text=text_with_newlines, target_language="en", source_language="es"
|
||||
)
|
||||
|
||||
# Should preserve structure (though exact newlines may vary)
|
||||
assert result.translated_text is not None
|
||||
assert len(result.translated_text) > 0
|
||||
|
||||
|
||||
async def test_translation_special_characters():
|
||||
"""Test translation with special characters"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
text = "¡Hola! ¿Cómo estás? Está bien."
|
||||
|
||||
result = await provider.translate(text=text, target_language="en", source_language="es")
|
||||
|
||||
assert result.translated_text is not None
|
||||
assert len(result.translated_text) > 0
|
||||
|
||||
|
||||
async def test_empty_text_translation():
|
||||
"""Test translation with empty text - should return empty or handle gracefully"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
provider = OpenAITranslationProvider()
|
||||
|
||||
# Empty text may either raise an error or return an empty result
|
||||
try:
|
||||
result = await provider.translate(text="", target_language="en", source_language="es")
|
||||
# If no error, should return a TranslationResult (possibly with empty text)
|
||||
assert isinstance(result, TranslationResult)
|
||||
except TranslationError:
|
||||
# This is also acceptable behavior
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all provider tests"""
|
||||
# Sync tests
|
||||
test_get_translation_provider_factory()
|
||||
print("✓ test_get_translation_provider_factory passed")
|
||||
|
||||
test_get_translation_provider_invalid()
|
||||
print("✓ test_get_translation_provider_invalid passed")
|
||||
|
||||
# Async tests
|
||||
await test_openai_provider_basic_translation()
|
||||
print("✓ test_openai_provider_basic_translation passed")
|
||||
|
||||
await test_openai_provider_auto_detect_source()
|
||||
print("✓ test_openai_provider_auto_detect_source passed")
|
||||
|
||||
await test_openai_provider_long_text()
|
||||
print("✓ test_openai_provider_long_text passed")
|
||||
|
||||
await test_openai_batch_translation()
|
||||
print("✓ test_openai_batch_translation passed")
|
||||
|
||||
await test_translation_preserves_formatting()
|
||||
print("✓ test_translation_preserves_formatting passed")
|
||||
|
||||
await test_translation_special_characters()
|
||||
print("✓ test_translation_special_characters passed")
|
||||
|
||||
await test_empty_text_translation()
|
||||
print("✓ test_empty_text_translation passed")
|
||||
|
||||
print("\nAll provider tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
256
cognee/tests/tasks/translation/translate_content_test.py
Normal file
256
cognee/tests/tasks/translation/translate_content_test.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
"""
|
||||
Unit tests for translate_content task
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from uuid import uuid4
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.tasks.translation import translate_content
|
||||
from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata
|
||||
|
||||
|
||||
def has_openai_key():
|
||||
"""Check if OpenAI API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
|
||||
def create_test_chunk(text: str, chunk_index: int = 0):
|
||||
"""Helper to create a DocumentChunk with all required fields"""
|
||||
# Create a minimal Document for the is_part_of field
|
||||
doc = TextDocument(
|
||||
id=uuid4(),
|
||||
name="test_doc",
|
||||
raw_data_location="/tmp/test.txt",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
return DocumentChunk(
|
||||
id=uuid4(),
|
||||
text=text,
|
||||
chunk_index=chunk_index,
|
||||
chunk_size=len(text),
|
||||
cut_type="sentence",
|
||||
is_part_of=doc,
|
||||
)
|
||||
|
||||
|
||||
async def test_translate_content_basic():
|
||||
"""Test basic content translation"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Create test chunk with Spanish text
|
||||
original_text = "Hola mundo, esta es una prueba."
|
||||
chunk = create_test_chunk(original_text)
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", translation_provider="openai"
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
# The chunk's text should now be translated (different from original Spanish)
|
||||
assert result[0].text != original_text # Text should be translated to English
|
||||
assert result[0].contains is not None
|
||||
|
||||
# Check for TranslatedContent in contains
|
||||
has_translated_content = any(isinstance(item, TranslatedContent) for item in result[0].contains)
|
||||
assert has_translated_content
|
||||
|
||||
|
||||
async def test_translate_content_preserves_original():
|
||||
"""Test that original text is preserved"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
original_text = "Bonjour le monde"
|
||||
chunk = create_test_chunk(original_text)
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", preserve_original=True
|
||||
)
|
||||
|
||||
# Find TranslatedContent in contains
|
||||
translated_content = None
|
||||
for item in result[0].contains:
|
||||
if isinstance(item, TranslatedContent):
|
||||
translated_content = item
|
||||
break
|
||||
|
||||
assert translated_content is not None
|
||||
assert translated_content.original_text == original_text
|
||||
assert translated_content.translated_text != original_text
|
||||
|
||||
|
||||
async def test_translate_content_skip_english():
|
||||
"""Test skipping translation for English text"""
|
||||
# This test doesn't require API call since English text is skipped
|
||||
chunk = create_test_chunk("Hello world, this is a test.")
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", skip_if_target_language=True
|
||||
)
|
||||
|
||||
# Text should remain unchanged
|
||||
assert result[0].text == chunk.text
|
||||
|
||||
# Should have LanguageMetadata but not TranslatedContent
|
||||
has_language_metadata = any(
|
||||
isinstance(item, LanguageMetadata) for item in (result[0].contains or [])
|
||||
)
|
||||
has_translated_content = any(
|
||||
isinstance(item, TranslatedContent) for item in (result[0].contains or [])
|
||||
)
|
||||
|
||||
assert has_language_metadata
|
||||
assert not has_translated_content
|
||||
|
||||
|
||||
async def test_translate_content_multiple_chunks():
|
||||
"""Test translation of multiple chunks"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Use longer texts to ensure reliable language detection
|
||||
original_texts = [
|
||||
"Hola mundo, esta es una prueba de traducción.",
|
||||
"Bonjour le monde, ceci est un test de traduction.",
|
||||
"Ciao mondo, questo è un test di traduzione.",
|
||||
]
|
||||
chunks = [create_test_chunk(text, i) for i, text in enumerate(original_texts)]
|
||||
|
||||
result = await translate_content(data_chunks=chunks, target_language="en")
|
||||
|
||||
assert len(result) == 3
|
||||
# Check that at least some chunks were translated
|
||||
translated_count = sum(
|
||||
1
|
||||
for chunk in result
|
||||
if any(isinstance(item, TranslatedContent) for item in (chunk.contains or []))
|
||||
)
|
||||
assert translated_count >= 2 # At least 2 chunks should be translated
|
||||
|
||||
|
||||
async def test_translate_content_empty_list():
|
||||
"""Test with empty chunk list"""
|
||||
result = await translate_content(data_chunks=[], target_language="en")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
async def test_translate_content_empty_text():
|
||||
"""Test with chunk containing empty text"""
|
||||
chunk = create_test_chunk("")
|
||||
|
||||
result = await translate_content(data_chunks=[chunk], target_language="en")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == ""
|
||||
|
||||
|
||||
async def test_translate_content_language_metadata():
|
||||
"""Test that LanguageMetadata is created correctly"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Use a longer, distinctly Spanish text to ensure reliable detection
|
||||
chunk = create_test_chunk(
|
||||
"La inteligencia artificial está cambiando el mundo de manera significativa"
|
||||
)
|
||||
|
||||
result = await translate_content(data_chunks=[chunk], target_language="en")
|
||||
|
||||
# Find LanguageMetadata
|
||||
language_metadata = None
|
||||
for item in result[0].contains:
|
||||
if isinstance(item, LanguageMetadata):
|
||||
language_metadata = item
|
||||
break
|
||||
|
||||
assert language_metadata is not None
|
||||
# Just check that a language was detected (short texts can be ambiguous)
|
||||
assert language_metadata.detected_language is not None
|
||||
assert language_metadata.requires_translation is True
|
||||
assert language_metadata.language_confidence > 0.0
|
||||
|
||||
|
||||
async def test_translate_content_confidence_threshold():
|
||||
"""Test with custom confidence threshold"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Use longer text for more reliable detection
|
||||
chunk = create_test_chunk("Hola mundo, esta es una frase más larga para mejor detección")
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", confidence_threshold=0.5
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
async def test_translate_content_no_preserve_original():
|
||||
"""Test translation without preserving original"""
|
||||
if not has_openai_key():
|
||||
print(" (skipped - no API key)")
|
||||
return
|
||||
|
||||
# Use longer text for more reliable detection
|
||||
chunk = create_test_chunk("Bonjour le monde, comment allez-vous aujourd'hui")
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", preserve_original=False
|
||||
)
|
||||
|
||||
# Find TranslatedContent
|
||||
translated_content = None
|
||||
for item in result[0].contains:
|
||||
if isinstance(item, TranslatedContent):
|
||||
translated_content = item
|
||||
break
|
||||
|
||||
assert translated_content is not None
|
||||
assert translated_content.original_text == "" # Should be empty
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all translate_content tests"""
|
||||
await test_translate_content_basic()
|
||||
print("✓ test_translate_content_basic passed")
|
||||
|
||||
await test_translate_content_preserves_original()
|
||||
print("✓ test_translate_content_preserves_original passed")
|
||||
|
||||
await test_translate_content_skip_english()
|
||||
print("✓ test_translate_content_skip_english passed")
|
||||
|
||||
await test_translate_content_multiple_chunks()
|
||||
print("✓ test_translate_content_multiple_chunks passed")
|
||||
|
||||
await test_translate_content_empty_list()
|
||||
print("✓ test_translate_content_empty_list passed")
|
||||
|
||||
await test_translate_content_empty_text()
|
||||
print("✓ test_translate_content_empty_text passed")
|
||||
|
||||
await test_translate_content_language_metadata()
|
||||
print("✓ test_translate_content_language_metadata passed")
|
||||
|
||||
await test_translate_content_confidence_threshold()
|
||||
print("✓ test_translate_content_confidence_threshold passed")
|
||||
|
||||
await test_translate_content_no_preserve_original()
|
||||
print("✓ test_translate_content_no_preserve_original passed")
|
||||
|
||||
print("\nAll translate_content tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Reference in a new issue