feat(translation): address PR review feedback
- Add langdetect>=1.0.9 as direct dependency in pyproject.toml - Wrap exceptions with TranslationProviderError in azure_provider.py - Add progress logging for large batch translations (every 100 chunks) - Add clear_translation_config_cache helper for testing - Set __cause__ on exceptions for proper exception chaining - Change TranslationResult.confidence_score to Optional[float] - Google provider: set confidence_score=None (API doesn't provide it) - Google provider: simplify translate methods with kwargs dict - Add assertion for result length in integration test
This commit is contained in:
parent
1a9c09e93d
commit
04616e3083
10 changed files with 50 additions and 27 deletions
|
|
@ -63,3 +63,8 @@ class TranslationConfig(BaseSettings):
|
||||||
def get_translation_config() -> TranslationConfig:
|
def get_translation_config() -> TranslationConfig:
|
||||||
"""Get the translation configuration singleton."""
|
"""Get the translation configuration singleton."""
|
||||||
return TranslationConfig()
|
return TranslationConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def clear_translation_config_cache():
|
||||||
|
"""Clear the cached config for testing purposes."""
|
||||||
|
get_translation_config.cache_clear()
|
||||||
|
|
|
||||||
|
|
@ -169,7 +169,7 @@ def detect_language(
|
||||||
async def detect_language_async(
|
async def detect_language_async(
|
||||||
text: str,
|
text: str,
|
||||||
target_language: str = "en",
|
target_language: str = "en",
|
||||||
confidence_threshold: float = None,
|
confidence_threshold: Optional[float] = None,
|
||||||
) -> LanguageDetectionResult:
|
) -> LanguageDetectionResult:
|
||||||
"""
|
"""
|
||||||
Async wrapper for language detection.
|
Async wrapper for language detection.
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ class TranslationError(Exception):
|
||||||
self.message = message
|
self.message = message
|
||||||
self.original_error = original_error
|
self.original_error = original_error
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
if original_error:
|
||||||
|
self.__cause__ = original_error
|
||||||
|
|
||||||
|
|
||||||
class LanguageDetectionError(TranslationError):
|
class LanguageDetectionError(TranslationError):
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,11 @@ class AzureTranslationProvider(TranslationProvider):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Azure translation failed: {e}")
|
logger.error(f"Azure translation failed: {e}")
|
||||||
raise
|
raise TranslationProviderError(
|
||||||
|
provider=self.provider_name,
|
||||||
|
message=f"Translation failed: {e}",
|
||||||
|
original_error=e,
|
||||||
|
)
|
||||||
|
|
||||||
async def translate_batch(
|
async def translate_batch(
|
||||||
self,
|
self,
|
||||||
|
|
@ -176,6 +180,10 @@ class AzureTranslationProvider(TranslationProvider):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Azure batch translation failed: {e}")
|
logger.error(f"Azure batch translation failed: {e}")
|
||||||
raise
|
raise TranslationProviderError(
|
||||||
|
provider=self.provider_name,
|
||||||
|
message=f"Batch translation failed: {e}",
|
||||||
|
original_error=e,
|
||||||
|
)
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,8 @@ class TranslationResult:
|
||||||
translated_text: str
|
translated_text: str
|
||||||
source_language: str
|
source_language: str
|
||||||
target_language: str
|
target_language: str
|
||||||
confidence_score: float
|
# Confidence score from the provider, or None if not available (e.g., Google Translate)
|
||||||
|
confidence_score: Optional[float]
|
||||||
provider: str
|
provider: str
|
||||||
raw_response: Optional[dict] = None
|
raw_response: Optional[dict] = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,17 +75,15 @@ class GoogleTranslationProvider(TranslationProvider):
|
||||||
# Run in thread pool since google-cloud-translate is synchronous
|
# Run in thread pool since google-cloud-translate is synchronous
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
# Build kwargs for translate call
|
||||||
|
translate_kwargs = {"target_language": target_language}
|
||||||
if source_language:
|
if source_language:
|
||||||
result = await loop.run_in_executor(
|
translate_kwargs["source_language"] = source_language
|
||||||
None,
|
|
||||||
lambda: client.translate(
|
result = await loop.run_in_executor(
|
||||||
text, target_language=target_language, source_language=source_language
|
None,
|
||||||
),
|
lambda: client.translate(text, **translate_kwargs),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
result = await loop.run_in_executor(
|
|
||||||
None, lambda: client.translate(text, target_language=target_language)
|
|
||||||
)
|
|
||||||
|
|
||||||
detected_language = result.get("detectedSourceLanguage", source_language or "unknown")
|
detected_language = result.get("detectedSourceLanguage", source_language or "unknown")
|
||||||
|
|
||||||
|
|
@ -93,7 +91,8 @@ class GoogleTranslationProvider(TranslationProvider):
|
||||||
translated_text=result["translatedText"],
|
translated_text=result["translatedText"],
|
||||||
source_language=detected_language,
|
source_language=detected_language,
|
||||||
target_language=target_language,
|
target_language=target_language,
|
||||||
confidence_score=0.9, # Google Translate is generally reliable
|
# Google Translate API does not provide confidence scores
|
||||||
|
confidence_score=None,
|
||||||
provider=self.provider_name,
|
provider=self.provider_name,
|
||||||
raw_response=result,
|
raw_response=result,
|
||||||
)
|
)
|
||||||
|
|
@ -125,17 +124,15 @@ class GoogleTranslationProvider(TranslationProvider):
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
# Build kwargs for translate call
|
||||||
|
translate_kwargs = {"target_language": target_language}
|
||||||
if source_language:
|
if source_language:
|
||||||
results = await loop.run_in_executor(
|
translate_kwargs["source_language"] = source_language
|
||||||
None,
|
|
||||||
lambda: client.translate(
|
results = await loop.run_in_executor(
|
||||||
texts, target_language=target_language, source_language=source_language
|
None,
|
||||||
),
|
lambda: client.translate(texts, **translate_kwargs),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
results = await loop.run_in_executor(
|
|
||||||
None, lambda: client.translate(texts, target_language=target_language)
|
|
||||||
)
|
|
||||||
|
|
||||||
translation_results = []
|
translation_results = []
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
@ -147,7 +144,8 @@ class GoogleTranslationProvider(TranslationProvider):
|
||||||
translated_text=result["translatedText"],
|
translated_text=result["translatedText"],
|
||||||
source_language=detected_language,
|
source_language=detected_language,
|
||||||
target_language=target_language,
|
target_language=target_language,
|
||||||
confidence_score=0.9,
|
# Google Translate API does not provide confidence scores
|
||||||
|
confidence_score=None,
|
||||||
provider=self.provider_name,
|
provider=self.provider_name,
|
||||||
raw_response=result,
|
raw_response=result,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -87,8 +87,13 @@ async def translate_content(
|
||||||
|
|
||||||
# Process chunks
|
# Process chunks
|
||||||
processed_chunks = []
|
processed_chunks = []
|
||||||
|
total_chunks = len(data_chunks)
|
||||||
|
|
||||||
|
for chunk_index, chunk in enumerate(data_chunks):
|
||||||
|
# Log progress for large batches
|
||||||
|
if chunk_index > 0 and chunk_index % 100 == 0:
|
||||||
|
logger.info(f"Translation progress: {chunk_index}/{total_chunks} chunks processed")
|
||||||
|
|
||||||
for chunk in data_chunks:
|
|
||||||
if not hasattr(chunk, "text") or not chunk.text:
|
if not hasattr(chunk, "text") or not chunk.text:
|
||||||
processed_chunks.append(chunk)
|
processed_chunks.append(chunk)
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ async def test_quick_translation():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ dependencies = [
|
||||||
"fakeredis[lua]>=2.32.0",
|
"fakeredis[lua]>=2.32.0",
|
||||||
"diskcache>=5.6.3",
|
"diskcache>=5.6.3",
|
||||||
"aiolimiter>=1.2.1",
|
"aiolimiter>=1.2.1",
|
||||||
|
"langdetect>=1.0.9",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -965,6 +965,7 @@ dependencies = [
|
||||||
{ name = "jinja2" },
|
{ name = "jinja2" },
|
||||||
{ name = "kuzu" },
|
{ name = "kuzu" },
|
||||||
{ name = "lancedb" },
|
{ name = "lancedb" },
|
||||||
|
{ name = "langdetect" },
|
||||||
{ name = "limits" },
|
{ name = "limits" },
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "mistralai" },
|
{ name = "mistralai" },
|
||||||
|
|
@ -1160,6 +1161,7 @@ requires-dist = [
|
||||||
{ name = "lancedb", specifier = ">=0.24.0,<1.0.0" },
|
{ name = "lancedb", specifier = ">=0.24.0,<1.0.0" },
|
||||||
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.22" },
|
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.22" },
|
||||||
{ name = "langchain-text-splitters", marker = "extra == 'langchain'", specifier = ">=0.3.2,<1.0.0" },
|
{ name = "langchain-text-splitters", marker = "extra == 'langchain'", specifier = ">=0.3.2,<1.0.0" },
|
||||||
|
{ name = "langdetect", specifier = ">=1.0.9" },
|
||||||
{ name = "langfuse", marker = "extra == 'monitoring'", specifier = ">=2.32.0,<3" },
|
{ name = "langfuse", marker = "extra == 'monitoring'", specifier = ">=2.32.0,<3" },
|
||||||
{ name = "langsmith", marker = "extra == 'langchain'", specifier = ">=0.2.3,<1.0.0" },
|
{ name = "langsmith", marker = "extra == 'langchain'", specifier = ">=0.2.3,<1.0.0" },
|
||||||
{ name = "limits", specifier = ">=4.4.1,<5" },
|
{ name = "limits", specifier = ">=4.4.1,<5" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue