From ac87e62adb55803cc2335889b21bcc3777d3d833 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 28 Aug 2025 10:52:08 +0200 Subject: [PATCH] feat: Save search flag progress --- .../modules/retrieval/completion_retriever.py | 17 ++++++++++++-- ..._completion_context_extension_retriever.py | 13 ++++++++++- .../graph_completion_cot_retriever.py | 15 +++++++++++-- .../retrieval/graph_completion_retriever.py | 12 +++++++++- cognee/modules/retrieval/utils/completion.py | 22 +++++++++++++------ cognee/modules/search/methods/search.py | 7 +++++- 6 files changed, 72 insertions(+), 14 deletions(-) diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 655a9010d..e9c8331a1 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -65,7 +65,14 @@ class CompletionRetriever(BaseRetriever): logger.error("DocumentChunk_text collection not found") raise NoDataError("No data found in the system, please add data first.") from error - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion( + self, + query: str, + context: Optional[Any] = None, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, + ) -> Any: """ Generates an LLM completion using the context. @@ -88,6 +95,12 @@ class CompletionRetriever(BaseRetriever): context = await self.get_context(query) completion = await generate_completion( - query, context, self.user_prompt_path, self.system_prompt_path + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, + only_context=only_context, ) return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index d05e6b4fa..f25edb4a7 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -41,7 +41,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) async def get_completion( - self, query: str, context: Optional[Any] = None, context_extension_rounds=4 + self, + query: str, + context: Optional[Any] = None, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, + context_extension_rounds=4, ) -> List[str]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -86,6 +92,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, ) triplets += await self.get_triplets(completion) @@ -112,6 +120,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, + only_context=only_context, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 032dccf9e..63ab6b3b7 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -51,7 +51,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self.followup_user_prompt_path = followup_user_prompt_path async def get_completion( - self, query: str, context: Optional[Any] = None, max_iter=4 + self, + query: str, + context: Optional[Any] = None, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, + max_iter=4, ) -> List[str]: """ Generate completion responses based on a user query and contextual information. @@ -92,6 +98,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, ) logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: @@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context, triplets=triplets ) - return [completion] + if only_context: + return [context] + else: + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index fb3cf4885..d88252054 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -151,7 +151,14 @@ class GraphCompletionRetriever(BaseRetriever): return context, triplets - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion( + self, + query: str, + context: Optional[Any] = None, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, + ) -> Any: """ Generates a completion using graph connections context based on a query. @@ -177,6 +184,9 @@ class GraphCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + user_prompt=user_prompt, + system_prompt=system_prompt, + only_context=only_context, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index ca0b30c18..69381d647 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -6,18 +6,26 @@ async def generate_completion( context: str, user_prompt_path: str, system_prompt_path: str, + user_prompt: str = None, + system_prompt: str = None, + only_context: bool = False, ) -> str: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} - user_prompt = LLMGateway.render_prompt(user_prompt_path, args) - system_prompt = LLMGateway.read_query_prompt(system_prompt_path) - - return await LLMGateway.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, + user_prompt = LLMGateway.render_prompt(user_prompt if user_prompt else user_prompt_path, args) + system_prompt = LLMGateway.read_query_prompt( + system_prompt if system_prompt else system_prompt_path ) + if only_context: + return context + else: + return await LLMGateway.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=str, + ) + async def summarize_text( text: str, diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index f5f2a793a..3e5d6ffcd 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -101,11 +101,14 @@ async def specific_search( query: str, user: User, system_prompt_path="answer_simple_question.txt", + user_prompt: str = None, + system_prompt: str = None, top_k: int = 10, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, + only_context: bool = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, @@ -159,7 +162,9 @@ async def specific_search( send_telemetry("cognee.search EXECUTION STARTED", user.id) - results = await search_task(query) + results = await search_task( + query=query, system_prompt=system_prompt, user_prompt=user_prompt, only_context=only_context + ) send_telemetry("cognee.search EXECUTION COMPLETED", user.id)