Compare commits

...
Sign in to create a new pull request.

5 commits

Author SHA1 Message Date
Daniel Chalef
4a8102717f Reset uv.lock to main branch version 2025-10-03 06:37:34 -07:00
Daniel Chalef
efdd683504 Remove ensure_ascii parameter and uv.lock file 2025-10-02 23:01:15 -07:00
Daniel Chalef
ff603b5490 Fix type inconsistency in extract_nodes_reflexion parameter
Change group_id parameter from str = '' to str | None = None to match
the pattern used throughout the codebase and align with the optional
nature of group_id in generate_response().

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-02 15:39:29 -07:00
Daniel Chalef
a20643470f Pass group_id to generate_response in extraction operations
Thread group_id parameter through all extraction-related generate_response()
calls where it's naturally available (via episode.group_id or node.group_id).
This enables consumers to override get_extraction_language_instruction() with
group-specific language preferences.

Changes:
- edge_operations.py: Pass group_id in extract_edges()
- node_operations.py: Pass episode.group_id in extract_nodes() and
  node.group_id in extract_attributes_from_node()
- node_operations.py: Add group_id parameter to extract_nodes_reflexion()

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-02 15:39:29 -07:00
Daniel Chalef
d430ab82b4 Add group_id parameter to get_extraction_language_instruction
Enable consumers to provide group-specific language extraction
instructions by passing group_id through the call chain.

Changes:
- Add optional group_id parameter to get_extraction_language_instruction()
- Add group_id parameter to all LLMClient.generate_response() methods
- Pass group_id through to language instruction function
- Maintain backward compatibility with default None value

Users can now customize extraction per group:
```python
def custom_instruction(group_id: str | None = None) -> str:
    if group_id == 'spanish-users':
        return '\n\nExtract in Spanish.'
    return '\n\nExtract in original language.'

client.get_extraction_language_instruction = custom_instruction
```

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-02 15:38:50 -07:00
7 changed files with 24 additions and 12 deletions

View file

@ -33,12 +33,16 @@ DEFAULT_TEMPERATURE = 0
DEFAULT_CACHE_DIR = './llm_cache'
def get_extraction_language_instruction() -> str:
def get_extraction_language_instruction(group_id: str | None = None) -> str:
"""Returns instruction for language extraction behavior.
Override this function to customize language extraction:
- Return empty string to disable multilingual instructions
- Return custom instructions for specific language requirements
- Use group_id to provide different instructions per group/partition
Args:
group_id: Optional partition identifier for the graph
Returns:
str: Language instruction to append to system messages
@ -142,6 +146,7 @@ class LLMClient(ABC):
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
) -> dict[str, typing.Any]:
if max_tokens is None:
max_tokens = self.max_tokens
@ -155,7 +160,7 @@ class LLMClient(ABC):
)
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction()
messages[0].content += get_extraction_language_instruction(group_id)
if self.cache_enabled and self.cache_dir is not None:
cache_key = self._get_cache_key(messages)

View file

@ -357,6 +357,7 @@ class GeminiClient(LLMClient):
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model with retry logic and error handling.
@ -367,6 +368,7 @@ class GeminiClient(LLMClient):
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int | None): The maximum number of tokens to generate in the response.
model_size (ModelSize): The size of the model to use (small or medium).
group_id (str | None): Optional partition identifier for the graph.
Returns:
dict[str, typing.Any]: The response from the language model.
@ -376,7 +378,7 @@ class GeminiClient(LLMClient):
last_output = None
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction()
messages[0].content += get_extraction_language_instruction(group_id)
while retry_count < self.MAX_RETRIES:
try:

View file

@ -175,6 +175,7 @@ class BaseOpenAIClient(LLMClient):
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
) -> dict[str, typing.Any]:
"""Generate a response with retry logic and error handling."""
if max_tokens is None:
@ -184,7 +185,7 @@ class BaseOpenAIClient(LLMClient):
last_error = None
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction()
messages[0].content += get_extraction_language_instruction(group_id)
while retry_count <= self.MAX_RETRIES:
try:

View file

@ -120,6 +120,7 @@ class OpenAIGenericClient(LLMClient):
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
) -> dict[str, typing.Any]:
if max_tokens is None:
max_tokens = self.max_tokens
@ -136,7 +137,7 @@ class OpenAIGenericClient(LLMClient):
)
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction()
messages[0].content += get_extraction_language_instruction(group_id)
while retry_count <= self.MAX_RETRIES:
try:

View file

@ -139,6 +139,7 @@ async def extract_edges(
prompt_library.extract_edges.edge(context),
response_model=ExtractedEdges,
max_tokens=extract_edges_max_tokens,
group_id=group_id,
)
edges_data = ExtractedEdges(**llm_response).edges
@ -150,6 +151,7 @@ async def extract_edges(
prompt_library.extract_edges.reflexion(context),
response_model=MissingFacts,
max_tokens=extract_edges_max_tokens,
group_id=group_id,
)
missing_facts = reflexion_response.get('missing_facts', [])

View file

@ -64,6 +64,7 @@ async def extract_nodes_reflexion(
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
node_names: list[str],
group_id: str | None = None,
) -> list[str]:
# Prepare context for LLM
context = {
@ -73,7 +74,7 @@ async def extract_nodes_reflexion(
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.reflexion(context), MissedEntities
prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id
)
missed_entities = llm_response.get('missed_entities', [])
@ -129,16 +130,19 @@ async def extract_nodes(
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_message(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
)
elif episode.source == EpisodeType.text:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_text(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
)
elif episode.source == EpisodeType.json:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_json(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
)
response_object = ExtractedEntities(**llm_response)
@ -152,6 +156,7 @@ async def extract_nodes(
episode,
previous_episodes,
[entity.name for entity in extracted_entities],
episode.group_id,
)
entities_missed = len(missing_entities) != 0
@ -510,6 +515,7 @@ async def extract_attributes_from_node(
prompt_library.extract_nodes.extract_attributes(attributes_context),
response_model=entity_type,
model_size=ModelSize.small,
group_id=node.group_id,
)
)
if has_entity_attributes
@ -527,6 +533,7 @@ async def extract_attributes_from_node(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
group_id=node.group_id,
)
node.summary = summary_response.get('summary', '')

6
uv.lock generated
View file

@ -803,7 +803,6 @@ anthropic = [
]
dev = [
{ name = "anthropic" },
{ name = "boto3" },
{ name = "diskcache-stubs" },
{ name = "falkordb" },
{ name = "google-genai" },
@ -812,11 +811,9 @@ dev = [
{ name = "jupyterlab" },
{ name = "kuzu" },
{ name = "langchain-anthropic" },
{ name = "langchain-aws" },
{ name = "langchain-openai" },
{ name = "langgraph" },
{ name = "langsmith" },
{ name = "opensearch-py" },
{ name = "pyright" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
@ -858,7 +855,6 @@ voyageai = [
requires-dist = [
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
{ name = "boto3", marker = "extra == 'dev'", specifier = ">=1.39.16" },
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
{ name = "diskcache", specifier = ">=5.6.3" },
@ -874,7 +870,6 @@ requires-dist = [
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
{ name = "langchain-aws", marker = "extra == 'dev'", specifier = ">=0.2.29" },
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
{ name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" },
@ -882,7 +877,6 @@ requires-dist = [
{ name = "neo4j", specifier = ">=5.26.0" },
{ name = "numpy", specifier = ">=1.0.0" },
{ name = "openai", specifier = ">=1.91.0" },
{ name = "opensearch-py", marker = "extra == 'dev'", specifier = ">=3.0.0" },
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
{ name = "posthog", specifier = ">=3.0.0" },