Merge branch 'dev' of github.com:topoteretes/cognee into dev

This commit is contained in:
vasilije 2025-04-19 10:07:02 +02:00
commit b35e04735f
40 changed files with 1127 additions and 824 deletions

View file

@ -8,12 +8,30 @@ on:
type: string type: string
default: '3.11.x' default: '3.11.x'
secrets: secrets:
OPENAI_API_KEY:
required: true
GRAPHISTRY_USERNAME: GRAPHISTRY_USERNAME:
required: true required: true
GRAPHISTRY_PASSWORD: GRAPHISTRY_PASSWORD:
required: true required: true
LLM_PROVIDER:
required: true
LLM_MODEL:
required: true
LLM_ENDPOINT:
required: true
LLM_API_KEY:
required: true
LLM_API_VERSION:
required: true
EMBEDDING_PROVIDER:
required: true
EMBEDDING_MODEL:
required: true
EMBEDDING_ENDPOINT:
required: true
EMBEDDING_API_KEY:
required: true
EMBEDDING_API_VERSION:
required: true
env: env:
RUNTIME__LOG_LEVEL: ERROR RUNTIME__LOG_LEVEL: ERROR
@ -60,6 +78,18 @@ jobs:
unit-tests: unit-tests:
name: Run Unit Tests name: Run Unit Tests
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
env:
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps: steps:
- name: Check out repository - name: Check out repository
uses: actions/checkout@v4 uses: actions/checkout@v4
@ -95,10 +125,20 @@ jobs:
name: Run Simple Examples name: Run Simple Examples
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
env: env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }} GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }} GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps: steps:
- name: Check out repository - name: Check out repository
uses: actions/checkout@v4 uses: actions/checkout@v4
@ -117,10 +157,20 @@ jobs:
name: Run Basic Graph Tests name: Run Basic Graph Tests
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
env: env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }} GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }} GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps: steps:
- name: Check out repository - name: Check out repository
uses: actions/checkout@v4 uses: actions/checkout@v4

View file

@ -1,4 +1,4 @@
name: Reusable Vector DB Tests name: Reusable Graph DB Tests
on: on:
workflow_call: workflow_call:

View file

@ -8,26 +8,30 @@ on:
type: string type: string
default: '["3.10.x", "3.11.x", "3.12.x"]' default: '["3.10.x", "3.11.x", "3.12.x"]'
secrets: secrets:
OPENAI_API_KEY:
required: true
GRAPHISTRY_USERNAME: GRAPHISTRY_USERNAME:
required: true required: true
GRAPHISTRY_PASSWORD: GRAPHISTRY_PASSWORD:
required: true required: true
LLM_PROVIDER:
required: true
LLM_MODEL: LLM_MODEL:
required: false required: true
LLM_ENDPOINT: LLM_ENDPOINT:
required: false required: true
LLM_API_KEY:
required: true
LLM_API_VERSION: LLM_API_VERSION:
required: false required: true
EMBEDDING_PROVIDER:
required: true
EMBEDDING_MODEL: EMBEDDING_MODEL:
required: false required: true
EMBEDDING_ENDPOINT: EMBEDDING_ENDPOINT:
required: false required: true
EMBEDDING_API_KEY: EMBEDDING_API_KEY:
required: false required: true
EMBEDDING_API_VERSION: EMBEDDING_API_VERSION:
required: false required: true
env: env:
RUNTIME__LOG_LEVEL: ERROR RUNTIME__LOG_LEVEL: ERROR
@ -55,6 +59,18 @@ jobs:
- name: Run unit tests - name: Run unit tests
run: poetry run pytest cognee/tests/unit/ run: poetry run pytest cognee/tests/unit/
env:
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
- name: Run integration tests - name: Run integration tests
if: ${{ !contains(matrix.os, 'windows') }} if: ${{ !contains(matrix.os, 'windows') }}
@ -62,13 +78,16 @@ jobs:
- name: Run default basic pipeline - name: Run default basic pipeline
env: env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }} GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }} GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}

View file

@ -8,6 +8,8 @@ on:
type: string type: string
default: '3.11.x' default: '3.11.x'
secrets: secrets:
LLM_PROVIDER:
required: true
LLM_MODEL: LLM_MODEL:
required: true required: true
LLM_ENDPOINT: LLM_ENDPOINT:
@ -16,6 +18,8 @@ on:
required: true required: true
LLM_API_VERSION: LLM_API_VERSION:
required: true required: true
EMBEDDING_PROVIDER:
required: true
EMBEDDING_MODEL: EMBEDDING_MODEL:
required: true required: true
EMBEDDING_ENDPOINT: EMBEDDING_ENDPOINT:
@ -24,12 +28,6 @@ on:
required: true required: true
EMBEDDING_API_VERSION: EMBEDDING_API_VERSION:
required: true required: true
OPENAI_API_KEY:
required: true
GRAPHISTRY_USERNAME:
required: true
GRAPHISTRY_PASSWORD:
required: true
jobs: jobs:
run-relational-db-migration-test-networkx: run-relational-db-migration-test-networkx:
@ -81,10 +79,13 @@ jobs:
- name: Run relational db test - name: Run relational db test
env: env:
ENV: 'dev' ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
@ -141,10 +142,14 @@ jobs:
env: env:
ENV: 'dev' ENV: 'dev'
GRAPH_DATABASE_PROVIDER: 'kuzu' GRAPH_DATABASE_PROVIDER: 'kuzu'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
@ -204,10 +209,14 @@ jobs:
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }} GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }} GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j" GRAPH_DATABASE_USERNAME: "neo4j"
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}

View file

@ -1,28 +0,0 @@
"""Add default user
Revision ID: 482cd6517ce4
Revises: 8057ae7329c2
Create Date: 2024-10-16 22:17:18.634638
"""
from typing import Sequence, Union
from sqlalchemy.util import await_only
from cognee.modules.users.methods import create_default_user, delete_user
# revision identifiers, used by Alembic.
revision: str = "482cd6517ce4"
down_revision: Union[str, None] = "8057ae7329c2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
def upgrade() -> None:
await_only(create_default_user())
def downgrade() -> None:
await_only(delete_user("default_user@example.com"))

View file

@ -1,27 +0,0 @@
"""Initial migration
Revision ID: 8057ae7329c2
Revises:
Create Date: 2024-10-02 12:55:20.989372
"""
from typing import Sequence, Union
from sqlalchemy.util import await_only
from cognee.infrastructure.databases.relational import get_relational_engine
# revision identifiers, used by Alembic.
revision: str = "8057ae7329c2"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
db_engine = get_relational_engine()
await_only(db_engine.create_database())
def downgrade() -> None:
db_engine = get_relational_engine()
await_only(db_engine.delete_database())

View file

@ -1,8 +1,7 @@
from typing import Union from typing import Union
from cognee.modules.search.types import SearchType
from cognee.modules.users.exceptions import UserNotFoundError
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.search.methods import search as search_function from cognee.modules.search.methods import search as search_function
@ -22,9 +21,6 @@ async def search(
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
if user is None:
raise UserNotFoundError
filtered_search_results = await search_function( filtered_search_results = await search_function(
query_text, query_text,
query_type, query_type,

View file

@ -10,4 +10,5 @@ from .exceptions import (
ServiceError, ServiceError,
InvalidValueError, InvalidValueError,
InvalidAttributeError, InvalidAttributeError,
CriticalError,
) )

View file

@ -53,3 +53,7 @@ class InvalidAttributeError(CogneeApiError):
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class CriticalError(CogneeApiError):
pass

View file

@ -7,4 +7,5 @@ This module defines a set of exceptions for handling various database errors
from .exceptions import ( from .exceptions import (
EntityNotFoundError, EntityNotFoundError,
EntityAlreadyExistsError, EntityAlreadyExistsError,
DatabaseNotCreatedError,
) )

View file

@ -1,5 +1,15 @@
from cognee.exceptions import CogneeApiError
from fastapi import status from fastapi import status
from cognee.exceptions import CogneeApiError, CriticalError
class DatabaseNotCreatedError(CriticalError):
def __init__(
self,
message: str = "The database has not been created yet. Please call `await setup()` first.",
name: str = "DatabaseNotCreatedError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)
class EntityNotFoundError(CogneeApiError): class EntityNotFoundError(CogneeApiError):

View file

@ -1,13 +1,13 @@
from typing import Protocol, Optional, Dict, Any, List, Tuple
from abc import abstractmethod, ABC
from uuid import UUID, uuid5, NAMESPACE_DNS
from cognee.modules.graph.relationship_manager import create_relationship
from functools import wraps
import inspect import inspect
from functools import wraps
from abc import abstractmethod, ABC
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple
from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
from cognee.shared.logging_utils import get_logger
from datetime import datetime, timezone
logger = get_logger() logger = get_logger()
@ -44,20 +44,16 @@ def record_graph_changes(func):
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
if func.__name__ == "add_nodes": if func.__name__ == "add_nodes":
nodes = args[0] nodes: List[DataPoint] = args[0]
for node in nodes: for node in nodes:
try: try:
node_id = ( node_id = UUID(str(node.id))
UUID(str(node[0])) if isinstance(node, tuple) else UUID(str(node.id))
)
relationship = GraphRelationshipLedger( relationship = GraphRelationshipLedger(
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"), id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=node_id, source_node_id=node_id,
destination_node_id=node_id, destination_node_id=node_id,
creator_function=f"{creator}.node", creator_function=f"{creator}.node",
node_label=node[1].get("type") node_label=getattr(node, "name", None) or str(node.id),
if isinstance(node, tuple)
else type(node).__name__,
) )
session.add(relationship) session.add(relationship)
await session.flush() await session.flush()
@ -74,7 +70,7 @@ def record_graph_changes(func):
target_id = UUID(str(edge[1])) target_id = UUID(str(edge[1]))
rel_type = str(edge[2]) rel_type = str(edge[2])
relationship = GraphRelationshipLedger( relationship = GraphRelationshipLedger(
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"), id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=source_id, source_node_id=source_id,
destination_node_id=target_id, destination_node_id=target_id,
creator_function=f"{creator}.{rel_type}", creator_function=f"{creator}.{rel_type}",

View file

@ -0,0 +1 @@
from .exceptions import CollectionNotFoundError

View file

@ -0,0 +1,12 @@
from fastapi import status
from cognee.exceptions import CriticalError
class CollectionNotFoundError(CriticalError):
def __init__(
self,
message,
name: str = "DatabaseNotCreatedError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)

View file

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import Generic, List, Optional, TypeVar, get_type_hints from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import lancedb import lancedb
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
@ -10,6 +10,7 @@ from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties from cognee.modules.storage.utils import copy_model, get_own_properties
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
@ -79,7 +80,6 @@ class LanceDBAdapter(VectorDBInterface):
connection = await self.get_connection() connection = await self.get_connection()
payload_schema = type(data_points[0]) payload_schema = type(data_points[0])
payload_schema = self.get_data_point_schema(payload_schema)
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await self.create_collection( await self.create_collection(
@ -194,12 +194,19 @@ class LanceDBAdapter(VectorDBInterface):
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection() connection = await self.get_connection()
collection = await connection.open_table(collection_name)
try:
collection = await connection.open_table(collection_name)
except ValueError:
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
results = await collection.vector_search(query_vector).limit(limit).to_pandas() results = await collection.vector_search(query_vector).limit(limit).to_pandas()
result_values = list(results.to_dict("index").values()) result_values = list(results.to_dict("index").values())
if not result_values:
return []
normalized_values = normalize_distances(result_values) normalized_values = normalize_distances(result_values)
return [ return [
@ -288,11 +295,33 @@ class LanceDBAdapter(VectorDBInterface):
if self.url.startswith("/"): if self.url.startswith("/"):
LocalStorage.remove_all(self.url) LocalStorage.remove_all(self.url)
def get_data_point_schema(self, model_type): def get_data_point_schema(self, model_type: BaseModel):
related_models_fields = []
for field_name, field_config in model_type.model_fields.items():
if hasattr(field_config, "model_fields"):
related_models_fields.append(field_name)
elif hasattr(field_config.annotation, "model_fields"):
related_models_fields.append(field_name)
elif (
get_origin(field_config.annotation) == Union
or get_origin(field_config.annotation) is list
):
models_list = get_args(field_config.annotation)
if any(hasattr(model, "model_fields") for model in models_list):
related_models_fields.append(field_name)
elif get_origin(field_config.annotation) == Optional:
model = get_args(field_config.annotation)
if hasattr(model, "model_fields"):
related_models_fields.append(field_name)
return copy_model( return copy_model(
model_type, model_type,
include_fields={ include_fields={
"id": (str, ...), "id": (str, ...),
}, },
exclude_fields=["metadata"], exclude_fields=["metadata"] + related_models_fields,
) )

View file

@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
@ -183,7 +184,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if collection_name in metadata.tables: if collection_name in metadata.tables:
return metadata.tables[collection_name] return metadata.tables[collection_name]
else: else:
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.") raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]):
# Get PGVectorDataPoint Table from database # Get PGVectorDataPoint Table from database
@ -244,6 +245,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
except EntityNotFoundError: except EntityNotFoundError:
# Ignore if collection does not exist # Ignore if collection does not exist
return [] return []
except CollectionNotFoundError:
# Ignore if collection does not exist
return []
async def search( async def search(
self, self,

View file

@ -1,8 +1,8 @@
import os
from typing import Optional from typing import Optional
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator, Field from pydantic import model_validator
import os
class LLMConfig(BaseSettings): class LLMConfig(BaseSettings):

View file

@ -1,7 +1,6 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from uuid import uuid5, NAMESPACE_DNS from uuid import uuid5, NAMESPACE_OID
from sqlalchemy import UUID, Column, DateTime, String, Index from sqlalchemy import UUID, Column, DateTime, String, Index
from sqlalchemy.orm import relationship
from cognee.infrastructure.databases.relational import Base from cognee.infrastructure.databases.relational import Base
@ -12,7 +11,7 @@ class GraphRelationshipLedger(Base):
id = Column( id = Column(
UUID, UUID,
primary_key=True, primary_key=True,
default=lambda: uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"), default=lambda: uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
) )
source_node_id = Column(UUID, nullable=False) source_node_id = Column(UUID, nullable=False)
destination_node_id = Column(UUID, nullable=False) destination_node_id = Column(UUID, nullable=False)

View file

@ -111,9 +111,6 @@ class CogneeGraph(CogneeAbstractGraph):
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
print(f"Error projecting graph: {e}") print(f"Error projecting graph: {e}")
raise e raise e
except Exception as ex:
print(f"Unexpected error: {ex}")
raise ex
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None: async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
for category, scored_results in node_distances.items(): for category, scored_results in node_distances.items():

View file

@ -2,15 +2,28 @@ from typing import Any, Optional
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class ChunksRetriever(BaseRetriever): class ChunksRetriever(BaseRetriever):
"""Retriever for handling document chunk-based searches.""" """Retriever for handling document chunk-based searches."""
def __init__(
self,
top_k: Optional[int] = 5,
):
self.top_k = top_k
async def get_context(self, query: str) -> Any: async def get_context(self, query: str) -> Any:
"""Retrieves document chunks context based on the query.""" """Retrieves document chunks context based on the query."""
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
return [result.payload for result in found_chunks] return [result.payload for result in found_chunks]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:

View file

@ -1,9 +1,10 @@
from typing import Any, Optional from typing import Any, Optional
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.tasks.completion.exceptions import NoRelevantDataFound from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
class CompletionRetriever(BaseRetriever): class CompletionRetriever(BaseRetriever):
@ -20,15 +21,21 @@ class CompletionRetriever(BaseRetriever):
self.system_prompt_path = system_prompt_path self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1 self.top_k = top_k if top_k is not None else 1
async def get_context(self, query: str) -> Any: async def get_context(self, query: str) -> str:
"""Retrieves relevant document chunks as context.""" """Retrieves relevant document chunks as context."""
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
if len(found_chunks) == 0: try:
raise NoRelevantDataFound found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks] if len(found_chunks) == 0:
return "\n".join(chunks_payload) return ""
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
return "\n".join(chunks_payload)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates an LLM completion using the context.""" """Generates an LLM completion using the context."""

View file

@ -1,5 +1,5 @@
from cognee.exceptions import CogneeApiError
from fastapi import status from fastapi import status
from cognee.exceptions import CogneeApiError, CriticalError
class CollectionDistancesNotFoundError(CogneeApiError): class CollectionDistancesNotFoundError(CogneeApiError):
@ -30,3 +30,7 @@ class CypherSearchError(CogneeApiError):
status_code: int = status.HTTP_400_BAD_REQUEST, status_code: int = status.HTTP_400_BAD_REQUEST,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NoDataError(CriticalError):
message: str = "No data found in the system, please add data first."

View file

@ -3,12 +3,12 @@ from collections import Counter
import string import string
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.tasks.completion.exceptions import NoRelevantDataFound
class GraphCompletionRetriever(BaseRetriever): class GraphCompletionRetriever(BaseRetriever):
@ -72,14 +72,18 @@ class GraphCompletionRetriever(BaseRetriever):
query, top_k=self.top_k, collections=vector_index_collections or None query, top_k=self.top_k, collections=vector_index_collections or None
) )
if len(found_triplets) == 0:
raise NoRelevantDataFound
return found_triplets return found_triplets
async def get_context(self, query: str) -> Any: async def get_context(self, query: str) -> str:
"""Retrieves and resolves graph triplets into context.""" """Retrieves and resolves graph triplets into context."""
triplets = await self.get_triplets(query) try:
triplets = await self.get_triplets(query)
except EntityNotFoundError:
return ""
if len(triplets) == 0:
return ""
return await self.resolve_edges_to_text(triplets) return await self.resolve_edges_to_text(triplets)
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:

View file

@ -4,6 +4,8 @@ from typing import Any, Optional
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class InsightsRetriever(BaseRetriever): class InsightsRetriever(BaseRetriever):
@ -14,7 +16,7 @@ class InsightsRetriever(BaseRetriever):
self.exploration_levels = exploration_levels self.exploration_levels = exploration_levels
self.top_k = top_k self.top_k = top_k
async def get_context(self, query: str) -> Any: async def get_context(self, query: str) -> list:
"""Find the neighbours of a given node in the graph.""" """Find the neighbours of a given node in the graph."""
if query is None: if query is None:
return [] return []
@ -27,10 +29,15 @@ class InsightsRetriever(BaseRetriever):
node_connections = await graph_engine.get_connections(str(exact_node["id"])) node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else: else:
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=self.top_k), try:
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k), results = await asyncio.gather(
) vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
results = [*results[0], *results[1]] results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][: self.top_k] relevant_results = [result for result in results if result.score < 0.5][: self.top_k]

View file

@ -2,6 +2,8 @@ from typing import Any, Optional
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class SummariesRetriever(BaseRetriever): class SummariesRetriever(BaseRetriever):
@ -14,7 +16,14 @@ class SummariesRetriever(BaseRetriever):
async def get_context(self, query: str) -> Any: async def get_context(self, query: str) -> Any:
"""Retrieves summary context based on the query.""" """Retrieves summary context based on the query."""
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit)
try:
summaries_results = await vector_engine.search(
"TextSummary_text", query, limit=self.limit
)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
return [summary.payload for summary in summaries_results] return [summary.payload for summary in summaries_results]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:

View file

@ -82,9 +82,6 @@ async def brute_force_triplet_search(
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
retrieved_results = await brute_force_search( retrieved_results = await brute_force_search(
query, query,
user, user,
@ -174,4 +171,4 @@ async def brute_force_search(
send_telemetry( send_telemetry(
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)} "cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
) )
raise RuntimeError("An error occurred during brute force search") from error raise error

View file

@ -20,9 +20,6 @@ async def code_description_to_code_part_search(
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs) retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs)
return retrieved_codeparts return retrieved_codeparts

View file

@ -2,9 +2,11 @@ from types import SimpleNamespace
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from sqlalchemy.future import select from sqlalchemy.future import select
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.base_config import get_base_config
from cognee.modules.users.exceptions.exceptions import UserNotFoundError
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.methods.create_default_user import create_default_user from cognee.modules.users.methods.create_default_user import create_default_user
from cognee.base_config import get_base_config
async def get_default_user() -> SimpleNamespace: async def get_default_user() -> SimpleNamespace:
@ -12,16 +14,24 @@ async def get_default_user() -> SimpleNamespace:
base_config = get_base_config() base_config = get_base_config()
default_email = base_config.default_user_email or "default_user@example.com" default_email = base_config.default_user_email or "default_user@example.com"
async with db_engine.get_async_session() as session: try:
query = select(User).options(selectinload(User.roles)).where(User.email == default_email) async with db_engine.get_async_session() as session:
query = (
select(User).options(selectinload(User.roles)).where(User.email == default_email)
)
result = await session.execute(query) result = await session.execute(query)
user = result.scalars().first() user = result.scalars().first()
if user is None: if user is None:
return await create_default_user() return await create_default_user()
# We return a SimpleNamespace to have the same user type as our SaaS # We return a SimpleNamespace to have the same user type as our SaaS
# SimpleNamespace is just a dictionary which can be accessed through attributes # SimpleNamespace is just a dictionary which can be accessed through attributes
auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[]) auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[])
return auth_data return auth_data
except Exception as error:
if "principals" in str(error.args):
raise DatabaseNotCreatedError() from error
raise UserNotFoundError(f"Failed to retrieve default user: {default_email}") from error

View file

@ -1 +1 @@
from cognee.tasks.completion.exceptions import NoRelevantDataFound from cognee.tasks.completion.exceptions import NoRelevantDataError

View file

@ -5,5 +5,5 @@ This module defines a set of exceptions for handling various compute errors
""" """
from .exceptions import ( from .exceptions import (
NoRelevantDataFound, NoRelevantDataError,
) )

View file

@ -2,11 +2,11 @@ from cognee.exceptions import CogneeApiError
from fastapi import status from fastapi import status
class NoRelevantDataFound(CogneeApiError): class NoRelevantDataError(CogneeApiError):
def __init__( def __init__(
self, self,
message: str = "Search did not find any data.", message: str = "Search did not find any data.",
name: str = "NoRelevantDataFound", name: str = "NoRelevantDataError",
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)

View file

@ -1,8 +1,5 @@
import asyncio
import time import time
import os import asyncio
from unittest.mock import patch, MagicMock
from functools import lru_cache
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.rate_limiter import ( from cognee.infrastructure.llm.rate_limiter import (
sleep_and_retry_sync, sleep_and_retry_sync,

View file

@ -1,120 +1,195 @@
import uuid import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
class TestChunksRetriever: class TestChunksRetriever:
@pytest.fixture @pytest.mark.asyncio
def mock_retriever(self): async def test_chunk_context_simple(self):
return ChunksRetriever() system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
retriever = ChunksRetriever()
context = await retriever.get_context("Mike")
assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") async def test_chunk_context_complex(self):
async def test_get_completion(self, mock_get_vector_engine, mock_retriever): system_directory_path = os.path.join(
# Setup pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
query = "test query" )
doc_id1 = str(uuid.uuid4()) cognee.config.system_root_directory(system_directory_path)
doc_id2 = str(uuid.uuid4()) data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
)
cognee.config.data_root_directory(data_directory_path)
# Mock search results await cognee.prune.prune_data()
mock_result_1 = MagicMock() await cognee.prune.prune_system(metadata=True)
mock_result_1.payload = { await setup()
"id": str(uuid.uuid4()),
"text": "This is the first chunk result.",
"document_id": doc_id1,
"metadata": {"title": "Document 1"},
}
mock_result_2 = MagicMock() document1 = TextDocument(
mock_result_2.payload = { name="Employee List",
"id": str(uuid.uuid4()), raw_data_location="somewhere",
"text": "This is the second chunk result.", external_metadata="",
"document_id": doc_id2, mime_type="text/plain",
"metadata": {"title": "Document 2"}, )
}
mock_search_results = [mock_result_1, mock_result_2] document2 = TextDocument(
mock_vector_engine = AsyncMock() name="Car List",
mock_vector_engine.search.return_value = mock_search_results raw_data_location="somewhere",
mock_get_vector_engine.return_value = mock_vector_engine external_metadata="",
mime_type="text/plain",
)
# Execute chunk1 = DocumentChunk(
results = await mock_retriever.get_completion(query) text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
# Verify chunk4 = DocumentChunk(
assert len(results) == 2 text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
# Check first result entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
assert results[0]["text"] == "This is the first chunk result."
assert results[0]["document_id"] == doc_id1
assert results[0]["metadata"]["title"] == "Document 1"
# Check second result await add_data_points(entities)
assert results[1]["text"] == "This is the second chunk result."
assert results[1]["document_id"] == doc_id2
assert results[1]["metadata"]["title"] == "Document 2"
# Verify search was called correctly retriever = ChunksRetriever(top_k=20)
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
context = await retriever.get_context("Christina")
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") async def test_chunk_context_on_empty_graph(self):
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever): system_directory_path = os.path.join(
# Setup pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
query = "test query with no results" )
mock_search_results = [] cognee.config.system_root_directory(system_directory_path)
mock_vector_engine = AsyncMock() data_directory_path = os.path.join(
mock_vector_engine.search.return_value = mock_search_results pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
mock_get_vector_engine.return_value = mock_vector_engine )
cognee.config.data_root_directory(data_directory_path)
# Execute await cognee.prune.prune_data()
results = await mock_retriever.get_completion(query) await cognee.prune.prune_system(metadata=True)
# Verify retriever = ChunksRetriever()
assert len(results) == 0
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
@pytest.mark.asyncio with pytest.raises(NoDataError):
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine") await retriever.get_context("Christina Mayer")
async def test_get_completion_with_missing_fields(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with incomplete data"
# Mock search results vector_engine = get_vector_engine()
mock_result_1 = MagicMock() await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk)
mock_result_1.payload = {
"id": str(uuid.uuid4()),
"text": "This chunk has no document_id.",
# Missing document_id and metadata
}
mock_result_2 = MagicMock()
mock_result_2.payload = {
"id": str(uuid.uuid4()),
# Missing text
"document_id": str(uuid.uuid4()),
"metadata": {"title": "Document with missing text"},
}
mock_search_results = [mock_result_1, mock_result_2] context = await retriever.get_context("Christina Mayer")
mock_vector_engine = AsyncMock() assert len(context) == 0, "Found chunks when none should exist"
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute
results = await mock_retriever.get_completion(query)
# Verify if __name__ == "__main__":
assert len(results) == 2 from asyncio import run
# First result should have content but no document_id test = TestChunksRetriever()
assert results[0]["text"] == "This chunk has no document_id."
assert "document_id" not in results[0]
assert "metadata" not in results[0]
# Second result should have document_id and metadata but no content run(test.test_chunk_context_simple())
assert "text" not in results[1] run(test.test_chunk_context_complex())
assert "document_id" in results[1] run(test.test_chunk_context_on_empty_graph())
assert results[1]["metadata"]["title"] == "Document with missing text"

View file

@ -1,84 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
class TestCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return CompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
async def test_get_completion(
self, mock_get_vector_engine, mock_render_prompt, mock_get_llm_client, mock_retriever
):
# Setup
query = "test query"
# Mock render_prompt
mock_render_prompt.return_value = "Rendered prompt with context"
mock_search_results = [MagicMock()]
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Mock LLM client
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = "Generated completion response"
mock_get_llm_client.return_value = mock_llm_client
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0] == "Generated completion response"
# Verify prompt was rendered
mock_render_prompt.assert_called_once()
# Verify LLM client was called
mock_llm_client.acreate_structured_output.assert_called_once_with(
text_input="Rendered prompt with context", system_prompt=None, response_model=str
)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.completion_retriever.generate_completion")
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
async def test_get_completion_with_custom_prompt(
self, mock_get_vector_engine, mock_generate_completion, mock_retriever
):
# Setup
query = "test query with custom prompt"
mock_search_results = [MagicMock()]
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
mock_generate_completion.return_value = "Custom prompt completion response"
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0] == "Custom prompt completion response"
assert mock_generate_completion.call_args[1]["user_prompt_path"] == "custom_user_prompt.txt"
assert (
mock_generate_completion.call_args[1]["system_prompt_path"]
== "custom_system_prompt.txt"
)

View file

@ -1,236 +1,159 @@
from unittest.mock import AsyncMock, MagicMock, patch import os
import pytest import pytest
import pathlib
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.graph.exceptions import EntityNotFoundError
from cognee.tasks.completion.exceptions import NoRelevantDataFound
class TestGraphCompletionRetriever: class TestGraphCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return GraphCompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search") async def test_graph_completion_context_simple(self):
async def test_get_triplets_success(self, mock_brute_force_triplet_search, mock_retriever): system_directory_path = os.path.join(
mock_brute_force_triplet_search.return_value = [ pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
AsyncMock(
node1=AsyncMock(attributes={"text": "Node A"}),
attributes={"relationship_type": "connects"},
node2=AsyncMock(attributes={"text": "Node B"}),
)
]
result = await mock_retriever.get_triplets("test query")
assert isinstance(result, list)
assert len(result) > 0
assert result[0].attributes["relationship_type"] == "connects"
mock_brute_force_triplet_search.assert_called_once()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
async def test_get_triplets_no_results(self, mock_brute_force_triplet_search, mock_retriever):
mock_brute_force_triplet_search.return_value = []
with pytest.raises(NoRelevantDataFound):
await mock_retriever.get_triplets("test query")
@pytest.mark.asyncio
async def test_resolve_edges_to_text(self, mock_retriever):
node_a = AsyncMock(id="node_a_id", attributes={"text": "Node A text content"})
node_b = AsyncMock(id="node_b_id", attributes={"text": "Node B text content"})
node_c = AsyncMock(id="node_c_id", attributes={"name": "Node C"})
triplets = [
AsyncMock(
node1=node_a,
attributes={"relationship_type": "connects"},
node2=node_b,
),
AsyncMock(
node1=node_a,
attributes={"relationship_type": "links"},
node2=node_c,
),
]
with patch.object(mock_retriever, "_get_title", return_value="Test Title"):
result = await mock_retriever.resolve_edges_to_text(triplets)
assert "Nodes:" in result
assert "Connections:" in result
assert "Node: Test Title" in result
assert "__node_content_start__" in result
assert "Node A text content" in result
assert "__node_content_end__" in result
assert "Node: Node C" in result
assert "Test Title --[connects]--> Test Title" in result
assert "Test Title --[links]--> Node C" in result
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_triplets",
new_callable=AsyncMock,
)
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
)
async def test_get_context(self, mock_resolve_edges_to_text, mock_get_triplets, mock_retriever):
"""Test get_context calls get_triplets and resolve_edges_to_text."""
mock_get_triplets.return_value = ["mock_triplet"]
mock_resolve_edges_to_text.return_value = "Mock Context"
result = await mock_retriever.get_context("test query")
assert result == "Mock Context"
mock_get_triplets.assert_called_once_with("test query")
mock_resolve_edges_to_text.assert_called_once_with(["mock_triplet"])
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
)
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
async def test_get_completion_without_context(
self, mock_generate_completion, mock_get_context, mock_retriever
):
"""Test get_completion when no context is provided (calls get_context)."""
mock_get_context.return_value = "Mock Context"
mock_generate_completion.return_value = "Generated Completion"
result = await mock_retriever.get_completion("test query")
assert result == ["Generated Completion"]
mock_get_context.assert_called_once_with("test query")
mock_generate_completion.assert_called_once()
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
)
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
async def test_get_completion_with_context(
self, mock_generate_completion, mock_get_context, mock_retriever
):
"""Test get_completion when context is provided (does not call get_context)."""
mock_generate_completion.return_value = "Generated Completion"
result = await mock_retriever.get_completion("test query", context="Provided Context")
assert result == ["Generated Completion"]
mock_get_context.assert_not_called()
mock_generate_completion.assert_called_once()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
async def test_get_completion_with_empty_graph(
self,
mock_get_default_user,
mock_get_graph_engine,
mock_get_llm_client,
mock_retriever,
):
query = "test query with empty graph"
mock_graph_engine = MagicMock()
mock_graph_engine.get_graph_data = AsyncMock()
mock_graph_engine.get_graph_data.return_value = ([], [])
mock_get_graph_engine.return_value = mock_graph_engine
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = (
"Generated graph completion response"
) )
mock_get_llm_client.return_value = mock_llm_client cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
)
cognee.config.data_root_directory(data_directory_path)
with pytest.raises(EntityNotFoundError): await cognee.prune.prune_data()
await mock_retriever.get_completion(query) await cognee.prune.prune_system(metadata=True)
await setup()
mock_graph_engine.get_graph_data.assert_called_once() class Company(DataPoint):
name: str
def test_top_n_words(self, mock_retriever): class Person(DataPoint):
"""Test extraction of top frequent words from text.""" name: str
text = "The quick brown fox jumps over the lazy dog. The fox is quick." works_for: Company
result = mock_retriever._top_n_words(text) company1 = Company(name="Figma")
assert len(result.split(", ")) <= 3 company2 = Company(name="Canva")
assert "fox" in result person1 = Person(name="Steve Rodger", works_for=company1)
assert "quick" in result person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
result = mock_retriever._top_n_words(text, top_n=2) entities = [company1, company2, person1, person2, person3, person4, person5]
assert len(result.split(", ")) <= 2
result = mock_retriever._top_n_words(text, separator=" | ") await add_data_points(entities)
assert " | " in result
result = mock_retriever._top_n_words(text, stop_words={"fox", "quick"}) retriever = GraphCompletionRetriever()
assert "fox" not in result
assert "quick" not in result
def test_get_title(self, mock_retriever): context = await retriever.get_context("Who works at Canva?")
"""Test title generation from text."""
text = "This is a long paragraph about various topics that should generate a title. The main topics are AI, programming and data science."
title = mock_retriever._get_title(text) assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "..." in title assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
assert "[" in title and "]" in title
title = mock_retriever._get_title(text, first_n_words=3) @pytest.mark.asyncio
first_part = title.split("...")[0].strip() async def test_graph_completion_context_complex(self):
assert len(first_part.split()) == 3 system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)
title = mock_retriever._get_title(text, top_n_words=2) await cognee.prune.prune_data()
top_part = title.split("[")[1].split("]")[0] await cognee.prune.prune_system(metadata=True)
assert len(top_part.split(", ")) <= 2 await setup()
def test_get_nodes(self, mock_retriever): class Company(DataPoint):
"""Test node processing and deduplication.""" name: str
node_with_text = AsyncMock(id="text_node", attributes={"text": "This is a text node"}) metadata: dict = {"index_fields": ["name"]}
node_with_name = AsyncMock(id="name_node", attributes={"name": "Named Node"})
node_without_attrs = AsyncMock(id="empty_node", attributes={})
edges = [ class Car(DataPoint):
AsyncMock( brand: str
node1=node_with_text, node2=node_with_name, attributes={"relationship_type": "rel1"} model: str
), year: int
AsyncMock(
node1=node_with_text, class Location(DataPoint):
node2=node_without_attrs, country: str
attributes={"relationship_type": "rel2"}, city: str
),
AsyncMock( class Home(DataPoint):
node1=node_with_name, location: Location
node2=node_without_attrs, rooms: int
attributes={"relationship_type": "rel3"}, sqm: int
),
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
] ]
with patch.object(mock_retriever, "_get_title", return_value="Generated Title"): person3 = Person(name="Jason Statham", works_for=company1)
nodes = mock_retriever._get_nodes(edges)
assert len(nodes) == 3 person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
for node_id, info in nodes.items(): person5 = Person(name="Christina Mayer", works_for=company2)
assert "node" in info person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
assert "name" in info
assert "content" in info
text_node_info = nodes[node_with_text.id] entities = [company1, company2, person1, person2, person3, person4, person5]
assert text_node_info["name"] == "Generated Title"
assert text_node_info["content"] == "This is a text node"
name_node_info = nodes[node_with_name.id] await add_data_points(entities)
assert name_node_info["name"] == "Named Node"
assert name_node_info["content"] == "Named Node"
empty_node_info = nodes[node_without_attrs.id] retriever = GraphCompletionRetriever(top_k=20)
assert empty_node_info["name"] == "Unnamed Node"
context = await retriever.get_context("Who works at Figma?")
print(context)
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
@pytest.mark.asyncio
async def test_get_graph_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = GraphCompletionRetriever()
with pytest.raises(DatabaseNotCreatedError):
await retriever.get_context("Who works at Figma?")
await setup()
context = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"
if __name__ == "__main__":
from asyncio import run
test = TestGraphCompletionRetriever()
run(test.test_graph_completion_context_simple())
run(test.test_graph_completion_context_complex())
run(test.test_get_graph_completion_context_on_empty_graph())

View file

@ -1,80 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
class TestGraphSummaryCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return GraphSummaryCompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.completion.read_query_prompt")
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
async def test_get_completion_with_custom_system_prompt(
self,
mock_get_default_user,
mock_render_prompt,
mock_read_query_prompt,
mock_get_llm_client,
mock_retriever,
):
# Setup
query = "test query with custom prompt"
# Set custom system prompt
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = (
"Generated graph summary completion response"
)
mock_get_llm_client.return_value = mock_llm_client
# Execute
results = await mock_retriever.get_completion(query, context="test context")
# Verify
assert len(results) == 1
# Verify render_prompt was called with custom prompt path
mock_render_prompt.assert_called_once()
assert mock_render_prompt.call_args[0][0] == "custom_user_prompt.txt"
mock_read_query_prompt.assert_called_once()
assert mock_read_query_prompt.call_args[0][0] == "custom_system_prompt.txt"
mock_llm_client.acreate_structured_output.assert_called_once()
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text"
)
@patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
)
async def test_resolve_edges_to_text_calls_super_and_summarizes(
self, mock_summarize_text, mock_resolve_edges_to_text, mock_retriever
):
"""Test resolve_edges_to_text calls the parent method and summarizes the result."""
mock_resolve_edges_to_text.return_value = "Raw graph edges text"
mock_summarize_text.return_value = "Summarized graph text"
result = await mock_retriever.resolve_edges_to_text(["mock_edge"])
mock_resolve_edges_to_text.assert_called_once_with(["mock_edge"])
mock_summarize_text.assert_called_once_with(
"Raw graph edges text", mock_retriever.summarize_prompt_path
)
assert result == "Summarized graph text"

View file

@ -1,103 +1,216 @@
import uuid import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pathlib
from cognee.modules.retrieval.insights_retriever import InsightsRetriever import cognee
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_connected_test_graph from cognee.low_level import setup
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine from cognee.tasks.storage import add_data_points
import unittest from cognee.modules.engine.models import Entity, EntityType
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
class TestInsightsRetriever: class TestInsightsRetriever:
@pytest.fixture @pytest.mark.asyncio
def mock_retriever(self): async def test_insights_context_simple(self):
return InsightsRetriever() system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
entityTypePerson = EntityType(
name="Person",
description="An individual",
)
person1 = Entity(
name="Steve Rodger",
is_a=entityTypePerson,
description="An American actor, comedian, and filmmaker",
)
person2 = Entity(
name="Mike Broski",
is_a=entityTypePerson,
description="Financial advisor and philanthropist",
)
person3 = Entity(
name="Christina Mayer",
is_a=entityTypePerson,
description="Maker of next generation of iconic American music videos",
)
entityTypeCompany = EntityType(
name="Company",
description="An organization that operates on an annual basis",
)
company1 = Entity(
name="Apple",
is_a=entityTypeCompany,
description="An American multinational technology company headquartered in Cupertino, California",
)
company2 = Entity(
name="Google",
is_a=entityTypeCompany,
description="An American multinational technology company that specializes in Internet-related services and products",
)
company3 = Entity(
name="Facebook",
is_a=entityTypeCompany,
description="An American social media, messaging, and online platform",
)
entities = [person1, person2, person3, company1, company2, company3]
await add_data_points(entities)
retriever = InsightsRetriever()
context = await retriever.get_context("Mike")
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine") async def test_insights_context_complex(self):
async def test_get_context_with_existing_node(self, mock_get_graph_engine, mock_retriever): system_directory_path = os.path.join(
"""Test get_context when node exists in graph.""" pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_complex"
mock_graph = AsyncMock() )
mock_get_graph_engine.return_value = mock_graph cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
# Mock graph response await cognee.prune.prune_data()
mock_graph.extract_node.return_value = {"id": "123"} await cognee.prune.prune_system(metadata=True)
mock_graph.get_connections.return_value = [ await setup()
({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})
]
result = await mock_retriever.get_context("123") entityTypePerson = EntityType(
name="Person",
description="An individual",
)
assert isinstance(result, list) person1 = Entity(
assert len(result) == 1 name="Steve Rodger",
assert result[0][0]["id"] == "123" is_a=entityTypePerson,
assert result[0][1]["relationship_name"] == "linked_to" description="An American actor, comedian, and filmmaker",
assert result[0][2]["id"] == "456" )
mock_graph.extract_node.assert_called_once_with("123")
mock_graph.get_connections.assert_called_once_with("123") person2 = Entity(
name="Mike Broski",
is_a=entityTypePerson,
description="Financial advisor and philanthropist",
)
person3 = Entity(
name="Christina Mayer",
is_a=entityTypePerson,
description="Maker of next generation of iconic American music videos",
)
person4 = Entity(
name="Jason Statham",
is_a=entityTypePerson,
description="An American actor",
)
person5 = Entity(
name="Mike Tyson",
is_a=entityTypePerson,
description="A former professional boxer from the United States",
)
entityTypeCompany = EntityType(
name="Company",
description="An organization that operates on an annual basis",
)
company1 = Entity(
name="Apple",
is_a=entityTypeCompany,
description="An American multinational technology company headquartered in Cupertino, California",
)
company2 = Entity(
name="Google",
is_a=entityTypeCompany,
description="An American multinational technology company that specializes in Internet-related services and products",
)
company3 = Entity(
name="Facebook",
is_a=entityTypeCompany,
description="An American social media, messaging, and online platform",
)
entities = [person1, person2, person3, company1, company2, company3]
await add_data_points(entities)
graph_engine = await get_graph_engine()
await graph_engine.add_edges(
[
(person1.id, company1.id, "works_for"),
(person2.id, company2.id, "works_for"),
(person3.id, company3.id, "works_for"),
(person4.id, company1.id, "works_for"),
(person5.id, company1.id, "works_for"),
]
)
retriever = InsightsRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine") async def test_insights_context_on_empty_graph(self):
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever): system_directory_path = os.path.join(
# Setup pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_empty"
query = "test query with no results" )
mock_search_results = [] cognee.config.system_root_directory(system_directory_path)
mock_vector_engine = AsyncMock() data_directory_path = os.path.join(
mock_vector_engine.search.return_value = mock_search_results pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_empty"
mock_get_vector_engine.return_value = mock_vector_engine )
cognee.config.data_root_directory(data_directory_path)
# Execute await cognee.prune.prune_data()
results = await mock_retriever.get_completion(query) await cognee.prune.prune_system(metadata=True)
# Verify retriever = InsightsRetriever()
assert len(results) == 0
@pytest.mark.asyncio with pytest.raises(NoDataError):
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine") await retriever.get_context("Christina Mayer")
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
async def test_get_context_with_no_exact_node(
self, mock_get_vector_engine, mock_get_graph_engine, mock_retriever
):
"""Test get_context when node does not exist in the graph and vector search is used."""
mock_graph = AsyncMock()
mock_get_graph_engine.return_value = mock_graph
mock_graph.extract_node.return_value = None # Node does not exist
mock_vector = AsyncMock() vector_engine = get_vector_engine()
mock_get_vector_engine.return_value = mock_vector await vector_engine.create_collection("Entity_name", payload_schema=Entity)
await vector_engine.create_collection("EntityType_name", payload_schema=EntityType)
mock_vector.search.side_effect = [ context = await retriever.get_context("Christina Mayer")
[AsyncMock(id="vec_1", score=0.4)], # Entity_name search assert context == [], "Returned context should be empty on an empty graph"
[AsyncMock(id="vec_2", score=0.3)], # EntityType_name search
]
mock_graph.get_connections.side_effect = lambda node_id: [
({"id": node_id}, {"relationship_name": "related_to"}, {"id": "456"})
]
result = await mock_retriever.get_context("non_existing_query") if __name__ == "__main__":
from asyncio import run
assert isinstance(result, list) test = TestInsightsRetriever()
assert len(result) == 2
assert result[0][0]["id"] == "vec_1"
assert result[0][1]["relationship_name"] == "related_to"
assert result[0][2]["id"] == "456"
assert result[1][0]["id"] == "vec_2" run(test.test_insights_context_simple())
assert result[1][1]["relationship_name"] == "related_to" run(test.test_insights_context_complex())
assert result[1][2]["id"] == "456" run(test.test_insights_context_on_empty_graph())
@pytest.mark.asyncio
async def test_get_context_with_none_query(self, mock_retriever):
"""Test get_context with a None query (should return empty list)."""
result = await mock_retriever.get_context(None)
assert result == []
@pytest.mark.asyncio
async def test_get_completion_with_context(self, mock_retriever):
"""Test get_completion when context is already provided."""
test_context = [({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})]
result = await mock_retriever.get_completion("test_query", context=test_context)
assert result == test_context

View file

@ -0,0 +1,196 @@
import os
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
class TestRAGCompletionRetriever:
@pytest.mark.asyncio
async def test_rag_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
retriever = CompletionRetriever()
context = await retriever.get_context("Mike")
assert context == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_rag_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
# TODO: top_k doesn't affect the output, it should be fixed.
retriever = CompletionRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_get_rag_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = CompletionRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk)
context = await retriever.get_context("Christina Mayer")
assert context == "", "Returned context should be empty on an empty graph"
if __name__ == "__main__":
from asyncio import run
test = TestRAGCompletionRetriever()
run(test.test_rag_completion_context_simple())
run(test.test_rag_completion_context_complex())
run(test.test_get_rag_completion_context_on_empty_graph())

View file

@ -1,122 +1,168 @@
import uuid import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.tasks.summarization.models import TextSummary
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
class TestSummariesRetriever: class TextSummariesRetriever:
@pytest.fixture @pytest.mark.asyncio
def mock_retriever(self): async def test_chunk_context(self):
return SummariesRetriever() system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk1_summary = TextSummary(
text="S.R.",
made_from=chunk1,
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2_summary = TextSummary(
text="M.B.",
made_from=chunk2,
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3_summary = TextSummary(
text="C.M.",
made_from=chunk3,
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk4_summary = TextSummary(
text="R.R.",
made_from=chunk4,
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5_summary = TextSummary(
text="H.Y.",
made_from=chunk5,
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6_summary = TextSummary(
text="C.H.",
made_from=chunk6,
)
entities = [
chunk1_summary,
chunk2_summary,
chunk3_summary,
chunk4_summary,
chunk5_summary,
chunk6_summary,
]
await add_data_points(entities)
retriever = SummariesRetriever(limit=20)
context = await retriever.get_context("Christina")
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine") async def test_chunk_context_on_empty_graph(self):
async def test_get_completion(self, mock_get_vector_engine, mock_retriever): system_directory_path = os.path.join(
# Setup pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
query = "test query" )
doc_id1 = str(uuid.uuid4()) cognee.config.system_root_directory(system_directory_path)
doc_id2 = str(uuid.uuid4()) data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
)
cognee.config.data_root_directory(data_directory_path)
# Mock search results await cognee.prune.prune_data()
mock_result_1 = MagicMock() await cognee.prune.prune_system(metadata=True)
mock_result_1.payload = {
"id": str(uuid.uuid4()),
"score": 0.95,
"payload": {
"text": "This is the first summary.",
"document_id": doc_id1,
"metadata": {"title": "Document 1"},
},
}
mock_result_2 = MagicMock()
mock_result_2.payload = {
"id": str(uuid.uuid4()),
"score": 0.85,
"payload": {
"text": "This is the second summary.",
"document_id": doc_id2,
"metadata": {"title": "Document 2"},
},
}
mock_search_results = [mock_result_1, mock_result_2] retriever = SummariesRetriever()
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute with pytest.raises(NoDataError):
results = await mock_retriever.get_completion(query) await retriever.get_context("Christina Mayer")
# Verify vector_engine = get_vector_engine()
assert len(results) == 2 await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
# Check first result context = await retriever.get_context("Christina Mayer")
assert results[0]["payload"]["text"] == "This is the first summary." assert context == [], "Returned context should be empty on an empty graph"
assert results[0]["payload"]["document_id"] == doc_id1
assert results[0]["payload"]["metadata"]["title"] == "Document 1"
assert results[0]["score"] == 0.95
# Check second result
assert results[1]["payload"]["text"] == "This is the second summary."
assert results[1]["payload"]["document_id"] == doc_id2
assert results[1]["payload"]["metadata"]["title"] == "Document 2"
assert results[1]["score"] == 0.85
# Verify search was called correctly if __name__ == "__main__":
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5) from asyncio import run
@pytest.mark.asyncio test = TextSummariesRetriever()
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with no results"
mock_search_results = []
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Execute run(test.test_chunk_context())
results = await mock_retriever.get_completion(query) run(test.test_chunk_context_on_empty_graph())
# Verify
assert len(results) == 0
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
async def test_get_completion_with_custom_limit(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with custom limit"
doc_id = str(uuid.uuid4())
# Mock search results
mock_result = MagicMock()
mock_result.payload = {
"id": str(uuid.uuid4()),
"score": 0.95,
"payload": {
"text": "This is a summary.",
"document_id": doc_id,
"metadata": {"title": "Document 1"},
},
}
mock_search_results = [mock_result]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Set custom limit
mock_retriever.limit = 10
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0]["payload"]["text"] == "This is a summary."
# Verify search was called with custom limit
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=10)

View file

@ -1,11 +1,11 @@
import pytest import pytest
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError from unittest.mock import AsyncMock, patch
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
from cognee.modules.retrieval.utils.brute_force_triplet_search import ( from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_search, brute_force_search,
brute_force_triplet_search, brute_force_triplet_search,
) )
from unittest.mock import AsyncMock, patch
@pytest.mark.asyncio @pytest.mark.asyncio
@ -20,13 +20,11 @@ async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
mock_vector_engine.get_distance_from_collection_elements.return_value = [] mock_vector_engine.get_distance_from_collection_elements.return_value = []
mock_get_vector_engine.return_value = mock_vector_engine mock_get_vector_engine.return_value = mock_vector_engine
with pytest.raises(Exception) as exc_info: with pytest.raises(CollectionDistancesNotFoundError):
await brute_force_search( await brute_force_search(
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
) )
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine") @patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
@ -40,9 +38,7 @@ async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_e
mock_vector_engine.get_distance_from_collection_elements.return_value = [] mock_vector_engine.get_distance_from_collection_elements.return_value = []
mock_get_vector_engine.return_value = mock_vector_engine mock_get_vector_engine.return_value = mock_vector_engine
with pytest.raises(Exception) as exc_info: with pytest.raises(CollectionDistancesNotFoundError):
await brute_force_triplet_search( await brute_force_triplet_search(
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
) )
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)