ruff formatting
This commit is contained in:
parent
b36772e8bf
commit
66280442ac
13 changed files with 60 additions and 34 deletions
|
|
@ -16,9 +16,11 @@ session_user = ContextVar("session_user", default=None)
|
|||
soup_crawler_config = ContextVar("soup_crawler_config", default=None)
|
||||
tavily_config = ContextVar("tavily_config", default=None)
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
"""
|
||||
If backend access control is enabled this function will ensure all datasets have their own databases,
|
||||
|
|
|
|||
|
|
@ -43,13 +43,13 @@ class CacheDBInterface(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def add_qa(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
question: str,
|
||||
context: str,
|
||||
answer: str,
|
||||
ttl: int | None = 86400,
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
question: str,
|
||||
context: str,
|
||||
answer: str,
|
||||
ttl: int | None = 86400,
|
||||
):
|
||||
"""
|
||||
Add a Q/A/context triplet to a cache session.
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ async def main():
|
|||
session_id,
|
||||
"What is Redis?",
|
||||
"Database context",
|
||||
"Redis is an in-memory data store."
|
||||
"Redis is an in-memory data store.",
|
||||
)
|
||||
|
||||
await adapter.add_qa(
|
||||
|
|
@ -174,7 +174,7 @@ async def main():
|
|||
session_id,
|
||||
"Who created Redis?",
|
||||
"History context",
|
||||
"Salvatore Sanfilippo (antirez)."
|
||||
"Salvatore Sanfilippo (antirez).",
|
||||
)
|
||||
|
||||
print(await adapter.get_all_qas(user_id, session_id))
|
||||
|
|
|
|||
|
|
@ -77,7 +77,9 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
logger.error(f"Context retrieval failed: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> List[str]:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion using provided context or fetch new context.
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,9 @@ class ChunksRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
|
||||
return chunk_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Generates a completion using document chunks context.
|
||||
|
||||
|
|
|
|||
|
|
@ -207,22 +207,24 @@ class CodeRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(result)} code file contexts")
|
||||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns the code files context.
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
|
||||
- query (str): The query string to retrieve code context for.
|
||||
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
||||
the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
|
||||
- Any: The code files context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
|
|
|
|||
|
|
@ -67,7 +67,9 @@ class CompletionRetriever(BaseRetriever):
|
|||
logger.error("DocumentChunk_text collection not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> str:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,9 @@ class CypherSearchRetriever(BaseRetriever):
|
|||
raise CypherSearchError() from e
|
||||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns the graph connections context.
|
||||
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
completion = await generate_completion(
|
||||
|
|
@ -195,11 +195,17 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
if session_save:
|
||||
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||
|
||||
cache_engine = get_cache_engine()
|
||||
if session_id is None:
|
||||
session_id = 'default_session'
|
||||
await cache_engine.add_qa(str(user_id), session_id=session_id, question=query, context=context_summary, answer=completion)
|
||||
|
||||
session_id = "default_session"
|
||||
await cache_engine.add_qa(
|
||||
str(user_id),
|
||||
session_id=session_id,
|
||||
question=query,
|
||||
context=context_summary,
|
||||
answer=completion,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
||||
|
|
|
|||
|
|
@ -116,22 +116,24 @@ class LexicalRetriever(BaseRetriever):
|
|||
else:
|
||||
return [self.payloads[chunk_id] for chunk_id, _ in top_results]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns context for the given query (retrieves if not provided).
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
|
||||
- query (str): The query string to retrieve context for.
|
||||
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
||||
the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
|
||||
- Any: The context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
|
|
|
|||
|
|
@ -125,7 +125,9 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
|
||||
return await self._execute_cypher_query(query, graph_engine)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns a completion based on the query and context.
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,9 @@ class SummariesRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
||||
return summary_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs) -> Any:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Generates a completion using summaries context.
|
||||
|
||||
|
|
|
|||
|
|
@ -137,22 +137,24 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
|
||||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[str] = None, session_id: Optional[str] = None) -> List[str]:
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generates a response using the query and optional context.
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
|
||||
- query (str): The query string for which a completion is generated.
|
||||
- context (Optional[str]): Optional context to use; if None, it will be
|
||||
retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
|
||||
- List[str]: A list containing the generated completion.
|
||||
"""
|
||||
if not context:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue