Merge branch 'dev' into multi-tenancy
This commit is contained in:
commit
9fb7f2c4cf
15 changed files with 279 additions and 147 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
||||
|
|
@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
return None
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generate completion using provided context or fetch new context.
|
||||
|
||||
|
|
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
fetched if not provided. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import Any, List, Optional, Type
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
|
@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
|
||||
) -> str:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""Generates a response using the query and optional context (triplets)."""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type, List
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
|
|
@ -12,7 +12,11 @@ class BaseRetriever(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""Generates a response using the query and optional context."""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type, List
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
|
@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever):
|
|||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> str:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
|
|
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
completion; if None, it retrieves the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Optional, List, Type
|
||||
from typing import Optional, List, Type, Any
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -56,7 +56,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
context_extension_rounds=4,
|
||||
) -> List[str]:
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Extends the context for a given query by retrieving related triplets and generating new
|
||||
completions based on them.
|
||||
|
|
@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
defaults to 'default_session'. (default None)
|
||||
- context_extension_rounds: The maximum number of rounds to extend the context with
|
||||
new triplets before halting. (default 4)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context_text and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import (
|
||||
generate_structured_completion,
|
||||
generate_completion,
|
||||
summarize_text,
|
||||
)
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
|
|
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
questions based on reasoning. The public methods are:
|
||||
|
||||
- get_completion
|
||||
- get_structured_completion
|
||||
|
||||
Instance variables include:
|
||||
- validation_system_prompt_path
|
||||
|
|
@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
triplets += await self.get_context(followup_question)
|
||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
|
||||
completion = await generate_structured_completion(
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
|
|
@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
return completion, context_text, triplets
|
||||
|
||||
async def get_structured_completion(
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter: int = 4,
|
||||
max_iter=4,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generate structured completion responses based on a user query and contextual information.
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method applies the same chain-of-thought logic as get_completion but returns
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs. It returns
|
||||
structured output using the provided response model.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
|
@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
Returns:
|
||||
--------
|
||||
- Any: The generated structured completion based on the response model.
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
# Check if session saving is enabled
|
||||
cache_config = CacheConfig()
|
||||
|
|
@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
completion = await self.get_structured_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
max_iter=max_iter,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -146,7 +146,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
query: str,
|
||||
context: Optional[List[Edge]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generates a completion using graph connections context based on a query.
|
||||
|
||||
|
|
@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -146,8 +146,12 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generates a response using the query and optional context.
|
||||
|
||||
|
|
@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
retrieved based on the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
|
||||
|
||||
async def generate_structured_completion(
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
|
|
@ -12,7 +12,7 @@ async def generate_structured_completion(
|
|||
conversation_history: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> Any:
|
||||
"""Generates a structured completion using LLM with given context and prompts."""
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = render_prompt(user_prompt_path, args)
|
||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||
|
|
@ -28,26 +28,6 @@ async def generate_structured_completion(
|
|||
)
|
||||
|
||||
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
conversation_history: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
return await generate_structured_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
system_prompt_path: str = "summarize_search_results.txt",
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
|
|||
)
|
||||
|
||||
retrieved_context = await retriever.get_context(query_text)
|
||||
completion = await retriever.get_structured_completion(
|
||||
completion = await retriever.get_completion(
|
||||
query=query_text,
|
||||
context=retrieved_context,
|
||||
response_model=ImprovedAnswerResponse,
|
||||
|
|
@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
|
|||
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
||||
|
||||
if completion:
|
||||
enrichment.improved_answer = completion.answer
|
||||
enrichment.improved_answer = completion[0].answer
|
||||
enrichment.new_context = new_context_text
|
||||
enrichment.explanation = completion.explanation
|
||||
enrichment.explanation = completion[0].explanation
|
||||
return enrichment
|
||||
else:
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import os
|
|||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
|
|
@ -11,11 +10,6 @@ from cognee.tasks.storage import add_data_points
|
|||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
||||
|
||||
class TestAnswer(BaseModel):
|
||||
answer: str
|
||||
explanation: str
|
||||
|
||||
|
||||
class TestGraphCompletionCoTRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_simple(self):
|
||||
|
|
@ -174,48 +168,3 @@ class TestGraphCompletionCoTRetriever:
|
|||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_structured_completion(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
|
||||
entities = [company1, person1]
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_structured_completion("Who works at Figma?")
|
||||
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
|
||||
assert string_answer.strip(), "Answer should not be empty"
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_structured_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
assert isinstance(structured_answer, TestAnswer), (
|
||||
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||
)
|
||||
assert structured_answer.answer.strip(), "Answer field should not be empty"
|
||||
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import List
|
|||
import pytest
|
||||
import pathlib
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
|
|
|||
204
cognee/tests/unit/modules/retrieval/structured_output_test.py
Normal file
204
cognee/tests/unit/modules/retrieval/structured_output_test.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import cognee
|
||||
import pathlib
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
|
||||
from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
|
||||
|
||||
class TestAnswer(BaseModel):
|
||||
answer: str
|
||||
explanation: str
|
||||
|
||||
|
||||
def _assert_string_answer(answer: list[str]):
|
||||
assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings"
|
||||
assert all(item.strip() for item in answer), "Items should not be empty"
|
||||
|
||||
|
||||
def _assert_structured_answer(answer: list[TestAnswer]):
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer"
|
||||
assert all(x.answer.strip() for x in answer), "Answer text should not be empty"
|
||||
assert all(x.explanation.strip() for x in answer), "Explanation should not be empty"
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion_cot():
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion():
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion_temporal():
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("When did Steve start working at Figma?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"When did Steve start working at Figma??", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion_rag():
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Where does Steve work?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Where does Steve work?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_graph_completion_context_extension():
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who works at Figma?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
async def _test_get_structured_entity_completion():
|
||||
retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
|
||||
|
||||
# Test with string response model (default)
|
||||
string_answer = await retriever.get_completion("Who is Albert Einstein?")
|
||||
_assert_string_answer(string_answer)
|
||||
|
||||
# Test with structured response model
|
||||
structured_answer = await retriever.get_completion(
|
||||
"Who is Albert Einstein?", response_model=TestAnswer
|
||||
)
|
||||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
class TestStructuredOutputCompletion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_structured_completion(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
works_since: int
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
||||
|
||||
entities = [company1, person1]
|
||||
await add_data_points(entities)
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
await add_data_points(entities)
|
||||
|
||||
entity_type = EntityType(name="Person", description="A human individual")
|
||||
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
||||
|
||||
entities = [entity]
|
||||
await add_data_points(entities)
|
||||
|
||||
await _test_get_structured_graph_completion_cot()
|
||||
await _test_get_structured_graph_completion()
|
||||
await _test_get_structured_graph_completion_temporal()
|
||||
await _test_get_structured_graph_completion_rag()
|
||||
await _test_get_structured_graph_completion_context_extension()
|
||||
await _test_get_structured_entity_completion()
|
||||
|
|
@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
|
||||
|
||||
class TextSummariesRetriever:
|
||||
class TestSummariesRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context(self):
|
||||
system_directory_path = os.path.join(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue