diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 6809f089a..d40345f8e 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -38,7 +38,7 @@ class CognifyPayloadDTO(InDTO): dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]]) run_in_background: Optional[bool] = Field(default=False) custom_prompt: Optional[str] = Field( - default=None, description="Custom prompt for entity extraction and graph generation" + default="", description="Custom prompt for entity extraction and graph generation" ) diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index 003df7cd4..2cf087b25 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -1,9 +1,11 @@ from uuid import UUID +import pathlib from typing import Optional from datetime import datetime from pydantic import Field from fastapi import Depends, APIRouter from fastapi.responses import JSONResponse + from cognee.modules.search.types import SearchType from cognee.api.DTO import InDTO, OutDTO from cognee.modules.users.exceptions.exceptions import PermissionDeniedError @@ -20,8 +22,12 @@ class SearchPayloadDTO(InDTO): datasets: Optional[list[str]] = Field(default=None) dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]]) query: str = Field(default="What is in the document?") + system_prompt: Optional[str] = Field( + default="Answer the question using the provided context. Be as brief as possible." + ) node_name: Optional[list[str]] = Field(default=None, example=[]) top_k: Optional[int] = Field(default=10) + only_context: bool = Field(default=False) def get_search_router() -> APIRouter: @@ -80,8 +86,10 @@ def get_search_router() -> APIRouter: - **datasets** (Optional[List[str]]): List of dataset names to search within - **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within - **query** (str): The search query string + - **system_prompt** Optional[str]: System prompt to be used for Completion type searches in Cognee - **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search). - **top_k** (Optional[int]): Maximum number of results to return (default: 10) + - **only_context** bool: Set to true to only return context Cognee will be sending to LLM in Completion type searches. This will be returned instead of LLM calls for completion type searches. ## Response Returns a list of search results containing relevant nodes from the graph. @@ -104,8 +112,10 @@ def get_search_router() -> APIRouter: "datasets": payload.datasets, "dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []], "query": payload.query, + "system_prompt": payload.system_prompt, "node_name": payload.node_name, "top_k": payload.top_k, + "only_context": payload.only_context, }, ) @@ -118,8 +128,10 @@ def get_search_router() -> APIRouter: user=user, datasets=payload.datasets, dataset_ids=payload.dataset_ids, + system_prompt=payload.system_prompt, node_name=payload.node_name, top_k=payload.top_k, + only_context=payload.only_context, ) return results diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 344e763ae..49f7aee51 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -17,11 +17,13 @@ async def search( datasets: Optional[Union[list[str], str]] = None, dataset_ids: Optional[Union[list[UUID], UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, + only_context: bool = False, ) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -184,11 +186,13 @@ async def search( dataset_ids=dataset_ids if dataset_ids else datasets, user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) return filtered_search_results diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 655a9010d..4d34dfdbe 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever): self, user_prompt_path: str = "context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: str = None, top_k: Optional[int] = 1, + only_context: bool = False, ): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 1 + self.system_prompt = system_prompt + self.only_context = only_context async def get_context(self, query: str) -> str: """ @@ -88,6 +92,11 @@ class CompletionRetriever(BaseRetriever): context = await self.get_context(query) completion = await generate_completion( - query, context, self.user_prompt_path, self.system_prompt_path + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + only_context=self.only_context, ) return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index d05e6b4fa..8bdf5f1a0 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, @@ -38,10 +40,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ) async def get_completion( - self, query: str, context: Optional[Any] = None, context_extension_rounds=4 + self, + query: str, + context: Optional[Any] = None, + context_extension_rounds=4, ) -> List[str]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -86,6 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, ) triplets += await self.get_triplets(completion) @@ -112,6 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + only_context=self.only_context, ) if self.save_interaction and context and triplets and completion: @@ -119,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context, triplets=triplets ) - return [completion] + if self.only_context: + return [context] + else: + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 032dccf9e..86ff8555b 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): validation_system_prompt_path: str = "cot_validation_system_prompt.txt", followup_system_prompt_path: str = "cot_followup_system_prompt.txt", followup_user_prompt_path: str = "cot_followup_user_prompt.txt", + system_prompt: str = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + only_context=only_context, top_k=top_k, node_type=node_type, node_name=node_name, @@ -51,7 +55,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self.followup_user_prompt_path = followup_user_prompt_path async def get_completion( - self, query: str, context: Optional[Any] = None, max_iter=4 + self, + query: str, + context: Optional[Any] = None, + max_iter=4, ) -> List[str]: """ Generate completion responses based on a user query and contextual information. @@ -92,6 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, ) logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: @@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context, triplets=triplets ) - return [completion] + if self.only_context: + return [context] + else: + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index fb3cf4885..6a5193c56 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: str = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): """Initialize retriever with prompt paths and search parameters.""" self.save_interaction = save_interaction self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path + self.system_prompt = system_prompt + self.only_context = only_context self.top_k = top_k if top_k is not None else 5 self.node_type = node_type self.node_name = node_name @@ -151,7 +155,11 @@ class GraphCompletionRetriever(BaseRetriever): return context, triplets - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion( + self, + query: str, + context: Optional[Any] = None, + ) -> Any: """ Generates a completion using graph connections context based on a query. @@ -177,6 +185,8 @@ class GraphCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + only_context=self.only_context, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index d344ebd26..051f39b22 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", summarize_prompt_path: str = "summarize_search_results.txt", + system_prompt: Optional[str] = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, @@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, ) self.summarize_prompt_path = summarize_prompt_path @@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): - str: A summary string representing the content of the retrieved edges. """ direct_text = await super().resolve_edges_to_text(retrieved_edges) - return await summarize_text(direct_text, self.summarize_prompt_path) + return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt) diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 56f414013..df35cdc51 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever): logger.info(f"Returning {len(summary_payloads)} summary payloads") return summary_payloads - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any: """ Generates a completion using summaries context. diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index ca0b30c18..81e636aad 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -1,3 +1,4 @@ +from typing import Optional from cognee.infrastructure.llm.LLMGateway import LLMGateway @@ -6,25 +7,35 @@ async def generate_completion( context: str, user_prompt_path: str, system_prompt_path: str, + system_prompt: Optional[str] = None, + only_context: bool = False, ) -> str: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} user_prompt = LLMGateway.render_prompt(user_prompt_path, args) - system_prompt = LLMGateway.read_query_prompt(system_prompt_path) - - return await LLMGateway.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, + system_prompt = ( + system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path) ) + if only_context: + return context + else: + return await LLMGateway.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=str, + ) + async def summarize_text( text: str, - prompt_path: str = "summarize_search_results.txt", + system_prompt_path: str = "summarize_search_results.txt", + system_prompt: str = None, ) -> str: """Summarizes text using LLM with the specified prompt.""" - system_prompt = LLMGateway.read_query_prompt(prompt_path) + system_prompt = ( + system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path) + ) return await LLMGateway.acreate_structured_output( text_input=text, diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 74ef2a6ad..71bf61d6b 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -38,11 +38,13 @@ async def search( dataset_ids: Union[list[UUID], None], user: User, system_prompt_path="answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = False, ): """ @@ -62,30 +64,34 @@ async def search( # Use search function filtered by permissions if access control is enabled if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": return await authorized_search( - query_text=query_text, query_type=query_type, + query_text=query_text, user=user, dataset_ids=dataset_ids, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) query = await log_query(query_text, query_type.value, user.id) search_results = await specific_search( - query_type, - query_text, - user, + query_type=query_type, + query_text=query_text, + user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) await log_result( @@ -101,21 +107,26 @@ async def search( async def specific_search( query_type: SearchType, - query: str, + query_text: str, user: User, - system_prompt_path="answer_simple_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion, SearchType.RAG_COMPLETION: CompletionRetriever( - system_prompt_path=system_prompt_path, top_k=top_k + system_prompt_path=system_prompt_path, + top_k=top_k, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( system_prompt_path=system_prompt_path, @@ -123,6 +134,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, @@ -130,6 +143,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, @@ -137,6 +152,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, @@ -144,6 +161,7 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, ).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, @@ -153,7 +171,7 @@ async def specific_search( # If the query type is FEELING_LUCKY, select the search type intelligently if query_type is SearchType.FEELING_LUCKY: - query_type = await select_search_type(query) + query_type = await select_search_type(query_text) search_task = search_tasks.get(query_type) @@ -162,7 +180,7 @@ async def specific_search( send_telemetry("cognee.search EXECUTION STARTED", user.id) - results = await search_task(query) + results = await search_task(query_text) send_telemetry("cognee.search EXECUTION COMPLETED", user.id) @@ -170,16 +188,18 @@ async def specific_search( async def authorized_search( - query_text: str, query_type: SearchType, - user: User = None, + query_text: str, + user: User, dataset_ids: Optional[list[UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, - save_interaction: bool = False, + save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ) -> list: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. @@ -193,16 +213,18 @@ async def authorized_search( # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await specific_search_by_context( - search_datasets, - query_text, - query_type, - user, - system_prompt_path, - top_k, + search_datasets=search_datasets, + query_type=query_type, + query_text=query_text, + user=user, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) @@ -212,15 +234,17 @@ async def authorized_search( async def specific_search_by_context( search_datasets: list[Dataset], - query_text: str, query_type: SearchType, + query_text: str, user: User, - system_prompt_path: str, - top_k: int, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, + top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, - save_interaction: bool = False, + save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ): """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -228,28 +252,33 @@ async def specific_search_by_context( """ async def _search_by_context( - dataset, - user, - query_type, - query_text, - system_prompt_path, - top_k, + dataset: Dataset, + query_type: SearchType, + query_text: str, + user: User, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, + top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, + save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ): # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) search_results = await specific_search( - query_type, - query_text, - user, + query_type=query_type, + query_text=query_text, + user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) return { "search_result": search_results, @@ -262,15 +291,18 @@ async def specific_search_by_context( for dataset in search_datasets: tasks.append( _search_by_context( - dataset, - user, - query_type, - query_text, - system_prompt_path, - top_k, + dataset=dataset, + query_type=query_type, + query_text=query_text, + user=user, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, node_type=node_type, node_name=node_name, + save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) ) diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index 004e1fca3..3a6bdc51e 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -58,15 +58,17 @@ async def test_search( # Verify mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) mock_specific_search.assert_called_once_with( - query_type, - query_text, - mock_user, + query_type=query_type, + query_text=query_text, + user=mock_user, system_prompt_path="answer_simple_question.txt", + system_prompt=None, top_k=10, node_type=NodeSet, node_name=None, save_interaction=False, last_k=None, + only_context=False, ) # Verify result logging @@ -201,7 +203,10 @@ async def test_specific_search_feeling_lucky( if retriever_name == "CompletionRetriever": mock_retriever_class.assert_called_once_with( - system_prompt_path="answer_simple_question.txt", top_k=top_k + system_prompt_path="answer_simple_question.txt", + top_k=top_k, + system_prompt=None, + only_context=None, ) else: mock_retriever_class.assert_called_once_with(top_k=top_k)