test: fix completion tests
This commit is contained in:
parent
7e3c24100b
commit
33b0516381
2 changed files with 60 additions and 1 deletions
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
|
|
@ -12,6 +13,11 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnswer(BaseModel):
|
||||||
|
answer: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
class TestGraphCompletionWithContextExtensionRetriever:
|
class TestGraphCompletionWithContextExtensionRetriever:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_extension_context_simple(self):
|
async def test_graph_completion_extension_context_simple(self):
|
||||||
|
|
@ -175,3 +181,56 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
"Answer must contain only non-empty strings"
|
"Answer must contain only non-empty strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_structured_completion_extension_context(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent,
|
||||||
|
".cognee_system/test_get_structured_completion_extension_context",
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent,
|
||||||
|
".data_storage/test_get_structured_completion_extension_context",
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
|
|
||||||
|
entities = [company1, person1]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
|
# Test with string response model (default)
|
||||||
|
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
|
||||||
|
"Answer should not be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with structured response model
|
||||||
|
structured_answer = await retriever.get_completion(
|
||||||
|
"Who works at Figma?", response_model=TestAnswer
|
||||||
|
)
|
||||||
|
assert isinstance(structured_answer, list), (
|
||||||
|
f"Expected list, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
assert all(isinstance(item, TestAnswer) for item in structured_answer), (
|
||||||
|
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
|
||||||
|
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"
|
||||||
|
|
|
||||||
|
|
@ -219,7 +219,7 @@ class TestGraphCompletionCoTRetriever:
|
||||||
assert isinstance(structured_answer, list), (
|
assert isinstance(structured_answer, list), (
|
||||||
f"Expected list, got {type(structured_answer).__name__}"
|
f"Expected list, got {type(structured_answer).__name__}"
|
||||||
)
|
)
|
||||||
assert all(isinstance(item, TestAnswer) for item in string_answer), (
|
assert all(isinstance(item, TestAnswer) for item in structured_answer), (
|
||||||
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue