feat: allow structured output in the cot retriever
This commit is contained in:
parent
97eb89386e
commit
1e1fac3261
2 changed files with 201 additions and 51 deletions
|
|
@ -1,4 +1,6 @@
|
|||
import json
|
||||
from typing import Optional, List, Type, Any
|
||||
from pydantic import BaseModel
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
|
@ -10,6 +12,20 @@ from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
|||
logger = get_logger()
|
||||
|
||||
|
||||
def _as_answer_text(completion: Any) -> str:
|
||||
"""Convert completion to human-readable text for validation and follow-up prompts."""
|
||||
if isinstance(completion, str):
|
||||
return completion
|
||||
if isinstance(completion, BaseModel):
|
||||
# Add notice that this is a structured response
|
||||
json_str = completion.model_dump_json(indent=2)
|
||||
return f"[Structured Response]\n{json_str}"
|
||||
try:
|
||||
return json.dumps(completion, indent=2)
|
||||
except TypeError:
|
||||
return str(completion)
|
||||
|
||||
|
||||
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||
"""
|
||||
Handles graph completion by generating responses based on a series of interactions with
|
||||
|
|
@ -54,6 +70,134 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
self.followup_system_prompt_path = followup_system_prompt_path
|
||||
self.followup_user_prompt_path = followup_user_prompt_path
|
||||
|
||||
async def _run_cot_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
max_iter: int = 4,
|
||||
response_model: Type = str,
|
||||
) -> tuple[Any, str, List[Edge]]:
|
||||
"""
|
||||
Run chain-of-thought completion with optional structured output.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- completion_result: The generated completion (string or structured model)
|
||||
- context_text: The resolved context text
|
||||
- triplets: The list of triplets used
|
||||
"""
|
||||
followup_question = ""
|
||||
triplets = []
|
||||
completion = ""
|
||||
|
||||
for round_idx in range(max_iter + 1):
|
||||
if round_idx == 0:
|
||||
if context is None:
|
||||
triplets = await self.get_context(query)
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
else:
|
||||
context_text = await self.resolve_edges_to_text(context)
|
||||
else:
|
||||
triplets += await self.get_context(followup_question)
|
||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
|
||||
if response_model is str:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
else:
|
||||
args = {"question": query, "context": context_text}
|
||||
user_prompt = render_prompt(self.user_prompt_path, args)
|
||||
system_prompt = (
|
||||
self.system_prompt
|
||||
if self.system_prompt
|
||||
else read_query_prompt(self.system_prompt_path)
|
||||
)
|
||||
|
||||
completion = await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||
|
||||
if round_idx < max_iter:
|
||||
answer_text = _as_answer_text(completion)
|
||||
valid_args = {"query": query, "answer": answer_text, "context": context_text}
|
||||
valid_user_prompt = render_prompt(
|
||||
filename=self.validation_user_prompt_path, context=valid_args
|
||||
)
|
||||
valid_system_prompt = read_query_prompt(
|
||||
prompt_file_name=self.validation_system_prompt_path
|
||||
)
|
||||
|
||||
reasoning = await LLMGateway.acreate_structured_output(
|
||||
text_input=valid_user_prompt,
|
||||
system_prompt=valid_system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
|
||||
followup_prompt = render_prompt(
|
||||
filename=self.followup_user_prompt_path, context=followup_args
|
||||
)
|
||||
followup_system = read_query_prompt(
|
||||
prompt_file_name=self.followup_system_prompt_path
|
||||
)
|
||||
|
||||
followup_question = await LLMGateway.acreate_structured_output(
|
||||
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
||||
)
|
||||
logger.info(
|
||||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||
)
|
||||
|
||||
return completion, context_text, triplets
|
||||
|
||||
async def get_structured_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
max_iter: int = 4,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
"""
|
||||
Generate structured completion responses based on a user query and contextual information.
|
||||
|
||||
This method applies the same chain-of-thought logic as get_completion but returns
|
||||
structured output using the provided response model.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Any: The generated structured completion based on the response model.
|
||||
"""
|
||||
completion, context_text, triplets = await self._run_cot_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
max_iter=max_iter,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=str(completion), context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -82,57 +226,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
followup_question = ""
|
||||
triplets = []
|
||||
completion = ""
|
||||
|
||||
for round_idx in range(max_iter + 1):
|
||||
if round_idx == 0:
|
||||
if context is None:
|
||||
triplets = await self.get_context(query)
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
else:
|
||||
context_text = await self.resolve_edges_to_text(context)
|
||||
else:
|
||||
triplets += await self.get_context(followup_question)
|
||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||
if round_idx < max_iter:
|
||||
valid_args = {"query": query, "answer": completion, "context": context_text}
|
||||
valid_user_prompt = render_prompt(
|
||||
filename=self.validation_user_prompt_path, context=valid_args
|
||||
)
|
||||
valid_system_prompt = read_query_prompt(
|
||||
prompt_file_name=self.validation_system_prompt_path
|
||||
)
|
||||
|
||||
reasoning = await LLMGateway.acreate_structured_output(
|
||||
text_input=valid_user_prompt,
|
||||
system_prompt=valid_system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
|
||||
followup_prompt = render_prompt(
|
||||
filename=self.followup_user_prompt_path, context=followup_args
|
||||
)
|
||||
followup_system = read_query_prompt(
|
||||
prompt_file_name=self.followup_system_prompt_path
|
||||
)
|
||||
|
||||
followup_question = await LLMGateway.acreate_structured_output(
|
||||
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
||||
)
|
||||
logger.info(
|
||||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||
)
|
||||
completion, context_text, triplets = await self._run_cot_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
max_iter=max_iter,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
|
|
@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points
|
|||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
||||
|
||||
class TestAnswer(BaseModel):
|
||||
answer: str
|
||||
explanation: str
|
||||
|
||||
|
||||
class TestGraphCompletionCoTRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_simple(self):
|
||||
|
|
@ -168,3 +174,48 @@ class TestGraphCompletionCoTRetriever:
|
|||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_structured_completion(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
|
||||
entities = [company1, person1]
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_structured_completion("Who works at Figma?")
|
||||
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
|
||||
assert string_answer.strip(), "Answer should not be empty"
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_structured_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
assert isinstance(structured_answer, TestAnswer), (
|
||||
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||
)
|
||||
assert structured_answer.answer.strip(), "Answer field should not be empty"
|
||||
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue