diff --git a/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py b/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py new file mode 100644 index 000000000..c0ba0a4d9 --- /dev/null +++ b/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py @@ -0,0 +1,55 @@ +from typing import Optional, List + +from cognee import memify +from cognee.context_global_variables import ( + set_database_global_context_variables, + set_session_user_context_variable, +) +from cognee.exceptions import CogneeValidationError +from cognee.modules.data.methods import get_authorized_existing_datasets +from cognee.shared.logging_utils import get_logger +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.users.models import User +from cognee.tasks.memify import extract_user_sessions, cognify_session + + +logger = get_logger("persist_sessions_in_knowledge_graph") + + +async def persist_sessions_in_knowledge_graph_pipeline( + user: User, + session_ids: Optional[List[str]] = None, + dataset: str = "main_dataset", + run_in_background: bool = False, +): + await set_session_user_context_variable(user) + dataset_to_write = await get_authorized_existing_datasets( + user=user, datasets=[dataset], permission_type="write" + ) + + if not dataset_to_write: + raise CogneeValidationError( + message=f"User (id: {str(user.id)}) does not have write access to dataset: {dataset}", + log=False, + ) + + await set_database_global_context_variables( + dataset_to_write[0].id, dataset_to_write[0].owner_id + ) + + extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)] + + enrichment_tasks = [ + Task(cognify_session), + ] + + result = await memify( + extraction_tasks=extraction_tasks, + enrichment_tasks=enrichment_tasks, + dataset=dataset_to_write[0].id, + data=[{}], + run_in_background=run_in_background, + ) + + logger.info("Session persistence pipeline completed") + return result diff --git a/cognee/tasks/memify/__init__.py b/cognee/tasks/memify/__init__.py index 692bac443..7e590ed47 100644 --- a/cognee/tasks/memify/__init__.py +++ b/cognee/tasks/memify/__init__.py @@ -1,2 +1,4 @@ from .extract_subgraph import extract_subgraph from .extract_subgraph_chunks import extract_subgraph_chunks +from .cognify_session import cognify_session +from .extract_user_sessions import extract_user_sessions diff --git a/cognee/tasks/memify/cognify_session.py b/cognee/tasks/memify/cognify_session.py new file mode 100644 index 000000000..7c276169a --- /dev/null +++ b/cognee/tasks/memify/cognify_session.py @@ -0,0 +1,40 @@ +import cognee + +from cognee.exceptions import CogneeValidationError, CogneeSystemError +from cognee.shared.logging_utils import get_logger + +logger = get_logger("cognify_session") + + +async def cognify_session(data): + """ + Process and cognify session data into the knowledge graph. + + Adds session content to cognee with a dedicated "user_sessions" node set, + then triggers the cognify pipeline to extract entities and relationships + from the session data. + + Args: + data: Session string containing Question, Context, and Answer information. + + Raises: + CogneeValidationError: If data is None or empty. + CogneeSystemError: If cognee operations fail. + """ + try: + if not data or (isinstance(data, str) and not data.strip()): + logger.warning("Empty session data provided to cognify_session task, skipping") + raise CogneeValidationError(message="Session data cannot be empty", log=False) + + logger.info("Processing session data for cognification") + + await cognee.add(data, node_set=["user_sessions_from_cache"]) + logger.debug("Session data added to cognee with node_set: user_sessions") + await cognee.cognify() + logger.info("Session data successfully cognified") + + except CogneeValidationError: + raise + except Exception as e: + logger.error(f"Error cognifying session data: {str(e)}") + raise CogneeSystemError(message=f"Failed to cognify session data: {str(e)}", log=False) diff --git a/cognee/tasks/memify/extract_user_sessions.py b/cognee/tasks/memify/extract_user_sessions.py new file mode 100644 index 000000000..9779a363e --- /dev/null +++ b/cognee/tasks/memify/extract_user_sessions.py @@ -0,0 +1,73 @@ +from typing import Optional, List + +from cognee.context_global_variables import session_user +from cognee.exceptions import CogneeSystemError +from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine +from cognee.shared.logging_utils import get_logger +from cognee.modules.users.models import User + +logger = get_logger("extract_user_sessions") + + +async def extract_user_sessions( + data, + session_ids: Optional[List[str]] = None, +): + """ + Extract Q&A sessions for the current user from cache. + + Retrieves all Q&A triplets from specified session IDs and yields them + as formatted strings combining question, context, and answer. + + Args: + data: Data passed from memify. If empty dict ({}), no external data is provided. + session_ids: Optional list of specific session IDs to extract. + + Yields: + String containing session ID and all Q&A pairs formatted. + + Raises: + CogneeSystemError: If cache engine is unavailable or extraction fails. + """ + try: + if not data or data == [{}]: + logger.info("Fetching session metadata for current user") + + user: User = session_user.get() + if not user: + raise CogneeSystemError(message="No authenticated user found in context", log=False) + + user_id = str(user.id) + + cache_engine = get_cache_engine() + if cache_engine is None: + raise CogneeSystemError( + message="Cache engine not available for session extraction, please enable caching in order to have sessions to save", + log=False, + ) + + if session_ids: + for session_id in session_ids: + try: + qa_data = await cache_engine.get_all_qas(user_id, session_id) + if qa_data: + logger.info(f"Extracted session {session_id} with {len(qa_data)} Q&A pairs") + session_string = f"Session ID: {session_id}\n\n" + for qa_pair in qa_data: + question = qa_pair.get("question", "") + answer = qa_pair.get("answer", "") + session_string += f"Question: {question}\n\nAnswer: {answer}\n\n" + yield session_string + except Exception as e: + logger.warning(f"Failed to extract session {session_id}: {str(e)}") + continue + else: + logger.info( + "No specific session_ids provided. Please specify which sessions to extract." + ) + + except CogneeSystemError: + raise + except Exception as e: + logger.error(f"Error extracting user sessions: {str(e)}") + raise CogneeSystemError(message=f"Failed to extract user sessions: {str(e)}", log=False) diff --git a/cognee/tests/test_conversation_history.py b/cognee/tests/test_conversation_history.py index 30bb54ef1..6b5b737f1 100644 --- a/cognee/tests/test_conversation_history.py +++ b/cognee/tests/test_conversation_history.py @@ -16,9 +16,11 @@ import cognee import pathlib from cognee.infrastructure.databases.cache import get_cache_engine +from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.search.types import SearchType from cognee.shared.logging_utils import get_logger from cognee.modules.users.methods import get_default_user +from collections import Counter logger = get_logger() @@ -188,7 +190,6 @@ async def main(): f"GRAPH_SUMMARY_COMPLETION should return non-empty list, got: {result_summary!r}" ) - # Verify saved history_summary = await cache_engine.get_latest_qa(str(user.id), session_id_summary, last_n=10) our_qa_summary = [ h for h in history_summary if h["question"] == "What are the key points about TechCorp?" @@ -228,6 +229,46 @@ async def main(): assert "CONTEXT:" in formatted_history, "Formatted history should contain 'CONTEXT:' prefix" assert "ANSWER:" in formatted_history, "Formatted history should contain 'ANSWER:' prefix" + from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import ( + persist_sessions_in_knowledge_graph_pipeline, + ) + + logger.info("Starting persist_sessions_in_knowledge_graph tests") + + await persist_sessions_in_knowledge_graph_pipeline( + user=user, + session_ids=[session_id_1, session_id_2], + dataset=dataset_name, + run_in_background=False, + ) + + graph_engine = await get_graph_engine() + graph = await graph_engine.get_graph_data() + + type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) + + "Tests the correct number of NodeSet nodes after session persistence" + assert type_counts.get("NodeSet", 0) == 1, ( + f"Number of NodeSets in the graph is incorrect, found {type_counts.get('NodeSet', 0)} but there should be exactly 1." + ) + + "Tests the correct number of DocumentChunk nodes after session persistence" + assert type_counts.get("DocumentChunk", 0) == 4, ( + f"Number of DocumentChunk ndoes in the graph is incorrect, found {type_counts.get('DocumentChunk', 0)} but there should be exactly 4 (2 original documents, 2 sessions)." + ) + + from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine + + vector_engine = get_vector_engine() + collection_size = await vector_engine.search( + collection_name="DocumentChunk_text", + query_text="test", + limit=1000, + ) + assert len(collection_size) == 4, ( + f"DocumentChunk_text collection should have exactly 4 embeddings, found {len(collection_size)}" + ) + await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) diff --git a/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py b/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py new file mode 100644 index 000000000..c23640fbd --- /dev/null +++ b/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py @@ -0,0 +1,107 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from cognee.tasks.memify.cognify_session import cognify_session +from cognee.exceptions import CogneeValidationError, CogneeSystemError + + +@pytest.mark.asyncio +async def test_cognify_session_success(): + """Test successful cognification of session data.""" + session_data = ( + "Session ID: test_session\n\nQuestion: What is AI?\n\nAnswer: AI is artificial intelligence" + ) + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + await cognify_session(session_data) + + mock_add.assert_called_once_with(session_data, node_set=["user_sessions_from_cache"]) + mock_cognify.assert_called_once() + + +@pytest.mark.asyncio +async def test_cognify_session_empty_string(): + """Test cognification fails with empty string.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session("") + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_whitespace_string(): + """Test cognification fails with whitespace-only string.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session(" \n\t ") + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_none_data(): + """Test cognification fails with None data.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session(None) + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_add_failure(): + """Test cognification handles cognee.add failure.""" + session_data = "Session ID: test\n\nQuestion: test?" + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock), + ): + mock_add.side_effect = Exception("Add operation failed") + + with pytest.raises(CogneeSystemError) as exc_info: + await cognify_session(session_data) + + assert "Failed to cognify session data" in str(exc_info.value) + assert "Add operation failed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_cognify_failure(): + """Test cognification handles cognify failure.""" + session_data = "Session ID: test\n\nQuestion: test?" + + with ( + patch("cognee.add", new_callable=AsyncMock), + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + mock_cognify.side_effect = Exception("Cognify operation failed") + + with pytest.raises(CogneeSystemError) as exc_info: + await cognify_session(session_data) + + assert "Failed to cognify session data" in str(exc_info.value) + assert "Cognify operation failed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_re_raises_validation_error(): + """Test that CogneeValidationError is re-raised as-is.""" + with pytest.raises(CogneeValidationError): + await cognify_session("") + + +@pytest.mark.asyncio +async def test_cognify_session_with_special_characters(): + """Test cognification with special characters.""" + session_data = "Session: test™ © Question: What's special? Answer: Cognee is special!" + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + await cognify_session(session_data) + + mock_add.assert_called_once_with(session_data, node_set=["user_sessions_from_cache"]) + mock_cognify.assert_called_once() diff --git a/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py b/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py new file mode 100644 index 000000000..8cb27fef3 --- /dev/null +++ b/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py @@ -0,0 +1,175 @@ +import sys +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from cognee.tasks.memify.extract_user_sessions import extract_user_sessions +from cognee.exceptions import CogneeSystemError +from cognee.modules.users.models import User + +# Get the actual module object (not the function) for patching +extract_user_sessions_module = sys.modules["cognee.tasks.memify.extract_user_sessions"] + + +@pytest.fixture +def mock_user(): + """Create a mock user.""" + user = MagicMock(spec=User) + user.id = "test-user-123" + return user + + +@pytest.fixture +def mock_qa_data(): + """Create mock Q&A data.""" + return [ + { + "question": "What is cognee?", + "context": "context about cognee", + "answer": "Cognee is a knowledge graph solution", + "time": "2025-01-01T12:00:00", + }, + { + "question": "How does it work?", + "context": "how it works context", + "answer": "It processes data and creates graphs", + "time": "2025-01-01T12:05:00", + }, + ] + + +@pytest.mark.asyncio +async def test_extract_user_sessions_success(mock_user, mock_qa_data): + """Test successful extraction of sessions.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["test_session"]): + sessions.append(session) + + assert len(sessions) == 1 + assert "Session ID: test_session" in sessions[0] + assert "Question: What is cognee?" in sessions[0] + assert "Answer: Cognee is a knowledge graph solution" in sessions[0] + assert "Question: How does it work?" in sessions[0] + assert "Answer: It processes data and creates graphs" in sessions[0] + + +@pytest.mark.asyncio +async def test_extract_user_sessions_multiple_sessions(mock_user, mock_qa_data): + """Test extraction of multiple sessions.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["session1", "session2"]): + sessions.append(session) + + assert len(sessions) == 2 + assert mock_cache_engine.get_all_qas.call_count == 2 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_no_data(mock_user, mock_qa_data): + """Test extraction handles empty data parameter.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions(None, session_ids=["test_session"]): + sessions.append(session) + + assert len(sessions) == 1 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_no_session_ids(mock_user): + """Test extraction handles no session IDs provided.""" + mock_cache_engine = AsyncMock() + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=None): + sessions.append(session) + + assert len(sessions) == 0 + mock_cache_engine.get_all_qas.assert_not_called() + + +@pytest.mark.asyncio +async def test_extract_user_sessions_empty_qa_data(mock_user): + """Test extraction handles empty Q&A data.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = [] + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["empty_session"]): + sessions.append(session) + + assert len(sessions) == 0 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_cache_error_handling(mock_user, mock_qa_data): + """Test extraction continues on cache error for specific session.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.side_effect = [ + mock_qa_data, + Exception("Cache error"), + mock_qa_data, + ] + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions( + [{}], session_ids=["session1", "session2", "session3"] + ): + sessions.append(session) + + assert len(sessions) == 2 diff --git a/examples/python/conversation_session_persistence_example.py b/examples/python/conversation_session_persistence_example.py new file mode 100644 index 000000000..5346f5012 --- /dev/null +++ b/examples/python/conversation_session_persistence_example.py @@ -0,0 +1,98 @@ +import asyncio + +import cognee +from cognee import visualize_graph +from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import ( + persist_sessions_in_knowledge_graph_pipeline, +) +from cognee.modules.search.types import SearchType +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import get_logger + +logger = get_logger("conversation_session_persistence_example") + + +async def main(): + # NOTE: CACHING has to be enabled for this example to work + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + text_1 = "Cognee is a solution that can build knowledge graph from text, creating an AI memory system" + text_2 = "Germany is a country located next to the Netherlands" + + await cognee.add([text_1, text_2]) + await cognee.cognify() + + question = "What can I use to create a knowledge graph?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "You sure about that?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=question + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "This is awesome!" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=question + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "Where is Germany?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "Next to which country again?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "So you remember everything I asked from you?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + session_ids_to_persist = ["default_session", "different_session"] + default_user = await get_default_user() + + await persist_sessions_in_knowledge_graph_pipeline( + user=default_user, + session_ids=session_ids_to_persist, + ) + + await visualize_graph() + + +if __name__ == "__main__": + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens())