Merge branch 'dev' of github.com:topoteretes/cognee into dev
This commit is contained in:
commit
b35e04735f
40 changed files with 1127 additions and 824 deletions
62
.github/workflows/basic_tests.yml
vendored
62
.github/workflows/basic_tests.yml
vendored
|
|
@ -8,12 +8,30 @@ on:
|
||||||
type: string
|
type: string
|
||||||
default: '3.11.x'
|
default: '3.11.x'
|
||||||
secrets:
|
secrets:
|
||||||
OPENAI_API_KEY:
|
|
||||||
required: true
|
|
||||||
GRAPHISTRY_USERNAME:
|
GRAPHISTRY_USERNAME:
|
||||||
required: true
|
required: true
|
||||||
GRAPHISTRY_PASSWORD:
|
GRAPHISTRY_PASSWORD:
|
||||||
required: true
|
required: true
|
||||||
|
LLM_PROVIDER:
|
||||||
|
required: true
|
||||||
|
LLM_MODEL:
|
||||||
|
required: true
|
||||||
|
LLM_ENDPOINT:
|
||||||
|
required: true
|
||||||
|
LLM_API_KEY:
|
||||||
|
required: true
|
||||||
|
LLM_API_VERSION:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_PROVIDER:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_MODEL:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_ENDPOINT:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_API_KEY:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_API_VERSION:
|
||||||
|
required: true
|
||||||
|
|
||||||
env:
|
env:
|
||||||
RUNTIME__LOG_LEVEL: ERROR
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
|
@ -60,6 +78,18 @@ jobs:
|
||||||
unit-tests:
|
unit-tests:
|
||||||
name: Run Unit Tests
|
name: Run Unit Tests
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
env:
|
||||||
|
LLM_PROVIDER: openai
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDER: openai
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -95,10 +125,20 @@ jobs:
|
||||||
name: Run Simple Examples
|
name: Run Simple Examples
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
|
|
||||||
|
LLM_PROVIDER: openai
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDER: openai
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -117,10 +157,20 @@ jobs:
|
||||||
name: Run Basic Graph Tests
|
name: Run Basic Graph Tests
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
|
|
||||||
|
LLM_PROVIDER: openai
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDER: openai
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
|
||||||
2
.github/workflows/graph_db_tests.yml
vendored
2
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -1,4 +1,4 @@
|
||||||
name: Reusable Vector DB Tests
|
name: Reusable Graph DB Tests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_call:
|
workflow_call:
|
||||||
|
|
|
||||||
41
.github/workflows/python_version_tests.yml
vendored
41
.github/workflows/python_version_tests.yml
vendored
|
|
@ -8,26 +8,30 @@ on:
|
||||||
type: string
|
type: string
|
||||||
default: '["3.10.x", "3.11.x", "3.12.x"]'
|
default: '["3.10.x", "3.11.x", "3.12.x"]'
|
||||||
secrets:
|
secrets:
|
||||||
OPENAI_API_KEY:
|
|
||||||
required: true
|
|
||||||
GRAPHISTRY_USERNAME:
|
GRAPHISTRY_USERNAME:
|
||||||
required: true
|
required: true
|
||||||
GRAPHISTRY_PASSWORD:
|
GRAPHISTRY_PASSWORD:
|
||||||
required: true
|
required: true
|
||||||
|
LLM_PROVIDER:
|
||||||
|
required: true
|
||||||
LLM_MODEL:
|
LLM_MODEL:
|
||||||
required: false
|
required: true
|
||||||
LLM_ENDPOINT:
|
LLM_ENDPOINT:
|
||||||
required: false
|
required: true
|
||||||
|
LLM_API_KEY:
|
||||||
|
required: true
|
||||||
LLM_API_VERSION:
|
LLM_API_VERSION:
|
||||||
required: false
|
required: true
|
||||||
|
EMBEDDING_PROVIDER:
|
||||||
|
required: true
|
||||||
EMBEDDING_MODEL:
|
EMBEDDING_MODEL:
|
||||||
required: false
|
required: true
|
||||||
EMBEDDING_ENDPOINT:
|
EMBEDDING_ENDPOINT:
|
||||||
required: false
|
required: true
|
||||||
EMBEDDING_API_KEY:
|
EMBEDDING_API_KEY:
|
||||||
required: false
|
required: true
|
||||||
EMBEDDING_API_VERSION:
|
EMBEDDING_API_VERSION:
|
||||||
required: false
|
required: true
|
||||||
|
|
||||||
env:
|
env:
|
||||||
RUNTIME__LOG_LEVEL: ERROR
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
|
@ -55,6 +59,18 @@ jobs:
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: poetry run pytest cognee/tests/unit/
|
run: poetry run pytest cognee/tests/unit/
|
||||||
|
env:
|
||||||
|
LLM_PROVIDER: openai
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDER: openai
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
|
||||||
- name: Run integration tests
|
- name: Run integration tests
|
||||||
if: ${{ !contains(matrix.os, 'windows') }}
|
if: ${{ !contains(matrix.os, 'windows') }}
|
||||||
|
|
@ -62,13 +78,16 @@ jobs:
|
||||||
|
|
||||||
- name: Run default basic pipeline
|
- name: Run default basic pipeline
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
|
||||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
|
|
||||||
|
LLM_PROVIDER: openai
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDER: openai
|
||||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ on:
|
||||||
type: string
|
type: string
|
||||||
default: '3.11.x'
|
default: '3.11.x'
|
||||||
secrets:
|
secrets:
|
||||||
|
LLM_PROVIDER:
|
||||||
|
required: true
|
||||||
LLM_MODEL:
|
LLM_MODEL:
|
||||||
required: true
|
required: true
|
||||||
LLM_ENDPOINT:
|
LLM_ENDPOINT:
|
||||||
|
|
@ -16,6 +18,8 @@ on:
|
||||||
required: true
|
required: true
|
||||||
LLM_API_VERSION:
|
LLM_API_VERSION:
|
||||||
required: true
|
required: true
|
||||||
|
EMBEDDING_PROVIDER:
|
||||||
|
required: true
|
||||||
EMBEDDING_MODEL:
|
EMBEDDING_MODEL:
|
||||||
required: true
|
required: true
|
||||||
EMBEDDING_ENDPOINT:
|
EMBEDDING_ENDPOINT:
|
||||||
|
|
@ -24,12 +28,6 @@ on:
|
||||||
required: true
|
required: true
|
||||||
EMBEDDING_API_VERSION:
|
EMBEDDING_API_VERSION:
|
||||||
required: true
|
required: true
|
||||||
OPENAI_API_KEY:
|
|
||||||
required: true
|
|
||||||
GRAPHISTRY_USERNAME:
|
|
||||||
required: true
|
|
||||||
GRAPHISTRY_PASSWORD:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run-relational-db-migration-test-networkx:
|
run-relational-db-migration-test-networkx:
|
||||||
|
|
@ -81,10 +79,13 @@ jobs:
|
||||||
- name: Run relational db test
|
- name: Run relational db test
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
LLM_PROVIDER: openai
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDER: openai
|
||||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
|
@ -141,10 +142,14 @@ jobs:
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||||
|
|
||||||
|
LLM_PROVIDER: openai
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDER: openai
|
||||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
|
@ -204,10 +209,14 @@ jobs:
|
||||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||||
|
|
||||||
|
LLM_PROVIDER: openai
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDER: openai
|
||||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
"""Add default user
|
|
||||||
|
|
||||||
Revision ID: 482cd6517ce4
|
|
||||||
Revises: 8057ae7329c2
|
|
||||||
Create Date: 2024-10-16 22:17:18.634638
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from sqlalchemy.util import await_only
|
|
||||||
|
|
||||||
from cognee.modules.users.methods import create_default_user, delete_user
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "482cd6517ce4"
|
|
||||||
down_revision: Union[str, None] = "8057ae7329c2"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
await_only(create_default_user())
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
await_only(delete_user("default_user@example.com"))
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
"""Initial migration
|
|
||||||
|
|
||||||
Revision ID: 8057ae7329c2
|
|
||||||
Revises:
|
|
||||||
Create Date: 2024-10-02 12:55:20.989372
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
from sqlalchemy.util import await_only
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "8057ae7329c2"
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
db_engine = get_relational_engine()
|
|
||||||
await_only(db_engine.create_database())
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
db_engine = get_relational_engine()
|
|
||||||
await_only(db_engine.delete_database())
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from cognee.modules.search.types import SearchType
|
|
||||||
from cognee.modules.users.exceptions import UserNotFoundError
|
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.search.methods import search as search_function
|
from cognee.modules.search.methods import search as search_function
|
||||||
|
|
||||||
|
|
@ -22,9 +21,6 @@ async def search(
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
if user is None:
|
|
||||||
raise UserNotFoundError
|
|
||||||
|
|
||||||
filtered_search_results = await search_function(
|
filtered_search_results = await search_function(
|
||||||
query_text,
|
query_text,
|
||||||
query_type,
|
query_type,
|
||||||
|
|
|
||||||
|
|
@ -10,4 +10,5 @@ from .exceptions import (
|
||||||
ServiceError,
|
ServiceError,
|
||||||
InvalidValueError,
|
InvalidValueError,
|
||||||
InvalidAttributeError,
|
InvalidAttributeError,
|
||||||
|
CriticalError,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -53,3 +53,7 @@ class InvalidAttributeError(CogneeApiError):
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class CriticalError(CogneeApiError):
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -7,4 +7,5 @@ This module defines a set of exceptions for handling various database errors
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
EntityNotFoundError,
|
EntityNotFoundError,
|
||||||
EntityAlreadyExistsError,
|
EntityAlreadyExistsError,
|
||||||
|
DatabaseNotCreatedError,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,15 @@
|
||||||
from cognee.exceptions import CogneeApiError
|
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
|
from cognee.exceptions import CogneeApiError, CriticalError
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseNotCreatedError(CriticalError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "The database has not been created yet. Please call `await setup()` first.",
|
||||||
|
name: str = "DatabaseNotCreatedError",
|
||||||
|
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class EntityNotFoundError(CogneeApiError):
|
class EntityNotFoundError(CogneeApiError):
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
from typing import Protocol, Optional, Dict, Any, List, Tuple
|
|
||||||
from abc import abstractmethod, ABC
|
|
||||||
from uuid import UUID, uuid5, NAMESPACE_DNS
|
|
||||||
from cognee.modules.graph.relationship_manager import create_relationship
|
|
||||||
from functools import wraps
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from functools import wraps
|
||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional, Dict, Any, List, Tuple
|
||||||
|
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
||||||
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -44,20 +44,16 @@ def record_graph_changes(func):
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
if func.__name__ == "add_nodes":
|
if func.__name__ == "add_nodes":
|
||||||
nodes = args[0]
|
nodes: List[DataPoint] = args[0]
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
try:
|
try:
|
||||||
node_id = (
|
node_id = UUID(str(node.id))
|
||||||
UUID(str(node[0])) if isinstance(node, tuple) else UUID(str(node.id))
|
|
||||||
)
|
|
||||||
relationship = GraphRelationshipLedger(
|
relationship = GraphRelationshipLedger(
|
||||||
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||||
source_node_id=node_id,
|
source_node_id=node_id,
|
||||||
destination_node_id=node_id,
|
destination_node_id=node_id,
|
||||||
creator_function=f"{creator}.node",
|
creator_function=f"{creator}.node",
|
||||||
node_label=node[1].get("type")
|
node_label=getattr(node, "name", None) or str(node.id),
|
||||||
if isinstance(node, tuple)
|
|
||||||
else type(node).__name__,
|
|
||||||
)
|
)
|
||||||
session.add(relationship)
|
session.add(relationship)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
@ -74,7 +70,7 @@ def record_graph_changes(func):
|
||||||
target_id = UUID(str(edge[1]))
|
target_id = UUID(str(edge[1]))
|
||||||
rel_type = str(edge[2])
|
rel_type = str(edge[2])
|
||||||
relationship = GraphRelationshipLedger(
|
relationship = GraphRelationshipLedger(
|
||||||
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||||
source_node_id=source_id,
|
source_node_id=source_id,
|
||||||
destination_node_id=target_id,
|
destination_node_id=target_id,
|
||||||
creator_function=f"{creator}.{rel_type}",
|
creator_function=f"{creator}.{rel_type}",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .exceptions import CollectionNotFoundError
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
from fastapi import status
|
||||||
|
from cognee.exceptions import CriticalError
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionNotFoundError(CriticalError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message,
|
||||||
|
name: str = "DatabaseNotCreatedError",
|
||||||
|
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Generic, List, Optional, TypeVar, get_type_hints
|
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||||
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
|
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
|
|
@ -79,7 +80,6 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
connection = await self.get_connection()
|
connection = await self.get_connection()
|
||||||
|
|
||||||
payload_schema = type(data_points[0])
|
payload_schema = type(data_points[0])
|
||||||
payload_schema = self.get_data_point_schema(payload_schema)
|
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
await self.create_collection(
|
await self.create_collection(
|
||||||
|
|
@ -194,12 +194,19 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
connection = await self.get_connection()
|
connection = await self.get_connection()
|
||||||
collection = await connection.open_table(collection_name)
|
|
||||||
|
try:
|
||||||
|
collection = await connection.open_table(collection_name)
|
||||||
|
except ValueError:
|
||||||
|
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||||
|
|
||||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||||
|
|
||||||
result_values = list(results.to_dict("index").values())
|
result_values = list(results.to_dict("index").values())
|
||||||
|
|
||||||
|
if not result_values:
|
||||||
|
return []
|
||||||
|
|
||||||
normalized_values = normalize_distances(result_values)
|
normalized_values = normalize_distances(result_values)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -288,11 +295,33 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
if self.url.startswith("/"):
|
if self.url.startswith("/"):
|
||||||
LocalStorage.remove_all(self.url)
|
LocalStorage.remove_all(self.url)
|
||||||
|
|
||||||
def get_data_point_schema(self, model_type):
|
def get_data_point_schema(self, model_type: BaseModel):
|
||||||
|
related_models_fields = []
|
||||||
|
|
||||||
|
for field_name, field_config in model_type.model_fields.items():
|
||||||
|
if hasattr(field_config, "model_fields"):
|
||||||
|
related_models_fields.append(field_name)
|
||||||
|
|
||||||
|
elif hasattr(field_config.annotation, "model_fields"):
|
||||||
|
related_models_fields.append(field_name)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
get_origin(field_config.annotation) == Union
|
||||||
|
or get_origin(field_config.annotation) is list
|
||||||
|
):
|
||||||
|
models_list = get_args(field_config.annotation)
|
||||||
|
if any(hasattr(model, "model_fields") for model in models_list):
|
||||||
|
related_models_fields.append(field_name)
|
||||||
|
|
||||||
|
elif get_origin(field_config.annotation) == Optional:
|
||||||
|
model = get_args(field_config.annotation)
|
||||||
|
if hasattr(model, "model_fields"):
|
||||||
|
related_models_fields.append(field_name)
|
||||||
|
|
||||||
return copy_model(
|
return copy_model(
|
||||||
model_type,
|
model_type,
|
||||||
include_fields={
|
include_fields={
|
||||||
"id": (str, ...),
|
"id": (str, ...),
|
||||||
},
|
},
|
||||||
exclude_fields=["metadata"],
|
exclude_fields=["metadata"] + related_models_fields,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||||
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
@ -183,7 +184,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
if collection_name in metadata.tables:
|
if collection_name in metadata.tables:
|
||||||
return metadata.tables[collection_name]
|
return metadata.tables[collection_name]
|
||||||
else:
|
else:
|
||||||
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
|
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||||
# Get PGVectorDataPoint Table from database
|
# Get PGVectorDataPoint Table from database
|
||||||
|
|
@ -244,6 +245,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
except EntityNotFoundError:
|
except EntityNotFoundError:
|
||||||
# Ignore if collection does not exist
|
# Ignore if collection does not exist
|
||||||
return []
|
return []
|
||||||
|
except CollectionNotFoundError:
|
||||||
|
# Ignore if collection does not exist
|
||||||
|
return []
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from pydantic import model_validator, Field
|
from pydantic import model_validator
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseSettings):
|
class LLMConfig(BaseSettings):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from uuid import uuid5, NAMESPACE_DNS
|
from uuid import uuid5, NAMESPACE_OID
|
||||||
from sqlalchemy import UUID, Column, DateTime, String, Index
|
from sqlalchemy import UUID, Column, DateTime, String, Index
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import Base
|
from cognee.infrastructure.databases.relational import Base
|
||||||
|
|
||||||
|
|
@ -12,7 +11,7 @@ class GraphRelationshipLedger(Base):
|
||||||
id = Column(
|
id = Column(
|
||||||
UUID,
|
UUID,
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
default=lambda: uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
default=lambda: uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||||
)
|
)
|
||||||
source_node_id = Column(UUID, nullable=False)
|
source_node_id = Column(UUID, nullable=False)
|
||||||
destination_node_id = Column(UUID, nullable=False)
|
destination_node_id = Column(UUID, nullable=False)
|
||||||
|
|
|
||||||
|
|
@ -111,9 +111,6 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
print(f"Error projecting graph: {e}")
|
print(f"Error projecting graph: {e}")
|
||||||
raise e
|
raise e
|
||||||
except Exception as ex:
|
|
||||||
print(f"Unexpected error: {ex}")
|
|
||||||
raise ex
|
|
||||||
|
|
||||||
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
||||||
for category, scored_results in node_distances.items():
|
for category, scored_results in node_distances.items():
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,28 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||||
|
|
||||||
|
|
||||||
class ChunksRetriever(BaseRetriever):
|
class ChunksRetriever(BaseRetriever):
|
||||||
"""Retriever for handling document chunk-based searches."""
|
"""Retriever for handling document chunk-based searches."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
top_k: Optional[int] = 5,
|
||||||
|
):
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> Any:
|
||||||
"""Retrieves document chunks context based on the query."""
|
"""Retrieves document chunks context based on the query."""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
|
|
||||||
|
try:
|
||||||
|
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||||
|
except CollectionNotFoundError as error:
|
||||||
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
||||||
return [result.payload for result in found_chunks]
|
return [result.payload for result in found_chunks]
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
|
|
||||||
|
|
||||||
class CompletionRetriever(BaseRetriever):
|
class CompletionRetriever(BaseRetriever):
|
||||||
|
|
@ -20,15 +21,21 @@ class CompletionRetriever(BaseRetriever):
|
||||||
self.system_prompt_path = system_prompt_path
|
self.system_prompt_path = system_prompt_path
|
||||||
self.top_k = top_k if top_k is not None else 1
|
self.top_k = top_k if top_k is not None else 1
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> str:
|
||||||
"""Retrieves relevant document chunks as context."""
|
"""Retrieves relevant document chunks as context."""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
|
||||||
if len(found_chunks) == 0:
|
try:
|
||||||
raise NoRelevantDataFound
|
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||||
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
|
||||||
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
if len(found_chunks) == 0:
|
||||||
return "\n".join(chunks_payload)
|
return ""
|
||||||
|
|
||||||
|
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
||||||
|
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
||||||
|
return "\n".join(chunks_payload)
|
||||||
|
except CollectionNotFoundError as error:
|
||||||
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Generates an LLM completion using the context."""
|
"""Generates an LLM completion using the context."""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from cognee.exceptions import CogneeApiError
|
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
|
from cognee.exceptions import CogneeApiError, CriticalError
|
||||||
|
|
||||||
|
|
||||||
class CollectionDistancesNotFoundError(CogneeApiError):
|
class CollectionDistancesNotFoundError(CogneeApiError):
|
||||||
|
|
@ -30,3 +30,7 @@ class CypherSearchError(CogneeApiError):
|
||||||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class NoDataError(CriticalError):
|
||||||
|
message: str = "No data found in the system, please add data first."
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@ from collections import Counter
|
||||||
import string
|
import string
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionRetriever(BaseRetriever):
|
class GraphCompletionRetriever(BaseRetriever):
|
||||||
|
|
@ -72,14 +72,18 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
query, top_k=self.top_k, collections=vector_index_collections or None
|
query, top_k=self.top_k, collections=vector_index_collections or None
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(found_triplets) == 0:
|
|
||||||
raise NoRelevantDataFound
|
|
||||||
|
|
||||||
return found_triplets
|
return found_triplets
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> str:
|
||||||
"""Retrieves and resolves graph triplets into context."""
|
"""Retrieves and resolves graph triplets into context."""
|
||||||
triplets = await self.get_triplets(query)
|
try:
|
||||||
|
triplets = await self.get_triplets(query)
|
||||||
|
except EntityNotFoundError:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if len(triplets) == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
return await self.resolve_edges_to_text(triplets)
|
return await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from typing import Any, Optional
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||||
|
|
||||||
|
|
||||||
class InsightsRetriever(BaseRetriever):
|
class InsightsRetriever(BaseRetriever):
|
||||||
|
|
@ -14,7 +16,7 @@ class InsightsRetriever(BaseRetriever):
|
||||||
self.exploration_levels = exploration_levels
|
self.exploration_levels = exploration_levels
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> list:
|
||||||
"""Find the neighbours of a given node in the graph."""
|
"""Find the neighbours of a given node in the graph."""
|
||||||
if query is None:
|
if query is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -27,10 +29,15 @@ class InsightsRetriever(BaseRetriever):
|
||||||
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
|
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
|
||||||
else:
|
else:
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
results = await asyncio.gather(
|
|
||||||
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
|
try:
|
||||||
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
results = await asyncio.gather(
|
||||||
)
|
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
|
||||||
|
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
||||||
|
)
|
||||||
|
except CollectionNotFoundError as error:
|
||||||
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
||||||
results = [*results[0], *results[1]]
|
results = [*results[0], *results[1]]
|
||||||
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
|
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||||
|
|
||||||
|
|
||||||
class SummariesRetriever(BaseRetriever):
|
class SummariesRetriever(BaseRetriever):
|
||||||
|
|
@ -14,7 +16,14 @@ class SummariesRetriever(BaseRetriever):
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> Any:
|
||||||
"""Retrieves summary context based on the query."""
|
"""Retrieves summary context based on the query."""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit)
|
|
||||||
|
try:
|
||||||
|
summaries_results = await vector_engine.search(
|
||||||
|
"TextSummary_text", query, limit=self.limit
|
||||||
|
)
|
||||||
|
except CollectionNotFoundError as error:
|
||||||
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
||||||
return [summary.payload for summary in summaries_results]
|
return [summary.payload for summary in summaries_results]
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
|
|
|
||||||
|
|
@ -82,9 +82,6 @@ async def brute_force_triplet_search(
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
if user is None:
|
|
||||||
raise PermissionError("No user found in the system. Please create a user.")
|
|
||||||
|
|
||||||
retrieved_results = await brute_force_search(
|
retrieved_results = await brute_force_search(
|
||||||
query,
|
query,
|
||||||
user,
|
user,
|
||||||
|
|
@ -174,4 +171,4 @@ async def brute_force_search(
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
|
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
|
||||||
)
|
)
|
||||||
raise RuntimeError("An error occurred during brute force search") from error
|
raise error
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,6 @@ async def code_description_to_code_part_search(
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
if user is None:
|
|
||||||
raise PermissionError("No user found in the system. Please create a user.")
|
|
||||||
|
|
||||||
retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs)
|
retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs)
|
||||||
return retrieved_codeparts
|
return retrieved_codeparts
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,11 @@ from types import SimpleNamespace
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.modules.users.exceptions.exceptions import UserNotFoundError
|
||||||
|
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.users.methods.create_default_user import create_default_user
|
from cognee.modules.users.methods.create_default_user import create_default_user
|
||||||
from cognee.base_config import get_base_config
|
|
||||||
|
|
||||||
|
|
||||||
async def get_default_user() -> SimpleNamespace:
|
async def get_default_user() -> SimpleNamespace:
|
||||||
|
|
@ -12,16 +14,24 @@ async def get_default_user() -> SimpleNamespace:
|
||||||
base_config = get_base_config()
|
base_config = get_base_config()
|
||||||
default_email = base_config.default_user_email or "default_user@example.com"
|
default_email = base_config.default_user_email or "default_user@example.com"
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
try:
|
||||||
query = select(User).options(selectinload(User.roles)).where(User.email == default_email)
|
async with db_engine.get_async_session() as session:
|
||||||
|
query = (
|
||||||
|
select(User).options(selectinload(User.roles)).where(User.email == default_email)
|
||||||
|
)
|
||||||
|
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
user = result.scalars().first()
|
user = result.scalars().first()
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
return await create_default_user()
|
return await create_default_user()
|
||||||
|
|
||||||
# We return a SimpleNamespace to have the same user type as our SaaS
|
# We return a SimpleNamespace to have the same user type as our SaaS
|
||||||
# SimpleNamespace is just a dictionary which can be accessed through attributes
|
# SimpleNamespace is just a dictionary which can be accessed through attributes
|
||||||
auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[])
|
auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[])
|
||||||
return auth_data
|
return auth_data
|
||||||
|
except Exception as error:
|
||||||
|
if "principals" in str(error.args):
|
||||||
|
raise DatabaseNotCreatedError() from error
|
||||||
|
|
||||||
|
raise UserNotFoundError(f"Failed to retrieve default user: {default_email}") from error
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
from cognee.tasks.completion.exceptions import NoRelevantDataError
|
||||||
|
|
|
||||||
|
|
@ -5,5 +5,5 @@ This module defines a set of exceptions for handling various compute errors
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
NoRelevantDataFound,
|
NoRelevantDataError,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,11 @@ from cognee.exceptions import CogneeApiError
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
|
|
||||||
|
|
||||||
class NoRelevantDataFound(CogneeApiError):
|
class NoRelevantDataError(CogneeApiError):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "Search did not find any data.",
|
message: str = "Search did not find any data.",
|
||||||
name: str = "NoRelevantDataFound",
|
name: str = "NoRelevantDataError",
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
import os
|
import asyncio
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from functools import lru_cache
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.llm.rate_limiter import (
|
from cognee.infrastructure.llm.rate_limiter import (
|
||||||
sleep_and_retry_sync,
|
sleep_and_retry_sync,
|
||||||
|
|
|
||||||
|
|
@ -1,120 +1,195 @@
|
||||||
import uuid
|
import os
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||||
|
|
||||||
|
|
||||||
class TestChunksRetriever:
|
class TestChunksRetriever:
|
||||||
@pytest.fixture
|
@pytest.mark.asyncio
|
||||||
def mock_retriever(self):
|
async def test_chunk_context_simple(self):
|
||||||
return ChunksRetriever()
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document = TextDocument(
|
||||||
|
name="Steve Rodger's career",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = ChunksRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Mike")
|
||||||
|
|
||||||
|
assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
async def test_chunk_context_complex(self):
|
||||||
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
|
system_directory_path = os.path.join(
|
||||||
# Setup
|
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||||
query = "test query"
|
)
|
||||||
doc_id1 = str(uuid.uuid4())
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
doc_id2 = str(uuid.uuid4())
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
# Mock search results
|
await cognee.prune.prune_data()
|
||||||
mock_result_1 = MagicMock()
|
await cognee.prune.prune_system(metadata=True)
|
||||||
mock_result_1.payload = {
|
await setup()
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"text": "This is the first chunk result.",
|
|
||||||
"document_id": doc_id1,
|
|
||||||
"metadata": {"title": "Document 1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_result_2 = MagicMock()
|
document1 = TextDocument(
|
||||||
mock_result_2.payload = {
|
name="Employee List",
|
||||||
"id": str(uuid.uuid4()),
|
raw_data_location="somewhere",
|
||||||
"text": "This is the second chunk result.",
|
external_metadata="",
|
||||||
"document_id": doc_id2,
|
mime_type="text/plain",
|
||||||
"metadata": {"title": "Document 2"},
|
)
|
||||||
}
|
|
||||||
|
|
||||||
mock_search_results = [mock_result_1, mock_result_2]
|
document2 = TextDocument(
|
||||||
mock_vector_engine = AsyncMock()
|
name="Car List",
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
raw_data_location="somewhere",
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
# Execute
|
chunk1 = DocumentChunk(
|
||||||
results = await mock_retriever.get_completion(query)
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
# Verify
|
chunk4 = DocumentChunk(
|
||||||
assert len(results) == 2
|
text="Range Rover",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk5 = DocumentChunk(
|
||||||
|
text="Hyundai",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk6 = DocumentChunk(
|
||||||
|
text="Chrysler",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
# Check first result
|
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||||
assert results[0]["text"] == "This is the first chunk result."
|
|
||||||
assert results[0]["document_id"] == doc_id1
|
|
||||||
assert results[0]["metadata"]["title"] == "Document 1"
|
|
||||||
|
|
||||||
# Check second result
|
await add_data_points(entities)
|
||||||
assert results[1]["text"] == "This is the second chunk result."
|
|
||||||
assert results[1]["document_id"] == doc_id2
|
|
||||||
assert results[1]["metadata"]["title"] == "Document 2"
|
|
||||||
|
|
||||||
# Verify search was called correctly
|
retriever = ChunksRetriever(top_k=20)
|
||||||
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
|
|
||||||
|
context = await retriever.get_context("Christina")
|
||||||
|
|
||||||
|
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
async def test_chunk_context_on_empty_graph(self):
|
||||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
system_directory_path = os.path.join(
|
||||||
# Setup
|
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||||
query = "test query with no results"
|
)
|
||||||
mock_search_results = []
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
mock_vector_engine = AsyncMock()
|
data_directory_path = os.path.join(
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
# Execute
|
await cognee.prune.prune_data()
|
||||||
results = await mock_retriever.get_completion(query)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
# Verify
|
retriever = ChunksRetriever()
|
||||||
assert len(results) == 0
|
|
||||||
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
with pytest.raises(NoDataError):
|
||||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
await retriever.get_context("Christina Mayer")
|
||||||
async def test_get_completion_with_missing_fields(self, mock_get_vector_engine, mock_retriever):
|
|
||||||
# Setup
|
|
||||||
query = "test query with incomplete data"
|
|
||||||
|
|
||||||
# Mock search results
|
vector_engine = get_vector_engine()
|
||||||
mock_result_1 = MagicMock()
|
await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk)
|
||||||
mock_result_1.payload = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"text": "This chunk has no document_id.",
|
|
||||||
# Missing document_id and metadata
|
|
||||||
}
|
|
||||||
mock_result_2 = MagicMock()
|
|
||||||
mock_result_2.payload = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
# Missing text
|
|
||||||
"document_id": str(uuid.uuid4()),
|
|
||||||
"metadata": {"title": "Document with missing text"},
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_search_results = [mock_result_1, mock_result_2]
|
context = await retriever.get_context("Christina Mayer")
|
||||||
mock_vector_engine = AsyncMock()
|
assert len(context) == 0, "Found chunks when none should exist"
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await mock_retriever.get_completion(query)
|
|
||||||
|
|
||||||
# Verify
|
if __name__ == "__main__":
|
||||||
assert len(results) == 2
|
from asyncio import run
|
||||||
|
|
||||||
# First result should have content but no document_id
|
test = TestChunksRetriever()
|
||||||
assert results[0]["text"] == "This chunk has no document_id."
|
|
||||||
assert "document_id" not in results[0]
|
|
||||||
assert "metadata" not in results[0]
|
|
||||||
|
|
||||||
# Second result should have document_id and metadata but no content
|
run(test.test_chunk_context_simple())
|
||||||
assert "text" not in results[1]
|
run(test.test_chunk_context_complex())
|
||||||
assert "document_id" in results[1]
|
run(test.test_chunk_context_on_empty_graph())
|
||||||
assert results[1]["metadata"]["title"] == "Document with missing text"
|
|
||||||
|
|
|
||||||
|
|
@ -1,84 +0,0 @@
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompletionRetriever:
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_retriever(self):
|
|
||||||
return CompletionRetriever(system_prompt_path="test_prompt.txt")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
|
||||||
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
|
|
||||||
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
|
|
||||||
async def test_get_completion(
|
|
||||||
self, mock_get_vector_engine, mock_render_prompt, mock_get_llm_client, mock_retriever
|
|
||||||
):
|
|
||||||
# Setup
|
|
||||||
query = "test query"
|
|
||||||
|
|
||||||
# Mock render_prompt
|
|
||||||
mock_render_prompt.return_value = "Rendered prompt with context"
|
|
||||||
|
|
||||||
mock_search_results = [MagicMock()]
|
|
||||||
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
|
|
||||||
mock_vector_engine = AsyncMock()
|
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
|
||||||
|
|
||||||
# Mock LLM client
|
|
||||||
mock_llm_client = MagicMock()
|
|
||||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
|
||||||
mock_llm_client.acreate_structured_output.return_value = "Generated completion response"
|
|
||||||
mock_get_llm_client.return_value = mock_llm_client
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await mock_retriever.get_completion(query)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0] == "Generated completion response"
|
|
||||||
|
|
||||||
# Verify prompt was rendered
|
|
||||||
mock_render_prompt.assert_called_once()
|
|
||||||
|
|
||||||
# Verify LLM client was called
|
|
||||||
mock_llm_client.acreate_structured_output.assert_called_once_with(
|
|
||||||
text_input="Rendered prompt with context", system_prompt=None, response_model=str
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("cognee.modules.retrieval.completion_retriever.generate_completion")
|
|
||||||
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
|
|
||||||
async def test_get_completion_with_custom_prompt(
|
|
||||||
self, mock_get_vector_engine, mock_generate_completion, mock_retriever
|
|
||||||
):
|
|
||||||
# Setup
|
|
||||||
query = "test query with custom prompt"
|
|
||||||
|
|
||||||
mock_search_results = [MagicMock()]
|
|
||||||
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
|
|
||||||
mock_vector_engine = AsyncMock()
|
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
|
||||||
|
|
||||||
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
|
|
||||||
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
|
|
||||||
|
|
||||||
mock_generate_completion.return_value = "Custom prompt completion response"
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await mock_retriever.get_completion(query)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0] == "Custom prompt completion response"
|
|
||||||
|
|
||||||
assert mock_generate_completion.call_args[1]["user_prompt_path"] == "custom_user_prompt.txt"
|
|
||||||
assert (
|
|
||||||
mock_generate_completion.call_args[1]["system_prompt_path"]
|
|
||||||
== "custom_system_prompt.txt"
|
|
||||||
)
|
|
||||||
|
|
@ -1,236 +1,159 @@
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.graph.exceptions import EntityNotFoundError
|
|
||||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphCompletionRetriever:
|
class TestGraphCompletionRetriever:
|
||||||
@pytest.fixture
|
|
||||||
def mock_retriever(self):
|
|
||||||
return GraphCompletionRetriever(system_prompt_path="test_prompt.txt")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
|
async def test_graph_completion_context_simple(self):
|
||||||
async def test_get_triplets_success(self, mock_brute_force_triplet_search, mock_retriever):
|
system_directory_path = os.path.join(
|
||||||
mock_brute_force_triplet_search.return_value = [
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||||
AsyncMock(
|
|
||||||
node1=AsyncMock(attributes={"text": "Node A"}),
|
|
||||||
attributes={"relationship_type": "connects"},
|
|
||||||
node2=AsyncMock(attributes={"text": "Node B"}),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await mock_retriever.get_triplets("test query")
|
|
||||||
|
|
||||||
assert isinstance(result, list)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert result[0].attributes["relationship_type"] == "connects"
|
|
||||||
mock_brute_force_triplet_search.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
|
|
||||||
async def test_get_triplets_no_results(self, mock_brute_force_triplet_search, mock_retriever):
|
|
||||||
mock_brute_force_triplet_search.return_value = []
|
|
||||||
|
|
||||||
with pytest.raises(NoRelevantDataFound):
|
|
||||||
await mock_retriever.get_triplets("test query")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_resolve_edges_to_text(self, mock_retriever):
|
|
||||||
node_a = AsyncMock(id="node_a_id", attributes={"text": "Node A text content"})
|
|
||||||
node_b = AsyncMock(id="node_b_id", attributes={"text": "Node B text content"})
|
|
||||||
node_c = AsyncMock(id="node_c_id", attributes={"name": "Node C"})
|
|
||||||
|
|
||||||
triplets = [
|
|
||||||
AsyncMock(
|
|
||||||
node1=node_a,
|
|
||||||
attributes={"relationship_type": "connects"},
|
|
||||||
node2=node_b,
|
|
||||||
),
|
|
||||||
AsyncMock(
|
|
||||||
node1=node_a,
|
|
||||||
attributes={"relationship_type": "links"},
|
|
||||||
node2=node_c,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(mock_retriever, "_get_title", return_value="Test Title"):
|
|
||||||
result = await mock_retriever.resolve_edges_to_text(triplets)
|
|
||||||
|
|
||||||
assert "Nodes:" in result
|
|
||||||
assert "Connections:" in result
|
|
||||||
|
|
||||||
assert "Node: Test Title" in result
|
|
||||||
assert "__node_content_start__" in result
|
|
||||||
assert "Node A text content" in result
|
|
||||||
assert "__node_content_end__" in result
|
|
||||||
assert "Node: Node C" in result
|
|
||||||
|
|
||||||
assert "Test Title --[connects]--> Test Title" in result
|
|
||||||
assert "Test Title --[links]--> Node C" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch(
|
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_triplets",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
)
|
|
||||||
@patch(
|
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
)
|
|
||||||
async def test_get_context(self, mock_resolve_edges_to_text, mock_get_triplets, mock_retriever):
|
|
||||||
"""Test get_context calls get_triplets and resolve_edges_to_text."""
|
|
||||||
mock_get_triplets.return_value = ["mock_triplet"]
|
|
||||||
mock_resolve_edges_to_text.return_value = "Mock Context"
|
|
||||||
|
|
||||||
result = await mock_retriever.get_context("test query")
|
|
||||||
|
|
||||||
assert result == "Mock Context"
|
|
||||||
mock_get_triplets.assert_called_once_with("test query")
|
|
||||||
mock_resolve_edges_to_text.assert_called_once_with(["mock_triplet"])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch(
|
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
|
|
||||||
)
|
|
||||||
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
|
|
||||||
async def test_get_completion_without_context(
|
|
||||||
self, mock_generate_completion, mock_get_context, mock_retriever
|
|
||||||
):
|
|
||||||
"""Test get_completion when no context is provided (calls get_context)."""
|
|
||||||
mock_get_context.return_value = "Mock Context"
|
|
||||||
mock_generate_completion.return_value = "Generated Completion"
|
|
||||||
|
|
||||||
result = await mock_retriever.get_completion("test query")
|
|
||||||
|
|
||||||
assert result == ["Generated Completion"]
|
|
||||||
mock_get_context.assert_called_once_with("test query")
|
|
||||||
mock_generate_completion.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch(
|
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
|
|
||||||
)
|
|
||||||
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
|
|
||||||
async def test_get_completion_with_context(
|
|
||||||
self, mock_generate_completion, mock_get_context, mock_retriever
|
|
||||||
):
|
|
||||||
"""Test get_completion when context is provided (does not call get_context)."""
|
|
||||||
mock_generate_completion.return_value = "Generated Completion"
|
|
||||||
|
|
||||||
result = await mock_retriever.get_completion("test query", context="Provided Context")
|
|
||||||
|
|
||||||
assert result == ["Generated Completion"]
|
|
||||||
mock_get_context.assert_not_called()
|
|
||||||
mock_generate_completion.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
|
||||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine")
|
|
||||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
|
|
||||||
async def test_get_completion_with_empty_graph(
|
|
||||||
self,
|
|
||||||
mock_get_default_user,
|
|
||||||
mock_get_graph_engine,
|
|
||||||
mock_get_llm_client,
|
|
||||||
mock_retriever,
|
|
||||||
):
|
|
||||||
query = "test query with empty graph"
|
|
||||||
|
|
||||||
mock_graph_engine = MagicMock()
|
|
||||||
mock_graph_engine.get_graph_data = AsyncMock()
|
|
||||||
mock_graph_engine.get_graph_data.return_value = ([], [])
|
|
||||||
mock_get_graph_engine.return_value = mock_graph_engine
|
|
||||||
|
|
||||||
mock_llm_client = MagicMock()
|
|
||||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
|
||||||
mock_llm_client.acreate_structured_output.return_value = (
|
|
||||||
"Generated graph completion response"
|
|
||||||
)
|
)
|
||||||
mock_get_llm_client.return_value = mock_llm_client
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
with pytest.raises(EntityNotFoundError):
|
await cognee.prune.prune_data()
|
||||||
await mock_retriever.get_completion(query)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
mock_graph_engine.get_graph_data.assert_called_once()
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
def test_top_n_words(self, mock_retriever):
|
class Person(DataPoint):
|
||||||
"""Test extraction of top frequent words from text."""
|
name: str
|
||||||
text = "The quick brown fox jumps over the lazy dog. The fox is quick."
|
works_for: Company
|
||||||
|
|
||||||
result = mock_retriever._top_n_words(text)
|
company1 = Company(name="Figma")
|
||||||
assert len(result.split(", ")) <= 3
|
company2 = Company(name="Canva")
|
||||||
assert "fox" in result
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
assert "quick" in result
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
|
||||||
result = mock_retriever._top_n_words(text, top_n=2)
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
assert len(result.split(", ")) <= 2
|
|
||||||
|
|
||||||
result = mock_retriever._top_n_words(text, separator=" | ")
|
await add_data_points(entities)
|
||||||
assert " | " in result
|
|
||||||
|
|
||||||
result = mock_retriever._top_n_words(text, stop_words={"fox", "quick"})
|
retriever = GraphCompletionRetriever()
|
||||||
assert "fox" not in result
|
|
||||||
assert "quick" not in result
|
|
||||||
|
|
||||||
def test_get_title(self, mock_retriever):
|
context = await retriever.get_context("Who works at Canva?")
|
||||||
"""Test title generation from text."""
|
|
||||||
text = "This is a long paragraph about various topics that should generate a title. The main topics are AI, programming and data science."
|
|
||||||
|
|
||||||
title = mock_retriever._get_title(text)
|
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||||
assert "..." in title
|
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||||
assert "[" in title and "]" in title
|
|
||||||
|
|
||||||
title = mock_retriever._get_title(text, first_n_words=3)
|
@pytest.mark.asyncio
|
||||||
first_part = title.split("...")[0].strip()
|
async def test_graph_completion_context_complex(self):
|
||||||
assert len(first_part.split()) == 3
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
title = mock_retriever._get_title(text, top_n_words=2)
|
await cognee.prune.prune_data()
|
||||||
top_part = title.split("[")[1].split("]")[0]
|
await cognee.prune.prune_system(metadata=True)
|
||||||
assert len(top_part.split(", ")) <= 2
|
await setup()
|
||||||
|
|
||||||
def test_get_nodes(self, mock_retriever):
|
class Company(DataPoint):
|
||||||
"""Test node processing and deduplication."""
|
name: str
|
||||||
node_with_text = AsyncMock(id="text_node", attributes={"text": "This is a text node"})
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
node_with_name = AsyncMock(id="name_node", attributes={"name": "Named Node"})
|
|
||||||
node_without_attrs = AsyncMock(id="empty_node", attributes={})
|
|
||||||
|
|
||||||
edges = [
|
class Car(DataPoint):
|
||||||
AsyncMock(
|
brand: str
|
||||||
node1=node_with_text, node2=node_with_name, attributes={"relationship_type": "rel1"}
|
model: str
|
||||||
),
|
year: int
|
||||||
AsyncMock(
|
|
||||||
node1=node_with_text,
|
class Location(DataPoint):
|
||||||
node2=node_without_attrs,
|
country: str
|
||||||
attributes={"relationship_type": "rel2"},
|
city: str
|
||||||
),
|
|
||||||
AsyncMock(
|
class Home(DataPoint):
|
||||||
node1=node_with_name,
|
location: Location
|
||||||
node2=node_without_attrs,
|
rooms: int
|
||||||
attributes={"relationship_type": "rel3"},
|
sqm: int
|
||||||
),
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
owns: Optional[list[Union[Car, Home]]] = None
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
|
||||||
|
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||||
|
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||||
|
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person2.owns = [
|
||||||
|
Car(brand="Tesla", model="Model S", year=2021),
|
||||||
|
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch.object(mock_retriever, "_get_title", return_value="Generated Title"):
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
nodes = mock_retriever._get_nodes(edges)
|
|
||||||
|
|
||||||
assert len(nodes) == 3
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||||
|
|
||||||
for node_id, info in nodes.items():
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
assert "node" in info
|
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||||
assert "name" in info
|
|
||||||
assert "content" in info
|
|
||||||
|
|
||||||
text_node_info = nodes[node_with_text.id]
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
assert text_node_info["name"] == "Generated Title"
|
|
||||||
assert text_node_info["content"] == "This is a text node"
|
|
||||||
|
|
||||||
name_node_info = nodes[node_with_name.id]
|
await add_data_points(entities)
|
||||||
assert name_node_info["name"] == "Named Node"
|
|
||||||
assert name_node_info["content"] == "Named Node"
|
|
||||||
|
|
||||||
empty_node_info = nodes[node_without_attrs.id]
|
retriever = GraphCompletionRetriever(top_k=20)
|
||||||
assert empty_node_info["name"] == "Unnamed Node"
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
print(context)
|
||||||
|
|
||||||
|
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_completion_context_on_empty_graph(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
retriever = GraphCompletionRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(DatabaseNotCreatedError):
|
||||||
|
await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
assert context == "", "Context should be empty on an empty graph"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from asyncio import run
|
||||||
|
|
||||||
|
test = TestGraphCompletionRetriever()
|
||||||
|
|
||||||
|
run(test.test_graph_completion_context_simple())
|
||||||
|
run(test.test_graph_completion_context_complex())
|
||||||
|
run(test.test_get_graph_completion_context_on_empty_graph())
|
||||||
|
|
|
||||||
|
|
@ -1,80 +0,0 @@
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
|
||||||
GraphSummaryCompletionRetriever,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphSummaryCompletionRetriever:
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_retriever(self):
|
|
||||||
return GraphSummaryCompletionRetriever(system_prompt_path="test_prompt.txt")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
|
||||||
@patch("cognee.modules.retrieval.utils.completion.read_query_prompt")
|
|
||||||
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
|
|
||||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
|
|
||||||
async def test_get_completion_with_custom_system_prompt(
|
|
||||||
self,
|
|
||||||
mock_get_default_user,
|
|
||||||
mock_render_prompt,
|
|
||||||
mock_read_query_prompt,
|
|
||||||
mock_get_llm_client,
|
|
||||||
mock_retriever,
|
|
||||||
):
|
|
||||||
# Setup
|
|
||||||
query = "test query with custom prompt"
|
|
||||||
|
|
||||||
# Set custom system prompt
|
|
||||||
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
|
|
||||||
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
|
|
||||||
|
|
||||||
mock_llm_client = MagicMock()
|
|
||||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
|
||||||
mock_llm_client.acreate_structured_output.return_value = (
|
|
||||||
"Generated graph summary completion response"
|
|
||||||
)
|
|
||||||
mock_get_llm_client.return_value = mock_llm_client
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await mock_retriever.get_completion(query, context="test context")
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert len(results) == 1
|
|
||||||
|
|
||||||
# Verify render_prompt was called with custom prompt path
|
|
||||||
mock_render_prompt.assert_called_once()
|
|
||||||
assert mock_render_prompt.call_args[0][0] == "custom_user_prompt.txt"
|
|
||||||
|
|
||||||
mock_read_query_prompt.assert_called_once()
|
|
||||||
assert mock_read_query_prompt.call_args[0][0] == "custom_system_prompt.txt"
|
|
||||||
|
|
||||||
mock_llm_client.acreate_structured_output.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch(
|
|
||||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text"
|
|
||||||
)
|
|
||||||
@patch(
|
|
||||||
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
)
|
|
||||||
async def test_resolve_edges_to_text_calls_super_and_summarizes(
|
|
||||||
self, mock_summarize_text, mock_resolve_edges_to_text, mock_retriever
|
|
||||||
):
|
|
||||||
"""Test resolve_edges_to_text calls the parent method and summarizes the result."""
|
|
||||||
|
|
||||||
mock_resolve_edges_to_text.return_value = "Raw graph edges text"
|
|
||||||
mock_summarize_text.return_value = "Summarized graph text"
|
|
||||||
|
|
||||||
result = await mock_retriever.resolve_edges_to_text(["mock_edge"])
|
|
||||||
|
|
||||||
mock_resolve_edges_to_text.assert_called_once_with(["mock_edge"])
|
|
||||||
mock_summarize_text.assert_called_once_with(
|
|
||||||
"Raw graph edges text", mock_retriever.summarize_prompt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == "Summarized graph text"
|
|
||||||
|
|
@ -1,103 +1,216 @@
|
||||||
import uuid
|
import os
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
|
||||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
import cognee
|
||||||
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_connected_test_graph
|
from cognee.low_level import setup
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
from cognee.tasks.storage import add_data_points
|
||||||
import unittest
|
from cognee.modules.engine.models import Entity, EntityType
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||||
|
|
||||||
|
|
||||||
class TestInsightsRetriever:
|
class TestInsightsRetriever:
|
||||||
@pytest.fixture
|
@pytest.mark.asyncio
|
||||||
def mock_retriever(self):
|
async def test_insights_context_simple(self):
|
||||||
return InsightsRetriever()
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_simple"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_simple"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
entityTypePerson = EntityType(
|
||||||
|
name="Person",
|
||||||
|
description="An individual",
|
||||||
|
)
|
||||||
|
|
||||||
|
person1 = Entity(
|
||||||
|
name="Steve Rodger",
|
||||||
|
is_a=entityTypePerson,
|
||||||
|
description="An American actor, comedian, and filmmaker",
|
||||||
|
)
|
||||||
|
|
||||||
|
person2 = Entity(
|
||||||
|
name="Mike Broski",
|
||||||
|
is_a=entityTypePerson,
|
||||||
|
description="Financial advisor and philanthropist",
|
||||||
|
)
|
||||||
|
|
||||||
|
person3 = Entity(
|
||||||
|
name="Christina Mayer",
|
||||||
|
is_a=entityTypePerson,
|
||||||
|
description="Maker of next generation of iconic American music videos",
|
||||||
|
)
|
||||||
|
|
||||||
|
entityTypeCompany = EntityType(
|
||||||
|
name="Company",
|
||||||
|
description="An organization that operates on an annual basis",
|
||||||
|
)
|
||||||
|
|
||||||
|
company1 = Entity(
|
||||||
|
name="Apple",
|
||||||
|
is_a=entityTypeCompany,
|
||||||
|
description="An American multinational technology company headquartered in Cupertino, California",
|
||||||
|
)
|
||||||
|
|
||||||
|
company2 = Entity(
|
||||||
|
name="Google",
|
||||||
|
is_a=entityTypeCompany,
|
||||||
|
description="An American multinational technology company that specializes in Internet-related services and products",
|
||||||
|
)
|
||||||
|
|
||||||
|
company3 = Entity(
|
||||||
|
name="Facebook",
|
||||||
|
is_a=entityTypeCompany,
|
||||||
|
description="An American social media, messaging, and online platform",
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [person1, person2, person3, company1, company2, company3]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = InsightsRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Mike")
|
||||||
|
|
||||||
|
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
|
async def test_insights_context_complex(self):
|
||||||
async def test_get_context_with_existing_node(self, mock_get_graph_engine, mock_retriever):
|
system_directory_path = os.path.join(
|
||||||
"""Test get_context when node exists in graph."""
|
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_complex"
|
||||||
mock_graph = AsyncMock()
|
)
|
||||||
mock_get_graph_engine.return_value = mock_graph
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_complex"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
# Mock graph response
|
await cognee.prune.prune_data()
|
||||||
mock_graph.extract_node.return_value = {"id": "123"}
|
await cognee.prune.prune_system(metadata=True)
|
||||||
mock_graph.get_connections.return_value = [
|
await setup()
|
||||||
({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await mock_retriever.get_context("123")
|
entityTypePerson = EntityType(
|
||||||
|
name="Person",
|
||||||
|
description="An individual",
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(result, list)
|
person1 = Entity(
|
||||||
assert len(result) == 1
|
name="Steve Rodger",
|
||||||
assert result[0][0]["id"] == "123"
|
is_a=entityTypePerson,
|
||||||
assert result[0][1]["relationship_name"] == "linked_to"
|
description="An American actor, comedian, and filmmaker",
|
||||||
assert result[0][2]["id"] == "456"
|
)
|
||||||
mock_graph.extract_node.assert_called_once_with("123")
|
|
||||||
mock_graph.get_connections.assert_called_once_with("123")
|
person2 = Entity(
|
||||||
|
name="Mike Broski",
|
||||||
|
is_a=entityTypePerson,
|
||||||
|
description="Financial advisor and philanthropist",
|
||||||
|
)
|
||||||
|
|
||||||
|
person3 = Entity(
|
||||||
|
name="Christina Mayer",
|
||||||
|
is_a=entityTypePerson,
|
||||||
|
description="Maker of next generation of iconic American music videos",
|
||||||
|
)
|
||||||
|
|
||||||
|
person4 = Entity(
|
||||||
|
name="Jason Statham",
|
||||||
|
is_a=entityTypePerson,
|
||||||
|
description="An American actor",
|
||||||
|
)
|
||||||
|
|
||||||
|
person5 = Entity(
|
||||||
|
name="Mike Tyson",
|
||||||
|
is_a=entityTypePerson,
|
||||||
|
description="A former professional boxer from the United States",
|
||||||
|
)
|
||||||
|
|
||||||
|
entityTypeCompany = EntityType(
|
||||||
|
name="Company",
|
||||||
|
description="An organization that operates on an annual basis",
|
||||||
|
)
|
||||||
|
|
||||||
|
company1 = Entity(
|
||||||
|
name="Apple",
|
||||||
|
is_a=entityTypeCompany,
|
||||||
|
description="An American multinational technology company headquartered in Cupertino, California",
|
||||||
|
)
|
||||||
|
|
||||||
|
company2 = Entity(
|
||||||
|
name="Google",
|
||||||
|
is_a=entityTypeCompany,
|
||||||
|
description="An American multinational technology company that specializes in Internet-related services and products",
|
||||||
|
)
|
||||||
|
|
||||||
|
company3 = Entity(
|
||||||
|
name="Facebook",
|
||||||
|
is_a=entityTypeCompany,
|
||||||
|
description="An American social media, messaging, and online platform",
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [person1, person2, person3, company1, company2, company3]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
await graph_engine.add_edges(
|
||||||
|
[
|
||||||
|
(person1.id, company1.id, "works_for"),
|
||||||
|
(person2.id, company2.id, "works_for"),
|
||||||
|
(person3.id, company3.id, "works_for"),
|
||||||
|
(person4.id, company1.id, "works_for"),
|
||||||
|
(person5.id, company1.id, "works_for"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = InsightsRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina")
|
||||||
|
|
||||||
|
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
|
async def test_insights_context_on_empty_graph(self):
|
||||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
system_directory_path = os.path.join(
|
||||||
# Setup
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_empty"
|
||||||
query = "test query with no results"
|
)
|
||||||
mock_search_results = []
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
mock_vector_engine = AsyncMock()
|
data_directory_path = os.path.join(
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_empty"
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
# Execute
|
await cognee.prune.prune_data()
|
||||||
results = await mock_retriever.get_completion(query)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
# Verify
|
retriever = InsightsRetriever()
|
||||||
assert len(results) == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
with pytest.raises(NoDataError):
|
||||||
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
|
await retriever.get_context("Christina Mayer")
|
||||||
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
|
|
||||||
async def test_get_context_with_no_exact_node(
|
|
||||||
self, mock_get_vector_engine, mock_get_graph_engine, mock_retriever
|
|
||||||
):
|
|
||||||
"""Test get_context when node does not exist in the graph and vector search is used."""
|
|
||||||
mock_graph = AsyncMock()
|
|
||||||
mock_get_graph_engine.return_value = mock_graph
|
|
||||||
mock_graph.extract_node.return_value = None # Node does not exist
|
|
||||||
|
|
||||||
mock_vector = AsyncMock()
|
vector_engine = get_vector_engine()
|
||||||
mock_get_vector_engine.return_value = mock_vector
|
await vector_engine.create_collection("Entity_name", payload_schema=Entity)
|
||||||
|
await vector_engine.create_collection("EntityType_name", payload_schema=EntityType)
|
||||||
|
|
||||||
mock_vector.search.side_effect = [
|
context = await retriever.get_context("Christina Mayer")
|
||||||
[AsyncMock(id="vec_1", score=0.4)], # Entity_name search
|
assert context == [], "Returned context should be empty on an empty graph"
|
||||||
[AsyncMock(id="vec_2", score=0.3)], # EntityType_name search
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_graph.get_connections.side_effect = lambda node_id: [
|
|
||||||
({"id": node_id}, {"relationship_name": "related_to"}, {"id": "456"})
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await mock_retriever.get_context("non_existing_query")
|
if __name__ == "__main__":
|
||||||
|
from asyncio import run
|
||||||
|
|
||||||
assert isinstance(result, list)
|
test = TestInsightsRetriever()
|
||||||
assert len(result) == 2
|
|
||||||
assert result[0][0]["id"] == "vec_1"
|
|
||||||
assert result[0][1]["relationship_name"] == "related_to"
|
|
||||||
assert result[0][2]["id"] == "456"
|
|
||||||
|
|
||||||
assert result[1][0]["id"] == "vec_2"
|
run(test.test_insights_context_simple())
|
||||||
assert result[1][1]["relationship_name"] == "related_to"
|
run(test.test_insights_context_complex())
|
||||||
assert result[1][2]["id"] == "456"
|
run(test.test_insights_context_on_empty_graph())
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_context_with_none_query(self, mock_retriever):
|
|
||||||
"""Test get_context with a None query (should return empty list)."""
|
|
||||||
result = await mock_retriever.get_context(None)
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_completion_with_context(self, mock_retriever):
|
|
||||||
"""Test get_completion when context is already provided."""
|
|
||||||
test_context = [({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})]
|
|
||||||
result = await mock_retriever.get_completion("test_query", context=test_context)
|
|
||||||
assert result == test_context
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,196 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGCompletionRetriever:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_completion_context_simple(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document = TextDocument(
|
||||||
|
name="Steve Rodger's career",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = CompletionRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Mike")
|
||||||
|
|
||||||
|
assert context == "Mike Broski", "Failed to get Mike Broski"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_completion_context_complex(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document1 = TextDocument(
|
||||||
|
name="Employee List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
document2 = TextDocument(
|
||||||
|
name="Car List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk4 = DocumentChunk(
|
||||||
|
text="Range Rover",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk5 = DocumentChunk(
|
||||||
|
text="Hyundai",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk6 = DocumentChunk(
|
||||||
|
text="Chrysler",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
# TODO: top_k doesn't affect the output, it should be fixed.
|
||||||
|
retriever = CompletionRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina")
|
||||||
|
|
||||||
|
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_rag_completion_context_on_empty_graph(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
retriever = CompletionRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(NoDataError):
|
||||||
|
await retriever.get_context("Christina Mayer")
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina Mayer")
|
||||||
|
assert context == "", "Returned context should be empty on an empty graph"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from asyncio import run
|
||||||
|
|
||||||
|
test = TestRAGCompletionRetriever()
|
||||||
|
|
||||||
|
run(test.test_rag_completion_context_simple())
|
||||||
|
run(test.test_rag_completion_context_complex())
|
||||||
|
run(test.test_get_rag_completion_context_on_empty_graph())
|
||||||
|
|
@ -1,122 +1,168 @@
|
||||||
import uuid
|
import os
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
|
from cognee.tasks.summarization.models import TextSummary
|
||||||
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||||
|
|
||||||
|
|
||||||
class TestSummariesRetriever:
|
class TextSummariesRetriever:
|
||||||
@pytest.fixture
|
@pytest.mark.asyncio
|
||||||
def mock_retriever(self):
|
async def test_chunk_context(self):
|
||||||
return SummariesRetriever()
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
document1 = TextDocument(
|
||||||
|
name="Employee List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
document2 = TextDocument(
|
||||||
|
name="Car List",
|
||||||
|
raw_data_location="somewhere",
|
||||||
|
external_metadata="",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk1 = DocumentChunk(
|
||||||
|
text="Steve Rodger",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk1_summary = TextSummary(
|
||||||
|
text="S.R.",
|
||||||
|
made_from=chunk1,
|
||||||
|
)
|
||||||
|
chunk2 = DocumentChunk(
|
||||||
|
text="Mike Broski",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk2_summary = TextSummary(
|
||||||
|
text="M.B.",
|
||||||
|
made_from=chunk2,
|
||||||
|
)
|
||||||
|
chunk3 = DocumentChunk(
|
||||||
|
text="Christina Mayer",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document1,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk3_summary = TextSummary(
|
||||||
|
text="C.M.",
|
||||||
|
made_from=chunk3,
|
||||||
|
)
|
||||||
|
chunk4 = DocumentChunk(
|
||||||
|
text="Range Rover",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk4_summary = TextSummary(
|
||||||
|
text="R.R.",
|
||||||
|
made_from=chunk4,
|
||||||
|
)
|
||||||
|
chunk5 = DocumentChunk(
|
||||||
|
text="Hyundai",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=1,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk5_summary = TextSummary(
|
||||||
|
text="H.Y.",
|
||||||
|
made_from=chunk5,
|
||||||
|
)
|
||||||
|
chunk6 = DocumentChunk(
|
||||||
|
text="Chrysler",
|
||||||
|
chunk_size=2,
|
||||||
|
chunk_index=2,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=document2,
|
||||||
|
contains=[],
|
||||||
|
)
|
||||||
|
chunk6_summary = TextSummary(
|
||||||
|
text="C.H.",
|
||||||
|
made_from=chunk6,
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = [
|
||||||
|
chunk1_summary,
|
||||||
|
chunk2_summary,
|
||||||
|
chunk3_summary,
|
||||||
|
chunk4_summary,
|
||||||
|
chunk5_summary,
|
||||||
|
chunk6_summary,
|
||||||
|
]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = SummariesRetriever(limit=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Christina")
|
||||||
|
|
||||||
|
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
async def test_chunk_context_on_empty_graph(self):
|
||||||
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
|
system_directory_path = os.path.join(
|
||||||
# Setup
|
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
|
||||||
query = "test query"
|
)
|
||||||
doc_id1 = str(uuid.uuid4())
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
doc_id2 = str(uuid.uuid4())
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
# Mock search results
|
await cognee.prune.prune_data()
|
||||||
mock_result_1 = MagicMock()
|
await cognee.prune.prune_system(metadata=True)
|
||||||
mock_result_1.payload = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"score": 0.95,
|
|
||||||
"payload": {
|
|
||||||
"text": "This is the first summary.",
|
|
||||||
"document_id": doc_id1,
|
|
||||||
"metadata": {"title": "Document 1"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
mock_result_2 = MagicMock()
|
|
||||||
mock_result_2.payload = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"score": 0.85,
|
|
||||||
"payload": {
|
|
||||||
"text": "This is the second summary.",
|
|
||||||
"document_id": doc_id2,
|
|
||||||
"metadata": {"title": "Document 2"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_search_results = [mock_result_1, mock_result_2]
|
retriever = SummariesRetriever()
|
||||||
mock_vector_engine = AsyncMock()
|
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
|
||||||
|
|
||||||
# Execute
|
with pytest.raises(NoDataError):
|
||||||
results = await mock_retriever.get_completion(query)
|
await retriever.get_context("Christina Mayer")
|
||||||
|
|
||||||
# Verify
|
vector_engine = get_vector_engine()
|
||||||
assert len(results) == 2
|
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
|
||||||
|
|
||||||
# Check first result
|
context = await retriever.get_context("Christina Mayer")
|
||||||
assert results[0]["payload"]["text"] == "This is the first summary."
|
assert context == [], "Returned context should be empty on an empty graph"
|
||||||
assert results[0]["payload"]["document_id"] == doc_id1
|
|
||||||
assert results[0]["payload"]["metadata"]["title"] == "Document 1"
|
|
||||||
assert results[0]["score"] == 0.95
|
|
||||||
|
|
||||||
# Check second result
|
|
||||||
assert results[1]["payload"]["text"] == "This is the second summary."
|
|
||||||
assert results[1]["payload"]["document_id"] == doc_id2
|
|
||||||
assert results[1]["payload"]["metadata"]["title"] == "Document 2"
|
|
||||||
assert results[1]["score"] == 0.85
|
|
||||||
|
|
||||||
# Verify search was called correctly
|
if __name__ == "__main__":
|
||||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
|
from asyncio import run
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
test = TextSummariesRetriever()
|
||||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
|
||||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
|
||||||
# Setup
|
|
||||||
query = "test query with no results"
|
|
||||||
mock_search_results = []
|
|
||||||
mock_vector_engine = AsyncMock()
|
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
|
||||||
|
|
||||||
# Execute
|
run(test.test_chunk_context())
|
||||||
results = await mock_retriever.get_completion(query)
|
run(test.test_chunk_context_on_empty_graph())
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert len(results) == 0
|
|
||||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
|
||||||
async def test_get_completion_with_custom_limit(self, mock_get_vector_engine, mock_retriever):
|
|
||||||
# Setup
|
|
||||||
query = "test query with custom limit"
|
|
||||||
doc_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# Mock search results
|
|
||||||
mock_result = MagicMock()
|
|
||||||
mock_result.payload = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"score": 0.95,
|
|
||||||
"payload": {
|
|
||||||
"text": "This is a summary.",
|
|
||||||
"document_id": doc_id,
|
|
||||||
"metadata": {"title": "Document 1"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_search_results = [mock_result]
|
|
||||||
mock_vector_engine = AsyncMock()
|
|
||||||
mock_vector_engine.search.return_value = mock_search_results
|
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
|
||||||
|
|
||||||
# Set custom limit
|
|
||||||
mock_retriever.limit = 10
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await mock_retriever.get_completion(query)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0]["payload"]["text"] == "This is a summary."
|
|
||||||
|
|
||||||
# Verify search was called with custom limit
|
|
||||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=10)
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
import pytest
|
import pytest
|
||||||
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
from unittest.mock import AsyncMock, patch
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||||
brute_force_search,
|
brute_force_search,
|
||||||
brute_force_triplet_search,
|
brute_force_triplet_search,
|
||||||
)
|
)
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -20,13 +20,11 @@ async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
|
||||||
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
mock_get_vector_engine.return_value = mock_vector_engine
|
||||||
|
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(CollectionDistancesNotFoundError):
|
||||||
await brute_force_search(
|
await brute_force_search(
|
||||||
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
||||||
|
|
@ -40,9 +38,7 @@ async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_e
|
||||||
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||||
mock_get_vector_engine.return_value = mock_vector_engine
|
mock_get_vector_engine.return_value = mock_vector_engine
|
||||||
|
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(CollectionDistancesNotFoundError):
|
||||||
await brute_force_triplet_search(
|
await brute_force_triplet_search(
|
||||||
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue