feat: add kwargs to cognify and related tasks

This commit is contained in:
Andrej Milicevic 2025-11-27 17:05:37 +01:00
parent c649900042
commit aa8afefe8a
4 changed files with 10 additions and 5 deletions

View file

@ -53,6 +53,7 @@ async def cognify(
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
temporal_cognify: bool = False, temporal_cognify: bool = False,
data_per_batch: int = 20, data_per_batch: int = 20,
**kwargs
): ):
""" """
Transform ingested data into a structured knowledge graph. Transform ingested data into a structured knowledge graph.
@ -224,6 +225,7 @@ async def cognify(
config=config, config=config,
custom_prompt=custom_prompt, custom_prompt=custom_prompt,
chunks_per_batch=chunks_per_batch, chunks_per_batch=chunks_per_batch,
**kwargs,
) )
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
@ -251,6 +253,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config: Config = None, config: Config = None,
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
chunks_per_batch: int = 100, chunks_per_batch: int = 100,
**kwargs,
) -> list[Task]: ) -> list[Task]:
if config is None: if config is None:
ontology_config = get_ontology_env_config() ontology_config = get_ontology_env_config()
@ -286,6 +289,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config=config, config=config,
custom_prompt=custom_prompt, custom_prompt=custom_prompt,
task_config={"batch_size": chunks_per_batch}, task_config={"batch_size": chunks_per_batch},
**kwargs,
), # Generate knowledge graphs from the document chunks. ), # Generate knowledge graphs from the document chunks.
Task( Task(
summarize_text, summarize_text,

View file

@ -11,7 +11,7 @@ class LLMGateway:
@staticmethod @staticmethod
def acreate_structured_output( def acreate_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel] text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> Coroutine: ) -> Coroutine:
llm_config = get_llm_config() llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML": if llm_config.structured_output_framework.upper() == "BAML":
@ -31,7 +31,7 @@ class LLMGateway:
llm_client = get_llm_client() llm_client = get_llm_client()
return llm_client.acreate_structured_output( return llm_client.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model text_input=text_input, system_prompt=system_prompt, response_model=response_model, **kwargs
) )
@staticmethod @staticmethod

View file

@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
async def extract_content_graph( async def extract_content_graph(
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
): ):
if custom_prompt: if custom_prompt:
system_prompt = custom_prompt system_prompt = custom_prompt
@ -30,7 +30,7 @@ async def extract_content_graph(
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory) system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output( content_graph = await LLMGateway.acreate_structured_output(
content, system_prompt, response_model content, system_prompt, response_model, **kwargs
) )
return content_graph return content_graph

View file

@ -99,6 +99,7 @@ async def extract_graph_from_data(
graph_model: Type[BaseModel], graph_model: Type[BaseModel],
config: Config = None, config: Config = None,
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
**kwargs,
) -> List[DocumentChunk]: ) -> List[DocumentChunk]:
""" """
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model. Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
@ -113,7 +114,7 @@ async def extract_graph_from_data(
chunk_graphs = await asyncio.gather( chunk_graphs = await asyncio.gather(
*[ *[
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt) extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt, **kwargs)
for chunk in data_chunks for chunk in data_chunks
] ]
) )