diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index b06f870d..874365a5 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -33,12 +33,16 @@ DEFAULT_TEMPERATURE = 0 DEFAULT_CACHE_DIR = './llm_cache' -def get_extraction_language_instruction() -> str: +def get_extraction_language_instruction(group_id: str | None = None) -> str: """Returns instruction for language extraction behavior. Override this function to customize language extraction: - Return empty string to disable multilingual instructions - Return custom instructions for specific language requirements + - Use group_id to provide different instructions per group/partition + + Args: + group_id: Optional partition identifier for the graph Returns: str: Language instruction to append to system messages @@ -142,6 +146,7 @@ class LLMClient(ABC): response_model: type[BaseModel] | None = None, max_tokens: int | None = None, model_size: ModelSize = ModelSize.medium, + group_id: str | None = None, ) -> dict[str, typing.Any]: if max_tokens is None: max_tokens = self.max_tokens @@ -155,7 +160,7 @@ class LLMClient(ABC): ) # Add multilingual extraction instructions - messages[0].content += get_extraction_language_instruction() + messages[0].content += get_extraction_language_instruction(group_id) if self.cache_enabled and self.cache_dir is not None: cache_key = self._get_cache_key(messages) diff --git a/graphiti_core/llm_client/gemini_client.py b/graphiti_core/llm_client/gemini_client.py index f80b2f7e..72f94ba5 100644 --- a/graphiti_core/llm_client/gemini_client.py +++ b/graphiti_core/llm_client/gemini_client.py @@ -357,6 +357,7 @@ class GeminiClient(LLMClient): response_model: type[BaseModel] | None = None, max_tokens: int | None = None, model_size: ModelSize = ModelSize.medium, + group_id: str | None = None, ) -> dict[str, typing.Any]: """ Generate a response from the Gemini language model with retry logic and error handling. @@ -367,6 +368,7 @@ class GeminiClient(LLMClient): response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into. max_tokens (int | None): The maximum number of tokens to generate in the response. model_size (ModelSize): The size of the model to use (small or medium). + group_id (str | None): Optional partition identifier for the graph. Returns: dict[str, typing.Any]: The response from the language model. @@ -376,7 +378,7 @@ class GeminiClient(LLMClient): last_output = None # Add multilingual extraction instructions - messages[0].content += get_extraction_language_instruction() + messages[0].content += get_extraction_language_instruction(group_id) while retry_count < self.MAX_RETRIES: try: diff --git a/graphiti_core/llm_client/openai_base_client.py b/graphiti_core/llm_client/openai_base_client.py index cdda179c..7eb0a378 100644 --- a/graphiti_core/llm_client/openai_base_client.py +++ b/graphiti_core/llm_client/openai_base_client.py @@ -175,6 +175,7 @@ class BaseOpenAIClient(LLMClient): response_model: type[BaseModel] | None = None, max_tokens: int | None = None, model_size: ModelSize = ModelSize.medium, + group_id: str | None = None, ) -> dict[str, typing.Any]: """Generate a response with retry logic and error handling.""" if max_tokens is None: @@ -184,7 +185,7 @@ class BaseOpenAIClient(LLMClient): last_error = None # Add multilingual extraction instructions - messages[0].content += get_extraction_language_instruction() + messages[0].content += get_extraction_language_instruction(group_id) while retry_count <= self.MAX_RETRIES: try: diff --git a/graphiti_core/llm_client/openai_generic_client.py b/graphiti_core/llm_client/openai_generic_client.py index c4f4d212..298e7334 100644 --- a/graphiti_core/llm_client/openai_generic_client.py +++ b/graphiti_core/llm_client/openai_generic_client.py @@ -120,6 +120,7 @@ class OpenAIGenericClient(LLMClient): response_model: type[BaseModel] | None = None, max_tokens: int | None = None, model_size: ModelSize = ModelSize.medium, + group_id: str | None = None, ) -> dict[str, typing.Any]: if max_tokens is None: max_tokens = self.max_tokens @@ -136,7 +137,7 @@ class OpenAIGenericClient(LLMClient): ) # Add multilingual extraction instructions - messages[0].content += get_extraction_language_instruction() + messages[0].content += get_extraction_language_instruction(group_id) while retry_count <= self.MAX_RETRIES: try: diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 354c5058..ec1f618c 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -139,6 +139,7 @@ async def extract_edges( prompt_library.extract_edges.edge(context), response_model=ExtractedEdges, max_tokens=extract_edges_max_tokens, + group_id=group_id, ) edges_data = ExtractedEdges(**llm_response).edges @@ -150,6 +151,7 @@ async def extract_edges( prompt_library.extract_edges.reflexion(context), response_model=MissingFacts, max_tokens=extract_edges_max_tokens, + group_id=group_id, ) missing_facts = reflexion_response.get('missing_facts', []) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index b5ace9be..56f0a1e2 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -64,6 +64,7 @@ async def extract_nodes_reflexion( episode: EpisodicNode, previous_episodes: list[EpisodicNode], node_names: list[str], + group_id: str | None = None, ) -> list[str]: # Prepare context for LLM context = { @@ -73,7 +74,7 @@ async def extract_nodes_reflexion( } llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.reflexion(context), MissedEntities + prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id ) missed_entities = llm_response.get('missed_entities', []) @@ -129,16 +130,19 @@ async def extract_nodes( llm_response = await llm_client.generate_response( prompt_library.extract_nodes.extract_message(context), response_model=ExtractedEntities, + group_id=episode.group_id, ) elif episode.source == EpisodeType.text: llm_response = await llm_client.generate_response( prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities, + group_id=episode.group_id, ) elif episode.source == EpisodeType.json: llm_response = await llm_client.generate_response( prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities, + group_id=episode.group_id, ) response_object = ExtractedEntities(**llm_response) @@ -152,6 +156,7 @@ async def extract_nodes( episode, previous_episodes, [entity.name for entity in extracted_entities], + episode.group_id, ) entities_missed = len(missing_entities) != 0 @@ -511,6 +516,7 @@ async def extract_attributes_from_node( prompt_library.extract_nodes.extract_attributes(attributes_context), response_model=entity_type, model_size=ModelSize.small, + group_id=node.group_id, ) ) if has_entity_attributes @@ -528,6 +534,7 @@ async def extract_attributes_from_node( prompt_library.extract_nodes.extract_summary(summary_context), response_model=EntitySummary, model_size=ModelSize.small, + group_id=node.group_id, ) node.summary = summary_response.get('summary', '') diff --git a/uv.lock b/uv.lock index a0bff556..4f8c8127 100644 --- a/uv.lock +++ b/uv.lock @@ -803,7 +803,6 @@ anthropic = [ ] dev = [ { name = "anthropic" }, - { name = "boto3" }, { name = "diskcache-stubs" }, { name = "falkordb" }, { name = "google-genai" }, @@ -812,11 +811,9 @@ dev = [ { name = "jupyterlab" }, { name = "kuzu" }, { name = "langchain-anthropic" }, - { name = "langchain-aws" }, { name = "langchain-openai" }, { name = "langgraph" }, { name = "langsmith" }, - { name = "opensearch-py" }, { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -858,7 +855,6 @@ voyageai = [ requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, { name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" }, - { name = "boto3", marker = "extra == 'dev'", specifier = ">=1.39.16" }, { name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" }, { name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" }, { name = "diskcache", specifier = ">=5.6.3" }, @@ -874,7 +870,6 @@ requires-dist = [ { name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" }, { name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" }, { name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" }, - { name = "langchain-aws", marker = "extra == 'dev'", specifier = ">=0.2.29" }, { name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" }, { name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" }, { name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" }, @@ -882,7 +877,6 @@ requires-dist = [ { name = "neo4j", specifier = ">=5.26.0" }, { name = "numpy", specifier = ">=1.0.0" }, { name = "openai", specifier = ">=1.91.0" }, - { name = "opensearch-py", marker = "extra == 'dev'", specifier = ">=3.0.0" }, { name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" }, { name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" }, { name = "posthog", specifier = ">=3.0.0" },