fix: fix failing tests
This commit is contained in:
parent
40667e63c9
commit
2655df9b21
3 changed files with 22 additions and 16 deletions
|
|
@ -95,16 +95,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
if query:
|
||||
# This is done mostly to avoid duplicating a lot of code unnecessarily
|
||||
query_batch = [query]
|
||||
query = None
|
||||
if triplets:
|
||||
triplets = [triplets]
|
||||
|
||||
if triplets is None:
|
||||
triplets = await self.get_context(query, query_batch)
|
||||
triplets = await self.get_context(query_batch=query_batch)
|
||||
|
||||
context_text = ""
|
||||
context_texts = ""
|
||||
if isinstance(triplets[0], list):
|
||||
if triplets and isinstance(triplets[0], list):
|
||||
context_texts = await asyncio.gather(
|
||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
|
|
@ -123,6 +122,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||
)
|
||||
|
||||
# Filter out the queries that cannot be extended further, and their associated contexts
|
||||
query_batch = [query for query in query_batch if query]
|
||||
triplets = [triplet_element for triplet_element in triplets if triplet_element]
|
||||
context_texts = [context_text for context_text in context_texts if context_text]
|
||||
|
|
@ -181,8 +181,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
# Reset variables for the final generations. They contain the final state
|
||||
# of triplets and contexts for each query, after all extension iterations.
|
||||
query_batch = original_query_batch
|
||||
context_texts = saved_context_texts
|
||||
triplets = saved_triplets
|
||||
context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts
|
||||
triplets = saved_triplets if len(saved_triplets) > 0 else triplets
|
||||
|
||||
if len(query_batch) == 1:
|
||||
triplets = [] if not triplets else triplets[0]
|
||||
context_text = "" if not context_texts else context_texts[0]
|
||||
|
||||
# Check if we need to generate context summary for caching
|
||||
cache_config = CacheConfig()
|
||||
|
|
@ -221,8 +225,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
)
|
||||
|
||||
if self.save_interaction and context_text and triplets and completion:
|
||||
if isinstance(completion, list):
|
||||
completion = completion[0]
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
question=query, answer=completion[0], context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
if session_save:
|
||||
|
|
@ -233,4 +239,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
return completion if isinstance(completion, list) else [completion]
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
context_text = ""
|
||||
context_texts = ""
|
||||
if isinstance(triplets[0], list):
|
||||
if triplets and isinstance(triplets[0], list):
|
||||
context_texts = await asyncio.gather(
|
||||
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ async def test_get_completion_without_context(mock_edge):
|
|||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
return_value=[[mock_edge]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
|
|
@ -157,7 +157,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
|
|||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[mock_edge], [mock_edge2]],
|
||||
side_effect=[[[mock_edge]], [[mock_edge2]]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
|
|
@ -194,7 +194,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
|
|||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
|
|
@ -240,7 +240,7 @@ async def test_get_completion_with_session(mock_edge):
|
|||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
|
|
@ -304,7 +304,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
|||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
|
|
@ -361,7 +361,7 @@ async def test_get_completion_with_response_model(mock_edge):
|
|||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
|
|
@ -403,7 +403,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
|
|||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
|
|
@ -446,7 +446,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
|
|||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue