feat: adds session_id to all retrievers + updates docstrings

This commit is contained in:
hajdul88 2025-10-15 18:01:13 +02:00
parent 0aa64403c5
commit b36772e8bf
12 changed files with 80 additions and 12 deletions

View file

@ -77,7 +77,7 @@ class EntityCompletionRetriever(BaseRetriever):
logger.error(f"Context retrieval failed: {str(e)}") logger.error(f"Context retrieval failed: {str(e)}")
return None return None
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]: async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> List[str]:
""" """
Generate completion using provided context or fetch new context. Generate completion using provided context or fetch new context.
@ -91,6 +91,8 @@ class EntityCompletionRetriever(BaseRetriever):
- query (str): The query string for which completion is being generated. - query (str): The query string for which completion is being generated.
- context (Optional[Any]): Optional context to be used for generating completion; - context (Optional[Any]): Optional context to be used for generating completion;
fetched if not provided. (default None) fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns: Returns:
-------- --------

View file

@ -61,7 +61,7 @@ class ChunksRetriever(BaseRetriever):
logger.info(f"Returning {len(chunk_payloads)} chunk payloads") logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
return chunk_payloads return chunk_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
""" """
Generates a completion using document chunks context. Generates a completion using document chunks context.
@ -74,6 +74,8 @@ class ChunksRetriever(BaseRetriever):
- query (str): The query string to be used for generating a completion. - query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional pre-fetched context to use for generating the - context (Optional[Any]): Optional pre-fetched context to use for generating the
completion; if None, it retrieves the context for the query. (default None) completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns: Returns:
-------- --------

View file

@ -207,8 +207,24 @@ class CodeRetriever(BaseRetriever):
logger.info(f"Returning {len(result)} code file contexts") logger.info(f"Returning {len(result)} code file contexts")
return result return result
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
"""Returns the code files context.""" """
Returns the code files context.
Parameters:
-----------
- query (str): The query string to retrieve code context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The code files context, either provided or retrieved.
"""
if context is None: if context is None:
context = await self.get_context(query) context = await self.get_context(query)
return context return context

View file

@ -67,7 +67,7 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found") logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> str: async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> str:
""" """
Generates an LLM completion using the context. Generates an LLM completion using the context.
@ -80,6 +80,8 @@ class CompletionRetriever(BaseRetriever):
- query (str): The query string to be used for generating a completion. - query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional pre-fetched context to use for generating the - context (Optional[Any]): Optional pre-fetched context to use for generating the
completion; if None, it retrieves the context for the query. (default None) completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns: Returns:
-------- --------

View file

@ -50,7 +50,7 @@ class CypherSearchRetriever(BaseRetriever):
raise CypherSearchError() from e raise CypherSearchError() from e
return result return result
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
""" """
Returns the graph connections context. Returns the graph connections context.
@ -62,6 +62,8 @@ class CypherSearchRetriever(BaseRetriever):
- query (str): The query to retrieve context. - query (str): The query to retrieve context.
- context (Optional[Any]): Optional context to use, otherwise fetched using the - context (Optional[Any]): Optional context to use, otherwise fetched using the
query. (default None) query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns: Returns:
-------- --------

View file

@ -47,6 +47,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
self, self,
query: str, query: str,
context: Optional[List[Edge]] = None, context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4, context_extension_rounds=4,
) -> List[str]: ) -> List[str]:
""" """
@ -64,6 +65,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
- query (str): The input query for which the completion is generated. - query (str): The input query for which the completion is generated.
- context (Optional[Any]): The existing context to use for enhancing the query; if - context (Optional[Any]): The existing context to use for enhancing the query; if
None, it will be initialized from triplets generated for the query. (default None) None, it will be initialized from triplets generated for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with - context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4) new triplets before halting. (default 4)

View file

@ -58,6 +58,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self, self,
query: str, query: str,
context: Optional[List[Edge]] = None, context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4, max_iter=4,
) -> List[str]: ) -> List[str]:
""" """
@ -74,6 +75,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- query (str): The user's query to be processed and answered. - query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query. - context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None) If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate - max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4) follow-up questions. (default 4)

View file

@ -148,6 +148,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
- query (str): The query string for which a completion is generated. - query (str): The query string for which a completion is generated.
- context (Optional[Any]): Optional context to use for generating the completion; if - context (Optional[Any]): Optional context to use for generating the completion; if
not provided, context is retrieved based on the query. (default None) not provided, context is retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns: Returns:
-------- --------

View file

@ -116,8 +116,24 @@ class LexicalRetriever(BaseRetriever):
else: else:
return [self.payloads[chunk_id] for chunk_id, _ in top_results] return [self.payloads[chunk_id] for chunk_id, _ in top_results]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
"""Returns context for the given query (retrieves if not provided).""" """
Returns context for the given query (retrieves if not provided).
Parameters:
-----------
- query (str): The query string to retrieve context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The context, either provided or retrieved.
"""
if context is None: if context is None:
context = await self.get_context(query) context = await self.get_context(query)
return context return context

View file

@ -125,7 +125,7 @@ class NaturalLanguageRetriever(BaseRetriever):
return await self._execute_cypher_query(query, graph_engine) return await self._execute_cypher_query(query, graph_engine)
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
""" """
Returns a completion based on the query and context. Returns a completion based on the query and context.
@ -139,6 +139,8 @@ class NaturalLanguageRetriever(BaseRetriever):
- query (str): The natural language query to get a completion from. - query (str): The natural language query to get a completion from.
- context (Optional[Any]): The context in which to base the completion; if not - context (Optional[Any]): The context in which to base the completion; if not
provided, it will be retrieved using the query. (default None) provided, it will be retrieved using the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns: Returns:
-------- --------

View file

@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads") logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs) -> Any:
""" """
Generates a completion using summaries context. Generates a completion using summaries context.
@ -75,6 +75,8 @@ class SummariesRetriever(BaseRetriever):
- query (str): The search query for generating the completion. - query (str): The search query for generating the completion.
- context (Optional[Any]): Optional context for the completion; if not provided, - context (Optional[Any]): Optional context for the completion; if not provided,
will be retrieved based on the query. (default None) will be retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns: Returns:
-------- --------

View file

@ -137,8 +137,24 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events) return self.descriptions_to_string(top_k_events)
async def get_completion(self, query: str, context: Optional[str] = None) -> List[str]: async def get_completion(self, query: str, context: Optional[str] = None, session_id: Optional[str] = None) -> List[str]:
"""Generates a response using the query and optional context.""" """
Generates a response using the query and optional context.
Parameters:
-----------
- query (str): The query string for which a completion is generated.
- context (Optional[str]): Optional context to use; if None, it will be
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- List[str]: A list containing the generated completion.
"""
if not context: if not context:
context = await self.get_context(query=query) context = await self.get_context(query=query)