Merge branch 'dev' into multi-tenancy

This commit is contained in:
Igor Ilic 2025-11-05 12:17:49 +01:00 committed by GitHub
commit 6a7d8ba106
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 592 additions and 1 deletions

View file

@ -0,0 +1,55 @@
from typing import Optional, List
from cognee import memify
from cognee.context_global_variables import (
set_database_global_context_variables,
set_session_user_context_variable,
)
from cognee.exceptions import CogneeValidationError
from cognee.modules.data.methods import get_authorized_existing_datasets
from cognee.shared.logging_utils import get_logger
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.models import User
from cognee.tasks.memify import extract_user_sessions, cognify_session
logger = get_logger("persist_sessions_in_knowledge_graph")
async def persist_sessions_in_knowledge_graph_pipeline(
user: User,
session_ids: Optional[List[str]] = None,
dataset: str = "main_dataset",
run_in_background: bool = False,
):
await set_session_user_context_variable(user)
dataset_to_write = await get_authorized_existing_datasets(
user=user, datasets=[dataset], permission_type="write"
)
if not dataset_to_write:
raise CogneeValidationError(
message=f"User (id: {str(user.id)}) does not have write access to dataset: {dataset}",
log=False,
)
await set_database_global_context_variables(
dataset_to_write[0].id, dataset_to_write[0].owner_id
)
extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)]
enrichment_tasks = [
Task(cognify_session),
]
result = await memify(
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
dataset=dataset_to_write[0].id,
data=[{}],
run_in_background=run_in_background,
)
logger.info("Session persistence pipeline completed")
return result

View file

@ -1,2 +1,4 @@
from .extract_subgraph import extract_subgraph
from .extract_subgraph_chunks import extract_subgraph_chunks
from .cognify_session import cognify_session
from .extract_user_sessions import extract_user_sessions

View file

@ -0,0 +1,40 @@
import cognee
from cognee.exceptions import CogneeValidationError, CogneeSystemError
from cognee.shared.logging_utils import get_logger
logger = get_logger("cognify_session")
async def cognify_session(data):
"""
Process and cognify session data into the knowledge graph.
Adds session content to cognee with a dedicated "user_sessions" node set,
then triggers the cognify pipeline to extract entities and relationships
from the session data.
Args:
data: Session string containing Question, Context, and Answer information.
Raises:
CogneeValidationError: If data is None or empty.
CogneeSystemError: If cognee operations fail.
"""
try:
if not data or (isinstance(data, str) and not data.strip()):
logger.warning("Empty session data provided to cognify_session task, skipping")
raise CogneeValidationError(message="Session data cannot be empty", log=False)
logger.info("Processing session data for cognification")
await cognee.add(data, node_set=["user_sessions_from_cache"])
logger.debug("Session data added to cognee with node_set: user_sessions")
await cognee.cognify()
logger.info("Session data successfully cognified")
except CogneeValidationError:
raise
except Exception as e:
logger.error(f"Error cognifying session data: {str(e)}")
raise CogneeSystemError(message=f"Failed to cognify session data: {str(e)}", log=False)

View file

@ -0,0 +1,73 @@
from typing import Optional, List
from cognee.context_global_variables import session_user
from cognee.exceptions import CogneeSystemError
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
logger = get_logger("extract_user_sessions")
async def extract_user_sessions(
data,
session_ids: Optional[List[str]] = None,
):
"""
Extract Q&A sessions for the current user from cache.
Retrieves all Q&A triplets from specified session IDs and yields them
as formatted strings combining question, context, and answer.
Args:
data: Data passed from memify. If empty dict ({}), no external data is provided.
session_ids: Optional list of specific session IDs to extract.
Yields:
String containing session ID and all Q&A pairs formatted.
Raises:
CogneeSystemError: If cache engine is unavailable or extraction fails.
"""
try:
if not data or data == [{}]:
logger.info("Fetching session metadata for current user")
user: User = session_user.get()
if not user:
raise CogneeSystemError(message="No authenticated user found in context", log=False)
user_id = str(user.id)
cache_engine = get_cache_engine()
if cache_engine is None:
raise CogneeSystemError(
message="Cache engine not available for session extraction, please enable caching in order to have sessions to save",
log=False,
)
if session_ids:
for session_id in session_ids:
try:
qa_data = await cache_engine.get_all_qas(user_id, session_id)
if qa_data:
logger.info(f"Extracted session {session_id} with {len(qa_data)} Q&A pairs")
session_string = f"Session ID: {session_id}\n\n"
for qa_pair in qa_data:
question = qa_pair.get("question", "")
answer = qa_pair.get("answer", "")
session_string += f"Question: {question}\n\nAnswer: {answer}\n\n"
yield session_string
except Exception as e:
logger.warning(f"Failed to extract session {session_id}: {str(e)}")
continue
else:
logger.info(
"No specific session_ids provided. Please specify which sessions to extract."
)
except CogneeSystemError:
raise
except Exception as e:
logger.error(f"Error extracting user sessions: {str(e)}")
raise CogneeSystemError(message=f"Failed to extract user sessions: {str(e)}", log=False)

View file

@ -16,9 +16,11 @@ import cognee
import pathlib
from cognee.infrastructure.databases.cache import get_cache_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.search.types import SearchType
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.methods import get_default_user
from collections import Counter
logger = get_logger()
@ -188,7 +190,6 @@ async def main():
f"GRAPH_SUMMARY_COMPLETION should return non-empty list, got: {result_summary!r}"
)
# Verify saved
history_summary = await cache_engine.get_latest_qa(str(user.id), session_id_summary, last_n=10)
our_qa_summary = [
h for h in history_summary if h["question"] == "What are the key points about TechCorp?"
@ -228,6 +229,46 @@ async def main():
assert "CONTEXT:" in formatted_history, "Formatted history should contain 'CONTEXT:' prefix"
assert "ANSWER:" in formatted_history, "Formatted history should contain 'ANSWER:' prefix"
from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import (
persist_sessions_in_knowledge_graph_pipeline,
)
logger.info("Starting persist_sessions_in_knowledge_graph tests")
await persist_sessions_in_knowledge_graph_pipeline(
user=user,
session_ids=[session_id_1, session_id_2],
dataset=dataset_name,
run_in_background=False,
)
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
"Tests the correct number of NodeSet nodes after session persistence"
assert type_counts.get("NodeSet", 0) == 1, (
f"Number of NodeSets in the graph is incorrect, found {type_counts.get('NodeSet', 0)} but there should be exactly 1."
)
"Tests the correct number of DocumentChunk nodes after session persistence"
assert type_counts.get("DocumentChunk", 0) == 4, (
f"Number of DocumentChunk ndoes in the graph is incorrect, found {type_counts.get('DocumentChunk', 0)} but there should be exactly 4 (2 original documents, 2 sessions)."
)
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
vector_engine = get_vector_engine()
collection_size = await vector_engine.search(
collection_name="DocumentChunk_text",
query_text="test",
limit=1000,
)
assert len(collection_size) == 4, (
f"DocumentChunk_text collection should have exactly 4 embeddings, found {len(collection_size)}"
)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

View file

@ -0,0 +1,107 @@
import pytest
from unittest.mock import AsyncMock, patch
from cognee.tasks.memify.cognify_session import cognify_session
from cognee.exceptions import CogneeValidationError, CogneeSystemError
@pytest.mark.asyncio
async def test_cognify_session_success():
"""Test successful cognification of session data."""
session_data = (
"Session ID: test_session\n\nQuestion: What is AI?\n\nAnswer: AI is artificial intelligence"
)
with (
patch("cognee.add", new_callable=AsyncMock) as mock_add,
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
):
await cognify_session(session_data)
mock_add.assert_called_once_with(session_data, node_set=["user_sessions_from_cache"])
mock_cognify.assert_called_once()
@pytest.mark.asyncio
async def test_cognify_session_empty_string():
"""Test cognification fails with empty string."""
with pytest.raises(CogneeValidationError) as exc_info:
await cognify_session("")
assert "Session data cannot be empty" in str(exc_info.value)
@pytest.mark.asyncio
async def test_cognify_session_whitespace_string():
"""Test cognification fails with whitespace-only string."""
with pytest.raises(CogneeValidationError) as exc_info:
await cognify_session(" \n\t ")
assert "Session data cannot be empty" in str(exc_info.value)
@pytest.mark.asyncio
async def test_cognify_session_none_data():
"""Test cognification fails with None data."""
with pytest.raises(CogneeValidationError) as exc_info:
await cognify_session(None)
assert "Session data cannot be empty" in str(exc_info.value)
@pytest.mark.asyncio
async def test_cognify_session_add_failure():
"""Test cognification handles cognee.add failure."""
session_data = "Session ID: test\n\nQuestion: test?"
with (
patch("cognee.add", new_callable=AsyncMock) as mock_add,
patch("cognee.cognify", new_callable=AsyncMock),
):
mock_add.side_effect = Exception("Add operation failed")
with pytest.raises(CogneeSystemError) as exc_info:
await cognify_session(session_data)
assert "Failed to cognify session data" in str(exc_info.value)
assert "Add operation failed" in str(exc_info.value)
@pytest.mark.asyncio
async def test_cognify_session_cognify_failure():
"""Test cognification handles cognify failure."""
session_data = "Session ID: test\n\nQuestion: test?"
with (
patch("cognee.add", new_callable=AsyncMock),
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
):
mock_cognify.side_effect = Exception("Cognify operation failed")
with pytest.raises(CogneeSystemError) as exc_info:
await cognify_session(session_data)
assert "Failed to cognify session data" in str(exc_info.value)
assert "Cognify operation failed" in str(exc_info.value)
@pytest.mark.asyncio
async def test_cognify_session_re_raises_validation_error():
"""Test that CogneeValidationError is re-raised as-is."""
with pytest.raises(CogneeValidationError):
await cognify_session("")
@pytest.mark.asyncio
async def test_cognify_session_with_special_characters():
"""Test cognification with special characters."""
session_data = "Session: test™ © Question: What's special? Answer: Cognee is special!"
with (
patch("cognee.add", new_callable=AsyncMock) as mock_add,
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
):
await cognify_session(session_data)
mock_add.assert_called_once_with(session_data, node_set=["user_sessions_from_cache"])
mock_cognify.assert_called_once()

View file

@ -0,0 +1,175 @@
import sys
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from cognee.tasks.memify.extract_user_sessions import extract_user_sessions
from cognee.exceptions import CogneeSystemError
from cognee.modules.users.models import User
# Get the actual module object (not the function) for patching
extract_user_sessions_module = sys.modules["cognee.tasks.memify.extract_user_sessions"]
@pytest.fixture
def mock_user():
"""Create a mock user."""
user = MagicMock(spec=User)
user.id = "test-user-123"
return user
@pytest.fixture
def mock_qa_data():
"""Create mock Q&A data."""
return [
{
"question": "What is cognee?",
"context": "context about cognee",
"answer": "Cognee is a knowledge graph solution",
"time": "2025-01-01T12:00:00",
},
{
"question": "How does it work?",
"context": "how it works context",
"answer": "It processes data and creates graphs",
"time": "2025-01-01T12:05:00",
},
]
@pytest.mark.asyncio
async def test_extract_user_sessions_success(mock_user, mock_qa_data):
"""Test successful extraction of sessions."""
mock_cache_engine = AsyncMock()
mock_cache_engine.get_all_qas.return_value = mock_qa_data
with (
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
patch.object(
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
),
):
mock_session_user.get.return_value = mock_user
sessions = []
async for session in extract_user_sessions([{}], session_ids=["test_session"]):
sessions.append(session)
assert len(sessions) == 1
assert "Session ID: test_session" in sessions[0]
assert "Question: What is cognee?" in sessions[0]
assert "Answer: Cognee is a knowledge graph solution" in sessions[0]
assert "Question: How does it work?" in sessions[0]
assert "Answer: It processes data and creates graphs" in sessions[0]
@pytest.mark.asyncio
async def test_extract_user_sessions_multiple_sessions(mock_user, mock_qa_data):
"""Test extraction of multiple sessions."""
mock_cache_engine = AsyncMock()
mock_cache_engine.get_all_qas.return_value = mock_qa_data
with (
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
patch.object(
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
),
):
mock_session_user.get.return_value = mock_user
sessions = []
async for session in extract_user_sessions([{}], session_ids=["session1", "session2"]):
sessions.append(session)
assert len(sessions) == 2
assert mock_cache_engine.get_all_qas.call_count == 2
@pytest.mark.asyncio
async def test_extract_user_sessions_no_data(mock_user, mock_qa_data):
"""Test extraction handles empty data parameter."""
mock_cache_engine = AsyncMock()
mock_cache_engine.get_all_qas.return_value = mock_qa_data
with (
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
patch.object(
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
),
):
mock_session_user.get.return_value = mock_user
sessions = []
async for session in extract_user_sessions(None, session_ids=["test_session"]):
sessions.append(session)
assert len(sessions) == 1
@pytest.mark.asyncio
async def test_extract_user_sessions_no_session_ids(mock_user):
"""Test extraction handles no session IDs provided."""
mock_cache_engine = AsyncMock()
with (
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
patch.object(
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
),
):
mock_session_user.get.return_value = mock_user
sessions = []
async for session in extract_user_sessions([{}], session_ids=None):
sessions.append(session)
assert len(sessions) == 0
mock_cache_engine.get_all_qas.assert_not_called()
@pytest.mark.asyncio
async def test_extract_user_sessions_empty_qa_data(mock_user):
"""Test extraction handles empty Q&A data."""
mock_cache_engine = AsyncMock()
mock_cache_engine.get_all_qas.return_value = []
with (
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
patch.object(
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
),
):
mock_session_user.get.return_value = mock_user
sessions = []
async for session in extract_user_sessions([{}], session_ids=["empty_session"]):
sessions.append(session)
assert len(sessions) == 0
@pytest.mark.asyncio
async def test_extract_user_sessions_cache_error_handling(mock_user, mock_qa_data):
"""Test extraction continues on cache error for specific session."""
mock_cache_engine = AsyncMock()
mock_cache_engine.get_all_qas.side_effect = [
mock_qa_data,
Exception("Cache error"),
mock_qa_data,
]
with (
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
patch.object(
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
),
):
mock_session_user.get.return_value = mock_user
sessions = []
async for session in extract_user_sessions(
[{}], session_ids=["session1", "session2", "session3"]
):
sessions.append(session)
assert len(sessions) == 2

View file

@ -0,0 +1,98 @@
import asyncio
import cognee
from cognee import visualize_graph
from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import (
persist_sessions_in_knowledge_graph_pipeline,
)
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
logger = get_logger("conversation_session_persistence_example")
async def main():
# NOTE: CACHING has to be enabled for this example to work
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
text_1 = "Cognee is a solution that can build knowledge graph from text, creating an AI memory system"
text_2 = "Germany is a country located next to the Netherlands"
await cognee.add([text_1, text_2])
await cognee.cognify()
question = "What can I use to create a knowledge graph?"
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=question,
)
print("\nSession ID: default_session")
print(f"Question: {question}")
print(f"Answer: {search_results}\n")
question = "You sure about that?"
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text=question
)
print("\nSession ID: default_session")
print(f"Question: {question}")
print(f"Answer: {search_results}\n")
question = "This is awesome!"
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text=question
)
print("\nSession ID: default_session")
print(f"Question: {question}")
print(f"Answer: {search_results}\n")
question = "Where is Germany?"
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=question,
session_id="different_session",
)
print("\nSession ID: different_session")
print(f"Question: {question}")
print(f"Answer: {search_results}\n")
question = "Next to which country again?"
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=question,
session_id="different_session",
)
print("\nSession ID: different_session")
print(f"Question: {question}")
print(f"Answer: {search_results}\n")
question = "So you remember everything I asked from you?"
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=question,
session_id="different_session",
)
print("\nSession ID: different_session")
print(f"Question: {question}")
print(f"Answer: {search_results}\n")
session_ids_to_persist = ["default_session", "different_session"]
default_user = await get_default_user()
await persist_sessions_in_knowledge_graph_pipeline(
user=default_user,
session_ids=session_ids_to_persist,
)
await visualize_graph()
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())