feat: adds session_id to all retrievers + updates docstrings
This commit is contained in:
parent
0aa64403c5
commit
b36772e8bf
12 changed files with 80 additions and 12 deletions
|
|
@ -77,7 +77,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
logger.error(f"Context retrieval failed: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Generate completion using provided context or fetch new context.
|
||||
|
||||
|
|
@ -91,6 +91,8 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
- query (str): The query string for which completion is being generated.
|
||||
- context (Optional[Any]): Optional context to be used for generating completion;
|
||||
fetched if not provided. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class ChunksRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
|
||||
return chunk_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Generates a completion using document chunks context.
|
||||
|
||||
|
|
@ -74,6 +74,8 @@ class ChunksRetriever(BaseRetriever):
|
|||
- query (str): The query string to be used for generating a completion.
|
||||
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||
completion; if None, it retrieves the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -207,8 +207,24 @@ class CodeRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(result)} code file contexts")
|
||||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns the code files context."""
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Returns the code files context.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to retrieve code context for.
|
||||
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
||||
the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The code files context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
logger.error("DocumentChunk_text collection not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
|
|
@ -80,6 +80,8 @@ class CompletionRetriever(BaseRetriever):
|
|||
- query (str): The query string to be used for generating a completion.
|
||||
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||
completion; if None, it retrieves the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class CypherSearchRetriever(BaseRetriever):
|
|||
raise CypherSearchError() from e
|
||||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Returns the graph connections context.
|
||||
|
||||
|
|
@ -62,6 +62,8 @@ class CypherSearchRetriever(BaseRetriever):
|
|||
- query (str): The query to retrieve context.
|
||||
- context (Optional[Any]): Optional context to use, otherwise fetched using the
|
||||
query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
context_extension_rounds=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
|
@ -64,6 +65,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
- query (str): The input query for which the completion is generated.
|
||||
- context (Optional[Any]): The existing context to use for enhancing the query; if
|
||||
None, it will be initialized from triplets generated for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- context_extension_rounds: The maximum number of rounds to extend the context with
|
||||
new triplets before halting. (default 4)
|
||||
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
|
@ -74,6 +75,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
|
||||
|
|
|
|||
|
|
@ -148,6 +148,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
- query (str): The query string for which a completion is generated.
|
||||
- context (Optional[Any]): Optional context to use for generating the completion; if
|
||||
not provided, context is retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -116,8 +116,24 @@ class LexicalRetriever(BaseRetriever):
|
|||
else:
|
||||
return [self.payloads[chunk_id] for chunk_id, _ in top_results]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns context for the given query (retrieves if not provided)."""
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Returns context for the given query (retrieves if not provided).
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to retrieve context for.
|
||||
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
||||
the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
|
||||
return await self._execute_cypher_query(query, graph_engine)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Returns a completion based on the query and context.
|
||||
|
||||
|
|
@ -139,6 +139,8 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
- query (str): The natural language query to get a completion from.
|
||||
- context (Optional[Any]): The context in which to base the completion; if not
|
||||
provided, it will be retrieved using the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
||||
return summary_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs) -> Any:
|
||||
"""
|
||||
Generates a completion using summaries context.
|
||||
|
||||
|
|
@ -75,6 +75,8 @@ class SummariesRetriever(BaseRetriever):
|
|||
- query (str): The search query for generating the completion.
|
||||
- context (Optional[Any]): Optional context for the completion; if not provided,
|
||||
will be retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
|
|||
|
|
@ -137,8 +137,24 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
|
||||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[str] = None) -> List[str]:
|
||||
"""Generates a response using the query and optional context."""
|
||||
async def get_completion(self, query: str, context: Optional[str] = None, session_id: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Generates a response using the query and optional context.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string for which a completion is generated.
|
||||
- context (Optional[str]): Optional context to use; if None, it will be
|
||||
retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated completion.
|
||||
"""
|
||||
if not context:
|
||||
context = await self.get_context(query=query)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue