feat: adds session_id to all retrievers + updates docstrings

This commit is contained in:
hajdul88 2025-10-15 18:01:13 +02:00
parent 0aa64403c5
commit b36772e8bf
12 changed files with 80 additions and 12 deletions

View file

@ -77,7 +77,7 @@ class EntityCompletionRetriever(BaseRetriever):
logger.error(f"Context retrieval failed: {str(e)}")
return None
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> List[str]:
"""
Generate completion using provided context or fetch new context.
@ -91,6 +91,8 @@ class EntityCompletionRetriever(BaseRetriever):
- query (str): The query string for which completion is being generated.
- context (Optional[Any]): Optional context to be used for generating completion;
fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -61,7 +61,7 @@ class ChunksRetriever(BaseRetriever):
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
return chunk_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
"""
Generates a completion using document chunks context.
@ -74,6 +74,8 @@ class ChunksRetriever(BaseRetriever):
- query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional pre-fetched context to use for generating the
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -207,8 +207,24 @@ class CodeRetriever(BaseRetriever):
logger.info(f"Returning {len(result)} code file contexts")
return result
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Returns the code files context."""
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
"""
Returns the code files context.
Parameters:
-----------
- query (str): The query string to retrieve code context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The code files context, either provided or retrieved.
"""
if context is None:
context = await self.get_context(query)
return context

View file

@ -67,7 +67,7 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> str:
"""
Generates an LLM completion using the context.
@ -80,6 +80,8 @@ class CompletionRetriever(BaseRetriever):
- query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional pre-fetched context to use for generating the
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -50,7 +50,7 @@ class CypherSearchRetriever(BaseRetriever):
raise CypherSearchError() from e
return result
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
"""
Returns the graph connections context.
@ -62,6 +62,8 @@ class CypherSearchRetriever(BaseRetriever):
- query (str): The query to retrieve context.
- context (Optional[Any]): Optional context to use, otherwise fetched using the
query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -47,6 +47,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
) -> List[str]:
"""
@ -64,6 +65,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
- query (str): The input query for which the completion is generated.
- context (Optional[Any]): The existing context to use for enhancing the query; if
None, it will be initialized from triplets generated for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4)

View file

@ -58,6 +58,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4,
) -> List[str]:
"""
@ -74,6 +75,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)

View file

@ -148,6 +148,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
- query (str): The query string for which a completion is generated.
- context (Optional[Any]): Optional context to use for generating the completion; if
not provided, context is retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -116,8 +116,24 @@ class LexicalRetriever(BaseRetriever):
else:
return [self.payloads[chunk_id] for chunk_id, _ in top_results]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Returns context for the given query (retrieves if not provided)."""
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
"""
Returns context for the given query (retrieves if not provided).
Parameters:
-----------
- query (str): The query string to retrieve context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The context, either provided or retrieved.
"""
if context is None:
context = await self.get_context(query)
return context

View file

@ -125,7 +125,7 @@ class NaturalLanguageRetriever(BaseRetriever):
return await self._execute_cypher_query(query, graph_engine)
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
"""
Returns a completion based on the query and context.
@ -139,6 +139,8 @@ class NaturalLanguageRetriever(BaseRetriever):
- query (str): The natural language query to get a completion from.
- context (Optional[Any]): The context in which to base the completion; if not
provided, it will be retrieved using the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs) -> Any:
"""
Generates a completion using summaries context.
@ -75,6 +75,8 @@ class SummariesRetriever(BaseRetriever):
- query (str): The search query for generating the completion.
- context (Optional[Any]): Optional context for the completion; if not provided,
will be retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------

View file

@ -137,8 +137,24 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events)
async def get_completion(self, query: str, context: Optional[str] = None) -> List[str]:
"""Generates a response using the query and optional context."""
async def get_completion(self, query: str, context: Optional[str] = None, session_id: Optional[str] = None) -> List[str]:
"""
Generates a response using the query and optional context.
Parameters:
-----------
- query (str): The query string for which a completion is generated.
- context (Optional[str]): Optional context to use; if None, it will be
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- List[str]: A list containing the generated completion.
"""
if not context:
context = await self.get_context(query=query)