ruff formatting

This commit is contained in:
hajdul88 2025-10-15 18:02:10 +02:00
parent b36772e8bf
commit 66280442ac
13 changed files with 60 additions and 34 deletions

View file

@ -16,9 +16,11 @@ session_user = ContextVar("session_user", default=None)
soup_crawler_config = ContextVar("soup_crawler_config", default=None)
tavily_config = ContextVar("tavily_config", default=None)
async def set_session_user_context_variable(user):
session_user.set(user)
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
"""
If backend access control is enabled this function will ensure all datasets have their own databases,

View file

@ -43,13 +43,13 @@ class CacheDBInterface(ABC):
@abstractmethod
async def add_qa(
self,
user_id: str,
session_id: str,
question: str,
context: str,
answer: str,
ttl: int | None = 86400,
self,
user_id: str,
session_id: str,
question: str,
context: str,
answer: str,
ttl: int | None = 86400,
):
"""
Add a Q/A/context triplet to a cache session.

View file

@ -166,7 +166,7 @@ async def main():
session_id,
"What is Redis?",
"Database context",
"Redis is an in-memory data store."
"Redis is an in-memory data store.",
)
await adapter.add_qa(
@ -174,7 +174,7 @@ async def main():
session_id,
"Who created Redis?",
"History context",
"Salvatore Sanfilippo (antirez)."
"Salvatore Sanfilippo (antirez).",
)
print(await adapter.get_all_qas(user_id, session_id))

View file

@ -77,7 +77,9 @@ class EntityCompletionRetriever(BaseRetriever):
logger.error(f"Context retrieval failed: {str(e)}")
return None
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> List[str]:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> List[str]:
"""
Generate completion using provided context or fetch new context.

View file

@ -61,7 +61,9 @@ class ChunksRetriever(BaseRetriever):
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
return chunk_payloads
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Generates a completion using document chunks context.

View file

@ -207,22 +207,24 @@ class CodeRetriever(BaseRetriever):
logger.info(f"Returning {len(result)} code file contexts")
return result
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns the code files context.
Parameters:
-----------
- query (str): The query string to retrieve code context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The code files context, either provided or retrieved.
"""
if context is None:

View file

@ -67,7 +67,9 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> str:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> str:
"""
Generates an LLM completion using the context.

View file

@ -50,7 +50,9 @@ class CypherSearchRetriever(BaseRetriever):
raise CypherSearchError() from e
return result
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns the graph connections context.

View file

@ -177,7 +177,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
),
)
else:
completion = await generate_completion(
@ -195,11 +195,17 @@ class GraphCompletionRetriever(BaseGraphRetriever):
if session_save:
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
cache_engine = get_cache_engine()
if session_id is None:
session_id = 'default_session'
await cache_engine.add_qa(str(user_id), session_id=session_id, question=query, context=context_summary, answer=completion)
session_id = "default_session"
await cache_engine.add_qa(
str(user_id),
session_id=session_id,
question=query,
context=context_summary,
answer=completion,
)
return [completion]

View file

@ -116,22 +116,24 @@ class LexicalRetriever(BaseRetriever):
else:
return [self.payloads[chunk_id] for chunk_id, _ in top_results]
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns context for the given query (retrieves if not provided).
Parameters:
-----------
- query (str): The query string to retrieve context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The context, either provided or retrieved.
"""
if context is None:

View file

@ -125,7 +125,9 @@ class NaturalLanguageRetriever(BaseRetriever):
return await self._execute_cypher_query(query, graph_engine)
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns a completion based on the query and context.

View file

@ -62,7 +62,9 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs) -> Any:
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs
) -> Any:
"""
Generates a completion using summaries context.

View file

@ -137,22 +137,24 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events)
async def get_completion(self, query: str, context: Optional[str] = None, session_id: Optional[str] = None) -> List[str]:
async def get_completion(
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
) -> List[str]:
"""
Generates a response using the query and optional context.
Parameters:
-----------
- query (str): The query string for which a completion is generated.
- context (Optional[str]): Optional context to use; if None, it will be
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- List[str]: A list containing the generated completion.
"""
if not context: