Merge branch 'dev' into fix-cypher-search

This commit is contained in:
Igor Ilic 2025-11-08 19:44:20 +01:00 committed by GitHub
commit 475e6a5b16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 1224 additions and 250 deletions

View file

@ -169,8 +169,9 @@ REQUIRE_AUTHENTICATION=False
# Vector: LanceDB
# Graph: KuzuDB
#
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
ENABLE_BACKEND_ACCESS_CONTROL=False
# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
# Disable mode when using not supported graph/vector databases.
ENABLE_BACKEND_ACCESS_CONTROL=True
################################################################################
# ☁️ Cloud Sync Settings

70
.github/workflows/load_tests.yml vendored Normal file
View file

@ -0,0 +1,70 @@
name: Load tests
permissions:
contents: read
on:
workflow_dispatch:
workflow_call:
secrets:
LLM_MODEL:
required: true
LLM_ENDPOINT:
required: true
LLM_API_KEY:
required: true
LLM_API_VERSION:
required: true
EMBEDDING_MODEL:
required: true
EMBEDDING_ENDPOINT:
required: true
EMBEDDING_API_KEY:
required: true
EMBEDDING_API_VERSION:
required: true
OPENAI_API_KEY:
required: true
AWS_ACCESS_KEY_ID:
required: true
AWS_SECRET_ACCESS_KEY:
required: true
jobs:
test-load:
name: Test Load
runs-on: ubuntu-22.04
timeout-minutes: 60
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Verify File Descriptor Limit
run: ulimit -n
- name: Run Load Test
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: True
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
STORAGE_BACKEND: s3
AWS_REGION: eu-west-1
AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }}
run: uv run python ./cognee/tests/test_load.py

17
.github/workflows/release_test.yml vendored Normal file
View file

@ -0,0 +1,17 @@
# Long-running, heavy and resource-consuming tests for release validation
name: Release Test Workflow
permissions:
contents: read
on:
workflow_dispatch:
pull_request:
branches:
- main
jobs:
load-tests:
name: Load Tests
uses: ./.github/workflows/load_tests.yml
secrets: inherit

View file

@ -84,6 +84,7 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
@ -135,6 +136,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
VECTOR_DB_PROVIDER: 'pgvector'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'
@ -197,6 +199,7 @@ jobs:
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_NAME: cognee_db
DB_HOST: 127.0.0.1
DB_PORT: 5432

View file

@ -4,6 +4,8 @@ from typing import Union
from uuid import UUID
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user
@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None)
vector_dbs_with_multi_user_support = ["lancedb"]
graph_dbs_with_multi_user_support = ["kuzu"]
async def set_session_user_context_variable(user):
session_user.set(user)
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
)
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return False
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
"""
If backend access control is enabled this function will ensure all datasets have their own databases,
@ -40,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
base_config = get_base_config()
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
if not backend_access_control_enabled():
return
user = await get_user(user_id)

View file

@ -40,7 +40,7 @@ async def persist_sessions_in_knowledge_graph_pipeline(
extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)]
enrichment_tasks = [
Task(cognify_session),
Task(cognify_session, dataset_id=dataset_to_write[0].id),
]
result = await memify(

View file

@ -0,0 +1,124 @@
from cognee.shared.logging_utils import get_logger
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from cognee.modules.chunking.Chunker import Chunker
from .models.DocumentChunk import DocumentChunk
logger = get_logger()
class TextChunkerWithOverlap(Chunker):
def __init__(
self,
document,
get_text: callable,
max_chunk_size: int,
chunk_overlap_ratio: float = 0.0,
get_chunk_data: callable = None,
):
super().__init__(document, get_text, max_chunk_size)
self._accumulated_chunk_data = []
self._accumulated_size = 0
self.chunk_overlap_ratio = chunk_overlap_ratio
self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio)
if get_chunk_data is not None:
self.get_chunk_data = get_chunk_data
elif chunk_overlap_ratio > 0:
paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size)
self.get_chunk_data = lambda text: chunk_by_paragraph(
text, paragraph_max_size, batch_paragraphs=True
)
else:
self.get_chunk_data = lambda text: chunk_by_paragraph(
text, self.max_chunk_size, batch_paragraphs=True
)
def _accumulation_overflows(self, chunk_data):
"""Check if adding chunk_data would exceed max_chunk_size."""
return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size
def _accumulate_chunk_data(self, chunk_data):
"""Add chunk_data to the current accumulation."""
self._accumulated_chunk_data.append(chunk_data)
self._accumulated_size += chunk_data["chunk_size"]
def _clear_accumulation(self):
"""Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio."""
if self.chunk_overlap == 0:
self._accumulated_chunk_data = []
self._accumulated_size = 0
return
# Keep chunk_data from the end that fit in overlap
overlap_chunk_data = []
overlap_size = 0
for chunk_data in reversed(self._accumulated_chunk_data):
if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap:
overlap_chunk_data.insert(0, chunk_data)
overlap_size += chunk_data["chunk_size"]
else:
break
self._accumulated_chunk_data = overlap_chunk_data
self._accumulated_size = overlap_size
def _create_chunk(self, text, size, cut_type, chunk_id=None):
"""Create a DocumentChunk with standard metadata."""
try:
return DocumentChunk(
id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text=text,
chunk_size=size,
is_part_of=self.document,
chunk_index=self.chunk_index,
cut_type=cut_type,
contains=[],
metadata={"index_fields": ["text"]},
)
except Exception as e:
logger.error(e)
raise e
def _create_chunk_from_accumulation(self):
"""Create a DocumentChunk from current accumulated chunk_data."""
chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data)
return self._create_chunk(
text=chunk_text,
size=self._accumulated_size,
cut_type=self._accumulated_chunk_data[-1]["cut_type"],
)
def _emit_chunk(self, chunk_data):
"""Emit a chunk when accumulation overflows."""
if len(self._accumulated_chunk_data) > 0:
chunk = self._create_chunk_from_accumulation()
self._clear_accumulation()
self._accumulate_chunk_data(chunk_data)
else:
# Handle single chunk_data exceeding max_chunk_size
chunk = self._create_chunk(
text=chunk_data["text"],
size=chunk_data["chunk_size"],
cut_type=chunk_data["cut_type"],
chunk_id=chunk_data["chunk_id"],
)
self.chunk_index += 1
return chunk
async def read(self):
async for content_text in self.get_text():
for chunk_data in self.get_chunk_data(content_text):
if not self._accumulation_overflows(chunk_data):
self._accumulate_chunk_data(chunk_data)
continue
yield self._emit_chunk(chunk_data)
if len(self._accumulated_chunk_data) == 0:
return
yield self._create_chunk_from_accumulation()

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional, List
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever):
return None
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> List[str]:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generate completion using provided context or fetch new context.
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:

View file

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Optional, Type
from abc import ABC, abstractmethod
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
@abstractmethod
async def get_completion(
self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
) -> str:
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context (triplets)."""
pass

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Optional, Type, List
class BaseRetriever(ABC):
@ -12,7 +12,11 @@ class BaseRetriever(ABC):
@abstractmethod
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""Generates a response using the query and optional context."""
pass

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional
from typing import Any, Optional, Type, List
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever):
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> str:
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generates an LLM completion using the context.
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if session_save:
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
session_id=session_id,
)
return completion
return [completion]

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Optional, List, Type
from typing import Optional, List, Type, Any
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -56,7 +56,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
) -> List[str]:
response_model: Type = str,
) -> List[Any]:
"""
Extends the context for a given query by retrieving related triplets and generating new
completions based on them.
@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context_text and triplets and completion:

View file

@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import (
generate_structured_completion,
generate_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import (
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- get_structured_completion
Instance variables include:
- validation_system_prompt_path
@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_structured_completion(
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
return completion, context_text, triplets
async def get_structured_completion(
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
max_iter=4,
response_model: Type = str,
) -> Any:
) -> List[Any]:
"""
Generate structured completion responses based on a user query and contextual information.
Generate completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs. It returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
Returns:
--------
- Any: The generated structured completion based on the response model.
- List[str]: A list containing the generated answer to the user's query.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id,
)
return completion
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4,
) -> List[str]:
"""
Generate completion responses based on a user query and contextual information.
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
Returns:
--------
- List[str]: A list containing the generated answer to the user's query.
"""
completion = await self.get_structured_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
return [completion]

View file

@ -146,7 +146,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
) -> List[str]:
response_model: Type = str,
) -> List[Any]:
"""
Generates a completion using graph connections context based on a query.
@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:

View file

@ -146,8 +146,12 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events)
async def get_completion(
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
) -> List[str]:
self,
query: str,
context: Optional[str] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generates a response using the query and optional context.
@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever):
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:

View file

@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_structured_completion(
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
@ -12,7 +12,7 @@ async def generate_structured_completion(
conversation_history: Optional[str] = None,
response_model: Type = str,
) -> Any:
"""Generates a structured completion using LLM with given context and prompts."""
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@ -28,26 +28,6 @@ async def generate_structured_completion(
)
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str,
)
async def summarize_text(
text: str,
system_prompt_path: str = "summarize_search_results.txt",

View file

@ -1,4 +1,3 @@
import os
import json
import asyncio
from uuid import UUID
@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables
from cognee.context_global_variables import backend_access_control_enabled
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -74,7 +74,7 @@ async def search(
)
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
if backend_access_control_enabled():
search_results = await authorized_search(
query_type=query_type,
query_text=query_text,
@ -156,7 +156,7 @@ async def search(
)
else:
# This is for maintaining backwards compatibility
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
if backend_access_control_enabled():
return_value = []
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)

View file

@ -5,6 +5,7 @@ from ..models import User
from ..get_fastapi_users import get_fastapi_users
from .get_default_user import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import backend_access_control_enabled
logger = get_logger("get_authenticated_user")
@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
or backend_access_control_enabled()
)
fastapi_users = get_fastapi_users()

View file

@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
)
retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_structured_completion(
completion = await retriever.get_completion(
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion:
enrichment.improved_answer = completion.answer
enrichment.improved_answer = completion[0].answer
enrichment.new_context = new_context_text
enrichment.explanation = completion.explanation
enrichment.explanation = completion[0].explanation
return enrichment
else:
logger.warning(

View file

@ -6,7 +6,7 @@ from cognee.shared.logging_utils import get_logger
logger = get_logger("cognify_session")
async def cognify_session(data):
async def cognify_session(data, dataset_id=None):
"""
Process and cognify session data into the knowledge graph.
@ -16,6 +16,7 @@ async def cognify_session(data):
Args:
data: Session string containing Question, Context, and Answer information.
dataset_name: Name of dataset.
Raises:
CogneeValidationError: If data is None or empty.
@ -28,9 +29,9 @@ async def cognify_session(data):
logger.info("Processing session data for cognification")
await cognee.add(data, node_set=["user_sessions_from_cache"])
await cognee.add(data, dataset_id=dataset_id, node_set=["user_sessions_from_cache"])
logger.debug("Session data added to cognee with node_set: user_sessions")
await cognee.cognify()
await cognee.cognify(datasets=[dataset_id])
logger.info("Session data successfully cognified")
except CogneeValidationError:

View file

@ -39,12 +39,12 @@ async def main():
answer = await cognee.search("Do programmers change light bulbs?")
assert len(answer) != 0
lowercase_answer = answer[0].lower()
lowercase_answer = answer[0]["search_result"][0].lower()
assert ("no" in lowercase_answer) or ("none" in lowercase_answer)
answer = await cognee.search("What colours are there in the presentation table?")
assert len(answer) != 0
lowercase_answer = answer[0].lower()
lowercase_answer = answer[0]["search_result"][0].lower()
assert (
("red" in lowercase_answer)
and ("blue" in lowercase_answer)

View file

@ -56,10 +56,10 @@ async def main():
"""DataCo is a data analytics company. They help businesses make sense of their data."""
)
await cognee.add(text_1, dataset_name)
await cognee.add(text_2, dataset_name)
await cognee.add(data=text_1, dataset_name=dataset_name)
await cognee.add(data=text_2, dataset_name=dataset_name)
await cognee.cognify([dataset_name])
await cognee.cognify(datasets=[dataset_name])
user = await get_default_user()

View file

@ -133,7 +133,7 @@ async def main():
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}],
dataset="feedback_enrichment_test_memify",
dataset=dataset_name,
)
nodes_after, edges_after = await graph_engine.get_graph_data()

View file

@ -90,15 +90,17 @@ async def main():
)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="What information do you contain?"
query_type=SearchType.GRAPH_COMPLETION,
query_text="What information do you contain?",
dataset_ids=[pipeline_run_obj.dataset_id],
)
assert "Mark" in search_results[0], (
assert "Mark" in search_results[0]["search_result"][0], (
"Failed to update document, no mention of Mark in search results"
)
assert "Cindy" in search_results[0], (
assert "Cindy" in search_results[0]["search_result"][0], (
"Failed to update document, no mention of Cindy in search results"
)
assert "Artificial intelligence" not in search_results[0], (
assert "Artificial intelligence" not in search_results[0]["search_result"][0], (
"Failed to update document, Artificial intelligence still mentioned in search results"
)

62
cognee/tests/test_load.py Normal file
View file

@ -0,0 +1,62 @@
import os
import pathlib
import asyncio
import time
import cognee
from cognee.modules.search.types import SearchType
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def process_and_search(num_of_searches):
start_time = time.time()
await cognee.cognify()
await asyncio.gather(
*[
cognee.search(
query_text="Tell me about the document", query_type=SearchType.GRAPH_COMPLETION
)
for _ in range(num_of_searches)
]
)
end_time = time.time()
return end_time - start_time
async def main():
data_directory_path = os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_load")
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_load")
cognee.config.system_root_directory(cognee_directory_path)
num_of_pdfs = 10
num_of_reps = 5
upper_boundary_minutes = 10
average_minutes = 8
recorded_times = []
for _ in range(num_of_reps):
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
s3_input = "s3://cognee-test-load-s3-bucket"
await cognee.add(s3_input)
recorded_times.append(await process_and_search(num_of_pdfs))
average_recorded_time = sum(recorded_times) / len(recorded_times)
assert average_recorded_time <= average_minutes * 60
assert all(rec_time <= upper_boundary_minutes * 60 for rec_time in recorded_times)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -27,6 +27,9 @@ def normalize_node_name(node_name: str) -> str:
async def setup_test_db():
# Disable backend access control to migrate relational data
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

View file

@ -146,7 +146,13 @@ async def main():
assert len(search_results) == 1, (
f"{name}: expected single-element list, got {len(search_results)}"
)
text = search_results[0]
from cognee.context_global_variables import backend_access_control_enabled
if backend_access_control_enabled():
text = search_results[0]["search_result"][0]
else:
text = search_results[0]
assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), (

View file

@ -1,3 +1,4 @@
import os
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from uuid import uuid4
@ -5,8 +6,6 @@ from fastapi.testclient import TestClient
from types import SimpleNamespace
import importlib
from cognee.api.client import app
# Fixtures for reuse across test classes
@pytest.fixture
@ -32,6 +31,10 @@ def mock_authenticated_user():
)
# To turn off authentication we need to set the environment variable before importing the module
# Also both require_authentication and backend access control must be false
os.environ["REQUIRE_AUTHENTICATION"] = "false"
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
@ -40,6 +43,8 @@ class TestConditionalAuthenticationEndpoints:
@pytest.fixture
def client(self):
from cognee.api.client import app
"""Create a test client."""
return TestClient(app)
@ -133,6 +138,8 @@ class TestConditionalAuthenticationBehavior:
@pytest.fixture
def client(self):
from cognee.api.client import app
return TestClient(app)
@pytest.mark.parametrize(
@ -209,6 +216,8 @@ class TestConditionalAuthenticationErrorHandling:
@pytest.fixture
def client(self):
from cognee.api.client import app
return TestClient(app)
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
@ -232,7 +241,7 @@ class TestConditionalAuthenticationErrorHandling:
# The exact error message may vary depending on the actual database connection
# The important thing is that we get a 500 error when user creation fails
def test_current_environment_configuration(self):
def test_current_environment_configuration(self, client):
"""Test that current environment configuration is working properly."""
# This tests the actual module state without trying to change it
from cognee.modules.users.methods.get_authenticated_user import (

View file

@ -0,0 +1,248 @@
"""Unit tests for TextChunker and TextChunkerWithOverlap behavioral equivalence."""
import pytest
from uuid import uuid4
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
from cognee.modules.data.processing.document_types import Document
@pytest.fixture(params=["TextChunker", "TextChunkerWithOverlap"])
def chunker_class(request):
"""Parametrize tests to run against both implementations."""
return TextChunker if request.param == "TextChunker" else TextChunkerWithOverlap
@pytest.fixture
def make_text_generator():
"""Factory for async text generators."""
def _factory(*texts):
async def gen():
for text in texts:
yield text
return gen
return _factory
async def collect_chunks(chunker):
"""Consume async generator and return list of chunks."""
chunks = []
async for chunk in chunker.read():
chunks.append(chunk)
return chunks
@pytest.mark.asyncio
async def test_empty_input_produces_no_chunks(chunker_class, make_text_generator):
"""Empty input should yield no chunks."""
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator("")
chunker = chunker_class(document, get_text, max_chunk_size=512)
chunks = await collect_chunks(chunker)
assert len(chunks) == 0, "Empty input should produce no chunks"
@pytest.mark.asyncio
async def test_whitespace_only_input_emits_single_chunk(chunker_class, make_text_generator):
"""Whitespace-only input should produce exactly one chunk with unchanged text."""
whitespace_text = " \n\t \r\n "
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(whitespace_text)
chunker = chunker_class(document, get_text, max_chunk_size=512)
chunks = await collect_chunks(chunker)
assert len(chunks) == 1, "Whitespace-only input should produce exactly one chunk"
assert chunks[0].text == whitespace_text, "Chunk text should equal input (whitespace preserved)"
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
@pytest.mark.asyncio
async def test_single_paragraph_below_limit_emits_one_chunk(chunker_class, make_text_generator):
"""Single paragraph below limit should emit exactly one chunk."""
text = "This is a short paragraph."
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=512)
chunks = await collect_chunks(chunker)
assert len(chunks) == 1, "Single short paragraph should produce exactly one chunk"
assert chunks[0].text == text, "Chunk text should match input"
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
assert chunks[0].chunk_size > 0, "Chunk should have positive size"
@pytest.mark.asyncio
async def test_oversized_paragraph_gets_emitted_as_a_single_chunk(
chunker_class, make_text_generator
):
"""Oversized paragraph from chunk_by_paragraph should be emitted as single chunk."""
text = ("A" * 1500) + ". Next sentence."
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=50)
chunks = await collect_chunks(chunker)
assert len(chunks) == 2, "Should produce 2 chunks (oversized paragraph + next sentence)"
assert chunks[0].chunk_size > 50, "First chunk should be oversized"
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
@pytest.mark.asyncio
async def test_overflow_on_next_paragraph_emits_separate_chunk(chunker_class, make_text_generator):
"""First paragraph near limit plus small paragraph should produce two separate chunks."""
first_para = " ".join(["word"] * 5)
second_para = "Short text."
text = first_para + " " + second_para
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=10)
chunks = await collect_chunks(chunker)
assert len(chunks) == 2, "Should produce 2 chunks due to overflow"
assert chunks[0].text.strip() == first_para, "First chunk should contain only first paragraph"
assert chunks[1].text.strip() == second_para, (
"Second chunk should contain only second paragraph"
)
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
@pytest.mark.asyncio
async def test_small_paragraphs_batch_correctly(chunker_class, make_text_generator):
"""Multiple small paragraphs should batch together with joiner spaces counted."""
paragraphs = [" ".join(["word"] * 12) for _ in range(40)]
text = " ".join(paragraphs)
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=49)
chunks = await collect_chunks(chunker)
assert len(chunks) == 20, (
"Should batch paragraphs (2 per chunk: 12 words × 2 tokens = 24, 24 + 1 joiner + 24 = 49)"
)
assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
"Chunk indices should be sequential"
)
all_text = " ".join(chunk.text.strip() for chunk in chunks)
expected_text = " ".join(paragraphs)
assert all_text == expected_text, "All paragraph text should be preserved with correct spacing"
@pytest.mark.asyncio
async def test_alternating_large_and_small_paragraphs_dont_batch(
chunker_class, make_text_generator
):
"""Alternating near-max and small paragraphs should each become separate chunks."""
large1 = "word" * 15 + "."
small1 = "Short."
large2 = "word" * 15 + "."
small2 = "Tiny."
text = large1 + " " + small1 + " " + large2 + " " + small2
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
max_chunk_size = 10
get_text = make_text_generator(text)
chunker = chunker_class(document, get_text, max_chunk_size=max_chunk_size)
chunks = await collect_chunks(chunker)
assert len(chunks) == 4, "Should produce multiple chunks"
assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
"Chunk indices should be sequential"
)
assert chunks[0].chunk_size > max_chunk_size, (
"First chunk should be oversized (large paragraph)"
)
assert chunks[1].chunk_size <= max_chunk_size, "Second chunk should be small (small paragraph)"
assert chunks[2].chunk_size > max_chunk_size, (
"Third chunk should be oversized (large paragraph)"
)
assert chunks[3].chunk_size <= max_chunk_size, "Fourth chunk should be small (small paragraph)"
@pytest.mark.asyncio
async def test_chunk_indices_and_ids_are_deterministic(chunker_class, make_text_generator):
"""Running chunker twice on identical input should produce identical indices and IDs."""
sentence1 = "one " * 4 + ". "
sentence2 = "two " * 4 + ". "
sentence3 = "one " * 4 + ". "
sentence4 = "two " * 4 + ". "
text = sentence1 + sentence2 + sentence3 + sentence4
doc_id = uuid4()
max_chunk_size = 20
document1 = Document(
id=doc_id,
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text1 = make_text_generator(text)
chunker1 = chunker_class(document1, get_text1, max_chunk_size=max_chunk_size)
chunks1 = await collect_chunks(chunker1)
document2 = Document(
id=doc_id,
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text2 = make_text_generator(text)
chunker2 = chunker_class(document2, get_text2, max_chunk_size=max_chunk_size)
chunks2 = await collect_chunks(chunker2)
assert len(chunks1) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
assert len(chunks2) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
assert [c.chunk_index for c in chunks1] == [0, 1], "First run indices should be [0, 1]"
assert [c.chunk_index for c in chunks2] == [0, 1], "Second run indices should be [0, 1]"
assert chunks1[0].id == chunks2[0].id, "First chunk ID should be deterministic"
assert chunks1[1].id == chunks2[1].id, "Second chunk ID should be deterministic"
assert chunks1[0].id != chunks1[1].id, "Chunk IDs should be unique within a run"

View file

@ -0,0 +1,324 @@
"""Unit tests for TextChunkerWithOverlap overlap behavior."""
import sys
import pytest
from uuid import uuid4
from unittest.mock import patch
from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
from cognee.modules.data.processing.document_types import Document
from cognee.tasks.chunks import chunk_by_paragraph
@pytest.fixture
def make_text_generator():
"""Factory for async text generators."""
def _factory(*texts):
async def gen():
for text in texts:
yield text
return gen
return _factory
@pytest.fixture
def make_controlled_chunk_data():
"""Factory for controlled chunk_data generators."""
def _factory(*sentences, chunk_size_per_sentence=10):
def _chunk_data(text):
return [
{
"text": sentence,
"chunk_size": chunk_size_per_sentence,
"cut_type": "sentence",
"chunk_id": uuid4(),
}
for sentence in sentences
]
return _chunk_data
return _factory
@pytest.mark.asyncio
async def test_half_overlap_preserves_content_across_chunks(
make_text_generator, make_controlled_chunk_data
):
"""With 50% overlap, consecutive chunks should share half their content."""
s1 = "one"
s2 = "two"
s3 = "three"
s4 = "four"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=20,
chunk_overlap_ratio=0.5,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 3, "Should produce exactly 3 chunks (s1+s2, s2+s3, s3+s4)"
assert [c.chunk_index for c in chunks] == [0, 1, 2], "Chunk indices should be [0, 1, 2]"
assert "one" in chunks[0].text and "two" in chunks[0].text, "Chunk 0 should contain s1 and s2"
assert "two" in chunks[1].text and "three" in chunks[1].text, (
"Chunk 1 should contain s2 (overlap) and s3"
)
assert "three" in chunks[2].text and "four" in chunks[2].text, (
"Chunk 2 should contain s3 (overlap) and s4"
)
@pytest.mark.asyncio
async def test_zero_overlap_produces_no_duplicate_content(
make_text_generator, make_controlled_chunk_data
):
"""With 0% overlap, no content should appear in multiple chunks."""
s1 = "one"
s2 = "two"
s3 = "three"
s4 = "four"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=20,
chunk_overlap_ratio=0.0,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 2, "Should produce exactly 2 chunks (s1+s2, s3+s4)"
assert "one" in chunks[0].text and "two" in chunks[0].text, (
"First chunk should contain s1 and s2"
)
assert "three" in chunks[1].text and "four" in chunks[1].text, (
"Second chunk should contain s3 and s4"
)
assert "two" not in chunks[1].text and "three" not in chunks[0].text, (
"No overlap: end of chunk 0 should not appear in chunk 1"
)
@pytest.mark.asyncio
async def test_small_overlap_ratio_creates_minimal_overlap(
make_text_generator, make_controlled_chunk_data
):
"""With 25% overlap ratio, chunks should have minimal overlap."""
s1 = "alpha"
s2 = "beta"
s3 = "gamma"
s4 = "delta"
s5 = "epsilon"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=10)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=40,
chunk_overlap_ratio=0.25,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 2, "Should produce exactly 2 chunks"
assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
"Chunk 0 should contain s1 through s4"
)
assert s4 in chunks[1].text and s5 in chunks[1].text, (
"Chunk 1 should contain overlap s4 and new content s5"
)
@pytest.mark.asyncio
async def test_high_overlap_ratio_creates_significant_overlap(
make_text_generator, make_controlled_chunk_data
):
"""With 75% overlap ratio, consecutive chunks should share most content."""
s1 = "red"
s2 = "blue"
s3 = "green"
s4 = "yellow"
s5 = "purple"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=5)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=20,
chunk_overlap_ratio=0.75,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 2, "Should produce exactly 2 chunks with 75% overlap"
assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
"Chunk 0 should contain s1, s2, s3, s4"
)
assert all(token in chunks[1].text for token in [s2, s3, s4, s5]), (
"Chunk 1 should contain s2, s3, s4 (overlap) and s5"
)
@pytest.mark.asyncio
async def test_single_chunk_no_dangling_overlap(make_text_generator, make_controlled_chunk_data):
"""Text that fits in one chunk should produce exactly one chunk, no overlap artifact."""
s1 = "alpha"
s2 = "beta"
text = "dummy"
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
get_chunk_data = make_controlled_chunk_data(s1, s2, chunk_size_per_sentence=10)
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=20,
chunk_overlap_ratio=0.5,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 1, (
"Should produce exactly 1 chunk when content fits within max_chunk_size"
)
assert chunks[0].chunk_index == 0, "Single chunk should have index 0"
assert "alpha" in chunks[0].text and "beta" in chunks[0].text, (
"Single chunk should contain all content"
)
@pytest.mark.asyncio
async def test_paragraph_chunking_with_overlap(make_text_generator):
"""Test that chunk_by_paragraph integration produces 25% overlap between chunks."""
def mock_get_embedding_engine():
class MockEngine:
tokenizer = None
return MockEngine()
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
max_chunk_size = 20
overlap_ratio = 0.25 # 5 token overlap
paragraph_max_size = int(0.5 * overlap_ratio * max_chunk_size) # = 2
text = (
"A0 A1. A2 A3. A4 A5. A6 A7. A8 A9. " # 10 tokens (0-9)
"B0 B1. B2 B3. B4 B5. B6 B7. B8 B9. " # 10 tokens (10-19)
"C0 C1. C2 C3. C4 C5. C6 C7. C8 C9. " # 10 tokens (20-29)
"D0 D1. D2 D3. D4 D5. D6 D7. D8 D9. " # 10 tokens (30-39)
"E0 E1. E2 E3. E4 E5. E6 E7. E8 E9." # 10 tokens (40-49)
)
document = Document(
id=uuid4(),
name="test_document",
raw_data_location="/test/path",
external_metadata=None,
mime_type="text/plain",
)
get_text = make_text_generator(text)
def get_chunk_data(text_input):
return chunk_by_paragraph(
text_input, max_chunk_size=paragraph_max_size, batch_paragraphs=True
)
with patch.object(
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
):
chunker = TextChunkerWithOverlap(
document,
get_text,
max_chunk_size=max_chunk_size,
chunk_overlap_ratio=overlap_ratio,
get_chunk_data=get_chunk_data,
)
chunks = [chunk async for chunk in chunker.read()]
assert len(chunks) == 3, f"Should produce exactly 3 chunks, got {len(chunks)}"
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
assert chunks[2].chunk_index == 2, "Third chunk should have index 2"
assert "A0" in chunks[0].text, "Chunk 0 should start with A0"
assert "A9" in chunks[0].text, "Chunk 0 should contain A9"
assert "B0" in chunks[0].text, "Chunk 0 should contain B0"
assert "B9" in chunks[0].text, "Chunk 0 should contain up to B9 (20 tokens)"
assert "B" in chunks[1].text, "Chunk 1 should have overlap from B section"
assert "C" in chunks[1].text, "Chunk 1 should contain C section"
assert "D" in chunks[1].text, "Chunk 1 should contain D section"
assert "D" in chunks[2].text, "Chunk 2 should have overlap from D section"
assert "E0" in chunks[2].text, "Chunk 2 should contain E0"
assert "E9" in chunks[2].text, "Chunk 2 should end with E9"
chunk_0_end_words = chunks[0].text.split()[-4:]
chunk_1_words = chunks[1].text.split()
overlap_0_1 = any(word in chunk_1_words for word in chunk_0_end_words)
assert overlap_0_1, (
f"No overlap detected between chunks 0 and 1. "
f"Chunk 0 ends with: {chunk_0_end_words}, "
f"Chunk 1 starts with: {chunk_1_words[:6]}"
)
chunk_1_end_words = chunks[1].text.split()[-4:]
chunk_2_words = chunks[2].text.split()
overlap_1_2 = any(word in chunk_2_words for word in chunk_1_end_words)
assert overlap_1_2, (
f"No overlap detected between chunks 1 and 2. "
f"Chunk 1 ends with: {chunk_1_end_words}, "
f"Chunk 2 starts with: {chunk_2_words[:6]}"
)

View file

@ -16,9 +16,11 @@ async def test_cognify_session_success():
patch("cognee.add", new_callable=AsyncMock) as mock_add,
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
):
await cognify_session(session_data)
await cognify_session(session_data, dataset_id="123")
mock_add.assert_called_once_with(session_data, node_set=["user_sessions_from_cache"])
mock_add.assert_called_once_with(
session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
)
mock_cognify.assert_called_once()
@ -101,7 +103,9 @@ async def test_cognify_session_with_special_characters():
patch("cognee.add", new_callable=AsyncMock) as mock_add,
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
):
await cognify_session(session_data)
await cognify_session(session_data, dataset_id="123")
mock_add.assert_called_once_with(session_data, node_set=["user_sessions_from_cache"])
mock_add.assert_called_once_with(
session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
)
mock_cognify.assert_called_once()

View file

@ -2,7 +2,6 @@ import os
import pytest
import pathlib
from typing import Optional, Union
from pydantic import BaseModel
import cognee
from cognee.low_level import setup, DataPoint
@ -11,11 +10,6 @@ from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
class TestAnswer(BaseModel):
answer: str
explanation: str
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
@ -174,48 +168,3 @@ class TestGraphCompletionCoTRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1)
entities = [company1, person1]
await add_data_points(entities)
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_structured_completion("Who works at Figma?")
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
assert string_answer.strip(), "Answer should not be empty"
# Test with structured response model
structured_answer = await retriever.get_structured_completion(
"Who works at Figma?", response_model=TestAnswer
)
assert isinstance(structured_answer, TestAnswer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}"
)
assert structured_answer.answer.strip(), "Answer field should not be empty"
assert structured_answer.explanation.strip(), "Explanation field should not be empty"

View file

@ -3,6 +3,7 @@ from typing import List
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine

View file

@ -0,0 +1,204 @@
import asyncio
import pytest
import cognee
import pathlib
import os
from pydantic import BaseModel
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
class TestAnswer(BaseModel):
answer: str
explanation: str
def _assert_string_answer(answer: list[str]):
assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings"
assert all(item.strip() for item in answer), "Items should not be empty"
def _assert_structured_answer(answer: list[TestAnswer]):
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer"
assert all(x.answer.strip() for x in answer), "Answer text should not be empty"
assert all(x.explanation.strip() for x in answer), "Explanation should not be empty"
async def _test_get_structured_graph_completion_cot():
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion():
retriever = GraphCompletionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_temporal():
retriever = TemporalRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("When did Steve start working at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"When did Steve start working at Figma??", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_rag():
retriever = CompletionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Where does Steve work?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Where does Steve work?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_graph_completion_context_extension():
retriever = GraphCompletionContextExtensionRetriever()
# Test with string response model (default)
string_answer = await retriever.get_completion("Who works at Figma?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
async def _test_get_structured_entity_completion():
retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
# Test with string response model (default)
string_answer = await retriever.get_completion("Who is Albert Einstein?")
_assert_string_answer(string_answer)
# Test with structured response model
structured_answer = await retriever.get_completion(
"Who is Albert Einstein?", response_model=TestAnswer
)
_assert_structured_answer(structured_answer)
class TestStructuredOutputCompletion:
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
works_since: int
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
entities = [company1, person1]
await add_data_points(entities)
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
entity_type = EntityType(name="Person", description="A human individual")
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
entities = [entity]
await add_data_points(entities)
await _test_get_structured_graph_completion_cot()
await _test_get_structured_graph_completion()
await _test_get_structured_graph_completion_temporal()
await _test_get_structured_graph_completion_rag()
await _test_get_structured_graph_completion_context_extension()
await _test_get_structured_entity_completion()

View file

@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
class TextSummariesRetriever:
class TestSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context(self):
system_directory_path = os.path.join(

View file

@ -1,4 +1,3 @@
import asyncio
from types import SimpleNamespace
import pytest

View file

@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration:
# REQUIRE_AUTHENTICATION should be a boolean
assert isinstance(REQUIRE_AUTHENTICATION, bool)
# Currently should be False (optional authentication)
assert not REQUIRE_AUTHENTICATION
class TestConditionalAuthenticationEnvironmentVariables:
"""Test environment variable handling."""
def test_require_authentication_default_false(self):
"""Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
with patch.dict(os.environ, {}, clear=True):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment - module will see empty environment
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
importlib.invalidate_caches()
assert not REQUIRE_AUTHENTICATION
def test_require_authentication_true(self):
"""Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables:
assert REQUIRE_AUTHENTICATION
def test_require_authentication_false_explicit(self):
"""Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
assert not REQUIRE_AUTHENTICATION
def test_require_authentication_case_insensitive(self):
"""Test that environment variable parsing is case insensitive when imported."""
test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
for case in test_cases:
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
expected = case.lower() == "true"
assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
def test_current_require_authentication_value(self):
"""Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
# The module-level variable should currently be False (set at import time)
assert isinstance(REQUIRE_AUTHENTICATION, bool)
assert not REQUIRE_AUTHENTICATION
class TestConditionalAuthenticationEdgeCases:
"""Test edge cases and error scenarios."""

View file

@ -168,7 +168,7 @@ async def run_procurement_example():
for q in questions:
print(f"Question: \n{q}")
results = await procurement_system.search_memory(q, search_categories=[category])
top_answer = results[category][0]
top_answer = results[category][0]["search_result"][0]
print(f"Answer: \n{top_answer.strip()}\n")
research_notes[category].append({"question": q, "answer": top_answer})

View file

@ -1,5 +1,7 @@
import argparse
import asyncio
import os
import cognee
from cognee import SearchType
from cognee.shared.logging_utils import setup_logging, ERROR
@ -8,6 +10,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
async def main(repo_path, include_docs):
# Disable permissions feature for this example
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
run_status = False
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
run_status = run_status

View file

@ -67,7 +67,6 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}], # A placeholder to prevent fetching the entire graph
dataset="feedback_enrichment_minimal",
)

View file

@ -89,7 +89,7 @@ async def main():
)
print("Coding rules created by memify:")
for coding_rule in coding_rules:
for coding_rule in coding_rules[0]["search_result"][0]:
print("- " + coding_rule)
# Visualize new graph with added memify context

View file

@ -31,6 +31,9 @@ from cognee.infrastructure.databases.vector.pgvector import (
async def main():
# Disable backend access control to migrate relational data
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
# Clean all data stored in Cognee
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

View file

@ -59,14 +59,6 @@ async def main():
for result_text in search_results:
print(result_text)
# Example output:
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
# (...)
# It represents nodes and relationships in the knowledge graph:
# - The first element is the source node (e.g., 'natural language processing').
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
# - The third element is the target node (e.g., 'computer science').
if __name__ == "__main__":
logger = setup_logging(log_level=ERROR)