From 9df440c02040f0b18a6b8df420168dcc42e31feb Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:18:29 +0200 Subject: [PATCH] feat: adds time extraction + unit tests for temporal retriever --- .../modules/retrieval/temporal_retriever.py | 1 - cognee/tests/test_temporal_graph.py | 18 ++ .../retrieval/temporal_retriever_test.py | 223 ++++++++++++++++++ 3 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 cognee/tests/unit/modules/retrieval/temporal_retriever_test.py diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 61881bf7e..edd38489c 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -40,7 +40,6 @@ class TemporalRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, - save_interaction: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, diff --git a/cognee/tests/test_temporal_graph.py b/cognee/tests/test_temporal_graph.py index 998b780f7..9a9b2a93e 100644 --- a/cognee/tests/test_temporal_graph.py +++ b/cognee/tests/test_temporal_graph.py @@ -1,11 +1,14 @@ import asyncio import cognee +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever from cognee.shared.logging_utils import setup_logging, INFO +from cognee.tasks.temporal_graph.models import Timestamp from cognee.api.v1.search import SearchType from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from collections import Counter +from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int logger = get_logger() @@ -138,6 +141,21 @@ async def main(): "Expected the same amount of time_to and interval objects in the graph" ) + retriever = TemporalRetriever() + + result_before = await retriever.extract_time_from_query("What happened before 1890?") + + assert result_before[0] == None + + result_after = await retriever.extract_time_from_query("What happened after 1891?") + + assert result_after[1] == None + + result_between = await retriever.extract_time_from_query("What happened between 1890 and 1900?") + + assert result_between[1] + assert result_between[0] + if __name__ == "__main__": logger = setup_logging(log_level=INFO) diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py new file mode 100644 index 000000000..954dc398e --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -0,0 +1,223 @@ +import asyncio +from types import SimpleNamespace +import pytest + +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever + + +# Test TemporalRetriever initialization defaults and overrides +def test_init_defaults_and_overrides(): + tr = TemporalRetriever() + assert tr.top_k == 5 + assert tr.user_prompt_path == "graph_context_for_question.txt" + assert tr.system_prompt_path == "answer_simple_question.txt" + assert tr.time_extraction_prompt_path == "extract_query_time.txt" + + tr2 = TemporalRetriever( + top_k=3, + user_prompt_path="u.txt", + system_prompt_path="s.txt", + time_extraction_prompt_path="t.txt", + ) + assert tr2.top_k == 3 + assert tr2.user_prompt_path == "u.txt" + assert tr2.system_prompt_path == "s.txt" + assert tr2.time_extraction_prompt_path == "t.txt" + + +# Test descriptions_to_string with basic and empty results +def test_descriptions_to_string_basic_and_empty(): + tr = TemporalRetriever() + + results = [ + {"description": " First "}, + {"nope": "no description"}, + {"description": "Second"}, + {"description": ""}, + {"description": " Third line "}, + ] + + s = tr.descriptions_to_string(results) + assert s == "First\n#####################\nSecond\n#####################\nThird line" + + assert tr.descriptions_to_string([]) == "" + + +# Test filter_top_k_events sorts and limits correctly +@pytest.mark.asyncio +async def test_filter_top_k_events_sorts_and_limits(): + tr = TemporalRetriever(top_k=2) + + relevant_events = [ + { + "events": [ + {"id": "e1", "description": "E1"}, + {"id": "e2", "description": "E2"}, + {"id": "e3", "description": "E3 - not in vector results"}, + ] + } + ] + + scored_results = [ + SimpleNamespace(payload={"id": "e2"}, score=0.10), + SimpleNamespace(payload={"id": "e1"}, score=0.20), + ] + + top = await tr.filter_top_k_events(relevant_events, scored_results) + + assert [e["id"] for e in top] == ["e2", "e1"] + assert all("score" in e for e in top) + assert top[0]["score"] == 0.10 + assert top[1]["score"] == 0.20 + + +# Test filter_top_k_events handles unknown ids as infinite scores +@pytest.mark.asyncio +async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k(): + tr = TemporalRetriever(top_k=2) + + relevant_events = [ + { + "events": [ + {"id": "known1", "description": "Known 1"}, + {"id": "unknown", "description": "Unknown"}, + {"id": "known2", "description": "Known 2"}, + ] + } + ] + + scored_results = [ + SimpleNamespace(payload={"id": "known2"}, score=0.05), + SimpleNamespace(payload={"id": "known1"}, score=0.50), + ] + + top = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in top] == ["known2", "known1"] + assert all(e["score"] != float("inf") for e in top) + + +# Test descriptions_to_string with unicode and newlines +def test_descriptions_to_string_unicode_and_newlines(): + tr = TemporalRetriever() + results = [ + {"description": "Line A\nwith newline"}, + {"description": "This is a description"}, + ] + s = tr.descriptions_to_string(results) + assert "Line A\nwith newline" in s + assert "This is a description" in s + assert s.count("#####################") == 1 + + +# Test filter_top_k_events when top_k is larger than available events +@pytest.mark.asyncio +async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): + tr = TemporalRetriever(top_k=10) + relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] + scored_results = [ + SimpleNamespace(payload={"id": "a"}, score=0.1), + SimpleNamespace(payload={"id": "b"}, score=0.2), + ] + out = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in out] == ["a", "b"] + + +# Test filter_top_k_events when scored_results is empty +@pytest.mark.asyncio +async def test_filter_top_k_events_handles_empty_scored_results(): + tr = TemporalRetriever(top_k=2) + relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}] + scored_results = [] + out = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in out] == ["x", "y"] + assert all(e["score"] == float("inf") for e in out) + + +# Test filter_top_k_events error handling for missing structure +@pytest.mark.asyncio +async def test_filter_top_k_events_error_handling(): + tr = TemporalRetriever(top_k=2) + with pytest.raises((KeyError, TypeError)): + await tr.filter_top_k_events([{}], []) + + +class _FakeRetriever(TemporalRetriever): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._calls = [] + + async def extract_time_from_query(self, query: str): + if "both" in query: + return "2024-01-01", "2024-12-31" + if "from_only" in query: + return "2024-01-01", None + if "to_only" in query: + return None, "2024-12-31" + return None, None + + async def get_triplets(self, query: str): + self._calls.append(("get_triplets", query)) + return [{"s": "a", "p": "b", "o": "c"}] + + async def resolve_edges_to_text(self, triplets): + self._calls.append(("resolve_edges_to_text", len(triplets))) + return "edges->text" + + async def _fake_graph_collect_ids(self, **kwargs): + return ["e1", "e2"] + + async def _fake_graph_collect_events(self, ids): + return [{"events": [ + {"id": "e1", "description": "E1"}, + {"id": "e2", "description": "E2"}, + {"id": "e3", "description": "E3"}, + ]}] + + async def _fake_vector_embed(self, texts): + assert isinstance(texts, list) and texts + return [[0.0, 1.0, 2.0]] + + async def _fake_vector_search(self, **kwargs): + return [ + SimpleNamespace(payload={"id": "e2"}, score=0.05), + SimpleNamespace(payload={"id": "e1"}, score=0.10), + ] + + async def get_context(self, query: str): + time_from, time_to = await self.extract_time_from_query(query) + + if not (time_from or time_to): + triplets = await self.get_triplets(query) + return await self.resolve_edges_to_text(triplets) + + ids = await self._fake_graph_collect_ids( + time_from=time_from, time_to=time_to + ) + relevant_events = await self._fake_graph_collect_events(ids) + + _ = await self._fake_vector_embed([query]) + vector_search_results = await self._fake_vector_search( + collection_name="Event_name", query_vector=[0.0], limit=0 + ) + top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) + return self.descriptions_to_string(top_k_events) + + +# Test get_context fallback to triplets when no time is extracted +@pytest.mark.asyncio +async def test_fake_get_context_falls_back_to_triplets_when_no_time(): + tr = _FakeRetriever(top_k=2) + ctx = await tr.get_context("no_time") + assert ctx == "edges->text" + assert tr._calls[0][0] == "get_triplets" + assert tr._calls[1][0] == "resolve_edges_to_text" + + +# Test get_context when time is extracted and vector ranking is applied +@pytest.mark.asyncio +async def test_fake_get_context_with_time_filters_and_vector_ranking(): + tr = _FakeRetriever(top_k=2) + ctx = await tr.get_context("both time") + assert ctx.startswith("E2") + assert "#####################" in ctx + assert "E1" in ctx and "E3" not in ctx