Add top k [COG-1862] (#743)

<!-- .github/pull_request_template.md -->

## Description
Add ability to define top-k for Cognee search types Insights, RAG and
GRAPH Completion

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Igor Ilic 2025-04-17 14:01:35 +02:00 committed by GitHub
parent f13607cf18
commit da332e85fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 26 additions and 9 deletions

View file

@ -13,6 +13,7 @@ async def search(
user: User = None,
datasets: Union[list[str], str, None] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
) -> list:
# We use lists from now on for datasets
if isinstance(datasets, str):
@ -25,7 +26,12 @@ async def search(
raise UserNotFoundError
filtered_search_results = await search_function(
query_text, query_type, datasets, user, system_prompt_path=system_prompt_path
query_text,
query_type,
datasets,
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
)
return filtered_search_results

View file

@ -26,7 +26,9 @@ class CompletionRetriever(BaseRetriever):
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
if len(found_chunks) == 0:
raise NoRelevantDataFound
return found_chunks[0].payload["text"]
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
return "\n".join(chunks_payload)
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates an LLM completion using the context."""

View file

@ -28,12 +28,13 @@ async def search(
datasets: list[str],
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
):
query = await log_query(query_text, query_type.value, user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(
query_type, query_text, user, system_prompt_path=system_prompt_path
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
)
filtered_search_results = []
@ -51,20 +52,26 @@ async def search(
async def specific_search(
query_type: SearchType, query: str, user: User, system_prompt_path="answer_simple_question.txt"
query_type: SearchType,
query: str,
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever().get_completion,
SearchType.INSIGHTS: InsightsRetriever().get_completion,
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever().get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path
system_prompt_path=system_prompt_path,
top_k=top_k,
).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path
system_prompt_path=system_prompt_path,
top_k=top_k,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path
system_prompt_path=system_prompt_path,
).get_completion,
SearchType.CODE: CodeRetriever().get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,

View file

@ -24,6 +24,7 @@ class TestCompletionRetriever:
mock_render_prompt.return_value = "Rendered prompt with context"
mock_search_results = [MagicMock()]
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
@ -59,6 +60,7 @@ class TestCompletionRetriever:
query = "test query with custom prompt"
mock_search_results = [MagicMock()]
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine

View file

@ -68,7 +68,7 @@ async def test_search(
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
mock_specific_search.assert_called_once_with(
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt"
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt", top_k=10
)
# Only the first two results should be included (doc_id3 is filtered out)