Add top k [COG-1862] (#743)
<!-- .github/pull_request_template.md --> ## Description Add ability to define top-k for Cognee search types Insights, RAG and GRAPH Completion ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
f13607cf18
commit
da332e85fe
5 changed files with 26 additions and 9 deletions
|
|
@ -13,6 +13,7 @@ async def search(
|
|||
user: User = None,
|
||||
datasets: Union[list[str], str, None] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
) -> list:
|
||||
# We use lists from now on for datasets
|
||||
if isinstance(datasets, str):
|
||||
|
|
@ -25,7 +26,12 @@ async def search(
|
|||
raise UserNotFoundError
|
||||
|
||||
filtered_search_results = await search_function(
|
||||
query_text, query_type, datasets, user, system_prompt_path=system_prompt_path
|
||||
query_text,
|
||||
query_type,
|
||||
datasets,
|
||||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ class CompletionRetriever(BaseRetriever):
|
|||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
if len(found_chunks) == 0:
|
||||
raise NoRelevantDataFound
|
||||
return found_chunks[0].payload["text"]
|
||||
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
||||
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
||||
return "\n".join(chunks_payload)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates an LLM completion using the context."""
|
||||
|
|
|
|||
|
|
@ -28,12 +28,13 @@ async def search(
|
|||
datasets: list[str],
|
||||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
):
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(
|
||||
query_type, query_text, user, system_prompt_path=system_prompt_path
|
||||
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
|
||||
)
|
||||
|
||||
filtered_search_results = []
|
||||
|
|
@ -51,20 +52,26 @@ async def search(
|
|||
|
||||
|
||||
async def specific_search(
|
||||
query_type: SearchType, query: str, user: User, system_prompt_path="answer_simple_question.txt"
|
||||
query_type: SearchType,
|
||||
query: str,
|
||||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever().get_completion,
|
||||
SearchType.INSIGHTS: InsightsRetriever().get_completion,
|
||||
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CHUNKS: ChunksRetriever().get_completion,
|
||||
SearchType.RAG_COMPLETION: CompletionRetriever(
|
||||
system_prompt_path=system_prompt_path
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path
|
||||
system_prompt_path=system_prompt_path,
|
||||
).get_completion,
|
||||
SearchType.CODE: CodeRetriever().get_completion,
|
||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class TestCompletionRetriever:
|
|||
mock_render_prompt.return_value = "Rendered prompt with context"
|
||||
|
||||
mock_search_results = [MagicMock()]
|
||||
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
|
@ -59,6 +60,7 @@ class TestCompletionRetriever:
|
|||
query = "test query with custom prompt"
|
||||
|
||||
mock_search_results = [MagicMock()]
|
||||
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ async def test_search(
|
|||
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
||||
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
|
||||
mock_specific_search.assert_called_once_with(
|
||||
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt"
|
||||
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt", top_k=10
|
||||
)
|
||||
|
||||
# Only the first two results should be included (doc_id3 is filtered out)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue