fix: fix failing tests
This commit is contained in:
parent
40667e63c9
commit
2655df9b21
3 changed files with 22 additions and 16 deletions
|
|
@ -95,16 +95,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
if query:
|
if query:
|
||||||
# This is done mostly to avoid duplicating a lot of code unnecessarily
|
# This is done mostly to avoid duplicating a lot of code unnecessarily
|
||||||
query_batch = [query]
|
query_batch = [query]
|
||||||
query = None
|
|
||||||
if triplets:
|
if triplets:
|
||||||
triplets = [triplets]
|
triplets = [triplets]
|
||||||
|
|
||||||
if triplets is None:
|
if triplets is None:
|
||||||
triplets = await self.get_context(query, query_batch)
|
triplets = await self.get_context(query_batch=query_batch)
|
||||||
|
|
||||||
context_text = ""
|
context_text = ""
|
||||||
context_texts = ""
|
context_texts = ""
|
||||||
if isinstance(triplets[0], list):
|
if triplets and isinstance(triplets[0], list):
|
||||||
context_texts = await asyncio.gather(
|
context_texts = await asyncio.gather(
|
||||||
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||||
)
|
)
|
||||||
|
|
@ -123,6 +122,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
f"Context extension: round {round_idx} - generating next graph locational query."
|
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Filter out the queries that cannot be extended further, and their associated contexts
|
||||||
query_batch = [query for query in query_batch if query]
|
query_batch = [query for query in query_batch if query]
|
||||||
triplets = [triplet_element for triplet_element in triplets if triplet_element]
|
triplets = [triplet_element for triplet_element in triplets if triplet_element]
|
||||||
context_texts = [context_text for context_text in context_texts if context_text]
|
context_texts = [context_text for context_text in context_texts if context_text]
|
||||||
|
|
@ -181,8 +181,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
# Reset variables for the final generations. They contain the final state
|
# Reset variables for the final generations. They contain the final state
|
||||||
# of triplets and contexts for each query, after all extension iterations.
|
# of triplets and contexts for each query, after all extension iterations.
|
||||||
query_batch = original_query_batch
|
query_batch = original_query_batch
|
||||||
context_texts = saved_context_texts
|
context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts
|
||||||
triplets = saved_triplets
|
triplets = saved_triplets if len(saved_triplets) > 0 else triplets
|
||||||
|
|
||||||
|
if len(query_batch) == 1:
|
||||||
|
triplets = [] if not triplets else triplets[0]
|
||||||
|
context_text = "" if not context_texts else context_texts[0]
|
||||||
|
|
||||||
# Check if we need to generate context summary for caching
|
# Check if we need to generate context summary for caching
|
||||||
cache_config = CacheConfig()
|
cache_config = CacheConfig()
|
||||||
|
|
@ -221,8 +225,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context_text and triplets and completion:
|
if self.save_interaction and context_text and triplets and completion:
|
||||||
|
if isinstance(completion, list):
|
||||||
|
completion = completion[0]
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion[0], context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
if session_save:
|
if session_save:
|
||||||
|
|
@ -233,4 +239,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [completion]
|
return completion if isinstance(completion, list) else [completion]
|
||||||
|
|
|
||||||
|
|
@ -197,7 +197,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
|
|
||||||
context_text = ""
|
context_text = ""
|
||||||
context_texts = ""
|
context_texts = ""
|
||||||
if isinstance(triplets[0], list):
|
if triplets and isinstance(triplets[0], list):
|
||||||
context_texts = await asyncio.gather(
|
context_texts = await asyncio.gather(
|
||||||
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ async def test_get_completion_without_context(mock_edge):
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||||
return_value=[mock_edge],
|
return_value=[[mock_edge]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
@ -157,7 +157,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
|
||||||
retriever,
|
retriever,
|
||||||
"get_context",
|
"get_context",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=[[mock_edge], [mock_edge2]],
|
side_effect=[[[mock_edge]], [[mock_edge2]]],
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
|
|
@ -194,7 +194,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
|
||||||
retriever = GraphCompletionContextExtensionRetriever()
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
return_value="Resolved context",
|
return_value="Resolved context",
|
||||||
|
|
@ -240,7 +240,7 @@ async def test_get_completion_with_session(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
return_value=mock_graph_engine,
|
return_value=mock_graph_engine,
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
return_value="Resolved context",
|
return_value="Resolved context",
|
||||||
|
|
@ -304,7 +304,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
return_value=mock_graph_engine,
|
return_value=mock_graph_engine,
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
return_value="Resolved context",
|
return_value="Resolved context",
|
||||||
|
|
@ -361,7 +361,7 @@ async def test_get_completion_with_response_model(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
return_value=mock_graph_engine,
|
return_value=mock_graph_engine,
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
return_value="Resolved context",
|
return_value="Resolved context",
|
||||||
|
|
@ -403,7 +403,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
return_value=mock_graph_engine,
|
return_value=mock_graph_engine,
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
return_value="Resolved context",
|
return_value="Resolved context",
|
||||||
|
|
@ -446,7 +446,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||||
return_value=mock_graph_engine,
|
return_value=mock_graph_engine,
|
||||||
),
|
),
|
||||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
|
||||||
patch(
|
patch(
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||||
return_value="Resolved context",
|
return_value="Resolved context",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue