feat: adds session id to get_completion methods

This commit is contained in:
hajdul88 2025-10-16 16:26:58 +02:00
parent 7149f8c45b
commit 91a22e8bc4
5 changed files with 23 additions and 5 deletions

View file

@ -26,6 +26,7 @@ async def search(
last_k: Optional[int] = 1,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -114,6 +115,8 @@ async def search(
save_interaction: Save interaction (query, context, answer connected to triplet endpoints) results into the graph or not
session_id: Optional session identifier for caching Q&A interactions. Defaults to 'default_session' if None.
Returns:
list: Search results in format determined by query_type:
@ -192,6 +195,7 @@ async def search(
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
)
return filtered_search_results

View file

@ -13,6 +13,8 @@ class BaseGraphRetriever(ABC):
pass
@abstractmethod
async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
async def get_completion(
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
) -> str:
"""Generates a response using the query and optional context (triplets)."""
pass

View file

@ -11,6 +11,8 @@ class BaseRetriever(ABC):
pass
@abstractmethod
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""Generates a response using the query and optional context."""
pass

View file

@ -19,6 +19,7 @@ async def no_access_control_search(
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
session_id: Optional[str] = None,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
search_tools = await get_search_type_tools(
query_type=query_type,
@ -38,7 +39,7 @@ async def no_access_control_search(
return None, await get_context(query_text), []
context = await get_context(query_text)
result = await get_completion(query_text, context)
result = await get_completion(query_text, context, session_id=session_id)
else:
unknown_tool = search_tools[0]
result = await unknown_tool(query_text)

View file

@ -42,6 +42,7 @@ async def search(
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
@ -77,6 +78,7 @@ async def search(
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
)
else:
search_results = [
@ -91,6 +93,7 @@ async def search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
session_id=session_id,
)
]
@ -195,6 +198,7 @@ async def authorized_search(
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@ -221,6 +225,7 @@ async def authorized_search(
save_interaction=save_interaction,
last_k=last_k,
only_context=True,
session_id=session_id,
)
context = {}
@ -263,7 +268,7 @@ async def authorized_search(
return combined_context
combined_context = prepare_combined_context(context)
completion = await get_completion(query_text, combined_context)
completion = await get_completion(query_text, combined_context, session_id=session_id)
return completion, combined_context, datasets
@ -280,6 +285,7 @@ async def authorized_search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
session_id=session_id,
)
return search_results
@ -298,6 +304,7 @@ async def search_in_datasets_context(
last_k: Optional[int] = None,
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -317,6 +324,7 @@ async def search_in_datasets_context(
last_k: Optional[int] = None,
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
@ -340,7 +348,7 @@ async def search_in_datasets_context(
return None, await get_context(query_text), [dataset]
search_context = context or await get_context(query_text)
search_result = await get_completion(query_text, search_context)
search_result = await get_completion(query_text, search_context, session_id=session_id)
return search_result, search_context, [dataset]
else:
@ -365,6 +373,7 @@ async def search_in_datasets_context(
last_k=last_k,
only_context=only_context,
context=context,
session_id=session_id,
)
)