refactor: unify structured and str completion
This commit is contained in:
parent
66a8242cec
commit
ecae650a28
2 changed files with 71 additions and 57 deletions
|
|
@ -6,7 +6,10 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
from cognee.modules.retrieval.utils.completion import (
|
||||||
|
generate_structured_completion,
|
||||||
|
summarize_text,
|
||||||
|
)
|
||||||
from cognee.modules.retrieval.utils.session_cache import (
|
from cognee.modules.retrieval.utils.session_cache import (
|
||||||
save_conversation_history,
|
save_conversation_history,
|
||||||
get_conversation_history,
|
get_conversation_history,
|
||||||
|
|
@ -82,12 +85,20 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[List[Edge]] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
session_id: Optional[str] = None,
|
conversation_history: str = "",
|
||||||
max_iter: int = 4,
|
max_iter: int = 4,
|
||||||
response_model: Type = str,
|
response_model: Type = str,
|
||||||
) -> tuple[Any, str, List[Edge]]:
|
) -> tuple[Any, str, List[Edge]]:
|
||||||
"""
|
"""
|
||||||
Run chain-of-thought completion with optional structured output and session caching.
|
Run chain-of-thought completion with optional structured output.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- query: User query
|
||||||
|
- context: Optional pre-fetched context edges
|
||||||
|
- conversation_history: Optional conversation history string
|
||||||
|
- max_iter: Maximum CoT iterations
|
||||||
|
- response_model: Type for structured output (str for plain text)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -99,16 +110,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
triplets = []
|
triplets = []
|
||||||
completion = ""
|
completion = ""
|
||||||
|
|
||||||
# Retrieve conversation history if session saving is enabled
|
|
||||||
cache_config = CacheConfig()
|
|
||||||
user = session_user.get()
|
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
session_save = user_id and cache_config.caching
|
|
||||||
|
|
||||||
conversation_history = ""
|
|
||||||
if session_save:
|
|
||||||
conversation_history = await get_conversation_history(session_id=session_id)
|
|
||||||
|
|
||||||
for round_idx in range(max_iter + 1):
|
for round_idx in range(max_iter + 1):
|
||||||
if round_idx == 0:
|
if round_idx == 0:
|
||||||
if context is None:
|
if context is None:
|
||||||
|
|
@ -120,29 +121,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
triplets += await self.get_context(followup_question)
|
triplets += await self.get_context(followup_question)
|
||||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||||
|
|
||||||
if response_model is str:
|
completion = await generate_structured_completion(
|
||||||
completion = await generate_completion(
|
query=query,
|
||||||
query=query,
|
context=context_text,
|
||||||
context=context_text,
|
user_prompt_path=self.user_prompt_path,
|
||||||
user_prompt_path=self.user_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt=self.system_prompt,
|
||||||
system_prompt=self.system_prompt,
|
conversation_history=conversation_history if conversation_history else None,
|
||||||
conversation_history=conversation_history if session_save else None,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
args = {"question": query, "context": context_text}
|
|
||||||
user_prompt = render_prompt(self.user_prompt_path, args)
|
|
||||||
system_prompt = (
|
|
||||||
self.system_prompt
|
|
||||||
if self.system_prompt
|
|
||||||
else read_query_prompt(self.system_prompt_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
completion = await LLMGateway.acreate_structured_output(
|
|
||||||
text_input=user_prompt,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
response_model=response_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||||
|
|
||||||
|
|
@ -176,16 +163,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to session cache
|
|
||||||
if session_save:
|
|
||||||
context_summary = await summarize_text(context_text)
|
|
||||||
await save_conversation_history(
|
|
||||||
query=query,
|
|
||||||
context_summary=context_summary,
|
|
||||||
answer=str(completion),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return completion, context_text, triplets
|
return completion, context_text, triplets
|
||||||
|
|
||||||
async def get_structured_completion(
|
async def get_structured_completion(
|
||||||
|
|
@ -217,10 +194,21 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
--------
|
--------
|
||||||
- Any: The generated structured completion based on the response model.
|
- Any: The generated structured completion based on the response model.
|
||||||
"""
|
"""
|
||||||
|
# Check if session saving is enabled
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
user = session_user.get()
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
session_save = user_id and cache_config.caching
|
||||||
|
|
||||||
|
# Load conversation history if enabled
|
||||||
|
conversation_history = ""
|
||||||
|
if session_save:
|
||||||
|
conversation_history = await get_conversation_history(session_id=session_id)
|
||||||
|
|
||||||
completion, context_text, triplets = await self._run_cot_completion(
|
completion, context_text, triplets = await self._run_cot_completion(
|
||||||
query=query,
|
query=query,
|
||||||
context=context,
|
context=context,
|
||||||
session_id=session_id,
|
conversation_history=conversation_history,
|
||||||
max_iter=max_iter,
|
max_iter=max_iter,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
@ -230,6 +218,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=str(completion), context=context_text, triplets=triplets
|
question=query, answer=str(completion), context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save to session cache if enabled
|
||||||
|
if session_save:
|
||||||
|
context_summary = await summarize_text(context_text)
|
||||||
|
await save_conversation_history(
|
||||||
|
query=query,
|
||||||
|
context_summary=context_summary,
|
||||||
|
answer=str(completion),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
|
|
@ -263,7 +261,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
- List[str]: A list containing the generated answer to the user's query.
|
- List[str]: A list containing the generated answer to the user's query.
|
||||||
"""
|
"""
|
||||||
completion, context_text, triplets = await self._run_cot_completion(
|
completion = await self.get_structured_completion(
|
||||||
query=query,
|
query=query,
|
||||||
context=context,
|
context=context,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
|
@ -271,9 +269,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
response_model=str,
|
response_model=str,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
|
||||||
await self.save_qa(
|
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
|
||||||
)
|
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
from typing import Optional
|
from typing import Optional, Type, Any
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
async def generate_completion(
|
async def generate_structured_completion(
|
||||||
query: str,
|
query: str,
|
||||||
context: str,
|
context: str,
|
||||||
user_prompt_path: str,
|
user_prompt_path: str,
|
||||||
system_prompt_path: str,
|
system_prompt_path: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
conversation_history: Optional[str] = None,
|
conversation_history: Optional[str] = None,
|
||||||
) -> str:
|
response_model: Type = str,
|
||||||
"""Generates a completion using LLM with given context and prompts."""
|
) -> Any:
|
||||||
|
"""Generates a structured completion using LLM with given context and prompts."""
|
||||||
args = {"question": query, "context": context}
|
args = {"question": query, "context": context}
|
||||||
user_prompt = render_prompt(user_prompt_path, args)
|
user_prompt = render_prompt(user_prompt_path, args)
|
||||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||||
|
|
@ -23,6 +24,26 @@ async def generate_completion(
|
||||||
return await LLMGateway.acreate_structured_output(
|
return await LLMGateway.acreate_structured_output(
|
||||||
text_input=user_prompt,
|
text_input=user_prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_completion(
|
||||||
|
query: str,
|
||||||
|
context: str,
|
||||||
|
user_prompt_path: str,
|
||||||
|
system_prompt_path: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
conversation_history: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generates a completion using LLM with given context and prompts."""
|
||||||
|
return await generate_structured_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=user_prompt_path,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
conversation_history=conversation_history,
|
||||||
response_model=str,
|
response_model=str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue