refactor: rename OpenAI translation provider to LLM provider
- Rename OpenAITranslationProvider to LLMTranslationProvider - Rename openai_provider.py to llm_provider.py - Change provider type from 'openai' to 'llm' in TranslationProviderType - Update all test files to use 'llm' provider and has_llm_api_key() - Add AliasChoices for explicit env var mapping in TranslationConfig - Update translate_content.py to fallback to config.target_language - Update cognify.py docstrings to reference 'llm' provider - Update .env.template and test README documentation The LLM provider now uses whatever LLM is configured in cognee (OpenAI, Azure, Ollama, Anthropic, etc.) instead of being tied to OpenAI.
This commit is contained in:
parent
2a9d795723
commit
b6aa33f343
13 changed files with 277 additions and 119 deletions
|
|
@ -145,6 +145,35 @@ VECTOR_DATASET_DATABASE_HANDLER="lancedb"
|
|||
# ONTOLOGY_FILE_PATH=YOUR_FULL_FULE_PATH # Default: empty
|
||||
# To add ontology resolvers, either set them as it is set in ontology_example or add full_path and settings as envs.
|
||||
|
||||
################################################################################
|
||||
# 🌐 Translation Settings
|
||||
################################################################################
|
||||
|
||||
# Translation provider: llm (uses configured LLM), google, or azure
|
||||
# "llm" uses whichever LLM is configured above (OpenAI, Azure, Ollama, Anthropic, etc.)
|
||||
# "google" and "azure" use dedicated translation APIs
|
||||
TRANSLATION_PROVIDER="llm"
|
||||
|
||||
# Default target language for translations (ISO 639-1 code, e.g., en, es, fr, de)
|
||||
TARGET_LANGUAGE="en"
|
||||
|
||||
# Minimum confidence threshold for language detection (0.0 to 1.0)
|
||||
CONFIDENCE_THRESHOLD=0.8
|
||||
|
||||
# -- Google Translate settings (required if using google provider) -----------
|
||||
# GOOGLE_TRANSLATE_API_KEY="your-google-api-key"
|
||||
# GOOGLE_PROJECT_ID="your-google-project-id"
|
||||
|
||||
# -- Azure Translator settings (required if using azure provider) ------------
|
||||
# AZURE_TRANSLATOR_KEY="your-azure-translator-key"
|
||||
# AZURE_TRANSLATOR_REGION="westeurope"
|
||||
# AZURE_TRANSLATOR_ENDPOINT="https://api.cognitive.microsofttranslator.com"
|
||||
|
||||
# -- Performance settings ----------------------------------------------------
|
||||
# TRANSLATION_BATCH_SIZE=10
|
||||
# TRANSLATION_MAX_RETRIES=3
|
||||
# TRANSLATION_TIMEOUT_SECONDS=30
|
||||
|
||||
################################################################################
|
||||
# 🔄 MIGRATION (RELATIONAL → GRAPH) SETTINGS
|
||||
################################################################################
|
||||
|
|
|
|||
|
|
@ -128,10 +128,10 @@ async def cognify(
|
|||
content that needs translation. Defaults to False.
|
||||
target_language: Target language code for translation (e.g., "en", "es", "fr").
|
||||
Only used when auto_translate=True. Defaults to "en" (English).
|
||||
translation_provider: Translation service to use ("openai", "google", "azure").
|
||||
OpenAI uses the existing LLM infrastructure, Google requires
|
||||
translation_provider: Translation service to use ("llm", "google", "azure").
|
||||
LLM uses the existing LLM infrastructure, Google requires
|
||||
GOOGLE_TRANSLATE_API_KEY, Azure requires AZURE_TRANSLATOR_KEY.
|
||||
If not specified, uses TRANSLATION_PROVIDER env var or defaults to "openai".
|
||||
If not specified, uses TRANSLATION_PROVIDER env var or defaults to "llm".
|
||||
|
||||
Returns:
|
||||
Union[dict, list[PipelineRunInfo]]:
|
||||
|
|
@ -202,7 +202,7 @@ async def cognify(
|
|||
await cognee.cognify(
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
translation_provider="openai" # or "google", "azure"
|
||||
translation_provider="llm" # or "google", "azure"
|
||||
)
|
||||
```
|
||||
|
||||
|
|
@ -215,7 +215,7 @@ async def cognify(
|
|||
- LLM_PROVIDER, LLM_MODEL, VECTOR_DB_PROVIDER, GRAPH_DATABASE_PROVIDER
|
||||
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
|
||||
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
|
||||
- TRANSLATION_PROVIDER: Default translation provider ("openai", "google", "azure")
|
||||
- TRANSLATION_PROVIDER: Default translation provider ("llm", "google", "azure")
|
||||
- GOOGLE_TRANSLATE_API_KEY: API key for Google Translate
|
||||
- AZURE_TRANSLATOR_KEY: API key for Azure Translator
|
||||
"""
|
||||
|
|
@ -387,7 +387,7 @@ async def get_temporal_tasks(
|
|||
chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify
|
||||
auto_translate (bool, optional): If True, translate non-English content. Defaults to False.
|
||||
target_language (str, optional): Target language for translation. Defaults to "en".
|
||||
translation_provider (str, optional): Translation provider to use ("openai", "google", "azure").
|
||||
translation_provider (str, optional): Translation provider to use ("llm", "google", "azure").
|
||||
|
||||
Returns:
|
||||
list[Task]: A list of Task objects representing the temporal processing pipeline.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.llm.config import (
|
|||
get_llm_config,
|
||||
)
|
||||
from cognee.infrastructure.databases.relational import get_relational_config, get_migration_config
|
||||
from cognee.tasks.translation.config import get_translation_config
|
||||
from cognee.api.v1.exceptions.exceptions import InvalidConfigAttributeError
|
||||
|
||||
|
||||
|
|
@ -176,3 +177,29 @@ class config:
|
|||
def set_vector_db_url(db_url: str):
|
||||
vector_db_config = get_vectordb_config()
|
||||
vector_db_config.vector_db_url = db_url
|
||||
|
||||
# Translation configuration methods
|
||||
|
||||
@staticmethod
|
||||
def set_translation_provider(provider: str):
|
||||
"""Set the translation provider (llm, google, azure)."""
|
||||
translation_config = get_translation_config()
|
||||
translation_config.translation_provider = provider
|
||||
|
||||
@staticmethod
|
||||
def set_translation_target_language(target_language: str):
|
||||
"""Set the default target language for translations."""
|
||||
translation_config = get_translation_config()
|
||||
translation_config.target_language = target_language
|
||||
|
||||
@staticmethod
|
||||
def set_translation_config(config_dict: dict):
|
||||
"""
|
||||
Updates the translation config with values from config_dict.
|
||||
"""
|
||||
translation_config = get_translation_config()
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(translation_config, key):
|
||||
object.__setattr__(translation_config, key, value)
|
||||
else:
|
||||
raise InvalidConfigAttributeError(attribute=key)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ Main Components:
|
|||
- LanguageMetadata: DataPoint model for language information
|
||||
|
||||
Supported Translation Providers:
|
||||
- OpenAI (default): Uses GPT models via existing LLM infrastructure
|
||||
- LLM (default): Uses the configured LLM via existing infrastructure
|
||||
- Google Translate: Requires google-cloud-translate package
|
||||
- Azure Translator: Requires Azure Translator API key
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ Example Usage:
|
|||
translated_chunks = await translate_content(
|
||||
chunks,
|
||||
target_language="en",
|
||||
translation_provider="openai"
|
||||
translation_provider="llm"
|
||||
)
|
||||
|
||||
# Translate a single text
|
||||
|
|
@ -54,7 +54,7 @@ from .providers import (
|
|||
TranslationProvider,
|
||||
TranslationResult,
|
||||
get_translation_provider,
|
||||
OpenAITranslationProvider,
|
||||
LLMTranslationProvider,
|
||||
GoogleTranslationProvider,
|
||||
AzureTranslationProvider,
|
||||
)
|
||||
|
|
@ -84,7 +84,7 @@ __all__ = [
|
|||
"TranslationProvider",
|
||||
"TranslationResult",
|
||||
"get_translation_provider",
|
||||
"OpenAITranslationProvider",
|
||||
"LLMTranslationProvider",
|
||||
"GoogleTranslationProvider",
|
||||
"AzureTranslationProvider",
|
||||
# Exceptions
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from functools import lru_cache
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
TranslationProviderType = Literal["openai", "google", "azure"]
|
||||
TranslationProviderType = Literal["llm", "google", "azure"]
|
||||
|
||||
|
||||
class TranslationConfig(BaseSettings):
|
||||
|
|
@ -13,34 +13,74 @@ class TranslationConfig(BaseSettings):
|
|||
Configuration settings for the translation task.
|
||||
|
||||
Environment variables can be used to configure these settings:
|
||||
- TRANSLATION_PROVIDER: The translation service to use
|
||||
- TRANSLATION_TARGET_LANGUAGE: Default target language
|
||||
- TRANSLATION_CONFIDENCE_THRESHOLD: Minimum confidence for language detection
|
||||
- TRANSLATION_PROVIDER: The translation service to use ("llm", "google", "azure")
|
||||
- TARGET_LANGUAGE: Default target language (ISO 639-1 code, e.g., "en", "es", "fr")
|
||||
- CONFIDENCE_THRESHOLD: Minimum confidence for language detection (0.0 to 1.0)
|
||||
- GOOGLE_TRANSLATE_API_KEY: API key for Google Translate
|
||||
- GOOGLE_PROJECT_ID: Google Cloud project ID
|
||||
- AZURE_TRANSLATOR_KEY: API key for Azure Translator
|
||||
- AZURE_TRANSLATOR_REGION: Region for Azure Translator
|
||||
- AZURE_TRANSLATOR_ENDPOINT: Endpoint URL for Azure Translator
|
||||
- TRANSLATION_BATCH_SIZE: Number of texts to translate per batch
|
||||
- TRANSLATION_MAX_RETRIES: Maximum retry attempts on failure
|
||||
- TRANSLATION_TIMEOUT_SECONDS: Request timeout in seconds
|
||||
"""
|
||||
|
||||
# Translation provider settings
|
||||
translation_provider: TranslationProviderType = "openai"
|
||||
target_language: str = "en"
|
||||
confidence_threshold: float = Field(default=0.8, ge=0.0, le=1.0)
|
||||
translation_provider: TranslationProviderType = Field(
|
||||
default="llm",
|
||||
validation_alias=AliasChoices("TRANSLATION_PROVIDER", "translation_provider"),
|
||||
)
|
||||
target_language: str = Field(
|
||||
default="en",
|
||||
validation_alias=AliasChoices("TARGET_LANGUAGE", "target_language"),
|
||||
)
|
||||
confidence_threshold: float = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
validation_alias=AliasChoices("CONFIDENCE_THRESHOLD", "confidence_threshold"),
|
||||
)
|
||||
|
||||
# Google Translate settings
|
||||
google_translate_api_key: Optional[str] = None
|
||||
google_project_id: Optional[str] = None
|
||||
google_translate_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("GOOGLE_TRANSLATE_API_KEY", "google_translate_api_key"),
|
||||
)
|
||||
google_project_id: Optional[str] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("GOOGLE_PROJECT_ID", "google_project_id"),
|
||||
)
|
||||
|
||||
# Azure Translator settings
|
||||
azure_translator_key: Optional[str] = None
|
||||
azure_translator_region: Optional[str] = None
|
||||
azure_translator_endpoint: str = "https://api.cognitive.microsofttranslator.com"
|
||||
azure_translator_key: Optional[str] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("AZURE_TRANSLATOR_KEY", "azure_translator_key"),
|
||||
)
|
||||
azure_translator_region: Optional[str] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("AZURE_TRANSLATOR_REGION", "azure_translator_region"),
|
||||
)
|
||||
azure_translator_endpoint: str = Field(
|
||||
default="https://api.cognitive.microsofttranslator.com",
|
||||
validation_alias=AliasChoices("AZURE_TRANSLATOR_ENDPOINT", "azure_translator_endpoint"),
|
||||
)
|
||||
|
||||
# OpenAI uses the existing LLM configuration
|
||||
# LLM provider uses the existing LLM configuration
|
||||
|
||||
# Performance settings
|
||||
batch_size: int = 10
|
||||
max_retries: int = 3
|
||||
timeout_seconds: int = 30
|
||||
# Performance settings (with TRANSLATION_ prefix for env vars)
|
||||
batch_size: int = Field(
|
||||
default=10,
|
||||
validation_alias=AliasChoices("TRANSLATION_BATCH_SIZE", "batch_size"),
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
validation_alias=AliasChoices("TRANSLATION_MAX_RETRIES", "max_retries"),
|
||||
)
|
||||
timeout_seconds: int = Field(
|
||||
default=30,
|
||||
validation_alias=AliasChoices("TRANSLATION_TIMEOUT_SECONDS", "timeout_seconds"),
|
||||
)
|
||||
|
||||
# Language detection settings
|
||||
min_text_length_for_detection: int = 10
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from .base import TranslationProvider, TranslationResult
|
||||
from .openai_provider import OpenAITranslationProvider
|
||||
from .llm_provider import LLMTranslationProvider
|
||||
from .google_provider import GoogleTranslationProvider
|
||||
from .azure_provider import AzureTranslationProvider
|
||||
|
||||
__all__ = [
|
||||
"TranslationProvider",
|
||||
"TranslationResult",
|
||||
"OpenAITranslationProvider",
|
||||
"LLMTranslationProvider",
|
||||
"GoogleTranslationProvider",
|
||||
"AzureTranslationProvider",
|
||||
"get_translation_provider",
|
||||
|
|
@ -18,7 +18,10 @@ def get_translation_provider(provider_name: str) -> TranslationProvider:
|
|||
Factory function to get the appropriate translation provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider ("openai", "google", or "azure")
|
||||
provider_name: Name of the provider:
|
||||
- "llm": Uses the configured LLM (OpenAI, Azure, Ollama, Anthropic, etc.)
|
||||
- "google": Uses Google Cloud Translation API
|
||||
- "azure": Uses Azure Translator API
|
||||
|
||||
Returns:
|
||||
TranslationProvider instance
|
||||
|
|
@ -27,7 +30,7 @@ def get_translation_provider(provider_name: str) -> TranslationProvider:
|
|||
ValueError: If the provider name is not recognized
|
||||
"""
|
||||
providers = {
|
||||
"openai": OpenAITranslationProvider,
|
||||
"llm": LLMTranslationProvider,
|
||||
"google": GoogleTranslationProvider,
|
||||
"azure": AzureTranslationProvider,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
|
@ -20,17 +21,24 @@ class TranslationOutput(BaseModel):
|
|||
translation_notes: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAITranslationProvider(TranslationProvider):
|
||||
class LLMTranslationProvider(TranslationProvider):
|
||||
"""
|
||||
Translation provider using OpenAI's LLM for translation.
|
||||
Translation provider using the configured LLM for translation.
|
||||
|
||||
This provider leverages the existing LLM infrastructure in Cognee
|
||||
to perform translations using GPT models.
|
||||
to perform translations using any LLM configured via LLM_PROVIDER
|
||||
(OpenAI, Azure, Ollama, Anthropic, etc.).
|
||||
|
||||
The LLM used is determined by the cognee LLM configuration settings:
|
||||
- LLM_PROVIDER: The LLM provider (openai, azure, ollama, etc.)
|
||||
- LLM_MODEL: The model to use
|
||||
- LLM_API_KEY: API key for the provider
|
||||
"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "openai"
|
||||
"""Return 'llm' as the provider name."""
|
||||
return "llm"
|
||||
|
||||
async def translate(
|
||||
self,
|
||||
|
|
@ -39,7 +47,7 @@ class OpenAITranslationProvider(TranslationProvider):
|
|||
source_language: Optional[str] = None,
|
||||
) -> TranslationResult:
|
||||
"""
|
||||
Translate text using OpenAI's LLM.
|
||||
Translate text using the configured LLM.
|
||||
|
||||
Args:
|
||||
text: The text to translate
|
||||
|
|
@ -92,7 +100,7 @@ class OpenAITranslationProvider(TranslationProvider):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI translation failed: {e}")
|
||||
logger.error(f"LLM translation failed: {e}")
|
||||
raise
|
||||
|
||||
async def translate_batch(
|
||||
|
|
@ -103,7 +111,7 @@ class OpenAITranslationProvider(TranslationProvider):
|
|||
max_concurrent: int = 5,
|
||||
) -> list[TranslationResult]:
|
||||
"""
|
||||
Translate multiple texts using OpenAI's LLM.
|
||||
Translate multiple texts using the configured LLM.
|
||||
|
||||
Uses a semaphore to limit concurrent requests and avoid API rate limits.
|
||||
|
||||
|
|
@ -126,7 +134,10 @@ class OpenAITranslationProvider(TranslationProvider):
|
|||
return await asyncio.gather(*tasks)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if OpenAI provider is available (has required credentials)."""
|
||||
import os
|
||||
|
||||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
"""Check if LLM provider is available (has required credentials)."""
|
||||
try:
|
||||
llm_config = get_llm_config()
|
||||
# Check if API key is configured (required for most providers)
|
||||
return bool(llm_config.llm_api_key)
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -16,7 +16,7 @@ logger = get_logger(__name__)
|
|||
|
||||
async def translate_content(
|
||||
data_chunks: List[DocumentChunk],
|
||||
target_language: str = "en",
|
||||
target_language: str = None,
|
||||
translation_provider: TranslationProviderType = None,
|
||||
confidence_threshold: float = None,
|
||||
skip_if_target_language: bool = True,
|
||||
|
|
@ -32,7 +32,8 @@ async def translate_content(
|
|||
Args:
|
||||
data_chunks: List of DocumentChunk objects to process
|
||||
target_language: Target language code (default: "en" for English)
|
||||
translation_provider: Translation service to use ("openai", "google", "azure")
|
||||
If not provided, uses config default
|
||||
translation_provider: Translation service to use ("llm", "google", "azure")
|
||||
If not provided, uses config default
|
||||
confidence_threshold: Minimum confidence for language detection (0.0 to 1.0)
|
||||
If not provided, uses config default
|
||||
|
|
@ -61,7 +62,7 @@ async def translate_content(
|
|||
# Translate with specific provider
|
||||
translated_chunks = await translate_content(
|
||||
chunks,
|
||||
translation_provider="openai",
|
||||
translation_provider="llm",
|
||||
confidence_threshold=0.9
|
||||
)
|
||||
```
|
||||
|
|
@ -75,11 +76,12 @@ async def translate_content(
|
|||
# Get configuration
|
||||
config = get_translation_config()
|
||||
provider_name = translation_provider or config.translation_provider
|
||||
target_lang = target_language or config.target_language
|
||||
threshold = confidence_threshold or config.confidence_threshold
|
||||
|
||||
logger.info(
|
||||
f"Starting translation task for {len(data_chunks)} chunks "
|
||||
f"using {provider_name} provider, target language: {target_language}"
|
||||
f"using {provider_name} provider, target language: {target_lang}"
|
||||
)
|
||||
|
||||
# Get the translation provider
|
||||
|
|
@ -100,7 +102,7 @@ async def translate_content(
|
|||
|
||||
try:
|
||||
# Detect language
|
||||
detection = await detect_language_async(chunk.text, target_language, threshold)
|
||||
detection = await detect_language_async(chunk.text, target_lang, threshold)
|
||||
|
||||
# Create language metadata
|
||||
language_metadata = LanguageMetadata(
|
||||
|
|
@ -127,12 +129,12 @@ async def translate_content(
|
|||
|
||||
# Translate the content
|
||||
logger.debug(
|
||||
f"Translating chunk {chunk.id} from {detection.language_code} to {target_language}"
|
||||
f"Translating chunk {chunk.id} from {detection.language_code} to {target_lang}"
|
||||
)
|
||||
|
||||
translation_result = await provider.translate(
|
||||
text=chunk.text,
|
||||
target_language=target_language,
|
||||
target_language=target_lang,
|
||||
source_language=detection.language_code,
|
||||
)
|
||||
|
||||
|
|
@ -160,7 +162,7 @@ async def translate_content(
|
|||
|
||||
logger.debug(
|
||||
f"Successfully translated chunk {chunk.id}: "
|
||||
f"{detection.language_code} -> {target_language}"
|
||||
f"{detection.language_code} -> {target_lang}"
|
||||
)
|
||||
|
||||
except LanguageDetectionError as e:
|
||||
|
|
@ -186,7 +188,7 @@ def _add_to_chunk_contains(chunk: DocumentChunk, item) -> None:
|
|||
|
||||
async def translate_text(
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
target_language: str = None,
|
||||
translation_provider: TranslationProviderType = None,
|
||||
source_language: Optional[str] = None,
|
||||
) -> TranslationResult:
|
||||
|
|
@ -198,8 +200,10 @@ async def translate_text(
|
|||
|
||||
Args:
|
||||
text: The text to translate
|
||||
target_language: Target language code (default: "en")
|
||||
target_language: Target language code (default: uses config, typically "en")
|
||||
If not provided, uses config default
|
||||
translation_provider: Translation service to use
|
||||
If not provided, uses config default
|
||||
source_language: Source language code (optional, auto-detected if not provided)
|
||||
|
||||
Returns:
|
||||
|
|
@ -219,19 +223,20 @@ async def translate_text(
|
|||
"""
|
||||
config = get_translation_config()
|
||||
provider_name = translation_provider or config.translation_provider
|
||||
target_lang = target_language or config.target_language
|
||||
|
||||
provider = get_translation_provider(provider_name)
|
||||
|
||||
return await provider.translate(
|
||||
text=text,
|
||||
target_language=target_language,
|
||||
target_language=target_lang,
|
||||
source_language=source_language,
|
||||
)
|
||||
|
||||
|
||||
async def batch_translate_texts(
|
||||
texts: List[str],
|
||||
target_language: str = "en",
|
||||
target_language: str = None,
|
||||
translation_provider: TranslationProviderType = None,
|
||||
source_language: Optional[str] = None,
|
||||
) -> List[TranslationResult]:
|
||||
|
|
@ -243,8 +248,10 @@ async def batch_translate_texts(
|
|||
|
||||
Args:
|
||||
texts: List of texts to translate
|
||||
target_language: Target language code (default: "en")
|
||||
target_language: Target language code (default: uses config, typically "en")
|
||||
If not provided, uses config default
|
||||
translation_provider: Translation service to use
|
||||
If not provided, uses config default
|
||||
source_language: Source language code (optional)
|
||||
|
||||
Returns:
|
||||
|
|
@ -264,11 +271,12 @@ async def batch_translate_texts(
|
|||
"""
|
||||
config = get_translation_config()
|
||||
provider_name = translation_provider or config.translation_provider
|
||||
target_lang = target_language or config.target_language
|
||||
|
||||
provider = get_translation_provider(provider_name)
|
||||
|
||||
return await provider.translate_batch(
|
||||
texts=texts,
|
||||
target_language=target_language,
|
||||
target_language=target_lang,
|
||||
source_language=source_language,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ Unit and integration tests for the multilingual content translation feature.
|
|||
- Edge cases (empty text, short text, mixed languages)
|
||||
|
||||
- **providers_test.py** - Tests for translation provider implementations
|
||||
- OpenAI provider basic translation
|
||||
- LLM provider basic translation
|
||||
- Auto-detection of source language
|
||||
- Batch translation
|
||||
- Special characters and formatting preservation
|
||||
|
|
@ -73,6 +73,46 @@ uv run pytest cognee/tests/tasks/translation/ --cov=cognee.tasks.translation --c
|
|||
- LLM API key set in environment: `LLM_API_KEY=your_key`
|
||||
- Tests will be skipped if no API key is available
|
||||
|
||||
**Note:** The translation feature uses the same LLM model configured for other cognee tasks (via `LLM_MODEL` and `LLM_PROVIDER` environment variables). This means any LLM provider supported by cognee (OpenAI, Azure, Anthropic, Ollama, etc.) can be used for translation.
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
import cognee
|
||||
from cognee.tasks.translation import translate_text
|
||||
|
||||
# Configure translation (optional - defaults to LLM provider)
|
||||
cognee.config.set_translation_config(
|
||||
provider="llm", # Uses configured LLM (default)
|
||||
target_language="en", # Target language code
|
||||
confidence_threshold=0.7 # Minimum confidence for language detection
|
||||
)
|
||||
|
||||
# Translate text directly
|
||||
result = await translate_text(
|
||||
text="Bonjour le monde",
|
||||
target_language="en"
|
||||
)
|
||||
print(result.translated_text) # "Hello world"
|
||||
|
||||
# Or use auto-translation in the cognify pipeline
|
||||
await cognee.add("Hola, ¿cómo estás?")
|
||||
await cognee.cognify(auto_translate=True)
|
||||
|
||||
# Search works on translated content
|
||||
results = await cognee.search("how are you")
|
||||
```
|
||||
|
||||
### Alternative Translation Providers
|
||||
|
||||
```python
|
||||
# Use Google Cloud Translate (requires GOOGLE_TRANSLATE_API_KEY)
|
||||
cognee.config.set_translation_provider("google")
|
||||
|
||||
# Use Azure Translator (requires AZURE_TRANSLATOR_KEY and AZURE_TRANSLATOR_REGION)
|
||||
cognee.config.set_translation_provider("azure")
|
||||
```
|
||||
|
||||
## Test Summary
|
||||
|
||||
| Test File | Tests | Description |
|
||||
|
|
@ -101,7 +141,7 @@ uv run pytest cognee/tests/tasks/translation/ --cov=cognee.tasks.translation --c
|
|||
|
||||
### Translation Providers (9 tests)
|
||||
- ✅ Provider factory function
|
||||
- ✅ OpenAI translation
|
||||
- ✅ LLM translation
|
||||
- ✅ Batch operations
|
||||
- ✅ Auto source language detection
|
||||
- ✅ Long text handling
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def test_default_translation_config():
|
|||
|
||||
assert isinstance(config, TranslationConfig), "Config should be TranslationConfig instance"
|
||||
assert config.translation_provider in [
|
||||
"openai",
|
||||
"llm",
|
||||
"google",
|
||||
"azure",
|
||||
], f"Invalid provider: {config.translation_provider}"
|
||||
|
|
@ -30,7 +30,7 @@ def test_translation_provider_type_literal():
|
|||
# Get the allowed values from the Literal type
|
||||
allowed_values = get_args(TranslationProviderType)
|
||||
|
||||
assert "openai" in allowed_values, "openai should be an allowed provider"
|
||||
assert "llm" in allowed_values, "llm should be an allowed provider"
|
||||
assert "google" in allowed_values, "google should be an allowed provider"
|
||||
assert "azure" in allowed_values, "azure should be an allowed provider"
|
||||
assert len(allowed_values) == 3, f"Expected 3 providers, got {len(allowed_values)}"
|
||||
|
|
@ -38,7 +38,7 @@ def test_translation_provider_type_literal():
|
|||
|
||||
def test_confidence_threshold_bounds():
|
||||
"""Test confidence threshold validation"""
|
||||
config = TranslationConfig(translation_provider="openai", confidence_threshold=0.9)
|
||||
config = TranslationConfig(translation_provider="llm", confidence_threshold=0.9)
|
||||
|
||||
assert 0.0 <= config.confidence_threshold <= 1.0, (
|
||||
f"Confidence threshold {config.confidence_threshold} out of bounds [0.0, 1.0]"
|
||||
|
|
@ -48,16 +48,16 @@ def test_confidence_threshold_bounds():
|
|||
def test_confidence_threshold_validation():
|
||||
"""Test that invalid confidence thresholds are rejected or clamped"""
|
||||
# Test boundary values - these should work
|
||||
config_min = TranslationConfig(translation_provider="openai", confidence_threshold=0.0)
|
||||
config_min = TranslationConfig(translation_provider="llm", confidence_threshold=0.0)
|
||||
assert config_min.confidence_threshold == 0.0, "Minimum bound (0.0) should be valid"
|
||||
|
||||
config_max = TranslationConfig(translation_provider="openai", confidence_threshold=1.0)
|
||||
config_max = TranslationConfig(translation_provider="llm", confidence_threshold=1.0)
|
||||
assert config_max.confidence_threshold == 1.0, "Maximum bound (1.0) should be valid"
|
||||
|
||||
# Test invalid values - these should either raise ValidationError or be clamped
|
||||
try:
|
||||
config_invalid_low = TranslationConfig(
|
||||
translation_provider="openai", confidence_threshold=-0.1
|
||||
translation_provider="llm", confidence_threshold=-0.1
|
||||
)
|
||||
# If no error, verify it was clamped to valid range
|
||||
assert 0.0 <= config_invalid_low.confidence_threshold <= 1.0, (
|
||||
|
|
@ -68,7 +68,7 @@ def test_confidence_threshold_validation():
|
|||
|
||||
try:
|
||||
config_invalid_high = TranslationConfig(
|
||||
translation_provider="openai", confidence_threshold=1.5
|
||||
translation_provider="llm", confidence_threshold=1.5
|
||||
)
|
||||
# If no error, verify it was clamped to valid range
|
||||
assert 0.0 <= config_invalid_high.confidence_threshold <= 1.0, (
|
||||
|
|
@ -81,7 +81,7 @@ def test_confidence_threshold_validation():
|
|||
def test_multiple_provider_keys():
|
||||
"""Test configuration with multiple provider API keys"""
|
||||
config = TranslationConfig(
|
||||
translation_provider="openai",
|
||||
translation_provider="llm",
|
||||
google_translate_api_key="google_key",
|
||||
azure_translator_key="azure_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from cognee.tasks.translation import translate_text
|
|||
from cognee.tasks.translation.detect_language import detect_language_async
|
||||
|
||||
|
||||
def has_openai_key():
|
||||
"""Check if OpenAI API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
def has_llm_api_key():
|
||||
"""Check if LLM API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_quick_translation():
|
||||
"""Quick smoke test for translation feature"""
|
||||
await prune.prune_data()
|
||||
|
|
@ -32,7 +32,7 @@ async def test_quick_translation():
|
|||
datasets=["spanish_test"],
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
translation_provider="openai",
|
||||
translation_provider="llm",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
|
@ -40,7 +40,7 @@ async def test_quick_translation():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_basic():
|
||||
"""Test basic translation functionality with English text"""
|
||||
await prune.prune_data()
|
||||
|
|
@ -53,7 +53,7 @@ async def test_translation_basic():
|
|||
datasets=["test_english"],
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
translation_provider="openai",
|
||||
translation_provider="llm",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
|
@ -66,7 +66,7 @@ async def test_translation_basic():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_spanish():
|
||||
"""Test translation with Spanish text"""
|
||||
await prune.prune_data()
|
||||
|
|
@ -84,7 +84,7 @@ async def test_translation_spanish():
|
|||
datasets=["test_spanish"],
|
||||
auto_translate=True,
|
||||
target_language="en",
|
||||
translation_provider="openai",
|
||||
translation_provider="llm",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
|
@ -97,7 +97,7 @@ async def test_translation_spanish():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_french():
|
||||
"""Test translation with French text"""
|
||||
await prune.prune_data()
|
||||
|
|
@ -128,7 +128,7 @@ async def test_translation_french():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_disabled():
|
||||
"""Test that cognify works without translation"""
|
||||
await prune.prune_data()
|
||||
|
|
@ -146,7 +146,7 @@ async def test_translation_disabled():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_mixed_languages():
|
||||
"""Test with multiple documents in different languages"""
|
||||
await prune.prune_data()
|
||||
|
|
@ -177,19 +177,19 @@ async def test_translation_mixed_languages():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_direct_translation_function():
|
||||
"""Test the translate_text convenience function directly"""
|
||||
result = await translate_text(
|
||||
text="Hola, ¿cómo estás? Espero que tengas un buen día.",
|
||||
target_language="en",
|
||||
translation_provider="openai",
|
||||
translation_provider="llm",
|
||||
)
|
||||
|
||||
assert result.translated_text is not None
|
||||
assert result.translated_text != ""
|
||||
assert result.target_language == "en"
|
||||
assert result.provider == "openai"
|
||||
assert result.provider == "llm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -8,22 +8,22 @@ import pytest
|
|||
|
||||
from cognee.tasks.translation.providers import (
|
||||
get_translation_provider,
|
||||
OpenAITranslationProvider,
|
||||
LLMTranslationProvider,
|
||||
TranslationResult,
|
||||
)
|
||||
from cognee.tasks.translation.exceptions import TranslationError
|
||||
|
||||
|
||||
def has_openai_key():
|
||||
"""Check if OpenAI API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
def has_llm_api_key():
|
||||
"""Check if LLM API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_openai_provider_basic_translation():
|
||||
"""Test basic translation with OpenAI provider"""
|
||||
provider = OpenAITranslationProvider()
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_llm_provider_basic_translation():
|
||||
"""Test basic translation with LLM provider (uses configured LLM)"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
result = await provider.translate(text="Hola mundo", target_language="en", source_language="es")
|
||||
|
||||
|
|
@ -32,14 +32,14 @@ async def test_openai_provider_basic_translation():
|
|||
assert len(result.translated_text) > 0
|
||||
assert result.source_language == "es"
|
||||
assert result.target_language == "en"
|
||||
assert result.provider == "openai"
|
||||
assert result.provider == "llm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_openai_provider_auto_detect_source():
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_llm_provider_auto_detect_source():
|
||||
"""Test translation with automatic source language detection"""
|
||||
provider = OpenAITranslationProvider()
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
result = await provider.translate(
|
||||
text="Bonjour le monde",
|
||||
|
|
@ -52,10 +52,10 @@ async def test_openai_provider_auto_detect_source():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_openai_provider_long_text():
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_llm_provider_long_text():
|
||||
"""Test translation of longer text"""
|
||||
provider = OpenAITranslationProvider()
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
long_text = """
|
||||
La inteligencia artificial es una rama de la informática que se centra en
|
||||
|
|
@ -71,8 +71,8 @@ async def test_openai_provider_long_text():
|
|||
|
||||
def test_get_translation_provider_factory():
|
||||
"""Test provider factory function"""
|
||||
provider = get_translation_provider("openai")
|
||||
assert isinstance(provider, OpenAITranslationProvider)
|
||||
provider = get_translation_provider("llm")
|
||||
assert isinstance(provider, LLMTranslationProvider)
|
||||
|
||||
|
||||
def test_get_translation_provider_invalid():
|
||||
|
|
@ -85,10 +85,10 @@ def test_get_translation_provider_invalid():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
async def test_openai_batch_translation():
|
||||
"""Test batch translation with OpenAI provider"""
|
||||
provider = OpenAITranslationProvider()
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_llm_batch_translation():
|
||||
"""Test batch translation with LLM provider"""
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
texts = ["Hola", "¿Cómo estás?", "Adiós"]
|
||||
|
||||
|
|
@ -105,10 +105,10 @@ async def test_openai_batch_translation():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_preserves_formatting():
|
||||
"""Test that translation preserves basic formatting"""
|
||||
provider = OpenAITranslationProvider()
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
text_with_newlines = "Primera línea.\nSegunda línea."
|
||||
|
||||
|
|
@ -122,10 +122,10 @@ async def test_translation_preserves_formatting():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translation_special_characters():
|
||||
"""Test translation with special characters"""
|
||||
provider = OpenAITranslationProvider()
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
text = "¡Hola! ¿Cómo estás? Está bien."
|
||||
|
||||
|
|
@ -136,10 +136,10 @@ async def test_translation_special_characters():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_empty_text_translation():
|
||||
"""Test translation with empty text - should return empty or handle gracefully"""
|
||||
provider = OpenAITranslationProvider()
|
||||
provider = LLMTranslationProvider()
|
||||
|
||||
# Empty text may either raise an error or return an empty result
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ from cognee.tasks.translation import translate_content
|
|||
from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata
|
||||
|
||||
|
||||
def has_openai_key():
|
||||
"""Check if OpenAI API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY") or os.environ.get("OPENAI_API_KEY"))
|
||||
def has_llm_api_key():
|
||||
"""Check if LLM API key is available"""
|
||||
return bool(os.environ.get("LLM_API_KEY"))
|
||||
|
||||
|
||||
def create_test_chunk(text: str, chunk_index: int = 0):
|
||||
|
|
@ -40,7 +40,7 @@ def create_test_chunk(text: str, chunk_index: int = 0):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_basic():
|
||||
"""Test basic content translation"""
|
||||
# Create test chunk with Spanish text
|
||||
|
|
@ -48,7 +48,7 @@ async def test_translate_content_basic():
|
|||
chunk = create_test_chunk(original_text)
|
||||
|
||||
result = await translate_content(
|
||||
data_chunks=[chunk], target_language="en", translation_provider="openai"
|
||||
data_chunks=[chunk], target_language="en", translation_provider="llm"
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
|
@ -62,7 +62,7 @@ async def test_translate_content_basic():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_preserves_original():
|
||||
"""Test that original text is preserved"""
|
||||
original_text = "Bonjour le monde"
|
||||
|
|
@ -110,7 +110,7 @@ async def test_translate_content_skip_english():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_multiple_chunks():
|
||||
"""Test translation of multiple chunks"""
|
||||
# Use longer texts to ensure reliable language detection
|
||||
|
|
@ -153,7 +153,7 @@ async def test_translate_content_empty_text():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_language_metadata():
|
||||
"""Test that LanguageMetadata is created correctly"""
|
||||
# Use a longer, distinctly Spanish text to ensure reliable detection
|
||||
|
|
@ -178,7 +178,7 @@ async def test_translate_content_language_metadata():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_confidence_threshold():
|
||||
"""Test with custom confidence threshold"""
|
||||
# Use longer text for more reliable detection
|
||||
|
|
@ -192,7 +192,7 @@ async def test_translate_content_confidence_threshold():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not has_openai_key(), reason="No OpenAI API key available")
|
||||
@pytest.mark.skipif(not has_llm_api_key(), reason="No LLM API key available")
|
||||
async def test_translate_content_no_preserve_original():
|
||||
"""Test translation without preserving original"""
|
||||
# Use longer text for more reliable detection
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue