feat: adds time extraction + unit tests for temporal retriever
This commit is contained in:
parent
a4e59b7583
commit
9df440c020
3 changed files with 241 additions and 1 deletions
|
|
@ -40,7 +40,6 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
|
||||
from cognee.shared.logging_utils import setup_logging, INFO
|
||||
from cognee.tasks.temporal_graph.models import Timestamp
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from collections import Counter
|
||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -138,6 +141,21 @@ async def main():
|
|||
"Expected the same amount of time_to and interval objects in the graph"
|
||||
)
|
||||
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
result_before = await retriever.extract_time_from_query("What happened before 1890?")
|
||||
|
||||
assert result_before[0] == None
|
||||
|
||||
result_after = await retriever.extract_time_from_query("What happened after 1891?")
|
||||
|
||||
assert result_after[1] == None
|
||||
|
||||
result_between = await retriever.extract_time_from_query("What happened between 1890 and 1900?")
|
||||
|
||||
assert result_between[1]
|
||||
assert result_between[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=INFO)
|
||||
|
|
|
|||
223
cognee/tests/unit/modules/retrieval/temporal_retriever_test.py
Normal file
223
cognee/tests/unit/modules/retrieval/temporal_retriever_test.py
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
|
||||
|
||||
# Test TemporalRetriever initialization defaults and overrides
|
||||
def test_init_defaults_and_overrides():
|
||||
tr = TemporalRetriever()
|
||||
assert tr.top_k == 5
|
||||
assert tr.user_prompt_path == "graph_context_for_question.txt"
|
||||
assert tr.system_prompt_path == "answer_simple_question.txt"
|
||||
assert tr.time_extraction_prompt_path == "extract_query_time.txt"
|
||||
|
||||
tr2 = TemporalRetriever(
|
||||
top_k=3,
|
||||
user_prompt_path="u.txt",
|
||||
system_prompt_path="s.txt",
|
||||
time_extraction_prompt_path="t.txt",
|
||||
)
|
||||
assert tr2.top_k == 3
|
||||
assert tr2.user_prompt_path == "u.txt"
|
||||
assert tr2.system_prompt_path == "s.txt"
|
||||
assert tr2.time_extraction_prompt_path == "t.txt"
|
||||
|
||||
|
||||
# Test descriptions_to_string with basic and empty results
|
||||
def test_descriptions_to_string_basic_and_empty():
|
||||
tr = TemporalRetriever()
|
||||
|
||||
results = [
|
||||
{"description": " First "},
|
||||
{"nope": "no description"},
|
||||
{"description": "Second"},
|
||||
{"description": ""},
|
||||
{"description": " Third line "},
|
||||
]
|
||||
|
||||
s = tr.descriptions_to_string(results)
|
||||
assert s == "First\n#####################\nSecond\n#####################\nThird line"
|
||||
|
||||
assert tr.descriptions_to_string([]) == ""
|
||||
|
||||
|
||||
# Test filter_top_k_events sorts and limits correctly
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_sorts_and_limits():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
|
||||
relevant_events = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "E1"},
|
||||
{"id": "e2", "description": "E2"},
|
||||
{"id": "e3", "description": "E3 - not in vector results"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "e2"}, score=0.10),
|
||||
SimpleNamespace(payload={"id": "e1"}, score=0.20),
|
||||
]
|
||||
|
||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
|
||||
assert [e["id"] for e in top] == ["e2", "e1"]
|
||||
assert all("score" in e for e in top)
|
||||
assert top[0]["score"] == 0.10
|
||||
assert top[1]["score"] == 0.20
|
||||
|
||||
|
||||
# Test filter_top_k_events handles unknown ids as infinite scores
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
|
||||
relevant_events = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "known1", "description": "Known 1"},
|
||||
{"id": "unknown", "description": "Unknown"},
|
||||
{"id": "known2", "description": "Known 2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "known2"}, score=0.05),
|
||||
SimpleNamespace(payload={"id": "known1"}, score=0.50),
|
||||
]
|
||||
|
||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in top] == ["known2", "known1"]
|
||||
assert all(e["score"] != float("inf") for e in top)
|
||||
|
||||
|
||||
# Test descriptions_to_string with unicode and newlines
|
||||
def test_descriptions_to_string_unicode_and_newlines():
|
||||
tr = TemporalRetriever()
|
||||
results = [
|
||||
{"description": "Line A\nwith newline"},
|
||||
{"description": "This is a description"},
|
||||
]
|
||||
s = tr.descriptions_to_string(results)
|
||||
assert "Line A\nwith newline" in s
|
||||
assert "This is a description" in s
|
||||
assert s.count("#####################") == 1
|
||||
|
||||
|
||||
# Test filter_top_k_events when top_k is larger than available events
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
|
||||
tr = TemporalRetriever(top_k=10)
|
||||
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "a"}, score=0.1),
|
||||
SimpleNamespace(payload={"id": "b"}, score=0.2),
|
||||
]
|
||||
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in out] == ["a", "b"]
|
||||
|
||||
|
||||
# Test filter_top_k_events when scored_results is empty
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_handles_empty_scored_results():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}]
|
||||
scored_results = []
|
||||
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in out] == ["x", "y"]
|
||||
assert all(e["score"] == float("inf") for e in out)
|
||||
|
||||
|
||||
# Test filter_top_k_events error handling for missing structure
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_error_handling():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
with pytest.raises((KeyError, TypeError)):
|
||||
await tr.filter_top_k_events([{}], [])
|
||||
|
||||
|
||||
class _FakeRetriever(TemporalRetriever):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._calls = []
|
||||
|
||||
async def extract_time_from_query(self, query: str):
|
||||
if "both" in query:
|
||||
return "2024-01-01", "2024-12-31"
|
||||
if "from_only" in query:
|
||||
return "2024-01-01", None
|
||||
if "to_only" in query:
|
||||
return None, "2024-12-31"
|
||||
return None, None
|
||||
|
||||
async def get_triplets(self, query: str):
|
||||
self._calls.append(("get_triplets", query))
|
||||
return [{"s": "a", "p": "b", "o": "c"}]
|
||||
|
||||
async def resolve_edges_to_text(self, triplets):
|
||||
self._calls.append(("resolve_edges_to_text", len(triplets)))
|
||||
return "edges->text"
|
||||
|
||||
async def _fake_graph_collect_ids(self, **kwargs):
|
||||
return ["e1", "e2"]
|
||||
|
||||
async def _fake_graph_collect_events(self, ids):
|
||||
return [{"events": [
|
||||
{"id": "e1", "description": "E1"},
|
||||
{"id": "e2", "description": "E2"},
|
||||
{"id": "e3", "description": "E3"},
|
||||
]}]
|
||||
|
||||
async def _fake_vector_embed(self, texts):
|
||||
assert isinstance(texts, list) and texts
|
||||
return [[0.0, 1.0, 2.0]]
|
||||
|
||||
async def _fake_vector_search(self, **kwargs):
|
||||
return [
|
||||
SimpleNamespace(payload={"id": "e2"}, score=0.05),
|
||||
SimpleNamespace(payload={"id": "e1"}, score=0.10),
|
||||
]
|
||||
|
||||
async def get_context(self, query: str):
|
||||
time_from, time_to = await self.extract_time_from_query(query)
|
||||
|
||||
if not (time_from or time_to):
|
||||
triplets = await self.get_triplets(query)
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
||||
ids = await self._fake_graph_collect_ids(
|
||||
time_from=time_from, time_to=time_to
|
||||
)
|
||||
relevant_events = await self._fake_graph_collect_events(ids)
|
||||
|
||||
_ = await self._fake_vector_embed([query])
|
||||
vector_search_results = await self._fake_vector_search(
|
||||
collection_name="Event_name", query_vector=[0.0], limit=0
|
||||
)
|
||||
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
||||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
|
||||
# Test get_context fallback to triplets when no time is extracted
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_get_context_falls_back_to_triplets_when_no_time():
|
||||
tr = _FakeRetriever(top_k=2)
|
||||
ctx = await tr.get_context("no_time")
|
||||
assert ctx == "edges->text"
|
||||
assert tr._calls[0][0] == "get_triplets"
|
||||
assert tr._calls[1][0] == "resolve_edges_to_text"
|
||||
|
||||
|
||||
# Test get_context when time is extracted and vector ranking is applied
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_get_context_with_time_filters_and_vector_ranking():
|
||||
tr = _FakeRetriever(top_k=2)
|
||||
ctx = await tr.get_context("both time")
|
||||
assert ctx.startswith("E2")
|
||||
assert "#####################" in ctx
|
||||
assert "E1" in ctx and "E3" not in ctx
|
||||
Loading…
Add table
Reference in a new issue