diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index b1d54c150..2ecf9b8d3 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -16,9 +16,11 @@ session_user = ContextVar("session_user", default=None) soup_crawler_config = ContextVar("soup_crawler_config", default=None) tavily_config = ContextVar("tavily_config", default=None) + async def set_session_user_context_variable(user): session_user.set(user) + async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID): """ If backend access control is enabled this function will ensure all datasets have their own databases, diff --git a/cognee/infrastructure/databases/cache/cache_db_interface.py b/cognee/infrastructure/databases/cache/cache_db_interface.py index f98103ea4..57e9a2ecd 100644 --- a/cognee/infrastructure/databases/cache/cache_db_interface.py +++ b/cognee/infrastructure/databases/cache/cache_db_interface.py @@ -43,13 +43,13 @@ class CacheDBInterface(ABC): @abstractmethod async def add_qa( - self, - user_id: str, - session_id: str, - question: str, - context: str, - answer: str, - ttl: int | None = 86400, + self, + user_id: str, + session_id: str, + question: str, + context: str, + answer: str, + ttl: int | None = 86400, ): """ Add a Q/A/context triplet to a cache session. diff --git a/cognee/infrastructure/databases/cache/redis/RedisAdapter.py b/cognee/infrastructure/databases/cache/redis/RedisAdapter.py index 4d9097b62..3d656ac90 100644 --- a/cognee/infrastructure/databases/cache/redis/RedisAdapter.py +++ b/cognee/infrastructure/databases/cache/redis/RedisAdapter.py @@ -166,7 +166,7 @@ async def main(): session_id, "What is Redis?", "Database context", - "Redis is an in-memory data store." + "Redis is an in-memory data store.", ) await adapter.add_qa( @@ -174,7 +174,7 @@ async def main(): session_id, "Who created Redis?", "History context", - "Salvatore Sanfilippo (antirez)." + "Salvatore Sanfilippo (antirez).", ) print(await adapter.get_all_qas(user_id, session_id)) diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index d8c25341f..ed66038c2 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -77,7 +77,9 @@ class EntityCompletionRetriever(BaseRetriever): logger.error(f"Context retrieval failed: {str(e)}") return None - async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> List[str]: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + ) -> List[str]: """ Generate completion using provided context or fetch new context. diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index 086473805..94b9d3fb9 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -61,7 +61,9 @@ class ChunksRetriever(BaseRetriever): logger.info(f"Returning {len(chunk_payloads)} chunk payloads") return chunk_payloads - async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + ) -> Any: """ Generates a completion using document chunks context. diff --git a/cognee/modules/retrieval/code_retriever.py b/cognee/modules/retrieval/code_retriever.py index 0390e1ff5..06f0bc266 100644 --- a/cognee/modules/retrieval/code_retriever.py +++ b/cognee/modules/retrieval/code_retriever.py @@ -207,22 +207,24 @@ class CodeRetriever(BaseRetriever): logger.info(f"Returning {len(result)} code file contexts") return result - async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + ) -> Any: """ Returns the code files context. - + Parameters: ----------- - + - query (str): The query string to retrieve code context for. - context (Optional[Any]): Optional pre-fetched context; if None, it retrieves the context for the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) - + Returns: -------- - + - Any: The code files context, either provided or retrieved. """ if context is None: diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index da371dfdf..dcddccef8 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -67,7 +67,9 @@ class CompletionRetriever(BaseRetriever): logger.error("DocumentChunk_text collection not found") raise NoDataError("No data found in the system, please add data first.") from error - async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> str: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + ) -> str: """ Generates an LLM completion using the context. diff --git a/cognee/modules/retrieval/cypher_search_retriever.py b/cognee/modules/retrieval/cypher_search_retriever.py index d02d41f34..1b4efb88d 100644 --- a/cognee/modules/retrieval/cypher_search_retriever.py +++ b/cognee/modules/retrieval/cypher_search_retriever.py @@ -50,7 +50,9 @@ class CypherSearchRetriever(BaseRetriever): raise CypherSearchError() from e return result - async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + ) -> Any: """ Returns the graph connections context. diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 64c2be2ae..83109fbd2 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -177,7 +177,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, - ) + ), ) else: completion = await generate_completion( @@ -195,11 +195,17 @@ class GraphCompletionRetriever(BaseGraphRetriever): if session_save: from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine + cache_engine = get_cache_engine() if session_id is None: - session_id = 'default_session' - await cache_engine.add_qa(str(user_id), session_id=session_id, question=query, context=context_summary, answer=completion) - + session_id = "default_session" + await cache_engine.add_qa( + str(user_id), + session_id=session_id, + question=query, + context=context_summary, + answer=completion, + ) return [completion] diff --git a/cognee/modules/retrieval/lexical_retriever.py b/cognee/modules/retrieval/lexical_retriever.py index 5c6e5dc20..71b50a0b3 100644 --- a/cognee/modules/retrieval/lexical_retriever.py +++ b/cognee/modules/retrieval/lexical_retriever.py @@ -116,22 +116,24 @@ class LexicalRetriever(BaseRetriever): else: return [self.payloads[chunk_id] for chunk_id, _ in top_results] - async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + ) -> Any: """ Returns context for the given query (retrieves if not provided). - + Parameters: ----------- - + - query (str): The query string to retrieve context for. - context (Optional[Any]): Optional pre-fetched context; if None, it retrieves the context for the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) - + Returns: -------- - + - Any: The context, either provided or retrieved. """ if context is None: diff --git a/cognee/modules/retrieval/natural_language_retriever.py b/cognee/modules/retrieval/natural_language_retriever.py index 992da326e..056a62163 100644 --- a/cognee/modules/retrieval/natural_language_retriever.py +++ b/cognee/modules/retrieval/natural_language_retriever.py @@ -125,7 +125,9 @@ class NaturalLanguageRetriever(BaseRetriever): return await self._execute_cypher_query(query, graph_engine) - async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None) -> Any: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + ) -> Any: """ Returns a completion based on the query and context. diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index df12aa80f..87b224946 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -62,7 +62,9 @@ class SummariesRetriever(BaseRetriever): logger.info(f"Returning {len(summary_payloads)} summary payloads") return summary_payloads - async def get_completion(self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs) -> Any: + async def get_completion( + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs + ) -> Any: """ Generates a completion using summaries context. diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 9b59ba1be..81f0deaee 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -137,22 +137,24 @@ class TemporalRetriever(GraphCompletionRetriever): return self.descriptions_to_string(top_k_events) - async def get_completion(self, query: str, context: Optional[str] = None, session_id: Optional[str] = None) -> List[str]: + async def get_completion( + self, query: str, context: Optional[str] = None, session_id: Optional[str] = None + ) -> List[str]: """ Generates a response using the query and optional context. - + Parameters: ----------- - + - query (str): The query string for which a completion is generated. - context (Optional[str]): Optional context to use; if None, it will be retrieved based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) - + Returns: -------- - + - List[str]: A list containing the generated completion. """ if not context: