Add group_id parameter to get_extraction_language_instruction
Enable consumers to provide group-specific language extraction
instructions by passing group_id through the call chain.
Changes:
- Add optional group_id parameter to get_extraction_language_instruction()
- Add group_id parameter to all LLMClient.generate_response() methods
- Pass group_id through to language instruction function
- Maintain backward compatibility with default None value
Users can now customize extraction per group:
```python
def custom_instruction(group_id: str | None = None) -> str:
if group_id == 'spanish-users':
return '\n\nExtract in Spanish.'
return '\n\nExtract in original language.'
client.get_extraction_language_instruction = custom_instruction
```
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
b28bd92c16
commit
d430ab82b4
5 changed files with 15 additions and 6 deletions
|
|
@ -33,12 +33,16 @@ DEFAULT_TEMPERATURE = 0
|
|||
DEFAULT_CACHE_DIR = './llm_cache'
|
||||
|
||||
|
||||
def get_extraction_language_instruction() -> str:
|
||||
def get_extraction_language_instruction(group_id: str | None = None) -> str:
|
||||
"""Returns instruction for language extraction behavior.
|
||||
|
||||
Override this function to customize language extraction:
|
||||
- Return empty string to disable multilingual instructions
|
||||
- Return custom instructions for specific language requirements
|
||||
- Use group_id to provide different instructions per group/partition
|
||||
|
||||
Args:
|
||||
group_id: Optional partition identifier for the graph
|
||||
|
||||
Returns:
|
||||
str: Language instruction to append to system messages
|
||||
|
|
@ -142,6 +146,7 @@ class LLMClient(ABC):
|
|||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
group_id: str | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
if max_tokens is None:
|
||||
max_tokens = self.max_tokens
|
||||
|
|
@ -155,7 +160,7 @@ class LLMClient(ABC):
|
|||
)
|
||||
|
||||
# Add multilingual extraction instructions
|
||||
messages[0].content += get_extraction_language_instruction()
|
||||
messages[0].content += get_extraction_language_instruction(group_id)
|
||||
|
||||
if self.cache_enabled and self.cache_dir is not None:
|
||||
cache_key = self._get_cache_key(messages)
|
||||
|
|
|
|||
|
|
@ -357,6 +357,7 @@ class GeminiClient(LLMClient):
|
|||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
group_id: str | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Generate a response from the Gemini language model with retry logic and error handling.
|
||||
|
|
@ -367,6 +368,7 @@ class GeminiClient(LLMClient):
|
|||
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
||||
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
||||
model_size (ModelSize): The size of the model to use (small or medium).
|
||||
group_id (str | None): Optional partition identifier for the graph.
|
||||
|
||||
Returns:
|
||||
dict[str, typing.Any]: The response from the language model.
|
||||
|
|
@ -376,7 +378,7 @@ class GeminiClient(LLMClient):
|
|||
last_output = None
|
||||
|
||||
# Add multilingual extraction instructions
|
||||
messages[0].content += get_extraction_language_instruction()
|
||||
messages[0].content += get_extraction_language_instruction(group_id)
|
||||
|
||||
while retry_count < self.MAX_RETRIES:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -175,6 +175,7 @@ class BaseOpenAIClient(LLMClient):
|
|||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
group_id: str | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Generate a response with retry logic and error handling."""
|
||||
if max_tokens is None:
|
||||
|
|
@ -184,7 +185,7 @@ class BaseOpenAIClient(LLMClient):
|
|||
last_error = None
|
||||
|
||||
# Add multilingual extraction instructions
|
||||
messages[0].content += get_extraction_language_instruction()
|
||||
messages[0].content += get_extraction_language_instruction(group_id)
|
||||
|
||||
while retry_count <= self.MAX_RETRIES:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ class OpenAIGenericClient(LLMClient):
|
|||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
group_id: str | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
if max_tokens is None:
|
||||
max_tokens = self.max_tokens
|
||||
|
|
@ -136,7 +137,7 @@ class OpenAIGenericClient(LLMClient):
|
|||
)
|
||||
|
||||
# Add multilingual extraction instructions
|
||||
messages[0].content += get_extraction_language_instruction()
|
||||
messages[0].content += get_extraction_language_instruction(group_id)
|
||||
|
||||
while retry_count <= self.MAX_RETRIES:
|
||||
try:
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.10, <4"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14'",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue