feat: add support to pass custom parameters in llm adapters during cognify (#1802)

<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->
A user raised this issue/feature request:
https://github.com/topoteretes/cognee/issues/1784
This PR implements this in a way, but I may have misunderstood the
request. If I'm correct, I think it makes sense to support this, if our
users need it.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-12-15 17:09:13 +01:00 committed by GitHub
commit 821852f2c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 26 additions and 16 deletions

View file

@ -53,6 +53,7 @@ async def cognify(
custom_prompt: Optional[str] = None,
temporal_cognify: bool = False,
data_per_batch: int = 20,
**kwargs
):
"""
Transform ingested data into a structured knowledge graph.
@ -223,6 +224,7 @@ async def cognify(
config=config,
custom_prompt=custom_prompt,
chunks_per_batch=chunks_per_batch,
**kwargs,
)
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
@ -251,6 +253,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config: Config = None,
custom_prompt: Optional[str] = None,
chunks_per_batch: int = 100,
**kwargs,
) -> list[Task]:
if config is None:
ontology_config = get_ontology_env_config()
@ -288,6 +291,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config=config,
custom_prompt=custom_prompt,
task_config={"batch_size": chunks_per_batch},
**kwargs,
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,

View file

@ -11,7 +11,7 @@ class LLMGateway:
@staticmethod
def acreate_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel]
text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> Coroutine:
llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML":
@ -31,7 +31,7 @@ class LLMGateway:
llm_client = get_llm_client()
return llm_client.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
text_input=text_input, system_prompt=system_prompt, response_model=response_model, **kwargs
)
@staticmethod

View file

@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
async def extract_content_graph(
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
):
if custom_prompt:
system_prompt = custom_prompt
@ -30,7 +30,7 @@ async def extract_content_graph(
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output(
content, system_prompt, response_model
content, system_prompt, response_model, **kwargs
)
return content_graph

View file

@ -52,7 +52,7 @@ class AnthropicAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.

View file

@ -80,7 +80,7 @@ class GeminiAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.

View file

@ -80,7 +80,7 @@ class GenericAPIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.

View file

@ -69,7 +69,7 @@ class MistralAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from the user query.

View file

@ -76,7 +76,7 @@ class OllamaAPIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a structured output from the LLM using the provided text and system prompt.
@ -123,7 +123,7 @@ class OllamaAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input_file: str) -> str:
async def create_transcript(self, input_file: str, **kwargs) -> str:
"""
Generate an audio transcript from a user query.
@ -162,7 +162,7 @@ class OllamaAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input_file: str) -> str:
async def transcribe_image(self, input_file: str, **kwargs) -> str:
"""
Transcribe content from an image using base64 encoding.

View file

@ -112,7 +112,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
@ -154,6 +154,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
**kwargs,
)
except (
ContentFilterFinishReasonError,
@ -180,6 +181,7 @@ class OpenAIAdapter(LLMInterface):
# api_base=self.fallback_endpoint,
response_model=response_model,
max_retries=self.MAX_RETRIES,
**kwargs,
)
except (
ContentFilterFinishReasonError,
@ -205,7 +207,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True,
)
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
@ -245,6 +247,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
**kwargs,
)
@retry(
@ -254,7 +257,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input):
async def create_transcript(self, input, **kwargs):
"""
Generate an audio transcript from a user query.
@ -281,6 +284,7 @@ class OpenAIAdapter(LLMInterface):
api_base=self.endpoint,
api_version=self.api_version,
max_retries=self.MAX_RETRIES,
**kwargs,
)
return transcription
@ -292,7 +296,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> BaseModel:
async def transcribe_image(self, input, **kwargs) -> BaseModel:
"""
Generate a transcription of an image from a user query.
@ -337,4 +341,5 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
**kwargs,
)

View file

@ -97,6 +97,7 @@ async def extract_graph_from_data(
graph_model: Type[BaseModel],
config: Config = None,
custom_prompt: Optional[str] = None,
**kwargs,
) -> List[DocumentChunk]:
"""
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
@ -111,7 +112,7 @@ async def extract_graph_from_data(
chunk_graphs = await asyncio.gather(
*[
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt, **kwargs)
for chunk in data_chunks
]
)