feat: adds e2e conversation history test
This commit is contained in:
parent
a3bbeb1c10
commit
e9f4e2000f
2 changed files with 241 additions and 1 deletions
240
cognee/tests/conversation_history.py
Normal file
240
cognee/tests/conversation_history.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
"""
|
||||
End-to-end integration test for conversation history feature.
|
||||
|
||||
Tests all retrievers that save conversation history to Redis cache:
|
||||
1. GRAPH_COMPLETION
|
||||
2. RAG_COMPLETION
|
||||
3. GRAPH_COMPLETION_COT
|
||||
4. GRAPH_COMPLETION_CONTEXT_EXTENSION
|
||||
5. GRAPH_SUMMARY_COMPLETION
|
||||
6. TEMPORAL
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import cognee
|
||||
import pathlib
|
||||
|
||||
from cognee.infrastructure.databases.cache import get_cache_engine
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_conversation_history",
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_conversation_history",
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
dataset_name = "conversation_history_test"
|
||||
|
||||
text_1 = """TechCorp is a technology company based in San Francisco. They specialize in artificial intelligence and machine learning."""
|
||||
text_2 = (
|
||||
"""DataCo is a data analytics company. They help businesses make sense of their data."""
|
||||
)
|
||||
|
||||
await cognee.add(text_1, dataset_name)
|
||||
await cognee.add(text_2, dataset_name)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
cache_engine = get_cache_engine()
|
||||
assert cache_engine is not None, "Cache engine should be available for testing"
|
||||
|
||||
session_id_1 = "test_session_graph"
|
||||
|
||||
await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is TechCorp?",
|
||||
session_id=session_id_1,
|
||||
)
|
||||
|
||||
history1 = await cache_engine.get_latest_qa(str(user.id), session_id_1, last_n=10)
|
||||
assert len(history1) == 1, f"Expected at least 1 Q&A in history, got {len(history1)}"
|
||||
our_qa = [h for h in history1 if h["question"] == "What is TechCorp?"]
|
||||
assert len(our_qa) >= 1, f"Expected to find 'What is TechCorp?' in history"
|
||||
assert "answer" in our_qa[0] and "context" in our_qa[0], (
|
||||
"Q&A should contain answer and context fields"
|
||||
)
|
||||
|
||||
result2 = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="Tell me more about it",
|
||||
session_id=session_id_1,
|
||||
)
|
||||
|
||||
assert isinstance(result2, list) and len(result2) > 0, (
|
||||
f"Second query should return non-empty list, got: {result2!r}"
|
||||
)
|
||||
|
||||
history2 = await cache_engine.get_latest_qa(str(user.id), session_id_1, last_n=10)
|
||||
our_questions = [
|
||||
h for h in history2 if h["question"] in ["What is TechCorp?", "Tell me more about it"]
|
||||
]
|
||||
assert len(our_questions) == 2, (
|
||||
f"Expected at least 2 Q&A pairs in history after 2 queries, got {len(our_questions)}"
|
||||
)
|
||||
|
||||
session_id_2 = "test_session_separate"
|
||||
|
||||
result3 = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is DataCo?",
|
||||
session_id=session_id_2,
|
||||
)
|
||||
|
||||
assert isinstance(result3, list) and len(result3) > 0, (
|
||||
f"Different session should return non-empty list, got: {result3!r}"
|
||||
)
|
||||
|
||||
history3 = await cache_engine.get_latest_qa(str(user.id), session_id_2, last_n=10)
|
||||
our_qa_session2 = [h for h in history3 if h["question"] == "What is DataCo?"]
|
||||
assert len(our_qa_session2) == 1, f"Session 2 should have 'What is DataCo?' question"
|
||||
|
||||
result4 = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="Test default session",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(result4, list) and len(result4) > 0, (
|
||||
f"Default session should return non-empty list, got: {result4!r}"
|
||||
)
|
||||
|
||||
history_default = await cache_engine.get_latest_qa(str(user.id), "default_session", last_n=10)
|
||||
our_qa_default = [h for h in history_default if h["question"] == "Test default session"]
|
||||
assert len(our_qa_default) == 1, f"Should find 'Test default session' in default_session"
|
||||
|
||||
session_id_rag = "test_session_rag"
|
||||
|
||||
result_rag = await cognee.search(
|
||||
query_type=SearchType.RAG_COMPLETION,
|
||||
query_text="What companies are mentioned?",
|
||||
session_id=session_id_rag,
|
||||
)
|
||||
|
||||
assert isinstance(result_rag, list) and len(result_rag) > 0, (
|
||||
f"RAG_COMPLETION should return non-empty list, got: {result_rag!r}"
|
||||
)
|
||||
|
||||
history_rag = await cache_engine.get_latest_qa(str(user.id), session_id_rag, last_n=10)
|
||||
our_qa_rag = [h for h in history_rag if h["question"] == "What companies are mentioned?"]
|
||||
assert len(our_qa_rag) == 1, f"Should find RAG question in history"
|
||||
|
||||
session_id_cot = "test_session_cot"
|
||||
|
||||
result_cot = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION_COT,
|
||||
query_text="What do you know about TechCorp?",
|
||||
session_id=session_id_cot,
|
||||
)
|
||||
|
||||
assert isinstance(result_cot, list) and len(result_cot) > 0, (
|
||||
f"GRAPH_COMPLETION_COT should return non-empty list, got: {result_cot!r}"
|
||||
)
|
||||
|
||||
history_cot = await cache_engine.get_latest_qa(str(user.id), session_id_cot, last_n=10)
|
||||
our_qa_cot = [h for h in history_cot if h["question"] == "What do you know about TechCorp?"]
|
||||
assert len(our_qa_cot) == 1, f"Should find CoT question in history"
|
||||
|
||||
session_id_ext = "test_session_ext"
|
||||
|
||||
result_ext = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
|
||||
query_text="Tell me about DataCo",
|
||||
session_id=session_id_ext,
|
||||
)
|
||||
|
||||
assert isinstance(result_ext, list) and len(result_ext) > 0, (
|
||||
f"GRAPH_COMPLETION_CONTEXT_EXTENSION should return non-empty list, got: {result_ext!r}"
|
||||
)
|
||||
|
||||
history_ext = await cache_engine.get_latest_qa(str(user.id), session_id_ext, last_n=10)
|
||||
our_qa_ext = [h for h in history_ext if h["question"] == "Tell me about DataCo"]
|
||||
assert len(our_qa_ext) == 1, f"Should find Context Extension question in history"
|
||||
|
||||
session_id_summary = "test_session_summary"
|
||||
|
||||
result_summary = await cognee.search(
|
||||
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||
query_text="What are the key points about TechCorp?",
|
||||
session_id=session_id_summary,
|
||||
)
|
||||
|
||||
assert isinstance(result_summary, list) and len(result_summary) > 0, (
|
||||
f"GRAPH_SUMMARY_COMPLETION should return non-empty list, got: {result_summary!r}"
|
||||
)
|
||||
|
||||
# Verify saved
|
||||
history_summary = await cache_engine.get_latest_qa(str(user.id), session_id_summary, last_n=10)
|
||||
our_qa_summary = [
|
||||
h for h in history_summary if h["question"] == "What are the key points about TechCorp?"
|
||||
]
|
||||
assert len(our_qa_summary) == 1, f"Should find Summary question in history"
|
||||
|
||||
session_id_temporal = "test_session_temporal"
|
||||
|
||||
result_temporal = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL,
|
||||
query_text="Tell me about the companies",
|
||||
session_id=session_id_temporal,
|
||||
)
|
||||
|
||||
assert isinstance(result_temporal, list) and len(result_temporal) > 0, (
|
||||
f"TEMPORAL should return non-empty list, got: {result_temporal!r}"
|
||||
)
|
||||
|
||||
history_temporal = await cache_engine.get_latest_qa(
|
||||
str(user.id), session_id_temporal, last_n=10
|
||||
)
|
||||
our_qa_temporal = [
|
||||
h for h in history_temporal if h["question"] == "Tell me about the companies"
|
||||
]
|
||||
assert len(our_qa_temporal) == 1, f"Should find Temporal question in history"
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
formatted_history = await get_conversation_history(session_id=session_id_1)
|
||||
|
||||
assert "Previous conversation:" in formatted_history, (
|
||||
"Formatted history should contain 'Previous conversation:' header"
|
||||
)
|
||||
assert "QUESTION:" in formatted_history, "Formatted history should contain 'QUESTION:' prefix"
|
||||
assert "CONTEXT:" in formatted_history, "Formatted history should contain 'CONTEXT:' prefix"
|
||||
assert "ANSWER:" in formatted_history, "Formatted history should contain 'ANSWER:' prefix"
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
logger.info("All conversation history tests passed successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -53,7 +53,7 @@ class TestConversationHistoryUtils:
|
|||
]
|
||||
mock_cache = create_mock_cache_engine(mock_history)
|
||||
|
||||
# ✅ Import the real module to patch safely
|
||||
# Import the real module to patch safely
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue