diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index bdb8d9a8c..d8c25341f 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -77,7 +77,7 @@ class EntityCompletionRetriever(BaseRetriever): logger.error(f"Context retrieval failed: {str(e)}") return None - async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]: + async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> List[str]: """ Generate completion using provided context or fetch new context. @@ -91,6 +91,8 @@ class EntityCompletionRetriever(BaseRetriever): - query (str): The query string for which completion is being generated. - context (Optional[Any]): Optional context to be used for generating completion; fetched if not provided. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) Returns: -------- diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index 084b50cdb..086473805 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -61,7 +61,7 @@ class ChunksRetriever(BaseRetriever): logger.info(f"Returning {len(chunk_payloads)} chunk payloads") return chunk_payloads - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: """ Generates a completion using document chunks context. @@ -74,6 +74,8 @@ class ChunksRetriever(BaseRetriever): - query (str): The query string to be used for generating a completion. - context (Optional[Any]): Optional pre-fetched context to use for generating the completion; if None, it retrieves the context for the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) Returns: -------- diff --git a/cognee/modules/retrieval/code_retriever.py b/cognee/modules/retrieval/code_retriever.py index d178e7ddd..0390e1ff5 100644 --- a/cognee/modules/retrieval/code_retriever.py +++ b/cognee/modules/retrieval/code_retriever.py @@ -207,8 +207,24 @@ class CodeRetriever(BaseRetriever): logger.info(f"Returning {len(result)} code file contexts") return result - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: - """Returns the code files context.""" + async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: + """ + Returns the code files context. + + Parameters: + ----------- + + - query (str): The query string to retrieve code context for. + - context (Optional[Any]): Optional pre-fetched context; if None, it retrieves + the context for the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) + + Returns: + -------- + + - Any: The code files context, either provided or retrieved. + """ if context is None: context = await self.get_context(query) return context diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 44f27bece..da371dfdf 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -67,7 +67,7 @@ class CompletionRetriever(BaseRetriever): logger.error("DocumentChunk_text collection not found") raise NoDataError("No data found in the system, please add data first.") from error - async def get_completion(self, query: str, context: Optional[Any] = None) -> str: + async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> str: """ Generates an LLM completion using the context. @@ -80,6 +80,8 @@ class CompletionRetriever(BaseRetriever): - query (str): The query string to be used for generating a completion. - context (Optional[Any]): Optional pre-fetched context to use for generating the completion; if None, it retrieves the context for the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) Returns: -------- diff --git a/cognee/modules/retrieval/cypher_search_retriever.py b/cognee/modules/retrieval/cypher_search_retriever.py index b885891e8..d02d41f34 100644 --- a/cognee/modules/retrieval/cypher_search_retriever.py +++ b/cognee/modules/retrieval/cypher_search_retriever.py @@ -50,7 +50,7 @@ class CypherSearchRetriever(BaseRetriever): raise CypherSearchError() from e return result - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: """ Returns the graph connections context. @@ -62,6 +62,8 @@ class CypherSearchRetriever(BaseRetriever): - query (str): The query to retrieve context. - context (Optional[Any]): Optional context to use, otherwise fetched using the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) Returns: -------- diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 6aca21808..3398d3fc2 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -47,6 +47,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): self, query: str, context: Optional[List[Edge]] = None, + session_id: Optional[str] = None, context_extension_rounds=4, ) -> List[str]: """ @@ -64,6 +65,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - query (str): The input query for which the completion is generated. - context (Optional[Any]): The existing context to use for enhancing the query; if None, it will be initialized from triplets generated for the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) - context_extension_rounds: The maximum number of rounds to extend the context with new triplets before halting. (default 4) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 4602dd59d..715a1406c 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -58,6 +58,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self, query: str, context: Optional[List[Edge]] = None, + session_id: Optional[str] = None, max_iter=4, ) -> List[str]: """ @@ -74,6 +75,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - query (str): The user's query to be processed and answered. - context (Optional[Any]): Optional context that may assist in answering the query. If not provided, it will be fetched based on the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) - max_iter: The maximum number of iterations to refine the answer and generate follow-up questions. (default 4) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 60609ba56..64c2be2ae 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -148,6 +148,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): - query (str): The query string for which a completion is generated. - context (Optional[Any]): Optional context to use for generating the completion; if not provided, context is retrieved based on the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) Returns: -------- diff --git a/cognee/modules/retrieval/lexical_retriever.py b/cognee/modules/retrieval/lexical_retriever.py index fa80fe320..5c6e5dc20 100644 --- a/cognee/modules/retrieval/lexical_retriever.py +++ b/cognee/modules/retrieval/lexical_retriever.py @@ -116,8 +116,24 @@ class LexicalRetriever(BaseRetriever): else: return [self.payloads[chunk_id] for chunk_id, _ in top_results] - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: - """Returns context for the given query (retrieves if not provided).""" + async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: + """ + Returns context for the given query (retrieves if not provided). + + Parameters: + ----------- + + - query (str): The query string to retrieve context for. + - context (Optional[Any]): Optional pre-fetched context; if None, it retrieves + the context for the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) + + Returns: + -------- + + - Any: The context, either provided or retrieved. + """ if context is None: context = await self.get_context(query) return context diff --git a/cognee/modules/retrieval/natural_language_retriever.py b/cognee/modules/retrieval/natural_language_retriever.py index 35c9f6666..992da326e 100644 --- a/cognee/modules/retrieval/natural_language_retriever.py +++ b/cognee/modules/retrieval/natural_language_retriever.py @@ -125,7 +125,7 @@ class NaturalLanguageRetriever(BaseRetriever): return await self._execute_cypher_query(query, graph_engine) - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: """ Returns a completion based on the query and context. @@ -139,6 +139,8 @@ class NaturalLanguageRetriever(BaseRetriever): - query (str): The natural language query to get a completion from. - context (Optional[Any]): The context in which to base the completion; if not provided, it will be retrieved using the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) Returns: -------- diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index df35cdc51..df12aa80f 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever): logger.info(f"Returning {len(summary_payloads)} summary payloads") return summary_payloads - async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs) -> Any: """ Generates a completion using summaries context. @@ -75,6 +75,8 @@ class SummariesRetriever(BaseRetriever): - query (str): The search query for generating the completion. - context (Optional[Any]): Optional context for the completion; if not provided, will be retrieved based on the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) Returns: -------- diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 7fcb68755..9b59ba1be 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -137,8 +137,24 @@ class TemporalRetriever(GraphCompletionRetriever): return self.descriptions_to_string(top_k_events) - async def get_completion(self, query: str, context: Optional[str] = None) -> List[str]: - """Generates a response using the query and optional context.""" + async def get_completion(self, query: str, context: Optional[str] = None, session_id: Optional[str] = None) -> List[str]: + """ + Generates a response using the query and optional context. + + Parameters: + ----------- + + - query (str): The query string for which a completion is generated. + - context (Optional[str]): Optional context to use; if None, it will be + retrieved based on the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) + + Returns: + -------- + + - List[str]: A list containing the generated completion. + """ if not context: context = await self.get_context(query=query)