fix: some new fixes

This commit is contained in:
Andrej Milicevic 2026-01-20 16:26:17 +01:00
parent a5494513d7
commit abf1ef9d29
3 changed files with 25 additions and 28 deletions

View file

@ -162,63 +162,58 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
)
break
relevant_queries = [
rel_query
for rel_query in finished_queries_states.keys()
if not finished_queries_states[rel_query].finished_extending_context
]
prev_sizes = [
len(batched_query_state.triplets)
for batched_query_state in finished_queries_states.values()
len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries
]
completions = await asyncio.gather(
*[
generate_completion(
query=batched_query,
context=batched_query_state.context_text,
query=rel_query,
context=finished_queries_states[rel_query].context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
for batched_query, batched_query_state in finished_queries_states.items()
if not batched_query_state.finished_extending_context
for rel_query in relevant_queries
],
)
# Get new triplets, and merge them with existing ones, filtering out duplicates
new_triplets_batch = await self.get_context(query_batch=completions)
for batched_query, batched_new_triplets in zip(
finished_queries_states.keys(), new_triplets_batch
):
finished_queries_states[batched_query].triplets = list(
for rel_query, batched_new_triplets in zip(relevant_queries, new_triplets_batch):
finished_queries_states[rel_query].triplets = list(
dict.fromkeys(
finished_queries_states[batched_query].triplets + batched_new_triplets
finished_queries_states[rel_query].triplets + batched_new_triplets
)
)
# Resolve new triplets to text
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(batched_query_state.triplets)
for batched_query_state in finished_queries_states.values()
if not batched_query_state.finished_extending_context
self.resolve_edges_to_text(finished_queries_states[rel_query].triplets)
for rel_query in relevant_queries
]
)
# Update context_texts in query states
for batched_query, batched_context_text in zip(
finished_queries_states.keys(), context_text_batch
):
if not finished_queries_states[batched_query].finished_extending_context:
finished_queries_states[batched_query].context_text = batched_context_text
for rel_query, batched_context_text in zip(relevant_queries, context_text_batch):
finished_queries_states[rel_query].context_text = batched_context_text
new_sizes = [
len(batched_query_state.triplets)
for batched_query_state in finished_queries_states.values()
len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries
]
for batched_query, prev_size, new_size in zip(
finished_queries_states.keys(), prev_sizes, new_sizes
):
for rel_query, prev_size, new_size in zip(relevant_queries, prev_sizes, new_sizes):
# Mark done queries accordingly
if prev_size == new_size:
finished_queries_states[batched_query].finished_extending_context = True
finished_queries_states[rel_query].finished_extending_context = True
logger.info(
f"Context extension: round {round_idx} - "

View file

@ -245,7 +245,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
]
)
for batched_query, batched_reasoning in zip(query_batch, reasoning_batch):
for batched_query, batched_reasoning in zip(
query_state_tracker.keys(), reasoning_batch
):
query_state_tracker[batched_query].reasoning = batched_reasoning
for batched_query, batched_query_state in query_state_tracker.items():
@ -274,7 +276,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
)
for batched_query, batched_followup_question in zip(
query_batch, followup_question_batch
query_state_tracker.keys(), followup_question_batch
):
query_state_tracker[batched_query].followup_question = batched_followup_question

View file

@ -806,7 +806,7 @@ async def test_get_completion_batch_queries_empty_context(mock_edge):
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion retrieves context when not provided."""
"""Test get_completion batch queries with duplicate queries."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)