refactor: add structured output to completion retrievers
This commit is contained in:
parent
8d7c4d5384
commit
7e3c24100b
9 changed files with 67 additions and 90 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
||||
|
|
@ -85,7 +85,11 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
return None
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion using provided context or fetch new context.
|
||||
|
|
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
fetched if not provided. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
|
@ -75,7 +75,11 @@ class CompletionRetriever(BaseRetriever):
|
|||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> str:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
|
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
completion; if None, it retrieves the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
context_extension_rounds=4,
|
||||
response_model: Type = str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extends the context for a given query by retrieving related triplets and generating new
|
||||
|
|
@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
defaults to 'default_session'. (default None)
|
||||
- context_extension_rounds: The maximum number of rounds to extend the context with
|
||||
new triplets before halting. (default 4)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context_text and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import (
|
||||
generate_structured_completion,
|
||||
generate_completion,
|
||||
summarize_text,
|
||||
)
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
|
|
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
questions based on reasoning. The public methods are:
|
||||
|
||||
- get_completion
|
||||
- get_structured_completion
|
||||
|
||||
Instance variables include:
|
||||
- validation_system_prompt_path
|
||||
|
|
@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
triplets += await self.get_context(followup_question)
|
||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
|
||||
completion = await generate_structured_completion(
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
|
|
@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
return completion, context_text, triplets
|
||||
|
||||
async def get_structured_completion(
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter: int = 4,
|
||||
max_iter=4,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate structured completion responses based on a user query and contextual information.
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method applies the same chain-of-thought logic as get_completion but returns
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs. It returns
|
||||
structured output using the provided response model.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
|
@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
Returns:
|
||||
--------
|
||||
- Any: The generated structured completion based on the response model.
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
# Check if session saving is enabled
|
||||
cache_config = CacheConfig()
|
||||
|
|
@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
completion = await self.get_structured_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
max_iter=max_iter,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -146,6 +146,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generates a completion using graph connections context based on a query.
|
||||
|
|
@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -146,7 +146,11 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generates a response using the query and optional context.
|
||||
|
|
@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
|
||||
|
||||
async def generate_structured_completion(
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
|
|
@ -11,8 +11,8 @@ async def generate_structured_completion(
|
|||
system_prompt: Optional[str] = None,
|
||||
conversation_history: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
"""Generates a structured completion using LLM with given context and prompts."""
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = render_prompt(user_prompt_path, args)
|
||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||
|
|
@ -28,26 +28,6 @@ async def generate_structured_completion(
|
|||
)
|
||||
|
||||
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
conversation_history: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
return await generate_structured_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
system_prompt_path: str = "summarize_search_results.txt",
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
|
|||
)
|
||||
|
||||
retrieved_context = await retriever.get_context(query_text)
|
||||
completion = await retriever.get_structured_completion(
|
||||
completion = await retriever.get_completion(
|
||||
query=query_text,
|
||||
context=retrieved_context,
|
||||
response_model=ImprovedAnswerResponse,
|
||||
|
|
@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
|
|||
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
||||
|
||||
if completion:
|
||||
enrichment.improved_answer = completion.answer
|
||||
enrichment.improved_answer = completion[0].answer
|
||||
enrichment.new_context = new_context_text
|
||||
enrichment.explanation = completion.explanation
|
||||
enrichment.explanation = completion[0].explanation
|
||||
return enrichment
|
||||
else:
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -206,16 +206,22 @@ class TestGraphCompletionCoTRetriever:
|
|||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_structured_completion("Who works at Figma?")
|
||||
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
|
||||
assert string_answer.strip(), "Answer should not be empty"
|
||||
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
|
||||
"Answer should not be empty"
|
||||
)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_structured_completion(
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
assert isinstance(structured_answer, TestAnswer), (
|
||||
assert isinstance(structured_answer, list), (
|
||||
f"Expected list, got {type(structured_answer).__name__}"
|
||||
)
|
||||
assert all(isinstance(item, TestAnswer) for item in string_answer), (
|
||||
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||
)
|
||||
assert structured_answer.answer.strip(), "Answer field should not be empty"
|
||||
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
|
||||
|
||||
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
|
||||
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue