test: fix completion tests

This commit is contained in:
Andrej Milicevic 2025-11-04 15:27:03 +01:00
parent 7e3c24100b
commit 33b0516381
2 changed files with 60 additions and 1 deletions

View file

@ -2,6 +2,7 @@ import os
import pytest
import pathlib
from typing import Optional, Union
from pydantic import BaseModel
import cognee
from cognee.low_level import setup, DataPoint
@ -12,6 +13,11 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor
)
class TestAnswer(BaseModel):
answer: str
explanation: str
class TestGraphCompletionWithContextExtensionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_extension_context_simple(self):
@ -175,3 +181,56 @@ class TestGraphCompletionWithContextExtensionRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_structured_completion_extension_context(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_structured_completion_extension_context",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_structured_completion_extension_context",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1)
entities = [company1, person1]
await add_data_points(entities)
retriever = GraphCompletionContextExtensionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
"Answer should not be empty"
)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
assert isinstance(structured_answer, list), (
f"Expected list, got {type(structured_answer).__name__}"
)
assert all(isinstance(item, TestAnswer) for item in structured_answer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}"
)
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"

View file

@ -219,7 +219,7 @@ class TestGraphCompletionCoTRetriever:
assert isinstance(structured_answer, list), (
f"Expected list, got {type(structured_answer).__name__}"
)
assert all(isinstance(item, TestAnswer) for item in string_answer), (
assert all(isinstance(item, TestAnswer) for item in structured_answer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}"
)