ruff formatting

This commit is contained in:
hajdul88 2025-10-15 18:02:10 +02:00
parent b36772e8bf
commit 66280442ac
13 changed files with 60 additions and 34 deletions

View file

@ -16,9 +16,11 @@ session_user = ContextVar("session_user", default=None)
soup_crawler_config = ContextVar("soup_crawler_config", default=None) soup_crawler_config = ContextVar("soup_crawler_config", default=None)
tavily_config = ContextVar("tavily_config", default=None) tavily_config = ContextVar("tavily_config", default=None)
async def set_session_user_context_variable(user): async def set_session_user_context_variable(user):
session_user.set(user) session_user.set(user)
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID): async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
""" """
If backend access control is enabled this function will ensure all datasets have their own databases, If backend access control is enabled this function will ensure all datasets have their own databases,

View file

@ -43,13 +43,13 @@ class CacheDBInterface(ABC):
@abstractmethod @abstractmethod
async def add_qa( async def add_qa(
self, self,
user_id: str, user_id: str,
session_id: str, session_id: str,
question: str, question: str,
context: str, context: str,
answer: str, answer: str,
ttl: int | None = 86400, ttl: int | None = 86400,
): ):
""" """
Add a Q/A/context triplet to a cache session. Add a Q/A/context triplet to a cache session.

View file

@ -166,7 +166,7 @@ async def main():
session_id, session_id,
"What is Redis?", "What is Redis?",
"Database context", "Database context",
"Redis is an in-memory data store." "Redis is an in-memory data store.",
) )
await adapter.add_qa( await adapter.add_qa(
@ -174,7 +174,7 @@ async def main():
session_id, session_id,
"Who created Redis?", "Who created Redis?",
"History context", "History context",
"Salvatore Sanfilippo (antirez)." "Salvatore Sanfilippo (antirez).",
) )
print(await adapter.get_all_qas(user_id, session_id)) print(await adapter.get_all_qas(user_id, session_id))

View file

@ -77,7 +77,9 @@ class EntityCompletionRetriever(BaseRetriever):
logger.error(f"Context retrieval failed: {str(e)}") logger.error(f"Context retrieval failed: {str(e)}")
return None return None
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> List[str]: async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> List[str]:
""" """
Generate completion using provided context or fetch new context. Generate completion using provided context or fetch new context.

View file

@ -61,7 +61,9 @@ class ChunksRetriever(BaseRetriever):
logger.info(f"Returning {len(chunk_payloads)} chunk payloads") logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
return chunk_payloads return chunk_payloads
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
""" """
Generates a completion using document chunks context. Generates a completion using document chunks context.

View file

@ -207,7 +207,9 @@ class CodeRetriever(BaseRetriever):
logger.info(f"Returning {len(result)} code file contexts") logger.info(f"Returning {len(result)} code file contexts")
return result return result
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
""" """
Returns the code files context. Returns the code files context.

View file

@ -67,7 +67,9 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found") logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> str: async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> str:
""" """
Generates an LLM completion using the context. Generates an LLM completion using the context.

View file

@ -50,7 +50,9 @@ class CypherSearchRetriever(BaseRetriever):
raise CypherSearchError() from e raise CypherSearchError() from e
return result return result
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
""" """
Returns the graph connections context. Returns the graph connections context.

View file

@ -177,7 +177,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
) ),
) )
else: else:
completion = await generate_completion( completion = await generate_completion(
@ -195,11 +195,17 @@ class GraphCompletionRetriever(BaseGraphRetriever):
if session_save: if session_save:
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
cache_engine = get_cache_engine() cache_engine = get_cache_engine()
if session_id is None: if session_id is None:
session_id = 'default_session' session_id = "default_session"
await cache_engine.add_qa(str(user_id), session_id=session_id, question=query, context=context_summary, answer=completion) await cache_engine.add_qa(
str(user_id),
session_id=session_id,
question=query,
context=context_summary,
answer=completion,
)
return [completion] return [completion]

View file

@ -116,7 +116,9 @@ class LexicalRetriever(BaseRetriever):
else: else:
return [self.payloads[chunk_id] for chunk_id, _ in top_results] return [self.payloads[chunk_id] for chunk_id, _ in top_results]
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
""" """
Returns context for the given query (retrieves if not provided). Returns context for the given query (retrieves if not provided).

View file

@ -125,7 +125,9 @@ class NaturalLanguageRetriever(BaseRetriever):
return await self._execute_cypher_query(query, graph_engine) return await self._execute_cypher_query(query, graph_engine)
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
""" """
Returns a completion based on the query and context. Returns a completion based on the query and context.

View file

@ -62,7 +62,9 @@ class SummariesRetriever(BaseRetriever):
logger.info(f"Returning {len(summary_payloads)} summary payloads") logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs) -> Any: async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs
) -> Any:
""" """
Generates a completion using summaries context. Generates a completion using summaries context.

View file

@ -137,7 +137,9 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events) return self.descriptions_to_string(top_k_events)
async def get_completion(self, query: str, context: Optional[str] = None, session_id: Optional[str] = None) -> List[str]: async def get_completion(
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
) -> List[str]:
""" """
Generates a response using the query and optional context. Generates a response using the query and optional context.