feat: Added custom prompt to cognify (#1278)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
62afced9a5
9 changed files with 99 additions and 33 deletions
|
|
@ -121,7 +121,9 @@ async def cognee_add_developer_rules(
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognify(data: str, graph_model_file: str = None, graph_model_name: str = None) -> list:
|
||||
async def cognify(
|
||||
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
|
||||
) -> list:
|
||||
"""
|
||||
Transform ingested data into a structured knowledge graph.
|
||||
|
||||
|
|
@ -169,6 +171,12 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
|
|||
Required if graph_model_file is specified.
|
||||
Default is None, which uses the default KnowledgeGraph class.
|
||||
|
||||
custom_prompt : str, optional
|
||||
Custom prompt string to use for entity extraction and graph generation.
|
||||
If provided, this prompt will be used instead of the default prompts for
|
||||
knowledge graph extraction. The prompt should guide the LLM on how to
|
||||
extract entities and relationships from the text content.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
|
|
@ -224,7 +232,10 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
|
|||
"""
|
||||
|
||||
async def cognify_task(
|
||||
data: str, graph_model_file: str = None, graph_model_name: str = None
|
||||
data: str,
|
||||
graph_model_file: str = None,
|
||||
graph_model_name: str = None,
|
||||
custom_prompt: str = None,
|
||||
) -> str:
|
||||
"""Build knowledge graph from the input text"""
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
|
|
@ -239,7 +250,7 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
|
|||
await cognee.add(data)
|
||||
|
||||
try:
|
||||
await cognee.cognify(graph_model=graph_model)
|
||||
await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt)
|
||||
logger.info("Cognify process finished.")
|
||||
except Exception as e:
|
||||
logger.error("Cognify process failed.")
|
||||
|
|
@ -250,6 +261,7 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
|
|||
data=data,
|
||||
graph_model_file=graph_model_file,
|
||||
graph_model_name=graph_model_name,
|
||||
custom_prompt=custom_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ async def cognify(
|
|||
graph_db_config: dict = None,
|
||||
run_in_background: bool = False,
|
||||
incremental_loading: bool = True,
|
||||
custom_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Transform ingested data into a structured knowledge graph.
|
||||
|
|
@ -101,6 +102,10 @@ async def cognify(
|
|||
If False, waits for completion before returning.
|
||||
Background mode recommended for large datasets (>100MB).
|
||||
Use pipeline_run_id from return value to monitor progress.
|
||||
custom_prompt: Optional custom prompt string to use for entity extraction and graph generation.
|
||||
If provided, this prompt will be used instead of the default prompts for
|
||||
knowledge graph extraction. The prompt should guide the LLM on how to
|
||||
extract entities and relationships from the text content.
|
||||
|
||||
Returns:
|
||||
Union[dict, list[PipelineRunInfo]]:
|
||||
|
|
@ -177,7 +182,9 @@ async def cognify(
|
|||
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
|
||||
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
|
||||
"""
|
||||
tasks = await get_default_tasks(user, graph_model, chunker, chunk_size, ontology_file_path)
|
||||
tasks = await get_default_tasks(
|
||||
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
|
||||
)
|
||||
|
||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||
|
|
@ -201,6 +208,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
chunker=TextChunker,
|
||||
chunk_size: int = None,
|
||||
ontology_file_path: Optional[str] = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
) -> list[Task]:
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
|
|
@ -214,6 +222,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
extract_graph_from_data,
|
||||
graph_model=graph_model,
|
||||
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
||||
custom_prompt=custom_prompt,
|
||||
task_config={"batch_size": 10},
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ class CognifyPayloadDTO(InDTO):
|
|||
datasets: Optional[List[str]] = Field(default=None)
|
||||
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
|
||||
run_in_background: Optional[bool] = Field(default=False)
|
||||
custom_prompt: Optional[str] = Field(
|
||||
default=None, description="Custom prompt for entity extraction and graph generation"
|
||||
)
|
||||
|
||||
|
||||
def get_cognify_router() -> APIRouter:
|
||||
|
|
@ -63,6 +66,7 @@ def get_cognify_router() -> APIRouter:
|
|||
- **datasets** (Optional[List[str]]): List of dataset names to process. Dataset names are resolved to datasets owned by the authenticated user.
|
||||
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
|
||||
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
|
||||
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
|
||||
|
||||
## Response
|
||||
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
|
||||
|
|
@ -76,7 +80,8 @@ def get_cognify_router() -> APIRouter:
|
|||
```json
|
||||
{
|
||||
"datasets": ["research_papers", "documentation"],
|
||||
"run_in_background": false
|
||||
"run_in_background": false,
|
||||
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -106,7 +111,10 @@ def get_cognify_router() -> APIRouter:
|
|||
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
||||
|
||||
cognify_run = await cognee_cognify(
|
||||
datasets, user, run_in_background=payload.run_in_background
|
||||
datasets,
|
||||
user,
|
||||
run_in_background=payload.run_in_background,
|
||||
custom_prompt=payload.custom_prompt,
|
||||
)
|
||||
|
||||
# If any cognify run errored return JSONResponse with proper error status code
|
||||
|
|
|
|||
|
|
@ -49,6 +49,10 @@ DEFAULT_TOOLS = [
|
|||
"type": "string",
|
||||
"description": "Path to a custom ontology file",
|
||||
},
|
||||
"custom_prompt": {
|
||||
"type": "string",
|
||||
"description": "Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts.",
|
||||
},
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -88,11 +88,16 @@ async def handle_cognify(arguments: Dict[str, Any], user) -> str:
|
|||
"""Handle cognify function call"""
|
||||
text = arguments.get("text")
|
||||
ontology_file_path = arguments.get("ontology_file_path")
|
||||
custom_prompt = arguments.get("custom_prompt")
|
||||
|
||||
if text:
|
||||
await add(data=text, user=user)
|
||||
|
||||
await cognify(user=user, ontology_file_path=ontology_file_path if ontology_file_path else None)
|
||||
await cognify(
|
||||
user=user,
|
||||
ontology_file_path=ontology_file_path if ontology_file_path else None,
|
||||
custom_prompt=custom_prompt,
|
||||
)
|
||||
|
||||
return (
|
||||
"Text successfully converted into knowledge graph."
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Type
|
||||
from typing import Type, Optional, Coroutine
|
||||
from pydantic import BaseModel
|
||||
from typing import Coroutine
|
||||
from cognee.infrastructure.llm import get_llm_config
|
||||
|
||||
|
||||
|
|
@ -79,7 +78,10 @@ class LLMGateway:
|
|||
|
||||
@staticmethod
|
||||
def extract_content_graph(
|
||||
content: str, response_model: Type[BaseModel], mode: str = "simple"
|
||||
content: str,
|
||||
response_model: Type[BaseModel],
|
||||
mode: str = "simple",
|
||||
custom_prompt: Optional[str] = None,
|
||||
) -> Coroutine:
|
||||
llm_config = get_llm_config()
|
||||
if llm_config.structured_output_framework.upper() == "BAML":
|
||||
|
|
@ -87,13 +89,20 @@ class LLMGateway:
|
|||
extract_content_graph,
|
||||
)
|
||||
|
||||
return extract_content_graph(content=content, response_model=response_model, mode=mode)
|
||||
return extract_content_graph(
|
||||
content=content,
|
||||
response_model=response_model,
|
||||
mode=mode,
|
||||
custom_prompt=custom_prompt,
|
||||
)
|
||||
else:
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
|
||||
extract_content_graph,
|
||||
)
|
||||
|
||||
return extract_content_graph(content=content, response_model=response_model)
|
||||
return extract_content_graph(
|
||||
content=content, response_model=response_model, custom_prompt=custom_prompt
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_categories(content: str, response_model: Type[BaseModel]) -> Coroutine:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
|
|
@ -6,7 +6,10 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.asyn
|
|||
|
||||
|
||||
async def extract_content_graph(
|
||||
content: str, response_model: Type[BaseModel], mode: str = "simple"
|
||||
content: str,
|
||||
response_model: Type[BaseModel],
|
||||
mode: str = "simple",
|
||||
custom_prompt: Optional[str] = None,
|
||||
):
|
||||
config = get_llm_config()
|
||||
setup_logging()
|
||||
|
|
@ -26,8 +29,16 @@ async def extract_content_graph(
|
|||
# return graph
|
||||
|
||||
# else:
|
||||
graph = await b.ExtractContentGraphGeneric(
|
||||
content, mode=mode, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
if custom_prompt:
|
||||
graph = await b.ExtractContentGraphGeneric(
|
||||
content,
|
||||
mode="custom",
|
||||
custom_prompt_content=custom_prompt,
|
||||
baml_options={"client_registry": config.baml_registry},
|
||||
)
|
||||
else:
|
||||
graph = await b.ExtractContentGraphGeneric(
|
||||
content, mode=mode, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
|
||||
return graph
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
|
@ -8,21 +8,25 @@ from cognee.infrastructure.llm.config import (
|
|||
)
|
||||
|
||||
|
||||
async def extract_content_graph(content: str, response_model: Type[BaseModel]):
|
||||
llm_config = get_llm_config()
|
||||
|
||||
prompt_path = llm_config.graph_prompt_path
|
||||
|
||||
# Check if the prompt path is an absolute path or just a filename
|
||||
if os.path.isabs(prompt_path):
|
||||
# directory containing the file
|
||||
base_directory = os.path.dirname(prompt_path)
|
||||
# just the filename itself
|
||||
prompt_path = os.path.basename(prompt_path)
|
||||
async def extract_content_graph(
|
||||
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
|
||||
):
|
||||
if custom_prompt:
|
||||
system_prompt = custom_prompt
|
||||
else:
|
||||
base_directory = None
|
||||
llm_config = get_llm_config()
|
||||
prompt_path = llm_config.graph_prompt_path
|
||||
|
||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
# Check if the prompt path is an absolute path or just a filename
|
||||
if os.path.isabs(prompt_path):
|
||||
# directory containing the file
|
||||
base_directory = os.path.dirname(prompt_path)
|
||||
# just the filename itself
|
||||
prompt_path = os.path.basename(prompt_path)
|
||||
else:
|
||||
base_directory = None
|
||||
|
||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
|
||||
content_graph = await LLMGateway.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Type, List
|
||||
from typing import Type, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
@ -71,6 +71,7 @@ async def extract_graph_from_data(
|
|||
data_chunks: List[DocumentChunk],
|
||||
graph_model: Type[BaseModel],
|
||||
ontology_adapter: OntologyResolver = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
||||
|
|
@ -84,7 +85,10 @@ async def extract_graph_from_data(
|
|||
raise InvalidGraphModelError(graph_model)
|
||||
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
*[
|
||||
LLMGateway.extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
|
||||
for chunk in data_chunks
|
||||
]
|
||||
)
|
||||
|
||||
# Note: Filter edges with missing source or target nodes
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue