feat: add batch queries to save interaction
This commit is contained in:
parent
abf1ef9d29
commit
5c8475a92a
6 changed files with 244 additions and 29 deletions
|
|
@ -104,10 +104,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
raise QueryValidationError(
|
||||
message="You cannot use batch queries with session saving currently."
|
||||
)
|
||||
if query_batch and self.save_interaction:
|
||||
raise QueryValidationError(
|
||||
message="Cannot use batch queries with interaction saving currently."
|
||||
)
|
||||
|
||||
is_query_valid, msg = validate_queries(query, query_batch)
|
||||
if not is_query_valid:
|
||||
|
|
@ -264,13 +260,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
for batched_query in query_batch:
|
||||
result_completion_batch.append(finished_queries_states[batched_query].completion)
|
||||
|
||||
# TODO: Do batch queries for save interaction
|
||||
if self.save_interaction and context_text_batch and triplets_batch and completion_batch:
|
||||
await self.save_qa(
|
||||
question=query,
|
||||
answer=completion_batch[0],
|
||||
context=context_text_batch[0],
|
||||
triplets=triplets_batch[0],
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.save_qa(
|
||||
question=batched_query,
|
||||
answer=finished_queries_states[batched_query].completion,
|
||||
context=finished_queries_states[batched_query].context_text,
|
||||
triplets=finished_queries_states[batched_query].triplets,
|
||||
)
|
||||
for batched_query in query_batch
|
||||
]
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
|
|||
|
|
@ -332,10 +332,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
raise QueryValidationError(
|
||||
message="You cannot use batch queries with session saving currently."
|
||||
)
|
||||
if query_batch and self.save_interaction:
|
||||
raise QueryValidationError(
|
||||
message="Cannot use batch queries with interaction saving currently."
|
||||
)
|
||||
|
||||
is_query_valid, msg = validate_queries(query, query_batch)
|
||||
if not is_query_valid:
|
||||
|
|
@ -355,14 +351,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
response_model=response_model,
|
||||
)
|
||||
|
||||
# TODO: Handle save interaction for batch queries
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query,
|
||||
answer=str(completion[0]),
|
||||
context=context_text[0],
|
||||
triplets=triplets[0],
|
||||
)
|
||||
if query_batch:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.save_qa(
|
||||
question=batched_query,
|
||||
answer=str(batched_completion),
|
||||
context=batched_context_text,
|
||||
triplets=batched_triplet,
|
||||
)
|
||||
for batched_query, batched_completion, batched_context_text, batched_triplet in zip(
|
||||
query_batch, completion, context_text, triplets
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
await self.save_qa(
|
||||
question=query,
|
||||
answer=str(completion[0]),
|
||||
context=context_text[0],
|
||||
triplets=triplets[0],
|
||||
)
|
||||
|
||||
# TODO: Handle session save interaction for batch queries
|
||||
# Save to session cache if enabled
|
||||
|
|
|
|||
|
|
@ -220,10 +220,6 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
raise QueryValidationError(
|
||||
message="You cannot use batch queries with session saving currently."
|
||||
)
|
||||
if query_batch and self.save_interaction:
|
||||
raise QueryValidationError(
|
||||
message="Cannot use batch queries with interaction saving currently."
|
||||
)
|
||||
|
||||
is_query_valid, msg = validate_queries(query, query_batch)
|
||||
if not is_query_valid:
|
||||
|
|
@ -236,7 +232,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
context_text = ""
|
||||
context_text_batch = []
|
||||
if triplets and isinstance(triplets[0], list):
|
||||
if query_batch:
|
||||
context_text_batch = await asyncio.gather(
|
||||
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
|
|
@ -284,9 +280,24 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
if query:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
await self.save_qa(
|
||||
question=batched_query,
|
||||
answer=batched_completion,
|
||||
context=batched_context_text,
|
||||
triplets=batched_triplets,
|
||||
)
|
||||
for batched_query, batched_completion, batched_context_text, batched_triplets in zip(
|
||||
query_batch, completion, context_text_batch, triplets
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if session_save:
|
||||
await save_conversation_history(
|
||||
|
|
|
|||
|
|
@ -804,3 +804,72 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
|
|||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_save_interaction(mock_edge):
|
||||
"""Test get_completion batch queries with save_interaction enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
mock_graph_engine.add_edges = AsyncMock()
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever(save_interaction=True)
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(
|
||||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[[mock_edge], [mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
"Generated answer",
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||
side_effect=[
|
||||
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"],
|
||||
context=[[mock_edge], [mock_edge]],
|
||||
context_extension_rounds=1,
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
mock_add_data.assert_awaited()
|
||||
|
|
|
|||
|
|
@ -853,3 +853,87 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
|
|||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_save_interaction(mock_edge):
|
||||
"""Test get_completion batch queries with save_interaction enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
mock_graph_engine.add_edges = AsyncMock()
|
||||
|
||||
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch.object(
|
||||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[[mock_edge], [mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||
return_value="Rendered prompt",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[
|
||||
"validation_result",
|
||||
"validation_result",
|
||||
"followup_question",
|
||||
"followup_question",
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||
side_effect=[
|
||||
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
# Pass context so save_interaction condition is met
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 2"],
|
||||
context=[[mock_edge], [mock_edge]],
|
||||
max_iter=1,
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
mock_add_data.assert_awaited()
|
||||
|
|
|
|||
|
|
@ -842,3 +842,44 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
|
|||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_with_save_interaction(mock_edge):
|
||||
"""Test get_completion batch queries with save_interaction."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever(save_interaction=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[[mock_edge], [mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
query_batch=["test query 1", "test query 1"], context=None
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 2
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue