feat: adds integration test for temporal retriever
This commit is contained in:
parent
a14dacdc0f
commit
4791d255be
1 changed files with 306 additions and 0 deletions
306
cognee/tests/integration/retrieval/test_temporal_retriever.py
Normal file
306
cognee/tests/integration/retrieval/test_temporal_retriever.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.modules.engine.models.Event import Event
|
||||
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||
from cognee.modules.engine.models.Interval import Interval
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_events():
|
||||
"""Set up a clean test environment with temporal events."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
# Create timestamps for events
|
||||
timestamp1 = Timestamp(
|
||||
time_at=1609459200, # 2021-01-01 00:00:00
|
||||
year=2021,
|
||||
month=1,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-01-01T00:00:00",
|
||||
)
|
||||
|
||||
timestamp2 = Timestamp(
|
||||
time_at=1612137600, # 2021-02-01 00:00:00
|
||||
year=2021,
|
||||
month=2,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-02-01T00:00:00",
|
||||
)
|
||||
|
||||
timestamp3 = Timestamp(
|
||||
time_at=1614556800, # 2021-03-01 00:00:00
|
||||
year=2021,
|
||||
month=3,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-03-01T00:00:00",
|
||||
)
|
||||
|
||||
timestamp4 = Timestamp(
|
||||
time_at=1625097600, # 2021-07-01 00:00:00
|
||||
year=2021,
|
||||
month=7,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-07-01T00:00:00",
|
||||
)
|
||||
|
||||
timestamp5 = Timestamp(
|
||||
time_at=1633046400, # 2021-10-01 00:00:00
|
||||
year=2021,
|
||||
month=10,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-10-01T00:00:00",
|
||||
)
|
||||
|
||||
# Create interval for event spanning multiple timestamps
|
||||
interval1 = Interval(time_from=timestamp2, time_to=timestamp3)
|
||||
|
||||
# Create events with timestamps
|
||||
event1 = Event(
|
||||
name="Project Alpha Launch",
|
||||
description="Launched Project Alpha at the beginning of 2021",
|
||||
at=timestamp1,
|
||||
location="San Francisco",
|
||||
)
|
||||
|
||||
event2 = Event(
|
||||
name="Team Meeting",
|
||||
description="Monthly team meeting discussing Q1 goals",
|
||||
during=interval1,
|
||||
location="New York",
|
||||
)
|
||||
|
||||
event3 = Event(
|
||||
name="Product Release",
|
||||
description="Released new product features in July",
|
||||
at=timestamp4,
|
||||
location="Remote",
|
||||
)
|
||||
|
||||
event4 = Event(
|
||||
name="Company Retreat",
|
||||
description="Annual company retreat in October",
|
||||
at=timestamp5,
|
||||
location="Lake Tahoe",
|
||||
)
|
||||
|
||||
entities = [event1, event2, event3, event4]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_graph_data():
|
||||
"""Set up a clean test environment with graph data (for fallback to triplets)."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma", description="Figma is a company")
|
||||
person1 = Person(
|
||||
name="Steve Rodger",
|
||||
description="This is description about Steve Rodger",
|
||||
works_for=company1,
|
||||
)
|
||||
|
||||
entities = [company1, person1]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever can retrieve events within time range."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
context = await retriever.get_context("What happened in January 2021?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
assert "Project Alpha" in context or "Launch" in context, (
|
||||
"Should retrieve Project Alpha Launch event from January 2021"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever can retrieve events at specific time."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
context = await retriever.get_context("What happened in July 2021?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
assert "Product Release" in context or "July" in context, (
|
||||
"Should retrieve Product Release event from July 2021"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_context_fallback_to_triplets(
|
||||
setup_test_environment_with_graph_data,
|
||||
):
|
||||
"""Integration test: verify TemporalRetriever falls back to triplets when no time extracted."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
assert "Steve" in context or "Figma" in context, (
|
||||
"Should retrieve graph data via triplet search fallback"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty):
|
||||
"""Integration test: verify TemporalRetriever handles empty graph correctly."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
context = await retriever.get_context("What happened?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) >= 0, "Context should be a string (possibly empty)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_get_completion(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever can generate completions."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
completion = await retriever.get_completion("What happened in January 2021?")
|
||||
|
||||
assert isinstance(completion, list), "Completion should be a list"
|
||||
assert len(completion) > 0, "Completion should not be empty"
|
||||
assert all(isinstance(item, str) and item.strip() for item in completion), (
|
||||
"Completion items should be non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data):
|
||||
"""Integration test: verify TemporalRetriever get_completion works with triplet fallback."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
completion = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(completion, list), "Completion should be a list"
|
||||
assert len(completion) > 0, "Completion should not be empty"
|
||||
assert all(isinstance(item, str) and item.strip() for item in completion), (
|
||||
"Completion items should be non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever respects top_k parameter."""
|
||||
retriever = TemporalRetriever(top_k=2)
|
||||
|
||||
context = await retriever.get_context("What happened in 2021?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
separator_count = context.count("#####################")
|
||||
assert separator_count <= 1, "Should respect top_k limit of 2 events"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_multiple_events(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever can retrieve multiple events."""
|
||||
retriever = TemporalRetriever(top_k=10)
|
||||
|
||||
context = await retriever.get_context("What events occurred in 2021?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
|
||||
assert (
|
||||
"Project Alpha" in context
|
||||
or "Team Meeting" in context
|
||||
or "Product Release" in context
|
||||
or "Company Retreat" in context
|
||||
), "Should retrieve at least one event from 2021"
|
||||
Loading…
Add table
Reference in a new issue