fix: some new fixes
This commit is contained in:
parent
a5494513d7
commit
abf1ef9d29
3 changed files with 25 additions and 28 deletions
|
|
@ -162,63 +162,58 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
)
|
||||
break
|
||||
|
||||
relevant_queries = [
|
||||
rel_query
|
||||
for rel_query in finished_queries_states.keys()
|
||||
if not finished_queries_states[rel_query].finished_extending_context
|
||||
]
|
||||
|
||||
prev_sizes = [
|
||||
len(batched_query_state.triplets)
|
||||
for batched_query_state in finished_queries_states.values()
|
||||
len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries
|
||||
]
|
||||
|
||||
completions = await asyncio.gather(
|
||||
*[
|
||||
generate_completion(
|
||||
query=batched_query,
|
||||
context=batched_query_state.context_text,
|
||||
query=rel_query,
|
||||
context=finished_queries_states[rel_query].context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
for batched_query, batched_query_state in finished_queries_states.items()
|
||||
if not batched_query_state.finished_extending_context
|
||||
for rel_query in relevant_queries
|
||||
],
|
||||
)
|
||||
|
||||
# Get new triplets, and merge them with existing ones, filtering out duplicates
|
||||
new_triplets_batch = await self.get_context(query_batch=completions)
|
||||
for batched_query, batched_new_triplets in zip(
|
||||
finished_queries_states.keys(), new_triplets_batch
|
||||
):
|
||||
finished_queries_states[batched_query].triplets = list(
|
||||
for rel_query, batched_new_triplets in zip(relevant_queries, new_triplets_batch):
|
||||
finished_queries_states[rel_query].triplets = list(
|
||||
dict.fromkeys(
|
||||
finished_queries_states[batched_query].triplets + batched_new_triplets
|
||||
finished_queries_states[rel_query].triplets + batched_new_triplets
|
||||
)
|
||||
)
|
||||
|
||||
# Resolve new triplets to text
|
||||
context_text_batch = await asyncio.gather(
|
||||
*[
|
||||
self.resolve_edges_to_text(batched_query_state.triplets)
|
||||
for batched_query_state in finished_queries_states.values()
|
||||
if not batched_query_state.finished_extending_context
|
||||
self.resolve_edges_to_text(finished_queries_states[rel_query].triplets)
|
||||
for rel_query in relevant_queries
|
||||
]
|
||||
)
|
||||
|
||||
# Update context_texts in query states
|
||||
for batched_query, batched_context_text in zip(
|
||||
finished_queries_states.keys(), context_text_batch
|
||||
):
|
||||
if not finished_queries_states[batched_query].finished_extending_context:
|
||||
finished_queries_states[batched_query].context_text = batched_context_text
|
||||
for rel_query, batched_context_text in zip(relevant_queries, context_text_batch):
|
||||
finished_queries_states[rel_query].context_text = batched_context_text
|
||||
|
||||
new_sizes = [
|
||||
len(batched_query_state.triplets)
|
||||
for batched_query_state in finished_queries_states.values()
|
||||
len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries
|
||||
]
|
||||
|
||||
for batched_query, prev_size, new_size in zip(
|
||||
finished_queries_states.keys(), prev_sizes, new_sizes
|
||||
):
|
||||
for rel_query, prev_size, new_size in zip(relevant_queries, prev_sizes, new_sizes):
|
||||
# Mark done queries accordingly
|
||||
if prev_size == new_size:
|
||||
finished_queries_states[batched_query].finished_extending_context = True
|
||||
finished_queries_states[rel_query].finished_extending_context = True
|
||||
|
||||
logger.info(
|
||||
f"Context extension: round {round_idx} - "
|
||||
|
|
|
|||
|
|
@ -245,7 +245,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
]
|
||||
)
|
||||
|
||||
for batched_query, batched_reasoning in zip(query_batch, reasoning_batch):
|
||||
for batched_query, batched_reasoning in zip(
|
||||
query_state_tracker.keys(), reasoning_batch
|
||||
):
|
||||
query_state_tracker[batched_query].reasoning = batched_reasoning
|
||||
|
||||
for batched_query, batched_query_state in query_state_tracker.items():
|
||||
|
|
@ -274,7 +276,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
)
|
||||
|
||||
for batched_query, batched_followup_question in zip(
|
||||
query_batch, followup_question_batch
|
||||
query_state_tracker.keys(), followup_question_batch
|
||||
):
|
||||
query_state_tracker[batched_query].followup_question = batched_followup_question
|
||||
|
||||
|
|
|
|||
|
|
@ -806,7 +806,7 @@ async def test_get_completion_batch_queries_empty_context(mock_edge):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
"""Test get_completion batch queries with duplicate queries."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue