Compare commits
5 commits
main
...
configurab
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a8102717f | ||
|
|
efdd683504 | ||
|
|
ff603b5490 | ||
|
|
a20643470f | ||
|
|
d430ab82b4 |
7 changed files with 24 additions and 12 deletions
|
|
@ -33,12 +33,16 @@ DEFAULT_TEMPERATURE = 0
|
||||||
DEFAULT_CACHE_DIR = './llm_cache'
|
DEFAULT_CACHE_DIR = './llm_cache'
|
||||||
|
|
||||||
|
|
||||||
def get_extraction_language_instruction() -> str:
|
def get_extraction_language_instruction(group_id: str | None = None) -> str:
|
||||||
"""Returns instruction for language extraction behavior.
|
"""Returns instruction for language extraction behavior.
|
||||||
|
|
||||||
Override this function to customize language extraction:
|
Override this function to customize language extraction:
|
||||||
- Return empty string to disable multilingual instructions
|
- Return empty string to disable multilingual instructions
|
||||||
- Return custom instructions for specific language requirements
|
- Return custom instructions for specific language requirements
|
||||||
|
- Use group_id to provide different instructions per group/partition
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: Optional partition identifier for the graph
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Language instruction to append to system messages
|
str: Language instruction to append to system messages
|
||||||
|
|
@ -142,6 +146,7 @@ class LLMClient(ABC):
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
model_size: ModelSize = ModelSize.medium,
|
model_size: ModelSize = ModelSize.medium,
|
||||||
|
group_id: str | None = None,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
|
|
@ -155,7 +160,7 @@ class LLMClient(ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add multilingual extraction instructions
|
# Add multilingual extraction instructions
|
||||||
messages[0].content += get_extraction_language_instruction()
|
messages[0].content += get_extraction_language_instruction(group_id)
|
||||||
|
|
||||||
if self.cache_enabled and self.cache_dir is not None:
|
if self.cache_enabled and self.cache_dir is not None:
|
||||||
cache_key = self._get_cache_key(messages)
|
cache_key = self._get_cache_key(messages)
|
||||||
|
|
|
||||||
|
|
@ -357,6 +357,7 @@ class GeminiClient(LLMClient):
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
model_size: ModelSize = ModelSize.medium,
|
model_size: ModelSize = ModelSize.medium,
|
||||||
|
group_id: str | None = None,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
Generate a response from the Gemini language model with retry logic and error handling.
|
Generate a response from the Gemini language model with retry logic and error handling.
|
||||||
|
|
@ -367,6 +368,7 @@ class GeminiClient(LLMClient):
|
||||||
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
||||||
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
||||||
model_size (ModelSize): The size of the model to use (small or medium).
|
model_size (ModelSize): The size of the model to use (small or medium).
|
||||||
|
group_id (str | None): Optional partition identifier for the graph.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, typing.Any]: The response from the language model.
|
dict[str, typing.Any]: The response from the language model.
|
||||||
|
|
@ -376,7 +378,7 @@ class GeminiClient(LLMClient):
|
||||||
last_output = None
|
last_output = None
|
||||||
|
|
||||||
# Add multilingual extraction instructions
|
# Add multilingual extraction instructions
|
||||||
messages[0].content += get_extraction_language_instruction()
|
messages[0].content += get_extraction_language_instruction(group_id)
|
||||||
|
|
||||||
while retry_count < self.MAX_RETRIES:
|
while retry_count < self.MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -175,6 +175,7 @@ class BaseOpenAIClient(LLMClient):
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
model_size: ModelSize = ModelSize.medium,
|
model_size: ModelSize = ModelSize.medium,
|
||||||
|
group_id: str | None = None,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
"""Generate a response with retry logic and error handling."""
|
"""Generate a response with retry logic and error handling."""
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
|
|
@ -184,7 +185,7 @@ class BaseOpenAIClient(LLMClient):
|
||||||
last_error = None
|
last_error = None
|
||||||
|
|
||||||
# Add multilingual extraction instructions
|
# Add multilingual extraction instructions
|
||||||
messages[0].content += get_extraction_language_instruction()
|
messages[0].content += get_extraction_language_instruction(group_id)
|
||||||
|
|
||||||
while retry_count <= self.MAX_RETRIES:
|
while retry_count <= self.MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,7 @@ class OpenAIGenericClient(LLMClient):
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
model_size: ModelSize = ModelSize.medium,
|
model_size: ModelSize = ModelSize.medium,
|
||||||
|
group_id: str | None = None,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
|
|
@ -136,7 +137,7 @@ class OpenAIGenericClient(LLMClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add multilingual extraction instructions
|
# Add multilingual extraction instructions
|
||||||
messages[0].content += get_extraction_language_instruction()
|
messages[0].content += get_extraction_language_instruction(group_id)
|
||||||
|
|
||||||
while retry_count <= self.MAX_RETRIES:
|
while retry_count <= self.MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,7 @@ async def extract_edges(
|
||||||
prompt_library.extract_edges.edge(context),
|
prompt_library.extract_edges.edge(context),
|
||||||
response_model=ExtractedEdges,
|
response_model=ExtractedEdges,
|
||||||
max_tokens=extract_edges_max_tokens,
|
max_tokens=extract_edges_max_tokens,
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
edges_data = ExtractedEdges(**llm_response).edges
|
edges_data = ExtractedEdges(**llm_response).edges
|
||||||
|
|
||||||
|
|
@ -150,6 +151,7 @@ async def extract_edges(
|
||||||
prompt_library.extract_edges.reflexion(context),
|
prompt_library.extract_edges.reflexion(context),
|
||||||
response_model=MissingFacts,
|
response_model=MissingFacts,
|
||||||
max_tokens=extract_edges_max_tokens,
|
max_tokens=extract_edges_max_tokens,
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
missing_facts = reflexion_response.get('missing_facts', [])
|
missing_facts = reflexion_response.get('missing_facts', [])
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,7 @@ async def extract_nodes_reflexion(
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
node_names: list[str],
|
node_names: list[str],
|
||||||
|
group_id: str | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
|
|
@ -73,7 +74,7 @@ async def extract_nodes_reflexion(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.reflexion(context), MissedEntities
|
prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id
|
||||||
)
|
)
|
||||||
missed_entities = llm_response.get('missed_entities', [])
|
missed_entities = llm_response.get('missed_entities', [])
|
||||||
|
|
||||||
|
|
@ -129,16 +130,19 @@ async def extract_nodes(
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.extract_message(context),
|
prompt_library.extract_nodes.extract_message(context),
|
||||||
response_model=ExtractedEntities,
|
response_model=ExtractedEntities,
|
||||||
|
group_id=episode.group_id,
|
||||||
)
|
)
|
||||||
elif episode.source == EpisodeType.text:
|
elif episode.source == EpisodeType.text:
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.extract_text(context),
|
prompt_library.extract_nodes.extract_text(context),
|
||||||
response_model=ExtractedEntities,
|
response_model=ExtractedEntities,
|
||||||
|
group_id=episode.group_id,
|
||||||
)
|
)
|
||||||
elif episode.source == EpisodeType.json:
|
elif episode.source == EpisodeType.json:
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.extract_json(context),
|
prompt_library.extract_nodes.extract_json(context),
|
||||||
response_model=ExtractedEntities,
|
response_model=ExtractedEntities,
|
||||||
|
group_id=episode.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_object = ExtractedEntities(**llm_response)
|
response_object = ExtractedEntities(**llm_response)
|
||||||
|
|
@ -152,6 +156,7 @@ async def extract_nodes(
|
||||||
episode,
|
episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
[entity.name for entity in extracted_entities],
|
[entity.name for entity in extracted_entities],
|
||||||
|
episode.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
entities_missed = len(missing_entities) != 0
|
entities_missed = len(missing_entities) != 0
|
||||||
|
|
@ -510,6 +515,7 @@ async def extract_attributes_from_node(
|
||||||
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
||||||
response_model=entity_type,
|
response_model=entity_type,
|
||||||
model_size=ModelSize.small,
|
model_size=ModelSize.small,
|
||||||
|
group_id=node.group_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if has_entity_attributes
|
if has_entity_attributes
|
||||||
|
|
@ -527,6 +533,7 @@ async def extract_attributes_from_node(
|
||||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||||
response_model=EntitySummary,
|
response_model=EntitySummary,
|
||||||
model_size=ModelSize.small,
|
model_size=ModelSize.small,
|
||||||
|
group_id=node.group_id,
|
||||||
)
|
)
|
||||||
node.summary = summary_response.get('summary', '')
|
node.summary = summary_response.get('summary', '')
|
||||||
|
|
||||||
|
|
|
||||||
6
uv.lock
generated
6
uv.lock
generated
|
|
@ -803,7 +803,6 @@ anthropic = [
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
{ name = "anthropic" },
|
{ name = "anthropic" },
|
||||||
{ name = "boto3" },
|
|
||||||
{ name = "diskcache-stubs" },
|
{ name = "diskcache-stubs" },
|
||||||
{ name = "falkordb" },
|
{ name = "falkordb" },
|
||||||
{ name = "google-genai" },
|
{ name = "google-genai" },
|
||||||
|
|
@ -812,11 +811,9 @@ dev = [
|
||||||
{ name = "jupyterlab" },
|
{ name = "jupyterlab" },
|
||||||
{ name = "kuzu" },
|
{ name = "kuzu" },
|
||||||
{ name = "langchain-anthropic" },
|
{ name = "langchain-anthropic" },
|
||||||
{ name = "langchain-aws" },
|
|
||||||
{ name = "langchain-openai" },
|
{ name = "langchain-openai" },
|
||||||
{ name = "langgraph" },
|
{ name = "langgraph" },
|
||||||
{ name = "langsmith" },
|
{ name = "langsmith" },
|
||||||
{ name = "opensearch-py" },
|
|
||||||
{ name = "pyright" },
|
{ name = "pyright" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
|
|
@ -858,7 +855,6 @@ voyageai = [
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
||||||
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
||||||
{ name = "boto3", marker = "extra == 'dev'", specifier = ">=1.39.16" },
|
|
||||||
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
||||||
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
||||||
{ name = "diskcache", specifier = ">=5.6.3" },
|
{ name = "diskcache", specifier = ">=5.6.3" },
|
||||||
|
|
@ -874,7 +870,6 @@ requires-dist = [
|
||||||
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
|
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
|
||||||
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
|
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
|
||||||
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
|
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
|
||||||
{ name = "langchain-aws", marker = "extra == 'dev'", specifier = ">=0.2.29" },
|
|
||||||
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
|
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
|
||||||
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
|
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
|
||||||
{ name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" },
|
{ name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" },
|
||||||
|
|
@ -882,7 +877,6 @@ requires-dist = [
|
||||||
{ name = "neo4j", specifier = ">=5.26.0" },
|
{ name = "neo4j", specifier = ">=5.26.0" },
|
||||||
{ name = "numpy", specifier = ">=1.0.0" },
|
{ name = "numpy", specifier = ">=1.0.0" },
|
||||||
{ name = "openai", specifier = ">=1.91.0" },
|
{ name = "openai", specifier = ">=1.91.0" },
|
||||||
{ name = "opensearch-py", marker = "extra == 'dev'", specifier = ">=3.0.0" },
|
|
||||||
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
||||||
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
||||||
{ name = "posthog", specifier = ">=3.0.0" },
|
{ name = "posthog", specifier = ">=3.0.0" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue