diff --git a/.env.template b/.env.template index 7dcd4f346..ae2cb1338 100644 --- a/.env.template +++ b/.env.template @@ -169,8 +169,9 @@ REQUIRE_AUTHENTICATION=False # Vector: LanceDB # Graph: KuzuDB # -# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset -ENABLE_BACKEND_ACCESS_CONTROL=False +# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers. +# Disable mode when using not supported graph/vector databases. +ENABLE_BACKEND_ACCESS_CONTROL=True ################################################################################ # ☁️ Cloud Sync Settings diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 70a4b56e6..0596f22d3 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -447,3 +447,44 @@ jobs: DB_USERNAME: cognee DB_PASSWORD: cognee run: uv run python ./cognee/tests/test_conversation_history.py + + test-load: + name: Test Load + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + extra-dependencies: "aws" + + - name: Set File Descriptor Limit + run: sudo prlimit --pid $$ --nofile=4096:4096 + + - name: Verify File Descriptor Limit + run: ulimit -n + + - name: Dependencies already installed + run: echo "Dependencies already installed in setup" + + - name: Run Load Test + env: + ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: True + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + STORAGE_BACKEND: s3 + AWS_REGION: eu-west-1 + AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }} + run: uv run python ./cognee/tests/test_load.py \ No newline at end of file diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index 57bc88157..36953e259 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -210,6 +210,31 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./examples/python/memify_coding_agent_example.py + test-custom-pipeline: + name: Run Custom Pipeline Example + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Custom Pipeline Example + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./examples/python/run_custom_pipeline_example.py + test-permissions-example: name: Run Permissions Example runs-on: ubuntu-22.04 diff --git a/.github/workflows/search_db_tests.yml b/.github/workflows/search_db_tests.yml index e3e46dd97..118c1c06c 100644 --- a/.github/workflows/search_db_tests.yml +++ b/.github/workflows/search_db_tests.yml @@ -84,6 +84,7 @@ jobs: GRAPH_DATABASE_PROVIDER: 'neo4j' VECTOR_DB_PROVIDER: 'lancedb' DB_PROVIDER: 'sqlite' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} @@ -135,6 +136,7 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} GRAPH_DATABASE_PROVIDER: 'kuzu' VECTOR_DB_PROVIDER: 'pgvector' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' DB_PROVIDER: 'postgres' DB_NAME: 'cognee_db' DB_HOST: '127.0.0.1' @@ -197,6 +199,7 @@ jobs: GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} + ENABLE_BACKEND_ACCESS_CONTROL: 'false' DB_NAME: cognee_db DB_HOST: 127.0.0.1 DB_PORT: 5432 diff --git a/.github/workflows/test_different_operating_systems.yml b/.github/workflows/test_different_operating_systems.yml index 64f1a14f9..02651b474 100644 --- a/.github/workflows/test_different_operating_systems.yml +++ b/.github/workflows/test_different_operating_systems.yml @@ -10,6 +10,10 @@ on: required: false type: string default: '["3.10.x", "3.12.x", "3.13.x"]' + os: + required: false + type: string + default: '["ubuntu-22.04", "macos-15", "windows-latest"]' secrets: LLM_PROVIDER: required: true @@ -40,10 +44,11 @@ jobs: run-unit-tests: name: Unit tests ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ubuntu-22.04, macos-15, windows-latest] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -76,10 +81,11 @@ jobs: run-integration-tests: name: Integration tests ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -112,10 +118,11 @@ jobs: run-library-test: name: Library test ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -148,10 +155,11 @@ jobs: run-build-test: name: Build test ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -177,10 +185,11 @@ jobs: run-soft-deletion-test: name: Soft Delete test ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out @@ -214,10 +223,11 @@ jobs: run-hard-deletion-test: name: Hard Delete test ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: matrix: python-version: ${{ fromJSON(inputs.python-versions) }} - os: [ ubuntu-22.04, macos-15, windows-latest ] + os: ${{ fromJSON(inputs.os) }} fail-fast: false steps: - name: Check out diff --git a/.github/workflows/test_ollama.yml b/.github/workflows/test_ollama.yml index fbd687319..686545c70 100644 --- a/.github/workflows/test_ollama.yml +++ b/.github/workflows/test_ollama.yml @@ -75,7 +75,7 @@ jobs: { "role": "user", "content": "Whatever I say, answer with Yes." } ] }' - curl -X POST http://127.0.0.1:11434/v1/embeddings \ + curl -X POST http://127.0.0.1:11434/api/embed \ -H "Content-Type: application/json" \ -d '{ "model": "avr/sfr-embedding-mistral:latest", @@ -98,7 +98,7 @@ jobs: LLM_MODEL: "phi4" EMBEDDING_PROVIDER: "ollama" EMBEDDING_MODEL: "avr/sfr-embedding-mistral:latest" - EMBEDDING_ENDPOINT: "http://localhost:11434/api/embeddings" + EMBEDDING_ENDPOINT: "http://localhost:11434/api/embed" EMBEDDING_DIMENSIONS: "4096" HUGGINGFACE_TOKENIZER: "Salesforce/SFR-Embedding-Mistral" run: uv run python ./examples/python/simple_example.py diff --git a/.github/workflows/test_suites.yml b/.github/workflows/test_suites.yml index 5c1597a93..be1e354fc 100644 --- a/.github/workflows/test_suites.yml +++ b/.github/workflows/test_suites.yml @@ -1,4 +1,6 @@ name: Test Suites +permissions: + contents: read on: push: @@ -80,12 +82,22 @@ jobs: uses: ./.github/workflows/notebooks_tests.yml secrets: inherit - different-operating-systems-tests: - name: Operating System and Python Tests + different-os-tests-basic: + name: OS and Python Tests Ubuntu needs: [basic-tests, e2e-tests] uses: ./.github/workflows/test_different_operating_systems.yml with: python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]' + os: '["ubuntu-22.04"]' + secrets: inherit + + different-os-tests-extended: + name: OS and Python Tests Extended + needs: [basic-tests, e2e-tests] + uses: ./.github/workflows/test_different_operating_systems.yml + with: + python-versions: '["3.13.x"]' + os: '["macos-15", "windows-latest"]' secrets: inherit # Matrix-based vector database tests @@ -135,7 +147,8 @@ jobs: e2e-tests, graph-db-tests, notebook-tests, - different-operating-systems-tests, + different-os-tests-basic, + different-os-tests-extended, vector-db-tests, example-tests, llm-tests, @@ -155,7 +168,8 @@ jobs: cli-tests, graph-db-tests, notebook-tests, - different-operating-systems-tests, + different-os-tests-basic, + different-os-tests-extended, vector-db-tests, example-tests, db-examples-tests, @@ -176,7 +190,8 @@ jobs: "${{ needs.cli-tests.result }}" == "success" && "${{ needs.graph-db-tests.result }}" == "success" && "${{ needs.notebook-tests.result }}" == "success" && - "${{ needs.different-operating-systems-tests.result }}" == "success" && + "${{ needs.different-os-tests-basic.result }}" == "success" && + "${{ needs.different-os-tests-extended.result }}" == "success" && "${{ needs.vector-db-tests.result }}" == "success" && "${{ needs.example-tests.result }}" == "success" && "${{ needs.db-examples-tests.result }}" == "success" && diff --git a/cognee/__init__.py b/cognee/__init__.py index 6e4d2a903..4d150ce4e 100644 --- a/cognee/__init__.py +++ b/cognee/__init__.py @@ -19,6 +19,7 @@ from .api.v1.add import add from .api.v1.delete import delete from .api.v1.cognify import cognify from .modules.memify import memify +from .modules.run_custom_pipeline import run_custom_pipeline from .api.v1.update import update from .api.v1.config.config import config from .api.v1.datasets.datasets import datasets diff --git a/cognee/api/client.py b/cognee/api/client.py index 6766c12de..19a607ff0 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -39,6 +39,8 @@ from cognee.api.v1.users.routers import ( ) from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION +# Ensure application logging is configured for container stdout/stderr +setup_logging() logger = get_logger() if os.getenv("ENV", "prod") == "prod": @@ -74,6 +76,9 @@ async def lifespan(app: FastAPI): await get_default_user() + # Emit a clear startup message for docker logs + logger.info("Backend server has started") + yield diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index d52de4b4e..f17c9187a 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -4,6 +4,8 @@ from typing import Union from uuid import UUID from cognee.base_config import get_base_config +from cognee.infrastructure.databases.vector.config import get_vectordb_context_config +from cognee.infrastructure.databases.graph.config import get_graph_context_config from cognee.infrastructure.databases.utils import get_or_create_dataset_database from cognee.infrastructure.files.storage.config import file_storage_config from cognee.modules.users.methods import get_user @@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None) session_user = ContextVar("session_user", default=None) +vector_dbs_with_multi_user_support = ["lancedb"] +graph_dbs_with_multi_user_support = ["kuzu"] + async def set_session_user_context_variable(user): session_user.set(user) +def multi_user_support_possible(): + graph_db_config = get_graph_context_config() + vector_db_config = get_vectordb_context_config() + return ( + graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support + and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support + ) + + +def backend_access_control_enabled(): + backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None) + if backend_access_control is None: + # If backend access control is not defined in environment variables, + # enable it by default if graph and vector DBs can support it, otherwise disable it + return multi_user_support_possible() + elif backend_access_control.lower() == "true": + # If enabled, ensure that the current graph and vector DBs can support it + multi_user_support = multi_user_support_possible() + if not multi_user_support: + raise EnvironmentError( + "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control." + ) + return True + return False + + async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID): """ If backend access control is enabled this function will ensure all datasets have their own databases, @@ -40,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ base_config = get_base_config() - if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": + if not backend_access_control_enabled(): return user = await get_user(user_id) diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index d1cf855d7..c54d94f6c 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -133,6 +133,6 @@ def create_vector_engine( else: raise EnvironmentError( - f"Unsupported graph database provider: {vector_db_provider}. " + f"Unsupported vector database provider: {vector_db_provider}. " f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}" ) diff --git a/cognee/infrastructure/engine/models/Edge.py b/cognee/infrastructure/engine/models/Edge.py index 5ad9c84dd..59f01a9ab 100644 --- a/cognee/infrastructure/engine/models/Edge.py +++ b/cognee/infrastructure/engine/models/Edge.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from typing import Optional, Any, Dict @@ -18,9 +18,21 @@ class Edge(BaseModel): # Mixed usage has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item]) + + # With edge_text for rich embedding representation + contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity) """ weight: Optional[float] = None weights: Optional[Dict[str, float]] = None relationship_type: Optional[str] = None properties: Optional[Dict[str, Any]] = None + edge_text: Optional[str] = None + + @field_validator("edge_text", mode="before") + @classmethod + def ensure_edge_text(cls, v, info): + """Auto-populate edge_text from relationship_type if not explicitly provided.""" + if v is None and info.data.get("relationship_type"): + return info.data["relationship_type"] + return v diff --git a/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py b/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py new file mode 100644 index 000000000..92d64c156 --- /dev/null +++ b/cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py @@ -0,0 +1,55 @@ +from typing import Optional, List + +from cognee import memify +from cognee.context_global_variables import ( + set_database_global_context_variables, + set_session_user_context_variable, +) +from cognee.exceptions import CogneeValidationError +from cognee.modules.data.methods import get_authorized_existing_datasets +from cognee.shared.logging_utils import get_logger +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.users.models import User +from cognee.tasks.memify import extract_user_sessions, cognify_session + + +logger = get_logger("persist_sessions_in_knowledge_graph") + + +async def persist_sessions_in_knowledge_graph_pipeline( + user: User, + session_ids: Optional[List[str]] = None, + dataset: str = "main_dataset", + run_in_background: bool = False, +): + await set_session_user_context_variable(user) + dataset_to_write = await get_authorized_existing_datasets( + user=user, datasets=[dataset], permission_type="write" + ) + + if not dataset_to_write: + raise CogneeValidationError( + message=f"User (id: {str(user.id)}) does not have write access to dataset: {dataset}", + log=False, + ) + + await set_database_global_context_variables( + dataset_to_write[0].id, dataset_to_write[0].owner_id + ) + + extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)] + + enrichment_tasks = [ + Task(cognify_session, dataset_id=dataset_to_write[0].id), + ] + + result = await memify( + extraction_tasks=extraction_tasks, + enrichment_tasks=enrichment_tasks, + dataset=dataset_to_write[0].id, + data=[{}], + run_in_background=run_in_background, + ) + + logger.info("Session persistence pipeline completed") + return result diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index 601454802..a9fb08a9e 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -3,6 +3,7 @@ from typing import List, Union from pydantic import BaseModel, Field from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.Edge import Edge from cognee.modules.data.processing.document_types import Document from cognee.modules.engine.models import Entity from cognee.tasks.temporal_graph.models import Event @@ -24,7 +25,6 @@ class DocumentChunk(DataPoint): - cut_type: The type of cut that defined this chunk. - is_part_of: The document to which this chunk belongs. - contains: A list of entities or events contained within the chunk (default is None). - - last_accessed_at: The timestamp of the last time the chunk was accessed. - metadata: A dictionary to hold meta information related to the chunk, including index fields. """ @@ -34,5 +34,5 @@ class DocumentChunk(DataPoint): chunk_index: int cut_type: str is_part_of: Document - contains: List[Union[Entity, Event]] = None + contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None metadata: dict = {"index_fields": ["text"]} diff --git a/cognee/modules/chunking/text_chunker_with_overlap.py b/cognee/modules/chunking/text_chunker_with_overlap.py new file mode 100644 index 000000000..4b9c23079 --- /dev/null +++ b/cognee/modules/chunking/text_chunker_with_overlap.py @@ -0,0 +1,124 @@ +from cognee.shared.logging_utils import get_logger +from uuid import NAMESPACE_OID, uuid5 + +from cognee.tasks.chunks import chunk_by_paragraph +from cognee.modules.chunking.Chunker import Chunker +from .models.DocumentChunk import DocumentChunk + +logger = get_logger() + + +class TextChunkerWithOverlap(Chunker): + def __init__( + self, + document, + get_text: callable, + max_chunk_size: int, + chunk_overlap_ratio: float = 0.0, + get_chunk_data: callable = None, + ): + super().__init__(document, get_text, max_chunk_size) + self._accumulated_chunk_data = [] + self._accumulated_size = 0 + self.chunk_overlap_ratio = chunk_overlap_ratio + self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio) + + if get_chunk_data is not None: + self.get_chunk_data = get_chunk_data + elif chunk_overlap_ratio > 0: + paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size) + self.get_chunk_data = lambda text: chunk_by_paragraph( + text, paragraph_max_size, batch_paragraphs=True + ) + else: + self.get_chunk_data = lambda text: chunk_by_paragraph( + text, self.max_chunk_size, batch_paragraphs=True + ) + + def _accumulation_overflows(self, chunk_data): + """Check if adding chunk_data would exceed max_chunk_size.""" + return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size + + def _accumulate_chunk_data(self, chunk_data): + """Add chunk_data to the current accumulation.""" + self._accumulated_chunk_data.append(chunk_data) + self._accumulated_size += chunk_data["chunk_size"] + + def _clear_accumulation(self): + """Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio.""" + if self.chunk_overlap == 0: + self._accumulated_chunk_data = [] + self._accumulated_size = 0 + return + + # Keep chunk_data from the end that fit in overlap + overlap_chunk_data = [] + overlap_size = 0 + + for chunk_data in reversed(self._accumulated_chunk_data): + if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap: + overlap_chunk_data.insert(0, chunk_data) + overlap_size += chunk_data["chunk_size"] + else: + break + + self._accumulated_chunk_data = overlap_chunk_data + self._accumulated_size = overlap_size + + def _create_chunk(self, text, size, cut_type, chunk_id=None): + """Create a DocumentChunk with standard metadata.""" + try: + return DocumentChunk( + id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), + text=text, + chunk_size=size, + is_part_of=self.document, + chunk_index=self.chunk_index, + cut_type=cut_type, + contains=[], + metadata={"index_fields": ["text"]}, + ) + except Exception as e: + logger.error(e) + raise e + + def _create_chunk_from_accumulation(self): + """Create a DocumentChunk from current accumulated chunk_data.""" + chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data) + return self._create_chunk( + text=chunk_text, + size=self._accumulated_size, + cut_type=self._accumulated_chunk_data[-1]["cut_type"], + ) + + def _emit_chunk(self, chunk_data): + """Emit a chunk when accumulation overflows.""" + if len(self._accumulated_chunk_data) > 0: + chunk = self._create_chunk_from_accumulation() + self._clear_accumulation() + self._accumulate_chunk_data(chunk_data) + else: + # Handle single chunk_data exceeding max_chunk_size + chunk = self._create_chunk( + text=chunk_data["text"], + size=chunk_data["chunk_size"], + cut_type=chunk_data["cut_type"], + chunk_id=chunk_data["chunk_id"], + ) + + self.chunk_index += 1 + return chunk + + async def read(self): + async for content_text in self.get_text(): + for chunk_data in self.get_chunk_data(content_text): + if not self._accumulation_overflows(chunk_data): + self._accumulate_chunk_data(chunk_data) + continue + + yield self._emit_chunk(chunk_data) + + if len(self._accumulated_chunk_data) == 0: + return + + yield self._create_chunk_from_accumulation() diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 9703928f0..cb7562422 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -171,8 +171,10 @@ class CogneeGraph(CogneeAbstractGraph): embedding_map = {result.payload["text"]: result.score for result in edge_distances} for edge in self.edges: - relationship_type = edge.attributes.get("relationship_type") - distance = embedding_map.get(relationship_type, None) + edge_key = edge.attributes.get("edge_text") or edge.attributes.get( + "relationship_type" + ) + distance = embedding_map.get(edge_key, None) if distance is not None: edge.attributes["vector_distance"] = distance diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py index 3b01f5af4..c68eb494d 100644 --- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py +++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py @@ -1,5 +1,6 @@ from typing import Optional +from cognee.infrastructure.engine.models.Edge import Edge from cognee.modules.chunking.models import DocumentChunk from cognee.modules.engine.models import Entity, EntityType from cognee.modules.engine.utils import ( @@ -243,10 +244,26 @@ def _process_graph_nodes( ontology_relationships, ) - # Add entity to data chunk if data_chunk.contains is None: data_chunk.contains = [] - data_chunk.contains.append(entity_node) + + edge_text = "; ".join( + [ + "relationship_name: contains", + f"entity_name: {entity_node.name}", + f"entity_description: {entity_node.description}", + ] + ) + + data_chunk.contains.append( + ( + Edge( + relationship_type="contains", + edge_text=edge_text, + ), + entity_node, + ) + ) def _process_graph_edges( diff --git a/cognee/modules/graph/utils/resolve_edges_to_text.py b/cognee/modules/graph/utils/resolve_edges_to_text.py index eb5bedd2c..5deb13ba8 100644 --- a/cognee/modules/graph/utils/resolve_edges_to_text.py +++ b/cognee/modules/graph/utils/resolve_edges_to_text.py @@ -1,71 +1,70 @@ +import string from typing import List +from collections import Counter + from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS + + +def _get_top_n_frequent_words( + text: str, stop_words: set = None, top_n: int = 3, separator: str = ", " +) -> str: + """Concatenates the top N frequent words in text.""" + if stop_words is None: + stop_words = DEFAULT_STOP_WORDS + + words = [word.lower().strip(string.punctuation) for word in text.split()] + words = [word for word in words if word and word not in stop_words] + + top_words = [word for word, freq in Counter(words).most_common(top_n)] + return separator.join(top_words) + + +def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str: + """Creates a title by combining first words with most frequent words from the text.""" + first_words = text.split()[:first_n_words] + top_words = _get_top_n_frequent_words(text, top_n=top_n_words) + return f"{' '.join(first_words)}... [{top_words}]" + + +def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict: + """Creates a dictionary of nodes with their names and content.""" + nodes = {} + + for edge in retrieved_edges: + for node in (edge.node1, edge.node2): + if node.id in nodes: + continue + + text = node.attributes.get("text") + if text: + name = _create_title_from_text(text) + content = text + else: + name = node.attributes.get("name", "Unnamed Node") + content = node.attributes.get("description", name) + + nodes[node.id] = {"node": node, "name": name, "content": content} + + return nodes async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str: - """ - Converts retrieved graph edges into a human-readable string format. + """Converts retrieved graph edges into a human-readable string format.""" + nodes = _extract_nodes_from_edges(retrieved_edges) - Parameters: - ----------- - - - retrieved_edges (list): A list of edges retrieved from the graph. - - Returns: - -------- - - - str: A formatted string representation of the nodes and their connections. - """ - - def _get_nodes(retrieved_edges: List[Edge]) -> dict: - def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str: - def _top_n_words(text, stop_words=None, top_n=3, separator=", "): - """Concatenates the top N frequent words in text.""" - if stop_words is None: - from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS - - stop_words = DEFAULT_STOP_WORDS - - import string - - words = [word.lower().strip(string.punctuation) for word in text.split()] - - if stop_words: - words = [word for word in words if word and word not in stop_words] - - from collections import Counter - - top_words = [word for word, freq in Counter(words).most_common(top_n)] - - return separator.join(top_words) - - """Creates a title, by combining first words with most frequent words from the text.""" - first_words = text.split()[:first_n_words] - top_words = _top_n_words(text, top_n=first_n_words) - return f"{' '.join(first_words)}... [{top_words}]" - - """Creates a dictionary of nodes with their names and content.""" - nodes = {} - for edge in retrieved_edges: - for node in (edge.node1, edge.node2): - if node.id not in nodes: - text = node.attributes.get("text") - if text: - name = _get_title(text) - content = text - else: - name = node.attributes.get("name", "Unnamed Node") - content = node.attributes.get("description", name) - nodes[node.id] = {"node": node, "name": name, "content": content} - return nodes - - nodes = _get_nodes(retrieved_edges) node_section = "\n".join( f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n" for info in nodes.values() ) - connection_section = "\n".join( - f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}" - for edge in retrieved_edges - ) + + connections = [] + for edge in retrieved_edges: + source_name = nodes[edge.node1.id]["name"] + target_name = nodes[edge.node2.id]["name"] + edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") + connections.append(f"{source_name} --[{edge_label}]--> {target_name}") + + connection_section = "\n".join(connections) + return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}" diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 1ef7545c2..f8bdbb97d 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -71,7 +71,7 @@ async def get_memory_fragment( await memory_fragment.project_graph_from_db( graph_engine, node_properties_to_project=properties_to_project, - edge_properties_to_project=["relationship_name"], + edge_properties_to_project=["relationship_name", "edge_text"], node_type=node_type, node_name=node_name, ) diff --git a/cognee/modules/run_custom_pipeline/__init__.py b/cognee/modules/run_custom_pipeline/__init__.py new file mode 100644 index 000000000..2d30e2e0c --- /dev/null +++ b/cognee/modules/run_custom_pipeline/__init__.py @@ -0,0 +1 @@ +from .run_custom_pipeline import run_custom_pipeline diff --git a/cognee/modules/run_custom_pipeline/run_custom_pipeline.py b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py new file mode 100644 index 000000000..d3df1c060 --- /dev/null +++ b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py @@ -0,0 +1,69 @@ +from typing import Union, Optional, List, Type, Any +from uuid import UUID + +from cognee.shared.logging_utils import get_logger + +from cognee.modules.pipelines import run_pipeline +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.users.models import User +from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor + +logger = get_logger() + + +async def run_custom_pipeline( + tasks: Union[List[Task], List[str]] = None, + data: Any = None, + dataset: Union[str, UUID] = "main_dataset", + user: User = None, + vector_db_config: Optional[dict] = None, + graph_db_config: Optional[dict] = None, + data_per_batch: int = 20, + run_in_background: bool = False, + pipeline_name: str = "custom_pipeline", +): + """ + Custom pipeline in Cognee, can work with already built graphs. Data needs to be provided which can be processed + with provided tasks. + + Provided tasks and data will be arranged to run the Cognee pipeline and execute graph enrichment/creation. + + This is the core processing step in Cognee that converts raw text and documents + into an intelligent knowledge graph. It analyzes content, extracts entities and + relationships, and creates semantic connections for enhanced search and reasoning. + + Args: + tasks: List of Cognee Tasks to execute. + data: The data to ingest. Can be anything when custom extraction and enrichment tasks are used. + Data provided here will be forwarded to the first extraction task in the pipeline as input. + dataset: Dataset name or dataset uuid to process. + user: User context for authentication and data access. Uses default if None. + vector_db_config: Custom vector database configuration for embeddings storage. + graph_db_config: Custom graph database configuration for relationship storage. + data_per_batch: Number of data items to be processed in parallel. + run_in_background: If True, starts processing asynchronously and returns immediately. + If False, waits for completion before returning. + Background mode recommended for large datasets (>100MB). + Use pipeline_run_id from return value to monitor progress. + """ + + custom_tasks = [ + *tasks, + ] + + # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for + pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background) + + # Run the run_pipeline in the background or blocking based on executor + return await pipeline_executor_func( + pipeline=run_pipeline, + tasks=custom_tasks, + user=user, + data=data, + datasets=dataset, + vector_db_config=vector_db_config, + graph_db_config=graph_db_config, + incremental_loading=False, + data_per_batch=data_per_batch, + pipeline_name=pipeline_name, + ) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index aab004924..5e465b239 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -1,4 +1,3 @@ -import os import json import asyncio from uuid import UUID @@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine from cognee.shared.logging_utils import get_logger from cognee.shared.utils import send_telemetry from cognee.context_global_variables import set_database_global_context_variables +from cognee.context_global_variables import backend_access_control_enabled from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge @@ -74,7 +74,7 @@ async def search( ) # Use search function filtered by permissions if access control is enabled - if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": + if backend_access_control_enabled(): search_results = await authorized_search( query_type=query_type, query_text=query_text, @@ -156,7 +156,7 @@ async def search( ) else: # This is for maintaining backwards compatibility - if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": + if backend_access_control_enabled(): return_value = [] for search_result in search_results: prepared_search_results = await prepare_search_result(search_result) diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py index d78215892..d6d701737 100644 --- a/cognee/modules/users/methods/get_authenticated_user.py +++ b/cognee/modules/users/methods/get_authenticated_user.py @@ -5,6 +5,7 @@ from ..models import User from ..get_fastapi_users import get_fastapi_users from .get_default_user import get_default_user from cognee.shared.logging_utils import get_logger +from cognee.context_global_variables import backend_access_control_enabled logger = get_logger("get_authenticated_user") @@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user") # Check environment variable to determine authentication requirement REQUIRE_AUTHENTICATION = ( os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" - or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true" + or backend_access_control_enabled() ) fastapi_users = get_fastapi_users() diff --git a/cognee/modules/users/methods/get_default_user.py b/cognee/modules/users/methods/get_default_user.py index 773545f8e..9e3940617 100644 --- a/cognee/modules/users/methods/get_default_user.py +++ b/cognee/modules/users/methods/get_default_user.py @@ -10,7 +10,7 @@ from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.methods.create_default_user import create_default_user -async def get_default_user() -> SimpleNamespace: +async def get_default_user() -> User: db_engine = get_relational_engine() base_config = get_base_config() default_email = base_config.default_user_email or "default_user@example.com" diff --git a/cognee/tasks/memify/__init__.py b/cognee/tasks/memify/__init__.py index 692bac443..7e590ed47 100644 --- a/cognee/tasks/memify/__init__.py +++ b/cognee/tasks/memify/__init__.py @@ -1,2 +1,4 @@ from .extract_subgraph import extract_subgraph from .extract_subgraph_chunks import extract_subgraph_chunks +from .cognify_session import cognify_session +from .extract_user_sessions import extract_user_sessions diff --git a/cognee/tasks/memify/cognify_session.py b/cognee/tasks/memify/cognify_session.py new file mode 100644 index 000000000..f53f9afb1 --- /dev/null +++ b/cognee/tasks/memify/cognify_session.py @@ -0,0 +1,41 @@ +import cognee + +from cognee.exceptions import CogneeValidationError, CogneeSystemError +from cognee.shared.logging_utils import get_logger + +logger = get_logger("cognify_session") + + +async def cognify_session(data, dataset_id=None): + """ + Process and cognify session data into the knowledge graph. + + Adds session content to cognee with a dedicated "user_sessions" node set, + then triggers the cognify pipeline to extract entities and relationships + from the session data. + + Args: + data: Session string containing Question, Context, and Answer information. + dataset_name: Name of dataset. + + Raises: + CogneeValidationError: If data is None or empty. + CogneeSystemError: If cognee operations fail. + """ + try: + if not data or (isinstance(data, str) and not data.strip()): + logger.warning("Empty session data provided to cognify_session task, skipping") + raise CogneeValidationError(message="Session data cannot be empty", log=False) + + logger.info("Processing session data for cognification") + + await cognee.add(data, dataset_id=dataset_id, node_set=["user_sessions_from_cache"]) + logger.debug("Session data added to cognee with node_set: user_sessions") + await cognee.cognify(datasets=[dataset_id]) + logger.info("Session data successfully cognified") + + except CogneeValidationError: + raise + except Exception as e: + logger.error(f"Error cognifying session data: {str(e)}") + raise CogneeSystemError(message=f"Failed to cognify session data: {str(e)}", log=False) diff --git a/cognee/tasks/memify/extract_user_sessions.py b/cognee/tasks/memify/extract_user_sessions.py new file mode 100644 index 000000000..9779a363e --- /dev/null +++ b/cognee/tasks/memify/extract_user_sessions.py @@ -0,0 +1,73 @@ +from typing import Optional, List + +from cognee.context_global_variables import session_user +from cognee.exceptions import CogneeSystemError +from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine +from cognee.shared.logging_utils import get_logger +from cognee.modules.users.models import User + +logger = get_logger("extract_user_sessions") + + +async def extract_user_sessions( + data, + session_ids: Optional[List[str]] = None, +): + """ + Extract Q&A sessions for the current user from cache. + + Retrieves all Q&A triplets from specified session IDs and yields them + as formatted strings combining question, context, and answer. + + Args: + data: Data passed from memify. If empty dict ({}), no external data is provided. + session_ids: Optional list of specific session IDs to extract. + + Yields: + String containing session ID and all Q&A pairs formatted. + + Raises: + CogneeSystemError: If cache engine is unavailable or extraction fails. + """ + try: + if not data or data == [{}]: + logger.info("Fetching session metadata for current user") + + user: User = session_user.get() + if not user: + raise CogneeSystemError(message="No authenticated user found in context", log=False) + + user_id = str(user.id) + + cache_engine = get_cache_engine() + if cache_engine is None: + raise CogneeSystemError( + message="Cache engine not available for session extraction, please enable caching in order to have sessions to save", + log=False, + ) + + if session_ids: + for session_id in session_ids: + try: + qa_data = await cache_engine.get_all_qas(user_id, session_id) + if qa_data: + logger.info(f"Extracted session {session_id} with {len(qa_data)} Q&A pairs") + session_string = f"Session ID: {session_id}\n\n" + for qa_pair in qa_data: + question = qa_pair.get("question", "") + answer = qa_pair.get("answer", "") + session_string += f"Question: {question}\n\nAnswer: {answer}\n\n" + yield session_string + except Exception as e: + logger.warning(f"Failed to extract session {session_id}: {str(e)}") + continue + else: + logger.info( + "No specific session_ids provided. Please specify which sessions to extract." + ) + + except CogneeSystemError: + raise + except Exception as e: + logger.error(f"Error extracting user sessions: {str(e)}") + raise CogneeSystemError(message=f"Failed to extract user sessions: {str(e)}", log=False) diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 902789c80..b0ec3a5b4 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -8,47 +8,58 @@ logger = get_logger("index_data_points") async def index_data_points(data_points: list[DataPoint]): - created_indexes = {} - index_points = {} + """Index data points in the vector engine by creating embeddings for specified fields. + + Process: + 1. Groups data points into a nested dict: {type_name: {field_name: [points]}} + 2. Creates vector indexes for each (type, field) combination on first encounter + 3. Batches points per (type, field) and creates async indexing tasks + 4. Executes all indexing tasks in parallel for efficient embedding generation + + Args: + data_points: List of DataPoint objects to index. Each DataPoint's metadata must + contain an 'index_fields' list specifying which fields to embed. + + Returns: + The original data_points list. + """ + data_points_by_type = {} vector_engine = get_vector_engine() for data_point in data_points: data_point_type = type(data_point) + type_name = data_point_type.__name__ for field_name in data_point.metadata["index_fields"]: if getattr(data_point, field_name, None) is None: continue - index_name = f"{data_point_type.__name__}_{field_name}" + if type_name not in data_points_by_type: + data_points_by_type[type_name] = {} - if index_name not in created_indexes: - await vector_engine.create_vector_index(data_point_type.__name__, field_name) - created_indexes[index_name] = True - - if index_name not in index_points: - index_points[index_name] = [] + if field_name not in data_points_by_type[type_name]: + await vector_engine.create_vector_index(type_name, field_name) + data_points_by_type[type_name][field_name] = [] indexed_data_point = data_point.model_copy() indexed_data_point.metadata["index_fields"] = [field_name] - index_points[index_name].append(indexed_data_point) + data_points_by_type[type_name][field_name].append(indexed_data_point) - tasks: list[asyncio.Task] = [] batch_size = vector_engine.embedding_engine.get_batch_size() - for index_name_and_field, points in index_points.items(): - first = index_name_and_field.index("_") - index_name = index_name_and_field[:first] - field_name = index_name_and_field[first + 1 :] + batches = ( + (type_name, field_name, points[i : i + batch_size]) + for type_name, fields in data_points_by_type.items() + for field_name, points in fields.items() + for i in range(0, len(points), batch_size) + ) - # Create embedding requests per batch to run in parallel later - for i in range(0, len(points), batch_size): - batch = points[i : i + batch_size] - tasks.append( - asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch)) - ) + tasks = [ + asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points)) + for type_name, field_name, batch_points in batches + ] - # Run all embedding requests in parallel await asyncio.gather(*tasks) return data_points diff --git a/cognee/tasks/storage/index_graph_edges.py b/cognee/tasks/storage/index_graph_edges.py index 4fa8cfc75..03b5a25a5 100644 --- a/cognee/tasks/storage/index_graph_edges.py +++ b/cognee/tasks/storage/index_graph_edges.py @@ -1,17 +1,44 @@ -import asyncio +from collections import Counter +from typing import Optional, Dict, Any, List, Tuple, Union from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.shared.logging_utils import get_logger -from collections import Counter -from typing import Optional, Dict, Any, List, Tuple, Union -from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.graph.models.EdgeType import EdgeType from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData +from cognee.tasks.storage.index_data_points import index_data_points logger = get_logger() +def _get_edge_text(item: dict) -> str: + """Extract edge text for embedding - prefers edge_text field with fallback.""" + if "edge_text" in item: + return item["edge_text"] + + if "relationship_name" in item: + return item["relationship_name"] + + return "" + + +def create_edge_type_datapoints(edges_data) -> list[EdgeType]: + """Transform raw edge data into EdgeType datapoints.""" + edge_texts = [ + _get_edge_text(item) + for edge in edges_data + for item in edge + if isinstance(item, dict) and "relationship_name" in item + ] + + edge_types = Counter(edge_texts) + + return [ + EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count) + for text, count in edge_types.items() + ] + + async def index_graph_edges( edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None, ): @@ -23,24 +50,17 @@ async def index_graph_edges( the `relationship_name` field. Steps: - 1. Initialize the vector engine and graph engine. - 2. Retrieve graph edge data and count relationship types (`relationship_name`). - 3. Create vector indexes for `relationship_name` if they don't exist. - 4. Transform the counted relationships into `EdgeType` objects. - 5. Index the transformed data points in the vector engine. + 1. Initialize the graph engine if needed and retrieve edge data. + 2. Transform edge data into EdgeType datapoints. + 3. Index the EdgeType datapoints using the standard indexing function. Raises: - RuntimeError: If initialization of the vector engine or graph engine fails. + RuntimeError: If initialization of the graph engine fails. Returns: None """ try: - created_indexes = {} - index_points = {} - - vector_engine = get_vector_engine() - if edges_data is None: graph_engine = await get_graph_engine() _, edges_data = await graph_engine.get_graph_data() @@ -51,47 +71,7 @@ async def index_graph_edges( logger.error("Failed to initialize engines: %s", e) raise RuntimeError("Initialization error") from e - edge_types = Counter( - item.get("relationship_name") - for edge in edges_data - for item in edge - if isinstance(item, dict) and "relationship_name" in item - ) - - for text, count in edge_types.items(): - edge = EdgeType( - id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count - ) - data_point_type = type(edge) - - for field_name in edge.metadata["index_fields"]: - index_name = f"{data_point_type.__name__}.{field_name}" - - if index_name not in created_indexes: - await vector_engine.create_vector_index(data_point_type.__name__, field_name) - created_indexes[index_name] = True - - if index_name not in index_points: - index_points[index_name] = [] - - indexed_data_point = edge.model_copy() - indexed_data_point.metadata["index_fields"] = [field_name] - index_points[index_name].append(indexed_data_point) - - # Get maximum batch size for embedding model - batch_size = vector_engine.embedding_engine.get_batch_size() - tasks: list[asyncio.Task] = [] - - for index_name, indexable_points in index_points.items(): - index_name, field_name = index_name.split(".") - - # Create embedding tasks to run in parallel later - for start in range(0, len(indexable_points), batch_size): - batch = indexable_points[start : start + batch_size] - - tasks.append(vector_engine.index_data_points(index_name, field_name, batch)) - - # Start all embedding tasks and wait for completion - await asyncio.gather(*tasks) + edge_type_datapoints = create_edge_type_datapoints(edges_data) + await index_data_points(edge_type_datapoints) return None diff --git a/cognee/tests/test_add_docling_document.py b/cognee/tests/test_add_docling_document.py index 2c82af66f..c5aa4e9d1 100644 --- a/cognee/tests/test_add_docling_document.py +++ b/cognee/tests/test_add_docling_document.py @@ -39,12 +39,12 @@ async def main(): answer = await cognee.search("Do programmers change light bulbs?") assert len(answer) != 0 - lowercase_answer = answer[0].lower() + lowercase_answer = answer[0]["search_result"][0].lower() assert ("no" in lowercase_answer) or ("none" in lowercase_answer) answer = await cognee.search("What colours are there in the presentation table?") assert len(answer) != 0 - lowercase_answer = answer[0].lower() + lowercase_answer = answer[0]["search_result"][0].lower() assert ( ("red" in lowercase_answer) and ("blue" in lowercase_answer) diff --git a/cognee/tests/test_conversation_history.py b/cognee/tests/test_conversation_history.py index 30bb54ef1..783baf563 100644 --- a/cognee/tests/test_conversation_history.py +++ b/cognee/tests/test_conversation_history.py @@ -16,9 +16,11 @@ import cognee import pathlib from cognee.infrastructure.databases.cache import get_cache_engine +from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.search.types import SearchType from cognee.shared.logging_utils import get_logger from cognee.modules.users.methods import get_default_user +from collections import Counter logger = get_logger() @@ -54,10 +56,10 @@ async def main(): """DataCo is a data analytics company. They help businesses make sense of their data.""" ) - await cognee.add(text_1, dataset_name) - await cognee.add(text_2, dataset_name) + await cognee.add(data=text_1, dataset_name=dataset_name) + await cognee.add(data=text_2, dataset_name=dataset_name) - await cognee.cognify([dataset_name]) + await cognee.cognify(datasets=[dataset_name]) user = await get_default_user() @@ -188,7 +190,6 @@ async def main(): f"GRAPH_SUMMARY_COMPLETION should return non-empty list, got: {result_summary!r}" ) - # Verify saved history_summary = await cache_engine.get_latest_qa(str(user.id), session_id_summary, last_n=10) our_qa_summary = [ h for h in history_summary if h["question"] == "What are the key points about TechCorp?" @@ -228,6 +229,46 @@ async def main(): assert "CONTEXT:" in formatted_history, "Formatted history should contain 'CONTEXT:' prefix" assert "ANSWER:" in formatted_history, "Formatted history should contain 'ANSWER:' prefix" + from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import ( + persist_sessions_in_knowledge_graph_pipeline, + ) + + logger.info("Starting persist_sessions_in_knowledge_graph tests") + + await persist_sessions_in_knowledge_graph_pipeline( + user=user, + session_ids=[session_id_1, session_id_2], + dataset=dataset_name, + run_in_background=False, + ) + + graph_engine = await get_graph_engine() + graph = await graph_engine.get_graph_data() + + type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) + + "Tests the correct number of NodeSet nodes after session persistence" + assert type_counts.get("NodeSet", 0) == 1, ( + f"Number of NodeSets in the graph is incorrect, found {type_counts.get('NodeSet', 0)} but there should be exactly 1." + ) + + "Tests the correct number of DocumentChunk nodes after session persistence" + assert type_counts.get("DocumentChunk", 0) == 4, ( + f"Number of DocumentChunk ndoes in the graph is incorrect, found {type_counts.get('DocumentChunk', 0)} but there should be exactly 4 (2 original documents, 2 sessions)." + ) + + from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine + + vector_engine = get_vector_engine() + collection_size = await vector_engine.search( + collection_name="DocumentChunk_text", + query_text="test", + limit=1000, + ) + assert len(collection_size) == 4, ( + f"DocumentChunk_text collection should have exactly 4 embeddings, found {len(collection_size)}" + ) + await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) diff --git a/cognee/tests/test_edge_ingestion.py b/cognee/tests/test_edge_ingestion.py index 5b23f7819..0d1407fab 100755 --- a/cognee/tests/test_edge_ingestion.py +++ b/cognee/tests/test_edge_ingestion.py @@ -52,6 +52,33 @@ async def test_edge_ingestion(): edge_type_counts = Counter(edge_type[2] for edge_type in graph[1]) + "Tests edge_text presence and format" + contains_edges = [edge for edge in graph[1] if edge[2] == "contains"] + assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification" + + edge_properties = contains_edges[0][3] + assert "edge_text" in edge_properties, "Expected edge_text in edge properties" + + edge_text = edge_properties["edge_text"] + assert "relationship_name: contains" in edge_text, ( + f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}" + ) + assert "entity_name:" in edge_text, f"Expected 'entity_name:' in edge_text, got: {edge_text}" + assert "entity_description:" in edge_text, ( + f"Expected 'entity_description:' in edge_text, got: {edge_text}" + ) + + all_edge_texts = [ + edge[3].get("edge_text", "") for edge in contains_edges if "edge_text" in edge[3] + ] + expected_entities = ["dave", "ana", "bob", "dexter", "apples", "cognee"] + found_entity = any( + any(entity in text.lower() for entity in expected_entities) for text in all_edge_texts + ) + assert found_entity, ( + f"Expected to find at least one entity name in edge_text: {all_edge_texts[:3]}" + ) + "Tests the presence of basic nested edges" for basic_nested_edge in basic_nested_edges: assert edge_type_counts.get(basic_nested_edge, 0) >= 1, ( diff --git a/cognee/tests/test_feedback_enrichment.py b/cognee/tests/test_feedback_enrichment.py index 02d90db32..378cb0e45 100644 --- a/cognee/tests/test_feedback_enrichment.py +++ b/cognee/tests/test_feedback_enrichment.py @@ -133,7 +133,7 @@ async def main(): extraction_tasks=extraction_tasks, enrichment_tasks=enrichment_tasks, data=[{}], - dataset="feedback_enrichment_test_memify", + dataset=dataset_name, ) nodes_after, edges_after = await graph_engine.get_graph_data() diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index 81f81ee61..893b836c0 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -90,15 +90,17 @@ async def main(): ) search_results = await cognee.search( - query_type=SearchType.GRAPH_COMPLETION, query_text="What information do you contain?" + query_type=SearchType.GRAPH_COMPLETION, + query_text="What information do you contain?", + dataset_ids=[pipeline_run_obj.dataset_id], ) - assert "Mark" in search_results[0], ( + assert "Mark" in search_results[0]["search_result"][0], ( "Failed to update document, no mention of Mark in search results" ) - assert "Cindy" in search_results[0], ( + assert "Cindy" in search_results[0]["search_result"][0], ( "Failed to update document, no mention of Cindy in search results" ) - assert "Artificial intelligence" not in search_results[0], ( + assert "Artificial intelligence" not in search_results[0]["search_result"][0], ( "Failed to update document, Artificial intelligence still mentioned in search results" ) diff --git a/cognee/tests/test_load.py b/cognee/tests/test_load.py new file mode 100644 index 000000000..b38466bc7 --- /dev/null +++ b/cognee/tests/test_load.py @@ -0,0 +1,62 @@ +import os +import pathlib +import asyncio +import time + +import cognee +from cognee.modules.search.types import SearchType +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +async def process_and_search(num_of_searches): + start_time = time.time() + + await cognee.cognify() + + await asyncio.gather( + *[ + cognee.search( + query_text="Tell me about the document", query_type=SearchType.GRAPH_COMPLETION + ) + for _ in range(num_of_searches) + ] + ) + + end_time = time.time() + + return end_time - start_time + + +async def main(): + data_directory_path = os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_load") + cognee.config.data_root_directory(data_directory_path) + + cognee_directory_path = os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_load") + cognee.config.system_root_directory(cognee_directory_path) + + num_of_pdfs = 10 + num_of_reps = 5 + upper_boundary_minutes = 10 + average_minutes = 8 + + recorded_times = [] + for _ in range(num_of_reps): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + s3_input = "s3://cognee-test-load-s3-bucket" + await cognee.add(s3_input) + + recorded_times.append(await process_and_search(num_of_pdfs)) + + average_recorded_time = sum(recorded_times) / len(recorded_times) + + assert average_recorded_time <= average_minutes * 60 + + assert all(rec_time <= upper_boundary_minutes * 60 for rec_time in recorded_times) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 2b69ce854..4557e9e2f 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -27,6 +27,9 @@ def normalize_node_name(node_name: str) -> str: async def setup_test_db(): + # Disable backend access control to migrate relational data + os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false" + await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index e24abd0f5..bd11dc62e 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -146,7 +146,13 @@ async def main(): assert len(search_results) == 1, ( f"{name}: expected single-element list, got {len(search_results)}" ) - text = search_results[0] + + from cognee.context_global_variables import backend_access_control_enabled + + if backend_access_control_enabled(): + text = search_results[0]["search_result"][0] + else: + text = search_results[0] assert isinstance(text, str), f"{name}: element should be a string" assert text.strip(), f"{name}: string should not be empty" assert "netherlands" in text.lower(), ( diff --git a/cognee/tests/unit/api/test_conditional_authentication_endpoints.py b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py index 2eabee91a..6cc37ef38 100644 --- a/cognee/tests/unit/api/test_conditional_authentication_endpoints.py +++ b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py @@ -1,3 +1,4 @@ +import os import pytest from unittest.mock import patch, AsyncMock, MagicMock from uuid import uuid4 @@ -5,8 +6,6 @@ from fastapi.testclient import TestClient from types import SimpleNamespace import importlib -from cognee.api.client import app - # Fixtures for reuse across test classes @pytest.fixture @@ -32,6 +31,10 @@ def mock_authenticated_user(): ) +# To turn off authentication we need to set the environment variable before importing the module +# Also both require_authentication and backend access control must be false +os.environ["REQUIRE_AUTHENTICATION"] = "false" +os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false" gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") @@ -40,6 +43,8 @@ class TestConditionalAuthenticationEndpoints: @pytest.fixture def client(self): + from cognee.api.client import app + """Create a test client.""" return TestClient(app) @@ -133,6 +138,8 @@ class TestConditionalAuthenticationBehavior: @pytest.fixture def client(self): + from cognee.api.client import app + return TestClient(app) @pytest.mark.parametrize( @@ -209,6 +216,8 @@ class TestConditionalAuthenticationErrorHandling: @pytest.fixture def client(self): + from cognee.api.client import app + return TestClient(app) @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) @@ -232,7 +241,7 @@ class TestConditionalAuthenticationErrorHandling: # The exact error message may vary depending on the actual database connection # The important thing is that we get a 500 error when user creation fails - def test_current_environment_configuration(self): + def test_current_environment_configuration(self, client): """Test that current environment configuration is working properly.""" # This tests the actual module state without trying to change it from cognee.modules.users.methods.get_authenticated_user import ( diff --git a/cognee/tests/unit/infrastructure/databases/test_index_data_points.py b/cognee/tests/unit/infrastructure/databases/test_index_data_points.py new file mode 100644 index 000000000..21a5695de --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/test_index_data_points.py @@ -0,0 +1,27 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from cognee.tasks.storage.index_data_points import index_data_points +from cognee.infrastructure.engine import DataPoint + + +class TestDataPoint(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + +@pytest.mark.asyncio +async def test_index_data_points_calls_vector_engine(): + """Test that index_data_points creates vector index and indexes data.""" + data_points = [TestDataPoint(name="test1")] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100) + + with patch.dict( + index_data_points.__globals__, + {"get_vector_engine": lambda: mock_vector_engine}, + ): + await index_data_points(data_points) + + assert mock_vector_engine.create_vector_index.await_count >= 1 + assert mock_vector_engine.index_data_points.await_count >= 1 diff --git a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py index 48bbc53e3..cee0896c2 100644 --- a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +++ b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py @@ -5,8 +5,7 @@ from cognee.tasks.storage.index_graph_edges import index_graph_edges @pytest.mark.asyncio async def test_index_graph_edges_success(): - """Test that index_graph_edges uses the index datapoints and creates vector index.""" - # Create the mocks for the graph and vector engines. + """Test that index_graph_edges retrieves edges and delegates to index_data_points.""" mock_graph_engine = AsyncMock() mock_graph_engine.get_graph_data.return_value = ( None, @@ -15,26 +14,23 @@ async def test_index_graph_edges_success(): [{"relationship_name": "rel2"}], ], ) - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100) + mock_index_data_points = AsyncMock() - # Patch the globals of the function so that when it does: - # vector_engine = get_vector_engine() - # graph_engine = await get_graph_engine() - # it uses the mocked versions. with patch.dict( index_graph_edges.__globals__, { "get_graph_engine": AsyncMock(return_value=mock_graph_engine), - "get_vector_engine": lambda: mock_vector_engine, + "index_data_points": mock_index_data_points, }, ): await index_graph_edges() - # Assertions on the mock calls. mock_graph_engine.get_graph_data.assert_awaited_once() - assert mock_vector_engine.create_vector_index.await_count == 1 - assert mock_vector_engine.index_data_points.await_count == 1 + mock_index_data_points.assert_awaited_once() + + call_args = mock_index_data_points.call_args[0][0] + assert len(call_args) == 2 + assert all(hasattr(item, "relationship_name") for item in call_args) @pytest.mark.asyncio @@ -42,20 +38,22 @@ async def test_index_graph_edges_no_relationships(): """Test that index_graph_edges handles empty relationships correctly.""" mock_graph_engine = AsyncMock() mock_graph_engine.get_graph_data.return_value = (None, []) - mock_vector_engine = AsyncMock() + mock_index_data_points = AsyncMock() with patch.dict( index_graph_edges.__globals__, { "get_graph_engine": AsyncMock(return_value=mock_graph_engine), - "get_vector_engine": lambda: mock_vector_engine, + "index_data_points": mock_index_data_points, }, ): await index_graph_edges() mock_graph_engine.get_graph_data.assert_awaited_once() - mock_vector_engine.create_vector_index.assert_not_awaited() - mock_vector_engine.index_data_points.assert_not_awaited() + mock_index_data_points.assert_awaited_once() + + call_args = mock_index_data_points.call_args[0][0] + assert len(call_args) == 0 @pytest.mark.asyncio diff --git a/cognee/tests/unit/modules/chunking/test_text_chunker.py b/cognee/tests/unit/modules/chunking/test_text_chunker.py new file mode 100644 index 000000000..d535f74b0 --- /dev/null +++ b/cognee/tests/unit/modules/chunking/test_text_chunker.py @@ -0,0 +1,248 @@ +"""Unit tests for TextChunker and TextChunkerWithOverlap behavioral equivalence.""" + +import pytest +from uuid import uuid4 + +from cognee.modules.chunking.TextChunker import TextChunker +from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap +from cognee.modules.data.processing.document_types import Document + + +@pytest.fixture(params=["TextChunker", "TextChunkerWithOverlap"]) +def chunker_class(request): + """Parametrize tests to run against both implementations.""" + return TextChunker if request.param == "TextChunker" else TextChunkerWithOverlap + + +@pytest.fixture +def make_text_generator(): + """Factory for async text generators.""" + + def _factory(*texts): + async def gen(): + for text in texts: + yield text + + return gen + + return _factory + + +async def collect_chunks(chunker): + """Consume async generator and return list of chunks.""" + chunks = [] + async for chunk in chunker.read(): + chunks.append(chunk) + return chunks + + +@pytest.mark.asyncio +async def test_empty_input_produces_no_chunks(chunker_class, make_text_generator): + """Empty input should yield no chunks.""" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator("") + chunker = chunker_class(document, get_text, max_chunk_size=512) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 0, "Empty input should produce no chunks" + + +@pytest.mark.asyncio +async def test_whitespace_only_input_emits_single_chunk(chunker_class, make_text_generator): + """Whitespace-only input should produce exactly one chunk with unchanged text.""" + whitespace_text = " \n\t \r\n " + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(whitespace_text) + chunker = chunker_class(document, get_text, max_chunk_size=512) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 1, "Whitespace-only input should produce exactly one chunk" + assert chunks[0].text == whitespace_text, "Chunk text should equal input (whitespace preserved)" + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + + +@pytest.mark.asyncio +async def test_single_paragraph_below_limit_emits_one_chunk(chunker_class, make_text_generator): + """Single paragraph below limit should emit exactly one chunk.""" + text = "This is a short paragraph." + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=512) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 1, "Single short paragraph should produce exactly one chunk" + assert chunks[0].text == text, "Chunk text should match input" + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + assert chunks[0].chunk_size > 0, "Chunk should have positive size" + + +@pytest.mark.asyncio +async def test_oversized_paragraph_gets_emitted_as_a_single_chunk( + chunker_class, make_text_generator +): + """Oversized paragraph from chunk_by_paragraph should be emitted as single chunk.""" + text = ("A" * 1500) + ". Next sentence." + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=50) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 2, "Should produce 2 chunks (oversized paragraph + next sentence)" + assert chunks[0].chunk_size > 50, "First chunk should be oversized" + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + assert chunks[1].chunk_index == 1, "Second chunk should have index 1" + + +@pytest.mark.asyncio +async def test_overflow_on_next_paragraph_emits_separate_chunk(chunker_class, make_text_generator): + """First paragraph near limit plus small paragraph should produce two separate chunks.""" + first_para = " ".join(["word"] * 5) + second_para = "Short text." + text = first_para + " " + second_para + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=10) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 2, "Should produce 2 chunks due to overflow" + assert chunks[0].text.strip() == first_para, "First chunk should contain only first paragraph" + assert chunks[1].text.strip() == second_para, ( + "Second chunk should contain only second paragraph" + ) + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + assert chunks[1].chunk_index == 1, "Second chunk should have index 1" + + +@pytest.mark.asyncio +async def test_small_paragraphs_batch_correctly(chunker_class, make_text_generator): + """Multiple small paragraphs should batch together with joiner spaces counted.""" + paragraphs = [" ".join(["word"] * 12) for _ in range(40)] + text = " ".join(paragraphs) + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=49) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 20, ( + "Should batch paragraphs (2 per chunk: 12 words × 2 tokens = 24, 24 + 1 joiner + 24 = 49)" + ) + assert all(c.chunk_index == i for i, c in enumerate(chunks)), ( + "Chunk indices should be sequential" + ) + all_text = " ".join(chunk.text.strip() for chunk in chunks) + expected_text = " ".join(paragraphs) + assert all_text == expected_text, "All paragraph text should be preserved with correct spacing" + + +@pytest.mark.asyncio +async def test_alternating_large_and_small_paragraphs_dont_batch( + chunker_class, make_text_generator +): + """Alternating near-max and small paragraphs should each become separate chunks.""" + large1 = "word" * 15 + "." + small1 = "Short." + large2 = "word" * 15 + "." + small2 = "Tiny." + text = large1 + " " + small1 + " " + large2 + " " + small2 + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + max_chunk_size = 10 + get_text = make_text_generator(text) + chunker = chunker_class(document, get_text, max_chunk_size=max_chunk_size) + chunks = await collect_chunks(chunker) + + assert len(chunks) == 4, "Should produce multiple chunks" + assert all(c.chunk_index == i for i, c in enumerate(chunks)), ( + "Chunk indices should be sequential" + ) + assert chunks[0].chunk_size > max_chunk_size, ( + "First chunk should be oversized (large paragraph)" + ) + assert chunks[1].chunk_size <= max_chunk_size, "Second chunk should be small (small paragraph)" + assert chunks[2].chunk_size > max_chunk_size, ( + "Third chunk should be oversized (large paragraph)" + ) + assert chunks[3].chunk_size <= max_chunk_size, "Fourth chunk should be small (small paragraph)" + + +@pytest.mark.asyncio +async def test_chunk_indices_and_ids_are_deterministic(chunker_class, make_text_generator): + """Running chunker twice on identical input should produce identical indices and IDs.""" + sentence1 = "one " * 4 + ". " + sentence2 = "two " * 4 + ". " + sentence3 = "one " * 4 + ". " + sentence4 = "two " * 4 + ". " + text = sentence1 + sentence2 + sentence3 + sentence4 + doc_id = uuid4() + max_chunk_size = 20 + + document1 = Document( + id=doc_id, + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text1 = make_text_generator(text) + chunker1 = chunker_class(document1, get_text1, max_chunk_size=max_chunk_size) + chunks1 = await collect_chunks(chunker1) + + document2 = Document( + id=doc_id, + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text2 = make_text_generator(text) + chunker2 = chunker_class(document2, get_text2, max_chunk_size=max_chunk_size) + chunks2 = await collect_chunks(chunker2) + + assert len(chunks1) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)" + assert len(chunks2) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)" + assert [c.chunk_index for c in chunks1] == [0, 1], "First run indices should be [0, 1]" + assert [c.chunk_index for c in chunks2] == [0, 1], "Second run indices should be [0, 1]" + assert chunks1[0].id == chunks2[0].id, "First chunk ID should be deterministic" + assert chunks1[1].id == chunks2[1].id, "Second chunk ID should be deterministic" + assert chunks1[0].id != chunks1[1].id, "Chunk IDs should be unique within a run" diff --git a/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py b/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py new file mode 100644 index 000000000..9d7be6936 --- /dev/null +++ b/cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py @@ -0,0 +1,324 @@ +"""Unit tests for TextChunkerWithOverlap overlap behavior.""" + +import sys +import pytest +from uuid import uuid4 +from unittest.mock import patch + +from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap +from cognee.modules.data.processing.document_types import Document +from cognee.tasks.chunks import chunk_by_paragraph + + +@pytest.fixture +def make_text_generator(): + """Factory for async text generators.""" + + def _factory(*texts): + async def gen(): + for text in texts: + yield text + + return gen + + return _factory + + +@pytest.fixture +def make_controlled_chunk_data(): + """Factory for controlled chunk_data generators.""" + + def _factory(*sentences, chunk_size_per_sentence=10): + def _chunk_data(text): + return [ + { + "text": sentence, + "chunk_size": chunk_size_per_sentence, + "cut_type": "sentence", + "chunk_id": uuid4(), + } + for sentence in sentences + ] + + return _chunk_data + + return _factory + + +@pytest.mark.asyncio +async def test_half_overlap_preserves_content_across_chunks( + make_text_generator, make_controlled_chunk_data +): + """With 50% overlap, consecutive chunks should share half their content.""" + s1 = "one" + s2 = "two" + s3 = "three" + s4 = "four" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=20, + chunk_overlap_ratio=0.5, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 3, "Should produce exactly 3 chunks (s1+s2, s2+s3, s3+s4)" + assert [c.chunk_index for c in chunks] == [0, 1, 2], "Chunk indices should be [0, 1, 2]" + assert "one" in chunks[0].text and "two" in chunks[0].text, "Chunk 0 should contain s1 and s2" + assert "two" in chunks[1].text and "three" in chunks[1].text, ( + "Chunk 1 should contain s2 (overlap) and s3" + ) + assert "three" in chunks[2].text and "four" in chunks[2].text, ( + "Chunk 2 should contain s3 (overlap) and s4" + ) + + +@pytest.mark.asyncio +async def test_zero_overlap_produces_no_duplicate_content( + make_text_generator, make_controlled_chunk_data +): + """With 0% overlap, no content should appear in multiple chunks.""" + s1 = "one" + s2 = "two" + s3 = "three" + s4 = "four" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=20, + chunk_overlap_ratio=0.0, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 2, "Should produce exactly 2 chunks (s1+s2, s3+s4)" + assert "one" in chunks[0].text and "two" in chunks[0].text, ( + "First chunk should contain s1 and s2" + ) + assert "three" in chunks[1].text and "four" in chunks[1].text, ( + "Second chunk should contain s3 and s4" + ) + assert "two" not in chunks[1].text and "three" not in chunks[0].text, ( + "No overlap: end of chunk 0 should not appear in chunk 1" + ) + + +@pytest.mark.asyncio +async def test_small_overlap_ratio_creates_minimal_overlap( + make_text_generator, make_controlled_chunk_data +): + """With 25% overlap ratio, chunks should have minimal overlap.""" + s1 = "alpha" + s2 = "beta" + s3 = "gamma" + s4 = "delta" + s5 = "epsilon" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=10) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=40, + chunk_overlap_ratio=0.25, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 2, "Should produce exactly 2 chunks" + assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]" + assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), ( + "Chunk 0 should contain s1 through s4" + ) + assert s4 in chunks[1].text and s5 in chunks[1].text, ( + "Chunk 1 should contain overlap s4 and new content s5" + ) + + +@pytest.mark.asyncio +async def test_high_overlap_ratio_creates_significant_overlap( + make_text_generator, make_controlled_chunk_data +): + """With 75% overlap ratio, consecutive chunks should share most content.""" + s1 = "red" + s2 = "blue" + s3 = "green" + s4 = "yellow" + s5 = "purple" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=5) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=20, + chunk_overlap_ratio=0.75, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 2, "Should produce exactly 2 chunks with 75% overlap" + assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]" + assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), ( + "Chunk 0 should contain s1, s2, s3, s4" + ) + assert all(token in chunks[1].text for token in [s2, s3, s4, s5]), ( + "Chunk 1 should contain s2, s3, s4 (overlap) and s5" + ) + + +@pytest.mark.asyncio +async def test_single_chunk_no_dangling_overlap(make_text_generator, make_controlled_chunk_data): + """Text that fits in one chunk should produce exactly one chunk, no overlap artifact.""" + s1 = "alpha" + s2 = "beta" + text = "dummy" + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + get_text = make_text_generator(text) + get_chunk_data = make_controlled_chunk_data(s1, s2, chunk_size_per_sentence=10) + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=20, + chunk_overlap_ratio=0.5, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 1, ( + "Should produce exactly 1 chunk when content fits within max_chunk_size" + ) + assert chunks[0].chunk_index == 0, "Single chunk should have index 0" + assert "alpha" in chunks[0].text and "beta" in chunks[0].text, ( + "Single chunk should contain all content" + ) + + +@pytest.mark.asyncio +async def test_paragraph_chunking_with_overlap(make_text_generator): + """Test that chunk_by_paragraph integration produces 25% overlap between chunks.""" + + def mock_get_embedding_engine(): + class MockEngine: + tokenizer = None + + return MockEngine() + + chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence") + + max_chunk_size = 20 + overlap_ratio = 0.25 # 5 token overlap + paragraph_max_size = int(0.5 * overlap_ratio * max_chunk_size) # = 2 + + text = ( + "A0 A1. A2 A3. A4 A5. A6 A7. A8 A9. " # 10 tokens (0-9) + "B0 B1. B2 B3. B4 B5. B6 B7. B8 B9. " # 10 tokens (10-19) + "C0 C1. C2 C3. C4 C5. C6 C7. C8 C9. " # 10 tokens (20-29) + "D0 D1. D2 D3. D4 D5. D6 D7. D8 D9. " # 10 tokens (30-39) + "E0 E1. E2 E3. E4 E5. E6 E7. E8 E9." # 10 tokens (40-49) + ) + + document = Document( + id=uuid4(), + name="test_document", + raw_data_location="/test/path", + external_metadata=None, + mime_type="text/plain", + ) + + get_text = make_text_generator(text) + + def get_chunk_data(text_input): + return chunk_by_paragraph( + text_input, max_chunk_size=paragraph_max_size, batch_paragraphs=True + ) + + with patch.object( + chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine + ): + chunker = TextChunkerWithOverlap( + document, + get_text, + max_chunk_size=max_chunk_size, + chunk_overlap_ratio=overlap_ratio, + get_chunk_data=get_chunk_data, + ) + chunks = [chunk async for chunk in chunker.read()] + + assert len(chunks) == 3, f"Should produce exactly 3 chunks, got {len(chunks)}" + + assert chunks[0].chunk_index == 0, "First chunk should have index 0" + assert chunks[1].chunk_index == 1, "Second chunk should have index 1" + assert chunks[2].chunk_index == 2, "Third chunk should have index 2" + + assert "A0" in chunks[0].text, "Chunk 0 should start with A0" + assert "A9" in chunks[0].text, "Chunk 0 should contain A9" + assert "B0" in chunks[0].text, "Chunk 0 should contain B0" + assert "B9" in chunks[0].text, "Chunk 0 should contain up to B9 (20 tokens)" + + assert "B" in chunks[1].text, "Chunk 1 should have overlap from B section" + assert "C" in chunks[1].text, "Chunk 1 should contain C section" + assert "D" in chunks[1].text, "Chunk 1 should contain D section" + + assert "D" in chunks[2].text, "Chunk 2 should have overlap from D section" + assert "E0" in chunks[2].text, "Chunk 2 should contain E0" + assert "E9" in chunks[2].text, "Chunk 2 should end with E9" + + chunk_0_end_words = chunks[0].text.split()[-4:] + chunk_1_words = chunks[1].text.split() + overlap_0_1 = any(word in chunk_1_words for word in chunk_0_end_words) + assert overlap_0_1, ( + f"No overlap detected between chunks 0 and 1. " + f"Chunk 0 ends with: {chunk_0_end_words}, " + f"Chunk 1 starts with: {chunk_1_words[:6]}" + ) + + chunk_1_end_words = chunks[1].text.split()[-4:] + chunk_2_words = chunks[2].text.split() + overlap_1_2 = any(word in chunk_2_words for word in chunk_1_end_words) + assert overlap_1_2, ( + f"No overlap detected between chunks 1 and 2. " + f"Chunk 1 ends with: {chunk_1_end_words}, " + f"Chunk 2 starts with: {chunk_2_words[:6]}" + ) diff --git a/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py b/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py new file mode 100644 index 000000000..8c2448287 --- /dev/null +++ b/cognee/tests/unit/modules/memify_tasks/test_cognify_session.py @@ -0,0 +1,111 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from cognee.tasks.memify.cognify_session import cognify_session +from cognee.exceptions import CogneeValidationError, CogneeSystemError + + +@pytest.mark.asyncio +async def test_cognify_session_success(): + """Test successful cognification of session data.""" + session_data = ( + "Session ID: test_session\n\nQuestion: What is AI?\n\nAnswer: AI is artificial intelligence" + ) + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + await cognify_session(session_data, dataset_id="123") + + mock_add.assert_called_once_with( + session_data, dataset_id="123", node_set=["user_sessions_from_cache"] + ) + mock_cognify.assert_called_once() + + +@pytest.mark.asyncio +async def test_cognify_session_empty_string(): + """Test cognification fails with empty string.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session("") + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_whitespace_string(): + """Test cognification fails with whitespace-only string.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session(" \n\t ") + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_none_data(): + """Test cognification fails with None data.""" + with pytest.raises(CogneeValidationError) as exc_info: + await cognify_session(None) + + assert "Session data cannot be empty" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_add_failure(): + """Test cognification handles cognee.add failure.""" + session_data = "Session ID: test\n\nQuestion: test?" + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock), + ): + mock_add.side_effect = Exception("Add operation failed") + + with pytest.raises(CogneeSystemError) as exc_info: + await cognify_session(session_data) + + assert "Failed to cognify session data" in str(exc_info.value) + assert "Add operation failed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_cognify_failure(): + """Test cognification handles cognify failure.""" + session_data = "Session ID: test\n\nQuestion: test?" + + with ( + patch("cognee.add", new_callable=AsyncMock), + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + mock_cognify.side_effect = Exception("Cognify operation failed") + + with pytest.raises(CogneeSystemError) as exc_info: + await cognify_session(session_data) + + assert "Failed to cognify session data" in str(exc_info.value) + assert "Cognify operation failed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cognify_session_re_raises_validation_error(): + """Test that CogneeValidationError is re-raised as-is.""" + with pytest.raises(CogneeValidationError): + await cognify_session("") + + +@pytest.mark.asyncio +async def test_cognify_session_with_special_characters(): + """Test cognification with special characters.""" + session_data = "Session: test™ © Question: What's special? Answer: Cognee is special!" + + with ( + patch("cognee.add", new_callable=AsyncMock) as mock_add, + patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify, + ): + await cognify_session(session_data, dataset_id="123") + + mock_add.assert_called_once_with( + session_data, dataset_id="123", node_set=["user_sessions_from_cache"] + ) + mock_cognify.assert_called_once() diff --git a/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py b/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py new file mode 100644 index 000000000..8cb27fef3 --- /dev/null +++ b/cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py @@ -0,0 +1,175 @@ +import sys +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from cognee.tasks.memify.extract_user_sessions import extract_user_sessions +from cognee.exceptions import CogneeSystemError +from cognee.modules.users.models import User + +# Get the actual module object (not the function) for patching +extract_user_sessions_module = sys.modules["cognee.tasks.memify.extract_user_sessions"] + + +@pytest.fixture +def mock_user(): + """Create a mock user.""" + user = MagicMock(spec=User) + user.id = "test-user-123" + return user + + +@pytest.fixture +def mock_qa_data(): + """Create mock Q&A data.""" + return [ + { + "question": "What is cognee?", + "context": "context about cognee", + "answer": "Cognee is a knowledge graph solution", + "time": "2025-01-01T12:00:00", + }, + { + "question": "How does it work?", + "context": "how it works context", + "answer": "It processes data and creates graphs", + "time": "2025-01-01T12:05:00", + }, + ] + + +@pytest.mark.asyncio +async def test_extract_user_sessions_success(mock_user, mock_qa_data): + """Test successful extraction of sessions.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["test_session"]): + sessions.append(session) + + assert len(sessions) == 1 + assert "Session ID: test_session" in sessions[0] + assert "Question: What is cognee?" in sessions[0] + assert "Answer: Cognee is a knowledge graph solution" in sessions[0] + assert "Question: How does it work?" in sessions[0] + assert "Answer: It processes data and creates graphs" in sessions[0] + + +@pytest.mark.asyncio +async def test_extract_user_sessions_multiple_sessions(mock_user, mock_qa_data): + """Test extraction of multiple sessions.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["session1", "session2"]): + sessions.append(session) + + assert len(sessions) == 2 + assert mock_cache_engine.get_all_qas.call_count == 2 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_no_data(mock_user, mock_qa_data): + """Test extraction handles empty data parameter.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = mock_qa_data + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions(None, session_ids=["test_session"]): + sessions.append(session) + + assert len(sessions) == 1 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_no_session_ids(mock_user): + """Test extraction handles no session IDs provided.""" + mock_cache_engine = AsyncMock() + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=None): + sessions.append(session) + + assert len(sessions) == 0 + mock_cache_engine.get_all_qas.assert_not_called() + + +@pytest.mark.asyncio +async def test_extract_user_sessions_empty_qa_data(mock_user): + """Test extraction handles empty Q&A data.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.return_value = [] + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions([{}], session_ids=["empty_session"]): + sessions.append(session) + + assert len(sessions) == 0 + + +@pytest.mark.asyncio +async def test_extract_user_sessions_cache_error_handling(mock_user, mock_qa_data): + """Test extraction continues on cache error for specific session.""" + mock_cache_engine = AsyncMock() + mock_cache_engine.get_all_qas.side_effect = [ + mock_qa_data, + Exception("Cache error"), + mock_qa_data, + ] + + with ( + patch.object(extract_user_sessions_module, "session_user") as mock_session_user, + patch.object( + extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine + ), + ): + mock_session_user.get.return_value = mock_user + + sessions = [] + async for session in extract_user_sessions( + [{}], session_ids=["session1", "session2", "session3"] + ): + sessions.append(session) + + assert len(sessions) == 2 diff --git a/cognee/tests/unit/modules/users/test_conditional_authentication.py b/cognee/tests/unit/modules/users/test_conditional_authentication.py index c4368d796..6568c3cb0 100644 --- a/cognee/tests/unit/modules/users/test_conditional_authentication.py +++ b/cognee/tests/unit/modules/users/test_conditional_authentication.py @@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration: # REQUIRE_AUTHENTICATION should be a boolean assert isinstance(REQUIRE_AUTHENTICATION, bool) - # Currently should be False (optional authentication) - assert not REQUIRE_AUTHENTICATION - class TestConditionalAuthenticationEnvironmentVariables: """Test environment variable handling.""" - def test_require_authentication_default_false(self): - """Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars.""" - with patch.dict(os.environ, {}, clear=True): - # Remove module from cache to force fresh import - module_name = "cognee.modules.users.methods.get_authenticated_user" - if module_name in sys.modules: - del sys.modules[module_name] - - # Import after patching environment - module will see empty environment - from cognee.modules.users.methods.get_authenticated_user import ( - REQUIRE_AUTHENTICATION, - ) - - importlib.invalidate_caches() - assert not REQUIRE_AUTHENTICATION - def test_require_authentication_true(self): """Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported.""" with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}): @@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables: assert REQUIRE_AUTHENTICATION - def test_require_authentication_false_explicit(self): - """Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported.""" - with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}): - # Remove module from cache to force fresh import - module_name = "cognee.modules.users.methods.get_authenticated_user" - if module_name in sys.modules: - del sys.modules[module_name] - - # Import after patching environment - module will see REQUIRE_AUTHENTICATION=false - from cognee.modules.users.methods.get_authenticated_user import ( - REQUIRE_AUTHENTICATION, - ) - - assert not REQUIRE_AUTHENTICATION - - def test_require_authentication_case_insensitive(self): - """Test that environment variable parsing is case insensitive when imported.""" - test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"] - - for case in test_cases: - with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}): - # Remove module from cache to force fresh import - module_name = "cognee.modules.users.methods.get_authenticated_user" - if module_name in sys.modules: - del sys.modules[module_name] - - # Import after patching environment - from cognee.modules.users.methods.get_authenticated_user import ( - REQUIRE_AUTHENTICATION, - ) - - expected = case.lower() == "true" - assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}" - - def test_current_require_authentication_value(self): - """Test that the current REQUIRE_AUTHENTICATION module value is as expected.""" - from cognee.modules.users.methods.get_authenticated_user import ( - REQUIRE_AUTHENTICATION, - ) - - # The module-level variable should currently be False (set at import time) - assert isinstance(REQUIRE_AUTHENTICATION, bool) - assert not REQUIRE_AUTHENTICATION - class TestConditionalAuthenticationEdgeCases: """Test edge cases and error scenarios.""" diff --git a/docker-compose.yml b/docker-compose.yml index 43d9b2607..472f24c21 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,7 @@ services: - DEBUG=false # Change to true if debugging - HOST=0.0.0.0 - ENVIRONMENT=local - - LOG_LEVEL=ERROR + - LOG_LEVEL=INFO extra_hosts: # Allows the container to reach your local machine using "host.docker.internal" instead of "localhost" - "host.docker.internal:host-gateway" diff --git a/entrypoint.sh b/entrypoint.sh index bad9b7aa3..496825408 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -43,10 +43,10 @@ sleep 2 if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then if [ "$DEBUG" = "true" ]; then echo "Waiting for the debugger to attach..." - debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app + exec debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload --access-logfile - --error-logfile - cognee.api.client:app else - gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app + exec gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload --access-logfile - --error-logfile - cognee.api.client:app fi else - gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error cognee.api.client:app + exec gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error --access-logfile - --error-logfile - cognee.api.client:app fi diff --git a/examples/python/agentic_reasoning_procurement_example.py b/examples/python/agentic_reasoning_procurement_example.py index 5aa3caa70..4e9d2d7e4 100644 --- a/examples/python/agentic_reasoning_procurement_example.py +++ b/examples/python/agentic_reasoning_procurement_example.py @@ -168,7 +168,7 @@ async def run_procurement_example(): for q in questions: print(f"Question: \n{q}") results = await procurement_system.search_memory(q, search_categories=[category]) - top_answer = results[category][0] + top_answer = results[category][0]["search_result"][0] print(f"Answer: \n{top_answer.strip()}\n") research_notes[category].append({"question": q, "answer": top_answer}) diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py index 431069050..1b476a2c3 100644 --- a/examples/python/code_graph_example.py +++ b/examples/python/code_graph_example.py @@ -1,5 +1,7 @@ import argparse import asyncio +import os + import cognee from cognee import SearchType from cognee.shared.logging_utils import setup_logging, ERROR @@ -8,6 +10,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline async def main(repo_path, include_docs): + # Disable permissions feature for this example + os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false" + run_status = False async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs): run_status = run_status diff --git a/examples/python/conversation_session_persistence_example.py b/examples/python/conversation_session_persistence_example.py new file mode 100644 index 000000000..5346f5012 --- /dev/null +++ b/examples/python/conversation_session_persistence_example.py @@ -0,0 +1,98 @@ +import asyncio + +import cognee +from cognee import visualize_graph +from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import ( + persist_sessions_in_knowledge_graph_pipeline, +) +from cognee.modules.search.types import SearchType +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import get_logger + +logger = get_logger("conversation_session_persistence_example") + + +async def main(): + # NOTE: CACHING has to be enabled for this example to work + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + text_1 = "Cognee is a solution that can build knowledge graph from text, creating an AI memory system" + text_2 = "Germany is a country located next to the Netherlands" + + await cognee.add([text_1, text_2]) + await cognee.cognify() + + question = "What can I use to create a knowledge graph?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "You sure about that?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=question + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "This is awesome!" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=question + ) + print("\nSession ID: default_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "Where is Germany?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "Next to which country again?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + question = "So you remember everything I asked from you?" + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=question, + session_id="different_session", + ) + print("\nSession ID: different_session") + print(f"Question: {question}") + print(f"Answer: {search_results}\n") + + session_ids_to_persist = ["default_session", "different_session"] + default_user = await get_default_user() + + await persist_sessions_in_knowledge_graph_pipeline( + user=default_user, + session_ids=session_ids_to_persist, + ) + + await visualize_graph() + + +if __name__ == "__main__": + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) diff --git a/examples/python/feedback_enrichment_minimal_example.py b/examples/python/feedback_enrichment_minimal_example.py index 11ef20830..8954bd5f6 100644 --- a/examples/python/feedback_enrichment_minimal_example.py +++ b/examples/python/feedback_enrichment_minimal_example.py @@ -67,7 +67,6 @@ async def run_feedback_enrichment_memify(last_n: int = 5): extraction_tasks=extraction_tasks, enrichment_tasks=enrichment_tasks, data=[{}], # A placeholder to prevent fetching the entire graph - dataset="feedback_enrichment_minimal", ) diff --git a/examples/python/memify_coding_agent_example.py b/examples/python/memify_coding_agent_example.py index 1fd3b1528..4a087ba61 100644 --- a/examples/python/memify_coding_agent_example.py +++ b/examples/python/memify_coding_agent_example.py @@ -89,7 +89,7 @@ async def main(): ) print("Coding rules created by memify:") - for coding_rule in coding_rules: + for coding_rule in coding_rules[0]["search_result"][0]: print("- " + coding_rule) # Visualize new graph with added memify context diff --git a/examples/python/relational_database_migration_example.py b/examples/python/relational_database_migration_example.py index 7e87347bc..98482cb4b 100644 --- a/examples/python/relational_database_migration_example.py +++ b/examples/python/relational_database_migration_example.py @@ -31,6 +31,9 @@ from cognee.infrastructure.databases.vector.pgvector import ( async def main(): + # Disable backend access control to migrate relational data + os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false" + # Clean all data stored in Cognee await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) diff --git a/examples/python/run_custom_pipeline_example.py b/examples/python/run_custom_pipeline_example.py new file mode 100644 index 000000000..1ca1b4402 --- /dev/null +++ b/examples/python/run_custom_pipeline_example.py @@ -0,0 +1,84 @@ +import asyncio +import cognee +from cognee.modules.engine.operations.setup import setup +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import setup_logging, INFO +from cognee.modules.pipelines import Task +from cognee.api.v1.search import SearchType + +# Prerequisites: +# 1. Copy `.env.template` and rename it to `.env`. +# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field: +# LLM_API_KEY = "your_key_here" + + +async def main(): + # Create a clean slate for cognee -- reset data and system state + print("Resetting cognee data...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("Data reset complete.\n") + + # Create relational database and tables + await setup() + + # cognee knowledge graph will be created based on this text + text = """ + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. + """ + + print("Adding text to cognee:") + print(text.strip()) + + # Let's recreate the cognee add pipeline through the custom pipeline framework + from cognee.tasks.ingestion import ingest_data, resolve_data_directories + + user = await get_default_user() + + # Values for tasks need to be filled before calling the pipeline + add_tasks = [ + Task(resolve_data_directories, include_subdirectories=True), + Task( + ingest_data, + "main_dataset", + user, + ), + ] + # Forward tasks to custom pipeline along with data and user information + await cognee.run_custom_pipeline( + tasks=add_tasks, data=text, user=user, dataset="main_dataset", pipeline_name="add_pipeline" + ) + print("Text added successfully.\n") + + # Use LLMs and cognee to create knowledge graph + from cognee.api.v1.cognify.cognify import get_default_tasks + + cognify_tasks = await get_default_tasks(user=user) + print("Recreating existing cognify pipeline in custom pipeline to create knowledge graph...\n") + await cognee.run_custom_pipeline( + tasks=cognify_tasks, user=user, dataset="main_dataset", pipeline_name="cognify_pipeline" + ) + print("Cognify process complete.\n") + + query_text = "Tell me about NLP" + print(f"Searching cognee for insights with query: '{query_text}'") + # Query cognee for insights on the added text + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=query_text + ) + + print("Search results:") + # Display results + for result_text in search_results: + print(result_text) + + +if __name__ == "__main__": + logger = setup_logging(log_level=INFO) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) diff --git a/examples/python/simple_example.py b/examples/python/simple_example.py index c13e48f85..237a8295e 100644 --- a/examples/python/simple_example.py +++ b/examples/python/simple_example.py @@ -59,14 +59,6 @@ async def main(): for result_text in search_results: print(result_text) - # Example output: - # ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'}) - # (...) - # It represents nodes and relationships in the knowledge graph: - # - The first element is the source node (e.g., 'natural language processing'). - # - The second element is the relationship between nodes (e.g., 'is_a_subfield_of'). - # - The third element is the target node (e.g., 'computer science'). - if __name__ == "__main__": logger = setup_logging(log_level=ERROR)