Merge branch 'dev' into multi-tenancy

This commit is contained in:
Igor Ilic 2025-11-06 18:55:18 +01:00 committed by GitHub
commit 5dbfea5084
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 899 additions and 103 deletions

View file

@ -169,8 +169,9 @@ REQUIRE_AUTHENTICATION=False
# Vector: LanceDB
# Graph: KuzuDB
#
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
ENABLE_BACKEND_ACCESS_CONTROL=False
# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
# Disable mode when using not supported graph/vector databases.
ENABLE_BACKEND_ACCESS_CONTROL=True
################################################################################
# ☁️ Cloud Sync Settings

View file

@ -447,3 +447,44 @@ jobs:
DB_USERNAME: cognee
DB_PASSWORD: cognee
run: uv run python ./cognee/tests/test_conversation_history.py
test-load:
name: Test Load
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Set File Descriptor Limit
run: sudo prlimit --pid $$ --nofile=4096:4096
- name: Verify File Descriptor Limit
run: ulimit -n
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Load Test
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: True
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
STORAGE_BACKEND: s3
AWS_REGION: eu-west-1
AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }}
run: uv run python ./cognee/tests/test_load.py

View file

@ -84,6 +84,7 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
@ -135,6 +136,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
VECTOR_DB_PROVIDER: 'pgvector'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'
@ -197,6 +199,7 @@ jobs:
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_NAME: cognee_db
DB_HOST: 127.0.0.1
DB_PORT: 5432

View file

@ -4,6 +4,8 @@ from typing import Union
from uuid import UUID
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user
@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None)
vector_dbs_with_multi_user_support = ["lancedb"]
graph_dbs_with_multi_user_support = ["kuzu"]
async def set_session_user_context_variable(user):
session_user.set(user)
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
)
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return False
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
"""
If backend access control is enabled this function will ensure all datasets have their own databases,
@ -40,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
base_config = get_base_config()
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
if not backend_access_control_enabled():
return
user = await get_user(user_id)

View file

@ -40,7 +40,7 @@ async def persist_sessions_in_knowledge_graph_pipeline(
extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)]
enrichment_tasks = [
Task(cognify_session),
Task(cognify_session, dataset_id=dataset_to_write[0].id),
]
result = await memify(

View file

@ -0,0 +1,124 @@
from cognee.shared.logging_utils import get_logger
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from cognee.modules.chunking.Chunker import Chunker
from .models.DocumentChunk import DocumentChunk
logger = get_logger()
class TextChunkerWithOverlap(Chunker):
def __init__(
self,
document,
get_text: callable,
max_chunk_size: int,
chunk_overlap_ratio: float = 0.0,
get_chunk_data: callable = None,
):
super().__init__(document, get_text, max_chunk_size)
self._accumulated_chunk_data = []
self._accumulated_size = 0
self.chunk_overlap_ratio = chunk_overlap_ratio
self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio)
if get_chunk_data is not None:
self.get_chunk_data = get_chunk_data
elif chunk_overlap_ratio > 0:
paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size)
self.get_chunk_data = lambda text: chunk_by_paragraph(
text, paragraph_max_size, batch_paragraphs=True
)
else:
self.get_chunk_data = lambda text: chunk_by_paragraph(
text, self.max_chunk_size, batch_paragraphs=True
)
def _accumulation_overflows(self, chunk_data):
"""Check if adding chunk_data would exceed max_chunk_size."""
return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size
def _accumulate_chunk_data(self, chunk_data):
"""Add chunk_data to the current accumulation."""
self._accumulated_chunk_data.append(chunk_data)
self._accumulated_size += chunk_data["chunk_size"]
def _clear_accumulation(self):
"""Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio."""
if self.chunk_overlap == 0:
self._accumulated_chunk_data = []
self._accumulated_size = 0
return
# Keep chunk_data from the end that fit in overlap
overlap_chunk_data = []
overlap_size = 0
for chunk_data in reversed(self._accumulated_chunk_data):
if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap:
overlap_chunk_data.insert(0, chunk_data)
overlap_size += chunk_data["chunk_size"]
else:
break
self._accumulated_chunk_data = overlap_chunk_data
self._accumulated_size = overlap_size
def _create_chunk(self, text, size, cut_type, chunk_id=None):
"""Create a DocumentChunk with standard metadata."""
try:
return DocumentChunk(
id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text=text,
chunk_size=size,
is_part_of=self.document,
chunk_index=self.chunk_index,
cut_type=cut_type,
contains=[],
metadata={"index_fields": ["text"]},
)
except Exception as e:
logger.error(e)
raise e
def _create_chunk_from_accumulation(self):
"""Create a DocumentChunk from current accumulated chunk_data."""
chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data)
return self._create_chunk(
text=chunk_text,
size=self._accumulated_size,
cut_type=self._accumulated_chunk_data[-1]["cut_type"],
)
def _emit_chunk(self, chunk_data):
"""Emit a chunk when accumulation overflows."""
if len(self._accumulated_chunk_data) > 0:
chunk = self._create_chunk_from_accumulation()
self._clear_accumulation()
self._accumulate_chunk_data(chunk_data)
else:
# Handle single chunk_data exceeding max_chunk_size
chunk = self._create_chunk(
text=chunk_data["text"],
size=chunk_data["chunk_size"],
cut_type=chunk_data["cut_type"],
chunk_id=chunk_data["chunk_id"],
)
self.chunk_index += 1
return chunk
async def read(self):
async for content_text in self.get_text():
for chunk_data in self.get_chunk_data(content_text):
if not self._accumulation_overflows(chunk_data):
self._accumulate_chunk_data(chunk_data)
continue
yield self._emit_chunk(chunk_data)
if len(self._accumulated_chunk_data) == 0:
return
yield self._create_chunk_from_accumulation()

View file

@ -1,4 +1,3 @@
import os
import json
import asyncio
from uuid import UUID
@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables
from cognee.context_global_variables import backend_access_control_enabled
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -74,7 +74,7 @@ async def search(
)
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
if backend_access_control_enabled():
search_results = await authorized_search(
query_type=query_type,
query_text=query_text,
@ -156,7 +156,7 @@ async def search(
)
else:
# This is for maintaining backwards compatibility
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
if backend_access_control_enabled():
return_value = []
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)

View file

@ -5,6 +5,7 @@ from ..models import User
from ..get_fastapi_users import get_fastapi_users
from .get_default_user import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import backend_access_control_enabled
logger = get_logger("get_authenticated_user")
@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
or backend_access_control_enabled()
)
fastapi_users = get_fastapi_users()

View file

@ -6,7 +6,7 @@ from cognee.shared.logging_utils import get_logger
logger = get_logger("cognify_session")
async def cognify_session(data):
async def cognify_session(data, dataset_id=None):
"""
Process and cognify session data into the knowledge graph.
@ -16,6 +16,7 @@ async def cognify_session(data):
Args:
data: Session string containing Question, Context, and Answer information.
dataset_name: Name of dataset.
Raises:
CogneeValidationError: If data is None or empty.
@ -28,9 +29,9 @@ async def cognify_session(data):
logger.info("Processing session data for cognification")
await cognee.add(data, node_set=["user_sessions_from_cache"])
await cognee.add(data, dataset_id=dataset_id, node_set=["user_sessions_from_cache"])
logger.debug("Session data added to cognee with node_set: user_sessions")
await cognee.cognify()
await cognee.cognify(datasets=[dataset_id])
logger.info("Session data successfully cognified")
except CogneeValidationError:

View file

@ -39,12 +39,12 @@ async def main():
answer = await cognee.search("Do programmers change light bulbs?")
assert len(answer) != 0
lowercase_answer = answer[0].lower()
lowercase_answer = answer[0]["search_result"][0].lower()
assert ("no" in lowercase_answer) or ("none" in lowercase_answer)
answer = await cognee.search("What colours are there in the presentation table?")
assert len(answer) != 0
lowercase_answer = answer[0].lower()
lowercase_answer = answer[0]["search_result"][0].lower()
assert (
("red" in lowercase_answer)
and ("blue" in lowercase_answer)

View file

@ -56,10 +56,10 @@ async def main():
"""DataCo is a data analytics company. They help businesses make sense of their data."""
)
await cognee.add(text_1, dataset_name)
await cognee.add(text_2, dataset_name)
await cognee.add(data=text_1, dataset_name=dataset_name)
await cognee.add(data=text_2, dataset_name=dataset_name)
await cognee.cognify([dataset_name])
await cognee.cognify(datasets=[dataset_name])
user = await get_default_user()

View file

@ -133,7 +133,7 @@ async def main():
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}],
dataset="feedback_enrichment_test_memify",
dataset=dataset_name,
)
nodes_after, edges_after = await graph_engine.get_graph_data()

View file

@ -90,15 +90,17 @@ async def main():
)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="What information do you contain?"
query_type=SearchType.GRAPH_COMPLETION,
query_text="What information do you contain?",
dataset_ids=[pipeline_run_obj.dataset_id],
)
assert "Mark" in search_results[0], (
assert "Mark" in search_results[0]["search_result"][0], (
"Failed to update document, no mention of Mark in search results"
)
assert "Cindy" in search_results[0], (
assert "Cindy" in search_results[0]["search_result"][0], (
"Failed to update document, no mention of Cindy in search results"
)
assert "Artificial intelligence" not in search_results[0], (
assert "Artificial intelligence" not in search_results[0]["search_result"][0], (
"Failed to update document, Artificial intelligence still mentioned in search results"
)

62
cognee/tests/test_load.py Normal file
View file

@ -0,0 +1,62 @@
import os
import pathlib
import asyncio
import time
import cognee
from cognee.modules.search.types import SearchType
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def process_and_search(num_of_searches):
start_time = time.time()
await cognee.cognify()
await asyncio.gather(
*[
cognee.search(
query_text="Tell me about the document", query_type=SearchType.GRAPH_COMPLETION
)
for _ in range(num_of_searches)
]
)
end_time = time.time()
return end_time - start_time
async def main():
data_directory_path = os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_load")
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_load")
cognee.config.system_root_directory(cognee_directory_path)
num_of_pdfs = 10
num_of_reps = 5
upper_boundary_minutes = 10
average_minutes = 8
recorded_times = []
for _ in range(num_of_reps):
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
s3_input = "s3://cognee-test-load-s3-bucket"
await cognee.add(s3_input)
recorded_times.append(await process_and_search(num_of_pdfs))
average_recorded_time = sum(recorded_times) / len(recorded_times)
assert average_recorded_time <= average_minutes * 60
assert all(rec_time <= upper_boundary_minutes * 60 for rec_time in recorded_times)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -27,6 +27,9 @@ def normalize_node_name(node_name: str) -> str:
async def setup_test_db():
# Disable backend access control to migrate relational data
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

View file

@ -146,7 +146,13 @@ async def main():
assert len(search_results) == 1, (
f"{name}: expected single-element list, got {len(search_results)}"
)
text = search_results[0]
from cognee.context_global_variables import backend_access_control_enabled
if backend_access_control_enabled():
text = search_results[0]["search_result"][0]
else:
text = search_results[0]
assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), (

View file

@ -1,3 +1,4 @@
import os
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from uuid import uuid4
@ -5,8 +6,6 @@ from fastapi.testclient import TestClient
from types import SimpleNamespace
import importlib
from cognee.api.client import app
# Fixtures for reuse across test classes
@pytest.fixture
@ -32,6 +31,10 @@ def mock_authenticated_user():
)
# To turn off authentication we need to set the environment variable before importing the module
# Also both require_authentication and backend access control must be false
os.environ["REQUIRE_AUTHENTICATION"] = "false"
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
@ -40,6 +43,8 @@ class TestConditionalAuthenticationEndpoints:
@pytest.fixture
def client(self):
from cognee.api.client import app
"""Create a test client."""
return TestClient(app)
@ -133,6 +138,8 @@ class TestConditionalAuthenticationBehavior:
@pytest.fixture
def client(self):
from cognee.api.client import app
return TestClient(app)
@pytest.mark.parametrize(
@ -209,6 +216,8 @@ class TestConditionalAuthenticationErrorHandling:
@pytest.fixture
def client(self):
from cognee.api.client import app
return TestClient(app)
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
@ -232,7 +241,7 @@ class TestConditionalAuthenticationErrorHandling:
# The exact error message may vary depending on the actual database connection
# The important thing is that we get a 500 error when user creation fails
def test_current_environment_configuration(self):
def test_current_environment_configuration(self, client):
"""Test that current environment configuration is working properly."""
# This tests the actual module state without trying to change it
from cognee.modules.users.methods.get_authenticated_user import (

View file

@ -0,0 +1,248 @@
"""Unit tests for TextChunker and TextChunkerWithOverlap behavioral equivalence."""
import pytest
from uuid import uuid4
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
from cognee.modules.data.processing.document_types import Document
@pytest.fixture(params=["TextChunker", "TextChunkerWithOverlap"])
def chunker_class(request):
"""Parametrize tests to run against both implementations."""
return TextChunker if request.param == "TextChunker" else TextChunkerWithOverlap
@pytest.fixture
def make_text_generator():
"""Factory for async text generators."""
def _factory(*texts):
async def gen():
for text in texts:
yield text
return gen
return _factory
async def collect_chunks(chunker):
"""Consume async generator and return list of chunks."""
chunks = []
async for chunk in chunker.read():
chunks.append(chunk)
return chunks
@pytest.mark.asyncio
async def test_empty_input_produces_no_chunks(chunker_class, make_text_generator):
"""Empty input should yield no chunks."""
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator("")
chunker = chunker_class(document, get_text, max_chunk_size=512)
chunks = await collect_chunks(chunker)
assert len(chunks) == 0, "Empty input should produce no chunks"
@pytest.mark.asyncio
async def test_whitespace_only_input_emits_single_chunk(chunker_class, make_text_generator):
"""Whitespace-only input should produce exactly one chunk with unchanged text."""
whitespace_text = " \n\t \r\n "
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(whitespace_text)
chunker = chunker_class(document, get_text, max_chunk_size=512)
chunks = await collect_chunks(chunker)
assert len(chunks) == 1, "Whitespace-only input should produce exactly one chunk"
assert chunks[0].text == whitespace_text, "Chunk text should equal input (whitespace preserved)"
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
@pytest.mark.asyncio
async def test_single_paragraph_below_limit_emits_one_chunk(chunker_class, make_text_generator):
"""Single paragraph below limit should emit exactly one chunk."""
text = "This is a short paragraph."
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=512)
chunks = await collect_chunks(chunker)
assert len(chunks) == 1, "Single short paragraph should produce exactly one chunk"
assert chunks[0].text == text, "Chunk text should match input"
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
assert chunks[0].chunk_size > 0, "Chunk should have positive size"
@pytest.mark.asyncio
async def test_oversized_paragraph_gets_emitted_as_a_single_chunk(
chunker_class, make_text_generator
):
"""Oversized paragraph from chunk_by_paragraph should be emitted as single chunk."""
text = ("A" * 1500) + ". Next sentence."
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=50)
chunks = await collect_chunks(chunker)
assert len(chunks) == 2, "Should produce 2 chunks (oversized paragraph + next sentence)"
assert chunks[0].chunk_size > 50, "First chunk should be oversized"
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
@pytest.mark.asyncio
async def test_overflow_on_next_paragraph_emits_separate_chunk(chunker_class, make_text_generator):
"""First paragraph near limit plus small paragraph should produce two separate chunks."""
first_para = " ".join(["word"] * 5)
second_para = "Short text."
text = first_para + " " + second_para
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=10)
chunks = await collect_chunks(chunker)
assert len(chunks) == 2, "Should produce 2 chunks due to overflow"
assert chunks[0].text.strip() == first_para, "First chunk should contain only first paragraph"
assert chunks[1].text.strip() == second_para, (
"Second chunk should contain only second paragraph"
)
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
@pytest.mark.asyncio
async def test_small_paragraphs_batch_correctly(chunker_class, make_text_generator):
"""Multiple small paragraphs should batch together with joiner spaces counted."""
paragraphs = [" ".join(["word"] * 12) for _ in range(40)]
text = " ".join(paragraphs)
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=49)
chunks = await collect_chunks(chunker)
assert len(chunks) == 20, (
"Should batch paragraphs (2 per chunk: 12 words × 2 tokens = 24, 24 + 1 joiner + 24 = 49)"
)
assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
"Chunk indices should be sequential"
)
all_text = " ".join(chunk.text.strip() for chunk in chunks)
expected_text = " ".join(paragraphs)
assert all_text == expected_text, "All paragraph text should be preserved with correct spacing"
@pytest.mark.asyncio
async def test_alternating_large_and_small_paragraphs_dont_batch(
chunker_class, make_text_generator
):
"""Alternating near-max and small paragraphs should each become separate chunks."""
large1 = "word" * 15 + "."
small1 = "Short."
large2 = "word" * 15 + "."
small2 = "Tiny."
text = large1 + " " + small1 + " " + large2 + " " + small2
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
max_chunk_size = 10
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=max_chunk_size)
chunks = await collect_chunks(chunker)
assert len(chunks) == 4, "Should produce multiple chunks"
assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
"Chunk indices should be sequential"
)
assert chunks[0].chunk_size > max_chunk_size, (
"First chunk should be oversized (large paragraph)"
)
assert chunks[1].chunk_size <= max_chunk_size, "Second chunk should be small (small paragraph)"
assert chunks[2].chunk_size > max_chunk_size, (
"Third chunk should be oversized (large paragraph)"
)
assert chunks[3].chunk_size <= max_chunk_size, "Fourth chunk should be small (small paragraph)"
@pytest.mark.asyncio
async def test_chunk_indices_and_ids_are_deterministic(chunker_class, make_text_generator):
"""Running chunker twice on identical input should produce identical indices and IDs."""
sentence1 = "one " * 4 + ". "
sentence2 = "two " * 4 + ". "
sentence3 = "one " * 4 + ". "
sentence4 = "two " * 4 + ". "
text = sentence1 + sentence2 + sentence3 + sentence4
doc_id = uuid4()
max_chunk_size = 20
document1 = Document(
id=doc_id,
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text1 = make_text_generator(text)
chunker1 = chunker_class(document1, get_text1, max_chunk_size=max_chunk_size)
chunks1 = await collect_chunks(chunker1)
document2 = Document(
id=doc_id,
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text2 = make_text_generator(text)
chunker2 = chunker_class(document2, get_text2, max_chunk_size=max_chunk_size)
chunks2 = await collect_chunks(chunker2)
assert len(chunks1) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
assert len(chunks2) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
assert [c.chunk_index for c in chunks1] == [0, 1], "First run indices should be [0, 1]"
assert [c.chunk_index for c in chunks2] == [0, 1], "Second run indices should be [0, 1]"
assert chunks1[0].id == chunks2[0].id, "First chunk ID should be deterministic"
assert chunks1[1].id == chunks2[1].id, "Second chunk ID should be deterministic"
assert chunks1[0].id != chunks1[1].id, "Chunk IDs should be unique within a run"

View file

@ -0,0 +1,324 @@
"""Unit tests for TextChunkerWithOverlap overlap behavior."""
import sys
import pytest
from uuid import uuid4
from unittest.mock import patch
from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
from cognee.modules.data.processing.document_types import Document
from cognee.tasks.chunks import chunk_by_paragraph
@pytest.fixture
def make_text_generator():
"""Factory for async text generators."""
def _factory(*texts):
async def gen():
for text in texts:
yield text
return gen
return _factory
@pytest.fixture
def make_controlled_chunk_data():
"""Factory for controlled chunk_data generators."""
def _factory(*sentences, chunk_size_per_sentence=10):
def _chunk_data(text):
return [
{
"text": sentence,
"chunk_size": chunk_size_per_sentence,
"cut_type": "sentence",
"chunk_id": uuid4(),
}
for sentence in sentences
]
return _chunk_data
return _factory
@pytest.mark.asyncio
async def test_half_overlap_preserves_content_across_chunks(
make_text_generator, make_controlled_chunk_data
):
"""With 50% overlap, consecutive chunks should share half their content."""
s1 = "one"
s2 = "two"
s3 = "three"
s4 = "four"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=20,
chunk_overlap_ratio=0.5,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 3, "Should produce exactly 3 chunks (s1+s2, s2+s3, s3+s4)"
assert [c.chunk_index for c in chunks] == [0, 1, 2], "Chunk indices should be [0, 1, 2]"
assert "one" in chunks[0].text and "two" in chunks[0].text, "Chunk 0 should contain s1 and s2"
assert "two" in chunks[1].text and "three" in chunks[1].text, (
"Chunk 1 should contain s2 (overlap) and s3"
)
assert "three" in chunks[2].text and "four" in chunks[2].text, (
"Chunk 2 should contain s3 (overlap) and s4"
)
@pytest.mark.asyncio
async def test_zero_overlap_produces_no_duplicate_content(
make_text_generator, make_controlled_chunk_data
):
"""With 0% overlap, no content should appear in multiple chunks."""
s1 = "one"
s2 = "two"
s3 = "three"
s4 = "four"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=20,
chunk_overlap_ratio=0.0,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 2, "Should produce exactly 2 chunks (s1+s2, s3+s4)"
assert "one" in chunks[0].text and "two" in chunks[0].text, (
"First chunk should contain s1 and s2"
)
assert "three" in chunks[1].text and "four" in chunks[1].text, (
"Second chunk should contain s3 and s4"
)
assert "two" not in chunks[1].text and "three" not in chunks[0].text, (
"No overlap: end of chunk 0 should not appear in chunk 1"
)
@pytest.mark.asyncio
async def test_small_overlap_ratio_creates_minimal_overlap(
make_text_generator, make_controlled_chunk_data
):
"""With 25% overlap ratio, chunks should have minimal overlap."""
s1 = "alpha"
s2 = "beta"
s3 = "gamma"
s4 = "delta"
s5 = "epsilon"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=10)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=40,
chunk_overlap_ratio=0.25,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 2, "Should produce exactly 2 chunks"
assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
"Chunk 0 should contain s1 through s4"
)
assert s4 in chunks[1].text and s5 in chunks[1].text, (
"Chunk 1 should contain overlap s4 and new content s5"
)
@pytest.mark.asyncio
async def test_high_overlap_ratio_creates_significant_overlap(
make_text_generator, make_controlled_chunk_data
):
"""With 75% overlap ratio, consecutive chunks should share most content."""
s1 = "red"
s2 = "blue"
s3 = "green"
s4 = "yellow"
s5 = "purple"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=5)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=20,
chunk_overlap_ratio=0.75,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 2, "Should produce exactly 2 chunks with 75% overlap"
assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
"Chunk 0 should contain s1, s2, s3, s4"
)
assert all(token in chunks[1].text for token in [s2, s3, s4, s5]), (
"Chunk 1 should contain s2, s3, s4 (overlap) and s5"
)
@pytest.mark.asyncio
async def test_single_chunk_no_dangling_overlap(make_text_generator, make_controlled_chunk_data):
"""Text that fits in one chunk should produce exactly one chunk, no overlap artifact."""
s1 = "alpha"
s2 = "beta"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, chunk_size_per_sentence=10)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=20,
chunk_overlap_ratio=0.5,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 1, (
"Should produce exactly 1 chunk when content fits within max_chunk_size"
)
assert chunks[0].chunk_index == 0, "Single chunk should have index 0"
assert "alpha" in chunks[0].text and "beta" in chunks[0].text, (
"Single chunk should contain all content"
)
@pytest.mark.asyncio
async def test_paragraph_chunking_with_overlap(make_text_generator):
"""Test that chunk_by_paragraph integration produces 25% overlap between chunks."""
def mock_get_embedding_engine():
class MockEngine:
tokenizer = None
return MockEngine()
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
max_chunk_size = 20
overlap_ratio = 0.25 # 5 token overlap
paragraph_max_size = int(0.5 * overlap_ratio * max_chunk_size) # = 2
text = (
"A0 A1. A2 A3. A4 A5. A6 A7. A8 A9. " # 10 tokens (0-9)
"B0 B1. B2 B3. B4 B5. B6 B7. B8 B9. " # 10 tokens (10-19)
"C0 C1. C2 C3. C4 C5. C6 C7. C8 C9. " # 10 tokens (20-29)
"D0 D1. D2 D3. D4 D5. D6 D7. D8 D9. " # 10 tokens (30-39)
"E0 E1. E2 E3. E4 E5. E6 E7. E8 E9." # 10 tokens (40-49)
)
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
def get_chunk_data(text_input):
return chunk_by_paragraph(
text_input, max_chunk_size=paragraph_max_size, batch_paragraphs=True
)
with patch.object(
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
):
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=max_chunk_size,
chunk_overlap_ratio=overlap_ratio,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 3, f"Should produce exactly 3 chunks, got {len(chunks)}"
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
assert chunks[2].chunk_index == 2, "Third chunk should have index 2"
assert "A0" in chunks[0].text, "Chunk 0 should start with A0"
assert "A9" in chunks[0].text, "Chunk 0 should contain A9"
assert "B0" in chunks[0].text, "Chunk 0 should contain B0"
assert "B9" in chunks[0].text, "Chunk 0 should contain up to B9 (20 tokens)"
assert "B" in chunks[1].text, "Chunk 1 should have overlap from B section"
assert "C" in chunks[1].text, "Chunk 1 should contain C section"
assert "D" in chunks[1].text, "Chunk 1 should contain D section"
assert "D" in chunks[2].text, "Chunk 2 should have overlap from D section"
assert "E0" in chunks[2].text, "Chunk 2 should contain E0"
assert "E9" in chunks[2].text, "Chunk 2 should end with E9"
chunk_0_end_words = chunks[0].text.split()[-4:]
chunk_1_words = chunks[1].text.split()
overlap_0_1 = any(word in chunk_1_words for word in chunk_0_end_words)
assert overlap_0_1, (
f"No overlap detected between chunks 0 and 1. "
f"Chunk 0 ends with: {chunk_0_end_words}, "
f"Chunk 1 starts with: {chunk_1_words[:6]}"
)
chunk_1_end_words = chunks[1].text.split()[-4:]
chunk_2_words = chunks[2].text.split()
overlap_1_2 = any(word in chunk_2_words for word in chunk_1_end_words)
assert overlap_1_2, (
f"No overlap detected between chunks 1 and 2. "
f"Chunk 1 ends with: {chunk_1_end_words}, "
f"Chunk 2 starts with: {chunk_2_words[:6]}"
)

View file

@ -16,9 +16,11 @@ async def test_cognify_session_success():
patch("cognee.add", new_callable=AsyncMock) as mock_add,
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
):
await cognify_session(session_data)
await cognify_session(session_data, dataset_id="123")
mock_add.assert_called_once_with(session_data, node_set=["user_sessions_from_cache"])
mock_add.assert_called_once_with(
session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
)
mock_cognify.assert_called_once()
@ -101,7 +103,9 @@ async def test_cognify_session_with_special_characters():
patch("cognee.add", new_callable=AsyncMock) as mock_add,
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
):
await cognify_session(session_data)
await cognify_session(session_data, dataset_id="123")
mock_add.assert_called_once_with(session_data, node_set=["user_sessions_from_cache"])
mock_add.assert_called_once_with(
session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
)
mock_cognify.assert_called_once()

View file

@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration:
# REQUIRE_AUTHENTICATION should be a boolean
assert isinstance(REQUIRE_AUTHENTICATION, bool)
# Currently should be False (optional authentication)
assert not REQUIRE_AUTHENTICATION
class TestConditionalAuthenticationEnvironmentVariables:
"""Test environment variable handling."""
def test_require_authentication_default_false(self):
"""Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
with patch.dict(os.environ, {}, clear=True):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment - module will see empty environment
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
importlib.invalidate_caches()
assert not REQUIRE_AUTHENTICATION
def test_require_authentication_true(self):
"""Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables:
assert REQUIRE_AUTHENTICATION
def test_require_authentication_false_explicit(self):
"""Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
assert not REQUIRE_AUTHENTICATION
def test_require_authentication_case_insensitive(self):
"""Test that environment variable parsing is case insensitive when imported."""
test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
for case in test_cases:
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
expected = case.lower() == "true"
assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
def test_current_require_authentication_value(self):
"""Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
# The module-level variable should currently be False (set at import time)
assert isinstance(REQUIRE_AUTHENTICATION, bool)
assert not REQUIRE_AUTHENTICATION
class TestConditionalAuthenticationEdgeCases:
"""Test edge cases and error scenarios."""

View file

@ -168,7 +168,7 @@ async def run_procurement_example():
for q in questions:
print(f"Question: \n{q}")
results = await procurement_system.search_memory(q, search_categories=[category])
top_answer = results[category][0]
top_answer = results[category][0]["search_result"][0]
print(f"Answer: \n{top_answer.strip()}\n")
research_notes[category].append({"question": q, "answer": top_answer})

View file

@ -1,5 +1,7 @@
import argparse
import asyncio
import os
import cognee
from cognee import SearchType
from cognee.shared.logging_utils import setup_logging, ERROR
@ -8,6 +10,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
async def main(repo_path, include_docs):
# Disable permissions feature for this example
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
run_status = False
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
run_status = run_status

View file

@ -67,7 +67,6 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}], # A placeholder to prevent fetching the entire graph
dataset="feedback_enrichment_minimal",
)

View file

@ -89,7 +89,7 @@ async def main():
)
print("Coding rules created by memify:")
for coding_rule in coding_rules:
for coding_rule in coding_rules[0]["search_result"][0]:
print("- " + coding_rule)
# Visualize new graph with added memify context

View file

@ -31,6 +31,9 @@ from cognee.infrastructure.databases.vector.pgvector import (
async def main():
# Disable backend access control to migrate relational data
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
# Clean all data stored in Cognee
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

View file

@ -59,14 +59,6 @@ async def main():
for result_text in search_results:
print(result_text)
# Example output:
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
# (...)
# It represents nodes and relationships in the knowledge graph:
# - The first element is the source node (e.g., 'natural language processing').
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
# - The third element is the target node (e.g., 'computer science').
if __name__ == "__main__":
logger = setup_logging(log_level=ERROR)