fix: retrievers work with duplicates, added tests also, plus PR change requests done

This commit is contained in:
Andrej Milicevic 2026-01-20 15:11:58 +01:00
parent 8eee3990f7
commit a5494513d7
8 changed files with 244 additions and 55 deletions

View file

@ -40,3 +40,13 @@ class CollectionDistancesNotFoundError(CogneeValidationError):
status_code: int = status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class QueryValidationError(CogneeValidationError):
def __init__(
self,
message: str = "Queries not supplied in the correct format.",
name: str = "QueryValidationError",
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)

View file

@ -1,6 +1,8 @@
import asyncio
from typing import Optional, List, Type, Any
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.query_state import QueryState
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -15,19 +17,6 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger()
class QueryState:
"""
Helper class containing all necessary information about the query state:
the triplets and context associated with it, and also a check whether
it has fully extended the context.
"""
def __init__(self, triplets: List[Edge], context_text: str, finished_extending_context: bool):
self.triplets = triplets
self.context_text = context_text
self.finished_extending_context = finished_extending_context
class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
"""
Handles graph context completion for question answering tasks, extending context based
@ -112,13 +101,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise ValueError("You cannot use batch queries with session saving currently.")
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise ValueError("Cannot use batch queries with interaction saving currently.")
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise ValueError(msg)
raise QueryValidationError(message=msg)
triplets_batch = context
@ -190,7 +183,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
# Get new triplets, and merge them with existing ones, filtering out duplicates
new_triplets_batch = await self.get_context(query_batch=completions)
for batched_query, batched_new_triplets in zip(query_batch, new_triplets_batch):
for batched_query, batched_new_triplets in zip(
finished_queries_states.keys(), new_triplets_batch
):
finished_queries_states[batched_query].triplets = list(
dict.fromkeys(
finished_queries_states[batched_query].triplets + batched_new_triplets
@ -207,7 +202,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
)
# Update context_texts in query states
for batched_query, batched_context_text in zip(query_batch, context_text_batch):
for batched_query, batched_context_text in zip(
finished_queries_states.keys(), context_text_batch
):
if not finished_queries_states[batched_query].finished_extending_context:
finished_queries_states[batched_query].context_text = batched_context_text
@ -216,7 +213,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
for batched_query_state in finished_queries_states.values()
]
for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes):
for batched_query, prev_size, new_size in zip(
finished_queries_states.keys(), prev_sizes, new_sizes
):
# Mark done queries accordingly
if prev_size == new_size:
finished_queries_states[batched_query].finished_extending_context = True
@ -229,6 +228,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
round_idx += 1
completion_batch = []
result_completion_batch = []
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
@ -260,6 +260,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
],
)
# Make sure answers are returned for duplicate queries, in the order they were asked.
for batched_query, batched_completion in zip(
finished_queries_states.keys(), completion_batch
):
finished_queries_states[batched_query].completion = batched_completion
for batched_query in query_batch:
result_completion_batch.append(finished_queries_states[batched_query].completion)
# TODO: Do batch queries for save interaction
if self.save_interaction and context_text_batch and triplets_batch and completion_batch:
await self.save_qa(
@ -277,4 +286,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
session_id=session_id,
)
return completion_batch if completion_batch else [completion]
return result_completion_batch if result_completion_batch else [completion]

View file

@ -3,6 +3,8 @@ import json
from typing import Optional, List, Type, Any
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.query_state import QueryState
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.shared.logging_utils import get_logger
@ -23,28 +25,6 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger()
class QueryState:
"""
Helper class containing all necessary information about the query state.
Used to keep track of important information in a more readable way, and
enable as many parallel calls to llms as possible.
"""
def __init__(self):
self.completion: str = ""
self.triplets: List[Edge] = []
self.context_text: str = ""
self.answer_text: str = ""
self.valid_user_prompt: str = ""
self.valid_system_prompt: str = ""
self.reasoning: str = ""
self.followup_question: str = ""
self.followup_prompt: str = ""
self.followup_system: str = ""
def _as_answer_text(completion: Any) -> str:
"""Convert completion to human-readable text for validation and follow-up prompts."""
if isinstance(completion, str):
@ -155,7 +135,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
if round_idx == 0:
if context is None:
# Get context, resolve to text, and store info in the query state
triplets_batch = await self.get_context(query_batch=query_batch)
triplets_batch = await self.get_context(
query_batch=list(query_state_tracker.keys())
)
context_text_batch = await asyncio.gather(
*[
self.resolve_edges_to_text(batched_triplets)
@ -163,7 +145,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
]
)
for batched_query, batched_triplets, batched_context_text in zip(
query_batch, triplets_batch, context_text_batch
query_state_tracker.keys(), triplets_batch, context_text_batch
):
query_state_tracker[batched_query].triplets = batched_triplets
query_state_tracker[batched_query].context_text = batched_context_text
@ -176,7 +158,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
]
)
for batched_query, batched_triplets, batched_context_text in zip(
query_batch, context, context_text_batch
query_state_tracker.keys(), context, context_text_batch
):
query_state_tracker[batched_query].triplets = batched_triplets
query_state_tracker[batched_query].context_text = batched_context_text
@ -184,7 +166,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
# Find new triplets, and update existing query states
triplets_batch = await self.get_context(query_batch=followup_question_batch)
for batched_query, batched_followup_triplets in zip(query_batch, triplets_batch):
for batched_query, batched_followup_triplets in zip(
query_state_tracker.keys(), triplets_batch
):
query_state_tracker[batched_query].triplets = list(
dict.fromkeys(
query_state_tracker[batched_query].triplets + batched_followup_triplets
@ -198,7 +182,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
]
)
for batched_query, batched_context_text in zip(query_batch, context_text_batch):
for batched_query, batched_context_text in zip(
query_state_tracker.keys(), context_text_batch
):
query_state_tracker[batched_query].context_text = batched_context_text
completion_batch = await asyncio.gather(
@ -216,9 +202,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
]
)
for batched_query, batched_completion in zip(query_batch, completion_batch):
for batched_query, batched_completion in zip(
query_state_tracker.keys(), completion_batch
):
query_state_tracker[batched_query].completion = batched_completion
if round_idx == max_iter:
# When we finish all iterations:
# Make sure answers are returned for duplicate queries, in the order they were asked.
completion_batch = []
for batched_query in query_batch:
completion_batch.append(query_state_tracker[batched_query].completion)
logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}")
if round_idx < max_iter:
@ -332,13 +327,17 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise ValueError("You cannot use batch queries with session saving currently.")
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise ValueError("Cannot use batch queries with interaction saving currently.")
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise ValueError(msg)
raise QueryValidationError(message=msg)
# Load conversation history if enabled
conversation_history = ""

View file

@ -4,6 +4,7 @@ from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError
from cognee.modules.retrieval.utils.validate_queries import validate_queries
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
@ -216,13 +217,17 @@ class GraphCompletionRetriever(BaseGraphRetriever):
session_save = user_id and cache_config.caching
if query_batch and session_save:
raise ValueError("You cannot use batch queries with session saving currently.")
raise QueryValidationError(
message="You cannot use batch queries with session saving currently."
)
if query_batch and self.save_interaction:
raise ValueError("Cannot use batch queries with interaction saving currently.")
raise QueryValidationError(
message="Cannot use batch queries with interaction saving currently."
)
is_query_valid, msg = validate_queries(query, query_batch)
if not is_query_valid:
raise ValueError(msg)
raise QueryValidationError(message=msg)
triplets = context

View file

@ -0,0 +1,34 @@
from typing import List
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
class QueryState:
"""
Helper class containing all necessary information about the query state.
Used (for now) in COT and Context Extension Retrievers to keep track of important information
in a more readable way, and enable as many parallel calls to llms as possible.
"""
def __init__(
self,
triplets: List[Edge] = None,
context_text: str = "",
finished_extending_context: bool = False,
):
# Mutual fields for COT and Context Extension
self.triplets = triplets if triplets else []
self.context_text = context_text
self.completion = ""
# Context Extension specific
self.finished_extending_context = finished_extending_context
# COT specific
self.answer_text: str = ""
self.valid_user_prompt: str = ""
self.valid_system_prompt: str = ""
self.reasoning: str = ""
self.followup_question: str = ""
self.followup_prompt: str = ""
self.followup_system: str = ""

View file

@ -747,3 +747,60 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge):
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion batch queries with duplicate queries."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionContextExtensionRetriever()
# Create a second edge for extension rounds
mock_edge2 = MagicMock(spec=Edge)
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=[
"Resolved context",
"Resolved context",
"Extended context",
"Extended context",
], # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
side_effect=[
"Extension query",
"Extension query",
"Generated answer",
"Generated answer",
],
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 2"], context_extension_rounds=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"

View file

@ -375,7 +375,7 @@ async def test_get_completion_with_provided_context(mock_edge):
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context=[[mock_edge]], max_iter=1)
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
assert isinstance(completion, list)
assert len(completion) == 1
@ -818,3 +818,38 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge):
assert isinstance(completion, list)
assert len(completion) == 2
assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel)
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion batch queries without provided context."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionCotRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
return_value="Generated answer",
),
):
completion = await retriever.get_completion(
query_batch=["test query 1", "test query 1"], max_iter=1
)
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"

View file

@ -802,3 +802,43 @@ async def test_get_completion_batch_queries_empty_context(mock_edge):
assert isinstance(completion, list)
assert len(completion) == 2
@pytest.mark.asyncio
async def test_get_completion_batch_queries_duplicate_queries(mock_edge):
"""Test get_completion retrieves context when not provided."""
mock_graph_engine = AsyncMock()
mock_graph_engine.is_empty = AsyncMock(return_value=False)
retriever = GraphCompletionRetriever()
with (
patch(
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[[mock_edge], [mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
) as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(query_batch=["test query 1", "test query 1"])
assert isinstance(completion, list)
assert len(completion) == 2
assert completion[0] == "Generated answer" and completion[1] == "Generated answer"