feat: increasing coverage of temporal retriever
This commit is contained in:
parent
29f94e39b9
commit
21a84d3100
1 changed files with 339 additions and 84 deletions
|
|
@ -141,90 +141,6 @@ async def test_filter_top_k_events_error_handling():
|
|||
await tr.filter_top_k_events([{}], [])
|
||||
|
||||
|
||||
class _FakeRetriever(TemporalRetriever):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._calls = []
|
||||
|
||||
async def extract_time_from_query(self, query: str):
|
||||
if "both" in query:
|
||||
return "2024-01-01", "2024-12-31"
|
||||
if "from_only" in query:
|
||||
return "2024-01-01", None
|
||||
if "to_only" in query:
|
||||
return None, "2024-12-31"
|
||||
return None, None
|
||||
|
||||
async def get_triplets(self, query: str):
|
||||
self._calls.append(("get_triplets", query))
|
||||
return [{"s": "a", "p": "b", "o": "c"}]
|
||||
|
||||
async def resolve_edges_to_text(self, triplets):
|
||||
self._calls.append(("resolve_edges_to_text", len(triplets)))
|
||||
return "edges->text"
|
||||
|
||||
async def _fake_graph_collect_ids(self, **kwargs):
|
||||
return ["e1", "e2"]
|
||||
|
||||
async def _fake_graph_collect_events(self, ids):
|
||||
return [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "E1"},
|
||||
{"id": "e2", "description": "E2"},
|
||||
{"id": "e3", "description": "E3"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
async def _fake_vector_embed(self, texts):
|
||||
assert isinstance(texts, list) and texts
|
||||
return [[0.0, 1.0, 2.0]]
|
||||
|
||||
async def _fake_vector_search(self, **kwargs):
|
||||
return [
|
||||
SimpleNamespace(payload={"id": "e2"}, score=0.05),
|
||||
SimpleNamespace(payload={"id": "e1"}, score=0.10),
|
||||
]
|
||||
|
||||
async def get_context(self, query: str):
|
||||
time_from, time_to = await self.extract_time_from_query(query)
|
||||
|
||||
if not (time_from or time_to):
|
||||
triplets = await self.get_triplets(query)
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
||||
ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to)
|
||||
relevant_events = await self._fake_graph_collect_events(ids)
|
||||
|
||||
_ = await self._fake_vector_embed([query])
|
||||
vector_search_results = await self._fake_vector_search(
|
||||
collection_name="Event_name", query_vector=[0.0], limit=0
|
||||
)
|
||||
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
||||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
|
||||
# Test get_context fallback to triplets when no time is extracted
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_get_context_falls_back_to_triplets_when_no_time():
|
||||
tr = _FakeRetriever(top_k=2)
|
||||
ctx = await tr.get_context("no_time")
|
||||
assert ctx == "edges->text"
|
||||
assert tr._calls[0][0] == "get_triplets"
|
||||
assert tr._calls[1][0] == "resolve_edges_to_text"
|
||||
|
||||
|
||||
# Test get_context when time is extracted and vector ranking is applied
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_get_context_with_time_filters_and_vector_ranking():
|
||||
tr = _FakeRetriever(top_k=2)
|
||||
ctx = await tr.get_context("both time")
|
||||
assert ctx.startswith("E2")
|
||||
assert "#####################" in ctx
|
||||
assert "E1" in ctx and "E3" not in ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_engine():
|
||||
"""Create a mock graph engine."""
|
||||
|
|
@ -343,3 +259,342 @@ async def test_get_context_no_events_found(mock_graph_engine):
|
|||
assert context == "triplet text"
|
||||
mock_get_triplets.assert_awaited_once_with("test query")
|
||||
mock_resolve.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_context with only time_from."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
context = await retriever.get_context("What happened after 2024?")
|
||||
|
||||
assert isinstance(context, str)
|
||||
assert "Event 1" in context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_context with only time_to."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
context = await retriever.get_context("What happened before 2024?")
|
||||
|
||||
assert isinstance(context, str)
|
||||
assert "Event 1" in context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("What happened in 2024?")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context():
|
||||
"""Test get_completion uses provided context."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context="Provided context")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_completion with session caching enabled."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_conversation_history",
|
||||
return_value="Previous conversation",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.summarize_text",
|
||||
return_value="Context summary",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.save_conversation_history",
|
||||
) as mock_save,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"What happened in 2024?", session_id="test_session"
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_save.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_completion with session config but no user ID."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = None # No user
|
||||
|
||||
completion = await retriever.get_completion("What happened in 2024?")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_context_retrieved_but_empty(mock_graph_engine):
|
||||
"""Test get_completion when get_context returns empty string."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
) as mock_get_vector,
|
||||
patch.object(retriever, "filter_top_k_events", return_value=[]),
|
||||
):
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
mock_get_vector.return_value = mock_vector_engine
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": ""},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises((UnboundLocalError, NameError)):
|
||||
await retriever.get_completion("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"What happened in 2024?", response_model=TestModel
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue