feat: adds time extraction + unit tests for temporal retriever

This commit is contained in:
hajdul88 2025-09-01 15:18:29 +02:00
parent a4e59b7583
commit 9df440c020
3 changed files with 241 additions and 1 deletions

View file

@ -40,7 +40,6 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,

View file

@ -1,11 +1,14 @@
import asyncio
import cognee
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.shared.logging_utils import setup_logging, INFO
from cognee.tasks.temporal_graph.models import Timestamp
from cognee.api.v1.search import SearchType
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from collections import Counter
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
logger = get_logger()
@ -138,6 +141,21 @@ async def main():
"Expected the same amount of time_to and interval objects in the graph"
)
retriever = TemporalRetriever()
result_before = await retriever.extract_time_from_query("What happened before 1890?")
assert result_before[0] == None
result_after = await retriever.extract_time_from_query("What happened after 1891?")
assert result_after[1] == None
result_between = await retriever.extract_time_from_query("What happened between 1890 and 1900?")
assert result_between[1]
assert result_between[0]
if __name__ == "__main__":
logger = setup_logging(log_level=INFO)

View file

@ -0,0 +1,223 @@
import asyncio
from types import SimpleNamespace
import pytest
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
# Test TemporalRetriever initialization defaults and overrides
def test_init_defaults_and_overrides():
tr = TemporalRetriever()
assert tr.top_k == 5
assert tr.user_prompt_path == "graph_context_for_question.txt"
assert tr.system_prompt_path == "answer_simple_question.txt"
assert tr.time_extraction_prompt_path == "extract_query_time.txt"
tr2 = TemporalRetriever(
top_k=3,
user_prompt_path="u.txt",
system_prompt_path="s.txt",
time_extraction_prompt_path="t.txt",
)
assert tr2.top_k == 3
assert tr2.user_prompt_path == "u.txt"
assert tr2.system_prompt_path == "s.txt"
assert tr2.time_extraction_prompt_path == "t.txt"
# Test descriptions_to_string with basic and empty results
def test_descriptions_to_string_basic_and_empty():
tr = TemporalRetriever()
results = [
{"description": " First "},
{"nope": "no description"},
{"description": "Second"},
{"description": ""},
{"description": " Third line "},
]
s = tr.descriptions_to_string(results)
assert s == "First\n#####################\nSecond\n#####################\nThird line"
assert tr.descriptions_to_string([]) == ""
# Test filter_top_k_events sorts and limits correctly
@pytest.mark.asyncio
async def test_filter_top_k_events_sorts_and_limits():
tr = TemporalRetriever(top_k=2)
relevant_events = [
{
"events": [
{"id": "e1", "description": "E1"},
{"id": "e2", "description": "E2"},
{"id": "e3", "description": "E3 - not in vector results"},
]
}
]
scored_results = [
SimpleNamespace(payload={"id": "e2"}, score=0.10),
SimpleNamespace(payload={"id": "e1"}, score=0.20),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in top] == ["e2", "e1"]
assert all("score" in e for e in top)
assert top[0]["score"] == 0.10
assert top[1]["score"] == 0.20
# Test filter_top_k_events handles unknown ids as infinite scores
@pytest.mark.asyncio
async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k():
tr = TemporalRetriever(top_k=2)
relevant_events = [
{
"events": [
{"id": "known1", "description": "Known 1"},
{"id": "unknown", "description": "Unknown"},
{"id": "known2", "description": "Known 2"},
]
}
]
scored_results = [
SimpleNamespace(payload={"id": "known2"}, score=0.05),
SimpleNamespace(payload={"id": "known1"}, score=0.50),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in top] == ["known2", "known1"]
assert all(e["score"] != float("inf") for e in top)
# Test descriptions_to_string with unicode and newlines
def test_descriptions_to_string_unicode_and_newlines():
tr = TemporalRetriever()
results = [
{"description": "Line A\nwith newline"},
{"description": "This is a description"},
]
s = tr.descriptions_to_string(results)
assert "Line A\nwith newline" in s
assert "This is a description" in s
assert s.count("#####################") == 1
# Test filter_top_k_events when top_k is larger than available events
@pytest.mark.asyncio
async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
tr = TemporalRetriever(top_k=10)
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
scored_results = [
SimpleNamespace(payload={"id": "a"}, score=0.1),
SimpleNamespace(payload={"id": "b"}, score=0.2),
]
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["a", "b"]
# Test filter_top_k_events when scored_results is empty
@pytest.mark.asyncio
async def test_filter_top_k_events_handles_empty_scored_results():
tr = TemporalRetriever(top_k=2)
relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}]
scored_results = []
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["x", "y"]
assert all(e["score"] == float("inf") for e in out)
# Test filter_top_k_events error handling for missing structure
@pytest.mark.asyncio
async def test_filter_top_k_events_error_handling():
tr = TemporalRetriever(top_k=2)
with pytest.raises((KeyError, TypeError)):
await tr.filter_top_k_events([{}], [])
class _FakeRetriever(TemporalRetriever):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._calls = []
async def extract_time_from_query(self, query: str):
if "both" in query:
return "2024-01-01", "2024-12-31"
if "from_only" in query:
return "2024-01-01", None
if "to_only" in query:
return None, "2024-12-31"
return None, None
async def get_triplets(self, query: str):
self._calls.append(("get_triplets", query))
return [{"s": "a", "p": "b", "o": "c"}]
async def resolve_edges_to_text(self, triplets):
self._calls.append(("resolve_edges_to_text", len(triplets)))
return "edges->text"
async def _fake_graph_collect_ids(self, **kwargs):
return ["e1", "e2"]
async def _fake_graph_collect_events(self, ids):
return [{"events": [
{"id": "e1", "description": "E1"},
{"id": "e2", "description": "E2"},
{"id": "e3", "description": "E3"},
]}]
async def _fake_vector_embed(self, texts):
assert isinstance(texts, list) and texts
return [[0.0, 1.0, 2.0]]
async def _fake_vector_search(self, **kwargs):
return [
SimpleNamespace(payload={"id": "e2"}, score=0.05),
SimpleNamespace(payload={"id": "e1"}, score=0.10),
]
async def get_context(self, query: str):
time_from, time_to = await self.extract_time_from_query(query)
if not (time_from or time_to):
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)
ids = await self._fake_graph_collect_ids(
time_from=time_from, time_to=time_to
)
relevant_events = await self._fake_graph_collect_events(ids)
_ = await self._fake_vector_embed([query])
vector_search_results = await self._fake_vector_search(
collection_name="Event_name", query_vector=[0.0], limit=0
)
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
return self.descriptions_to_string(top_k_events)
# Test get_context fallback to triplets when no time is extracted
@pytest.mark.asyncio
async def test_fake_get_context_falls_back_to_triplets_when_no_time():
tr = _FakeRetriever(top_k=2)
ctx = await tr.get_context("no_time")
assert ctx == "edges->text"
assert tr._calls[0][0] == "get_triplets"
assert tr._calls[1][0] == "resolve_edges_to_text"
# Test get_context when time is extracted and vector ranking is applied
@pytest.mark.asyncio
async def test_fake_get_context_with_time_filters_and_vector_ranking():
tr = _FakeRetriever(top_k=2)
ctx = await tr.get_context("both time")
assert ctx.startswith("E2")
assert "#####################" in ctx
assert "E1" in ctx and "E3" not in ctx