From d69669b527d6f5f9009d7231a58c682636a755a5 Mon Sep 17 00:00:00 2001 From: vasilije Date: Fri, 22 Aug 2025 12:37:51 +0200 Subject: [PATCH] added ability to send custom prompts to cognify --- cognee-mcp/src/server.py | 18 +++++++++-- cognee/api/v1/cognify/cognify.py | 11 ++++++- .../v1/cognify/routers/get_cognify_router.py | 12 +++++-- cognee/api/v1/responses/default_tools.py | 4 +++ cognee/api/v1/responses/dispatch_function.py | 7 +++- cognee/infrastructure/llm/LLMGateway.py | 19 ++++++++--- .../knowledge_graph/extract_content_graph.py | 21 +++++++++--- .../knowledge_graph/extract_content_graph.py | 32 +++++++++++-------- cognee/tasks/graph/extract_graph_from_data.py | 8 +++-- 9 files changed, 99 insertions(+), 33 deletions(-) diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index 3e65a5eb7..9e55b9707 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -121,7 +121,9 @@ async def cognee_add_developer_rules( @mcp.tool() -async def cognify(data: str, graph_model_file: str = None, graph_model_name: str = None) -> list: +async def cognify( + data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None +) -> list: """ Transform ingested data into a structured knowledge graph. @@ -169,6 +171,12 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str Required if graph_model_file is specified. Default is None, which uses the default KnowledgeGraph class. + custom_prompt : str, optional + Custom prompt string to use for entity extraction and graph generation. + If provided, this prompt will be used instead of the default prompts for + knowledge graph extraction. The prompt should guide the LLM on how to + extract entities and relationships from the text content. + Returns ------- list @@ -224,7 +232,10 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str """ async def cognify_task( - data: str, graph_model_file: str = None, graph_model_name: str = None + data: str, + graph_model_file: str = None, + graph_model_name: str = None, + custom_prompt: str = None, ) -> str: """Build knowledge graph from the input text""" # NOTE: MCP uses stdout to communicate, we must redirect all output @@ -239,7 +250,7 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str await cognee.add(data) try: - await cognee.cognify(graph_model=graph_model) + await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt) logger.info("Cognify process finished.") except Exception as e: logger.error("Cognify process failed.") @@ -250,6 +261,7 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str data=data, graph_model_file=graph_model_file, graph_model_name=graph_model_name, + custom_prompt=custom_prompt, ) ) diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 90c3c469e..c499de2f3 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -40,6 +40,7 @@ async def cognify( graph_db_config: dict = None, run_in_background: bool = False, incremental_loading: bool = True, + custom_prompt: Optional[str] = None, ): """ Transform ingested data into a structured knowledge graph. @@ -102,6 +103,10 @@ async def cognify( If False, waits for completion before returning. Background mode recommended for large datasets (>100MB). Use pipeline_run_id from return value to monitor progress. + custom_prompt: Optional custom prompt string to use for entity extraction and graph generation. + If provided, this prompt will be used instead of the default prompts for + knowledge graph extraction. The prompt should guide the LLM on how to + extract entities and relationships from the text content. Returns: Union[dict, list[PipelineRunInfo]]: @@ -178,7 +183,9 @@ async def cognify( - LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False) - LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60) """ - tasks = await get_default_tasks(user, graph_model, chunker, chunk_size, ontology_file_path) + tasks = await get_default_tasks( + user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt + ) if run_in_background: return await run_cognify_as_background_process( @@ -295,6 +302,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's chunker=TextChunker, chunk_size: int = None, ontology_file_path: Optional[str] = None, + custom_prompt: Optional[str] = None, ) -> list[Task]: default_tasks = [ Task(classify_documents), @@ -308,6 +316,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's extract_graph_from_data, graph_model=graph_model, ontology_adapter=OntologyResolver(ontology_file=ontology_file_path), + custom_prompt=custom_prompt, task_config={"batch_size": 10}, ), # Generate knowledge graphs from the document chunks. Task( diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index b63238966..cf19eaa43 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -37,6 +37,9 @@ class CognifyPayloadDTO(InDTO): datasets: Optional[List[str]] = Field(default=None) dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]]) run_in_background: Optional[bool] = Field(default=False) + custom_prompt: Optional[str] = Field( + default=None, description="Custom prompt for entity extraction and graph generation" + ) def get_cognify_router() -> APIRouter: @@ -63,6 +66,7 @@ def get_cognify_router() -> APIRouter: - **datasets** (Optional[List[str]]): List of dataset names to process. Dataset names are resolved to datasets owned by the authenticated user. - **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted). - **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking). + - **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction. ## Response - **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status @@ -76,7 +80,8 @@ def get_cognify_router() -> APIRouter: ```json { "datasets": ["research_papers", "documentation"], - "run_in_background": false + "run_in_background": false, + "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections." } ``` @@ -106,7 +111,10 @@ def get_cognify_router() -> APIRouter: datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets cognify_run = await cognee_cognify( - datasets, user, run_in_background=payload.run_in_background + datasets, + user, + run_in_background=payload.run_in_background, + custom_prompt=payload.custom_prompt, ) # If any cognify run errored return JSONResponse with proper error status code diff --git a/cognee/api/v1/responses/default_tools.py b/cognee/api/v1/responses/default_tools.py index abf222d59..295d132f1 100644 --- a/cognee/api/v1/responses/default_tools.py +++ b/cognee/api/v1/responses/default_tools.py @@ -49,6 +49,10 @@ DEFAULT_TOOLS = [ "type": "string", "description": "Path to a custom ontology file", }, + "custom_prompt": { + "type": "string", + "description": "Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts.", + }, }, "required": ["text"], }, diff --git a/cognee/api/v1/responses/dispatch_function.py b/cognee/api/v1/responses/dispatch_function.py index 564f3b889..85388b564 100644 --- a/cognee/api/v1/responses/dispatch_function.py +++ b/cognee/api/v1/responses/dispatch_function.py @@ -88,11 +88,16 @@ async def handle_cognify(arguments: Dict[str, Any], user) -> str: """Handle cognify function call""" text = arguments.get("text") ontology_file_path = arguments.get("ontology_file_path") + custom_prompt = arguments.get("custom_prompt") if text: await add(data=text, user=user) - await cognify(user=user, ontology_file_path=ontology_file_path if ontology_file_path else None) + await cognify( + user=user, + ontology_file_path=ontology_file_path if ontology_file_path else None, + custom_prompt=custom_prompt, + ) return ( "Text successfully converted into knowledge graph." diff --git a/cognee/infrastructure/llm/LLMGateway.py b/cognee/infrastructure/llm/LLMGateway.py index a88cfb85d..c1bc7ac79 100644 --- a/cognee/infrastructure/llm/LLMGateway.py +++ b/cognee/infrastructure/llm/LLMGateway.py @@ -1,6 +1,5 @@ -from typing import Type +from typing import Type, Optional, Coroutine from pydantic import BaseModel -from typing import Coroutine from cognee.infrastructure.llm import get_llm_config @@ -79,7 +78,10 @@ class LLMGateway: @staticmethod def extract_content_graph( - content: str, response_model: Type[BaseModel], mode: str = "simple" + content: str, + response_model: Type[BaseModel], + mode: str = "simple", + custom_prompt: Optional[str] = None, ) -> Coroutine: llm_config = get_llm_config() if llm_config.structured_output_framework.upper() == "BAML": @@ -87,13 +89,20 @@ class LLMGateway: extract_content_graph, ) - return extract_content_graph(content=content, response_model=response_model, mode=mode) + return extract_content_graph( + content=content, + response_model=response_model, + mode=mode, + custom_prompt=custom_prompt, + ) else: from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import ( extract_content_graph, ) - return extract_content_graph(content=content, response_model=response_model) + return extract_content_graph( + content=content, response_model=response_model, custom_prompt=custom_prompt + ) @staticmethod def extract_categories(content: str, response_model: Type[BaseModel]) -> Coroutine: diff --git a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py index d98112434..abff07e09 100644 --- a/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +++ b/cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from cognee.infrastructure.llm.config import get_llm_config from cognee.shared.logging_utils import get_logger, setup_logging @@ -6,7 +6,10 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.asyn async def extract_content_graph( - content: str, response_model: Type[BaseModel], mode: str = "simple" + content: str, + response_model: Type[BaseModel], + mode: str = "simple", + custom_prompt: Optional[str] = None, ): config = get_llm_config() setup_logging() @@ -26,8 +29,16 @@ async def extract_content_graph( # return graph # else: - graph = await b.ExtractContentGraphGeneric( - content, mode=mode, baml_options={"client_registry": config.baml_registry} - ) + if custom_prompt: + graph = await b.ExtractContentGraphGeneric( + content, + mode="custom", + custom_prompt_content=custom_prompt, + baml_options={"client_registry": config.baml_registry}, + ) + else: + graph = await b.ExtractContentGraphGeneric( + content, mode=mode, baml_options={"client_registry": config.baml_registry} + ) return graph diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py index 9b945d167..e5fc61634 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py @@ -1,5 +1,5 @@ import os -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from cognee.infrastructure.llm.LLMGateway import LLMGateway @@ -8,21 +8,25 @@ from cognee.infrastructure.llm.config import ( ) -async def extract_content_graph(content: str, response_model: Type[BaseModel]): - llm_config = get_llm_config() - - prompt_path = llm_config.graph_prompt_path - - # Check if the prompt path is an absolute path or just a filename - if os.path.isabs(prompt_path): - # directory containing the file - base_directory = os.path.dirname(prompt_path) - # just the filename itself - prompt_path = os.path.basename(prompt_path) +async def extract_content_graph( + content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None +): + if custom_prompt: + system_prompt = custom_prompt else: - base_directory = None + llm_config = get_llm_config() + prompt_path = llm_config.graph_prompt_path - system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory) + # Check if the prompt path is an absolute path or just a filename + if os.path.isabs(prompt_path): + # directory containing the file + base_directory = os.path.dirname(prompt_path) + # just the filename itself + prompt_path = os.path.basename(prompt_path) + else: + base_directory = None + + system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory) content_graph = await LLMGateway.acreate_structured_output( content, system_prompt, response_model diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 019e9e4a1..d81516206 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -1,5 +1,5 @@ import asyncio -from typing import Type, List +from typing import Type, List, Optional from pydantic import BaseModel from cognee.infrastructure.databases.graph import get_graph_engine @@ -71,6 +71,7 @@ async def extract_graph_from_data( data_chunks: List[DocumentChunk], graph_model: Type[BaseModel], ontology_adapter: OntologyResolver = None, + custom_prompt: Optional[str] = None, ) -> List[DocumentChunk]: """ Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model. @@ -84,7 +85,10 @@ async def extract_graph_from_data( raise InvalidGraphModelError(graph_model) chunk_graphs = await asyncio.gather( - *[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] + *[ + LLMGateway.extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt) + for chunk in data_chunks + ] ) # Note: Filter edges with missing source or target nodes