diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index 6086977ce..1f1ddad0a 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional, List +from typing import Any, Optional, List, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor @@ -85,7 +85,11 @@ class EntityCompletionRetriever(BaseRetriever): return None async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, ) -> List[str]: """ Generate completion using provided context or fetch new context. @@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever): fetched if not provided. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + response_model=response_model, ) if session_save: diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index bb568924d..f071e41de 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional +from typing import Any, Optional, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine @@ -75,7 +75,11 @@ class CompletionRetriever(BaseRetriever): raise NoDataError("No data found in the system, please add data first.") from error async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, ) -> str: """ Generates an LLM completion using the context. @@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever): completion; if None, it retrieves the context for the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if session_save: @@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever): session_id=session_id, ) - return completion + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 58b6b586f..6b2c6a9e6 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -56,6 +56,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context: Optional[List[Edge]] = None, session_id: Optional[str] = None, context_extension_rounds=4, + response_model: Type = str, ) -> List[str]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): defaults to 'default_session'. (default None) - context_extension_rounds: The maximum number of rounds to extend the context with new triplets before halting. (default 4) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if self.save_interaction and context_text and triplets and completion: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 299db6855..39255fe68 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import ( - generate_structured_completion, + generate_completion, summarize_text, ) from cognee.modules.retrieval.utils.session_cache import ( @@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): questions based on reasoning. The public methods are: - get_completion - - get_structured_completion Instance variables include: - validation_system_prompt_path @@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): triplets += await self.get_context(followup_question) context_text = await self.resolve_edges_to_text(list(set(triplets))) - completion = await generate_structured_completion( + completion = await generate_completion( query=query, context=context_text, user_prompt_path=self.user_prompt_path, @@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): return completion, context_text, triplets - async def get_structured_completion( + async def get_completion( self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, - max_iter: int = 4, + max_iter=4, response_model: Type = str, - ) -> Any: + ) -> List[str]: """ - Generate structured completion responses based on a user query and contextual information. + Generate completion responses based on a user query and contextual information. - This method applies the same chain-of-thought logic as get_completion but returns + This method interacts with a language model client to retrieve a structured response, + using a series of iterations to refine the answers and generate follow-up questions + based on reasoning derived from previous outputs. It raises exceptions if the context + retrieval fails or if the model encounters issues in generating outputs. It returns structured output using the provided response model. Parameters: ----------- + - query (str): The user's query to be processed and answered. - - context (Optional[List[Edge]]): Optional context that may assist in answering the query. + - context (Optional[Any]): Optional context that may assist in answering the query. If not provided, it will be fetched based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) @@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): Returns: -------- - - Any: The generated structured completion based on the response model. + + - List[str]: A list containing the generated answer to the user's query. """ # Check if session saving is enabled cache_config = CacheConfig() @@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - return completion - - async def get_completion( - self, - query: str, - context: Optional[List[Edge]] = None, - session_id: Optional[str] = None, - max_iter=4, - ) -> List[str]: - """ - Generate completion responses based on a user query and contextual information. - - This method interacts with a language model client to retrieve a structured response, - using a series of iterations to refine the answers and generate follow-up questions - based on reasoning derived from previous outputs. It raises exceptions if the context - retrieval fails or if the model encounters issues in generating outputs. - - Parameters: - ----------- - - - query (str): The user's query to be processed and answered. - - context (Optional[Any]): Optional context that may assist in answering the query. - If not provided, it will be fetched based on the query. (default None) - - session_id (Optional[str]): Optional session identifier for caching. If None, - defaults to 'default_session'. (default None) - - max_iter: The maximum number of iterations to refine the answer and generate - follow-up questions. (default 4) - - Returns: - -------- - - - List[str]: A list containing the generated answer to the user's query. - """ - completion = await self.get_structured_completion( - query=query, - context=context, - session_id=session_id, - max_iter=max_iter, - response_model=str, - ) - return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index b7ab4edae..b544e8ead 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -146,6 +146,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, + response_model: Type = str, ) -> List[str]: """ Generates a completion using graph connections context based on a query. @@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index ec68d37bb..38d69ec80 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -146,7 +146,11 @@ class TemporalRetriever(GraphCompletionRetriever): return self.descriptions_to_string(top_k_events) async def get_completion( - self, query: str, context: Optional[str] = None, session_id: Optional[str] = None + self, + query: str, + context: Optional[str] = None, + session_id: Optional[str] = None, + response_model: Type = str, ) -> List[str]: """ Generates a response using the query and optional context. @@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever): retrieved based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + response_model=response_model, ) if session_save: diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index db7a10252..b77d7ef90 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt -async def generate_structured_completion( +async def generate_completion( query: str, context: str, user_prompt_path: str, @@ -11,8 +11,8 @@ async def generate_structured_completion( system_prompt: Optional[str] = None, conversation_history: Optional[str] = None, response_model: Type = str, -) -> Any: - """Generates a structured completion using LLM with given context and prompts.""" +) -> str: + """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} user_prompt = render_prompt(user_prompt_path, args) system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path) @@ -28,26 +28,6 @@ async def generate_structured_completion( ) -async def generate_completion( - query: str, - context: str, - user_prompt_path: str, - system_prompt_path: str, - system_prompt: Optional[str] = None, - conversation_history: Optional[str] = None, -) -> str: - """Generates a completion using LLM with given context and prompts.""" - return await generate_structured_completion( - query=query, - context=context, - user_prompt_path=user_prompt_path, - system_prompt_path=system_prompt_path, - system_prompt=system_prompt, - conversation_history=conversation_history, - response_model=str, - ) - - async def summarize_text( text: str, system_prompt_path: str = "summarize_search_results.txt", diff --git a/cognee/tasks/feedback/generate_improved_answers.py b/cognee/tasks/feedback/generate_improved_answers.py index e439cf9e5..d2b143d29 100644 --- a/cognee/tasks/feedback/generate_improved_answers.py +++ b/cognee/tasks/feedback/generate_improved_answers.py @@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction( ) retrieved_context = await retriever.get_context(query_text) - completion = await retriever.get_structured_completion( + completion = await retriever.get_completion( query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse, @@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction( new_context_text = await retriever.resolve_edges_to_text(retrieved_context) if completion: - enrichment.improved_answer = completion.answer + enrichment.improved_answer = completion[0].answer enrichment.new_context = new_context_text - enrichment.explanation = completion.explanation + enrichment.explanation = completion[0].explanation return enrichment else: logger.warning( diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 7fcfe0d6b..bf10dc023 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -206,16 +206,22 @@ class TestGraphCompletionCoTRetriever: retriever = GraphCompletionCotRetriever() # Test with string response model (default) - string_answer = await retriever.get_structured_completion("Who works at Figma?") - assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}" - assert string_answer.strip(), "Answer should not be empty" + string_answer = await retriever.get_completion("Who works at Figma?") + assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in string_answer), ( + "Answer should not be empty" + ) # Test with structured response model - structured_answer = await retriever.get_structured_completion( + structured_answer = await retriever.get_completion( "Who works at Figma?", response_model=TestAnswer ) - assert isinstance(structured_answer, TestAnswer), ( + assert isinstance(structured_answer, list), ( + f"Expected list, got {type(structured_answer).__name__}" + ) + assert all(isinstance(item, TestAnswer) for item in string_answer), ( f"Expected TestAnswer, got {type(structured_answer).__name__}" ) - assert structured_answer.answer.strip(), "Answer field should not be empty" - assert structured_answer.explanation.strip(), "Explanation field should not be empty" + + assert structured_answer[0].answer.strip(), "Answer field should not be empty" + assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"