From ac87e62adb55803cc2335889b21bcc3777d3d833 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 28 Aug 2025 10:52:08 +0200 Subject: [PATCH 1/7] feat: Save search flag progress --- .../modules/retrieval/completion_retriever.py | 17 ++++++++++++-- ..._completion_context_extension_retriever.py | 13 ++++++++++- .../graph_completion_cot_retriever.py | 15 +++++++++++-- .../retrieval/graph_completion_retriever.py | 12 +++++++++- cognee/modules/retrieval/utils/completion.py | 22 +++++++++++++------ cognee/modules/search/methods/search.py | 7 +++++- 6 files changed, 72 insertions(+), 14 deletions(-) diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 655a9010d..e9c8331a1 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -65,7 +65,14 @@ class CompletionRetriever(BaseRetriever): logger.error("DocumentChunk_text collection not found") raise NoDataError("No data found in the system, please add data first.") from error - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion( + self, + query: str, + context: Optional[Any] = None, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, + ) -> Any: """ Generates an LLM completion using the context. @@ -88,6 +95,12 @@ class CompletionRetriever(BaseRetriever): context = await self.get_context(query) completion = await generate_completion( - query, context, self.user_prompt_path, self.system_prompt_path + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, + only_context=only_context, ) return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index d05e6b4fa..f25edb4a7 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -41,7 +41,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) async def get_completion( - self, query: str, context: Optional[Any] = None, context_extension_rounds=4 + self, + query: str, + context: Optional[Any] = None, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, + context_extension_rounds=4, ) -> List[str]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -86,6 +92,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, ) triplets += await self.get_triplets(completion) @@ -112,6 +120,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, + only_context=only_context, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 032dccf9e..63ab6b3b7 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -51,7 +51,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self.followup_user_prompt_path = followup_user_prompt_path async def get_completion( - self, query: str, context: Optional[Any] = None, max_iter=4 + self, + query: str, + context: Optional[Any] = None, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, + max_iter=4, ) -> List[str]: """ Generate completion responses based on a user query and contextual information. @@ -92,6 +98,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, ) logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: @@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context, triplets=triplets ) - return [completion] + if only_context: + return [context] + else: + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index fb3cf4885..d88252054 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -151,7 +151,14 @@ class GraphCompletionRetriever(BaseRetriever): return context, triplets - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion( + self, + query: str, + context: Optional[Any] = None, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, + ) -> Any: """ Generates a completion using graph connections context based on a query. @@ -177,6 +184,9 @@ class GraphCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, + only_context=only_context, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index ca0b30c18..69381d647 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -6,18 +6,26 @@ async def generate_completion( context: str, user_prompt_path: str, system_prompt_path: str, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, ) -> str: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} - user_prompt = LLMGateway.render_prompt(user_prompt_path, args) - system_prompt = LLMGateway.read_query_prompt(system_prompt_path) - - return await LLMGateway.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, + user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args) + system_prompt = LLMGateway.read_query_prompt( + system_prompt if system_prompt else system_prompt_path ) + if only_context: + return context + else: + return await LLMGateway.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=str, + ) + async def summarize_text( text: str, diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index f5f2a793a..3e5d6ffcd 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -101,11 +101,14 @@ async def specific_search( query: str, user: User, system_prompt_path="answer_simple_question.txt", + user_prompt: str = None, + system_prompt: str = None, top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, @@ -159,7 +162,9 @@ async def specific_search( send_telemetry("cognee.search EXECUTION STARTED", user.id) - results = await search_task(query) + results = await search_task( + query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context + ) send_telemetry("cognee.search EXECUTION COMPLETED", user.id) From 2915698d601f8ce84d5d63458d0e8da51794fa67 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 28 Aug 2025 13:43:37 +0200 Subject: [PATCH 2/7] feat: Add only_context and system prompt flags for search --- .../v1/search/routers/get_search_router.py | 6 + cognee/api/v1/search/search.py | 4 + .../modules/retrieval/completion_retriever.py | 18 ++- ..._completion_context_extension_retriever.py | 20 +-- .../graph_completion_cot_retriever.py | 12 +- .../retrieval/graph_completion_retriever.py | 12 +- .../graph_summary_completion_retriever.py | 4 +- .../modules/retrieval/summaries_retriever.py | 2 +- cognee/modules/retrieval/utils/completion.py | 18 +-- cognee/modules/search/methods/search.py | 117 +++++++++++++----- 10 files changed, 140 insertions(+), 73 deletions(-) diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index 0ceeb1abb..b141c6bdc 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -20,7 +20,9 @@ class SearchPayloadDTO(InDTO): datasets: Optional[list[str]] = Field(default=None) dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]]) query: str = Field(default="What is in the document?") + system_prompt: Optional[str] = Field(default=None) top_k: Optional[int] = Field(default=10) + only_context: bool = Field(default=False) def get_search_router() -> APIRouter: @@ -102,7 +104,9 @@ def get_search_router() -> APIRouter: "datasets": payload.datasets, "dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []], "query": payload.query, + "system_prompt": payload.system_prompt, "top_k": payload.top_k, + "only_context": payload.only_context, }, ) @@ -115,7 +119,9 @@ def get_search_router() -> APIRouter: user=user, datasets=payload.datasets, dataset_ids=payload.dataset_ids, + system_prompt=payload.system_prompt, top_k=payload.top_k, + only_context=payload.only_context, ) return results diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index f37f8ba6d..113d33557 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -16,11 +16,13 @@ async def search( datasets: Optional[Union[list[str], str]] = None, dataset_ids: Optional[Union[list[UUID], UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, + only_context: bool = False, ) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -183,11 +185,13 @@ async def search( dataset_ids=dataset_ids if dataset_ids else datasets, user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) return filtered_search_results diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index e9c8331a1..4d34dfdbe 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever): self, user_prompt_path: str = "context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: str = None, top_k: Optional[int] = 1, + only_context: bool = False, ): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 1 + self.system_prompt = system_prompt + self.only_context = only_context async def get_context(self, query: str) -> str: """ @@ -65,14 +69,7 @@ class CompletionRetriever(BaseRetriever): logger.error("DocumentChunk_text collection not found") raise NoDataError("No data found in the system, please add data first.") from error - async def get_completion( - self, - query: str, - context: Optional[Any] = None, - user_prompt: str = None, - system_prompt: str = None, - only_context: bool = False, - ) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """ Generates an LLM completion using the context. @@ -99,8 +96,7 @@ class CompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, - user_prompt=user_prompt, - system_prompt=system_prompt, - only_context=only_context, + system_prompt=self.system_prompt, + only_context=self.only_context, ) return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index f25edb4a7..8bdf5f1a0 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, @@ -38,15 +40,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ) async def get_completion( self, query: str, context: Optional[Any] = None, - user_prompt: str = None, - system_prompt: str = None, - only_context: bool = False, context_extension_rounds=4, ) -> List[str]: """ @@ -92,8 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, - user_prompt=user_prompt, - system_prompt=system_prompt, + system_prompt=self.system_prompt, ) triplets += await self.get_triplets(completion) @@ -120,9 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, - user_prompt=user_prompt, - system_prompt=system_prompt, - only_context=only_context, + system_prompt=self.system_prompt, + only_context=self.only_context, ) if self.save_interaction and context and triplets and completion: @@ -130,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context, triplets=triplets ) - return [completion] + if self.only_context: + return [context] + else: + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 63ab6b3b7..86ff8555b 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): validation_system_prompt_path: str = "cot_validation_system_prompt.txt", followup_system_prompt_path: str = "cot_followup_system_prompt.txt", followup_user_prompt_path: str = "cot_followup_user_prompt.txt", + system_prompt: str = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + only_context=only_context, top_k=top_k, node_type=node_type, node_name=node_name, @@ -54,9 +58,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self, query: str, context: Optional[Any] = None, - user_prompt: str = None, - system_prompt: str = None, - only_context: bool = False, max_iter=4, ) -> List[str]: """ @@ -98,8 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, - user_prompt=user_prompt, - system_prompt=system_prompt, + system_prompt=self.system_prompt, ) logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: @@ -136,7 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context, triplets=triplets ) - if only_context: + if self.only_context: return [context] else: return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index d88252054..6a5193c56 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + system_prompt: str = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + only_context: bool = False, ): """Initialize retriever with prompt paths and search parameters.""" self.save_interaction = save_interaction self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path + self.system_prompt = system_prompt + self.only_context = only_context self.top_k = top_k if top_k is not None else 5 self.node_type = node_type self.node_name = node_name @@ -155,9 +159,6 @@ class GraphCompletionRetriever(BaseRetriever): self, query: str, context: Optional[Any] = None, - user_prompt: str = None, - system_prompt: str = None, - only_context: bool = False, ) -> Any: """ Generates a completion using graph connections context based on a query. @@ -184,9 +185,8 @@ class GraphCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, - user_prompt=user_prompt, - system_prompt=system_prompt, - only_context=only_context, + system_prompt=self.system_prompt, + only_context=self.only_context, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index d344ebd26..051f39b22 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", summarize_prompt_path: str = "summarize_search_results.txt", + system_prompt: Optional[str] = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, @@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, ) self.summarize_prompt_path = summarize_prompt_path @@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): - str: A summary string representing the content of the retrieved edges. """ direct_text = await super().resolve_edges_to_text(retrieved_edges) - return await summarize_text(direct_text, self.summarize_prompt_path) + return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt) diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 56f414013..df35cdc51 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever): logger.info(f"Returning {len(summary_payloads)} summary payloads") return summary_payloads - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any: """ Generates a completion using summaries context. diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index 69381d647..4c2639517 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -1,3 +1,4 @@ +from typing import Optional from cognee.infrastructure.llm.LLMGateway import LLMGateway @@ -6,15 +7,15 @@ async def generate_completion( context: str, user_prompt_path: str, system_prompt_path: str, - user_prompt: str = None, - system_prompt: str = None, + user_prompt: Optional[str] = None, + system_prompt: Optional[str] = None, only_context: bool = False, ) -> str: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} - user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args) - system_prompt = LLMGateway.read_query_prompt( - system_prompt if system_prompt else system_prompt_path + user_prompt = user_prompt if user_prompt else LLMGateway.render_prompt(user_prompt_path, args) + system_prompt = ( + system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path) ) if only_context: @@ -29,10 +30,13 @@ async def generate_completion( async def summarize_text( text: str, - prompt_path: str = "summarize_search_results.txt", + system_prompt_path: str = "summarize_search_results.txt", + system_prompt: str = None, ) -> str: """Summarizes text using LLM with the specified prompt.""" - system_prompt = LLMGateway.read_query_prompt(prompt_path) + system_prompt = ( + system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path) + ) return await LLMGateway.acreate_structured_output( text_input=text, diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 3e5d6ffcd..465d0cbb3 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -37,11 +37,13 @@ async def search( dataset_ids: Union[list[UUID], None], user: User, system_prompt_path="answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = False, ): """ @@ -61,28 +63,34 @@ async def search( # Use search function filtered by permissions if access control is enabled if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": return await authorized_search( - query_text=query_text, query_type=query_type, + query_text=query_text, user=user, dataset_ids=dataset_ids, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, + node_type=node_type, + node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) query = await log_query(query_text, query_type.value, user.id) search_results = await specific_search( - query_type, - query_text, - user, + query_type=query_type, + query_text=query_text, + user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, node_type=node_type, node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) await log_result( @@ -98,11 +106,10 @@ async def search( async def specific_search( query_type: SearchType, - query: str, + query_text: str, user: User, - system_prompt_path="answer_simple_question.txt", - user_prompt: str = None, - system_prompt: str = None, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, @@ -115,7 +122,10 @@ async def specific_search( SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion, SearchType.RAG_COMPLETION: CompletionRetriever( - system_prompt_path=system_prompt_path, top_k=top_k + system_prompt_path=system_prompt_path, + top_k=top_k, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( system_prompt_path=system_prompt_path, @@ -123,6 +133,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, @@ -130,6 +142,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, @@ -137,6 +151,8 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, + only_context=only_context, ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, @@ -144,6 +160,7 @@ async def specific_search( node_type=node_type, node_name=node_name, save_interaction=save_interaction, + system_prompt=system_prompt, ).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, @@ -153,7 +170,7 @@ async def specific_search( # If the query type is FEELING_LUCKY, select the search type intelligently if query_type is SearchType.FEELING_LUCKY: - query_type = await select_search_type(query) + query_type = await select_search_type(query_text) search_task = search_tasks.get(query_type) @@ -162,9 +179,7 @@ async def specific_search( send_telemetry("cognee.search EXECUTION STARTED", user.id) - results = await search_task( - query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context - ) + results = await search_task(query=query_text) send_telemetry("cognee.search EXECUTION COMPLETED", user.id) @@ -172,14 +187,18 @@ async def specific_search( async def authorized_search( - query_text: str, query_type: SearchType, - user: User = None, + query_text: str, + user: User, dataset_ids: Optional[list[UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, top_k: int = 10, - save_interaction: bool = False, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, + save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ) -> list: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. @@ -193,14 +212,18 @@ async def authorized_search( # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await specific_search_by_context( - search_datasets, - query_text, - query_type, - user, - system_prompt_path, - top_k, - save_interaction, + search_datasets=search_datasets, + query_type=query_type, + query_text=query_text, + user=user, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id) @@ -210,13 +233,17 @@ async def authorized_search( async def specific_search_by_context( search_datasets: list[Dataset], - query_text: str, query_type: SearchType, + query_text: str, user: User, - system_prompt_path: str, - top_k: int, - save_interaction: bool = False, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, + top_k: int = 10, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, + save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ): """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -224,18 +251,33 @@ async def specific_search_by_context( """ async def _search_by_context( - dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + dataset: Dataset, + query_type: SearchType, + query_text: str, + user: User, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, + top_k: int = 10, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, + save_interaction: Optional[bool] = False, + last_k: Optional[int] = None, + only_context: bool = None, ): # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) search_results = await specific_search( - query_type, - query_text, - user, + query_type=query_type, + query_text=query_text, + user=user, system_prompt_path=system_prompt_path, + system_prompt=system_prompt, top_k=top_k, + node_type=node_type, + node_name=node_name, save_interaction=save_interaction, last_k=last_k, + only_context=only_context, ) return { "search_result": search_results, @@ -248,7 +290,18 @@ async def specific_search_by_context( for dataset in search_datasets: tasks.append( _search_by_context( - dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + dataset=dataset, + query_type=query_type, + query_text=query_text, + user=user, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + last_k=last_k, + only_context=only_context, ) ) From 7fd5e1e0104c061e056c5e97a4b0ea04effa45dd Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 28 Aug 2025 13:53:08 +0200 Subject: [PATCH 3/7] fix: Make custom_prompt be None by default --- cognee/api/v1/cognify/routers/get_cognify_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 6809f089a..d40345f8e 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -38,7 +38,7 @@ class CognifyPayloadDTO(InDTO): dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]]) run_in_background: Optional[bool] = Field(default=False) custom_prompt: Optional[str] = Field( - default=None, description="Custom prompt for entity extraction and graph generation" + default="", description="Custom prompt for entity extraction and graph generation" ) From 966e676d610a38b1607ce415ec8b9d620cf5cec2 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 28 Aug 2025 17:23:15 +0200 Subject: [PATCH 4/7] refactor: Have search prompt be empty string by default --- cognee/api/v1/search/routers/get_search_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index b141c6bdc..39a896dd8 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -20,7 +20,7 @@ class SearchPayloadDTO(InDTO): datasets: Optional[list[str]] = Field(default=None) dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]]) query: str = Field(default="What is in the document?") - system_prompt: Optional[str] = Field(default=None) + system_prompt: Optional[str] = Field(default="") top_k: Optional[int] = Field(default=10) only_context: bool = Field(default=False) From 5bfae7a36b10b746c167a4895d108130f9a62a2a Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 29 Aug 2025 10:30:49 +0200 Subject: [PATCH 5/7] refactor: Resolve unit tests failing for search --- cognee/modules/search/methods/search.py | 2 +- .../unit/modules/search/search_methods_test.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 465d0cbb3..2db105d71 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -179,7 +179,7 @@ async def specific_search( send_telemetry("cognee.search EXECUTION STARTED", user.id) - results = await search_task(query=query_text) + results = await search_task(query_text) send_telemetry("cognee.search EXECUTION COMPLETED", user.id) diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index 46995d087..9833a770b 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -58,15 +58,17 @@ async def test_search( # Verify mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) mock_specific_search.assert_called_once_with( - query_type, - query_text, - mock_user, + query_type=query_type, + query_text=query_text, + user=mock_user, system_prompt_path="answer_simple_question.txt", + system_prompt=None, top_k=10, node_type=None, node_name=None, save_interaction=False, last_k=None, + only_context=False, ) # Verify result logging @@ -201,7 +203,10 @@ async def test_specific_search_feeling_lucky( if retriever_name == "CompletionRetriever": mock_retriever_class.assert_called_once_with( - system_prompt_path="answer_simple_question.txt", top_k=top_k + system_prompt_path="answer_simple_question.txt", + top_k=top_k, + system_prompt=None, + only_context=None, ) else: mock_retriever_class.assert_called_once_with(top_k=top_k) From c3f5840bff1a9623066718d3a6ab14994bd4b0fe Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 29 Aug 2025 12:24:15 +0200 Subject: [PATCH 6/7] refactor: Remove unused argument --- cognee/modules/retrieval/utils/completion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index 4c2639517..81e636aad 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -7,13 +7,12 @@ async def generate_completion( context: str, user_prompt_path: str, system_prompt_path: str, - user_prompt: Optional[str] = None, system_prompt: Optional[str] = None, only_context: bool = False, ) -> str: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} - user_prompt = user_prompt if user_prompt else LLMGateway.render_prompt(user_prompt_path, args) + user_prompt = LLMGateway.render_prompt(user_prompt_path, args) system_prompt = ( system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path) ) From 614055c850661fcbb816a9bf77b2e61324a83f69 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 29 Aug 2025 14:16:18 +0200 Subject: [PATCH 7/7] refactor: Add docs for new search arguments --- cognee/api/v1/search/routers/get_search_router.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index 39a896dd8..f9f4e4764 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -1,9 +1,11 @@ from uuid import UUID +import pathlib from typing import Optional from datetime import datetime from pydantic import Field from fastapi import Depends, APIRouter from fastapi.responses import JSONResponse + from cognee.modules.search.types import SearchType from cognee.api.DTO import InDTO, OutDTO from cognee.modules.users.exceptions.exceptions import PermissionDeniedError @@ -20,7 +22,9 @@ class SearchPayloadDTO(InDTO): datasets: Optional[list[str]] = Field(default=None) dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]]) query: str = Field(default="What is in the document?") - system_prompt: Optional[str] = Field(default="") + system_prompt: Optional[str] = Field( + default="Answer the question using the provided context. Be as brief as possible." + ) top_k: Optional[int] = Field(default=10) only_context: bool = Field(default=False) @@ -81,7 +85,9 @@ def get_search_router() -> APIRouter: - **datasets** (Optional[List[str]]): List of dataset names to search within - **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within - **query** (str): The search query string + - **system_prompt** Optional[str]: System prompt to be used for Completion type searches in Cognee - **top_k** (Optional[int]): Maximum number of results to return (default: 10) + - **only_context** bool: Set to true to only return context Cognee will be sending to LLM in Completion type searches. This will be returned instead of LLM calls for completion type searches. ## Response Returns a list of search results containing relevant nodes from the graph.