From 91a22e8bc4585064947eea917628ab7e7b5cd1cf Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:26:58 +0200 Subject: [PATCH] feat: adds session id to get_completion methods --- cognee/api/v1/search/search.py | 4 ++++ cognee/modules/retrieval/base_graph_retriever.py | 4 +++- cognee/modules/retrieval/base_retriever.py | 4 +++- .../search/methods/no_access_control_search.py | 3 ++- cognee/modules/search/methods/search.py | 13 +++++++++++-- 5 files changed, 23 insertions(+), 5 deletions(-) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 0ebef4e84..1e438867b 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -26,6 +26,7 @@ async def search( last_k: Optional[int] = 1, only_context: bool = False, use_combined_context: bool = False, + session_id: Optional[str] = None, ) -> Union[List[SearchResult], CombinedSearchResult]: """ Search and query the knowledge graph for insights, information, and connections. @@ -114,6 +115,8 @@ async def search( save_interaction: Save interaction (query, context, answer connected to triplet endpoints) results into the graph or not + session_id: Optional session identifier for caching Q&A interactions. Defaults to 'default_session' if None. + Returns: list: Search results in format determined by query_type: @@ -192,6 +195,7 @@ async def search( last_k=last_k, only_context=only_context, use_combined_context=use_combined_context, + session_id=session_id, ) return filtered_search_results diff --git a/cognee/modules/retrieval/base_graph_retriever.py b/cognee/modules/retrieval/base_graph_retriever.py index 2aaf3468f..b0abc2991 100644 --- a/cognee/modules/retrieval/base_graph_retriever.py +++ b/cognee/modules/retrieval/base_graph_retriever.py @@ -13,6 +13,8 @@ class BaseGraphRetriever(ABC): pass @abstractmethod - async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str: + async def get_completion( + self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None + ) -> str: """Generates a response using the query and optional context (triplets).""" pass diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py index 88313b253..1533dd44f 100644 --- a/cognee/modules/retrieval/base_retriever.py +++ b/cognee/modules/retrieval/base_retriever.py @@ -11,6 +11,8 @@ class BaseRetriever(ABC): pass @abstractmethod - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + ) -> Any: """Generates a response using the query and optional context.""" pass diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py index a93fce067..c43105ca0 100644 --- a/cognee/modules/search/methods/no_access_control_search.py +++ b/cognee/modules/search/methods/no_access_control_search.py @@ -19,6 +19,7 @@ async def no_access_control_search( save_interaction: bool = False, last_k: Optional[int] = None, only_context: bool = False, + session_id: Optional[str] = None, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: search_tools = await get_search_type_tools( query_type=query_type, @@ -38,7 +39,7 @@ async def no_access_control_search( return None, await get_context(query_text), [] context = await get_context(query_text) - result = await get_completion(query_text, context) + result = await get_completion(query_text, context, session_id=session_id) else: unknown_tool = search_tools[0] result = await unknown_tool(query_text) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 958af6444..c3c59feac 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -42,6 +42,7 @@ async def search( last_k: Optional[int] = None, only_context: bool = False, use_combined_context: bool = False, + session_id: Optional[str] = None, ) -> Union[CombinedSearchResult, List[SearchResult]]: """ @@ -77,6 +78,7 @@ async def search( last_k=last_k, only_context=only_context, use_combined_context=use_combined_context, + session_id=session_id, ) else: search_results = [ @@ -91,6 +93,7 @@ async def search( save_interaction=save_interaction, last_k=last_k, only_context=only_context, + session_id=session_id, ) ] @@ -195,6 +198,7 @@ async def authorized_search( last_k: Optional[int] = None, only_context: bool = False, use_combined_context: bool = False, + session_id: Optional[str] = None, ) -> Union[ Tuple[Any, Union[List[Edge], str], List[Dataset]], List[Tuple[Any, Union[List[Edge], str], List[Dataset]]], @@ -221,6 +225,7 @@ async def authorized_search( save_interaction=save_interaction, last_k=last_k, only_context=True, + session_id=session_id, ) context = {} @@ -263,7 +268,7 @@ async def authorized_search( return combined_context combined_context = prepare_combined_context(context) - completion = await get_completion(query_text, combined_context) + completion = await get_completion(query_text, combined_context, session_id=session_id) return completion, combined_context, datasets @@ -280,6 +285,7 @@ async def authorized_search( save_interaction=save_interaction, last_k=last_k, only_context=only_context, + session_id=session_id, ) return search_results @@ -298,6 +304,7 @@ async def search_in_datasets_context( last_k: Optional[int] = None, only_context: bool = False, context: Optional[Any] = None, + session_id: Optional[str] = None, ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]: """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -317,6 +324,7 @@ async def search_in_datasets_context( last_k: Optional[int] = None, only_context: bool = False, context: Optional[Any] = None, + session_id: Optional[str] = None, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -340,7 +348,7 @@ async def search_in_datasets_context( return None, await get_context(query_text), [dataset] search_context = context or await get_context(query_text) - search_result = await get_completion(query_text, search_context) + search_result = await get_completion(query_text, search_context, session_id=session_id) return search_result, search_context, [dataset] else: @@ -365,6 +373,7 @@ async def search_in_datasets_context( last_k=last_k, only_context=only_context, context=context, + session_id=session_id, ) )