refactor: add structured output to completion retrievers
This commit is contained in:
parent
8d7c4d5384
commit
7e3c24100b
9 changed files with 67 additions and 90 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Optional, List
|
from typing import Any, Optional, List, Type
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
||||||
|
|
@ -85,7 +85,11 @@ class EntityCompletionRetriever(BaseRetriever):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
response_model: Type = str,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generate completion using provided context or fetch new context.
|
Generate completion using provided context or fetch new context.
|
||||||
|
|
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
||||||
fetched if not provided. (default None)
|
fetched if not provided. (default None)
|
||||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||||
defaults to 'default_session'. (default None)
|
defaults to 'default_session'. (default None)
|
||||||
|
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
conversation_history=conversation_history,
|
conversation_history=conversation_history,
|
||||||
|
response_model=response_model,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
@ -75,7 +75,11 @@ class CompletionRetriever(BaseRetriever):
|
||||||
raise NoDataError("No data found in the system, please add data first.") from error
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
response_model: Type = str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generates an LLM completion using the context.
|
Generates an LLM completion using the context.
|
||||||
|
|
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
|
||||||
completion; if None, it retrieves the context for the query. (default None)
|
completion; if None, it retrieves the context for the query. (default None)
|
||||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||||
defaults to 'default_session'. (default None)
|
defaults to 'default_session'. (default None)
|
||||||
|
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
conversation_history=conversation_history,
|
conversation_history=conversation_history,
|
||||||
|
response_model=response_model,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
|
|
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
context: Optional[List[Edge]] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
context_extension_rounds=4,
|
context_extension_rounds=4,
|
||||||
|
response_model: Type = str,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Extends the context for a given query by retrieving related triplets and generating new
|
Extends the context for a given query by retrieving related triplets and generating new
|
||||||
|
|
@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
defaults to 'default_session'. (default None)
|
defaults to 'default_session'. (default None)
|
||||||
- context_extension_rounds: The maximum number of rounds to extend the context with
|
- context_extension_rounds: The maximum number of rounds to extend the context with
|
||||||
new triplets before halting. (default 4)
|
new triplets before halting. (default 4)
|
||||||
|
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
conversation_history=conversation_history,
|
conversation_history=conversation_history,
|
||||||
|
response_model=response_model,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context_text and triplets and completion:
|
if self.save_interaction and context_text and triplets and completion:
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.utils.completion import (
|
from cognee.modules.retrieval.utils.completion import (
|
||||||
generate_structured_completion,
|
generate_completion,
|
||||||
summarize_text,
|
summarize_text,
|
||||||
)
|
)
|
||||||
from cognee.modules.retrieval.utils.session_cache import (
|
from cognee.modules.retrieval.utils.session_cache import (
|
||||||
|
|
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
questions based on reasoning. The public methods are:
|
questions based on reasoning. The public methods are:
|
||||||
|
|
||||||
- get_completion
|
- get_completion
|
||||||
- get_structured_completion
|
|
||||||
|
|
||||||
Instance variables include:
|
Instance variables include:
|
||||||
- validation_system_prompt_path
|
- validation_system_prompt_path
|
||||||
|
|
@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
triplets += await self.get_context(followup_question)
|
triplets += await self.get_context(followup_question)
|
||||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||||
|
|
||||||
completion = await generate_structured_completion(
|
completion = await generate_completion(
|
||||||
query=query,
|
query=query,
|
||||||
context=context_text,
|
context=context_text,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
|
@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
return completion, context_text, triplets
|
return completion, context_text, triplets
|
||||||
|
|
||||||
async def get_structured_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[List[Edge]] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
max_iter: int = 4,
|
max_iter=4,
|
||||||
response_model: Type = str,
|
response_model: Type = str,
|
||||||
) -> Any:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generate structured completion responses based on a user query and contextual information.
|
Generate completion responses based on a user query and contextual information.
|
||||||
|
|
||||||
This method applies the same chain-of-thought logic as get_completion but returns
|
This method interacts with a language model client to retrieve a structured response,
|
||||||
|
using a series of iterations to refine the answers and generate follow-up questions
|
||||||
|
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||||
|
retrieval fails or if the model encounters issues in generating outputs. It returns
|
||||||
structured output using the provided response model.
|
structured output using the provided response model.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- query (str): The user's query to be processed and answered.
|
- query (str): The user's query to be processed and answered.
|
||||||
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||||
If not provided, it will be fetched based on the query. (default None)
|
If not provided, it will be fetched based on the query. (default None)
|
||||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||||
defaults to 'default_session'. (default None)
|
defaults to 'default_session'. (default None)
|
||||||
|
|
@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
- Any: The generated structured completion based on the response model.
|
|
||||||
|
- List[str]: A list containing the generated answer to the user's query.
|
||||||
"""
|
"""
|
||||||
# Check if session saving is enabled
|
# Check if session saving is enabled
|
||||||
cache_config = CacheConfig()
|
cache_config = CacheConfig()
|
||||||
|
|
@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
|
||||||
|
|
||||||
async def get_completion(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context: Optional[List[Edge]] = None,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
max_iter=4,
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
Generate completion responses based on a user query and contextual information.
|
|
||||||
|
|
||||||
This method interacts with a language model client to retrieve a structured response,
|
|
||||||
using a series of iterations to refine the answers and generate follow-up questions
|
|
||||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
|
||||||
retrieval fails or if the model encounters issues in generating outputs.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- query (str): The user's query to be processed and answered.
|
|
||||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
|
||||||
If not provided, it will be fetched based on the query. (default None)
|
|
||||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
||||||
defaults to 'default_session'. (default None)
|
|
||||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
|
||||||
follow-up questions. (default 4)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- List[str]: A list containing the generated answer to the user's query.
|
|
||||||
"""
|
|
||||||
completion = await self.get_structured_completion(
|
|
||||||
query=query,
|
|
||||||
context=context,
|
|
||||||
session_id=session_id,
|
|
||||||
max_iter=max_iter,
|
|
||||||
response_model=str,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[List[Edge]] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
response_model: Type = str,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generates a completion using graph connections context based on a query.
|
Generates a completion using graph connections context based on a query.
|
||||||
|
|
@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
conversation_history=conversation_history,
|
conversation_history=conversation_history,
|
||||||
|
response_model=response_model,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,11 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
return self.descriptions_to_string(top_k_events)
|
return self.descriptions_to_string(top_k_events)
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
response_model: Type = str,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generates a response using the query and optional context.
|
Generates a response using the query and optional context.
|
||||||
|
|
@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
retrieved based on the query. (default None)
|
retrieved based on the query. (default None)
|
||||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||||
defaults to 'default_session'. (default None)
|
defaults to 'default_session'. (default None)
|
||||||
|
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
conversation_history=conversation_history,
|
conversation_history=conversation_history,
|
||||||
|
response_model=response_model,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
async def generate_structured_completion(
|
async def generate_completion(
|
||||||
query: str,
|
query: str,
|
||||||
context: str,
|
context: str,
|
||||||
user_prompt_path: str,
|
user_prompt_path: str,
|
||||||
|
|
@ -11,8 +11,8 @@ async def generate_structured_completion(
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
conversation_history: Optional[str] = None,
|
conversation_history: Optional[str] = None,
|
||||||
response_model: Type = str,
|
response_model: Type = str,
|
||||||
) -> Any:
|
) -> str:
|
||||||
"""Generates a structured completion using LLM with given context and prompts."""
|
"""Generates a completion using LLM with given context and prompts."""
|
||||||
args = {"question": query, "context": context}
|
args = {"question": query, "context": context}
|
||||||
user_prompt = render_prompt(user_prompt_path, args)
|
user_prompt = render_prompt(user_prompt_path, args)
|
||||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||||
|
|
@ -28,26 +28,6 @@ async def generate_structured_completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_completion(
|
|
||||||
query: str,
|
|
||||||
context: str,
|
|
||||||
user_prompt_path: str,
|
|
||||||
system_prompt_path: str,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
conversation_history: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Generates a completion using LLM with given context and prompts."""
|
|
||||||
return await generate_structured_completion(
|
|
||||||
query=query,
|
|
||||||
context=context,
|
|
||||||
user_prompt_path=user_prompt_path,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
conversation_history=conversation_history,
|
|
||||||
response_model=str,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def summarize_text(
|
async def summarize_text(
|
||||||
text: str,
|
text: str,
|
||||||
system_prompt_path: str = "summarize_search_results.txt",
|
system_prompt_path: str = "summarize_search_results.txt",
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
|
||||||
)
|
)
|
||||||
|
|
||||||
retrieved_context = await retriever.get_context(query_text)
|
retrieved_context = await retriever.get_context(query_text)
|
||||||
completion = await retriever.get_structured_completion(
|
completion = await retriever.get_completion(
|
||||||
query=query_text,
|
query=query_text,
|
||||||
context=retrieved_context,
|
context=retrieved_context,
|
||||||
response_model=ImprovedAnswerResponse,
|
response_model=ImprovedAnswerResponse,
|
||||||
|
|
@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
|
||||||
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
||||||
|
|
||||||
if completion:
|
if completion:
|
||||||
enrichment.improved_answer = completion.answer
|
enrichment.improved_answer = completion[0].answer
|
||||||
enrichment.new_context = new_context_text
|
enrichment.new_context = new_context_text
|
||||||
enrichment.explanation = completion.explanation
|
enrichment.explanation = completion[0].explanation
|
||||||
return enrichment
|
return enrichment
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -206,16 +206,22 @@ class TestGraphCompletionCoTRetriever:
|
||||||
retriever = GraphCompletionCotRetriever()
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
# Test with string response model (default)
|
# Test with string response model (default)
|
||||||
string_answer = await retriever.get_structured_completion("Who works at Figma?")
|
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||||
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
|
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
|
||||||
assert string_answer.strip(), "Answer should not be empty"
|
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
|
||||||
|
"Answer should not be empty"
|
||||||
|
)
|
||||||
|
|
||||||
# Test with structured response model
|
# Test with structured response model
|
||||||
structured_answer = await retriever.get_structured_completion(
|
structured_answer = await retriever.get_completion(
|
||||||
"Who works at Figma?", response_model=TestAnswer
|
"Who works at Figma?", response_model=TestAnswer
|
||||||
)
|
)
|
||||||
assert isinstance(structured_answer, TestAnswer), (
|
assert isinstance(structured_answer, list), (
|
||||||
|
f"Expected list, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
assert all(isinstance(item, TestAnswer) for item in string_answer), (
|
||||||
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||||
)
|
)
|
||||||
assert structured_answer.answer.strip(), "Answer field should not be empty"
|
|
||||||
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
|
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
|
||||||
|
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue