feat: add batch queries to save interaction

This commit is contained in:
Andrej Milicevic 2026-01-20 18:59:18 +01:00
parent abf1ef9d29
commit 5c8475a92a
6 changed files with 244 additions and 29 deletions

View file

@ -104,10 +104,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
@ -264,13 +260,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
for batched_query in query_batch:
result_completion_batch.append(finished_queries_states[batched_query].completion)
# TODO: Do batch queries for save interaction
if self.save_interaction and context_text_batch and triplets_batch and completion_batch:
await self.save_qa(
question=query,
answer=completion_batch[0],
context=context_text_batch[0],
triplets=triplets_batch[0],
await asyncio.gather(
*[
self.save_qa(
question=batched_query,
answer=finished_queries_states[batched_query].completion,
context=finished_queries_states[batched_query].context_text,
triplets=finished_queries_states[batched_query].triplets,
)
for batched_query in query_batch
]
)
if session_save:

View file

@ -332,10 +332,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
@ -355,14 +351,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
response_model=response_model,
)
# TODO: Handle save interaction for batch queries
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query,
answer=str(completion[0]),
context=context_text[0],
triplets=triplets[0],
)
if query_batch:
await asyncio.gather(
*[
self.save_qa(
question=batched_query,
answer=str(batched_completion),
context=batched_context_text,
triplets=batched_triplet,
)
for batched_query, batched_completion, batched_context_text, batched_triplet in zip(
query_batch, completion, context_text, triplets
)
]
)
else:
await self.save_qa(
question=query,
answer=str(completion[0]),
context=context_text[0],
triplets=triplets[0],
)
# TODO: Handle session save interaction for batch queries
# Save to session cache if enabled

View file

@ -220,10 +220,6 @@ class GraphCompletionRetriever(BaseGraphRetriever):
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
@ -236,7 +232,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
context_text = ""
context_text_batch = []
if triplets and isinstance(triplets[0], list):
if query_batch:
context_text_batch = await asyncio.gather(
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
@ -284,9 +280,24 @@ class GraphCompletionRetriever(BaseGraphRetriever):
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
)
if query:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
)
else:
await asyncio.gather(
*[
await self.save_qa(
question=batched_query,
answer=batched_completion,
context=batched_context_text,
triplets=batched_triplets,
)
for batched_query, batched_completion, batched_context_text, batched_triplets in zip(
query_batch, completion, context_text_batch, triplets
)
]
)
if session_save:
await save_conversation_history(

View file

@ -804,3 +804,72 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_save_interaction(mock_edge):
"""Test get_completion batch queries with save_interaction enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionContextExtensionRetriever(save_interaction=True)
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
], # Extension query, then final answer
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=[
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
context_extension_rounds=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
mock_add_data.assert_awaited()

View file

@ -853,3 +853,87 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_save_interaction(mock_edge):
"""Test get_completion batch queries with save_interaction enabled."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
mock_graph_engine.add_edges = AsyncMock()
retriever = GraphCompletionCotRetriever(save_interaction=True)
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
return_value="Rendered prompt",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
return_value="System prompt",
),
patch.object(
LLMGateway,
"acreate_structured_output",
new_callable=AsyncMock,
side_effect=[
"validation_result",
"validation_result",
"followup_question",
"followup_question",
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
side_effect=[
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440000"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
UUID("550e8400-e29b-41d4-a716-446655440001"),
],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
) as mock_add_data,
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
# Pass context so save_interaction condition is met
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"],
context=[[mock_edge], [mock_edge]],
max_iter=1,
)
assert isinstance(completion, list)
assert len(completion) == 2
mock_add_data.assert_awaited()

View file

@ -842,3 +842,44 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_batch_queries_with_save_interaction(mock_edge):
"""Test get_completion batch queries with save_interaction."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever(save_interaction=True)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 1"], context=None
)
assert isinstance(completion, list)
assert len(completion) == 2