367 lines
15 KiB
Python
367 lines
15 KiB
Python
import asyncio
|
|
import json
|
|
from typing import Optional, List, Type, Any
|
|
from pydantic import BaseModel
|
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
from cognee.modules.retrieval.utils.validate_queries import validate_queries
|
|
from cognee.shared.logging_utils import get_logger
|
|
|
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
from cognee.modules.retrieval.utils.completion import (
|
|
generate_completion,
|
|
summarize_text,
|
|
)
|
|
from cognee.modules.retrieval.utils.session_cache import (
|
|
save_conversation_history,
|
|
get_conversation_history,
|
|
)
|
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
|
from cognee.context_global_variables import session_user
|
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class QueryState:
|
|
"""
|
|
Helper class containing all necessary information about the query state.
|
|
Used to keep track of important information in a more readable way, and
|
|
enable as many parallel calls to llms as possible.
|
|
"""
|
|
|
|
completion: str = ""
|
|
triplets: List[Edge] = []
|
|
context_text: str = ""
|
|
|
|
answer_text: str = ""
|
|
valid_user_prompt: str = ""
|
|
valid_system_prompt: str = ""
|
|
reasoning: str = ""
|
|
|
|
followup_question: str = ""
|
|
followup_prompt: str = ""
|
|
followup_system: str = ""
|
|
|
|
|
|
def _as_answer_text(completion: Any) -> str:
|
|
"""Convert completion to human-readable text for validation and follow-up prompts."""
|
|
if isinstance(completion, str):
|
|
return completion
|
|
if isinstance(completion, BaseModel):
|
|
# Add notice that this is a structured response
|
|
json_str = completion.model_dump_json(indent=2)
|
|
return f"[Structured Response]\n{json_str}"
|
|
try:
|
|
return json.dumps(completion, indent=2)
|
|
except TypeError:
|
|
return str(completion)
|
|
|
|
|
|
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
"""
|
|
Handles graph completion by generating responses based on a series of interactions with
|
|
a language model. This class extends from GraphCompletionRetriever and is designed to
|
|
manage the retrieval and validation process for user queries, integrating follow-up
|
|
questions based on reasoning. The public methods are:
|
|
|
|
- get_completion
|
|
|
|
Instance variables include:
|
|
- validation_system_prompt_path
|
|
- validation_user_prompt_path
|
|
- followup_system_prompt_path
|
|
- followup_user_prompt_path
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
validation_user_prompt_path: str = "cot_validation_user_prompt.txt",
|
|
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
|
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
|
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
|
system_prompt: Optional[str] = None,
|
|
top_k: Optional[int] = 5,
|
|
node_type: Optional[Type] = None,
|
|
node_name: Optional[List[str]] = None,
|
|
save_interaction: bool = False,
|
|
wide_search_top_k: Optional[int] = 100,
|
|
triplet_distance_penalty: Optional[float] = 3.5,
|
|
):
|
|
super().__init__(
|
|
user_prompt_path=user_prompt_path,
|
|
system_prompt_path=system_prompt_path,
|
|
system_prompt=system_prompt,
|
|
top_k=top_k,
|
|
node_type=node_type,
|
|
node_name=node_name,
|
|
save_interaction=save_interaction,
|
|
wide_search_top_k=wide_search_top_k,
|
|
triplet_distance_penalty=triplet_distance_penalty,
|
|
)
|
|
self.validation_system_prompt_path = validation_system_prompt_path
|
|
self.validation_user_prompt_path = validation_user_prompt_path
|
|
self.followup_system_prompt_path = followup_system_prompt_path
|
|
self.followup_user_prompt_path = followup_user_prompt_path
|
|
|
|
async def _run_cot_completion(
|
|
self,
|
|
query: Optional[str] = None,
|
|
query_batch: Optional[List[str]] = None,
|
|
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
|
conversation_history: str = "",
|
|
max_iter: int = 4,
|
|
response_model: Type = str,
|
|
) -> tuple[List[Any], List[str], List[List[Edge]]]:
|
|
"""
|
|
Run chain-of-thought completion with optional structured output.
|
|
|
|
Parameters:
|
|
-----------
|
|
- query: User query
|
|
- context: Optional pre-fetched context edges
|
|
- conversation_history: Optional conversation history string
|
|
- max_iter: Maximum CoT iterations
|
|
- response_model: Type for structured output (str for plain text)
|
|
|
|
Returns:
|
|
--------
|
|
- completion_result: The generated completion (string or structured model)
|
|
- context_text: The resolved context text
|
|
- triplets: The list of triplets used
|
|
"""
|
|
followup_question_batch = []
|
|
completion_batch = []
|
|
context_text_batch = []
|
|
|
|
if query:
|
|
# Treat a single query as a batch of queries, mainly avoiding massive code duplication
|
|
query_batch = [query]
|
|
if context:
|
|
context = [context]
|
|
|
|
triplets_batch = context
|
|
|
|
# dict containing query -> QueryState key-value pairs
|
|
# For every query, we save necessary data so we can execute requests in parallel
|
|
query_state_tracker = {}
|
|
for batched_query in query_batch:
|
|
query_state_tracker[batched_query] = QueryState()
|
|
|
|
for round_idx in range(max_iter + 1):
|
|
if round_idx == 0:
|
|
if context is None:
|
|
# Get context, resolve to text, and store info in the query state
|
|
triplets_batch = await self.get_context(query_batch=query_batch)
|
|
context_text_batch = await asyncio.gather(
|
|
*[
|
|
self.resolve_edges_to_text(batched_triplets)
|
|
for batched_triplets in triplets_batch
|
|
]
|
|
)
|
|
for batched_query, batched_triplets, batched_context_text in zip(
|
|
query_batch, triplets_batch, context_text_batch
|
|
):
|
|
query_state_tracker[batched_query].triplets = batched_triplets
|
|
query_state_tracker[batched_query].context_text = batched_context_text
|
|
else:
|
|
# In this case just resolve to text and save to the query state
|
|
context_text_batch = await asyncio.gather(
|
|
*[
|
|
self.resolve_edges_to_text(batched_context)
|
|
for batched_context in context
|
|
]
|
|
)
|
|
for batched_query, batched_triplets, batched_context_text in zip(
|
|
query_batch, context, context_text_batch
|
|
):
|
|
query_state_tracker[batched_query].triplets = batched_triplets
|
|
query_state_tracker[batched_query].context_text = batched_context_text
|
|
else:
|
|
# Find new triplets, and update existing query states
|
|
followup_triplets_batch = await self.get_context(
|
|
query_batch=followup_question_batch
|
|
)
|
|
for batched_query, batched_followup_triplets in zip(
|
|
query_batch, followup_triplets_batch
|
|
):
|
|
query_state_tracker[batched_query].triplets = list(
|
|
dict.fromkeys(
|
|
query_state_tracker[batched_query].triplets + batched_followup_triplets
|
|
)
|
|
)
|
|
|
|
context_text_batch = await asyncio.gather(
|
|
*[
|
|
self.resolve_edges_to_text(batched_query_state.triplets)
|
|
for batched_query_state in query_state_tracker.values()
|
|
]
|
|
)
|
|
|
|
for batched_query, batched_context_text in zip(query_batch, context_text_batch):
|
|
query_state_tracker[batched_query].context_text = batched_context_text
|
|
|
|
completion_batch = await asyncio.gather(
|
|
*[
|
|
generate_completion(
|
|
query=batched_query,
|
|
context=batched_query_state.context_text,
|
|
user_prompt_path=self.user_prompt_path,
|
|
system_prompt_path=self.system_prompt_path,
|
|
system_prompt=self.system_prompt,
|
|
conversation_history=conversation_history if conversation_history else None,
|
|
response_model=response_model,
|
|
)
|
|
for batched_query, batched_query_state in query_state_tracker.items()
|
|
]
|
|
)
|
|
|
|
for batched_query, batched_completion in zip(query_batch, completion_batch):
|
|
query_state_tracker[batched_query].completion = batched_completion
|
|
|
|
logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}")
|
|
|
|
if round_idx < max_iter:
|
|
for batched_query, batched_query_state in query_state_tracker.items():
|
|
batched_query_state.answer_text = _as_answer_text(
|
|
batched_query_state.completion
|
|
)
|
|
valid_args = {
|
|
"query": batched_query,
|
|
"answer": batched_query_state.answer_text,
|
|
"context": batched_query_state.context_text,
|
|
}
|
|
batched_query_state.valid_user_prompt = render_prompt(
|
|
filename=self.validation_user_prompt_path,
|
|
context=valid_args,
|
|
)
|
|
batched_query_state.valid_system_prompt = read_query_prompt(
|
|
prompt_file_name=self.validation_system_prompt_path
|
|
)
|
|
|
|
reasoning_batch = await asyncio.gather(
|
|
*[
|
|
LLMGateway.acreate_structured_output(
|
|
text_input=batched_query_state.valid_user_prompt,
|
|
system_prompt=batched_query_state.valid_system_prompt,
|
|
response_model=str,
|
|
)
|
|
for batched_query_state in query_state_tracker.values()
|
|
]
|
|
)
|
|
|
|
for batched_query, batched_reasoning in zip(query_batch, reasoning_batch):
|
|
query_state_tracker[batched_query].reasoning = batched_reasoning
|
|
|
|
for batched_query, batched_query_state in query_state_tracker.items():
|
|
followup_args = {
|
|
"query": query,
|
|
"answer": batched_query_state.answer_text,
|
|
"reasoning": batched_query_state.reasoning,
|
|
}
|
|
batched_query_state.followup_prompt = render_prompt(
|
|
filename=self.followup_user_prompt_path,
|
|
context=followup_args,
|
|
)
|
|
batched_query_state.followup_system = read_query_prompt(
|
|
prompt_file_name=self.followup_system_prompt_path
|
|
)
|
|
|
|
followup_question_batch = await asyncio.gather(
|
|
*[
|
|
LLMGateway.acreate_structured_output(
|
|
text_input=batched_query_state.followup_prompt,
|
|
system_prompt=batched_query_state.followup_system,
|
|
response_model=str,
|
|
)
|
|
for batched_query_state in query_state_tracker.values()
|
|
]
|
|
)
|
|
logger.info(
|
|
f"Chain-of-thought: round {round_idx} - follow-up questions: {followup_question_batch}"
|
|
)
|
|
|
|
return completion_batch, context_text_batch, triplets_batch
|
|
|
|
async def get_completion(
|
|
self,
|
|
query: Optional[str] = None,
|
|
context: Optional[List[Edge] | List[List[Edge]]] = None,
|
|
session_id: Optional[str] = None,
|
|
max_iter=4,
|
|
response_model: Type = str,
|
|
query_batch: Optional[List[str]] = None,
|
|
) -> List[Any]:
|
|
"""
|
|
Generate completion responses based on a user query and contextual information.
|
|
|
|
This method interacts with a language model client to retrieve a structured response,
|
|
using a series of iterations to refine the answers and generate follow-up questions
|
|
based on reasoning derived from previous outputs. It raises exceptions if the context
|
|
retrieval fails or if the model encounters issues in generating outputs. It returns
|
|
structured output using the provided response model.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- query (str): The user's query to be processed and answered.
|
|
- context (Optional[Any]): Optional context that may assist in answering the query.
|
|
If not provided, it will be fetched based on the query. (default None)
|
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
defaults to 'default_session'. (default None)
|
|
- max_iter: The maximum number of iterations to refine the answer and generate
|
|
follow-up questions. (default 4)
|
|
- response_model (Type): The Pydantic model type for structured output. (default str)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[str]: A list containing the generated answer to the user's query.
|
|
"""
|
|
is_query_valid, msg = validate_queries(query, query_batch)
|
|
if not is_query_valid:
|
|
raise ValueError(msg)
|
|
|
|
# Check if session saving is enabled
|
|
cache_config = CacheConfig()
|
|
user = session_user.get()
|
|
user_id = getattr(user, "id", None)
|
|
session_save = user_id and cache_config.caching
|
|
|
|
# Load conversation history if enabled
|
|
conversation_history = ""
|
|
if session_save:
|
|
conversation_history = await get_conversation_history(session_id=session_id)
|
|
|
|
completion, context_text, triplets = await self._run_cot_completion(
|
|
query=query,
|
|
query_batch=query_batch,
|
|
context=context,
|
|
conversation_history=conversation_history,
|
|
max_iter=max_iter,
|
|
response_model=response_model,
|
|
)
|
|
|
|
# TODO: Handle save interaction for batch queries
|
|
if self.save_interaction and context and triplets and completion:
|
|
await self.save_qa(
|
|
question=query,
|
|
answer=str(completion[0]),
|
|
context=context_text[0],
|
|
triplets=triplets[0],
|
|
)
|
|
|
|
# TODO: Handle session save interaction for batch queries
|
|
# Save to session cache if enabled
|
|
if session_save:
|
|
context_summary = await summarize_text(context_text)
|
|
await save_conversation_history(
|
|
query=query,
|
|
context_summary=context_summary,
|
|
answer=str(completion),
|
|
session_id=session_id,
|
|
)
|
|
|
|
return completion
|