adds 100% coverage to completion
This commit is contained in:
parent
22af6a050a
commit
5f6560d6a7
1 changed files with 351 additions and 0 deletions
351
cognee/tests/unit/modules/retrieval/test_completion.py
Normal file
351
cognee/tests/unit/modules/retrieval/test_completion.py
Normal file
|
|
@ -0,0 +1,351 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateCompletion:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_completion_with_system_prompt(self):
|
||||||
|
"""Test generate_completion with provided system_prompt."""
|
||||||
|
mock_llm_response = "Generated answer"
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||||
|
return_value="User prompt text",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
|
||||||
|
result = await generate_completion(
|
||||||
|
query="What is AI?",
|
||||||
|
context="AI is artificial intelligence",
|
||||||
|
user_prompt_path="user_prompt.txt",
|
||||||
|
system_prompt_path="system_prompt.txt",
|
||||||
|
system_prompt="Custom system prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="User prompt text",
|
||||||
|
system_prompt="Custom system prompt",
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_completion_without_system_prompt(self):
|
||||||
|
"""Test generate_completion reads system_prompt from file when not provided."""
|
||||||
|
mock_llm_response = "Generated answer"
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||||
|
return_value="User prompt text",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||||
|
return_value="System prompt from file",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
|
||||||
|
result = await generate_completion(
|
||||||
|
query="What is AI?",
|
||||||
|
context="AI is artificial intelligence",
|
||||||
|
user_prompt_path="user_prompt.txt",
|
||||||
|
system_prompt_path="system_prompt.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="User prompt text",
|
||||||
|
system_prompt="System prompt from file",
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_completion_with_conversation_history(self):
|
||||||
|
"""Test generate_completion includes conversation_history in system_prompt."""
|
||||||
|
mock_llm_response = "Generated answer"
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||||
|
return_value="User prompt text",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||||
|
return_value="System prompt from file",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
|
||||||
|
result = await generate_completion(
|
||||||
|
query="What is AI?",
|
||||||
|
context="AI is artificial intelligence",
|
||||||
|
user_prompt_path="user_prompt.txt",
|
||||||
|
system_prompt_path="system_prompt.txt",
|
||||||
|
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
expected_system_prompt = (
|
||||||
|
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
|
||||||
|
+ "\nTASK:"
|
||||||
|
+ "System prompt from file"
|
||||||
|
)
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="User prompt text",
|
||||||
|
system_prompt=expected_system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self):
|
||||||
|
"""Test generate_completion includes conversation_history with custom system_prompt."""
|
||||||
|
mock_llm_response = "Generated answer"
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||||
|
return_value="User prompt text",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
|
||||||
|
result = await generate_completion(
|
||||||
|
query="What is AI?",
|
||||||
|
context="AI is artificial intelligence",
|
||||||
|
user_prompt_path="user_prompt.txt",
|
||||||
|
system_prompt_path="system_prompt.txt",
|
||||||
|
system_prompt="Custom system prompt",
|
||||||
|
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
expected_system_prompt = (
|
||||||
|
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
|
||||||
|
+ "\nTASK:"
|
||||||
|
+ "Custom system prompt"
|
||||||
|
)
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="User prompt text",
|
||||||
|
system_prompt=expected_system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_completion_with_response_model(self):
|
||||||
|
"""Test generate_completion with custom response_model."""
|
||||||
|
mock_response_model = MagicMock()
|
||||||
|
mock_llm_response = {"answer": "Generated answer"}
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||||
|
return_value="User prompt text",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||||
|
return_value="System prompt from file",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
|
||||||
|
result = await generate_completion(
|
||||||
|
query="What is AI?",
|
||||||
|
context="AI is artificial intelligence",
|
||||||
|
user_prompt_path="user_prompt.txt",
|
||||||
|
system_prompt_path="system_prompt.txt",
|
||||||
|
response_model=mock_response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="User prompt text",
|
||||||
|
system_prompt="System prompt from file",
|
||||||
|
response_model=mock_response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_completion_render_prompt_args(self):
|
||||||
|
"""Test generate_completion passes correct args to render_prompt."""
|
||||||
|
mock_llm_response = "Generated answer"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||||
|
return_value="User prompt text",
|
||||||
|
) as mock_render,
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||||
|
return_value="System prompt from file",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
|
||||||
|
await generate_completion(
|
||||||
|
query="What is AI?",
|
||||||
|
context="AI is artificial intelligence",
|
||||||
|
user_prompt_path="user_prompt.txt",
|
||||||
|
system_prompt_path="system_prompt.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_render.assert_called_once_with(
|
||||||
|
"user_prompt.txt", {"question": "What is AI?", "context": "AI is artificial intelligence"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSummarizeText:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_summarize_text_with_system_prompt(self):
|
||||||
|
"""Test summarize_text with provided system_prompt."""
|
||||||
|
mock_llm_response = "Summary text"
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm:
|
||||||
|
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||||
|
|
||||||
|
result = await summarize_text(
|
||||||
|
text="Long text to summarize",
|
||||||
|
system_prompt_path="summarize_search_results.txt",
|
||||||
|
system_prompt="Custom summary prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="Long text to summarize",
|
||||||
|
system_prompt="Custom summary prompt",
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_summarize_text_without_system_prompt(self):
|
||||||
|
"""Test summarize_text reads system_prompt from file when not provided."""
|
||||||
|
mock_llm_response = "Summary text"
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||||
|
return_value="System prompt from file",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||||
|
|
||||||
|
result = await summarize_text(
|
||||||
|
text="Long text to summarize",
|
||||||
|
system_prompt_path="summarize_search_results.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="Long text to summarize",
|
||||||
|
system_prompt="System prompt from file",
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_summarize_text_default_prompt_path(self):
|
||||||
|
"""Test summarize_text uses default prompt path when not provided."""
|
||||||
|
mock_llm_response = "Summary text"
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||||
|
return_value="Default system prompt",
|
||||||
|
) as mock_read,
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||||
|
|
||||||
|
result = await summarize_text(text="Long text to summarize")
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
mock_read.assert_called_once_with("summarize_search_results.txt")
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="Long text to summarize",
|
||||||
|
system_prompt="Default system prompt",
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_summarize_text_custom_prompt_path(self):
|
||||||
|
"""Test summarize_text uses custom prompt path when provided."""
|
||||||
|
mock_llm_response = "Summary text"
|
||||||
|
mock_llm_gateway = AsyncMock(return_value=mock_llm_response)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||||
|
return_value="Custom system prompt",
|
||||||
|
) as mock_read,
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_llm_response,
|
||||||
|
) as mock_llm,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||||
|
|
||||||
|
result = await summarize_text(
|
||||||
|
text="Long text to summarize",
|
||||||
|
system_prompt_path="custom_prompt.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_llm_response
|
||||||
|
mock_read.assert_called_once_with("custom_prompt.txt")
|
||||||
|
mock_llm.assert_awaited_once_with(
|
||||||
|
text_input="Long text to summarize",
|
||||||
|
system_prompt="Custom system prompt",
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
Loading…
Add table
Reference in a new issue