feat: Added custom prompt to cognify (#1278)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-08-27 14:10:21 +02:00 committed by GitHub
commit 62afced9a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 99 additions and 33 deletions

View file

@ -121,7 +121,9 @@ async def cognee_add_developer_rules(
@mcp.tool()
async def cognify(data: str, graph_model_file: str = None, graph_model_name: str = None) -> list:
async def cognify(
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
) -> list:
"""
Transform ingested data into a structured knowledge graph.
@ -169,6 +171,12 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
Required if graph_model_file is specified.
Default is None, which uses the default KnowledgeGraph class.
custom_prompt : str, optional
Custom prompt string to use for entity extraction and graph generation.
If provided, this prompt will be used instead of the default prompts for
knowledge graph extraction. The prompt should guide the LLM on how to
extract entities and relationships from the text content.
Returns
-------
list
@ -224,7 +232,10 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
"""
async def cognify_task(
data: str, graph_model_file: str = None, graph_model_name: str = None
data: str,
graph_model_file: str = None,
graph_model_name: str = None,
custom_prompt: str = None,
) -> str:
"""Build knowledge graph from the input text"""
# NOTE: MCP uses stdout to communicate, we must redirect all output
@ -239,7 +250,7 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
await cognee.add(data)
try:
await cognee.cognify(graph_model=graph_model)
await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt)
logger.info("Cognify process finished.")
except Exception as e:
logger.error("Cognify process failed.")
@ -250,6 +261,7 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
data=data,
graph_model_file=graph_model_file,
graph_model_name=graph_model_name,
custom_prompt=custom_prompt,
)
)

View file

@ -39,6 +39,7 @@ async def cognify(
graph_db_config: dict = None,
run_in_background: bool = False,
incremental_loading: bool = True,
custom_prompt: Optional[str] = None,
):
"""
Transform ingested data into a structured knowledge graph.
@ -101,6 +102,10 @@ async def cognify(
If False, waits for completion before returning.
Background mode recommended for large datasets (>100MB).
Use pipeline_run_id from return value to monitor progress.
custom_prompt: Optional custom prompt string to use for entity extraction and graph generation.
If provided, this prompt will be used instead of the default prompts for
knowledge graph extraction. The prompt should guide the LLM on how to
extract entities and relationships from the text content.
Returns:
Union[dict, list[PipelineRunInfo]]:
@ -177,7 +182,9 @@ async def cognify(
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
"""
tasks = await get_default_tasks(user, graph_model, chunker, chunk_size, ontology_file_path)
tasks = await get_default_tasks(
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
)
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
@ -201,6 +208,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
chunker=TextChunker,
chunk_size: int = None,
ontology_file_path: Optional[str] = None,
custom_prompt: Optional[str] = None,
) -> list[Task]:
default_tasks = [
Task(classify_documents),
@ -214,6 +222,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
custom_prompt=custom_prompt,
task_config={"batch_size": 10},
), # Generate knowledge graphs from the document chunks.
Task(

View file

@ -37,6 +37,9 @@ class CognifyPayloadDTO(InDTO):
datasets: Optional[List[str]] = Field(default=None)
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
run_in_background: Optional[bool] = Field(default=False)
custom_prompt: Optional[str] = Field(
default=None, description="Custom prompt for entity extraction and graph generation"
)
def get_cognify_router() -> APIRouter:
@ -63,6 +66,7 @@ def get_cognify_router() -> APIRouter:
- **datasets** (Optional[List[str]]): List of dataset names to process. Dataset names are resolved to datasets owned by the authenticated user.
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
## Response
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
@ -76,7 +80,8 @@ def get_cognify_router() -> APIRouter:
```json
{
"datasets": ["research_papers", "documentation"],
"run_in_background": false
"run_in_background": false,
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
}
```
@ -106,7 +111,10 @@ def get_cognify_router() -> APIRouter:
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
cognify_run = await cognee_cognify(
datasets, user, run_in_background=payload.run_in_background
datasets,
user,
run_in_background=payload.run_in_background,
custom_prompt=payload.custom_prompt,
)
# If any cognify run errored return JSONResponse with proper error status code

View file

@ -49,6 +49,10 @@ DEFAULT_TOOLS = [
"type": "string",
"description": "Path to a custom ontology file",
},
"custom_prompt": {
"type": "string",
"description": "Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts.",
},
},
"required": ["text"],
},

View file

@ -88,11 +88,16 @@ async def handle_cognify(arguments: Dict[str, Any], user) -> str:
"""Handle cognify function call"""
text = arguments.get("text")
ontology_file_path = arguments.get("ontology_file_path")
custom_prompt = arguments.get("custom_prompt")
if text:
await add(data=text, user=user)
await cognify(user=user, ontology_file_path=ontology_file_path if ontology_file_path else None)
await cognify(
user=user,
ontology_file_path=ontology_file_path if ontology_file_path else None,
custom_prompt=custom_prompt,
)
return (
"Text successfully converted into knowledge graph."

View file

@ -1,6 +1,5 @@
from typing import Type
from typing import Type, Optional, Coroutine
from pydantic import BaseModel
from typing import Coroutine
from cognee.infrastructure.llm import get_llm_config
@ -79,7 +78,10 @@ class LLMGateway:
@staticmethod
def extract_content_graph(
content: str, response_model: Type[BaseModel], mode: str = "simple"
content: str,
response_model: Type[BaseModel],
mode: str = "simple",
custom_prompt: Optional[str] = None,
) -> Coroutine:
llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML":
@ -87,13 +89,20 @@ class LLMGateway:
extract_content_graph,
)
return extract_content_graph(content=content, response_model=response_model, mode=mode)
return extract_content_graph(
content=content,
response_model=response_model,
mode=mode,
custom_prompt=custom_prompt,
)
else:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
extract_content_graph,
)
return extract_content_graph(content=content, response_model=response_model)
return extract_content_graph(
content=content, response_model=response_model, custom_prompt=custom_prompt
)
@staticmethod
def extract_categories(content: str, response_model: Type[BaseModel]) -> Coroutine:

View file

@ -1,4 +1,4 @@
from typing import Type
from typing import Type, Optional
from pydantic import BaseModel
from cognee.infrastructure.llm.config import get_llm_config
from cognee.shared.logging_utils import get_logger, setup_logging
@ -6,7 +6,10 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.asyn
async def extract_content_graph(
content: str, response_model: Type[BaseModel], mode: str = "simple"
content: str,
response_model: Type[BaseModel],
mode: str = "simple",
custom_prompt: Optional[str] = None,
):
config = get_llm_config()
setup_logging()
@ -26,8 +29,16 @@ async def extract_content_graph(
# return graph
# else:
graph = await b.ExtractContentGraphGeneric(
content, mode=mode, baml_options={"client_registry": config.baml_registry}
)
if custom_prompt:
graph = await b.ExtractContentGraphGeneric(
content,
mode="custom",
custom_prompt_content=custom_prompt,
baml_options={"client_registry": config.baml_registry},
)
else:
graph = await b.ExtractContentGraphGeneric(
content, mode=mode, baml_options={"client_registry": config.baml_registry}
)
return graph

View file

@ -1,5 +1,5 @@
import os
from typing import Type
from typing import Type, Optional
from pydantic import BaseModel
from cognee.infrastructure.llm.LLMGateway import LLMGateway
@ -8,21 +8,25 @@ from cognee.infrastructure.llm.config import (
)
async def extract_content_graph(content: str, response_model: Type[BaseModel]):
llm_config = get_llm_config()
prompt_path = llm_config.graph_prompt_path
# Check if the prompt path is an absolute path or just a filename
if os.path.isabs(prompt_path):
# directory containing the file
base_directory = os.path.dirname(prompt_path)
# just the filename itself
prompt_path = os.path.basename(prompt_path)
async def extract_content_graph(
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
):
if custom_prompt:
system_prompt = custom_prompt
else:
base_directory = None
llm_config = get_llm_config()
prompt_path = llm_config.graph_prompt_path
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
# Check if the prompt path is an absolute path or just a filename
if os.path.isabs(prompt_path):
# directory containing the file
base_directory = os.path.dirname(prompt_path)
# just the filename itself
prompt_path = os.path.basename(prompt_path)
else:
base_directory = None
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output(
content, system_prompt, response_model

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Type, List
from typing import Type, List, Optional
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
@ -71,6 +71,7 @@ async def extract_graph_from_data(
data_chunks: List[DocumentChunk],
graph_model: Type[BaseModel],
ontology_adapter: OntologyResolver = None,
custom_prompt: Optional[str] = None,
) -> List[DocumentChunk]:
"""
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
@ -84,7 +85,10 @@ async def extract_graph_from_data(
raise InvalidGraphModelError(graph_model)
chunk_graphs = await asyncio.gather(
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
*[
LLMGateway.extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
for chunk in data_chunks
]
)
# Note: Filter edges with missing source or target nodes