diff --git a/.github/workflows/basic_tests.yml b/.github/workflows/basic_tests.yml index 1b9e232c5..13ea829da 100644 --- a/.github/workflows/basic_tests.yml +++ b/.github/workflows/basic_tests.yml @@ -8,12 +8,30 @@ on: type: string default: '3.11.x' secrets: - OPENAI_API_KEY: - required: true GRAPHISTRY_USERNAME: required: true GRAPHISTRY_PASSWORD: required: true + LLM_PROVIDER: + required: true + LLM_MODEL: + required: true + LLM_ENDPOINT: + required: true + LLM_API_KEY: + required: true + LLM_API_VERSION: + required: true + EMBEDDING_PROVIDER: + required: true + EMBEDDING_MODEL: + required: true + EMBEDDING_ENDPOINT: + required: true + EMBEDDING_API_KEY: + required: true + EMBEDDING_API_VERSION: + required: true env: RUNTIME__LOG_LEVEL: ERROR @@ -60,6 +78,18 @@ jobs: unit-tests: name: Run Unit Tests runs-on: ubuntu-22.04 + env: + LLM_PROVIDER: openai + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + + EMBEDDING_PROVIDER: openai + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} steps: - name: Check out repository uses: actions/checkout@v4 @@ -95,10 +125,20 @@ jobs: name: Run Simple Examples runs-on: ubuntu-22.04 env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }} GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }} + + LLM_PROVIDER: openai + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + + EMBEDDING_PROVIDER: openai + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} steps: - name: Check out repository uses: actions/checkout@v4 @@ -117,10 +157,20 @@ jobs: name: Run Basic Graph Tests runs-on: ubuntu-22.04 env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }} GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }} + + LLM_PROVIDER: openai + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + + EMBEDDING_PROVIDER: openai + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} steps: - name: Check out repository uses: actions/checkout@v4 diff --git a/.github/workflows/graph_db_tests.yml b/.github/workflows/graph_db_tests.yml index 2b3a4dd58..3c6b7610e 100644 --- a/.github/workflows/graph_db_tests.yml +++ b/.github/workflows/graph_db_tests.yml @@ -1,4 +1,4 @@ -name: Reusable Vector DB Tests +name: Reusable Graph DB Tests on: workflow_call: diff --git a/.github/workflows/python_version_tests.yml b/.github/workflows/python_version_tests.yml index 506e65f62..75b5bfa18 100644 --- a/.github/workflows/python_version_tests.yml +++ b/.github/workflows/python_version_tests.yml @@ -8,26 +8,30 @@ on: type: string default: '["3.10.x", "3.11.x", "3.12.x"]' secrets: - OPENAI_API_KEY: - required: true GRAPHISTRY_USERNAME: required: true GRAPHISTRY_PASSWORD: required: true + LLM_PROVIDER: + required: true LLM_MODEL: - required: false + required: true LLM_ENDPOINT: - required: false + required: true + LLM_API_KEY: + required: true LLM_API_VERSION: - required: false + required: true + EMBEDDING_PROVIDER: + required: true EMBEDDING_MODEL: - required: false + required: true EMBEDDING_ENDPOINT: - required: false + required: true EMBEDDING_API_KEY: - required: false + required: true EMBEDDING_API_VERSION: - required: false + required: true env: RUNTIME__LOG_LEVEL: ERROR @@ -55,6 +59,18 @@ jobs: - name: Run unit tests run: poetry run pytest cognee/tests/unit/ + env: + LLM_PROVIDER: openai + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + + EMBEDDING_PROVIDER: openai + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - name: Run integration tests if: ${{ !contains(matrix.os, 'windows') }} @@ -62,13 +78,16 @@ jobs: - name: Run default basic pipeline env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }} GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }} + + LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + + EMBEDDING_PROVIDER: openai EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} diff --git a/.github/workflows/relational_db_migration_tests.yml b/.github/workflows/relational_db_migration_tests.yml index 2c5688313..5630bbc12 100644 --- a/.github/workflows/relational_db_migration_tests.yml +++ b/.github/workflows/relational_db_migration_tests.yml @@ -8,6 +8,8 @@ on: type: string default: '3.11.x' secrets: + LLM_PROVIDER: + required: true LLM_MODEL: required: true LLM_ENDPOINT: @@ -16,6 +18,8 @@ on: required: true LLM_API_VERSION: required: true + EMBEDDING_PROVIDER: + required: true EMBEDDING_MODEL: required: true EMBEDDING_ENDPOINT: @@ -24,12 +28,6 @@ on: required: true EMBEDDING_API_VERSION: required: true - OPENAI_API_KEY: - required: true - GRAPHISTRY_USERNAME: - required: true - GRAPHISTRY_PASSWORD: - required: true jobs: run-relational-db-migration-test-networkx: @@ -81,10 +79,13 @@ jobs: - name: Run relational db test env: ENV: 'dev' + LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + + EMBEDDING_PROVIDER: openai EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} @@ -141,10 +142,14 @@ jobs: env: ENV: 'dev' GRAPH_DATABASE_PROVIDER: 'kuzu' + + LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + + EMBEDDING_PROVIDER: openai EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} @@ -204,10 +209,14 @@ jobs: GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }} GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }} GRAPH_DATABASE_USERNAME: "neo4j" + + LLM_PROVIDER: openai LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + + EMBEDDING_PROVIDER: openai EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} diff --git a/alembic/versions/482cd6517ce4_add_default_user.py b/alembic/versions/482cd6517ce4_add_default_user.py deleted file mode 100644 index 92429e1e4..000000000 --- a/alembic/versions/482cd6517ce4_add_default_user.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Add default user - -Revision ID: 482cd6517ce4 -Revises: 8057ae7329c2 -Create Date: 2024-10-16 22:17:18.634638 - -""" - -from typing import Sequence, Union - -from sqlalchemy.util import await_only - -from cognee.modules.users.methods import create_default_user, delete_user - - -# revision identifiers, used by Alembic. -revision: str = "482cd6517ce4" -down_revision: Union[str, None] = "8057ae7329c2" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = "8057ae7329c2" - - -def upgrade() -> None: - await_only(create_default_user()) - - -def downgrade() -> None: - await_only(delete_user("default_user@example.com")) diff --git a/alembic/versions/8057ae7329c2_initial_migration.py b/alembic/versions/8057ae7329c2_initial_migration.py deleted file mode 100644 index 48e795327..000000000 --- a/alembic/versions/8057ae7329c2_initial_migration.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Initial migration - -Revision ID: 8057ae7329c2 -Revises: -Create Date: 2024-10-02 12:55:20.989372 - -""" - -from typing import Sequence, Union -from sqlalchemy.util import await_only -from cognee.infrastructure.databases.relational import get_relational_engine - -# revision identifiers, used by Alembic. -revision: str = "8057ae7329c2" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - db_engine = get_relational_engine() - await_only(db_engine.create_database()) - - -def downgrade() -> None: - db_engine = get_relational_engine() - await_only(db_engine.delete_database()) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 8d8aa2be4..723f41bdb 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -1,8 +1,7 @@ from typing import Union -from cognee.modules.search.types import SearchType -from cognee.modules.users.exceptions import UserNotFoundError from cognee.modules.users.models import User +from cognee.modules.search.types import SearchType from cognee.modules.users.methods import get_default_user from cognee.modules.search.methods import search as search_function @@ -22,9 +21,6 @@ async def search( if user is None: user = await get_default_user() - if user is None: - raise UserNotFoundError - filtered_search_results = await search_function( query_text, query_type, diff --git a/cognee/exceptions/__init__.py b/cognee/exceptions/__init__.py index 1432afcc8..d1d4ecbf5 100644 --- a/cognee/exceptions/__init__.py +++ b/cognee/exceptions/__init__.py @@ -10,4 +10,5 @@ from .exceptions import ( ServiceError, InvalidValueError, InvalidAttributeError, + CriticalError, ) diff --git a/cognee/exceptions/exceptions.py b/cognee/exceptions/exceptions.py index e983a2462..614cb29ac 100644 --- a/cognee/exceptions/exceptions.py +++ b/cognee/exceptions/exceptions.py @@ -53,3 +53,7 @@ class InvalidAttributeError(CogneeApiError): status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, ): super().__init__(message, name, status_code) + + +class CriticalError(CogneeApiError): + pass diff --git a/cognee/infrastructure/databases/exceptions/__init__.py b/cognee/infrastructure/databases/exceptions/__init__.py index 7c74db3df..9d8d18567 100644 --- a/cognee/infrastructure/databases/exceptions/__init__.py +++ b/cognee/infrastructure/databases/exceptions/__init__.py @@ -7,4 +7,5 @@ This module defines a set of exceptions for handling various database errors from .exceptions import ( EntityNotFoundError, EntityAlreadyExistsError, + DatabaseNotCreatedError, ) diff --git a/cognee/infrastructure/databases/exceptions/exceptions.py b/cognee/infrastructure/databases/exceptions/exceptions.py index f3e945d4e..eacfc4095 100644 --- a/cognee/infrastructure/databases/exceptions/exceptions.py +++ b/cognee/infrastructure/databases/exceptions/exceptions.py @@ -1,5 +1,15 @@ -from cognee.exceptions import CogneeApiError from fastapi import status +from cognee.exceptions import CogneeApiError, CriticalError + + +class DatabaseNotCreatedError(CriticalError): + def __init__( + self, + message: str = "The database has not been created yet. Please call `await setup()` first.", + name: str = "DatabaseNotCreatedError", + status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY, + ): + super().__init__(message, name, status_code) class EntityNotFoundError(CogneeApiError): diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index b26ed66f3..94a6335a7 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -1,13 +1,13 @@ -from typing import Protocol, Optional, Dict, Any, List, Tuple -from abc import abstractmethod, ABC -from uuid import UUID, uuid5, NAMESPACE_DNS -from cognee.modules.graph.relationship_manager import create_relationship -from functools import wraps import inspect +from functools import wraps +from abc import abstractmethod, ABC +from datetime import datetime, timezone +from typing import Optional, Dict, Any, List, Tuple +from uuid import NAMESPACE_OID, UUID, uuid5 +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.engine import DataPoint from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine -from cognee.shared.logging_utils import get_logger -from datetime import datetime, timezone logger = get_logger() @@ -44,20 +44,16 @@ def record_graph_changes(func): async with db_engine.get_async_session() as session: if func.__name__ == "add_nodes": - nodes = args[0] + nodes: List[DataPoint] = args[0] for node in nodes: try: - node_id = ( - UUID(str(node[0])) if isinstance(node, tuple) else UUID(str(node.id)) - ) + node_id = UUID(str(node.id)) relationship = GraphRelationshipLedger( - id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"), + id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), source_node_id=node_id, destination_node_id=node_id, creator_function=f"{creator}.node", - node_label=node[1].get("type") - if isinstance(node, tuple) - else type(node).__name__, + node_label=getattr(node, "name", None) or str(node.id), ) session.add(relationship) await session.flush() @@ -74,7 +70,7 @@ def record_graph_changes(func): target_id = UUID(str(edge[1])) rel_type = str(edge[2]) relationship = GraphRelationshipLedger( - id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"), + id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), source_node_id=source_id, destination_node_id=target_id, creator_function=f"{creator}.{rel_type}", diff --git a/cognee/infrastructure/databases/vector/exceptions/__init__.py b/cognee/infrastructure/databases/vector/exceptions/__init__.py new file mode 100644 index 000000000..da7eb1499 --- /dev/null +++ b/cognee/infrastructure/databases/vector/exceptions/__init__.py @@ -0,0 +1 @@ +from .exceptions import CollectionNotFoundError diff --git a/cognee/infrastructure/databases/vector/exceptions/exceptions.py b/cognee/infrastructure/databases/vector/exceptions/exceptions.py new file mode 100644 index 000000000..8f7ff17cc --- /dev/null +++ b/cognee/infrastructure/databases/vector/exceptions/exceptions.py @@ -0,0 +1,12 @@ +from fastapi import status +from cognee.exceptions import CriticalError + + +class CollectionNotFoundError(CriticalError): + def __init__( + self, + message, + name: str = "DatabaseNotCreatedError", + status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY, + ): + super().__init__(message, name, status_code) diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 959259632..bb512596b 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -1,5 +1,5 @@ import asyncio -from typing import Generic, List, Optional, TypeVar, get_type_hints +from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import lancedb from lancedb.pydantic import LanceModel, Vector @@ -10,6 +10,7 @@ from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.files.storage import LocalStorage from cognee.modules.storage.utils import copy_model, get_own_properties +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..models.ScoredResult import ScoredResult @@ -79,7 +80,6 @@ class LanceDBAdapter(VectorDBInterface): connection = await self.get_connection() payload_schema = type(data_points[0]) - payload_schema = self.get_data_point_schema(payload_schema) if not await self.has_collection(collection_name): await self.create_collection( @@ -194,12 +194,19 @@ class LanceDBAdapter(VectorDBInterface): query_vector = (await self.embedding_engine.embed_text([query_text]))[0] connection = await self.get_connection() - collection = await connection.open_table(collection_name) + + try: + collection = await connection.open_table(collection_name) + except ValueError: + raise CollectionNotFoundError(f"Collection '{collection_name}' not found!") results = await collection.vector_search(query_vector).limit(limit).to_pandas() result_values = list(results.to_dict("index").values()) + if not result_values: + return [] + normalized_values = normalize_distances(result_values) return [ @@ -288,11 +295,33 @@ class LanceDBAdapter(VectorDBInterface): if self.url.startswith("/"): LocalStorage.remove_all(self.url) - def get_data_point_schema(self, model_type): + def get_data_point_schema(self, model_type: BaseModel): + related_models_fields = [] + + for field_name, field_config in model_type.model_fields.items(): + if hasattr(field_config, "model_fields"): + related_models_fields.append(field_name) + + elif hasattr(field_config.annotation, "model_fields"): + related_models_fields.append(field_name) + + elif ( + get_origin(field_config.annotation) == Union + or get_origin(field_config.annotation) is list + ): + models_list = get_args(field_config.annotation) + if any(hasattr(model, "model_fields") for model in models_list): + related_models_fields.append(field_name) + + elif get_origin(field_config.annotation) == Optional: + model = get_args(field_config.annotation) + if hasattr(model, "model_fields"): + related_models_fields.append(field_name) + return copy_model( model_type, include_fields={ "id": (str, ...), }, - exclude_fields=["metadata"], + exclude_fields=["metadata"] + related_models_fields, ) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 4badb0a97..07745ac77 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from cognee.exceptions import InvalidValueError from cognee.infrastructure.databases.exceptions import EntityNotFoundError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.databases.relational import get_relational_engine @@ -183,7 +184,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): if collection_name in metadata.tables: return metadata.tables[collection_name] else: - raise EntityNotFoundError(message=f"Table '{collection_name}' not found.") + raise CollectionNotFoundError(f"Collection '{collection_name}' not found!") async def retrieve(self, collection_name: str, data_point_ids: List[str]): # Get PGVectorDataPoint Table from database @@ -244,6 +245,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): except EntityNotFoundError: # Ignore if collection does not exist return [] + except CollectionNotFoundError: + # Ignore if collection does not exist + return [] async def search( self, diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index ccf0fa3c3..f31e18308 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -1,8 +1,8 @@ +import os from typing import Optional from functools import lru_cache from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import model_validator, Field -import os +from pydantic import model_validator class LLMConfig(BaseSettings): diff --git a/cognee/modules/data/models/graph_relationship_ledger.py b/cognee/modules/data/models/graph_relationship_ledger.py index 0dd62ef06..2027714db 100644 --- a/cognee/modules/data/models/graph_relationship_ledger.py +++ b/cognee/modules/data/models/graph_relationship_ledger.py @@ -1,7 +1,6 @@ from datetime import datetime, timezone -from uuid import uuid5, NAMESPACE_DNS +from uuid import uuid5, NAMESPACE_OID from sqlalchemy import UUID, Column, DateTime, String, Index -from sqlalchemy.orm import relationship from cognee.infrastructure.databases.relational import Base @@ -12,7 +11,7 @@ class GraphRelationshipLedger(Base): id = Column( UUID, primary_key=True, - default=lambda: uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"), + default=lambda: uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), ) source_node_id = Column(UUID, nullable=False) destination_node_id = Column(UUID, nullable=False) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index ce227f296..fa1edca1b 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -111,9 +111,6 @@ class CogneeGraph(CogneeAbstractGraph): except (ValueError, TypeError) as e: print(f"Error projecting graph: {e}") raise e - except Exception as ex: - print(f"Unexpected error: {ex}") - raise ex async def map_vector_distances_to_graph_nodes(self, node_distances) -> None: for category, scored_results in node_distances.items(): diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index 61427b6f9..db17f2e57 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -2,15 +2,28 @@ from typing import Any, Optional from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError class ChunksRetriever(BaseRetriever): """Retriever for handling document chunk-based searches.""" + def __init__( + self, + top_k: Optional[int] = 5, + ): + self.top_k = top_k + async def get_context(self, query: str) -> Any: """Retrieves document chunks context based on the query.""" vector_engine = get_vector_engine() - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5) + + try: + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + except CollectionNotFoundError as error: + raise NoDataError("No data found in the system, please add data first.") from error + return [result.payload for result in found_chunks] async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index fba011cf5..1ec3a0eb3 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -1,9 +1,10 @@ from typing import Any, Optional from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.utils.completion import generate_completion -from cognee.tasks.completion.exceptions import NoRelevantDataFound +from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError class CompletionRetriever(BaseRetriever): @@ -20,15 +21,21 @@ class CompletionRetriever(BaseRetriever): self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 1 - async def get_context(self, query: str) -> Any: + async def get_context(self, query: str) -> str: """Retrieves relevant document chunks as context.""" vector_engine = get_vector_engine() - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) - if len(found_chunks) == 0: - raise NoRelevantDataFound - # Combine all chunks text returned from vector search (number of chunks is determined by top_k - chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks] - return "\n".join(chunks_payload) + + try: + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + + if len(found_chunks) == 0: + return "" + + # Combine all chunks text returned from vector search (number of chunks is determined by top_k + chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks] + return "\n".join(chunks_payload) + except CollectionNotFoundError as error: + raise NoDataError("No data found in the system, please add data first.") from error async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """Generates an LLM completion using the context.""" diff --git a/cognee/modules/retrieval/exceptions/exceptions.py b/cognee/modules/retrieval/exceptions/exceptions.py index 7e33e3a5f..acd9c1fa8 100644 --- a/cognee/modules/retrieval/exceptions/exceptions.py +++ b/cognee/modules/retrieval/exceptions/exceptions.py @@ -1,5 +1,5 @@ -from cognee.exceptions import CogneeApiError from fastapi import status +from cognee.exceptions import CogneeApiError, CriticalError class CollectionDistancesNotFoundError(CogneeApiError): @@ -30,3 +30,7 @@ class CypherSearchError(CogneeApiError): status_code: int = status.HTTP_400_BAD_REQUEST, ): super().__init__(message, name, status_code) + + +class NoDataError(CriticalError): + message: str = "No data found in the system, please add data first." diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 16b22eab7..3d41444c2 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -3,12 +3,12 @@ from collections import Counter import string from cognee.infrastructure.engine import DataPoint +from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS -from cognee.tasks.completion.exceptions import NoRelevantDataFound class GraphCompletionRetriever(BaseRetriever): @@ -72,14 +72,18 @@ class GraphCompletionRetriever(BaseRetriever): query, top_k=self.top_k, collections=vector_index_collections or None ) - if len(found_triplets) == 0: - raise NoRelevantDataFound - return found_triplets - async def get_context(self, query: str) -> Any: + async def get_context(self, query: str) -> str: """Retrieves and resolves graph triplets into context.""" - triplets = await self.get_triplets(query) + try: + triplets = await self.get_triplets(query) + except EntityNotFoundError: + return "" + + if len(triplets) == 0: + return "" + return await self.resolve_edges_to_text(triplets) async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: diff --git a/cognee/modules/retrieval/insights_retriever.py b/cognee/modules/retrieval/insights_retriever.py index 021b39f95..e34280f99 100644 --- a/cognee/modules/retrieval/insights_retriever.py +++ b/cognee/modules/retrieval/insights_retriever.py @@ -4,6 +4,8 @@ from typing import Any, Optional from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError class InsightsRetriever(BaseRetriever): @@ -14,7 +16,7 @@ class InsightsRetriever(BaseRetriever): self.exploration_levels = exploration_levels self.top_k = top_k - async def get_context(self, query: str) -> Any: + async def get_context(self, query: str) -> list: """Find the neighbours of a given node in the graph.""" if query is None: return [] @@ -27,10 +29,15 @@ class InsightsRetriever(BaseRetriever): node_connections = await graph_engine.get_connections(str(exact_node["id"])) else: vector_engine = get_vector_engine() - results = await asyncio.gather( - vector_engine.search("Entity_name", query_text=query, limit=self.top_k), - vector_engine.search("EntityType_name", query_text=query, limit=self.top_k), - ) + + try: + results = await asyncio.gather( + vector_engine.search("Entity_name", query_text=query, limit=self.top_k), + vector_engine.search("EntityType_name", query_text=query, limit=self.top_k), + ) + except CollectionNotFoundError as error: + raise NoDataError("No data found in the system, please add data first.") from error + results = [*results[0], *results[1]] relevant_results = [result for result in results if result.score < 0.5][: self.top_k] diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 7356563e1..8c0aac0dd 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -2,6 +2,8 @@ from typing import Any, Optional from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError class SummariesRetriever(BaseRetriever): @@ -14,7 +16,14 @@ class SummariesRetriever(BaseRetriever): async def get_context(self, query: str) -> Any: """Retrieves summary context based on the query.""" vector_engine = get_vector_engine() - summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit) + + try: + summaries_results = await vector_engine.search( + "TextSummary_text", query, limit=self.limit + ) + except CollectionNotFoundError as error: + raise NoDataError("No data found in the system, please add data first.") from error + return [summary.payload for summary in summaries_results] async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index bef4493b4..fcec0edd3 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -82,9 +82,6 @@ async def brute_force_triplet_search( if user is None: user = await get_default_user() - if user is None: - raise PermissionError("No user found in the system. Please create a user.") - retrieved_results = await brute_force_search( query, user, @@ -174,4 +171,4 @@ async def brute_force_search( send_telemetry( "cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)} ) - raise RuntimeError("An error occurred during brute force search") from error + raise error diff --git a/cognee/modules/retrieval/utils/description_to_codepart_search.py b/cognee/modules/retrieval/utils/description_to_codepart_search.py index 4b87131ef..1aaf084d4 100644 --- a/cognee/modules/retrieval/utils/description_to_codepart_search.py +++ b/cognee/modules/retrieval/utils/description_to_codepart_search.py @@ -20,9 +20,6 @@ async def code_description_to_code_part_search( if user is None: user = await get_default_user() - if user is None: - raise PermissionError("No user found in the system. Please create a user.") - retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs) return retrieved_codeparts diff --git a/cognee/modules/users/methods/get_default_user.py b/cognee/modules/users/methods/get_default_user.py index cb49b2c24..2c7d64176 100644 --- a/cognee/modules/users/methods/get_default_user.py +++ b/cognee/modules/users/methods/get_default_user.py @@ -2,9 +2,11 @@ from types import SimpleNamespace from sqlalchemy.orm import selectinload from sqlalchemy.future import select from cognee.modules.users.models import User +from cognee.base_config import get_base_config +from cognee.modules.users.exceptions.exceptions import UserNotFoundError +from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.methods.create_default_user import create_default_user -from cognee.base_config import get_base_config async def get_default_user() -> SimpleNamespace: @@ -12,16 +14,24 @@ async def get_default_user() -> SimpleNamespace: base_config = get_base_config() default_email = base_config.default_user_email or "default_user@example.com" - async with db_engine.get_async_session() as session: - query = select(User).options(selectinload(User.roles)).where(User.email == default_email) + try: + async with db_engine.get_async_session() as session: + query = ( + select(User).options(selectinload(User.roles)).where(User.email == default_email) + ) - result = await session.execute(query) - user = result.scalars().first() + result = await session.execute(query) + user = result.scalars().first() - if user is None: - return await create_default_user() + if user is None: + return await create_default_user() - # We return a SimpleNamespace to have the same user type as our SaaS - # SimpleNamespace is just a dictionary which can be accessed through attributes - auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[]) - return auth_data + # We return a SimpleNamespace to have the same user type as our SaaS + # SimpleNamespace is just a dictionary which can be accessed through attributes + auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[]) + return auth_data + except Exception as error: + if "principals" in str(error.args): + raise DatabaseNotCreatedError() from error + + raise UserNotFoundError(f"Failed to retrieve default user: {default_email}") from error diff --git a/cognee/tasks/completion/__init__.py b/cognee/tasks/completion/__init__.py index 93901e0d7..9d208c07f 100644 --- a/cognee/tasks/completion/__init__.py +++ b/cognee/tasks/completion/__init__.py @@ -1 +1 @@ -from cognee.tasks.completion.exceptions import NoRelevantDataFound +from cognee.tasks.completion.exceptions import NoRelevantDataError diff --git a/cognee/tasks/completion/exceptions/__init__.py b/cognee/tasks/completion/exceptions/__init__.py index 1530bf3f2..4ecfb0374 100644 --- a/cognee/tasks/completion/exceptions/__init__.py +++ b/cognee/tasks/completion/exceptions/__init__.py @@ -5,5 +5,5 @@ This module defines a set of exceptions for handling various compute errors """ from .exceptions import ( - NoRelevantDataFound, + NoRelevantDataError, ) diff --git a/cognee/tasks/completion/exceptions/exceptions.py b/cognee/tasks/completion/exceptions/exceptions.py index bb4bcb0c8..aebece145 100644 --- a/cognee/tasks/completion/exceptions/exceptions.py +++ b/cognee/tasks/completion/exceptions/exceptions.py @@ -2,11 +2,11 @@ from cognee.exceptions import CogneeApiError from fastapi import status -class NoRelevantDataFound(CogneeApiError): +class NoRelevantDataError(CogneeApiError): def __init__( self, message: str = "Search did not find any data.", - name: str = "NoRelevantDataFound", + name: str = "NoRelevantDataError", status_code=status.HTTP_404_NOT_FOUND, ): super().__init__(message, name, status_code) diff --git a/cognee/tests/unit/infrastructure/test_rate_limiting_retry.py b/cognee/tests/unit/infrastructure/test_rate_limiting_retry.py index cf8270e18..4338ab2ea 100644 --- a/cognee/tests/unit/infrastructure/test_rate_limiting_retry.py +++ b/cognee/tests/unit/infrastructure/test_rate_limiting_retry.py @@ -1,8 +1,5 @@ -import asyncio import time -import os -from unittest.mock import patch, MagicMock -from functools import lru_cache +import asyncio from cognee.shared.logging_utils import get_logger from cognee.infrastructure.llm.rate_limiter import ( sleep_and_retry_sync, diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py index 85b33060d..0050228e9 100644 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -1,120 +1,195 @@ -import uuid -from unittest.mock import AsyncMock, MagicMock, patch - +import os import pytest +import pathlib +import cognee +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.chunks_retriever import ChunksRetriever class TestChunksRetriever: - @pytest.fixture - def mock_retriever(self): - return ChunksRetriever() + @pytest.mark.asyncio + async def test_chunk_context_simple(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_rag_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_rag_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + + await add_data_points(entities) + + retriever = ChunksRetriever() + + context = await retriever.get_context("Mike") + + assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski" @pytest.mark.asyncio - @patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") - async def test_get_completion(self, mock_get_vector_engine, mock_retriever): - # Setup - query = "test query" - doc_id1 = str(uuid.uuid4()) - doc_id2 = str(uuid.uuid4()) + async def test_chunk_context_complex(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" + ) + cognee.config.data_root_directory(data_directory_path) - # Mock search results - mock_result_1 = MagicMock() - mock_result_1.payload = { - "id": str(uuid.uuid4()), - "text": "This is the first chunk result.", - "document_id": doc_id1, - "metadata": {"title": "Document 1"}, - } + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() - mock_result_2 = MagicMock() - mock_result_2.payload = { - "id": str(uuid.uuid4()), - "text": "This is the second chunk result.", - "document_id": doc_id2, - "metadata": {"title": "Document 2"}, - } + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) - mock_search_results = [mock_result_1, mock_result_2] - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) - # Execute - results = await mock_retriever.get_completion(query) + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) - # Verify - assert len(results) == 2 + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) - # Check first result - assert results[0]["text"] == "This is the first chunk result." - assert results[0]["document_id"] == doc_id1 - assert results[0]["metadata"]["title"] == "Document 1" + entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] - # Check second result - assert results[1]["text"] == "This is the second chunk result." - assert results[1]["document_id"] == doc_id2 - assert results[1]["metadata"]["title"] == "Document 2" + await add_data_points(entities) - # Verify search was called correctly - mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5) + retriever = ChunksRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" @pytest.mark.asyncio - @patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") - async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever): - # Setup - query = "test query with no results" - mock_search_results = [] - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine + async def test_chunk_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" + ) + cognee.config.data_root_directory(data_directory_path) - # Execute - results = await mock_retriever.get_completion(query) + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) - # Verify - assert len(results) == 0 - mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5) + retriever = ChunksRetriever() - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") - async def test_get_completion_with_missing_fields(self, mock_get_vector_engine, mock_retriever): - # Setup - query = "test query with incomplete data" + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") - # Mock search results - mock_result_1 = MagicMock() - mock_result_1.payload = { - "id": str(uuid.uuid4()), - "text": "This chunk has no document_id.", - # Missing document_id and metadata - } - mock_result_2 = MagicMock() - mock_result_2.payload = { - "id": str(uuid.uuid4()), - # Missing text - "document_id": str(uuid.uuid4()), - "metadata": {"title": "Document with missing text"}, - } + vector_engine = get_vector_engine() + await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk) - mock_search_results = [mock_result_1, mock_result_2] - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine + context = await retriever.get_context("Christina Mayer") + assert len(context) == 0, "Found chunks when none should exist" - # Execute - results = await mock_retriever.get_completion(query) - # Verify - assert len(results) == 2 +if __name__ == "__main__": + from asyncio import run - # First result should have content but no document_id - assert results[0]["text"] == "This chunk has no document_id." - assert "document_id" not in results[0] - assert "metadata" not in results[0] + test = TestChunksRetriever() - # Second result should have document_id and metadata but no content - assert "text" not in results[1] - assert "document_id" in results[1] - assert results[1]["metadata"]["title"] == "Document with missing text" + run(test.test_chunk_context_simple()) + run(test.test_chunk_context_complex()) + run(test.test_chunk_context_on_empty_graph()) diff --git a/cognee/tests/unit/modules/retrieval/completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/completion_retriever_test.py deleted file mode 100644 index 0da518008..000000000 --- a/cognee/tests/unit/modules/retrieval/completion_retriever_test.py +++ /dev/null @@ -1,84 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from cognee.modules.retrieval.completion_retriever import CompletionRetriever - - -class TestCompletionRetriever: - @pytest.fixture - def mock_retriever(self): - return CompletionRetriever(system_prompt_path="test_prompt.txt") - - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.utils.completion.get_llm_client") - @patch("cognee.modules.retrieval.utils.completion.render_prompt") - @patch("cognee.modules.retrieval.completion_retriever.get_vector_engine") - async def test_get_completion( - self, mock_get_vector_engine, mock_render_prompt, mock_get_llm_client, mock_retriever - ): - # Setup - query = "test query" - - # Mock render_prompt - mock_render_prompt.return_value = "Rendered prompt with context" - - mock_search_results = [MagicMock()] - mock_search_results[0].payload = {"text": "This is a sample document chunk."} - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine - - # Mock LLM client - mock_llm_client = MagicMock() - mock_llm_client.acreate_structured_output = AsyncMock() - mock_llm_client.acreate_structured_output.return_value = "Generated completion response" - mock_get_llm_client.return_value = mock_llm_client - - # Execute - results = await mock_retriever.get_completion(query) - - # Verify - assert len(results) == 1 - assert results[0] == "Generated completion response" - - # Verify prompt was rendered - mock_render_prompt.assert_called_once() - - # Verify LLM client was called - mock_llm_client.acreate_structured_output.assert_called_once_with( - text_input="Rendered prompt with context", system_prompt=None, response_model=str - ) - - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.completion_retriever.generate_completion") - @patch("cognee.modules.retrieval.completion_retriever.get_vector_engine") - async def test_get_completion_with_custom_prompt( - self, mock_get_vector_engine, mock_generate_completion, mock_retriever - ): - # Setup - query = "test query with custom prompt" - - mock_search_results = [MagicMock()] - mock_search_results[0].payload = {"text": "This is a sample document chunk."} - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine - - mock_retriever.user_prompt_path = "custom_user_prompt.txt" - mock_retriever.system_prompt_path = "custom_system_prompt.txt" - - mock_generate_completion.return_value = "Custom prompt completion response" - - # Execute - results = await mock_retriever.get_completion(query) - - # Verify - assert len(results) == 1 - assert results[0] == "Custom prompt completion response" - - assert mock_generate_completion.call_args[1]["user_prompt_path"] == "custom_user_prompt.txt" - assert ( - mock_generate_completion.call_args[1]["system_prompt_path"] - == "custom_system_prompt.txt" - ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 7befa8243..ad9a6eb52 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -1,236 +1,159 @@ -from unittest.mock import AsyncMock, MagicMock, patch - +import os import pytest +import pathlib +from typing import Optional, Union +import cognee +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.graph.exceptions import EntityNotFoundError -from cognee.tasks.completion.exceptions import NoRelevantDataFound class TestGraphCompletionRetriever: - @pytest.fixture - def mock_retriever(self): - return GraphCompletionRetriever(system_prompt_path="test_prompt.txt") - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search") - async def test_get_triplets_success(self, mock_brute_force_triplet_search, mock_retriever): - mock_brute_force_triplet_search.return_value = [ - AsyncMock( - node1=AsyncMock(attributes={"text": "Node A"}), - attributes={"relationship_type": "connects"}, - node2=AsyncMock(attributes={"text": "Node B"}), - ) - ] - - result = await mock_retriever.get_triplets("test query") - - assert isinstance(result, list) - assert len(result) > 0 - assert result[0].attributes["relationship_type"] == "connects" - mock_brute_force_triplet_search.assert_called_once() - - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search") - async def test_get_triplets_no_results(self, mock_brute_force_triplet_search, mock_retriever): - mock_brute_force_triplet_search.return_value = [] - - with pytest.raises(NoRelevantDataFound): - await mock_retriever.get_triplets("test query") - - @pytest.mark.asyncio - async def test_resolve_edges_to_text(self, mock_retriever): - node_a = AsyncMock(id="node_a_id", attributes={"text": "Node A text content"}) - node_b = AsyncMock(id="node_b_id", attributes={"text": "Node B text content"}) - node_c = AsyncMock(id="node_c_id", attributes={"name": "Node C"}) - - triplets = [ - AsyncMock( - node1=node_a, - attributes={"relationship_type": "connects"}, - node2=node_b, - ), - AsyncMock( - node1=node_a, - attributes={"relationship_type": "links"}, - node2=node_c, - ), - ] - - with patch.object(mock_retriever, "_get_title", return_value="Test Title"): - result = await mock_retriever.resolve_edges_to_text(triplets) - - assert "Nodes:" in result - assert "Connections:" in result - - assert "Node: Test Title" in result - assert "__node_content_start__" in result - assert "Node A text content" in result - assert "__node_content_end__" in result - assert "Node: Node C" in result - - assert "Test Title --[connects]--> Test Title" in result - assert "Test Title --[links]--> Node C" in result - - @pytest.mark.asyncio - @patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_triplets", - new_callable=AsyncMock, - ) - @patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", - new_callable=AsyncMock, - ) - async def test_get_context(self, mock_resolve_edges_to_text, mock_get_triplets, mock_retriever): - """Test get_context calls get_triplets and resolve_edges_to_text.""" - mock_get_triplets.return_value = ["mock_triplet"] - mock_resolve_edges_to_text.return_value = "Mock Context" - - result = await mock_retriever.get_context("test query") - - assert result == "Mock Context" - mock_get_triplets.assert_called_once_with("test query") - mock_resolve_edges_to_text.assert_called_once_with(["mock_triplet"]) - - @pytest.mark.asyncio - @patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context" - ) - @patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion") - async def test_get_completion_without_context( - self, mock_generate_completion, mock_get_context, mock_retriever - ): - """Test get_completion when no context is provided (calls get_context).""" - mock_get_context.return_value = "Mock Context" - mock_generate_completion.return_value = "Generated Completion" - - result = await mock_retriever.get_completion("test query") - - assert result == ["Generated Completion"] - mock_get_context.assert_called_once_with("test query") - mock_generate_completion.assert_called_once() - - @pytest.mark.asyncio - @patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context" - ) - @patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion") - async def test_get_completion_with_context( - self, mock_generate_completion, mock_get_context, mock_retriever - ): - """Test get_completion when context is provided (does not call get_context).""" - mock_generate_completion.return_value = "Generated Completion" - - result = await mock_retriever.get_completion("test query", context="Provided Context") - - assert result == ["Generated Completion"] - mock_get_context.assert_not_called() - mock_generate_completion.assert_called_once() - - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.utils.completion.get_llm_client") - @patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine") - @patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user") - async def test_get_completion_with_empty_graph( - self, - mock_get_default_user, - mock_get_graph_engine, - mock_get_llm_client, - mock_retriever, - ): - query = "test query with empty graph" - - mock_graph_engine = MagicMock() - mock_graph_engine.get_graph_data = AsyncMock() - mock_graph_engine.get_graph_data.return_value = ([], []) - mock_get_graph_engine.return_value = mock_graph_engine - - mock_llm_client = MagicMock() - mock_llm_client.acreate_structured_output = AsyncMock() - mock_llm_client.acreate_structured_output.return_value = ( - "Generated graph completion response" + async def test_graph_completion_context_simple(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_context" ) - mock_get_llm_client.return_value = mock_llm_client + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_context" + ) + cognee.config.data_root_directory(data_directory_path) - with pytest.raises(EntityNotFoundError): - await mock_retriever.get_completion(query) + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() - mock_graph_engine.get_graph_data.assert_called_once() + class Company(DataPoint): + name: str - def test_top_n_words(self, mock_retriever): - """Test extraction of top frequent words from text.""" - text = "The quick brown fox jumps over the lazy dog. The fox is quick." + class Person(DataPoint): + name: str + works_for: Company - result = mock_retriever._top_n_words(text) - assert len(result.split(", ")) <= 3 - assert "fox" in result - assert "quick" in result + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) - result = mock_retriever._top_n_words(text, top_n=2) - assert len(result.split(", ")) <= 2 + entities = [company1, company2, person1, person2, person3, person4, person5] - result = mock_retriever._top_n_words(text, separator=" | ") - assert " | " in result + await add_data_points(entities) - result = mock_retriever._top_n_words(text, stop_words={"fox", "quick"}) - assert "fox" not in result - assert "quick" not in result + retriever = GraphCompletionRetriever() - def test_get_title(self, mock_retriever): - """Test title generation from text.""" - text = "This is a long paragraph about various topics that should generate a title. The main topics are AI, programming and data science." + context = await retriever.get_context("Who works at Canva?") - title = mock_retriever._get_title(text) - assert "..." in title - assert "[" in title and "]" in title + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" - title = mock_retriever._get_title(text, first_n_words=3) - first_part = title.split("...")[0].strip() - assert len(first_part.split()) == 3 + @pytest.mark.asyncio + async def test_graph_completion_context_complex(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) - title = mock_retriever._get_title(text, top_n_words=2) - top_part = title.split("[")[1].split("]")[0] - assert len(top_part.split(", ")) <= 2 + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() - def test_get_nodes(self, mock_retriever): - """Test node processing and deduplication.""" - node_with_text = AsyncMock(id="text_node", attributes={"text": "This is a text node"}) - node_with_name = AsyncMock(id="name_node", attributes={"name": "Named Node"}) - node_without_attrs = AsyncMock(id="empty_node", attributes={}) + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} - edges = [ - AsyncMock( - node1=node_with_text, node2=node_with_name, attributes={"relationship_type": "rel1"} - ), - AsyncMock( - node1=node_with_text, - node2=node_without_attrs, - attributes={"relationship_type": "rel2"}, - ), - AsyncMock( - node1=node_with_name, - node2=node_without_attrs, - attributes={"relationship_type": "rel3"}, - ), + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), ] - with patch.object(mock_retriever, "_get_title", return_value="Generated Title"): - nodes = mock_retriever._get_nodes(edges) + person3 = Person(name="Jason Statham", works_for=company1) - assert len(nodes) == 3 + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - for node_id, info in nodes.items(): - assert "node" in info - assert "name" in info - assert "content" in info + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - text_node_info = nodes[node_with_text.id] - assert text_node_info["name"] == "Generated Title" - assert text_node_info["content"] == "This is a text node" + entities = [company1, company2, person1, person2, person3, person4, person5] - name_node_info = nodes[node_with_name.id] - assert name_node_info["name"] == "Named Node" - assert name_node_info["content"] == "Named Node" + await add_data_points(entities) - empty_node_info = nodes[node_without_attrs.id] - assert empty_node_info["name"] == "Unnamed Node" + retriever = GraphCompletionRetriever(top_k=20) + + context = await retriever.get_context("Who works at Figma?") + + print(context) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + @pytest.mark.asyncio + async def test_get_graph_completion_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + retriever = GraphCompletionRetriever() + + with pytest.raises(DatabaseNotCreatedError): + await retriever.get_context("Who works at Figma?") + + await setup() + + context = await retriever.get_context("Who works at Figma?") + assert context == "", "Context should be empty on an empty graph" + + +if __name__ == "__main__": + from asyncio import run + + test = TestGraphCompletionRetriever() + + run(test.test_graph_completion_context_simple()) + run(test.test_graph_completion_context_complex()) + run(test.test_get_graph_completion_context_on_empty_graph()) diff --git a/cognee/tests/unit/modules/retrieval/graph_summary_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_summary_completion_retriever_test.py deleted file mode 100644 index e35842d86..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_summary_completion_retriever_test.py +++ /dev/null @@ -1,80 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from cognee.modules.retrieval.graph_summary_completion_retriever import ( - GraphSummaryCompletionRetriever, -) - - -class TestGraphSummaryCompletionRetriever: - @pytest.fixture - def mock_retriever(self): - return GraphSummaryCompletionRetriever(system_prompt_path="test_prompt.txt") - - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.utils.completion.get_llm_client") - @patch("cognee.modules.retrieval.utils.completion.read_query_prompt") - @patch("cognee.modules.retrieval.utils.completion.render_prompt") - @patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user") - async def test_get_completion_with_custom_system_prompt( - self, - mock_get_default_user, - mock_render_prompt, - mock_read_query_prompt, - mock_get_llm_client, - mock_retriever, - ): - # Setup - query = "test query with custom prompt" - - # Set custom system prompt - mock_retriever.user_prompt_path = "custom_user_prompt.txt" - mock_retriever.system_prompt_path = "custom_system_prompt.txt" - - mock_llm_client = MagicMock() - mock_llm_client.acreate_structured_output = AsyncMock() - mock_llm_client.acreate_structured_output.return_value = ( - "Generated graph summary completion response" - ) - mock_get_llm_client.return_value = mock_llm_client - - # Execute - results = await mock_retriever.get_completion(query, context="test context") - - # Verify - assert len(results) == 1 - - # Verify render_prompt was called with custom prompt path - mock_render_prompt.assert_called_once() - assert mock_render_prompt.call_args[0][0] == "custom_user_prompt.txt" - - mock_read_query_prompt.assert_called_once() - assert mock_read_query_prompt.call_args[0][0] == "custom_system_prompt.txt" - - mock_llm_client.acreate_structured_output.assert_called_once() - - @pytest.mark.asyncio - @patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text" - ) - @patch( - "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", - new_callable=AsyncMock, - ) - async def test_resolve_edges_to_text_calls_super_and_summarizes( - self, mock_summarize_text, mock_resolve_edges_to_text, mock_retriever - ): - """Test resolve_edges_to_text calls the parent method and summarizes the result.""" - - mock_resolve_edges_to_text.return_value = "Raw graph edges text" - mock_summarize_text.return_value = "Summarized graph text" - - result = await mock_retriever.resolve_edges_to_text(["mock_edge"]) - - mock_resolve_edges_to_text.assert_called_once_with(["mock_edge"]) - mock_summarize_text.assert_called_once_with( - "Raw graph edges text", mock_retriever.summarize_prompt_path - ) - - assert result == "Summarized graph text" diff --git a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py index 76a3506ad..06efafb46 100644 --- a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py @@ -1,103 +1,216 @@ -import uuid -from unittest.mock import AsyncMock, MagicMock, patch - +import os import pytest +import pathlib -from cognee.modules.retrieval.insights_retriever import InsightsRetriever -from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_connected_test_graph -from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine -import unittest +import cognee +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.modules.engine.models import Entity, EntityType from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.insights_retriever import InsightsRetriever class TestInsightsRetriever: - @pytest.fixture - def mock_retriever(self): - return InsightsRetriever() + @pytest.mark.asyncio + async def test_insights_context_simple(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_simple" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_insights_context_simple" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + entityTypePerson = EntityType( + name="Person", + description="An individual", + ) + + person1 = Entity( + name="Steve Rodger", + is_a=entityTypePerson, + description="An American actor, comedian, and filmmaker", + ) + + person2 = Entity( + name="Mike Broski", + is_a=entityTypePerson, + description="Financial advisor and philanthropist", + ) + + person3 = Entity( + name="Christina Mayer", + is_a=entityTypePerson, + description="Maker of next generation of iconic American music videos", + ) + + entityTypeCompany = EntityType( + name="Company", + description="An organization that operates on an annual basis", + ) + + company1 = Entity( + name="Apple", + is_a=entityTypeCompany, + description="An American multinational technology company headquartered in Cupertino, California", + ) + + company2 = Entity( + name="Google", + is_a=entityTypeCompany, + description="An American multinational technology company that specializes in Internet-related services and products", + ) + + company3 = Entity( + name="Facebook", + is_a=entityTypeCompany, + description="An American social media, messaging, and online platform", + ) + + entities = [person1, person2, person3, company1, company2, company3] + + await add_data_points(entities) + + retriever = InsightsRetriever() + + context = await retriever.get_context("Mike") + + assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski" @pytest.mark.asyncio - @patch("cognee.modules.retrieval.insights_retriever.get_graph_engine") - async def test_get_context_with_existing_node(self, mock_get_graph_engine, mock_retriever): - """Test get_context when node exists in graph.""" - mock_graph = AsyncMock() - mock_get_graph_engine.return_value = mock_graph + async def test_insights_context_complex(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_complex" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_insights_context_complex" + ) + cognee.config.data_root_directory(data_directory_path) - # Mock graph response - mock_graph.extract_node.return_value = {"id": "123"} - mock_graph.get_connections.return_value = [ - ({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"}) - ] + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() - result = await mock_retriever.get_context("123") + entityTypePerson = EntityType( + name="Person", + description="An individual", + ) - assert isinstance(result, list) - assert len(result) == 1 - assert result[0][0]["id"] == "123" - assert result[0][1]["relationship_name"] == "linked_to" - assert result[0][2]["id"] == "456" - mock_graph.extract_node.assert_called_once_with("123") - mock_graph.get_connections.assert_called_once_with("123") + person1 = Entity( + name="Steve Rodger", + is_a=entityTypePerson, + description="An American actor, comedian, and filmmaker", + ) + + person2 = Entity( + name="Mike Broski", + is_a=entityTypePerson, + description="Financial advisor and philanthropist", + ) + + person3 = Entity( + name="Christina Mayer", + is_a=entityTypePerson, + description="Maker of next generation of iconic American music videos", + ) + + person4 = Entity( + name="Jason Statham", + is_a=entityTypePerson, + description="An American actor", + ) + + person5 = Entity( + name="Mike Tyson", + is_a=entityTypePerson, + description="A former professional boxer from the United States", + ) + + entityTypeCompany = EntityType( + name="Company", + description="An organization that operates on an annual basis", + ) + + company1 = Entity( + name="Apple", + is_a=entityTypeCompany, + description="An American multinational technology company headquartered in Cupertino, California", + ) + + company2 = Entity( + name="Google", + is_a=entityTypeCompany, + description="An American multinational technology company that specializes in Internet-related services and products", + ) + + company3 = Entity( + name="Facebook", + is_a=entityTypeCompany, + description="An American social media, messaging, and online platform", + ) + + entities = [person1, person2, person3, company1, company2, company3] + + await add_data_points(entities) + + graph_engine = await get_graph_engine() + + await graph_engine.add_edges( + [ + (person1.id, company1.id, "works_for"), + (person2.id, company2.id, "works_for"), + (person3.id, company3.id, "works_for"), + (person4.id, company1.id, "works_for"), + (person5.id, company1.id, "works_for"), + ] + ) + + retriever = InsightsRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer" @pytest.mark.asyncio - @patch("cognee.modules.retrieval.insights_retriever.get_vector_engine") - async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever): - # Setup - query = "test query with no results" - mock_search_results = [] - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine + async def test_insights_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_empty" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_empty" + ) + cognee.config.data_root_directory(data_directory_path) - # Execute - results = await mock_retriever.get_completion(query) + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) - # Verify - assert len(results) == 0 + retriever = InsightsRetriever() - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.insights_retriever.get_graph_engine") - @patch("cognee.modules.retrieval.insights_retriever.get_vector_engine") - async def test_get_context_with_no_exact_node( - self, mock_get_vector_engine, mock_get_graph_engine, mock_retriever - ): - """Test get_context when node does not exist in the graph and vector search is used.""" - mock_graph = AsyncMock() - mock_get_graph_engine.return_value = mock_graph - mock_graph.extract_node.return_value = None # Node does not exist + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") - mock_vector = AsyncMock() - mock_get_vector_engine.return_value = mock_vector + vector_engine = get_vector_engine() + await vector_engine.create_collection("Entity_name", payload_schema=Entity) + await vector_engine.create_collection("EntityType_name", payload_schema=EntityType) - mock_vector.search.side_effect = [ - [AsyncMock(id="vec_1", score=0.4)], # Entity_name search - [AsyncMock(id="vec_2", score=0.3)], # EntityType_name search - ] + context = await retriever.get_context("Christina Mayer") + assert context == [], "Returned context should be empty on an empty graph" - mock_graph.get_connections.side_effect = lambda node_id: [ - ({"id": node_id}, {"relationship_name": "related_to"}, {"id": "456"}) - ] - result = await mock_retriever.get_context("non_existing_query") +if __name__ == "__main__": + from asyncio import run - assert isinstance(result, list) - assert len(result) == 2 - assert result[0][0]["id"] == "vec_1" - assert result[0][1]["relationship_name"] == "related_to" - assert result[0][2]["id"] == "456" + test = TestInsightsRetriever() - assert result[1][0]["id"] == "vec_2" - assert result[1][1]["relationship_name"] == "related_to" - assert result[1][2]["id"] == "456" - - @pytest.mark.asyncio - async def test_get_context_with_none_query(self, mock_retriever): - """Test get_context with a None query (should return empty list).""" - result = await mock_retriever.get_context(None) - assert result == [] - - @pytest.mark.asyncio - async def test_get_completion_with_context(self, mock_retriever): - """Test get_completion when context is already provided.""" - test_context = [({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})] - result = await mock_retriever.get_completion("test_query", context=test_context) - assert result == test_context + run(test.test_insights_context_simple()) + run(test.test_insights_context_complex()) + run(test.test_insights_context_on_empty_graph()) diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py new file mode 100644 index 000000000..5304160c8 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -0,0 +1,196 @@ +import os +import pytest +import pathlib + +import cognee +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.completion_retriever import CompletionRetriever + + +class TestRAGCompletionRetriever: + @pytest.mark.asyncio + async def test_rag_completion_context_simple(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_rag_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_rag_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + + await add_data_points(entities) + + retriever = CompletionRetriever() + + context = await retriever.get_context("Mike") + + assert context == "Mike Broski", "Failed to get Mike Broski" + + @pytest.mark.asyncio + async def test_rag_completion_context_complex(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] + + await add_data_points(entities) + + # TODO: top_k doesn't affect the output, it should be fixed. + retriever = CompletionRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" + + @pytest.mark.asyncio + async def test_get_rag_completion_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + retriever = CompletionRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk) + + context = await retriever.get_context("Christina Mayer") + assert context == "", "Returned context should be empty on an empty graph" + + +if __name__ == "__main__": + from asyncio import run + + test = TestRAGCompletionRetriever() + + run(test.test_rag_completion_context_simple()) + run(test.test_rag_completion_context_complex()) + run(test.test_get_rag_completion_context_on_empty_graph()) diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index f62d81292..69a53194a 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -1,122 +1,168 @@ -import uuid -from unittest.mock import AsyncMock, MagicMock, patch - +import os import pytest +import pathlib +import cognee +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.tasks.summarization.models import TextSummary +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.summaries_retriever import SummariesRetriever -class TestSummariesRetriever: - @pytest.fixture - def mock_retriever(self): - return SummariesRetriever() +class TextSummariesRetriever: + @pytest.mark.asyncio + async def test_chunk_context(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_summary_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_summary_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk1_summary = TextSummary( + text="S.R.", + made_from=chunk1, + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2_summary = TextSummary( + text="M.B.", + made_from=chunk2, + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3_summary = TextSummary( + text="C.M.", + made_from=chunk3, + ) + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk4_summary = TextSummary( + text="R.R.", + made_from=chunk4, + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5_summary = TextSummary( + text="H.Y.", + made_from=chunk5, + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6_summary = TextSummary( + text="C.H.", + made_from=chunk6, + ) + + entities = [ + chunk1_summary, + chunk2_summary, + chunk3_summary, + chunk4_summary, + chunk5_summary, + chunk6_summary, + ] + + await add_data_points(entities) + + retriever = SummariesRetriever(limit=20) + + context = await retriever.get_context("Christina") + + assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" @pytest.mark.asyncio - @patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine") - async def test_get_completion(self, mock_get_vector_engine, mock_retriever): - # Setup - query = "test query" - doc_id1 = str(uuid.uuid4()) - doc_id2 = str(uuid.uuid4()) + async def test_chunk_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_summary_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_summary_context" + ) + cognee.config.data_root_directory(data_directory_path) - # Mock search results - mock_result_1 = MagicMock() - mock_result_1.payload = { - "id": str(uuid.uuid4()), - "score": 0.95, - "payload": { - "text": "This is the first summary.", - "document_id": doc_id1, - "metadata": {"title": "Document 1"}, - }, - } - mock_result_2 = MagicMock() - mock_result_2.payload = { - "id": str(uuid.uuid4()), - "score": 0.85, - "payload": { - "text": "This is the second summary.", - "document_id": doc_id2, - "metadata": {"title": "Document 2"}, - }, - } + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) - mock_search_results = [mock_result_1, mock_result_2] - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine + retriever = SummariesRetriever() - # Execute - results = await mock_retriever.get_completion(query) + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") - # Verify - assert len(results) == 2 + vector_engine = get_vector_engine() + await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) - # Check first result - assert results[0]["payload"]["text"] == "This is the first summary." - assert results[0]["payload"]["document_id"] == doc_id1 - assert results[0]["payload"]["metadata"]["title"] == "Document 1" - assert results[0]["score"] == 0.95 + context = await retriever.get_context("Christina Mayer") + assert context == [], "Returned context should be empty on an empty graph" - # Check second result - assert results[1]["payload"]["text"] == "This is the second summary." - assert results[1]["payload"]["document_id"] == doc_id2 - assert results[1]["payload"]["metadata"]["title"] == "Document 2" - assert results[1]["score"] == 0.85 - # Verify search was called correctly - mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5) +if __name__ == "__main__": + from asyncio import run - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine") - async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever): - # Setup - query = "test query with no results" - mock_search_results = [] - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine + test = TextSummariesRetriever() - # Execute - results = await mock_retriever.get_completion(query) - - # Verify - assert len(results) == 0 - mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5) - - @pytest.mark.asyncio - @patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine") - async def test_get_completion_with_custom_limit(self, mock_get_vector_engine, mock_retriever): - # Setup - query = "test query with custom limit" - doc_id = str(uuid.uuid4()) - - # Mock search results - mock_result = MagicMock() - mock_result.payload = { - "id": str(uuid.uuid4()), - "score": 0.95, - "payload": { - "text": "This is a summary.", - "document_id": doc_id, - "metadata": {"title": "Document 1"}, - }, - } - - mock_search_results = [mock_result] - mock_vector_engine = AsyncMock() - mock_vector_engine.search.return_value = mock_search_results - mock_get_vector_engine.return_value = mock_vector_engine - - # Set custom limit - mock_retriever.limit = 10 - - # Execute - results = await mock_retriever.get_completion(query) - - # Verify - assert len(results) == 1 - assert results[0]["payload"]["text"] == "This is a summary." - - # Verify search was called with custom limit - mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=10) + run(test.test_chunk_context()) + run(test.test_chunk_context_on_empty_graph()) diff --git a/cognee/tests/unit/modules/retrieval/utils/brute_force_triplet_search_test.py b/cognee/tests/unit/modules/retrieval/utils/brute_force_triplet_search_test.py index 9af0af42a..cb4d8de8c 100644 --- a/cognee/tests/unit/modules/retrieval/utils/brute_force_triplet_search_test.py +++ b/cognee/tests/unit/modules/retrieval/utils/brute_force_triplet_search_test.py @@ -1,11 +1,11 @@ import pytest -from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError +from unittest.mock import AsyncMock, patch from cognee.modules.users.models import User +from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError from cognee.modules.retrieval.utils.brute_force_triplet_search import ( brute_force_search, brute_force_triplet_search, ) -from unittest.mock import AsyncMock, patch @pytest.mark.asyncio @@ -20,13 +20,11 @@ async def test_brute_force_search_collection_not_found(mock_get_vector_engine): mock_vector_engine.get_distance_from_collection_elements.return_value = [] mock_get_vector_engine.return_value = mock_vector_engine - with pytest.raises(Exception) as exc_info: + with pytest.raises(CollectionDistancesNotFoundError): await brute_force_search( query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment ) - assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError) - @pytest.mark.asyncio @patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine") @@ -40,9 +38,7 @@ async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_e mock_vector_engine.get_distance_from_collection_elements.return_value = [] mock_get_vector_engine.return_value = mock_vector_engine - with pytest.raises(Exception) as exc_info: + with pytest.raises(CollectionDistancesNotFoundError): await brute_force_triplet_search( query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment ) - - assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)