147 lines
5.5 KiB
Python
147 lines
5.5 KiB
Python
import asyncio
|
|
from typing import Any, Optional, Type, List
|
|
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
|
from cognee.modules.retrieval.utils.session_cache import (
|
|
save_conversation_history,
|
|
get_conversation_history,
|
|
)
|
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
from cognee.context_global_variables import session_user
|
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
|
|
|
logger = get_logger("CompletionRetriever")
|
|
|
|
|
|
class CompletionRetriever(BaseRetriever):
|
|
"""
|
|
Retriever for handling LLM-based completion searches.
|
|
|
|
Public methods:
|
|
- get_context(query: str) -> str
|
|
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
user_prompt_path: str = "context_for_question.txt",
|
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
system_prompt: Optional[str] = None,
|
|
top_k: Optional[int] = 1,
|
|
):
|
|
"""Initialize retriever with optional custom prompt paths."""
|
|
self.user_prompt_path = user_prompt_path
|
|
self.system_prompt_path = system_prompt_path
|
|
self.top_k = top_k if top_k is not None else 1
|
|
self.system_prompt = system_prompt
|
|
|
|
async def get_context(self, query: str) -> str:
|
|
"""
|
|
Retrieves relevant document chunks as context.
|
|
|
|
Fetches document chunks based on a query from a vector engine and combines their text.
|
|
Returns empty string if no chunks are found. Raises NoDataError if the collection is not
|
|
found.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- query (str): The query string used to search for relevant document chunks.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- str: A string containing the combined text of the retrieved document chunks, or an
|
|
empty string if none are found.
|
|
"""
|
|
vector_engine = get_vector_engine()
|
|
|
|
try:
|
|
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
|
|
|
if len(found_chunks) == 0:
|
|
return ""
|
|
|
|
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
|
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
|
combined_context = "\n".join(chunks_payload)
|
|
return combined_context
|
|
except CollectionNotFoundError as error:
|
|
logger.error("DocumentChunk_text collection not found")
|
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
|
|
async def get_completion(
|
|
self,
|
|
query: str,
|
|
context: Optional[Any] = None,
|
|
session_id: Optional[str] = None,
|
|
response_model: Type = str,
|
|
) -> List[Any]:
|
|
"""
|
|
Generates an LLM completion using the context.
|
|
|
|
Retrieves context if not provided and generates a completion based on the query and
|
|
context using an external completion generator.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- query (str): The query string to be used for generating a completion.
|
|
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
|
completion; if None, it retrieves the context for the query. (default None)
|
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
defaults to 'default_session'. (default None)
|
|
- response_model (Type): The Pydantic model type for structured output. (default str)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- Any: The generated completion based on the provided query and context.
|
|
"""
|
|
if context is None:
|
|
context = await self.get_context(query)
|
|
|
|
# Check if we need to generate context summary for caching
|
|
cache_config = CacheConfig()
|
|
user = session_user.get()
|
|
user_id = getattr(user, "id", None)
|
|
session_save = user_id and cache_config.caching
|
|
|
|
if session_save:
|
|
conversation_history = await get_conversation_history(session_id=session_id)
|
|
|
|
context_summary, completion = await asyncio.gather(
|
|
summarize_text(context),
|
|
generate_completion(
|
|
query=query,
|
|
context=context,
|
|
user_prompt_path=self.user_prompt_path,
|
|
system_prompt_path=self.system_prompt_path,
|
|
system_prompt=self.system_prompt,
|
|
conversation_history=conversation_history,
|
|
response_model=response_model,
|
|
),
|
|
)
|
|
else:
|
|
completion = await generate_completion(
|
|
query=query,
|
|
context=context,
|
|
user_prompt_path=self.user_prompt_path,
|
|
system_prompt_path=self.system_prompt_path,
|
|
system_prompt=self.system_prompt,
|
|
response_model=response_model,
|
|
)
|
|
|
|
if session_save:
|
|
await save_conversation_history(
|
|
query=query,
|
|
context_summary=context_summary,
|
|
answer=completion,
|
|
session_id=session_id,
|
|
)
|
|
|
|
return [completion]
|