diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 8afd8545c..8d8aa2be4 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -13,6 +13,7 @@ async def search( user: User = None, datasets: Union[list[str], str, None] = None, system_prompt_path: str = "answer_simple_question.txt", + top_k: int = 10, ) -> list: # We use lists from now on for datasets if isinstance(datasets, str): @@ -25,7 +26,12 @@ async def search( raise UserNotFoundError filtered_search_results = await search_function( - query_text, query_type, datasets, user, system_prompt_path=system_prompt_path + query_text, + query_type, + datasets, + user, + system_prompt_path=system_prompt_path, + top_k=top_k, ) return filtered_search_results diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index cf8600f27..fba011cf5 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -26,7 +26,9 @@ class CompletionRetriever(BaseRetriever): found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) if len(found_chunks) == 0: raise NoRelevantDataFound - return found_chunks[0].payload["text"] + # Combine all chunks text returned from vector search (number of chunks is determined by top_k + chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks] + return "\n".join(chunks_payload) async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """Generates an LLM completion using the context.""" diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 4a79a29a8..8466686a2 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -28,12 +28,13 @@ async def search( datasets: list[str], user: User, system_prompt_path="answer_simple_question.txt", + top_k: int = 10, ): query = await log_query(query_text, query_type.value, user.id) own_document_ids = await get_document_ids_for_user(user.id, datasets) search_results = await specific_search( - query_type, query_text, user, system_prompt_path=system_prompt_path + query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k ) filtered_search_results = [] @@ -51,20 +52,26 @@ async def search( async def specific_search( - query_type: SearchType, query: str, user: User, system_prompt_path="answer_simple_question.txt" + query_type: SearchType, + query: str, + user: User, + system_prompt_path="answer_simple_question.txt", + top_k: int = 10, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever().get_completion, - SearchType.INSIGHTS: InsightsRetriever().get_completion, + SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, SearchType.CHUNKS: ChunksRetriever().get_completion, SearchType.RAG_COMPLETION: CompletionRetriever( - system_prompt_path=system_prompt_path + system_prompt_path=system_prompt_path, + top_k=top_k, ).get_completion, SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( - system_prompt_path=system_prompt_path + system_prompt_path=system_prompt_path, + top_k=top_k, ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( - system_prompt_path=system_prompt_path + system_prompt_path=system_prompt_path, ).get_completion, SearchType.CODE: CodeRetriever().get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, diff --git a/cognee/tests/unit/modules/retrieval/completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/completion_retriever_test.py index 1eace3cf1..0da518008 100644 --- a/cognee/tests/unit/modules/retrieval/completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/completion_retriever_test.py @@ -24,6 +24,7 @@ class TestCompletionRetriever: mock_render_prompt.return_value = "Rendered prompt with context" mock_search_results = [MagicMock()] + mock_search_results[0].payload = {"text": "This is a sample document chunk."} mock_vector_engine = AsyncMock() mock_vector_engine.search.return_value = mock_search_results mock_get_vector_engine.return_value = mock_vector_engine @@ -59,6 +60,7 @@ class TestCompletionRetriever: query = "test query with custom prompt" mock_search_results = [MagicMock()] + mock_search_results[0].payload = {"text": "This is a sample document chunk."} mock_vector_engine = AsyncMock() mock_vector_engine.search.return_value = mock_search_results mock_get_vector_engine.return_value = mock_vector_engine diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index a077178f8..f8e440ca4 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -68,7 +68,7 @@ async def test_search( mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) mock_get_document_ids.assert_called_once_with(mock_user.id, datasets) mock_specific_search.assert_called_once_with( - query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt" + query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt", top_k=10 ) # Only the first two results should be included (doc_id3 is filtered out)