test: add retriever tests
This commit is contained in:
parent
33b0516381
commit
215ef7f3c2
6 changed files with 271 additions and 6 deletions
|
|
@ -0,0 +1,65 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.modules.engine.models import Entity, EntityType
|
||||||
|
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
|
||||||
|
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
|
||||||
|
from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnswer(BaseModel):
|
||||||
|
answer: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Add more tests, similar to other retrievers.
|
||||||
|
# TODO: For the tests, one needs to define an Entity Extractor and a Context Provider.
|
||||||
|
class TestEntityCompletionRetriever:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_entity_structured_completion(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_get_entity_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_get_entity_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
entity_type = EntityType(name="Person", description="A human individual")
|
||||||
|
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
||||||
|
|
||||||
|
entities = [entity]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
|
||||||
|
|
||||||
|
# Test with string response model (default)
|
||||||
|
string_answer = await retriever.get_completion("Who is Albert Einstein?")
|
||||||
|
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
|
||||||
|
"Answer should not be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with structured response model
|
||||||
|
structured_answer = await retriever.get_completion(
|
||||||
|
"Who is Albert Einstein?", response_model=TestAnswer
|
||||||
|
)
|
||||||
|
assert isinstance(structured_answer, list), (
|
||||||
|
f"Expected list, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
assert all(isinstance(item, TestAnswer) for item in structured_answer), (
|
||||||
|
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
|
||||||
|
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"
|
||||||
|
|
@ -183,15 +183,15 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_structured_completion_extension_context(self):
|
async def test_get_graph_structured_completion_extension_context(self):
|
||||||
system_directory_path = os.path.join(
|
system_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent,
|
pathlib.Path(__file__).parent,
|
||||||
".cognee_system/test_get_structured_completion_extension_context",
|
".cognee_system/test_get_graph_structured_completion_extension_context",
|
||||||
)
|
)
|
||||||
cognee.config.system_root_directory(system_directory_path)
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
data_directory_path = os.path.join(
|
data_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent,
|
pathlib.Path(__file__).parent,
|
||||||
".data_storage/test_get_structured_completion_extension_context",
|
".data_storage/test_get_graph_structured_completion_extension_context",
|
||||||
)
|
)
|
||||||
cognee.config.data_root_directory(data_directory_path)
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -176,13 +176,13 @@ class TestGraphCompletionCoTRetriever:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_structured_completion(self):
|
async def test_get_graph_structured_completion_cot(self):
|
||||||
system_directory_path = os.path.join(
|
system_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
pathlib.Path(__file__).parent, ".cognee_system/test_get_graph_structured_completion_cot"
|
||||||
)
|
)
|
||||||
cognee.config.system_root_directory(system_directory_path)
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
data_directory_path = os.path.join(
|
data_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
pathlib.Path(__file__).parent, ".data_storage/test_get_graph_structured_completion_cot"
|
||||||
)
|
)
|
||||||
cognee.config.data_root_directory(data_directory_path)
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
|
|
@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnswer(BaseModel):
|
||||||
|
answer: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
class TestGraphCompletionRetriever:
|
class TestGraphCompletionRetriever:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_context_simple(self):
|
async def test_graph_completion_context_simple(self):
|
||||||
|
|
@ -221,3 +227,54 @@ class TestGraphCompletionRetriever:
|
||||||
|
|
||||||
context = await retriever.get_context("Who works at Figma?")
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
assert context == [], "Context should be empty on an empty graph"
|
assert context == [], "Context should be empty on an empty graph"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_structured_completion(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_get_graph_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_get_graph_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
|
|
||||||
|
entities = [company1, person1]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionRetriever()
|
||||||
|
|
||||||
|
# Test with string response model (default)
|
||||||
|
string_answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
|
||||||
|
"Answer should not be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with structured response model
|
||||||
|
structured_answer = await retriever.get_completion(
|
||||||
|
"Who works at Figma?", response_model=TestAnswer
|
||||||
|
)
|
||||||
|
assert isinstance(structured_answer, list), (
|
||||||
|
f"Expected list, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
assert all(isinstance(item, TestAnswer) for item in structured_answer), (
|
||||||
|
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
|
||||||
|
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from typing import List
|
||||||
import pytest
|
import pytest
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
|
from pydantic import BaseModel
|
||||||
from cognee.low_level import setup
|
from cognee.low_level import setup
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
@ -26,6 +27,11 @@ class DocumentChunkWithEntities(DataPoint):
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnswer(BaseModel):
|
||||||
|
answer: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
class TestRAGCompletionRetriever:
|
class TestRAGCompletionRetriever:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rag_completion_context_simple(self):
|
async def test_rag_completion_context_simple(self):
|
||||||
|
|
@ -202,3 +208,76 @@ class TestRAGCompletionRetriever:
|
||||||
|
|
||||||
context = await retriever.get_context("Christina Mayer")
|
context = await retriever.get_context("Christina Mayer")
|
||||||
assert context == "", "Returned context should be empty on an empty graph"
|
assert context == "", "Returned context should be empty on an empty graph"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_rag_structured_completion(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_get_rag_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_get_rag_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document = TextDocument(
|
||||||
|
name="Steve Rodger's career",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = CompletionRetriever()
|
||||||
|
|
||||||
|
# Test with string response model (default)
|
||||||
|
string_answer = await retriever.get_completion("Where does Steve work?")
|
||||||
|
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
|
||||||
|
"Answer should not be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with structured response model
|
||||||
|
structured_answer = await retriever.get_completion(
|
||||||
|
"Where does Steve work?", response_model=TestAnswer
|
||||||
|
)
|
||||||
|
assert isinstance(structured_answer, list), (
|
||||||
|
f"Expected list, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
assert all(isinstance(item, TestAnswer) for item in structured_answer), (
|
||||||
|
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
|
||||||
|
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,13 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import cognee
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,6 +147,64 @@ async def test_filter_top_k_events_error_handling():
|
||||||
await tr.filter_top_k_events([{}], [])
|
await tr.filter_top_k_events([{}], [])
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnswer(BaseModel):
|
||||||
|
answer: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_temporal_structured_completion():
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_get_temporal_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_get_temporal_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
works_since: int
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
||||||
|
|
||||||
|
entities = [company1, person1]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = TemporalRetriever()
|
||||||
|
|
||||||
|
# Test with string response model (default)
|
||||||
|
string_answer = await retriever.get_completion("When did Steve start working at Figma?")
|
||||||
|
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
|
||||||
|
"Answer should not be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with structured response model
|
||||||
|
structured_answer = await retriever.get_completion(
|
||||||
|
"When did Steve start working at Figma??", response_model=TestAnswer
|
||||||
|
)
|
||||||
|
assert isinstance(structured_answer, list), (
|
||||||
|
f"Expected list, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
assert all(isinstance(item, TestAnswer) for item in structured_answer), (
|
||||||
|
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
|
||||||
|
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"
|
||||||
|
|
||||||
|
|
||||||
class _FakeRetriever(TemporalRetriever):
|
class _FakeRetriever(TemporalRetriever):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue