diff --git a/cognee/tests/conversation_history.py b/cognee/tests/conversation_history.py new file mode 100644 index 000000000..b20c315f9 --- /dev/null +++ b/cognee/tests/conversation_history.py @@ -0,0 +1,240 @@ +""" +End-to-end integration test for conversation history feature. + +Tests all retrievers that save conversation history to Redis cache: +1. GRAPH_COMPLETION +2. RAG_COMPLETION +3. GRAPH_COMPLETION_COT +4. GRAPH_COMPLETION_CONTEXT_EXTENSION +5. GRAPH_SUMMARY_COMPLETION +6. TEMPORAL +""" + +import os +import shutil +import cognee +import pathlib + +from cognee.infrastructure.databases.cache import get_cache_engine +from cognee.modules.search.types import SearchType +from cognee.shared.logging_utils import get_logger +from cognee.modules.users.methods import get_default_user + +logger = get_logger() + + +async def main(): + data_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, + ".data_storage/test_conversation_history", + ) + ).resolve() + ) + cognee_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, + ".cognee_system/test_conversation_history", + ) + ).resolve() + ) + + cognee.config.data_root_directory(data_directory_path) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + dataset_name = "conversation_history_test" + + text_1 = """TechCorp is a technology company based in San Francisco. They specialize in artificial intelligence and machine learning.""" + text_2 = ( + """DataCo is a data analytics company. They help businesses make sense of their data.""" + ) + + await cognee.add(text_1, dataset_name) + await cognee.add(text_2, dataset_name) + + await cognee.cognify([dataset_name]) + + user = await get_default_user() + + cache_engine = get_cache_engine() + assert cache_engine is not None, "Cache engine should be available for testing" + + session_id_1 = "test_session_graph" + + await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is TechCorp?", + session_id=session_id_1, + ) + + history1 = await cache_engine.get_latest_qa(str(user.id), session_id_1, last_n=10) + assert len(history1) == 1, f"Expected at least 1 Q&A in history, got {len(history1)}" + our_qa = [h for h in history1 if h["question"] == "What is TechCorp?"] + assert len(our_qa) >= 1, f"Expected to find 'What is TechCorp?' in history" + assert "answer" in our_qa[0] and "context" in our_qa[0], ( + "Q&A should contain answer and context fields" + ) + + result2 = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="Tell me more about it", + session_id=session_id_1, + ) + + assert isinstance(result2, list) and len(result2) > 0, ( + f"Second query should return non-empty list, got: {result2!r}" + ) + + history2 = await cache_engine.get_latest_qa(str(user.id), session_id_1, last_n=10) + our_questions = [ + h for h in history2 if h["question"] in ["What is TechCorp?", "Tell me more about it"] + ] + assert len(our_questions) == 2, ( + f"Expected at least 2 Q&A pairs in history after 2 queries, got {len(our_questions)}" + ) + + session_id_2 = "test_session_separate" + + result3 = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is DataCo?", + session_id=session_id_2, + ) + + assert isinstance(result3, list) and len(result3) > 0, ( + f"Different session should return non-empty list, got: {result3!r}" + ) + + history3 = await cache_engine.get_latest_qa(str(user.id), session_id_2, last_n=10) + our_qa_session2 = [h for h in history3 if h["question"] == "What is DataCo?"] + assert len(our_qa_session2) == 1, f"Session 2 should have 'What is DataCo?' question" + + result4 = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="Test default session", + session_id=None, + ) + + assert isinstance(result4, list) and len(result4) > 0, ( + f"Default session should return non-empty list, got: {result4!r}" + ) + + history_default = await cache_engine.get_latest_qa(str(user.id), "default_session", last_n=10) + our_qa_default = [h for h in history_default if h["question"] == "Test default session"] + assert len(our_qa_default) == 1, f"Should find 'Test default session' in default_session" + + session_id_rag = "test_session_rag" + + result_rag = await cognee.search( + query_type=SearchType.RAG_COMPLETION, + query_text="What companies are mentioned?", + session_id=session_id_rag, + ) + + assert isinstance(result_rag, list) and len(result_rag) > 0, ( + f"RAG_COMPLETION should return non-empty list, got: {result_rag!r}" + ) + + history_rag = await cache_engine.get_latest_qa(str(user.id), session_id_rag, last_n=10) + our_qa_rag = [h for h in history_rag if h["question"] == "What companies are mentioned?"] + assert len(our_qa_rag) == 1, f"Should find RAG question in history" + + session_id_cot = "test_session_cot" + + result_cot = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION_COT, + query_text="What do you know about TechCorp?", + session_id=session_id_cot, + ) + + assert isinstance(result_cot, list) and len(result_cot) > 0, ( + f"GRAPH_COMPLETION_COT should return non-empty list, got: {result_cot!r}" + ) + + history_cot = await cache_engine.get_latest_qa(str(user.id), session_id_cot, last_n=10) + our_qa_cot = [h for h in history_cot if h["question"] == "What do you know about TechCorp?"] + assert len(our_qa_cot) == 1, f"Should find CoT question in history" + + session_id_ext = "test_session_ext" + + result_ext = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION, + query_text="Tell me about DataCo", + session_id=session_id_ext, + ) + + assert isinstance(result_ext, list) and len(result_ext) > 0, ( + f"GRAPH_COMPLETION_CONTEXT_EXTENSION should return non-empty list, got: {result_ext!r}" + ) + + history_ext = await cache_engine.get_latest_qa(str(user.id), session_id_ext, last_n=10) + our_qa_ext = [h for h in history_ext if h["question"] == "Tell me about DataCo"] + assert len(our_qa_ext) == 1, f"Should find Context Extension question in history" + + session_id_summary = "test_session_summary" + + result_summary = await cognee.search( + query_type=SearchType.GRAPH_SUMMARY_COMPLETION, + query_text="What are the key points about TechCorp?", + session_id=session_id_summary, + ) + + assert isinstance(result_summary, list) and len(result_summary) > 0, ( + f"GRAPH_SUMMARY_COMPLETION should return non-empty list, got: {result_summary!r}" + ) + + # Verify saved + history_summary = await cache_engine.get_latest_qa(str(user.id), session_id_summary, last_n=10) + our_qa_summary = [ + h for h in history_summary if h["question"] == "What are the key points about TechCorp?" + ] + assert len(our_qa_summary) == 1, f"Should find Summary question in history" + + session_id_temporal = "test_session_temporal" + + result_temporal = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="Tell me about the companies", + session_id=session_id_temporal, + ) + + assert isinstance(result_temporal, list) and len(result_temporal) > 0, ( + f"TEMPORAL should return non-empty list, got: {result_temporal!r}" + ) + + history_temporal = await cache_engine.get_latest_qa( + str(user.id), session_id_temporal, last_n=10 + ) + our_qa_temporal = [ + h for h in history_temporal if h["question"] == "Tell me about the companies" + ] + assert len(our_qa_temporal) == 1, f"Should find Temporal question in history" + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + formatted_history = await get_conversation_history(session_id=session_id_1) + + assert "Previous conversation:" in formatted_history, ( + "Formatted history should contain 'Previous conversation:' header" + ) + assert "QUESTION:" in formatted_history, "Formatted history should contain 'QUESTION:' prefix" + assert "CONTEXT:" in formatted_history, "Formatted history should contain 'CONTEXT:' prefix" + assert "ANSWER:" in formatted_history, "Formatted history should contain 'ANSWER:' prefix" + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + logger.info("All conversation history tests passed successfully") + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py index ca54c3f16..efbafa2a4 100644 --- a/cognee/tests/unit/modules/retrieval/conversation_history_test.py +++ b/cognee/tests/unit/modules/retrieval/conversation_history_test.py @@ -53,7 +53,7 @@ class TestConversationHistoryUtils: ] mock_cache = create_mock_cache_engine(mock_history) - # ✅ Import the real module to patch safely + # Import the real module to patch safely cache_module = importlib.import_module( "cognee.infrastructure.databases.cache.get_cache_engine" )