Add group_id parameter to get_extraction_language_instruction

Enable consumers to provide group-specific language extraction
instructions by passing group_id through the call chain.

Changes:
- Add optional group_id parameter to get_extraction_language_instruction()
- Add group_id parameter to all LLMClient.generate_response() methods
- Pass group_id through to language instruction function
- Maintain backward compatibility with default None value

Users can now customize extraction per group:
```python
def custom_instruction(group_id: str | None = None) -> str:
    if group_id == 'spanish-users':
        return '\n\nExtract in Spanish.'
    return '\n\nExtract in original language.'

client.get_extraction_language_instruction = custom_instruction
```

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Chalef 2025-09-30 09:03:25 -07:00
parent b28bd92c16
commit d430ab82b4
5 changed files with 15 additions and 6 deletions

View file

@ -33,12 +33,16 @@ DEFAULT_TEMPERATURE = 0
DEFAULT_CACHE_DIR = './llm_cache'
def get_extraction_language_instruction() -> str:
def get_extraction_language_instruction(group_id: str | None = None) -> str:
"""Returns instruction for language extraction behavior.
Override this function to customize language extraction:
- Return empty string to disable multilingual instructions
- Return custom instructions for specific language requirements
- Use group_id to provide different instructions per group/partition
Args:
group_id: Optional partition identifier for the graph
Returns:
str: Language instruction to append to system messages
@ -142,6 +146,7 @@ class LLMClient(ABC):
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
) -> dict[str, typing.Any]:
if max_tokens is None:
max_tokens = self.max_tokens
@ -155,7 +160,7 @@ class LLMClient(ABC):
)
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction()
messages[0].content += get_extraction_language_instruction(group_id)
if self.cache_enabled and self.cache_dir is not None:
cache_key = self._get_cache_key(messages)

View file

@ -357,6 +357,7 @@ class GeminiClient(LLMClient):
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model with retry logic and error handling.
@ -367,6 +368,7 @@ class GeminiClient(LLMClient):
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int | None): The maximum number of tokens to generate in the response.
model_size (ModelSize): The size of the model to use (small or medium).
group_id (str | None): Optional partition identifier for the graph.
Returns:
dict[str, typing.Any]: The response from the language model.
@ -376,7 +378,7 @@ class GeminiClient(LLMClient):
last_output = None
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction()
messages[0].content += get_extraction_language_instruction(group_id)
while retry_count < self.MAX_RETRIES:
try:

View file

@ -175,6 +175,7 @@ class BaseOpenAIClient(LLMClient):
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
) -> dict[str, typing.Any]:
"""Generate a response with retry logic and error handling."""
if max_tokens is None:
@ -184,7 +185,7 @@ class BaseOpenAIClient(LLMClient):
last_error = None
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction()
messages[0].content += get_extraction_language_instruction(group_id)
while retry_count <= self.MAX_RETRIES:
try:

View file

@ -120,6 +120,7 @@ class OpenAIGenericClient(LLMClient):
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
) -> dict[str, typing.Any]:
if max_tokens is None:
max_tokens = self.max_tokens
@ -136,7 +137,7 @@ class OpenAIGenericClient(LLMClient):
)
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction()
messages[0].content += get_extraction_language_instruction(group_id)
while retry_count <= self.MAX_RETRIES:
try:

2
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.10, <4"
resolution-markers = [
"python_full_version >= '3.14'",