Compare commits
1 commit
main
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f6b2dca51 |
4 changed files with 1432 additions and 195 deletions
|
|
@ -1,18 +1,21 @@
|
|||
import asyncio
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, Type
|
||||
from uuid import UUID
|
||||
import os
|
||||
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.infrastructure.llm.utils import get_max_chunk_tokens
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.modules.pipelines import run_pipeline
|
||||
from cognee.modules.pipelines.operations.pipeline import run_pipeline
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
|
|
@ -21,179 +24,101 @@ from cognee.tasks.documents import (
|
|||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.tasks.translation import translate_content, get_available_providers, validate_provider
|
||||
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
||||
from cognee.tasks.temporal_graph.extract_events_and_entities import extract_events_and_timestamps
|
||||
from cognee.tasks.temporal_graph.extract_knowledge_graph_from_events import (
|
||||
extract_knowledge_graph_from_events,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("cognify")
|
||||
class TranslationProviderError(ValueError):
|
||||
"""Error related to translation provider initialization."""
|
||||
pass
|
||||
|
||||
update_status_lock = asyncio.Lock()
|
||||
class UnknownTranslationProviderError(TranslationProviderError):
|
||||
"""Unknown translation provider name."""
|
||||
|
||||
class ProviderInitializationError(TranslationProviderError):
|
||||
"""Provider failed to initialize (likely missing dependency or bad config)."""
|
||||
|
||||
|
||||
async def cognify(
|
||||
datasets: Union[str, list[str], list[UUID]] = None,
|
||||
user: User = None,
|
||||
graph_model: BaseModel = KnowledgeGraph,
|
||||
_WARNED_ENV_VARS: set[str] = set()
|
||||
|
||||
def _parse_batch_env(var: str, default: int = 10) -> int:
|
||||
"""
|
||||
Parse an environment variable as a positive integer (minimum 1), falling back to a default.
|
||||
|
||||
If the environment variable named `var` is unset, the provided `default` is returned.
|
||||
If the variable is set but cannot be parsed as an integer, `default` is returned and a
|
||||
one-time warning is logged for that variable (the variable name is recorded in
|
||||
`_WARNED_ENV_VARS` to avoid repeated warnings).
|
||||
|
||||
Parameters:
|
||||
var: Name of the environment variable to read.
|
||||
default: Fallback integer value returned when the variable is missing or invalid.
|
||||
|
||||
Returns:
|
||||
An integer >= 1 representing the parsed value or the fallback `default`.
|
||||
"""
|
||||
raw = os.getenv(var)
|
||||
if raw is None:
|
||||
return default
|
||||
try:
|
||||
return max(1, int(raw))
|
||||
except (TypeError, ValueError):
|
||||
if var not in _WARNED_ENV_VARS:
|
||||
logger.warning("Invalid int for %s=%r; using default=%d", var, raw, default)
|
||||
_WARNED_ENV_VARS.add(var)
|
||||
return default
|
||||
|
||||
# Constants for batch processing
|
||||
DEFAULT_BATCH_SIZE = _parse_batch_env("COGNEE_DEFAULT_BATCH_SIZE", 10)
|
||||
|
||||
async def cognify( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
datasets: Optional[Union[str, UUID, list[str], list[UUID]]] = None,
|
||||
user: Optional[User] = None,
|
||||
graph_model: Type[BaseModel] = KnowledgeGraph,
|
||||
chunker=TextChunker,
|
||||
chunk_size: int = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
ontology_file_path: Optional[str] = None,
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
vector_db_config: Optional[dict] = None,
|
||||
graph_db_config: Optional[dict] = None,
|
||||
run_in_background: bool = False,
|
||||
incremental_loading: bool = True,
|
||||
custom_prompt: Optional[str] = None,
|
||||
temporal_cognify: bool = False,
|
||||
):
|
||||
"""
|
||||
Transform ingested data into a structured knowledge graph.
|
||||
|
||||
This is the core processing step in Cognee that converts raw text and documents
|
||||
into an intelligent knowledge graph. It analyzes content, extracts entities and
|
||||
relationships, and creates semantic connections for enhanced search and reasoning.
|
||||
|
||||
Prerequisites:
|
||||
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
|
||||
- **Data Added**: Must have data previously added via `cognee.add()`
|
||||
- **Vector Database**: Must be accessible for embeddings storage
|
||||
- **Graph Database**: Must be accessible for relationship storage
|
||||
|
||||
Input Requirements:
|
||||
- **Datasets**: Must contain data previously added via `cognee.add()`
|
||||
- **Content Types**: Works with any text-extractable content including:
|
||||
* Natural language documents
|
||||
* Structured data (CSV, JSON)
|
||||
* Code repositories
|
||||
* Academic papers and technical documentation
|
||||
* Mixed multimedia content (with text extraction)
|
||||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||
7. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||
|
||||
Graph Model Customization:
|
||||
The `graph_model` parameter allows custom knowledge structures:
|
||||
- **Default**: General-purpose KnowledgeGraph for any domain
|
||||
- **Custom Models**: Domain-specific schemas (e.g., scientific papers, code analysis)
|
||||
- **Ontology Integration**: Use `ontology_file_path` for predefined vocabularies
|
||||
|
||||
Args:
|
||||
datasets: Dataset name(s) or dataset uuid to process. Processes all available data if None.
|
||||
- Single dataset: "my_dataset"
|
||||
- Multiple datasets: ["docs", "research", "reports"]
|
||||
- None: Process all datasets for the user
|
||||
user: User context for authentication and data access. Uses default if None.
|
||||
graph_model: Pydantic model defining the knowledge graph structure.
|
||||
Defaults to KnowledgeGraph for general-purpose processing.
|
||||
chunker: Text chunking strategy (TextChunker, LangchainChunker).
|
||||
- TextChunker: Paragraph-based chunking (default, most reliable)
|
||||
- LangchainChunker: Recursive character splitting with overlap
|
||||
Determines how documents are segmented for processing.
|
||||
chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None.
|
||||
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
|
||||
Default limits: ~512-8192 tokens depending on models.
|
||||
Smaller chunks = more granular but potentially fragmented knowledge.
|
||||
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.
|
||||
Useful for specialized fields like medical or legal documents.
|
||||
vector_db_config: Custom vector database configuration for embeddings storage.
|
||||
graph_db_config: Custom graph database configuration for relationship storage.
|
||||
run_in_background: If True, starts processing asynchronously and returns immediately.
|
||||
If False, waits for completion before returning.
|
||||
Background mode recommended for large datasets (>100MB).
|
||||
Use pipeline_run_id from return value to monitor progress.
|
||||
custom_prompt: Optional custom prompt string to use for entity extraction and graph generation.
|
||||
If provided, this prompt will be used instead of the default prompts for
|
||||
knowledge graph extraction. The prompt should guide the LLM on how to
|
||||
extract entities and relationships from the text content.
|
||||
|
||||
Orchestrate processing of datasets into a knowledge graph.
|
||||
|
||||
Builds the default Cognify task sequence (classification, permission check, chunking,
|
||||
graph extraction, summarization, indexing) and executes it via the pipeline
|
||||
executor. Use get_default_tasks_with_translation(...) to include an automatic
|
||||
translation step before graph extraction.
|
||||
|
||||
Parameters:
|
||||
datasets: Optional dataset id or list of ids to process. If None, processes all
|
||||
datasets available to the user.
|
||||
user: Optional user context used for permission checks; defaults to the current
|
||||
runtime user if omitted.
|
||||
graph_model: Pydantic model type that defines the structure of produced graph
|
||||
DataPoints (default: KnowledgeGraph).
|
||||
chunker: Chunking strategy/class used to split documents (default: TextChunker).
|
||||
chunk_size: Optional max tokens per chunk; when None a sensible default is used.
|
||||
ontology_file_path: Optional path to an ontology (RDF/OWL) used by the extractor.
|
||||
vector_db_config: Optional mapping of vector DB configuration (overrides defaults).
|
||||
graph_db_config: Optional mapping of graph DB configuration (overrides defaults).
|
||||
run_in_background: If True, starts the pipeline asynchronously and returns
|
||||
background run info; if False, waits for completion and returns results.
|
||||
incremental_loading: If True, performs incremental loading to avoid reprocessing
|
||||
unchanged content.
|
||||
custom_prompt: Optional prompt to override the default prompt used for graph
|
||||
extraction.
|
||||
|
||||
Returns:
|
||||
Union[dict, list[PipelineRunInfo]]:
|
||||
- **Blocking mode**: Dictionary mapping dataset_id -> PipelineRunInfo with:
|
||||
* Processing status (completed/failed/in_progress)
|
||||
* Extracted entity and relationship counts
|
||||
* Processing duration and resource usage
|
||||
* Error details if any failures occurred
|
||||
- **Background mode**: List of PipelineRunInfo objects for tracking progress
|
||||
* Use pipeline_run_id to monitor status
|
||||
* Check completion via pipeline monitoring APIs
|
||||
|
||||
Next Steps:
|
||||
After successful cognify processing, use search functions to query the knowledge:
|
||||
|
||||
```python
|
||||
import cognee
|
||||
from cognee import SearchType
|
||||
|
||||
# Process your data into knowledge graph
|
||||
await cognee.cognify()
|
||||
|
||||
# Query for insights using different search types:
|
||||
|
||||
# 1. Natural language completion with graph context
|
||||
insights = await cognee.search(
|
||||
"What are the main themes?",
|
||||
query_type=SearchType.GRAPH_COMPLETION
|
||||
)
|
||||
|
||||
# 2. Get entity relationships and connections
|
||||
relationships = await cognee.search(
|
||||
"connections between concepts",
|
||||
query_type=SearchType.INSIGHTS
|
||||
)
|
||||
|
||||
# 3. Find relevant document chunks
|
||||
chunks = await cognee.search(
|
||||
"specific topic",
|
||||
query_type=SearchType.CHUNKS
|
||||
)
|
||||
```
|
||||
|
||||
Advanced Usage:
|
||||
```python
|
||||
# Custom domain model for scientific papers
|
||||
class ScientificPaper(DataPoint):
|
||||
title: str
|
||||
authors: List[str]
|
||||
methodology: str
|
||||
findings: List[str]
|
||||
|
||||
await cognee.cognify(
|
||||
datasets=["research_papers"],
|
||||
graph_model=ScientificPaper,
|
||||
ontology_file_path="scientific_ontology.owl"
|
||||
)
|
||||
|
||||
# Background processing for large datasets
|
||||
run_info = await cognee.cognify(
|
||||
datasets=["large_corpus"],
|
||||
run_in_background=True
|
||||
)
|
||||
# Check status later with run_info.pipeline_run_id
|
||||
```
|
||||
|
||||
|
||||
Environment Variables:
|
||||
Required:
|
||||
- LLM_API_KEY: API key for your LLM provider
|
||||
|
||||
Optional (same as add function):
|
||||
- LLM_PROVIDER, LLM_MODEL, VECTOR_DB_PROVIDER, GRAPH_DATABASE_PROVIDER
|
||||
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
|
||||
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
|
||||
The pipeline executor result. In blocking mode this is the pipeline run result
|
||||
(per-dataset run info and status). In background mode this returns information
|
||||
required to track the background run (e.g., pipeline_run_id and submission status).
|
||||
"""
|
||||
if temporal_cognify:
|
||||
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
||||
else:
|
||||
tasks = await get_default_tasks(
|
||||
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
|
||||
)
|
||||
tasks = get_default_tasks(
|
||||
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
|
||||
)
|
||||
|
||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||
|
|
@ -211,20 +136,49 @@ async def cognify(
|
|||
)
|
||||
|
||||
|
||||
async def get_default_tasks( # TODO: Find out a better way to do this (Boris's comment)
|
||||
user: User = None,
|
||||
graph_model: BaseModel = KnowledgeGraph,
|
||||
def get_default_tasks( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
user: Optional[User] = None,
|
||||
graph_model: Type[BaseModel] = KnowledgeGraph,
|
||||
chunker=TextChunker,
|
||||
chunk_size: int = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
ontology_file_path: Optional[str] = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
) -> list[Task]:
|
||||
"""
|
||||
Return the standard, non-translation Task list used by the cognify pipeline.
|
||||
|
||||
This builds the default processing pipeline (no automatic translation) and returns
|
||||
a list of Task objects in execution order:
|
||||
1. classify_documents
|
||||
2. check_permissions_on_dataset (enforces write permission for `user`)
|
||||
3. extract_chunks_from_documents (uses `chunker` and `chunk_size`)
|
||||
4. extract_graph_from_data (uses `graph_model`, optional `ontology_file_path`, and `custom_prompt`)
|
||||
5. summarize_text
|
||||
6. add_data_points
|
||||
|
||||
Notes:
|
||||
- Batch sizes for downstream tasks use the module-level DEFAULT_BATCH_SIZE.
|
||||
- If `chunk_size` is not provided, the token limit from get_max_chunk_tokens() is used.
|
||||
|
||||
Parameters:
|
||||
user: Optional user context used for the permission check.
|
||||
graph_model: Model class used to construct knowledge graph instances.
|
||||
chunker: Chunking strategy or class used to split documents into chunks.
|
||||
chunk_size: Optional max tokens per chunk; if omitted, defaults to get_max_chunk_tokens().
|
||||
ontology_file_path: Optional path to an ontology file passed to the extractor.
|
||||
custom_prompt: Optional custom prompt applied during graph extraction.
|
||||
|
||||
Returns:
|
||||
List[Task]: Ordered list of Task objects for the cognify pipeline (no translation).
|
||||
"""
|
||||
# Precompute max_chunk_size for stability
|
||||
max_chunk = chunk_size or get_max_chunk_tokens()
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
max_chunk_size=max_chunk,
|
||||
chunker=chunker,
|
||||
), # Extract text chunks based on the document type.
|
||||
Task(
|
||||
|
|
@ -232,51 +186,92 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
graph_model=graph_model,
|
||||
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
||||
custom_prompt=custom_prompt,
|
||||
task_config={"batch_size": 10},
|
||||
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
task_config={"batch_size": 10},
|
||||
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}),
|
||||
]
|
||||
|
||||
return default_tasks
|
||||
|
||||
|
||||
async def get_temporal_tasks(
|
||||
user: User = None, chunker=TextChunker, chunk_size: int = None
|
||||
def get_default_tasks_with_translation( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
user: Optional[User] = None,
|
||||
graph_model: Type[BaseModel] = KnowledgeGraph,
|
||||
chunker=TextChunker,
|
||||
chunk_size: Optional[int] = None,
|
||||
ontology_file_path: Optional[str] = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
translation_provider: str = "noop",
|
||||
) -> list[Task]:
|
||||
"""
|
||||
Builds and returns a list of temporal processing tasks to be executed in sequence.
|
||||
|
||||
The pipeline includes:
|
||||
1. Document classification.
|
||||
2. Dataset permission checks (requires "write" access).
|
||||
3. Document chunking with a specified or default chunk size.
|
||||
4. Event and timestamp extraction from chunks.
|
||||
5. Knowledge graph extraction from events.
|
||||
6. Batched insertion of data points.
|
||||
|
||||
Args:
|
||||
user (User, optional): The user requesting task execution, used for permission checks.
|
||||
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
|
||||
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
|
||||
|
||||
Return the default Cognify pipeline task list with an added translation step.
|
||||
|
||||
Constructs the standard processing pipeline (classify -> permission check -> chunk extraction -> translate -> graph extraction -> summarize -> add data points),
|
||||
validates and initializes the named translation provider, and applies module DEFAULT_BATCH_SIZE to downstream batchable tasks.
|
||||
|
||||
Parameters:
|
||||
translation_provider (str): Name of a registered translation provider (case-insensitive). Defaults to `"noop"` which is a no-op provider.
|
||||
|
||||
Returns:
|
||||
list[Task]: A list of Task objects representing the temporal processing pipeline.
|
||||
list[Task]: Ordered Task objects ready to be executed by the pipeline executor.
|
||||
|
||||
Raises:
|
||||
UnknownTranslationProviderError: If the given provider name is not in get_available_providers().
|
||||
ProviderInitializationError: If the provider fails to initialize or validate via validate_provider().
|
||||
"""
|
||||
temporal_tasks = [
|
||||
# Fail fast on unknown providers (keeps errors close to the API surface)
|
||||
translation_provider = (translation_provider or "noop").strip().lower()
|
||||
# Validate provider using public API
|
||||
if translation_provider not in get_available_providers():
|
||||
available = ", ".join(get_available_providers())
|
||||
logger.error("Unknown provider '%s'. Available: %s", translation_provider, available)
|
||||
raise UnknownTranslationProviderError(f"Unknown provider '{translation_provider}'")
|
||||
# Instantiate to validate dependencies; include provider-specific config errors
|
||||
try:
|
||||
validate_provider(translation_provider)
|
||||
except Exception as e: # we want to convert provider init errors
|
||||
available = ", ".join(get_available_providers())
|
||||
logger.error(
|
||||
"Provider '%s' failed to initialize (available: %s).",
|
||||
translation_provider,
|
||||
available,
|
||||
exc_info=True,
|
||||
)
|
||||
raise ProviderInitializationError() from e
|
||||
|
||||
# Precompute max_chunk_size for stability
|
||||
max_chunk = chunk_size or get_max_chunk_tokens()
|
||||
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
max_chunk_size=max_chunk,
|
||||
chunker=chunker,
|
||||
), # Extract text chunks based on the document type.
|
||||
Task(
|
||||
translate_content,
|
||||
target_language="en",
|
||||
translation_provider=translation_provider,
|
||||
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||
), # Auto-translate non-English content and attach metadata
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
graph_model=graph_model,
|
||||
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
||||
custom_prompt=custom_prompt,
|
||||
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
task_config={"batch_size": DEFAULT_BATCH_SIZE},
|
||||
),
|
||||
Task(extract_events_and_timestamps, task_config={"chunk_size": 10}),
|
||||
Task(extract_knowledge_graph_from_events),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}),
|
||||
]
|
||||
|
||||
return temporal_tasks
|
||||
return default_tasks
|
||||
|
|
|
|||
496
cognee/tasks/translation/test_translation.py
Normal file
496
cognee/tasks/translation/test_translation.py
Normal file
|
|
@ -0,0 +1,496 @@
|
|||
"""
|
||||
Unit tests for translation functionality.
|
||||
|
||||
Tests cover:
|
||||
- Translation provider registry and discovery
|
||||
- Language detection across providers
|
||||
- Translation functionality
|
||||
- Error handling and fallbacks
|
||||
- Model validation and serialization
|
||||
"""
|
||||
|
||||
import pytest # type: ignore[import-untyped]
|
||||
from typing import Tuple, Optional, Dict
|
||||
from pydantic import ValidationError
|
||||
import cognee.tasks.translation.translate_content as tr
|
||||
|
||||
from cognee.tasks.translation.translate_content import (
|
||||
translate_content,
|
||||
register_translation_provider,
|
||||
get_available_providers,
|
||||
TranslationProvider,
|
||||
NoOpProvider,
|
||||
_get_provider,
|
||||
)
|
||||
from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata
|
||||
|
||||
|
||||
class TestDetectionError(Exception): # pylint: disable=too-few-public-methods
|
||||
"""Test exception for detection failures."""
|
||||
|
||||
|
||||
class TestTranslationError(Exception): # pylint: disable=too-few-public-methods
|
||||
"""Test exception for translation failures."""
|
||||
|
||||
|
||||
# Ensure registry isolation across tests using public helpers
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_registry():
|
||||
"""
|
||||
Pytest fixture that snapshots the translation provider registry and restores it after the test.
|
||||
|
||||
Use to isolate tests that register or modify providers: the current registry state is captured before the test runs, and always restored when the fixture completes (including on exceptions).
|
||||
"""
|
||||
snapshot = tr.snapshot_registry()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tr.restore_registry(snapshot)
|
||||
|
||||
|
||||
class MockDocumentChunk: # pylint: disable=too-few-public-methods
|
||||
"""Mock document chunk for testing."""
|
||||
|
||||
def __init__(self, text: str, chunk_id: str = "test_chunk", metadata: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize a mock document chunk used in tests.
|
||||
|
||||
Parameters:
|
||||
text (str): Chunk text content.
|
||||
chunk_id (str): Identifier for the chunk; also used as chunk_index for tests. Defaults to "test_chunk".
|
||||
metadata (Optional[Dict]): Optional mapping of metadata values; defaults to an empty dict.
|
||||
"""
|
||||
self.text = text
|
||||
self.id = chunk_id
|
||||
self.chunk_index = chunk_id
|
||||
self.metadata = metadata or {}
|
||||
|
||||
|
||||
class MockTranslationProvider:
|
||||
"""Mock provider for testing custom provider registration."""
|
||||
|
||||
async def detect_language(self, text: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Detect the language of the given text and return an ISO 639-1 language code with a confidence score.
|
||||
|
||||
This mock implementation uses simple keyword heuristics: returns ("es", 0.95) if the text contains "hola",
|
||||
("fr", 0.90) if it contains "bonjour", and ("en", 0.85) otherwise.
|
||||
|
||||
Parameters:
|
||||
text (str): Input text to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple[str, float]: A tuple of (language_code, confidence) where language_code is an ISO 639-1 code and
|
||||
confidence is a float between 0.0 and 1.0 indicating detection confidence.
|
||||
"""
|
||||
if "hola" in text.lower():
|
||||
return "es", 0.95
|
||||
if "bonjour" in text.lower():
|
||||
return "fr", 0.90
|
||||
return "en", 0.85
|
||||
|
||||
async def translate(self, text: str, target_language: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Simulate translating `text` into `target_language` and return a mock translated string with a confidence score.
|
||||
|
||||
If `target_language` is "en", returns the input prefixed with "[MOCK TRANSLATED]" and a confidence of 0.88. For any other target language, returns the original `text` and a confidence of 0.0.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to translate.
|
||||
target_language (str): The target language code (e.g., "en").
|
||||
|
||||
Returns:
|
||||
Tuple[str, float]: A pair of (translated_text, confidence) where confidence is in [0.0, 1.0].
|
||||
"""
|
||||
if target_language == "en":
|
||||
return f"[MOCK TRANSLATED] {text}", 0.88
|
||||
return text, 0.0
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Test translation provider registration and discovery."""
|
||||
|
||||
def test_get_available_providers_includes_builtin(self):
|
||||
"""Test that built-in providers are included in available list."""
|
||||
providers = get_available_providers()
|
||||
assert "noop" in providers
|
||||
assert "langdetect" in providers
|
||||
|
||||
def test_register_custom_provider(self):
|
||||
"""Test custom provider registration."""
|
||||
register_translation_provider("mock", MockTranslationProvider)
|
||||
providers = get_available_providers()
|
||||
assert "mock" in providers
|
||||
|
||||
# Test provider can be retrieved
|
||||
provider = _get_provider("mock")
|
||||
assert isinstance(provider, MockTranslationProvider)
|
||||
|
||||
def test_provider_name_normalization(self):
|
||||
"""Test provider names are normalized to lowercase."""
|
||||
register_translation_provider("CUSTOM_PROVIDER", MockTranslationProvider)
|
||||
providers = get_available_providers()
|
||||
assert "custom_provider" in providers
|
||||
|
||||
# Should be retrievable with different casing
|
||||
provider1 = _get_provider("CUSTOM_PROVIDER")
|
||||
provider2 = _get_provider("custom_provider")
|
||||
assert provider1.__class__ is provider2.__class__
|
||||
|
||||
def test_unknown_provider_raises(self):
|
||||
"""Test unknown providers raise ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
_get_provider("nonexistent_provider")
|
||||
|
||||
|
||||
class TestNoOpProvider:
|
||||
"""Test NoOp provider functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_language_ascii(self):
|
||||
"""Test language detection for ASCII text."""
|
||||
provider = NoOpProvider()
|
||||
lang, conf = await provider.detect_language("Hello world")
|
||||
assert lang is None
|
||||
assert conf == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_language_unicode(self):
|
||||
"""Test language detection for Unicode text."""
|
||||
provider = NoOpProvider()
|
||||
lang, conf = await provider.detect_language("Hëllo wörld")
|
||||
assert lang is None
|
||||
assert conf == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_returns_original(self):
|
||||
"""Test translation returns original text with zero confidence."""
|
||||
provider = NoOpProvider()
|
||||
text = "Test text"
|
||||
translated, conf = await provider.translate(text, "es")
|
||||
assert translated == text
|
||||
assert conf == 0.0
|
||||
|
||||
|
||||
class TestTranslationModels:
|
||||
"""Test Pydantic models for translation data."""
|
||||
|
||||
def test_translated_content_validation(self):
|
||||
"""Test TranslatedContent model validation."""
|
||||
content = TranslatedContent(
|
||||
original_chunk_id="chunk_1",
|
||||
original_text="Hello",
|
||||
translated_text="Hola",
|
||||
source_language="en",
|
||||
target_language="es",
|
||||
translation_provider="test",
|
||||
confidence_score=0.9
|
||||
)
|
||||
assert content.original_chunk_id == "chunk_1"
|
||||
assert content.confidence_score == 0.9
|
||||
|
||||
def test_translated_content_confidence_validation(self):
|
||||
"""Test confidence score validation bounds."""
|
||||
# Valid confidence scores
|
||||
TranslatedContent(
|
||||
original_chunk_id="test",
|
||||
original_text="test",
|
||||
translated_text="test",
|
||||
source_language="en",
|
||||
confidence_score=0.0
|
||||
)
|
||||
TranslatedContent(
|
||||
original_chunk_id="test",
|
||||
original_text="test",
|
||||
translated_text="test",
|
||||
source_language="en",
|
||||
confidence_score=1.0
|
||||
)
|
||||
|
||||
# Invalid confidence scores should raise validation error
|
||||
with pytest.raises(ValidationError):
|
||||
TranslatedContent(
|
||||
original_chunk_id="test",
|
||||
original_text="test",
|
||||
translated_text="test",
|
||||
source_language="en",
|
||||
confidence_score=-0.1
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
TranslatedContent(
|
||||
original_chunk_id="test",
|
||||
original_text="test",
|
||||
translated_text="test",
|
||||
source_language="en",
|
||||
confidence_score=1.1
|
||||
)
|
||||
|
||||
def test_language_metadata_validation(self):
|
||||
"""Test LanguageMetadata model validation."""
|
||||
metadata = LanguageMetadata(
|
||||
content_id="chunk_1",
|
||||
detected_language="es",
|
||||
language_confidence=0.95,
|
||||
requires_translation=True,
|
||||
character_count=100
|
||||
)
|
||||
assert metadata.content_id == "chunk_1"
|
||||
assert metadata.requires_translation is True
|
||||
assert metadata.character_count == 100
|
||||
|
||||
def test_language_metadata_character_count_validation(self):
|
||||
"""Test character count cannot be negative."""
|
||||
with pytest.raises(ValidationError):
|
||||
LanguageMetadata(
|
||||
content_id="test",
|
||||
detected_language="en",
|
||||
character_count=-1
|
||||
)
|
||||
|
||||
|
||||
class TestTranslateContentFunction:
|
||||
"""Test main translate_content function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_provider_processing(self):
|
||||
"""Test processing with noop provider."""
|
||||
chunks = [
|
||||
MockDocumentChunk("Hello world", "chunk_1"),
|
||||
MockDocumentChunk("Test content", "chunk_2")
|
||||
]
|
||||
|
||||
result = await translate_content(
|
||||
chunks,
|
||||
target_language="en",
|
||||
translation_provider="noop",
|
||||
confidence_threshold=0.8
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
for chunk in result:
|
||||
assert "language" in chunk.metadata
|
||||
assert chunk.metadata["language"]["detected_language"] == "unknown"
|
||||
# No translation should occur with noop provider
|
||||
assert "translation" not in chunk.metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translation_with_custom_provider(self):
|
||||
"""Test translation with custom registered provider."""
|
||||
# Register mock provider
|
||||
register_translation_provider("test_provider", MockTranslationProvider)
|
||||
|
||||
chunks = [MockDocumentChunk("Hola mundo", "chunk_1")]
|
||||
|
||||
result = await translate_content(
|
||||
chunks,
|
||||
target_language="en",
|
||||
translation_provider="test_provider",
|
||||
confidence_threshold=0.8
|
||||
)
|
||||
|
||||
chunk = result[0]
|
||||
assert "language" in chunk.metadata
|
||||
assert "translation" in chunk.metadata
|
||||
|
||||
# Check language metadata
|
||||
lang_meta = chunk.metadata["language"]
|
||||
assert lang_meta["detected_language"] == "es"
|
||||
assert lang_meta["requires_translation"] is True
|
||||
|
||||
# Check translation metadata
|
||||
trans_meta = chunk.metadata["translation"]
|
||||
assert trans_meta["original_text"] == "Hola mundo"
|
||||
assert "[MOCK TRANSLATED]" in trans_meta["translated_text"]
|
||||
assert trans_meta["source_language"] == "es"
|
||||
assert trans_meta["target_language"] == "en"
|
||||
assert trans_meta["translation_provider"] == "test_provider"
|
||||
|
||||
# Check chunk text was updated
|
||||
assert "[MOCK TRANSLATED]" in chunk.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_confidence_no_translation(self):
|
||||
"""Test that low confidence detection doesn't trigger translation."""
|
||||
register_translation_provider("low_conf", MockTranslationProvider)
|
||||
|
||||
chunks = [MockDocumentChunk("Hello world", "chunk_1")] # English text
|
||||
|
||||
result = await translate_content(
|
||||
chunks,
|
||||
target_language="en",
|
||||
translation_provider="low_conf",
|
||||
confidence_threshold=0.9 # High threshold
|
||||
)
|
||||
|
||||
chunk = result[0]
|
||||
assert "language" in chunk.metadata
|
||||
# Should not translate due to high threshold and English detection
|
||||
assert "translation" not in chunk.metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_detection(self):
|
||||
"""Test graceful error handling in language detection."""
|
||||
class FailingProvider:
|
||||
async def detect_language(self, _text: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Simulate a language detection failure by always raising TestDetectionError.
|
||||
|
||||
This async method is used in tests to emulate a provider that fails during language detection. It accepts a text string but does not return; it always raises TestDetectionError.
|
||||
"""
|
||||
raise TestDetectionError()
|
||||
|
||||
async def translate(self, text: str, _target_language: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Return the input text unchanged and a translation confidence of 0.0.
|
||||
|
||||
This no-op translator performs no translation; the supplied target language is ignored.
|
||||
|
||||
Parameters:
|
||||
text (str): Source text to "translate".
|
||||
_target_language (str): Target language (ignored).
|
||||
|
||||
Returns:
|
||||
Tuple[str, float]: A tuple containing the original text and a confidence score (always 0.0).
|
||||
"""
|
||||
return text, 0.0
|
||||
|
||||
register_translation_provider("failing", FailingProvider)
|
||||
|
||||
chunks = [MockDocumentChunk("Test text", "chunk_1")]
|
||||
|
||||
# Disable 'langdetect' fallback to force unknown
|
||||
ld = tr._provider_registry.pop("langdetect", None)
|
||||
try:
|
||||
result = await translate_content(chunks, translation_provider="failing")
|
||||
finally:
|
||||
if ld is not None:
|
||||
tr._provider_registry["langdetect"] = ld
|
||||
|
||||
chunk = result[0]
|
||||
assert "language" in chunk.metadata
|
||||
# Should have unknown language due to detection failure
|
||||
lang_meta = chunk.metadata["language"]
|
||||
assert lang_meta["detected_language"] == "unknown"
|
||||
assert lang_meta["language_confidence"] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_translation(self):
|
||||
"""Test graceful error handling in translation."""
|
||||
class PartialProvider:
|
||||
async def detect_language(self, _text: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Mock language detection used in tests.
|
||||
|
||||
Parameters:
|
||||
_text (str): Input text (ignored by this mock).
|
||||
|
||||
Returns:
|
||||
Tuple[str, float]: A fixed detected language code ("es") and confidence (0.9).
|
||||
"""
|
||||
return "es", 0.9
|
||||
|
||||
async def translate(self, _text: str, _target_language: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Simulate a failing translation by always raising TestTranslationError.
|
||||
|
||||
This async method ignores its inputs and is used in tests to emulate a provider-side failure during translation.
|
||||
|
||||
Parameters:
|
||||
_text (str): Unused input text.
|
||||
_target_language (str): Unused target language code.
|
||||
|
||||
Raises:
|
||||
TestTranslationError: Always raised to simulate a translation failure.
|
||||
"""
|
||||
raise TestTranslationError()
|
||||
|
||||
register_translation_provider("partial", PartialProvider)
|
||||
|
||||
chunks = [MockDocumentChunk("Hola", "chunk_1")]
|
||||
|
||||
result = await translate_content(
|
||||
chunks,
|
||||
translation_provider="partial",
|
||||
confidence_threshold=0.8
|
||||
)
|
||||
|
||||
chunk = result[0]
|
||||
# Should have detected Spanish but failed translation
|
||||
assert chunk.metadata["language"]["detected_language"] == "es"
|
||||
# Should still create translation metadata with original text
|
||||
assert "translation" in chunk.metadata
|
||||
trans_meta = chunk.metadata["translation"]
|
||||
assert trans_meta["translated_text"] == "Hola" # Original text due to failure
|
||||
assert trans_meta["confidence_score"] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_translation_when_same_language(self):
|
||||
"""Test no translation occurs when source equals target language."""
|
||||
register_translation_provider("same_lang", MockTranslationProvider)
|
||||
|
||||
chunks = [MockDocumentChunk("Hello world", "chunk_1")]
|
||||
|
||||
result = await translate_content(
|
||||
chunks,
|
||||
target_language="en", # Same as detected language
|
||||
translation_provider="same_lang"
|
||||
)
|
||||
|
||||
chunk = result[0]
|
||||
assert "language" in chunk.metadata
|
||||
# No translation should occur for same language
|
||||
assert "translation" not in chunk.metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_serialization(self):
|
||||
"""Test that metadata is properly serialized to dicts."""
|
||||
register_translation_provider("serialize_test", MockTranslationProvider)
|
||||
|
||||
chunks = [MockDocumentChunk("Hola", "chunk_1")]
|
||||
|
||||
result = await translate_content(
|
||||
chunks,
|
||||
translation_provider="serialize_test",
|
||||
confidence_threshold=0.8
|
||||
)
|
||||
|
||||
chunk = result[0]
|
||||
|
||||
# Metadata should be plain dicts, not Pydantic models
|
||||
assert isinstance(chunk.metadata["language"], dict)
|
||||
if "translation" in chunk.metadata:
|
||||
assert isinstance(chunk.metadata["translation"], dict)
|
||||
|
||||
def test_model_serialization_compatibility(self):
|
||||
"""
|
||||
Verify that a TranslatedContent instance can be dumped to a JSON-serializable dict.
|
||||
|
||||
Creates a TranslatedContent with sample fields, calls model_dump(), and asserts:
|
||||
- the result is a dict,
|
||||
- required fields like `original_chunk_id`, `translation_timestamp`, and `metadata` are present and preserved,
|
||||
- the dict can be round-tripped through json.dumps/json.loads without losing `original_chunk_id`.
|
||||
"""
|
||||
content = TranslatedContent(
|
||||
original_chunk_id="test",
|
||||
original_text="Hello",
|
||||
translated_text="Hola",
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
# Should serialize to dict
|
||||
data = content.model_dump()
|
||||
assert isinstance(data, dict)
|
||||
assert data["original_chunk_id"] == "test"
|
||||
assert "translation_timestamp" in data
|
||||
assert "metadata" in data
|
||||
|
||||
# Should be JSON serializable
|
||||
import json
|
||||
json_str = json.dumps(data)
|
||||
parsed = json.loads(json_str)
|
||||
assert parsed["original_chunk_id"] == "test"
|
||||
|
||||
|
||||
|
||||
660
cognee/tasks/translation/translate_content.py
Normal file
660
cognee/tasks/translation/translate_content.py
Normal file
|
|
@ -0,0 +1,660 @@
|
|||
# pylint: disable=R0903, W0221
|
||||
"""This module provides content translation capabilities for the Cognee framework."""
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Type, Protocol, Tuple, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from .models import TranslatedContent, LanguageMetadata
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Custom exceptions for better error handling
|
||||
class TranslationDependencyError(ImportError):
|
||||
"""Raised when a required translation dependency is missing."""
|
||||
|
||||
class LangDetectError(TranslationDependencyError):
|
||||
"""LangDetect library required."""
|
||||
|
||||
class OpenAIError(TranslationDependencyError):
|
||||
"""OpenAI library required."""
|
||||
|
||||
class GoogleTranslateError(TranslationDependencyError):
|
||||
"""GoogleTrans library required."""
|
||||
|
||||
class AzureTranslateError(TranslationDependencyError):
|
||||
"""Azure AI Translation library required."""
|
||||
|
||||
class AzureConfigError(ValueError):
|
||||
"""Azure configuration error."""
|
||||
|
||||
# Environment variables for configuration
|
||||
TARGET_LANGUAGE = os.getenv("COGNEE_TRANSLATION_TARGET_LANGUAGE", "en")
|
||||
try:
|
||||
CONFIDENCE_THRESHOLD = float(os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD", "0.80"))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid float for COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD=%r; defaulting to 0.80",
|
||||
os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD"),
|
||||
)
|
||||
CONFIDENCE_THRESHOLD = 0.80
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationContext:
|
||||
"""A context object to hold data for a single translation operation."""
|
||||
provider: "TranslationProvider"
|
||||
chunk: Any
|
||||
text: str
|
||||
target_language: str
|
||||
confidence_threshold: float
|
||||
provider_name: str
|
||||
content_id: str = field(init=False)
|
||||
detected_language: str = "unknown"
|
||||
detection_confidence: float = 0.0
|
||||
requires_translation: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initialize derived fields after dataclass construction.
|
||||
|
||||
Sets self.content_id to the first available identifier on self.chunk in this order:
|
||||
- self.chunk.id
|
||||
- self.chunk.chunk_index
|
||||
If neither attribute exists, content_id is set to the string "unknown".
|
||||
"""
|
||||
self.content_id = getattr(self.chunk, "id", getattr(self.chunk, "chunk_index", "unknown"))
|
||||
|
||||
|
||||
class TranslationProvider(Protocol):
|
||||
"""Protocol for translation providers."""
|
||||
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Detect the language of the provided text.
|
||||
|
||||
Uses the langdetect library to determine the most likely language and its probability.
|
||||
Returns a tuple (language_code, confidence) where `language_code` is a normalized short code (e.g., "en", "fr" or "unknown") and `confidence` is a float in [0.0, 1.0]. Returns None when detection fails (empty input, an error, or no reliable result).
|
||||
"""
|
||||
|
||||
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Translate the given text into the specified target language asynchronously.
|
||||
|
||||
Parameters:
|
||||
text: The source text to translate.
|
||||
target_language: Target language code (e.g., "en", "es", "fr-CA").
|
||||
|
||||
Returns:
|
||||
A tuple (translated_text, confidence) on success, where `confidence` is a float in [0.0, 1.0] (may be 0.0 if the provider does not supply a score), or None if translation failed or was unavailable.
|
||||
"""
|
||||
|
||||
# Registry for translation providers
|
||||
_provider_registry: Dict[str, Type[TranslationProvider]] = {}
|
||||
|
||||
def register_translation_provider(name: str, provider: Type[TranslationProvider]):
|
||||
"""
|
||||
Register a translation provider under a canonical lowercase key.
|
||||
|
||||
The provided class will be stored in the internal provider registry and looked up by its lowercased `name`. If an entry with the same key already exists it will be replaced.
|
||||
|
||||
Parameters:
|
||||
name (str): Human-readable provider name (case-insensitive); stored as lower-case.
|
||||
provider (Type[TranslationProvider]): Provider class implementing the TranslationProvider protocol; instances are constructed when the provider is resolved.
|
||||
"""
|
||||
_provider_registry[name.lower()] = provider
|
||||
|
||||
def get_available_providers():
|
||||
"""Returns a list of available translation providers."""
|
||||
return sorted(_provider_registry.keys())
|
||||
|
||||
def _get_provider(translation_provider: str) -> TranslationProvider:
|
||||
"""
|
||||
Resolve and instantiate a registered translation provider by name.
|
||||
|
||||
The lookup is case-insensitive: `translation_provider` should be the provider key (e.g., "openai", "google", "noop").
|
||||
Returns an instance of the provider implementing the TranslationProvider protocol.
|
||||
|
||||
Raises:
|
||||
ValueError: if no provider is registered under the given name; the error message lists available providers.
|
||||
"""
|
||||
provider_class = _provider_registry.get(translation_provider.lower())
|
||||
if not provider_class:
|
||||
available = ', '.join(get_available_providers())
|
||||
msg = f"Unknown translation provider: {translation_provider}. Available providers: {available}"
|
||||
raise ValueError(msg)
|
||||
return provider_class()
|
||||
# Helpers
|
||||
def _normalize_lang_code(code: Optional[str]) -> str:
|
||||
"""
|
||||
Normalize a language code to a canonical form or return "unknown".
|
||||
|
||||
Normalizes common language code formats:
|
||||
- Two-letter codes (e.g., "en", "EN", " en ") -> "en"
|
||||
- Locale codes with region (e.g., "en-us", "en_US", "EN-us") -> "en-US"
|
||||
- Returns "unknown" for empty, non-string, or unrecognized inputs.
|
||||
|
||||
Parameters:
|
||||
code (Optional[str]): Language code or locale string to normalize.
|
||||
|
||||
Returns:
|
||||
str: Normalized language code in either "xx" or "xx-YY" form, or "unknown" if input is invalid.
|
||||
"""
|
||||
if not isinstance(code, str) or not code.strip():
|
||||
return "unknown"
|
||||
c = code.strip().replace("_", "-")
|
||||
parts = c.split("-")
|
||||
if len(parts) == 1 and len(parts[0]) == 2 and parts[0].isalpha():
|
||||
return parts[0].lower()
|
||||
if len(parts) >= 2 and len(parts[0]) == 2 and parts[1]:
|
||||
return f"{parts[0].lower()}-{parts[1][:2].upper()}"
|
||||
return "unknown"
|
||||
|
||||
def _provider_name(provider: TranslationProvider) -> str:
|
||||
"""Return the canonical registry key for a provider instance, or a best-effort name."""
|
||||
return next(
|
||||
(name for name, cls in _provider_registry.items() if isinstance(provider, cls)),
|
||||
provider.__class__.__name__.replace("Provider", "").lower(),
|
||||
)
|
||||
|
||||
async def _detect_language_with_fallback(provider: TranslationProvider, text: str, content_id: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Detect the language of `text`, falling back to the registered "langdetect" provider if the primary provider fails.
|
||||
|
||||
Attempts to call the primary provider's `detect_language`. If that call returns None or raises, and a different "langdetect" provider is registered, it will try the fallback. Detection failures are logged; exceptions are not propagated.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to detect language for.
|
||||
content_id (str): Identifier used in logs to correlate errors to the input content.
|
||||
|
||||
Returns:
|
||||
Tuple[str, float]: A normalized language code (e.g., "en" or "pt-BR") and a confidence score in [0.0, 1.0].
|
||||
On detection failure returns ("unknown", 0.0). Confidence values are coerced to float, NaNs converted to 0.0, and clamped to the [0.0, 1.0] range.
|
||||
"""
|
||||
try:
|
||||
detection = await provider.detect_language(text)
|
||||
except Exception:
|
||||
logger.exception("Language detection failed for content_id=%s", content_id)
|
||||
detection = None
|
||||
|
||||
if detection is None:
|
||||
fallback_cls = _provider_registry.get("langdetect")
|
||||
if fallback_cls is not None and not isinstance(provider, fallback_cls):
|
||||
try:
|
||||
detection = await fallback_cls().detect_language(text)
|
||||
except Exception:
|
||||
logger.exception("Fallback language detection failed for content_id=%s", content_id)
|
||||
detection = None
|
||||
|
||||
if detection is None:
|
||||
return "unknown", 0.0
|
||||
|
||||
lang_code, conf = detection
|
||||
detected_language = _normalize_lang_code(lang_code)
|
||||
try:
|
||||
conf = float(conf)
|
||||
except (TypeError, ValueError):
|
||||
conf = 0.0
|
||||
if math.isnan(conf):
|
||||
conf = 0.0
|
||||
conf = max(0.0, min(1.0, conf))
|
||||
return detected_language, conf
|
||||
|
||||
def _decide_if_translation_is_required(ctx: TranslationContext) -> None:
|
||||
"""
|
||||
Decide whether a translation should be performed and update ctx.requires_translation.
|
||||
|
||||
Normalizes the configured target language and marks translation as required only when:
|
||||
- The provider can perform translations (not "noop" or "langdetect"), and
|
||||
- Either the detected language is "unknown" and the text is non-empty, or
|
||||
- The detected language (normalized) differs from the target language and the detection confidence meets or exceeds ctx.confidence_threshold.
|
||||
|
||||
The function mutates the provided TranslationContext in-place and does not return a value.
|
||||
"""
|
||||
# Normalize to align with detected_language normalization and model regex.
|
||||
target_language = _normalize_lang_code(ctx.target_language)
|
||||
can_translate = ctx.provider_name not in ("noop", "langdetect")
|
||||
|
||||
if ctx.detected_language == "unknown":
|
||||
ctx.requires_translation = can_translate and bool(ctx.text.strip())
|
||||
else:
|
||||
ctx.requires_translation = (
|
||||
ctx.detected_language != target_language
|
||||
and ctx.detection_confidence >= ctx.confidence_threshold
|
||||
)
|
||||
|
||||
def _attach_language_metadata(ctx: TranslationContext) -> None:
|
||||
"""
|
||||
Attach language detection and translation decision metadata to the context's chunk.
|
||||
|
||||
Ensures the chunk has a metadata mapping, builds a LanguageMetadata record from
|
||||
the context (content_id, detected language and confidence, whether translation is
|
||||
required, and character count of the text), serializes it, and stores it under
|
||||
the "language" key in chunk.metadata.
|
||||
|
||||
Parameters:
|
||||
ctx (TranslationContext): Context containing the chunk and detection/decision values.
|
||||
"""
|
||||
ctx.chunk.metadata = getattr(ctx.chunk, "metadata", {}) or {}
|
||||
lang_meta = LanguageMetadata(
|
||||
content_id=str(ctx.content_id),
|
||||
detected_language=ctx.detected_language,
|
||||
language_confidence=ctx.detection_confidence,
|
||||
requires_translation=ctx.requires_translation,
|
||||
character_count=len(ctx.text),
|
||||
)
|
||||
ctx.chunk.metadata["language"] = lang_meta.model_dump()
|
||||
|
||||
async def _translate_and_update(ctx: TranslationContext) -> None:
|
||||
"""
|
||||
Translate the text in the provided TranslationContext and update the chunk and its metadata.
|
||||
|
||||
Performs an async translation via ctx.provider.translate, and when a non-empty, changed translation is returned:
|
||||
- replaces ctx.chunk.text with the translated text,
|
||||
- attempts to update ctx.chunk.chunk_size (if present),
|
||||
- attaches a `translation` entry in ctx.chunk.metadata containing a TranslatedContent dict (original/translated text, source/target languages, provider, and confidence).
|
||||
|
||||
If translation fails (exception or None) the original text is preserved and a TranslatedContent record is still attached with confidence 0.0. If the provider returns the same text unchanged, no metadata is attached and the function returns without modifying the chunk.
|
||||
|
||||
Parameters:
|
||||
ctx (TranslationContext): context carrying provider, chunk, original text, target language, detected language, and content_id.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
tr = await ctx.provider.translate(ctx.text, ctx.target_language)
|
||||
except Exception:
|
||||
logger.exception("Translation failed for content_id=%s", ctx.content_id)
|
||||
tr = None
|
||||
|
||||
translated_text = None
|
||||
translation_confidence = 0.0
|
||||
provider_used = _provider_name(ctx.provider)
|
||||
target_for_meta = _normalize_lang_code(ctx.target_language)
|
||||
|
||||
if tr and isinstance(tr[0], str) and tr[0].strip() and tr[0] != ctx.text:
|
||||
translated_text, translation_confidence = tr
|
||||
ctx.chunk.text = translated_text
|
||||
if hasattr(ctx.chunk, "chunk_size"):
|
||||
try:
|
||||
ctx.chunk.chunk_size = len(translated_text.split())
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
logger.debug(
|
||||
"Could not update chunk_size for content_id=%s",
|
||||
ctx.content_id,
|
||||
exc_info=True,
|
||||
)
|
||||
elif tr is None:
|
||||
# Translation failed, keep original text
|
||||
translated_text = ctx.text
|
||||
else:
|
||||
# Provider returned unchanged text
|
||||
logger.info("Provider returned unchanged text; skipping translation metadata (content_id=%s)", ctx.content_id)
|
||||
return
|
||||
|
||||
trans = TranslatedContent(
|
||||
original_chunk_id=str(ctx.content_id),
|
||||
original_text=ctx.text,
|
||||
translated_text=translated_text,
|
||||
source_language=ctx.detected_language,
|
||||
target_language=target_for_meta,
|
||||
translation_provider=provider_used,
|
||||
confidence_score=translation_confidence or 0.0,
|
||||
)
|
||||
ctx.chunk.metadata["translation"] = trans.model_dump()
|
||||
|
||||
|
||||
# Test helpers for registry isolation
|
||||
def snapshot_registry() -> Dict[str, Type[TranslationProvider]]:
|
||||
"""Return a shallow copy snapshot of the provider registry (for tests)."""
|
||||
return dict(_provider_registry)
|
||||
|
||||
def restore_registry(snapshot: Dict[str, Type[TranslationProvider]]) -> None:
|
||||
"""
|
||||
Restore the global translation provider registry from a previously captured snapshot.
|
||||
|
||||
This replaces the current internal provider registry with the given snapshot (clears then updates),
|
||||
typically used by tests to restore provider registration state.
|
||||
|
||||
Parameters:
|
||||
snapshot (Dict[str, Type[TranslationProvider]]): Mapping of provider name keys to provider classes.
|
||||
"""
|
||||
_provider_registry.clear()
|
||||
_provider_registry.update(snapshot)
|
||||
|
||||
def validate_provider(name: str) -> None:
|
||||
"""Ensure a provider can be resolved and instantiated or raise."""
|
||||
_get_provider(name)
|
||||
|
||||
# Built-in Providers
|
||||
class NoOpProvider:
|
||||
"""A provider that does nothing, used for testing or disabling translation."""
|
||||
async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
No-op language detection: intentionally performs no detection and always returns None.
|
||||
|
||||
The `_text` parameter is ignored. Returns None to indicate that this provider does not provide a language detection result.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Return the input text unchanged and a confidence score of 0.0.
|
||||
|
||||
This provider does not perform any translation; it mirrors the source text back to the caller.
|
||||
Parameters:
|
||||
text (str): Source text to "translate".
|
||||
_target_language (str): Unused target language parameter.
|
||||
Returns:
|
||||
Optional[Tuple[str, float]]: A tuple of (text, 0.0).
|
||||
"""
|
||||
return text, 0.0
|
||||
|
||||
class LangDetectProvider:
|
||||
"""
|
||||
A provider that uses the 'langdetect' library for offline language detection.
|
||||
This provider only detects the language and does not perform translation.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the LangDetectProvider by loading the `langdetect.detect_langs` function.
|
||||
|
||||
Attempts to import `detect_langs` from the `langdetect` package and stores it on the instance as `_detect_langs`. Raises `LangDetectError` if the `langdetect` dependency is not available.
|
||||
"""
|
||||
try:
|
||||
from langdetect import detect_langs # type: ignore[import-untyped]
|
||||
self._detect_langs = detect_langs
|
||||
except ImportError as e:
|
||||
raise LangDetectError() from e
|
||||
|
||||
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Detect the language of `text` using the provider's langdetect backend.
|
||||
|
||||
Returns a tuple of (language_code, confidence) where `language_code` is the top
|
||||
detected language (e.g., "en") and `confidence` is the detection probability
|
||||
in [0.0, 1.0]. Returns None if detection fails or no result is available.
|
||||
"""
|
||||
try:
|
||||
detections = await asyncio.to_thread(self._detect_langs, text)
|
||||
except Exception:
|
||||
logger.exception("Error during language detection")
|
||||
return None
|
||||
|
||||
if not detections:
|
||||
return None
|
||||
best_detection = detections[0]
|
||||
return best_detection.lang, best_detection.prob
|
||||
|
||||
async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]:
|
||||
# This provider only detects language, does not translate.
|
||||
"""
|
||||
No-op translation: returns the input text unchanged with a 0.0 confidence.
|
||||
|
||||
This provider only performs language detection; translate is a passthrough that returns the original `text`
|
||||
and a confidence of 0.0 to indicate no translated content was produced.
|
||||
|
||||
Returns:
|
||||
A tuple of (text, confidence) where `text` is the original input and `confidence` is 0.0.
|
||||
"""
|
||||
return text, 0.0
|
||||
|
||||
class OpenAIProvider:
|
||||
"""A provider that uses OpenAI's API for translation."""
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the OpenAIProvider by creating an AsyncOpenAI client and loading configuration.
|
||||
|
||||
Reads the following environment variables:
|
||||
- OPENAI_API_KEY: API key passed to AsyncOpenAI for authentication.
|
||||
- OPENAI_TRANSLATE_MODEL: model name to use for translations (default: "gpt-4o-mini").
|
||||
- OPENAI_TIMEOUT: request timeout in seconds (default: "30", parsed as float).
|
||||
|
||||
Raises:
|
||||
OpenAIError: if the OpenAI SDK (AsyncOpenAI) cannot be imported.
|
||||
"""
|
||||
try:
|
||||
from openai import AsyncOpenAI # type: ignore[import-untyped]
|
||||
self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
self.model = os.getenv("OPENAI_TRANSLATE_MODEL", "gpt-4o-mini")
|
||||
self.timeout = float(os.getenv("OPENAI_TIMEOUT", "30"))
|
||||
except ImportError as e:
|
||||
raise OpenAIError() from e
|
||||
|
||||
async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]:
|
||||
# OpenAI's API does not have a separate language detection endpoint.
|
||||
# This can be implemented as part of the translation prompt if needed.
|
||||
"""
|
||||
Indicates that this provider does not perform standalone language detection.
|
||||
|
||||
The OpenAI-based provider does not expose a separate detection endpoint and therefore
|
||||
always returns None. Language detection can be achieved by using another provider
|
||||
(e.g., the registered langdetect provider) or by incorporating detection into a
|
||||
translation prompt if needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Translate the given text to the specified target language using the OpenAI chat completions client.
|
||||
|
||||
Parameters:
|
||||
text (str): Source text to translate.
|
||||
target_language (str): Target language name or code (used verbatim in the translation prompt).
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, float]]: A tuple of (translated_text, confidence). Confidence is 0.0 because no calibrated confidence is available.
|
||||
Returns None if translation failed or an error occurred.
|
||||
"""
|
||||
try:
|
||||
response = await self.client.with_options(timeout=self.timeout).chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": f"Translate the following text to {target_language}."},
|
||||
{"role": "user", "content": text},
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error during OpenAI translation (model=%s)", self.model)
|
||||
return None
|
||||
|
||||
translated_text = response.choices[0].message.content.strip()
|
||||
return translated_text, 0.0 # No calibrated confidence available.
|
||||
|
||||
class GoogleTranslateProvider:
|
||||
"""A provider that uses the 'googletrans' library for translation."""
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the GoogleTranslateProvider by importing and instantiating googletrans.Translator.
|
||||
|
||||
Raises:
|
||||
GoogleTranslateError: If the `googletrans` library is not installed or cannot be imported.
|
||||
"""
|
||||
try:
|
||||
from googletrans import Translator # type: ignore[import-untyped]
|
||||
self.translator = Translator()
|
||||
except ImportError as e:
|
||||
raise GoogleTranslateError() from e
|
||||
|
||||
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Detect the language of the given text using the configured googletrans Translator.
|
||||
|
||||
Uses a thread to call the synchronous translator.detect method; on failure returns None.
|
||||
|
||||
Parameters:
|
||||
text: The text to detect the language for.
|
||||
|
||||
Returns:
|
||||
A tuple (language_code, confidence) where `language_code` is the detected language string from the translator (e.g. "en") and `confidence` is a float in [0.0, 1.0]. Returns None if detection fails.
|
||||
"""
|
||||
try:
|
||||
detection = await asyncio.to_thread(self.translator.detect, text)
|
||||
except Exception:
|
||||
logger.exception("Error during Google Translate language detection")
|
||||
return None
|
||||
|
||||
try:
|
||||
conf = float(detection.confidence) if detection.confidence is not None else 0.0
|
||||
except (TypeError, ValueError):
|
||||
conf = 0.0
|
||||
return detection.lang, conf
|
||||
|
||||
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Translate `text` to `target_language` using the configured googletrans Translator.
|
||||
|
||||
Returns a tuple (translated_text, confidence) on success — confidence is always 0.0 because googletrans does not provide a confidence score — or None if translation fails.
|
||||
"""
|
||||
try:
|
||||
translation = await asyncio.to_thread(self.translator.translate, text, dest=target_language)
|
||||
except Exception:
|
||||
logger.exception("Error during Google Translate translation")
|
||||
return None
|
||||
|
||||
return translation.text, 0.0 # Confidence not provided.
|
||||
|
||||
class AzureTranslatorProvider:
|
||||
"""A provider that uses Azure's Translator service."""
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the AzureTranslatorProvider.
|
||||
|
||||
Attempts to import Azure SDK classes, reads AZURE_TRANSLATOR_KEY, AZURE_TRANSLATOR_ENDPOINT,
|
||||
and AZURE_TRANSLATOR_REGION from the environment, verifies the key is present, and constructs
|
||||
a TextTranslationClient using an AzureKeyCredential.
|
||||
|
||||
Raises:
|
||||
AzureConfigError: if AZURE_TRANSLATOR_KEY is not set.
|
||||
AzureTranslateError: if required Azure SDK imports are unavailable.
|
||||
"""
|
||||
try:
|
||||
from azure.core.credentials import AzureKeyCredential # type: ignore[import-untyped]
|
||||
from azure.ai.translation.text import TextTranslationClient # type: ignore[import-untyped]
|
||||
|
||||
self.key = os.getenv("AZURE_TRANSLATOR_KEY")
|
||||
self.endpoint = os.getenv("AZURE_TRANSLATOR_ENDPOINT", "https://api.cognitive.microsofttranslator.com/")
|
||||
self.region = os.getenv("AZURE_TRANSLATOR_REGION", "global")
|
||||
|
||||
if not self.key:
|
||||
raise AzureConfigError()
|
||||
|
||||
self.client = TextTranslationClient(
|
||||
endpoint=self.endpoint,
|
||||
credential=AzureKeyCredential(self.key),
|
||||
)
|
||||
except ImportError as e:
|
||||
raise AzureTranslateError() from e
|
||||
|
||||
async def detect_language(self, text: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Detect the language of the given text using the Azure Translator client's detect API.
|
||||
|
||||
Attempts to call the Azure client's detect method (using a two-letter region as a country hint when available)
|
||||
and returns a tuple of (language_code, confidence_score). Returns None if detection fails or an exception occurs.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to detect language for.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, float]]: (ISO language code, confidence between 0.0 and 1.0), or None on error.
|
||||
"""
|
||||
try:
|
||||
# Use a valid country hint only when it looks like ISO 3166-1 alpha-2; otherwise omit.
|
||||
hint = self.region.lower() if isinstance(self.region, str) and len(self.region) == 2 else None
|
||||
response = await asyncio.to_thread(self.client.detect, content=[text], country_hint=hint)
|
||||
except Exception:
|
||||
logger.exception("Error during Azure language detection")
|
||||
return None
|
||||
|
||||
detection = response[0].primary_language
|
||||
return detection.language, detection.score
|
||||
|
||||
async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Translate the given text to the target language using the Azure Translator client.
|
||||
|
||||
Parameters:
|
||||
text (str): Plain text to translate.
|
||||
target_language (str): BCP-47 or ISO language code to translate the text into.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, float]]: A tuple of (translated_text, confidence). Returns None on error.
|
||||
The provider does not surface a numeric confidence score, so the returned confidence is always 0.0.
|
||||
"""
|
||||
try:
|
||||
response = await asyncio.to_thread(self.client.translate, content=[text], to=[target_language])
|
||||
except Exception:
|
||||
logger.exception("Error during Azure translation")
|
||||
return None
|
||||
|
||||
translation = response[0].translations[0]
|
||||
return translation.text, 0.0 # Confidence not provided.
|
||||
|
||||
# Register built-in providers
|
||||
register_translation_provider("noop", NoOpProvider)
|
||||
register_translation_provider("langdetect", LangDetectProvider)
|
||||
register_translation_provider("openai", OpenAIProvider)
|
||||
register_translation_provider("google", GoogleTranslateProvider)
|
||||
register_translation_provider("azure", AzureTranslatorProvider)
|
||||
|
||||
async def translate_content( # pylint: disable=too-many-locals,too-many-branches
|
||||
*data_chunks,
|
||||
target_language: str = TARGET_LANGUAGE,
|
||||
translation_provider: str = "noop",
|
||||
confidence_threshold: float = CONFIDENCE_THRESHOLD,
|
||||
):
|
||||
"""
|
||||
Translate content chunks to a target language and attach language and translation metadata.
|
||||
|
||||
This function accepts either multiple chunk objects as varargs or a single list of chunks.
|
||||
For each chunk it:
|
||||
- Resolves the named translation provider.
|
||||
- Detects the chunk's language (with a fallback detector when available).
|
||||
- Decides whether translation is required based on detected language, confidence threshold, and provider.
|
||||
- Attaches language metadata (LanguageMetadata) to chunk.metadata.
|
||||
- If required, performs translation and updates the chunk text and metadata (TranslatedContent).
|
||||
|
||||
Parameters:
|
||||
*data_chunks: One or more chunk objects, or a single list of chunk objects. Each chunk must expose a `text` attribute and a `metadata` mapping (the function will create `metadata` if missing).
|
||||
target_language (str): Language code to translate into (defaults to TARGET_LANGUAGE).
|
||||
translation_provider (str): Registered provider name to use for detection/translation (defaults to "noop").
|
||||
confidence_threshold (float): Minimum detection confidence required to skip translation (defaults to CONFIDENCE_THRESHOLD).
|
||||
|
||||
Returns:
|
||||
list: The list of processed chunk objects (same objects, possibly modified). Metadata keys added include language detection results and, when a translation occurs, translation details.
|
||||
"""
|
||||
provider = _get_provider(translation_provider)
|
||||
results = []
|
||||
|
||||
if len(data_chunks) == 1 and isinstance(data_chunks[0], list):
|
||||
_chunks = data_chunks[0]
|
||||
else:
|
||||
_chunks = list(data_chunks)
|
||||
|
||||
for chunk in _chunks:
|
||||
ctx = TranslationContext(
|
||||
provider=provider,
|
||||
chunk=chunk,
|
||||
text=getattr(chunk, "text", "") or "",
|
||||
target_language=target_language,
|
||||
confidence_threshold=confidence_threshold,
|
||||
provider_name=translation_provider.lower(),
|
||||
)
|
||||
|
||||
ctx.detected_language, ctx.detection_confidence = await _detect_language_with_fallback(
|
||||
ctx.provider, ctx.text, str(ctx.content_id)
|
||||
)
|
||||
|
||||
_decide_if_translation_is_required(ctx)
|
||||
_attach_language_metadata(ctx)
|
||||
|
||||
if ctx.requires_translation:
|
||||
await _translate_and_update(ctx)
|
||||
|
||||
results.append(ctx.chunk)
|
||||
|
||||
return results
|
||||
86
examples/python/translation_example.py
Normal file
86
examples/python/translation_example.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import asyncio
|
||||
import os
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.api.v1.cognify.cognify import get_default_tasks_with_translation
|
||||
from cognee.modules.pipelines.operations.pipeline import run_pipeline
|
||||
|
||||
# Prerequisites:
|
||||
# 1. Set up your environment with API keys for your chosen translation provider.
|
||||
# - For OpenAI: OPENAI_API_KEY
|
||||
# - For Azure: AZURE_TRANSLATOR_KEY, AZURE_TRANSLATOR_ENDPOINT, AZURE_TRANSLATOR_REGION
|
||||
# 2. Specify the translation provider via an environment variable (optional, defaults to "noop"):
|
||||
# COGNEE_TRANSLATION_PROVIDER="openai" # Or "google", "azure", "langdetect"
|
||||
# 3. Install any required libraries for your provider:
|
||||
# - pip install langdetect googletrans==4.0.0rc1 azure-ai-translation-text
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Demonstrates an end-to-end translation-enabled Cognify workflow using the Cognee SDK.
|
||||
|
||||
Performs three main steps:
|
||||
1. Resets the demo workspace by pruning stored data and system metadata.
|
||||
2. Seeds three multilingual documents, builds translation-enabled Cognify tasks using the
|
||||
provider specified by the COGNEE_TRANSLATION_PROVIDER environment variable (defaults to "noop"),
|
||||
and executes the pipeline to translate and process the documents.
|
||||
- If the selected provider is missing or invalid, the function prints the error and returns early.
|
||||
3. Issues an English search query (using SearchType.INSIGHTS) against the processed index and
|
||||
prints any returned result texts.
|
||||
|
||||
Side effects:
|
||||
- Mutates persistent Cognee state (prune, add, cognify pipeline execution).
|
||||
- Prints status and result messages to stdout.
|
||||
|
||||
Notes:
|
||||
- No return value.
|
||||
- Exceptions ValueError and ImportError are caught and handled by printing an error and exiting the function.
|
||||
"""
|
||||
# 1. Set up cognee and add multilingual content
|
||||
print("Setting up demo environment...")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
multilingual_texts = [
|
||||
"El procesamiento de lenguaje natural (PLN) es un subcampo de la IA.",
|
||||
"Le traitement automatique du langage naturel (TALN) est un sous-domaine de l'IA.",
|
||||
"Natural language processing (NLP) is a subfield of AI.",
|
||||
]
|
||||
|
||||
print("Adding multilingual texts...")
|
||||
for text in multilingual_texts:
|
||||
await cognee.add(text)
|
||||
print("Texts added successfully.\n")
|
||||
|
||||
# 2. Run the cognify pipeline with translation enabled
|
||||
provider = os.getenv('COGNEE_TRANSLATION_PROVIDER', 'noop').lower()
|
||||
print(f"Running cognify with translation provider: {provider}")
|
||||
|
||||
try:
|
||||
# Build translation-enabled tasks and execute the pipeline
|
||||
translation_enabled_tasks = get_default_tasks_with_translation(
|
||||
translation_provider=provider
|
||||
)
|
||||
async for _ in run_pipeline(tasks=translation_enabled_tasks):
|
||||
pass
|
||||
print("Cognify pipeline with translation completed successfully.")
|
||||
except (ValueError, ImportError) as e:
|
||||
print(f"Error during cognify: {e}")
|
||||
print("Please ensure the selected provider is installed and configured correctly.")
|
||||
return
|
||||
|
||||
# 3. Search for content in English
|
||||
query_text = "Tell me about NLP"
|
||||
print(f"\nSearching for: '{query_text}'")
|
||||
|
||||
# The search should now return results from all documents, as they have been translated.
|
||||
search_results = await cognee.search(query_text, query_type=SearchType.INSIGHTS)
|
||||
|
||||
print("\nSearch Results:")
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
print(f"- {result.text}")
|
||||
else:
|
||||
print("No results found.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Reference in a new issue