feat(translation): address PR review feedback
- Add langdetect>=1.0.9 as direct dependency in pyproject.toml - Wrap exceptions with TranslationProviderError in azure_provider.py - Add progress logging for large batch translations (every 100 chunks) - Add clear_translation_config_cache helper for testing - Set __cause__ on exceptions for proper exception chaining - Change TranslationResult.confidence_score to Optional[float] - Google provider: set confidence_score=None (API doesn't provide it) - Google provider: simplify translate methods with kwargs dict - Add assertion for result length in integration test
This commit is contained in:
parent
1a9c09e93d
commit
04616e3083
10 changed files with 50 additions and 27 deletions
|
|
@ -63,3 +63,8 @@ class TranslationConfig(BaseSettings):
|
|||
def get_translation_config() -> TranslationConfig:
|
||||
"""Get the translation configuration singleton."""
|
||||
return TranslationConfig()
|
||||
|
||||
|
||||
def clear_translation_config_cache():
|
||||
"""Clear the cached config for testing purposes."""
|
||||
get_translation_config.cache_clear()
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ def detect_language(
|
|||
async def detect_language_async(
|
||||
text: str,
|
||||
target_language: str = "en",
|
||||
confidence_threshold: float = None,
|
||||
confidence_threshold: Optional[float] = None,
|
||||
) -> LanguageDetectionResult:
|
||||
"""
|
||||
Async wrapper for language detection.
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ class TranslationError(Exception):
|
|||
self.message = message
|
||||
self.original_error = original_error
|
||||
super().__init__(self.message)
|
||||
if original_error:
|
||||
self.__cause__ = original_error
|
||||
|
||||
|
||||
class LanguageDetectionError(TranslationError):
|
||||
|
|
|
|||
|
|
@ -98,7 +98,11 @@ class AzureTranslationProvider(TranslationProvider):
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Azure translation failed: {e}")
|
||||
raise
|
||||
raise TranslationProviderError(
|
||||
provider=self.provider_name,
|
||||
message=f"Translation failed: {e}",
|
||||
original_error=e,
|
||||
)
|
||||
|
||||
async def translate_batch(
|
||||
self,
|
||||
|
|
@ -176,6 +180,10 @@ class AzureTranslationProvider(TranslationProvider):
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Azure batch translation failed: {e}")
|
||||
raise
|
||||
raise TranslationProviderError(
|
||||
provider=self.provider_name,
|
||||
message=f"Batch translation failed: {e}",
|
||||
original_error=e,
|
||||
)
|
||||
|
||||
return all_results
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@ class TranslationResult:
|
|||
translated_text: str
|
||||
source_language: str
|
||||
target_language: str
|
||||
confidence_score: float
|
||||
# Confidence score from the provider, or None if not available (e.g., Google Translate)
|
||||
confidence_score: Optional[float]
|
||||
provider: str
|
||||
raw_response: Optional[dict] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -75,17 +75,15 @@ class GoogleTranslationProvider(TranslationProvider):
|
|||
# Run in thread pool since google-cloud-translate is synchronous
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Build kwargs for translate call
|
||||
translate_kwargs = {"target_language": target_language}
|
||||
if source_language:
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.translate(
|
||||
text, target_language=target_language, source_language=source_language
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = await loop.run_in_executor(
|
||||
None, lambda: client.translate(text, target_language=target_language)
|
||||
)
|
||||
translate_kwargs["source_language"] = source_language
|
||||
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.translate(text, **translate_kwargs),
|
||||
)
|
||||
|
||||
detected_language = result.get("detectedSourceLanguage", source_language or "unknown")
|
||||
|
||||
|
|
@ -93,7 +91,8 @@ class GoogleTranslationProvider(TranslationProvider):
|
|||
translated_text=result["translatedText"],
|
||||
source_language=detected_language,
|
||||
target_language=target_language,
|
||||
confidence_score=0.9, # Google Translate is generally reliable
|
||||
# Google Translate API does not provide confidence scores
|
||||
confidence_score=None,
|
||||
provider=self.provider_name,
|
||||
raw_response=result,
|
||||
)
|
||||
|
|
@ -125,17 +124,15 @@ class GoogleTranslationProvider(TranslationProvider):
|
|||
client = self._get_client()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Build kwargs for translate call
|
||||
translate_kwargs = {"target_language": target_language}
|
||||
if source_language:
|
||||
results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.translate(
|
||||
texts, target_language=target_language, source_language=source_language
|
||||
),
|
||||
)
|
||||
else:
|
||||
results = await loop.run_in_executor(
|
||||
None, lambda: client.translate(texts, target_language=target_language)
|
||||
)
|
||||
translate_kwargs["source_language"] = source_language
|
||||
|
||||
results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.translate(texts, **translate_kwargs),
|
||||
)
|
||||
|
||||
translation_results = []
|
||||
for result in results:
|
||||
|
|
@ -147,7 +144,8 @@ class GoogleTranslationProvider(TranslationProvider):
|
|||
translated_text=result["translatedText"],
|
||||
source_language=detected_language,
|
||||
target_language=target_language,
|
||||
confidence_score=0.9,
|
||||
# Google Translate API does not provide confidence scores
|
||||
confidence_score=None,
|
||||
provider=self.provider_name,
|
||||
raw_response=result,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -87,8 +87,13 @@ async def translate_content(
|
|||
|
||||
# Process chunks
|
||||
processed_chunks = []
|
||||
total_chunks = len(data_chunks)
|
||||
|
||||
for chunk_index, chunk in enumerate(data_chunks):
|
||||
# Log progress for large batches
|
||||
if chunk_index > 0 and chunk_index % 100 == 0:
|
||||
logger.info(f"Translation progress: {chunk_index}/{total_chunks} chunks processed")
|
||||
|
||||
for chunk in data_chunks:
|
||||
if not hasattr(chunk, "text") or not chunk.text:
|
||||
processed_chunks.append(chunk)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ async def test_quick_translation():
|
|||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ dependencies = [
|
|||
"fakeredis[lua]>=2.32.0",
|
||||
"diskcache>=5.6.3",
|
||||
"aiolimiter>=1.2.1",
|
||||
"langdetect>=1.0.9",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -965,6 +965,7 @@ dependencies = [
|
|||
{ name = "jinja2" },
|
||||
{ name = "kuzu" },
|
||||
{ name = "lancedb" },
|
||||
{ name = "langdetect" },
|
||||
{ name = "limits" },
|
||||
{ name = "litellm" },
|
||||
{ name = "mistralai" },
|
||||
|
|
@ -1160,6 +1161,7 @@ requires-dist = [
|
|||
{ name = "lancedb", specifier = ">=0.24.0,<1.0.0" },
|
||||
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.22" },
|
||||
{ name = "langchain-text-splitters", marker = "extra == 'langchain'", specifier = ">=0.3.2,<1.0.0" },
|
||||
{ name = "langdetect", specifier = ">=1.0.9" },
|
||||
{ name = "langfuse", marker = "extra == 'monitoring'", specifier = ">=2.32.0,<3" },
|
||||
{ name = "langsmith", marker = "extra == 'langchain'", specifier = ">=0.2.3,<1.0.0" },
|
||||
{ name = "limits", specifier = ">=4.4.1,<5" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue