Add top k [COG-1862] (#743)
<!-- .github/pull_request_template.md --> ## Description Add ability to define top-k for Cognee search types Insights, RAG and GRAPH Completion ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
f13607cf18
commit
da332e85fe
5 changed files with 26 additions and 9 deletions
|
|
@ -13,6 +13,7 @@ async def search(
|
||||||
user: User = None,
|
user: User = None,
|
||||||
datasets: Union[list[str], str, None] = None,
|
datasets: Union[list[str], str, None] = None,
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
top_k: int = 10,
|
||||||
) -> list:
|
) -> list:
|
||||||
# We use lists from now on for datasets
|
# We use lists from now on for datasets
|
||||||
if isinstance(datasets, str):
|
if isinstance(datasets, str):
|
||||||
|
|
@ -25,7 +26,12 @@ async def search(
|
||||||
raise UserNotFoundError
|
raise UserNotFoundError
|
||||||
|
|
||||||
filtered_search_results = await search_function(
|
filtered_search_results = await search_function(
|
||||||
query_text, query_type, datasets, user, system_prompt_path=system_prompt_path
|
query_text,
|
||||||
|
query_type,
|
||||||
|
datasets,
|
||||||
|
user,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
return filtered_search_results
|
return filtered_search_results
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,9 @@ class CompletionRetriever(BaseRetriever):
|
||||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||||
if len(found_chunks) == 0:
|
if len(found_chunks) == 0:
|
||||||
raise NoRelevantDataFound
|
raise NoRelevantDataFound
|
||||||
return found_chunks[0].payload["text"]
|
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
||||||
|
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
||||||
|
return "\n".join(chunks_payload)
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Generates an LLM completion using the context."""
|
"""Generates an LLM completion using the context."""
|
||||||
|
|
|
||||||
|
|
@ -28,12 +28,13 @@ async def search(
|
||||||
datasets: list[str],
|
datasets: list[str],
|
||||||
user: User,
|
user: User,
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path="answer_simple_question.txt",
|
||||||
|
top_k: int = 10,
|
||||||
):
|
):
|
||||||
query = await log_query(query_text, query_type.value, user.id)
|
query = await log_query(query_text, query_type.value, user.id)
|
||||||
|
|
||||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||||
search_results = await specific_search(
|
search_results = await specific_search(
|
||||||
query_type, query_text, user, system_prompt_path=system_prompt_path
|
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
filtered_search_results = []
|
filtered_search_results = []
|
||||||
|
|
@ -51,20 +52,26 @@ async def search(
|
||||||
|
|
||||||
|
|
||||||
async def specific_search(
|
async def specific_search(
|
||||||
query_type: SearchType, query: str, user: User, system_prompt_path="answer_simple_question.txt"
|
query_type: SearchType,
|
||||||
|
query: str,
|
||||||
|
user: User,
|
||||||
|
system_prompt_path="answer_simple_question.txt",
|
||||||
|
top_k: int = 10,
|
||||||
) -> list:
|
) -> list:
|
||||||
search_tasks: dict[SearchType, Callable] = {
|
search_tasks: dict[SearchType, Callable] = {
|
||||||
SearchType.SUMMARIES: SummariesRetriever().get_completion,
|
SearchType.SUMMARIES: SummariesRetriever().get_completion,
|
||||||
SearchType.INSIGHTS: InsightsRetriever().get_completion,
|
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
|
||||||
SearchType.CHUNKS: ChunksRetriever().get_completion,
|
SearchType.CHUNKS: ChunksRetriever().get_completion,
|
||||||
SearchType.RAG_COMPLETION: CompletionRetriever(
|
SearchType.RAG_COMPLETION: CompletionRetriever(
|
||||||
system_prompt_path=system_prompt_path
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||||
system_prompt_path=system_prompt_path
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||||
system_prompt_path=system_prompt_path
|
system_prompt_path=system_prompt_path,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.CODE: CodeRetriever().get_completion,
|
SearchType.CODE: CodeRetriever().get_completion,
|
||||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ class TestCompletionRetriever:
|
||||||
mock_render_prompt.return_value = "Rendered prompt with context"
|
mock_render_prompt.return_value = "Rendered prompt with context"
|
||||||
|
|
||||||
mock_search_results = [MagicMock()]
|
mock_search_results = [MagicMock()]
|
||||||
|
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
|
||||||
mock_vector_engine = AsyncMock()
|
mock_vector_engine = AsyncMock()
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
mock_vector_engine.search.return_value = mock_search_results
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
mock_get_vector_engine.return_value = mock_vector_engine
|
||||||
|
|
@ -59,6 +60,7 @@ class TestCompletionRetriever:
|
||||||
query = "test query with custom prompt"
|
query = "test query with custom prompt"
|
||||||
|
|
||||||
mock_search_results = [MagicMock()]
|
mock_search_results = [MagicMock()]
|
||||||
|
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
|
||||||
mock_vector_engine = AsyncMock()
|
mock_vector_engine = AsyncMock()
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
mock_vector_engine.search.return_value = mock_search_results
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
mock_get_vector_engine.return_value = mock_vector_engine
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ async def test_search(
|
||||||
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
||||||
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
|
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
|
||||||
mock_specific_search.assert_called_once_with(
|
mock_specific_search.assert_called_once_with(
|
||||||
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt"
|
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt", top_k=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only the first two results should be included (doc_id3 is filtered out)
|
# Only the first two results should be included (doc_id3 is filtered out)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue