From 2a462085695a127a61467d1d14d8b2a86333be3a Mon Sep 17 00:00:00 2001 From: xdurawa Date: Mon, 1 Sep 2025 20:47:26 -0400 Subject: [PATCH 01/31] Adding AWS Bedrock support as a LLM provider Signed-off-by: xdurawa --- .github/workflows/test_bedrock_api_key.yml | 28 +++ .../test_bedrock_aws_credentials.yml | 29 ++++ .../workflows/test_bedrock_aws_profile.yml | 37 ++++ .github/workflows/test_suites.yml | 24 +++ README.md | 10 ++ cognee-starter-kit/README.md | 38 +++++ cognee/api/v1/add/add.py | 2 +- cognee/api/v1/search/search.py | 2 +- cognee/infrastructure/llm/config.py | 14 ++ .../llm/bedrock/__init__.py | 6 + .../litellm_instructor/llm/bedrock/adapter.py | 161 ++++++++++++++++++ .../litellm_instructor/llm/get_llm_client.py | 20 +++ cognee/modules/settings/get_settings.py | 21 ++- pyproject.toml | 2 +- uv.lock | 8 +- 15 files changed, 394 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/test_bedrock_api_key.yml create mode 100644 .github/workflows/test_bedrock_aws_credentials.yml create mode 100644 .github/workflows/test_bedrock_aws_profile.yml create mode 100644 cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py create mode 100644 cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py diff --git a/.github/workflows/test_bedrock_api_key.yml b/.github/workflows/test_bedrock_api_key.yml new file mode 100644 index 000000000..3f5ea94b3 --- /dev/null +++ b/.github/workflows/test_bedrock_api_key.yml @@ -0,0 +1,28 @@ +name: test | bedrock | api key + +on: + workflow_call: + +jobs: + test-bedrock-api-key: + name: Run Bedrock API Key Test + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Bedrock API Key Test + env: + LLM_PROVIDER: "bedrock" + LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }} + LLM_MODEL: "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + AWS_REGION_NAME: "us-east-1" + EMBEDDING_PROVIDER: "bedrock" + EMBEDDING_MODEL: "amazon.titan-embed-text-v1" + EMBEDDING_DIMENSIONS: "1536" + run: poetry run python ./examples/python/simple_example.py diff --git a/.github/workflows/test_bedrock_aws_credentials.yml b/.github/workflows/test_bedrock_aws_credentials.yml new file mode 100644 index 000000000..c086dceb3 --- /dev/null +++ b/.github/workflows/test_bedrock_aws_credentials.yml @@ -0,0 +1,29 @@ +name: test | bedrock | aws credentials + +on: + workflow_call: + +jobs: + test-bedrock-aws-credentials: + name: Run Bedrock AWS Credentials Test + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Bedrock AWS Credentials Test + env: + LLM_PROVIDER: "bedrock" + LLM_MODEL: "us.anthropic.claude-3-5-sonnet-20240620-v1:0" + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + AWS_REGION_NAME: "us-east-1" + EMBEDDING_PROVIDER: "cohere" + EMBEDDING_MODEL: "cohere.embed-english-v3" + EMBEDDING_DIMENSIONS: "1024" + run: poetry run python ./examples/python/simple_example.py diff --git a/.github/workflows/test_bedrock_aws_profile.yml b/.github/workflows/test_bedrock_aws_profile.yml new file mode 100644 index 000000000..aa15074e1 --- /dev/null +++ b/.github/workflows/test_bedrock_aws_profile.yml @@ -0,0 +1,37 @@ +name: test | bedrock | aws profile + +on: + workflow_call: + +jobs: + test-bedrock-aws-profile: + name: Run Bedrock AWS Profile Test + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Configure AWS Profile + run: | + mkdir -p ~/.aws + cat > ~/.aws/credentials << EOF + [bedrock-test] + aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }} + aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }} + EOF + + - name: Run Bedrock AWS Profile Test + env: + LLM_PROVIDER: "bedrock" + LLM_MODEL: "us.anthropic.claude-3-5-haiku-20241022-v1:0" + AWS_PROFILE_NAME: "bedrock-test" + AWS_REGION_NAME: "us-east-1" + EMBEDDING_PROVIDER: "bedrock" + EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0" + EMBEDDING_DIMENSIONS: "1024" + run: poetry run python ./examples/python/simple_example.py diff --git a/.github/workflows/test_suites.yml b/.github/workflows/test_suites.yml index d34523ce1..78b878931 100644 --- a/.github/workflows/test_suites.yml +++ b/.github/workflows/test_suites.yml @@ -110,6 +110,24 @@ jobs: uses: ./.github/workflows/test_gemini.yml secrets: inherit + bedrock-tests: + name: Bedrock Tests + needs: [basic-tests, e2e-tests] + uses: ./.github/workflows/test_bedrock_api_key.yml + secrets: inherit + + bedrock-aws-credentials-tests: + name: Bedrock AWS Credentials Tests + needs: [basic-tests, e2e-tests] + uses: ./.github/workflows/test_bedrock_aws_credentials.yml + secrets: inherit + + bedrock-aws-profile-tests: + name: Bedrock AWS Profile Tests + needs: [basic-tests, e2e-tests] + uses: ./.github/workflows/test_bedrock_aws_profile.yml + secrets: inherit + # Ollama tests moved to the end ollama-tests: name: Ollama Tests @@ -143,6 +161,9 @@ jobs: db-examples-tests, mcp-test, gemini-tests, + bedrock-tests, + bedrock-aws-credentials-tests, + bedrock-aws-profile-tests, ollama-tests, relational-db-migration-tests, docker-compose-test, @@ -163,6 +184,9 @@ jobs: "${{ needs.db-examples-tests.result }}" == "success" && "${{ needs.relational-db-migration-tests.result }}" == "success" && "${{ needs.gemini-tests.result }}" == "success" && + "${{ needs.bedrock-tests.result }}" == "success" && + "${{ needs.bedrock-aws-credentials-tests.result }}" == "success" && + "${{ needs.bedrock-aws-profile-tests.result }}" == "success" && "${{ needs.docker-compose-test.result }}" == "success" && "${{ needs.docker-ci-test.result }}" == "success" && "${{ needs.ollama-tests.result }}" == "success" ]]; then diff --git a/README.md b/README.md index 3486d2ce9..70422bc6a 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,16 @@ os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY" ``` You can also set the variables by creating .env file, using our template. + +**Supported LLM Providers:** OpenAI (default), Anthropic, Gemini, Ollama, AWS Bedrock + +**For AWS Bedrock:** Set `LLM_PROVIDER="bedrock"` and use one of three authentication methods: +- API Key: `LLM_API_KEY="your_bedrock_api_key"` +- AWS Credentials: `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY` (+ `AWS_SESSION_TOKEN` if needed) +- AWS Profile: `AWS_PROFILE_NAME="your_profile"` + +Use an [inference profile](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html#API_runtime_InvokeModel_Example_5:~:text=Use%20an%20inference%20profile%20in%20model%20invocation) for the model IDs. This usually means appending `us.*` (or other region) to the model ID (e.g., `us.anthropic.claude-3-5-sonnet-20241022-v2:0`). See [AWS Bedrock models](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html). + To use different LLM providers, for more info check out our documentation diff --git a/cognee-starter-kit/README.md b/cognee-starter-kit/README.md index c265e278e..5a9369b89 100644 --- a/cognee-starter-kit/README.md +++ b/cognee-starter-kit/README.md @@ -25,6 +25,14 @@ uv sync ## Setup LLM Add environment variables to `.env` file. In case you choose to use OpenAI provider, add just the model and api_key. + +**Supported LLM Providers:** +- OpenAI (default) +- Anthropic +- Gemini +- Ollama +- AWS Bedrock + ``` LLM_PROVIDER="" LLM_MODEL="" @@ -39,6 +47,36 @@ EMBEDDING_API_KEY="" EMBEDDING_API_VERSION="" ``` +**For AWS Bedrock, you have three authentication options:** + +1. **API Key (Bearer Token):** +``` +LLM_PROVIDER="bedrock" +LLM_API_KEY="your_bedrock_api_key" +LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0" +AWS_REGION_NAME="us-east-1" +``` + +2. **AWS Credentials:** +``` +LLM_PROVIDER="bedrock" +LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0" +AWS_ACCESS_KEY_ID="your_aws_access_key" +AWS_SECRET_ACCESS_KEY="your_aws_secret_key" +[if needed] AWS_SESSION_TOKEN="your_session_token" +AWS_REGION_NAME="us-east-1" +``` + +3. **AWS Profile:** +``` +LLM_PROVIDER="bedrock" +LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0" +AWS_PROFILE_NAME="your_aws_profile" +AWS_REGION_NAME="us-east-1" +``` + +**Note:** For Bedrock models, use an [inference profile](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html#API_runtime_InvokeModel_Example_5:~:text=Use%20an%20inference%20profile%20in%20model%20invocation) for `LLM_MODEL`. This usually means appending `us.*` (or other region) to the model ID (e.g., `us.anthropic.claude-3-5-sonnet-20241022-v2:0`). See [AWS Bedrock models](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for available models. + Activate the Python environment: ``` source .venv/bin/activate diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 7daaaf1dd..01c58a134 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -127,7 +127,7 @@ async def add( - LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.) Optional: - - LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama" + - LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "bedrock" - LLM_MODEL: Model name (default: "gpt-4o-mini") - DEFAULT_USER_EMAIL: Custom default user email - DEFAULT_USER_PASSWORD: Custom default user password diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 66ce48cc2..c13d1c366 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -154,7 +154,7 @@ async def search( - LLM_API_KEY: API key for your LLM provider Optional: - - LLM_PROVIDER, LLM_MODEL: Configure LLM for search responses + - LLM_PROVIDER, LLM_MODEL: Configure LLM for search responses (supports: openai, anthropic, gemini, ollama, bedrock) - VECTOR_DB_PROVIDER: Must match what was used during cognify - GRAPH_DATABASE_PROVIDER: Must match what was used during cognify diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index f31aada33..a6acb647d 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -27,6 +27,12 @@ class LLMConfig(BaseSettings): - embedding_rate_limit_enabled - embedding_rate_limit_requests - embedding_rate_limit_interval + - aws_access_key_id (Bedrock) + - aws_secret_access_key (Bedrock) + - aws_session_token (Bedrock) + - aws_region_name (Bedrock) + - aws_profile_name (Bedrock) + - aws_bedrock_runtime_endpoint (Bedrock) Public methods include: - ensure_env_vars_for_ollama @@ -63,6 +69,14 @@ class LLMConfig(BaseSettings): fallback_endpoint: str = "" fallback_model: str = "" + # AWS Bedrock configuration + aws_access_key_id: Optional[str] = None + aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None + aws_region_name: str = "us-east-1" + aws_profile_name: Optional[str] = None + aws_bedrock_runtime_endpoint: Optional[str] = None + baml_registry: ClassVar[ClientRegistry] = ClientRegistry() model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py new file mode 100644 index 000000000..6fb964a82 --- /dev/null +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py @@ -0,0 +1,6 @@ +"""Bedrock LLM adapter module.""" + +from .adapter import BedrockAdapter + +__all__ = ["BedrockAdapter"] + diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py new file mode 100644 index 000000000..868fe51b8 --- /dev/null +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py @@ -0,0 +1,161 @@ +import litellm +import instructor +from typing import Type, Optional +from pydantic import BaseModel +from litellm.exceptions import ContentPolicyViolationError +from instructor.exceptions import InstructorRetryException + +from cognee.exceptions import InvalidValueError +from cognee.infrastructure.llm.LLMGateway import LLMGateway +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( + LLMInterface, +) +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( + rate_limit_async, + rate_limit_sync, + sleep_and_retry_async, + sleep_and_retry_sync, +) +from cognee.modules.observability.get_observe import get_observe + +observe = get_observe() + + +class BedrockAdapter(LLMInterface): + """ + Adapter for AWS Bedrock API with support for three authentication methods: + 1. API Key (Bearer Token) + 2. AWS Credentials (access key + secret key) + 3. AWS Profile (boto3 credential chain) + """ + + name = "Bedrock" + model: str + api_key: str + aws_access_key_id: str + aws_secret_access_key: str + aws_region_name: str + aws_profile_name: str + + MAX_RETRIES = 5 + + def __init__( + self, + model: str, + api_key: str = None, + aws_access_key_id: str = None, + aws_secret_access_key: str = None, + aws_session_token: str = None, + aws_region_name: str = "us-east-1", + aws_profile_name: str = None, + aws_bedrock_runtime_endpoint: str = None, + max_tokens: int = 16384, + streaming: bool = False, + ): + self.aclient = instructor.from_litellm(litellm.acompletion) + self.client = instructor.from_litellm(litellm.completion) + self.model = model + self.api_key = api_key + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + self.aws_bedrock_runtime_endpoint = aws_bedrock_runtime_endpoint + self.max_tokens = max_tokens + self.streaming = streaming + + def _create_bedrock_request( + self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + ) -> dict: + """Create Bedrock request with authentication and enhanced JSON formatting.""" + enhanced_system_prompt = f"""{system_prompt} + +IMPORTANT: You must respond with valid JSON only. Do not include any text before or after the JSON. The response must be a valid JSON object that can be parsed directly.""" + + request_params = { + "model": self.model, + "custom_llm_provider": "bedrock", + "drop_params": True, + "messages": [ + {"role": "user", "content": text_input}, + {"role": "system", "content": enhanced_system_prompt}, + ], + "response_model": response_model, + "max_retries": self.MAX_RETRIES, + "max_tokens": self.max_tokens, + "stream": self.streaming, + } + + # Add authentication parameters + if self.api_key: + request_params["api_key"] = self.api_key + elif self.aws_access_key_id and self.aws_secret_access_key: + request_params["aws_access_key_id"] = self.aws_access_key_id + request_params["aws_secret_access_key"] = self.aws_secret_access_key + if self.aws_session_token: + request_params["aws_session_token"] = self.aws_session_token + elif self.aws_profile_name: + request_params["aws_profile_name"] = self.aws_profile_name + + # Add optional parameters + if self.aws_region_name: + request_params["aws_region_name"] = self.aws_region_name + if self.aws_bedrock_runtime_endpoint: + request_params["aws_bedrock_runtime_endpoint"] = self.aws_bedrock_runtime_endpoint + + return request_params + + @observe(as_type="generation") + @sleep_and_retry_async() + @rate_limit_async + async def acreate_structured_output( + self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + ) -> BaseModel: + """Generate structured output from AWS Bedrock API.""" + + try: + request_params = self._create_bedrock_request(text_input, system_prompt, response_model) + return await self.aclient.chat.completions.create(**request_params) + + except ( + ContentPolicyViolationError, + InstructorRetryException, + ) as error: + if ( + isinstance(error, InstructorRetryException) + and "content management policy" not in str(error).lower() + ): + raise error + + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" + ) + + @observe + @sleep_and_retry_sync() + @rate_limit_sync + def create_structured_output( + self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + ) -> BaseModel: + """Generate structured output from AWS Bedrock API (synchronous).""" + + request_params = self._create_bedrock_request(text_input, system_prompt, response_model) + return self.client.chat.completions.create(**request_params) + + def show_prompt(self, text_input: str, system_prompt: str) -> str: + """Format and display the prompt for a user query.""" + if not text_input: + text_input = "No user input provided." + if not system_prompt: + raise InvalidValueError(message="No system prompt path provided.") + system_prompt = LLMGateway.read_query_prompt(system_prompt) + + formatted_prompt = ( + f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" + if system_prompt + else None + ) + return formatted_prompt diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 22d101077..0ade7a292 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -20,6 +20,7 @@ class LLMProvider(Enum): - ANTHROPIC: Represents the Anthropic provider. - CUSTOM: Represents a custom provider option. - GEMINI: Represents the Gemini provider. + - BEDROCK: Represents the AWS Bedrock provider. """ OPENAI = "openai" @@ -27,6 +28,7 @@ class LLMProvider(Enum): ANTHROPIC = "anthropic" CUSTOM = "custom" GEMINI = "gemini" + BEDROCK = "bedrock" def get_llm_client(): @@ -137,5 +139,23 @@ def get_llm_client(): streaming=llm_config.llm_streaming, ) + elif provider == LLMProvider.BEDROCK: + from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import ( + BedrockAdapter, + ) + + return BedrockAdapter( + model=llm_config.llm_model, + api_key=llm_config.llm_api_key, + aws_access_key_id=llm_config.aws_access_key_id, + aws_secret_access_key=llm_config.aws_secret_access_key, + aws_session_token=llm_config.aws_session_token, + aws_region_name=llm_config.aws_region_name, + aws_profile_name=llm_config.aws_profile_name, + aws_bedrock_runtime_endpoint=llm_config.aws_bedrock_runtime_endpoint, + max_tokens=max_tokens, + streaming=llm_config.llm_streaming, + ) + else: raise InvalidValueError(message=f"Unsupported LLM provider: {provider}") diff --git a/cognee/modules/settings/get_settings.py b/cognee/modules/settings/get_settings.py index fa7dfc2df..71017350f 100644 --- a/cognee/modules/settings/get_settings.py +++ b/cognee/modules/settings/get_settings.py @@ -15,10 +15,11 @@ class ModelName(Enum): ollama = "ollama" anthropic = "anthropic" gemini = "gemini" + bedrock = "bedrock" class LLMConfig(BaseModel): - api_key: str + api_key: Optional[str] model: str provider: str endpoint: Optional[str] @@ -72,6 +73,10 @@ def get_settings() -> SettingsDict: "value": "gemini", "label": "Gemini", }, + { + "value": "bedrock", + "label": "AWS Bedrock", + }, ] return SettingsDict.model_validate( @@ -134,6 +139,20 @@ def get_settings() -> SettingsDict: "label": "Gemini 2.0 Flash", }, ], + "bedrock": [ + { + "value": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "label": "Claude 3.5 Sonnet", + }, + { + "value": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "label": "Claude 3.5 Haiku", + }, + { + "value": "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + "label": "Claude 3.5 Sonnet (June)", + }, + ], }, }, vector_db={ diff --git a/pyproject.toml b/pyproject.toml index 61076e86b..e6ad0eff4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "sqlalchemy>=2.0.39,<3.0.0", "aiosqlite>=0.20.0,<1.0.0", "tiktoken>=0.8.0,<1.0.0", - "litellm>=1.71.0, <2.0.0", + "litellm>=1.76.0, <2.0.0", "instructor>=1.9.1,<2.0.0", "langfuse>=2.32.0,<3", "filetype>=1.2.0,<2.0.0", diff --git a/uv.lock b/uv.lock index 137263188..963c6805e 100644 --- a/uv.lock +++ b/uv.lock @@ -1060,7 +1060,7 @@ requires-dist = [ { name = "langfuse", specifier = ">=2.32.0,<3" }, { name = "langsmith", marker = "extra == 'langchain'", specifier = ">=0.2.3,<1.0.0" }, { name = "limits", specifier = ">=4.4.1,<5" }, - { name = "litellm", specifier = ">=1.71.0,<2.0.0" }, + { name = "litellm", specifier = ">=1.76.0,<2.0.0" }, { name = "llama-index-core", marker = "extra == 'llama-index'", specifier = ">=0.12.11,<0.13" }, { name = "matplotlib", specifier = ">=3.8.3,<4" }, { name = "mistral-common", marker = "extra == 'mistral'", specifier = ">=1.5.2,<2" }, @@ -3552,7 +3552,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.75.8" +version = "1.76.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3567,9 +3567,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/4e/48e3d6de19afe713223e3bc7009a2003501420de2a5d823c569cefbd9731/litellm-1.75.8.tar.gz", hash = "sha256:92061bd263ff8c33c8fff70ba92cd046adb7ea041a605826a915d108742fe59e", size = 10140384, upload-time = "2025-08-16T21:42:24.23Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/e8/47c791c3d2cb4397ddece90840aecfc6cdc4a003f039dde42d7c861f4709/litellm-1.76.0.tar.gz", hash = "sha256:d26d12333135edd72af60e0e310284dac3b079f4d7c47c79dfbb2430b9b4b421", size = 10170569, upload-time = "2025-08-24T05:14:01.176Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/82/c4d00fbeafd93c00dab6ea03f33cadd6a97adeb720ba1d89fc319e5cb10b/litellm-1.75.8-py3-none-any.whl", hash = "sha256:0bf004488df8506381ec6e35e1486e2870e8d578a7c3f2427cd497558ce07a2e", size = 8916305, upload-time = "2025-08-16T21:42:21.387Z" }, + { url = "https://files.pythonhosted.org/packages/86/f2/891b4b6c09021046d7f5bcff57178b18f352a67b032d35cc693d79b38620/litellm-1.76.0-py3-none-any.whl", hash = "sha256:357464242fc1eeda384810c9e334e48ad67a50ecd30cf61e86c15f89e2f2e0b4", size = 8953112, upload-time = "2025-08-24T05:13:58.642Z" }, ] [[package]] From c91d1ff0aed90e66073f5a2a284cb2d21237eb23 Mon Sep 17 00:00:00 2001 From: xdurawa Date: Wed, 3 Sep 2025 01:34:21 -0400 Subject: [PATCH 02/31] Remove documentation changes as requested by reviewers - Reverted README.md to original state - Reverted cognee-starter-kit/README.md to original state - Documentation will be updated separately by maintainers --- README.md | 10 ---------- cognee-starter-kit/README.md | 38 ------------------------------------ 2 files changed, 48 deletions(-) diff --git a/README.md b/README.md index 73c6aa898..e618d5bf9 100644 --- a/README.md +++ b/README.md @@ -125,16 +125,6 @@ os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY" ``` You can also set the variables by creating .env file, using our template. - -**Supported LLM Providers:** OpenAI (default), Anthropic, Gemini, Ollama, AWS Bedrock - -**For AWS Bedrock:** Set `LLM_PROVIDER="bedrock"` and use one of three authentication methods: -- API Key: `LLM_API_KEY="your_bedrock_api_key"` -- AWS Credentials: `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY` (+ `AWS_SESSION_TOKEN` if needed) -- AWS Profile: `AWS_PROFILE_NAME="your_profile"` - -Use an [inference profile](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html#API_runtime_InvokeModel_Example_5:~:text=Use%20an%20inference%20profile%20in%20model%20invocation) for the model IDs. This usually means appending `us.*` (or other region) to the model ID (e.g., `us.anthropic.claude-3-5-sonnet-20241022-v2:0`). See [AWS Bedrock models](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html). - To use different LLM providers, for more info check out our documentation diff --git a/cognee-starter-kit/README.md b/cognee-starter-kit/README.md index 5a9369b89..c265e278e 100644 --- a/cognee-starter-kit/README.md +++ b/cognee-starter-kit/README.md @@ -25,14 +25,6 @@ uv sync ## Setup LLM Add environment variables to `.env` file. In case you choose to use OpenAI provider, add just the model and api_key. - -**Supported LLM Providers:** -- OpenAI (default) -- Anthropic -- Gemini -- Ollama -- AWS Bedrock - ``` LLM_PROVIDER="" LLM_MODEL="" @@ -47,36 +39,6 @@ EMBEDDING_API_KEY="" EMBEDDING_API_VERSION="" ``` -**For AWS Bedrock, you have three authentication options:** - -1. **API Key (Bearer Token):** -``` -LLM_PROVIDER="bedrock" -LLM_API_KEY="your_bedrock_api_key" -LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0" -AWS_REGION_NAME="us-east-1" -``` - -2. **AWS Credentials:** -``` -LLM_PROVIDER="bedrock" -LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0" -AWS_ACCESS_KEY_ID="your_aws_access_key" -AWS_SECRET_ACCESS_KEY="your_aws_secret_key" -[if needed] AWS_SESSION_TOKEN="your_session_token" -AWS_REGION_NAME="us-east-1" -``` - -3. **AWS Profile:** -``` -LLM_PROVIDER="bedrock" -LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0" -AWS_PROFILE_NAME="your_aws_profile" -AWS_REGION_NAME="us-east-1" -``` - -**Note:** For Bedrock models, use an [inference profile](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html#API_runtime_InvokeModel_Example_5:~:text=Use%20an%20inference%20profile%20in%20model%20invocation) for `LLM_MODEL`. This usually means appending `us.*` (or other region) to the model ID (e.g., `us.anthropic.claude-3-5-sonnet-20241022-v2:0`). See [AWS Bedrock models](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for available models. - Activate the Python environment: ``` source .venv/bin/activate From 0a4b1068a253df8fb4e39a93ee18a73c911ee49e Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Mon, 17 Nov 2025 17:42:22 +0100 Subject: [PATCH 03/31] feat: add kwargs to openai adapter functions --- .../litellm_instructor/llm/openai/adapter.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 305b426b8..152f43e33 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -108,7 +108,7 @@ class OpenAIAdapter(LLMInterface): reraise=True, ) async def acreate_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs ) -> BaseModel: """ Generate a response from a user query. @@ -149,6 +149,7 @@ class OpenAIAdapter(LLMInterface): api_version=self.api_version, response_model=response_model, max_retries=self.MAX_RETRIES, + **kwargs, ) except ( ContentFilterFinishReasonError, @@ -174,6 +175,7 @@ class OpenAIAdapter(LLMInterface): # api_base=self.fallback_endpoint, response_model=response_model, max_retries=self.MAX_RETRIES, + **kwargs, ) except ( ContentFilterFinishReasonError, @@ -199,7 +201,7 @@ class OpenAIAdapter(LLMInterface): reraise=True, ) def create_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs ) -> BaseModel: """ Generate a response from a user query. @@ -239,6 +241,7 @@ class OpenAIAdapter(LLMInterface): api_version=self.api_version, response_model=response_model, max_retries=self.MAX_RETRIES, + **kwargs, ) @retry( @@ -248,7 +251,7 @@ class OpenAIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input): + async def create_transcript(self, input, **kwargs): """ Generate an audio transcript from a user query. @@ -275,6 +278,7 @@ class OpenAIAdapter(LLMInterface): api_base=self.endpoint, api_version=self.api_version, max_retries=self.MAX_RETRIES, + **kwargs, ) return transcription @@ -286,7 +290,7 @@ class OpenAIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def transcribe_image(self, input) -> BaseModel: + async def transcribe_image(self, input, **kwargs) -> BaseModel: """ Generate a transcription of an image from a user query. @@ -331,4 +335,5 @@ class OpenAIAdapter(LLMInterface): api_version=self.api_version, max_completion_tokens=300, max_retries=self.MAX_RETRIES, + **kwargs, ) From 3b78eb88bd4bc778089f1061408cb413b5e7ff20 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Mon, 24 Nov 2025 16:38:23 +0100 Subject: [PATCH 04/31] fix: use s3 config --- cognee/api/v1/search/search.py | 2 +- .../infrastructure/files/storage/s3_config.py | 3 ++ cognee/infrastructure/llm/config.py | 14 ------- .../litellm_instructor/llm/bedrock/adapter.py | 41 +++++++------------ .../litellm_instructor/llm/get_llm_client.py | 8 +--- 5 files changed, 19 insertions(+), 49 deletions(-) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index e64bcb848..49f7aee51 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -161,7 +161,7 @@ async def search( - LLM_API_KEY: API key for your LLM provider Optional: - - LLM_PROVIDER, LLM_MODEL: Configure LLM for search responses (supports: openai, anthropic, gemini, ollama, bedrock) + - LLM_PROVIDER, LLM_MODEL: Configure LLM for search responses - VECTOR_DB_PROVIDER: Must match what was used during cognify - GRAPH_DATABASE_PROVIDER: Must match what was used during cognify diff --git a/cognee/infrastructure/files/storage/s3_config.py b/cognee/infrastructure/files/storage/s3_config.py index 0b9372b7e..4cc6b1d63 100644 --- a/cognee/infrastructure/files/storage/s3_config.py +++ b/cognee/infrastructure/files/storage/s3_config.py @@ -8,6 +8,9 @@ class S3Config(BaseSettings): aws_endpoint_url: Optional[str] = None aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None + aws_profile_name: Optional[str] = None + aws_bedrock_runtime_endpoint: Optional[str] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 091f8e6ea..7aa8f33f7 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -27,12 +27,6 @@ class LLMConfig(BaseSettings): - embedding_rate_limit_enabled - embedding_rate_limit_requests - embedding_rate_limit_interval - - aws_access_key_id (Bedrock) - - aws_secret_access_key (Bedrock) - - aws_session_token (Bedrock) - - aws_region_name (Bedrock) - - aws_profile_name (Bedrock) - - aws_bedrock_runtime_endpoint (Bedrock) Public methods include: - ensure_env_vars_for_ollama @@ -71,14 +65,6 @@ class LLMConfig(BaseSettings): fallback_endpoint: str = "" fallback_model: str = "" - # AWS Bedrock configuration - aws_access_key_id: Optional[str] = None - aws_secret_access_key: Optional[str] = None - aws_session_token: Optional[str] = None - aws_region_name: str = "us-east-1" - aws_profile_name: Optional[str] = None - aws_bedrock_runtime_endpoint: Optional[str] = None - baml_registry: ClassVar[ClientRegistry] = ClientRegistry() model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py index 868fe51b8..66f484164 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py @@ -11,6 +11,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll LLMInterface, ) from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.infrastructure.files.storage.s3_config import get_s3_config from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( rate_limit_async, @@ -34,10 +35,6 @@ class BedrockAdapter(LLMInterface): name = "Bedrock" model: str api_key: str - aws_access_key_id: str - aws_secret_access_key: str - aws_region_name: str - aws_profile_name: str MAX_RETRIES = 5 @@ -45,12 +42,6 @@ class BedrockAdapter(LLMInterface): self, model: str, api_key: str = None, - aws_access_key_id: str = None, - aws_secret_access_key: str = None, - aws_session_token: str = None, - aws_region_name: str = "us-east-1", - aws_profile_name: str = None, - aws_bedrock_runtime_endpoint: str = None, max_tokens: int = 16384, streaming: bool = False, ): @@ -58,12 +49,6 @@ class BedrockAdapter(LLMInterface): self.client = instructor.from_litellm(litellm.completion) self.model = model self.api_key = api_key - self.aws_access_key_id = aws_access_key_id - self.aws_secret_access_key = aws_secret_access_key - self.aws_session_token = aws_session_token - self.aws_region_name = aws_region_name - self.aws_profile_name = aws_profile_name - self.aws_bedrock_runtime_endpoint = aws_bedrock_runtime_endpoint self.max_tokens = max_tokens self.streaming = streaming @@ -89,22 +74,24 @@ IMPORTANT: You must respond with valid JSON only. Do not include any text before "stream": self.streaming, } + s3_config = get_s3_config() + # Add authentication parameters if self.api_key: request_params["api_key"] = self.api_key - elif self.aws_access_key_id and self.aws_secret_access_key: - request_params["aws_access_key_id"] = self.aws_access_key_id - request_params["aws_secret_access_key"] = self.aws_secret_access_key - if self.aws_session_token: - request_params["aws_session_token"] = self.aws_session_token - elif self.aws_profile_name: - request_params["aws_profile_name"] = self.aws_profile_name + elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key: + request_params["aws_access_key_id"] = s3_config.aws_access_key_id + request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key + if s3_config.aws_session_token: + request_params["aws_session_token"] = s3_config.aws_session_token + elif s3_config.aws_profile_name: + request_params["aws_profile_name"] = s3_config.aws_profile_name # Add optional parameters - if self.aws_region_name: - request_params["aws_region_name"] = self.aws_region_name - if self.aws_bedrock_runtime_endpoint: - request_params["aws_bedrock_runtime_endpoint"] = self.aws_bedrock_runtime_endpoint + if s3_config.aws_region_name: + request_params["aws_region_name"] = s3_config.aws_region_name + if s3_config.aws_bedrock_runtime_endpoint: + request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint return request_params diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 946698a95..489f7ae8e 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -156,13 +156,7 @@ def get_llm_client(): return BedrockAdapter( model=llm_config.llm_model, api_key=llm_config.llm_api_key, - aws_access_key_id=llm_config.aws_access_key_id, - aws_secret_access_key=llm_config.aws_secret_access_key, - aws_session_token=llm_config.aws_session_token, - aws_region_name=llm_config.aws_region_name, - aws_profile_name=llm_config.aws_profile_name, - aws_bedrock_runtime_endpoint=llm_config.aws_bedrock_runtime_endpoint, - max_tokens=max_tokens, + max_tokens=max_completion_tokens, streaming=llm_config.llm_streaming, ) From e0d48c043a1594c567135a1e73cd209c1d07eba1 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 25 Nov 2025 12:58:07 +0100 Subject: [PATCH 05/31] fix: fixes to adapter and tests --- .github/workflows/test_bedrock_api_key.yml | 28 ------ .../test_bedrock_aws_credentials.yml | 29 ------ .../workflows/test_bedrock_aws_profile.yml | 37 -------- .github/workflows/test_llms.yml | 88 +++++++++++++++++++ .github/workflows/test_suites.yml | 24 ----- .../llm/bedrock/__init__.py | 1 - .../litellm_instructor/llm/bedrock/adapter.py | 35 ++++---- .../litellm_instructor/llm/get_llm_client.py | 7 +- cognee/modules/settings/get_settings.py | 12 +-- 9 files changed, 117 insertions(+), 144 deletions(-) delete mode 100644 .github/workflows/test_bedrock_api_key.yml delete mode 100644 .github/workflows/test_bedrock_aws_credentials.yml delete mode 100644 .github/workflows/test_bedrock_aws_profile.yml diff --git a/.github/workflows/test_bedrock_api_key.yml b/.github/workflows/test_bedrock_api_key.yml deleted file mode 100644 index 3f5ea94b3..000000000 --- a/.github/workflows/test_bedrock_api_key.yml +++ /dev/null @@ -1,28 +0,0 @@ -name: test | bedrock | api key - -on: - workflow_call: - -jobs: - test-bedrock-api-key: - name: Run Bedrock API Key Test - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Bedrock API Key Test - env: - LLM_PROVIDER: "bedrock" - LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }} - LLM_MODEL: "us.anthropic.claude-3-5-sonnet-20241022-v2:0" - AWS_REGION_NAME: "us-east-1" - EMBEDDING_PROVIDER: "bedrock" - EMBEDDING_MODEL: "amazon.titan-embed-text-v1" - EMBEDDING_DIMENSIONS: "1536" - run: poetry run python ./examples/python/simple_example.py diff --git a/.github/workflows/test_bedrock_aws_credentials.yml b/.github/workflows/test_bedrock_aws_credentials.yml deleted file mode 100644 index c086dceb3..000000000 --- a/.github/workflows/test_bedrock_aws_credentials.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: test | bedrock | aws credentials - -on: - workflow_call: - -jobs: - test-bedrock-aws-credentials: - name: Run Bedrock AWS Credentials Test - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Bedrock AWS Credentials Test - env: - LLM_PROVIDER: "bedrock" - LLM_MODEL: "us.anthropic.claude-3-5-sonnet-20240620-v1:0" - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - AWS_REGION_NAME: "us-east-1" - EMBEDDING_PROVIDER: "cohere" - EMBEDDING_MODEL: "cohere.embed-english-v3" - EMBEDDING_DIMENSIONS: "1024" - run: poetry run python ./examples/python/simple_example.py diff --git a/.github/workflows/test_bedrock_aws_profile.yml b/.github/workflows/test_bedrock_aws_profile.yml deleted file mode 100644 index aa15074e1..000000000 --- a/.github/workflows/test_bedrock_aws_profile.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: test | bedrock | aws profile - -on: - workflow_call: - -jobs: - test-bedrock-aws-profile: - name: Run Bedrock AWS Profile Test - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Configure AWS Profile - run: | - mkdir -p ~/.aws - cat > ~/.aws/credentials << EOF - [bedrock-test] - aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }} - aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }} - EOF - - - name: Run Bedrock AWS Profile Test - env: - LLM_PROVIDER: "bedrock" - LLM_MODEL: "us.anthropic.claude-3-5-haiku-20241022-v1:0" - AWS_PROFILE_NAME: "bedrock-test" - AWS_REGION_NAME: "us-east-1" - EMBEDDING_PROVIDER: "bedrock" - EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0" - EMBEDDING_DIMENSIONS: "1024" - run: poetry run python ./examples/python/simple_example.py diff --git a/.github/workflows/test_llms.yml b/.github/workflows/test_llms.yml index 6b0221309..0cbbc7b3a 100644 --- a/.github/workflows/test_llms.yml +++ b/.github/workflows/test_llms.yml @@ -84,3 +84,91 @@ jobs: EMBEDDING_DIMENSIONS: "3072" EMBEDDING_MAX_TOKENS: "8191" run: uv run python ./examples/python/simple_example.py + + test-bedrock-api-key: + name: Run Bedrock API Key Test + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + extra-dependencies: "aws" + + - name: Run Bedrock API Key Test + env: + LLM_PROVIDER: "bedrock" + LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }} + LLM_MODEL: "eu.amazon.nova-lite-v1:0" + LLM_MAX_TOKENS: "16384" + AWS_REGION_NAME: "eu-west-1" + EMBEDDING_PROVIDER: "bedrock" + EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }} + EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0" + EMBEDDING_DIMENSIONS: "1024" + EMBEDDING_MAX_TOKENS: "8191" + run: poetry run python ./examples/python/simple_example.py + + test-bedrock-aws-credentials: + name: Run Bedrock AWS Credentials Test + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + extra-dependencies: "aws" + + - name: Run Bedrock API Key Test + env: + LLM_PROVIDER: "bedrock" + LLM_MODEL: "eu.amazon.nova-lite-v1:0" + LLM_MAX_TOKENS: "16384" + AWS_REGION_NAME: "eu-west-1" + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + EMBEDDING_PROVIDER: "bedrock" + EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }} + EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0" + EMBEDDING_DIMENSIONS: "1024" + EMBEDDING_MAX_TOKENS: "8191" + run: poetry run python ./examples/python/simple_example.py + + test-bedrock-aws-profile: + name: Run Bedrock AWS Profile Test + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Configure AWS Profile + run: | + mkdir -p ~/.aws + cat > ~/.aws/credentials << EOF + [bedrock-test] + aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }} + aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }} + EOF + + - name: Run Bedrock AWS Profile Test + env: + LLM_PROVIDER: "bedrock" + LLM_MODEL: "eu.amazon.nova-lite-v1:0" + AWS_PROFILE_NAME: "bedrock-test" + AWS_REGION_NAME: "eu-west-1" + EMBEDDING_PROVIDER: "bedrock" + EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0" + EMBEDDING_DIMENSIONS: "1024" + EMBEDDING_MAX_TOKENS: "8191" + run: poetry run python ./examples/python/simple_example.py \ No newline at end of file diff --git a/.github/workflows/test_suites.yml b/.github/workflows/test_suites.yml index 9f2767faf..be1e354fc 100644 --- a/.github/workflows/test_suites.yml +++ b/.github/workflows/test_suites.yml @@ -139,24 +139,6 @@ jobs: uses: ./.github/workflows/test_llms.yml secrets: inherit - bedrock-tests: - name: Bedrock Tests - needs: [basic-tests, e2e-tests] - uses: ./.github/workflows/test_bedrock_api_key.yml - secrets: inherit - - bedrock-aws-credentials-tests: - name: Bedrock AWS Credentials Tests - needs: [basic-tests, e2e-tests] - uses: ./.github/workflows/test_bedrock_aws_credentials.yml - secrets: inherit - - bedrock-aws-profile-tests: - name: Bedrock AWS Profile Tests - needs: [basic-tests, e2e-tests] - uses: ./.github/workflows/test_bedrock_aws_profile.yml - secrets: inherit - # Ollama tests moved to the end ollama-tests: name: Ollama Tests @@ -193,9 +175,6 @@ jobs: db-examples-tests, mcp-test, llm-tests, - bedrock-tests, - bedrock-aws-credentials-tests, - bedrock-aws-profile-tests, ollama-tests, relational-db-migration-tests, docker-compose-test, @@ -218,9 +197,6 @@ jobs: "${{ needs.db-examples-tests.result }}" == "success" && "${{ needs.relational-db-migration-tests.result }}" == "success" && "${{ needs.llm-tests.result }}" == "success" && - "${{ needs.bedrock-tests.result }}" == "success" && - "${{ needs.bedrock-aws-credentials-tests.result }}" == "success" && - "${{ needs.bedrock-aws-profile-tests.result }}" == "success" && "${{ needs.docker-compose-test.result }}" == "success" && "${{ needs.docker-ci-test.result }}" == "success" && "${{ needs.ollama-tests.result }}" == "success" ]]; then diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py index 6fb964a82..ad7cdf994 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py @@ -3,4 +3,3 @@ from .adapter import BedrockAdapter __all__ = ["BedrockAdapter"] - diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py index 66f484164..c461a0886 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py @@ -1,18 +1,19 @@ import litellm import instructor -from typing import Type, Optional +from typing import Type from pydantic import BaseModel from litellm.exceptions import ContentPolicyViolationError from instructor.exceptions import InstructorRetryException -from cognee.exceptions import InvalidValueError from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( LLMInterface, ) -from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.infrastructure.llm.exceptions import ( + ContentPolicyFilterError, + MissingSystemPromptPathError, +) from cognee.infrastructure.files.storage.s3_config import get_s3_config -from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( rate_limit_async, rate_limit_sync, @@ -35,6 +36,7 @@ class BedrockAdapter(LLMInterface): name = "Bedrock" model: str api_key: str + default_instructor_mode = "json_schema_mode" MAX_RETRIES = 5 @@ -42,23 +44,23 @@ class BedrockAdapter(LLMInterface): self, model: str, api_key: str = None, - max_tokens: int = 16384, + max_completion_tokens: int = 16384, streaming: bool = False, + instructor_mode: str = None, ): - self.aclient = instructor.from_litellm(litellm.acompletion) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + + self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode(self.instructor_mode)) self.client = instructor.from_litellm(litellm.completion) self.model = model self.api_key = api_key - self.max_tokens = max_tokens + self.max_completion_tokens = max_completion_tokens self.streaming = streaming def _create_bedrock_request( self, text_input: str, system_prompt: str, response_model: Type[BaseModel] ) -> dict: - """Create Bedrock request with authentication and enhanced JSON formatting.""" - enhanced_system_prompt = f"""{system_prompt} - -IMPORTANT: You must respond with valid JSON only. Do not include any text before or after the JSON. The response must be a valid JSON object that can be parsed directly.""" + """Create Bedrock request with authentication.""" request_params = { "model": self.model, @@ -66,11 +68,11 @@ IMPORTANT: You must respond with valid JSON only. Do not include any text before "drop_params": True, "messages": [ {"role": "user", "content": text_input}, - {"role": "system", "content": enhanced_system_prompt}, + {"role": "system", "content": system_prompt}, ], "response_model": response_model, "max_retries": self.MAX_RETRIES, - "max_tokens": self.max_tokens, + "max_completion_tokens": self.max_completion_tokens, "stream": self.streaming, } @@ -87,9 +89,10 @@ IMPORTANT: You must respond with valid JSON only. Do not include any text before elif s3_config.aws_profile_name: request_params["aws_profile_name"] = s3_config.aws_profile_name + if s3_config.aws_region: + request_params["aws_region_name"] = s3_config.aws_region + # Add optional parameters - if s3_config.aws_region_name: - request_params["aws_region_name"] = s3_config.aws_region_name if s3_config.aws_bedrock_runtime_endpoint: request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint @@ -137,7 +140,7 @@ IMPORTANT: You must respond with valid JSON only. Do not include any text before if not text_input: text_input = "No user input provided." if not system_prompt: - raise InvalidValueError(message="No system prompt path provided.") + raise MissingSystemPromptPathError() system_prompt = LLMGateway.read_query_prompt(system_prompt) formatted_prompt = ( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 086fd84de..954d85c1d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -172,8 +172,8 @@ def get_llm_client(raise_api_key_error: bool = True): ) elif provider == LLMProvider.BEDROCK: - if llm_config.llm_api_key is None and raise_api_key_error: - raise LLMAPIKeyNotSetError() + # if llm_config.llm_api_key is None and raise_api_key_error: + # raise LLMAPIKeyNotSetError() from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import ( BedrockAdapter, @@ -182,8 +182,9 @@ def get_llm_client(raise_api_key_error: bool = True): return BedrockAdapter( model=llm_config.llm_model, api_key=llm_config.llm_api_key, - max_tokens=max_completion_tokens, + max_completion_tokens=max_completion_tokens, streaming=llm_config.llm_streaming, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) else: diff --git a/cognee/modules/settings/get_settings.py b/cognee/modules/settings/get_settings.py index 7e58e981f..37093bb35 100644 --- a/cognee/modules/settings/get_settings.py +++ b/cognee/modules/settings/get_settings.py @@ -164,16 +164,16 @@ def get_settings() -> SettingsDict: ], "bedrock": [ { - "value": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", - "label": "Claude 3.5 Sonnet", + "value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0", + "label": "Claude 4.5 Sonnet", }, { - "value": "us.anthropic.claude-3-5-haiku-20241022-v1:0", - "label": "Claude 3.5 Haiku", + "value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0", + "label": "Claude 4.5 Haiku", }, { - "value": "us.anthropic.claude-3-5-sonnet-20240620-v1:0", - "label": "Claude 3.5 Sonnet (June)", + "value": "eu.amazon.nova-lite-v1:0", + "label": "Amazon Nova Lite", }, ], }, From 4c6bed885e04d1a97367493de276cc781bbfe8f4 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 25 Nov 2025 13:02:26 +0100 Subject: [PATCH 06/31] chore: ruff format --- .../litellm_instructor/llm/bedrock/adapter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py index c461a0886..1faec2d0b 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py @@ -50,7 +50,9 @@ class BedrockAdapter(LLMInterface): ): self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode(self.instructor_mode)) + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) + ) self.client = instructor.from_litellm(litellm.completion) self.model = model self.api_key = api_key From 9e652a3a935fecf7d973e9478bff2237df1cf09d Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 25 Nov 2025 13:34:18 +0100 Subject: [PATCH 07/31] fix: use uv instead of poetry in CI tests --- .github/workflows/test_llms.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_llms.yml b/.github/workflows/test_llms.yml index 0cbbc7b3a..cc21dc97b 100644 --- a/.github/workflows/test_llms.yml +++ b/.github/workflows/test_llms.yml @@ -110,7 +110,7 @@ jobs: EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0" EMBEDDING_DIMENSIONS: "1024" EMBEDDING_MAX_TOKENS: "8191" - run: poetry run python ./examples/python/simple_example.py + run: uv run python ./examples/python/simple_example.py test-bedrock-aws-credentials: name: Run Bedrock AWS Credentials Test @@ -138,7 +138,7 @@ jobs: EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0" EMBEDDING_DIMENSIONS: "1024" EMBEDDING_MAX_TOKENS: "8191" - run: poetry run python ./examples/python/simple_example.py + run: uv run python ./examples/python/simple_example.py test-bedrock-aws-profile: name: Run Bedrock AWS Profile Test @@ -171,4 +171,4 @@ jobs: EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0" EMBEDDING_DIMENSIONS: "1024" EMBEDDING_MAX_TOKENS: "8191" - run: poetry run python ./examples/python/simple_example.py \ No newline at end of file + run: uv run python ./examples/python/simple_example.py \ No newline at end of file From 7c5a17ecb5b54a2ac591e254c4a64f74113d7e5a Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 26 Nov 2025 11:02:36 +0100 Subject: [PATCH 08/31] test: add extra dependency to bedrock ci test --- .github/workflows/test_llms.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test_llms.yml b/.github/workflows/test_llms.yml index cc21dc97b..bb9d1792a 100644 --- a/.github/workflows/test_llms.yml +++ b/.github/workflows/test_llms.yml @@ -151,6 +151,7 @@ jobs: uses: ./.github/actions/cognee_setup with: python-version: '3.11.x' + extra-dependencies: "aws" - name: Configure AWS Profile run: | From 02b9fa485cb7cd74ca4a35e8b6bc8a0521c6c164 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 26 Nov 2025 12:33:22 +0100 Subject: [PATCH 09/31] fix: remove random addition to pyproject file --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index add0b1603..a9b895dfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,6 @@ exclude = [ "/dist", "/.data", "/.github", - "/alembic", "/deployment", "/cognee-mcp", "/cognee-frontend", From 700362a2332d4381c3f4b3dfe686ecbb8dc9a9d1 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 26 Nov 2025 13:35:56 +0100 Subject: [PATCH 10/31] fix: fix model names and test names --- .github/workflows/test_llms.yml | 13 +++++++------ cognee/modules/settings/get_settings.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test_llms.yml b/.github/workflows/test_llms.yml index bb9d1792a..8f9d30d10 100644 --- a/.github/workflows/test_llms.yml +++ b/.github/workflows/test_llms.yml @@ -98,11 +98,11 @@ jobs: python-version: '3.11.x' extra-dependencies: "aws" - - name: Run Bedrock API Key Test + - name: Run Bedrock API Key Simple Example env: LLM_PROVIDER: "bedrock" LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }} - LLM_MODEL: "eu.amazon.nova-lite-v1:0" + LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0" LLM_MAX_TOKENS: "16384" AWS_REGION_NAME: "eu-west-1" EMBEDDING_PROVIDER: "bedrock" @@ -125,10 +125,10 @@ jobs: python-version: '3.11.x' extra-dependencies: "aws" - - name: Run Bedrock API Key Test + - name: Run Bedrock AWS Credentials Simple Example env: LLM_PROVIDER: "bedrock" - LLM_MODEL: "eu.amazon.nova-lite-v1:0" + LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0" LLM_MAX_TOKENS: "16384" AWS_REGION_NAME: "eu-west-1" AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} @@ -162,10 +162,11 @@ jobs: aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }} EOF - - name: Run Bedrock AWS Profile Test + - name: Run Bedrock AWS Profile Simple Example env: LLM_PROVIDER: "bedrock" - LLM_MODEL: "eu.amazon.nova-lite-v1:0" + LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0" + LLM_MAX_TOKENS: "16384" AWS_PROFILE_NAME: "bedrock-test" AWS_REGION_NAME: "eu-west-1" EMBEDDING_PROVIDER: "bedrock" diff --git a/cognee/modules/settings/get_settings.py b/cognee/modules/settings/get_settings.py index 37093bb35..4132ba048 100644 --- a/cognee/modules/settings/get_settings.py +++ b/cognee/modules/settings/get_settings.py @@ -164,7 +164,7 @@ def get_settings() -> SettingsDict: ], "bedrock": [ { - "value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0", + "value": "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", "label": "Claude 4.5 Sonnet", }, { From aa8afefe8a7ae4233e82edc71ee9441f0b68d325 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 27 Nov 2025 17:05:37 +0100 Subject: [PATCH 11/31] feat: add kwargs to cognify and related tasks --- cognee/api/v1/cognify/cognify.py | 4 ++++ cognee/infrastructure/llm/LLMGateway.py | 4 ++-- .../llm/extraction/knowledge_graph/extract_content_graph.py | 4 ++-- cognee/tasks/graph/extract_graph_from_data.py | 3 ++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 0fa345176..bb2ebe86e 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -53,6 +53,7 @@ async def cognify( custom_prompt: Optional[str] = None, temporal_cognify: bool = False, data_per_batch: int = 20, + **kwargs ): """ Transform ingested data into a structured knowledge graph. @@ -224,6 +225,7 @@ async def cognify( config=config, custom_prompt=custom_prompt, chunks_per_batch=chunks_per_batch, + **kwargs, ) # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for @@ -251,6 +253,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's config: Config = None, custom_prompt: Optional[str] = None, chunks_per_batch: int = 100, + **kwargs, ) -> list[Task]: if config is None: ontology_config = get_ontology_env_config() @@ -286,6 +289,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's config=config, custom_prompt=custom_prompt, task_config={"batch_size": chunks_per_batch}, + **kwargs, ), # Generate knowledge graphs from the document chunks. Task( summarize_text, diff --git a/cognee/infrastructure/llm/LLMGateway.py b/cognee/infrastructure/llm/LLMGateway.py index ab5bb35d7..fd42eb55e 100644 --- a/cognee/infrastructure/llm/LLMGateway.py +++ b/cognee/infrastructure/llm/LLMGateway.py @@ -11,7 +11,7 @@ class LLMGateway: @staticmethod def acreate_structured_output( - text_input: str, system_prompt: str, response_model: Type[BaseModel] + text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs ) -> Coroutine: llm_config = get_llm_config() if llm_config.structured_output_framework.upper() == "BAML": @@ -31,7 +31,7 @@ class LLMGateway: llm_client = get_llm_client() return llm_client.acreate_structured_output( - text_input=text_input, system_prompt=system_prompt, response_model=response_model + text_input=text_input, system_prompt=system_prompt, response_model=response_model, **kwargs ) @staticmethod diff --git a/cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py b/cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py index 59e6f563a..4a40979f4 100644 --- a/cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +++ b/cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py @@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import ( async def extract_content_graph( - content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None + content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs ): if custom_prompt: system_prompt = custom_prompt @@ -30,7 +30,7 @@ async def extract_content_graph( system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory) content_graph = await LLMGateway.acreate_structured_output( - content, system_prompt, response_model + content, system_prompt, response_model, **kwargs ) return content_graph diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 49b51af2d..965214677 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -99,6 +99,7 @@ async def extract_graph_from_data( graph_model: Type[BaseModel], config: Config = None, custom_prompt: Optional[str] = None, + **kwargs, ) -> List[DocumentChunk]: """ Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model. @@ -113,7 +114,7 @@ async def extract_graph_from_data( chunk_graphs = await asyncio.gather( *[ - extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt) + extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt, **kwargs) for chunk in data_chunks ] ) From af8c5bedcc48e18c3723a2fbfa8afba3de242cbb Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 11 Dec 2025 17:47:23 +0100 Subject: [PATCH 12/31] feat: add kwargs to other adapters --- .../litellm_instructor/llm/anthropic/adapter.py | 2 +- .../litellm_instructor/llm/gemini/adapter.py | 2 +- .../litellm_instructor/llm/generic_llm_api/adapter.py | 2 +- .../litellm_instructor/llm/mistral/adapter.py | 2 +- .../litellm_instructor/llm/ollama/adapter.py | 6 +++--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index dbf0dfbea..46e2b2736 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -51,7 +51,7 @@ class AnthropicAdapter(LLMInterface): reraise=True, ) async def acreate_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs ) -> BaseModel: """ Generate a response from a user query. diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 226f291d7..66d53b842 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -79,7 +79,7 @@ class GeminiAdapter(LLMInterface): reraise=True, ) async def acreate_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs ) -> BaseModel: """ Generate a response from a user query. diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 9d7f25fc5..3049b3c4f 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -79,7 +79,7 @@ class GenericAPIAdapter(LLMInterface): reraise=True, ) async def acreate_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs ) -> BaseModel: """ Generate a response from a user query. diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 355cdae0b..146d0a07a 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -68,7 +68,7 @@ class MistralAdapter(LLMInterface): reraise=True, ) async def acreate_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs ) -> BaseModel: """ Generate a response from the user query. diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index aabd19867..5ae09a4ac 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -74,7 +74,7 @@ class OllamaAPIAdapter(LLMInterface): reraise=True, ) async def acreate_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs ) -> BaseModel: """ Generate a structured output from the LLM using the provided text and system prompt. @@ -121,7 +121,7 @@ class OllamaAPIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input_file: str) -> str: + async def create_transcript(self, input_file: str, **kwargs) -> str: """ Generate an audio transcript from a user query. @@ -160,7 +160,7 @@ class OllamaAPIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def transcribe_image(self, input_file: str) -> str: + async def transcribe_image(self, input_file: str, **kwargs) -> str: """ Transcribe content from an image using base64 encoding. From 248ba74592fe4697d16746f38c382505e7e30685 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 11 Dec 2025 18:18:42 +0100 Subject: [PATCH 13/31] test: remove codify-related stuff from mcp test --- cognee-mcp/src/test_client.py | 110 +--------------------------------- 1 file changed, 2 insertions(+), 108 deletions(-) diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py index 23160d8b2..93a040e38 100755 --- a/cognee-mcp/src/test_client.py +++ b/cognee-mcp/src/test_client.py @@ -3,7 +3,7 @@ Test client for Cognee MCP Server functionality. This script tests all the tools and functions available in the Cognee MCP server, -including cognify, codify, search, prune, status checks, and utility functions. +including cognify, search, prune, status checks, and utility functions. Usage: # Set your OpenAI API key first @@ -35,7 +35,7 @@ from src.server import ( load_class, ) -# Set timeout for cognify/codify to complete in +# Set timeout for cognify to complete in TIMEOUT = 5 * 60 # 5 min in seconds @@ -151,12 +151,9 @@ DEBUG = True expected_tools = { "cognify", - "codify", "search", "prune", "cognify_status", - "codify_status", - "cognee_add_developer_rules", "list_data", "delete", } @@ -247,106 +244,6 @@ DEBUG = True } print(f"❌ {test_name} test failed: {e}") - async def test_codify(self): - """Test the codify functionality using MCP client.""" - print("\n🧪 Testing codify functionality...") - try: - async with self.mcp_server_session() as session: - codify_result = await session.call_tool( - "codify", arguments={"repo_path": self.test_repo_dir} - ) - - start = time.time() # mark the start - while True: - try: - # Wait a moment - await asyncio.sleep(5) - - # Check if codify processing is finished - status_result = await session.call_tool("codify_status", arguments={}) - if hasattr(status_result, "content") and status_result.content: - status_text = ( - status_result.content[0].text - if status_result.content - else str(status_result) - ) - else: - status_text = str(status_result) - - if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text: - break - elif time.time() - start > TIMEOUT: - raise TimeoutError("Codify did not complete in 5min") - except DatabaseNotCreatedError: - if time.time() - start > TIMEOUT: - raise TimeoutError("Database was not created in 5min") - - self.test_results["codify"] = { - "status": "PASS", - "result": codify_result, - "message": "Codify executed successfully", - } - print("✅ Codify test passed") - - except Exception as e: - self.test_results["codify"] = { - "status": "FAIL", - "error": str(e), - "message": "Codify test failed", - } - print(f"❌ Codify test failed: {e}") - - async def test_cognee_add_developer_rules(self): - """Test the cognee_add_developer_rules functionality using MCP client.""" - print("\n🧪 Testing cognee_add_developer_rules functionality...") - try: - async with self.mcp_server_session() as session: - result = await session.call_tool( - "cognee_add_developer_rules", arguments={"base_path": self.test_data_dir} - ) - - start = time.time() # mark the start - while True: - try: - # Wait a moment - await asyncio.sleep(5) - - # Check if developer rule cognify processing is finished - status_result = await session.call_tool("cognify_status", arguments={}) - if hasattr(status_result, "content") and status_result.content: - status_text = ( - status_result.content[0].text - if status_result.content - else str(status_result) - ) - else: - status_text = str(status_result) - - if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text: - break - elif time.time() - start > TIMEOUT: - raise TimeoutError( - "Cognify of developer rules did not complete in 5min" - ) - except DatabaseNotCreatedError: - if time.time() - start > TIMEOUT: - raise TimeoutError("Database was not created in 5min") - - self.test_results["cognee_add_developer_rules"] = { - "status": "PASS", - "result": result, - "message": "Developer rules addition executed successfully", - } - print("✅ Developer rules test passed") - - except Exception as e: - self.test_results["cognee_add_developer_rules"] = { - "status": "FAIL", - "error": str(e), - "message": "Developer rules test failed", - } - print(f"❌ Developer rules test failed: {e}") - async def test_search_functionality(self): """Test the search functionality with different search types using MCP client.""" print("\n🧪 Testing search functionality...") @@ -681,9 +578,6 @@ class TestModel: test_name="Cognify2", ) - await self.test_codify() - await self.test_cognee_add_developer_rules() - # Test list_data and delete functionality await self.test_list_data() await self.test_delete() From 0f50c993ac502da598db0f77ce2bcd510b79ce22 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 11 Dec 2025 18:20:07 +0100 Subject: [PATCH 14/31] chore: add quick option to isolate mcp CI test --- .github/workflows/test_mcp.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test_mcp.yml b/.github/workflows/test_mcp.yml index 09c8db69d..d667ed5f9 100644 --- a/.github/workflows/test_mcp.yml +++ b/.github/workflows/test_mcp.yml @@ -1,6 +1,9 @@ name: test | mcp on: + push: + branches: + - feature/cog-3543-remove-anything-codify-related-from-mcp-test workflow_call: jobs: From e211e66275919d6d56bd1dadfc7dc3e536fe10cc Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 11 Dec 2025 18:29:17 +0100 Subject: [PATCH 15/31] chore: remove quick option to isolate mcp CI test --- .github/workflows/test_mcp.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/test_mcp.yml b/.github/workflows/test_mcp.yml index d667ed5f9..09c8db69d 100644 --- a/.github/workflows/test_mcp.yml +++ b/.github/workflows/test_mcp.yml @@ -1,9 +1,6 @@ name: test | mcp on: - push: - branches: - - feature/cog-3543-remove-anything-codify-related-from-mcp-test workflow_call: jobs: From 3b8a607b5fa524dd77fea2fe157ef01f0313159f Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 12 Dec 2025 11:37:27 +0100 Subject: [PATCH 16/31] test: fix errors in mcp test --- .github/workflows/test_mcp.yml | 3 +++ cognee-mcp/src/test_client.py | 18 ++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test_mcp.yml b/.github/workflows/test_mcp.yml index 09c8db69d..d667ed5f9 100644 --- a/.github/workflows/test_mcp.yml +++ b/.github/workflows/test_mcp.yml @@ -1,6 +1,9 @@ name: test | mcp on: + push: + branches: + - feature/cog-3543-remove-anything-codify-related-from-mcp-test workflow_call: jobs: diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py index 93a040e38..c4e5b0573 100755 --- a/cognee-mcp/src/test_client.py +++ b/cognee-mcp/src/test_client.py @@ -256,7 +256,7 @@ DEBUG = True # Go through all Cognee search types for search_type in SearchType: # Don't test these search types - if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER]: + if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER, SearchType.TRIPLET_COMPLETION]: break try: async with self.mcp_server_session() as session: @@ -420,15 +420,13 @@ DEBUG = True if invalid_result.content and len(invalid_result.content) > 0: invalid_content = invalid_result.content[0].text - if "Invalid UUID format" in invalid_content: - self.test_results["delete_error_handling"] = { - "status": "PASS", - "result": invalid_content, - "message": "delete error handling works correctly", - } - print("✅ delete error handling test passed") - else: - raise Exception(f"Expected UUID error not found: {invalid_content}") + assert("Invalid UUID format" in invalid_content) + self.test_results["delete_error_handling"] = { + "status": "PASS", + "result": invalid_content, + "message": "delete error handling works correctly", + } + print("✅ delete error handling test passed") else: raise Exception("Delete error test returned no content") From c48b2745712fe0c02a5e0f6c1d148ba77b68ddfb Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 12 Dec 2025 11:53:40 +0100 Subject: [PATCH 17/31] test: remove delete error from mcp test --- cognee-mcp/src/test_client.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py index c4e5b0573..b03fef0db 100755 --- a/cognee-mcp/src/test_client.py +++ b/cognee-mcp/src/test_client.py @@ -408,25 +408,24 @@ DEBUG = True else: # Test with invalid UUIDs to check error handling - invalid_result = await session.call_tool( - "delete", - arguments={ - "data_id": "invalid-uuid", - "dataset_id": "another-invalid-uuid", - "mode": "soft", - }, - ) - - if invalid_result.content and len(invalid_result.content) > 0: - invalid_content = invalid_result.content[0].text - - assert("Invalid UUID format" in invalid_content) + try: + await session.call_tool( + "delete", + arguments={ + "data_id": "invalid-uuid", + "dataset_id": "another-invalid-uuid", + "mode": "soft", + }, + ) + except Exception as e: + assert ("Invalid UUID format" in e) self.test_results["delete_error_handling"] = { "status": "PASS", - "result": invalid_content, + "result": e, "message": "delete error handling works correctly", } print("✅ delete error handling test passed") + else: raise Exception("Delete error test returned no content") From bce6094010ef53af2130ed5ca045df7b5971ed8c Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 12 Dec 2025 12:43:54 +0100 Subject: [PATCH 18/31] test: change logger --- cognee-mcp/src/test_client.py | 42 +++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py index b03fef0db..5ef8a8be4 100755 --- a/cognee-mcp/src/test_client.py +++ b/cognee-mcp/src/test_client.py @@ -23,6 +23,7 @@ import tempfile import time from contextlib import asynccontextmanager from cognee.shared.logging_utils import setup_logging +from logging import ERROR, INFO from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -407,27 +408,32 @@ DEBUG = True raise Exception("Delete returned no content") else: + logger = setup_logging(log_level=INFO) # Test with invalid UUIDs to check error handling - try: - await session.call_tool( - "delete", - arguments={ - "data_id": "invalid-uuid", - "dataset_id": "another-invalid-uuid", - "mode": "soft", - }, - ) - except Exception as e: - assert ("Invalid UUID format" in e) - self.test_results["delete_error_handling"] = { - "status": "PASS", - "result": e, - "message": "delete error handling works correctly", - } - print("✅ delete error handling test passed") + invalid_result = await session.call_tool( + "delete", + arguments={ + "data_id": "invalid-uuid", + "dataset_id": "another-invalid-uuid", + "mode": "soft", + }, + ) + if invalid_result.content and len(invalid_result.content) > 0: + invalid_content = invalid_result.content[0].text + + if "Invalid UUID format" in invalid_content: + self.test_results["delete_error_handling"] = { + "status": "PASS", + "result": invalid_content, + "message": "delete error handling works correctly", + } + print("✅ delete error handling test passed") + else: + raise Exception(f"Expected UUID error not found: {invalid_content}") else: raise Exception("Delete error test returned no content") + logger = setup_logging(log_level=ERROR) except Exception as e: self.test_results["delete"] = { @@ -630,7 +636,5 @@ async def main(): if __name__ == "__main__": - from logging import ERROR - logger = setup_logging(log_level=ERROR) asyncio.run(main()) From a337f4e54ca323e7c4dc060dadf7b78efe7b63f3 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 12 Dec 2025 13:02:55 +0100 Subject: [PATCH 19/31] test: testing logger --- cognee-mcp/src/test_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py index 5ef8a8be4..ed779a106 100755 --- a/cognee-mcp/src/test_client.py +++ b/cognee-mcp/src/test_client.py @@ -408,7 +408,6 @@ DEBUG = True raise Exception("Delete returned no content") else: - logger = setup_logging(log_level=INFO) # Test with invalid UUIDs to check error handling invalid_result = await session.call_tool( "delete", @@ -433,7 +432,6 @@ DEBUG = True raise Exception(f"Expected UUID error not found: {invalid_content}") else: raise Exception("Delete error test returned no content") - logger = setup_logging(log_level=ERROR) except Exception as e: self.test_results["delete"] = { @@ -636,5 +634,5 @@ async def main(): if __name__ == "__main__": - logger = setup_logging(log_level=ERROR) + logger = setup_logging(log_level=INFO) asyncio.run(main()) From a225d7fc61afb86f0cc7303bdd7869f6d5730907 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 12 Dec 2025 13:44:58 +0100 Subject: [PATCH 20/31] test: revert some changes --- .github/workflows/test_mcp.yml | 3 --- cognee-mcp/src/test_client.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/test_mcp.yml b/.github/workflows/test_mcp.yml index d667ed5f9..09c8db69d 100644 --- a/.github/workflows/test_mcp.yml +++ b/.github/workflows/test_mcp.yml @@ -1,9 +1,6 @@ name: test | mcp on: - push: - branches: - - feature/cog-3543-remove-anything-codify-related-from-mcp-test workflow_call: jobs: diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py index ed779a106..06c1fceeb 100755 --- a/cognee-mcp/src/test_client.py +++ b/cognee-mcp/src/test_client.py @@ -634,5 +634,5 @@ async def main(): if __name__ == "__main__": - logger = setup_logging(log_level=INFO) + logger = setup_logging(log_level=ERROR) asyncio.run(main()) From 116b6f1eeb4cc6e068abadc8b42bee31dff40609 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 12 Dec 2025 13:46:16 +0100 Subject: [PATCH 21/31] chore: formatting --- cognee-mcp/src/test_client.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py index 06c1fceeb..bce7f807f 100755 --- a/cognee-mcp/src/test_client.py +++ b/cognee-mcp/src/test_client.py @@ -257,7 +257,11 @@ DEBUG = True # Go through all Cognee search types for search_type in SearchType: # Don't test these search types - if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER, SearchType.TRIPLET_COMPLETION]: + if search_type in [ + SearchType.NATURAL_LANGUAGE, + SearchType.CYPHER, + SearchType.TRIPLET_COMPLETION, + ]: break try: async with self.mcp_server_session() as session: From 14ff94f269599140df6e830761ef3b6f2c99eb28 Mon Sep 17 00:00:00 2001 From: Pavel Zorin Date: Thu, 11 Dec 2025 12:38:19 +0100 Subject: [PATCH 22/31] Initial release pipeline --- .github/workflows/release.yml | 154 ++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000..a19635628 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,154 @@ +name: release.yml +on: + workflow_dispatch: + inputs: + flavour: + required: true + default: dev + type: choice + options: + - dev + - main + description: Dev or Main release + test_mode: + required: true + type: boolean + description: Aka Dry Run. If true, it won't affect public indices or repositories + +jobs: + release-github: + name: Create GitHub Release from ${{ inputs.flavour }} + outputs: + tag: ${{ steps.create_tag.outputs.tag }} + version: ${{ steps.create_tag.outputs.version }} + permissions: + contents: write + runs-on: ubuntu-latest + + steps: + - name: Check out ${{ inputs.flavour }} + uses: actions/checkout@v4 + with: + ref: ${{ inputs.flavour }} + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Create and push git tag + id: create_tag + env: + TEST_MODE: ${{ inputs.test_mode }} + run: | + VERSION="$(uv version --short)" + TAG="v${VERSION}" + + echo "Tag to create: ${TAG}" + + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + + echo "tag=${TAG}" >> "$GITHUB_OUTPUT" + echo "version=${VERSION}" >> "$GITHUB_OUTPUT" + + if [ "$TEST_MODE" = "false" ]; then + git tag "${TAG}" + git push origin "${TAG}" + else + echo "Test mode is enabled. Skipping tag creation and push." + fi + + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ steps.create_tag.outputs.tag }} + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + release-pypi-package: + needs: release-github + name: Release PyPI Package from ${{ inputs.flavour }} + permissions: + contents: read + runs-on: ubuntu-latest + + steps: + - name: Check out ${{ inputs.flavour }} + uses: actions/checkout@v4 + with: + ref: ${{ inputs.flavour }} + + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Install Python + run: uv python install + + - name: Install dependencies + run: uv sync --locked --all-extras + + - name: Build distributions + run: uv build + + - name: Publish ${{ inputs.flavour }} release to TestPyPI + if: ${{ !inputs.test_mode }} + env: + UV_PUBLISH_TOKEN: ${{ secrets.TEST_PYPI_TOKEN }} + run: uv publish --publish-url https://test.pypi.org/legacy/ + + - name: Publish ${{ inputs.flavour }} release to PyPI + if: ${{ !inputs.test_mode }} + env: + UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }} + run: uv publish + + release-docker-image: + needs: release-github + name: Release Docker Image from ${{ inputs.flavour }} + permissions: + contents: read + runs-on: ubuntu-latest + + steps: + - name: Check out ${{ inputs.flavour }} + uses: actions/checkout@v4 + with: + ref: ${{ inputs.flavour }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Build and push Dev Docker Image + if: ${{ inputs.flavour == 'dev' }} + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: ${{ !inputs.test_mode }} + tags: cognee/cognee:${{ needs.release-github.outputs.version }} + labels: | + version=${{ needs.release-github.outputs.version }} + flavour=${{ inputs.flavour }} + cache-from: type=registry,ref=cognee/cognee:buildcache + cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max + + - name: Build and push Main Docker Image + if: ${{ inputs.flavour == 'main' }} + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: ${{ !inputs.test_mode }} + tags: | + cognee/cognee:${{ needs.release-github.outputs.version }} + cognee/cognee:latest + labels: | + version=${{ needs.release-github.outputs.version }} + flavour=${{ inputs.flavour }} + cache-from: type=registry,ref=cognee/cognee:buildcache + cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max From a6bc27afaaeb901e5e771a84ca5e9ba2af473aba Mon Sep 17 00:00:00 2001 From: Pavel Zorin Date: Fri, 12 Dec 2025 17:31:54 +0100 Subject: [PATCH 23/31] Cleanup --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a19635628..ff2f809f3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -90,7 +90,7 @@ jobs: run: uv build - name: Publish ${{ inputs.flavour }} release to TestPyPI - if: ${{ !inputs.test_mode }} + if: ${{ inputs.test_mode }} env: UV_PUBLISH_TOKEN: ${{ secrets.TEST_PYPI_TOKEN }} run: uv publish --publish-url https://test.pypi.org/legacy/ From c94225f5058990d99bacebddd5ad0883861ff478 Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:30:22 +0100 Subject: [PATCH 24/31] fix: make ontology key an optional param in cognify (#1894) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Make ontology key optional in Swagger and None by default (it was "string" by default before change which was causing issues when running cognify endpoint) ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Documentation** * Enhanced API documentation with additional examples and validation metadata to improve request clarity and validation guidance. ✏️ Tip: You can customize this high-level summary in your review settings. --- cognee/api/v1/cognify/routers/get_cognify_router.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 4f1497e3c..a499b3ca3 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -42,7 +42,9 @@ class CognifyPayloadDTO(InDTO): default="", description="Custom prompt for entity extraction and graph generation" ) ontology_key: Optional[List[str]] = Field( - default=None, description="Reference to one or more previously uploaded ontologies" + default=None, + examples=[[]], + description="Reference to one or more previously uploaded ontologies", ) From bad22ba26be2ef7539df73b9b03b93fe9c0a3994 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 15 Dec 2025 15:45:35 +0100 Subject: [PATCH 25/31] chore: adds id generation to memify triplet embedding pipeline (#1895) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR adds id generation to the Triplet objects in triplet embedding memify pipeline. In some edge cases duplicated elements could have been ingested into the collection ## Type of Change - [x] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit ## Release Notes * **Enhancements** * Relationship data now includes unique identifiers for improved tracking and data management capabilities. ✏️ Tip: You can customize this high-level summary in your review settings. --- cognee/tasks/memify/get_triplet_datapoints.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cognee/tasks/memify/get_triplet_datapoints.py b/cognee/tasks/memify/get_triplet_datapoints.py index bfc02ec6a..764adfb63 100644 --- a/cognee/tasks/memify/get_triplet_datapoints.py +++ b/cognee/tasks/memify/get_triplet_datapoints.py @@ -1,5 +1,6 @@ from typing import AsyncGenerator, Dict, Any, List, Optional from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine +from cognee.modules.engine.utils import generate_node_id from cognee.shared.logging_utils import get_logger from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.infrastructure.engine import DataPoint @@ -155,7 +156,12 @@ def _process_single_triplet( embeddable_text = f"{start_node_text}-›{relationship_text}-›{end_node_text}".strip() - triplet_obj = Triplet(from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text) + relationship_name = relationship.get("relationship_name", "") + triplet_id = generate_node_id(str(start_node_id) + str(relationship_name) + str(end_node_id)) + + triplet_obj = Triplet( + id=triplet_id, from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text + ) return triplet_obj, None From 14d9540d1b9d1aa3504baad0a026d7f92556c2e4 Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:15:48 +0100 Subject: [PATCH 26/31] feat: Add database deletion on dataset delete (#1893) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description - Add support for database deletion when dataset is deleted - Simplify dataset handler usage in Cognee ## Type of Change - [x] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Bug Fixes** * Improved dataset deletion: stronger authorization checks and reliable removal of associated graph and vector storage. * **Tests** * Added end-to-end test to verify complete dataset deletion and cleanup of all related storage components. ✏️ Tip: You can customize this high-level summary in your review settings. --- .github/workflows/e2e_tests.yml | 25 ++++++ cognee/api/v1/cognify/cognify.py | 2 +- .../datasets/routers/get_datasets_router.py | 6 +- .../databases/utils/__init__.py | 2 + .../get_graph_dataset_database_handler.py | 10 +++ .../get_vector_dataset_database_handler.py | 10 +++ ...esolve_dataset_database_connection_info.py | 34 ++++----- cognee/infrastructure/llm/LLMGateway.py | 5 +- cognee/modules/data/deletion/prune_system.py | 38 +++------- cognee/modules/data/methods/delete_dataset.py | 26 +++++++ cognee/tests/test_dataset_delete.py | 76 +++++++++++++++++++ 11 files changed, 183 insertions(+), 51 deletions(-) create mode 100644 cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py create mode 100644 cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py create mode 100644 cognee/tests/test_dataset_delete.py diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index cb69e9ef6..8cd62910c 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -237,6 +237,31 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_dataset_database_handler.py + test-dataset-database-deletion: + name: Test dataset database deletion in Cognee + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run dataset databases deletion test + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_dataset_delete.py + test-permissions: name: Test permissions with different situations in Cognee runs-on: ubuntu-22.04 diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 9371f7ffd..ffc903d68 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -53,7 +53,7 @@ async def cognify( custom_prompt: Optional[str] = None, temporal_cognify: bool = False, data_per_batch: int = 20, - **kwargs + **kwargs, ): """ Transform ingested data into a structured knowledge graph. diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py index eff87b3af..ca738dfbe 100644 --- a/cognee/api/v1/datasets/routers/get_datasets_router.py +++ b/cognee/api/v1/datasets/routers/get_datasets_router.py @@ -208,14 +208,14 @@ def get_datasets_router() -> APIRouter: }, ) - from cognee.modules.data.methods import get_dataset, delete_dataset + from cognee.modules.data.methods import delete_dataset - dataset = await get_dataset(user.id, dataset_id) + dataset = await get_authorized_existing_datasets([dataset_id], "delete", user) if dataset is None: raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.") - await delete_dataset(dataset) + await delete_dataset(dataset[0]) @router.delete( "/{dataset_id}/data/{data_id}", diff --git a/cognee/infrastructure/databases/utils/__init__.py b/cognee/infrastructure/databases/utils/__init__.py index f31d1e0dc..3907b4325 100644 --- a/cognee/infrastructure/databases/utils/__init__.py +++ b/cognee/infrastructure/databases/utils/__init__.py @@ -1,2 +1,4 @@ from .get_or_create_dataset_database import get_or_create_dataset_database from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info +from .get_graph_dataset_database_handler import get_graph_dataset_database_handler +from .get_vector_dataset_database_handler import get_vector_dataset_database_handler diff --git a/cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py b/cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py new file mode 100644 index 000000000..d88685b48 --- /dev/null +++ b/cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py @@ -0,0 +1,10 @@ +from cognee.modules.users.models.DatasetDatabase import DatasetDatabase + + +def get_graph_dataset_database_handler(dataset_database: DatasetDatabase) -> dict: + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler] + return handler diff --git a/cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py b/cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py new file mode 100644 index 000000000..5d1152c04 --- /dev/null +++ b/cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py @@ -0,0 +1,10 @@ +from cognee.modules.users.models.DatasetDatabase import DatasetDatabase + + +def get_vector_dataset_database_handler(dataset_database: DatasetDatabase) -> dict: + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler] + return handler diff --git a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py index d33169642..561268eaf 100644 --- a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +++ b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py @@ -1,24 +1,12 @@ +from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import ( + get_graph_dataset_database_handler, +) +from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import ( + get_vector_dataset_database_handler, +) from cognee.modules.users.models.DatasetDatabase import DatasetDatabase -async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( - supported_dataset_database_handlers, - ) - - handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler] - return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) - - -async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( - supported_dataset_database_handlers, - ) - - handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler] - return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) - - async def resolve_dataset_database_connection_info( dataset_database: DatasetDatabase, ) -> DatasetDatabase: @@ -31,6 +19,12 @@ async def resolve_dataset_database_connection_info( Returns: DatasetDatabase instance with resolved connection info """ - dataset_database = await _get_vector_db_connection_info(dataset_database) - dataset_database = await _get_graph_db_connection_info(dataset_database) + vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database) + graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database) + dataset_database = await vector_dataset_database_handler[ + "handler_instance" + ].resolve_dataset_connection_info(dataset_database) + dataset_database = await graph_dataset_database_handler[ + "handler_instance" + ].resolve_dataset_connection_info(dataset_database) return dataset_database diff --git a/cognee/infrastructure/llm/LLMGateway.py b/cognee/infrastructure/llm/LLMGateway.py index fd42eb55e..7bec9ca01 100644 --- a/cognee/infrastructure/llm/LLMGateway.py +++ b/cognee/infrastructure/llm/LLMGateway.py @@ -31,7 +31,10 @@ class LLMGateway: llm_client = get_llm_client() return llm_client.acreate_structured_output( - text_input=text_input, system_prompt=system_prompt, response_model=response_model, **kwargs + text_input=text_input, + system_prompt=system_prompt, + response_model=response_model, + **kwargs, ) @staticmethod diff --git a/cognee/modules/data/deletion/prune_system.py b/cognee/modules/data/deletion/prune_system.py index 645e1a223..22a0fde5f 100644 --- a/cognee/modules/data/deletion/prune_system.py +++ b/cognee/modules/data/deletion/prune_system.py @@ -5,6 +5,10 @@ from cognee.context_global_variables import backend_access_control_enabled from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.infrastructure.databases.utils import ( + get_graph_dataset_database_handler, + get_vector_dataset_database_handler, +) from cognee.shared.cache import delete_cache from cognee.modules.users.models import DatasetDatabase from cognee.shared.logging_utils import get_logger @@ -13,22 +17,13 @@ logger = get_logger() async def prune_graph_databases(): - async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict: - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( - supported_dataset_database_handlers, - ) - - handler = supported_dataset_database_handlers[ - dataset_database.graph_dataset_database_handler - ] - return await handler["handler_instance"].delete_dataset(dataset_database) - db_engine = get_relational_engine() try: - data = await db_engine.get_all_data_from_table("dataset_database") + dataset_databases = await db_engine.get_all_data_from_table("dataset_database") # Go through each dataset database and delete the graph database - for data_item in data: - await _prune_graph_db(data_item) + for dataset_database in dataset_databases: + handler = get_graph_dataset_database_handler(dataset_database) + await handler["handler_instance"].delete_dataset(dataset_database) except (OperationalError, EntityNotFoundError) as e: logger.debug( "Skipping pruning of graph DB. Error when accessing dataset_database table: %s", @@ -38,22 +33,13 @@ async def prune_graph_databases(): async def prune_vector_databases(): - async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict: - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( - supported_dataset_database_handlers, - ) - - handler = supported_dataset_database_handlers[ - dataset_database.vector_dataset_database_handler - ] - return await handler["handler_instance"].delete_dataset(dataset_database) - db_engine = get_relational_engine() try: - data = await db_engine.get_all_data_from_table("dataset_database") + dataset_databases = await db_engine.get_all_data_from_table("dataset_database") # Go through each dataset database and delete the vector database - for data_item in data: - await _prune_vector_db(data_item) + for dataset_database in dataset_databases: + handler = get_vector_dataset_database_handler(dataset_database) + await handler["handler_instance"].delete_dataset(dataset_database) except (OperationalError, EntityNotFoundError) as e: logger.debug( "Skipping pruning of vector DB. Error when accessing dataset_database table: %s", diff --git a/cognee/modules/data/methods/delete_dataset.py b/cognee/modules/data/methods/delete_dataset.py index ff20ff9e7..dea10e741 100644 --- a/cognee/modules/data/methods/delete_dataset.py +++ b/cognee/modules/data/methods/delete_dataset.py @@ -1,8 +1,34 @@ +from cognee.modules.users.models import DatasetDatabase +from sqlalchemy import select + from cognee.modules.data.models import Dataset +from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import ( + get_vector_dataset_database_handler, +) +from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import ( + get_graph_dataset_database_handler, +) from cognee.infrastructure.databases.relational import get_relational_engine async def delete_dataset(dataset: Dataset): db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + stmt = select(DatasetDatabase).where( + DatasetDatabase.dataset_id == dataset.id, + ) + dataset_database: DatasetDatabase = await session.scalar(stmt) + if dataset_database: + graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database) + vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database) + await graph_dataset_database_handler["handler_instance"].delete_dataset( + dataset_database + ) + await vector_dataset_database_handler["handler_instance"].delete_dataset( + dataset_database + ) + # TODO: Remove dataset from pipeline_run_status in Data objects related to dataset as well + # This blocks recreation of the dataset with the same name and data after deletion as + # it's marked as completed and will be just skipped even though it's empty. return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id) diff --git a/cognee/tests/test_dataset_delete.py b/cognee/tests/test_dataset_delete.py new file mode 100644 index 000000000..372945bdb --- /dev/null +++ b/cognee/tests/test_dataset_delete.py @@ -0,0 +1,76 @@ +import os +import asyncio +import pathlib +from uuid import UUID + +import cognee +from cognee.shared.logging_utils import setup_logging, ERROR +from cognee.modules.data.methods.delete_dataset import delete_dataset +from cognee.modules.data.methods.get_dataset import get_dataset +from cognee.modules.users.methods import get_default_user + + +async def main(): + # Set data and system directory paths + data_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_dataset_delete") + ).resolve() + ) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_dataset_delete") + ).resolve() + ) + cognee.config.system_root_directory(cognee_directory_path) + + # Create a clean slate for cognee -- reset data and system state + print("Resetting cognee data...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("Data reset complete.\n") + + # cognee knowledge graph will be created based on this text + text = """ + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. + """ + + # Add the text, and make it available for cognify + await cognee.add(text, "nlp_dataset") + await cognee.add("Quantum computing is the study of quantum computers.", "quantum_dataset") + + # Use LLMs and cognee to create knowledge graph + ret_val = await cognee.cognify() + user = await get_default_user() + + for val in ret_val: + dataset_id = str(val) + vector_db_path = os.path.join( + cognee_directory_path, "databases", str(user.id), dataset_id + ".lance.db" + ) + graph_db_path = os.path.join( + cognee_directory_path, "databases", str(user.id), dataset_id + ".pkl" + ) + + # Check if databases are properly created and exist before deletion + assert os.path.exists(graph_db_path), "Graph database file not found." + assert os.path.exists(vector_db_path), "Vector database file not found." + + dataset = await get_dataset(user_id=user.id, dataset_id=UUID(dataset_id)) + await delete_dataset(dataset) + + # Confirm databases have been deleted + assert not os.path.exists(graph_db_path), "Graph database file found." + assert not os.path.exists(vector_db_path), "Vector database file found." + + +if __name__ == "__main__": + logger = setup_logging(log_level=ERROR) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) From 622f8fa79e459d4cec8000de0cbf704957405b05 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:30:35 +0100 Subject: [PATCH 27/31] chore: introduces 1 file upload in ontology endpoint (#1899) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR fixes the ontology upload endpoint by forcing 1 file upload at the time. Tests are adjusted in both server start and ontology endpoint unit test. API was tested. Do not merge it together with https://github.com/topoteretes/cognee/pull/1898 its either that or this one. ## Type of Change - [x] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **API Changes** * Ontology upload now accepts exactly one file per request; field renamed from "descriptions" to "description" and validated as a plain string. * Stricter form validation and tighter 400/500 error handling for malformed submissions. * **Tests** * Tests converted to real HTTP-style interactions using a shared test client and dependency overrides. * Payloads now use plain string fields; added coverage for single-file constraints and specific error responses. * **Style** * Minor formatting cleanups with no functional impact. ✏️ Tip: You can customize this high-level summary in your review settings. --- .../ontologies/routers/get_ontology_router.py | 48 ++--- cognee/tests/test_cognee_server_start.py | 4 +- .../tests/unit/api/test_ontology_endpoint.py | 166 ++++++++---------- 3 files changed, 100 insertions(+), 118 deletions(-) diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py index ee31c683f..77667d88d 100644 --- a/cognee/api/v1/ontologies/routers/get_ontology_router.py +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException +from fastapi import APIRouter, File, Form, UploadFile, Depends, Request from fastapi.responses import JSONResponse from typing import Optional, List @@ -15,28 +15,25 @@ def get_ontology_router() -> APIRouter: @router.post("", response_model=dict) async def upload_ontology( + request: Request, ontology_key: str = Form(...), - ontology_file: List[UploadFile] = File(...), - descriptions: Optional[str] = Form(None), + ontology_file: UploadFile = File(...), + description: Optional[str] = Form(None), user: User = Depends(get_authenticated_user), ): """ - Upload ontology files with their respective keys for later use in cognify operations. - - Supports both single and multiple file uploads: - - Single file: ontology_key=["key"], ontology_file=[file] - - Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2] + Upload a single ontology file for later use in cognify operations. ## Request Parameters - - **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies - - **ontology_file** (List[UploadFile]): OWL format ontology files - - **descriptions** (Optional[str]): JSON array string of optional descriptions + - **ontology_key** (str): User-defined identifier for the ontology. + - **ontology_file** (UploadFile): Single OWL format ontology file + - **description** (Optional[str]): Optional description for the ontology. ## Response - Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps. + Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp. ## Error Codes - - **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded + - **400 Bad Request**: Invalid file format, duplicate key, multiple files uploaded - **500 Internal Server Error**: File system or processing errors """ send_telemetry( @@ -49,16 +46,22 @@ def get_ontology_router() -> APIRouter: ) try: - import json + # Enforce: exactly one uploaded file for "ontology_file" + form = await request.form() + uploaded_files = form.getlist("ontology_file") + if len(uploaded_files) != 1: + raise ValueError("Only one ontology_file is allowed") - ontology_keys = json.loads(ontology_key) - description_list = json.loads(descriptions) if descriptions else None + if ontology_key.strip().startswith(("[", "{")): + raise ValueError("ontology_key must be a string") + if description is not None and description.strip().startswith(("[", "{")): + raise ValueError("description must be a string") - if not isinstance(ontology_keys, list): - raise ValueError("ontology_key must be a JSON array") - - results = await ontology_service.upload_ontologies( - ontology_keys, ontology_file, user, description_list + result = await ontology_service.upload_ontology( + ontology_key=ontology_key, + file=ontology_file, + user=user, + description=description, ) return { @@ -70,10 +73,9 @@ def get_ontology_router() -> APIRouter: "uploaded_at": result.uploaded_at, "description": result.description, } - for result in results ] } - except (json.JSONDecodeError, ValueError) as e: + except ValueError as e: return JSONResponse(status_code=400, content={"error": str(e)}) except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index fece88240..a626088a3 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -148,8 +148,8 @@ class TestCogneeServerStart(unittest.TestCase): headers=headers, files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], data={ - "ontology_key": json.dumps([ontology_key]), - "description": json.dumps(["Test ontology"]), + "ontology_key": ontology_key, + "description": "Test ontology", }, ) self.assertEqual(ontology_response.status_code, 200) diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py index af3a4d90e..e072ceda8 100644 --- a/cognee/tests/unit/api/test_ontology_endpoint.py +++ b/cognee/tests/unit/api/test_ontology_endpoint.py @@ -1,17 +1,28 @@ import pytest import uuid from fastapi.testclient import TestClient -from unittest.mock import patch, Mock, AsyncMock +from unittest.mock import Mock from types import SimpleNamespace -import importlib from cognee.api.client import app +from cognee.modules.users.methods import get_authenticated_user -gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + +@pytest.fixture(scope="session") +def test_client(): + # Keep a single TestClient (and event loop) for the whole module. + # Re-creating TestClient repeatedly can break async DB connections (asyncpg loop mismatch). + with TestClient(app) as c: + yield c @pytest.fixture -def client(): - return TestClient(app) +def client(test_client, mock_default_user): + async def override_get_authenticated_user(): + return mock_default_user + + app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user + yield test_client + app.dependency_overrides.pop(get_authenticated_user, None) @pytest.fixture @@ -32,12 +43,8 @@ def mock_default_user(): ) -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): +def test_upload_ontology_success(client): """Test successful ontology upload""" - import json - - mock_get_default_user.return_value = mock_default_user ontology_content = ( b"" ) @@ -46,7 +53,7 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use response = client.post( "/api/v1/ontologies", files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], - data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])}, + data={"ontology_key": unique_key, "description": "Test"}, ) assert response.status_code == 200 @@ -55,10 +62,8 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use assert "uploaded_at" in data["uploaded_ontologies"][0] -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user): +def test_upload_ontology_invalid_file(client): """Test 400 response for non-.owl files""" - mock_get_default_user.return_value = mock_default_user unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" response = client.post( "/api/v1/ontologies", @@ -68,14 +73,10 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul assert response.status_code == 400 -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): +def test_upload_ontology_missing_data(client): """Test 400 response for missing file or key""" - import json - - mock_get_default_user.return_value = mock_default_user # Missing file - response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])}) + response = client.post("/api/v1/ontologies", data={"ontology_key": "test"}) assert response.status_code == 400 # Missing key @@ -85,34 +86,25 @@ def test_upload_ontology_missing_data(mock_get_default_user, client, mock_defaul assert response.status_code == 400 -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): - """Test behavior when default user is provided (no explicit authentication)""" - import json - +def test_upload_ontology_without_auth_header(client): + """Test behavior when no explicit authentication header is provided.""" unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" - mock_get_default_user.return_value = mock_default_user response = client.post( "/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"", "application/xml"))], - data={"ontology_key": json.dumps([unique_key])}, + data={"ontology_key": unique_key}, ) - # The current system provides a default user when no explicit authentication is given - # This test verifies the system works with conditional authentication assert response.status_code == 200 data = response.json() assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key assert "uploaded_at" in data["uploaded_ontologies"][0] -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): - """Test uploading multiple ontology files in single request""" +def test_upload_multiple_ontologies_in_single_request_is_rejected(client): + """Uploading multiple ontology files in a single request should fail.""" import io - mock_get_default_user.return_value = mock_default_user - # Create mock files file1_content = b"" file2_content = b"" @@ -120,45 +112,34 @@ def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_ ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), ] - data = { - "ontology_key": '["vehicles", "manufacturers"]', - "descriptions": '["Base vehicles", "Car manufacturers"]', - } + data = {"ontology_key": "vehicles", "description": "Base vehicles"} response = client.post("/api/v1/ontologies", files=files, data=data) - assert response.status_code == 200 - result = response.json() - assert "uploaded_ontologies" in result - assert len(result["uploaded_ontologies"]) == 2 - assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles" - assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers" + assert response.status_code == 400 + assert "Only one ontology_file is allowed" in response.json()["error"] -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): - """Test that upload endpoint accepts array parameters""" +def test_upload_endpoint_rejects_array_style_fields(client): + """Array-style form values should be rejected (no backwards compatibility).""" import io import json - mock_get_default_user.return_value = mock_default_user file_content = b"" files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))] data = { "ontology_key": json.dumps(["single_key"]), - "descriptions": json.dumps(["Single ontology"]), + "description": json.dumps(["Single ontology"]), } response = client.post("/api/v1/ontologies", files=files, data=data) - assert response.status_code == 200 - result = response.json() - assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key" + assert response.status_code == 400 + assert "ontology_key must be a string" in response.json()["error"] -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): +def test_cognify_with_multiple_ontologies(client): """Test cognify endpoint accepts multiple ontology keys""" payload = { "datasets": ["test_dataset"], @@ -172,14 +153,11 @@ def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_de assert response.status_code in [200, 400, 409] # May fail for other reasons, not type -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): - """Test complete workflow: upload multiple ontologies → cognify with multiple keys""" +def test_complete_multifile_workflow(client): + """Test workflow: upload ontologies one-by-one → cognify with multiple keys""" import io - import json - mock_get_default_user.return_value = mock_default_user - # Step 1: Upload multiple ontologies + # Step 1: Upload two ontologies (one-by-one) file1_content = b""" @@ -192,17 +170,21 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default """ - files = [ - ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), - ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), - ] - data = { - "ontology_key": json.dumps(["vehicles", "manufacturers"]), - "descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]), - } + upload_response_1 = client.post( + "/api/v1/ontologies", + files=[("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml"))], + data={"ontology_key": "vehicles", "description": "Vehicle ontology"}, + ) + assert upload_response_1.status_code == 200 - upload_response = client.post("/api/v1/ontologies", files=files, data=data) - assert upload_response.status_code == 200 + upload_response_2 = client.post( + "/api/v1/ontologies", + files=[ + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")) + ], + data={"ontology_key": "manufacturers", "description": "Manufacturer ontology"}, + ) + assert upload_response_2.status_code == 200 # Step 2: Verify ontologies are listed list_response = client.get("/api/v1/ontologies") @@ -223,44 +205,42 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default assert cognify_response.status_code != 400 # Not a validation error -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): - """Test error handling for invalid multifile uploads""" +def test_upload_error_handling(client): + """Test error handling for invalid uploads (single-file endpoint).""" import io import json - # Test mismatched array lengths + # Array-style key should be rejected file_content = b"" files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))] data = { - "ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file - "descriptions": json.dumps(["desc1"]), + "ontology_key": json.dumps(["key1", "key2"]), + "description": "desc1", } response = client.post("/api/v1/ontologies", files=files, data=data) assert response.status_code == 400 - assert "Number of keys must match number of files" in response.json()["error"] + assert "ontology_key must be a string" in response.json()["error"] - # Test duplicate keys - files = [ - ("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")), - ("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")), - ] - data = { - "ontology_key": json.dumps(["duplicate", "duplicate"]), - "descriptions": json.dumps(["desc1", "desc2"]), - } + # Duplicate key should be rejected + response_1 = client.post( + "/api/v1/ontologies", + files=[("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml"))], + data={"ontology_key": "duplicate", "description": "desc1"}, + ) + assert response_1.status_code == 200 - response = client.post("/api/v1/ontologies", files=files, data=data) - assert response.status_code == 400 - assert "Duplicate ontology keys not allowed" in response.json()["error"] + response_2 = client.post( + "/api/v1/ontologies", + files=[("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml"))], + data={"ontology_key": "duplicate", "description": "desc2"}, + ) + assert response_2.status_code == 400 + assert "already exists" in response_2.json()["error"] -@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): +def test_cognify_missing_ontology_key(client): """Test cognify with non-existent ontology key""" - mock_get_default_user.return_value = mock_default_user - payload = { "datasets": ["test_dataset"], "ontology_key": ["nonexistent_key"], From 67af8a7cb46f65c0075b0af5ea35f0607f026b9d Mon Sep 17 00:00:00 2001 From: Pavel Zorin Date: Mon, 15 Dec 2025 18:36:15 +0100 Subject: [PATCH 28/31] Bump version from 0.5.0.dev0 to 0.5.0.dev1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8e4ed8a0d..cf2081d0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "cognee" -version = "0.5.0.dev0" +version = "0.5.0.dev1" description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." authors = [ { name = "Vasilije Markovic" }, From 78028b819f0b9293ec60b5894c8e7155284c5fcd Mon Sep 17 00:00:00 2001 From: Pavel Zorin Date: Mon, 15 Dec 2025 18:42:02 +0100 Subject: [PATCH 29/31] update dev uv.lock --- uv.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uv.lock b/uv.lock index fccab8c40..884fb63be 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'", @@ -946,7 +946,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.5.0.dev0" +version = "0.5.0.dev1" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From 4e8845c117ecf892c3f5554c94de4f9f1171b9ff Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 11:11:29 +0100 Subject: [PATCH 30/31] chore: retriever test reorganization + adding new tests (integration) (STEP 1) (#1881) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR restructures/adds integration and unit tests for the retrieval module. -Old integration tests were updated and moved under unit tests + fixtures added -Added missing unit tests for all core retrieval business logic -Covered 100% of the core retrievers with tests -Minor changes (dead code deletion, typo fixed) ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Changes** * TripletRetriever now returns up to 5 results by default (was 1), providing richer context. * **Tests** * Reorganized test coverage: many unit tests removed and replaced with comprehensive integration tests across retrieval components (graph, chunks, RAG, summaries, temporal, triplets, structured output). * **Chores** * Simplified triplet formatting logic and removed debug output. ✏️ Tip: You can customize this high-level summary in your review settings. --- cognee/modules/retrieval/triplet_retriever.py | 2 +- .../utils/brute_force_triplet_search.py | 18 - .../retrieval/test_chunks_retriever.py | 252 ++++++++ .../test_graph_completion_retriever.py | 268 ++++++++ ..._completion_retriever_context_extension.py | 226 +++++++ .../test_graph_completion_retriever_cot.py | 218 +++++++ .../test_rag_completion_retriever.py | 254 ++++++++ .../retrieval/test_structured_output.py} | 162 ++--- .../retrieval/test_summaries_retriever.py | 184 ++++++ .../retrieval/test_temporal_retriever.py | 306 +++++++++ .../retrieval/test_triplet_retriever.py | 35 + .../eval_framework/benchmark_adapters_test.py | 25 + .../eval_framework/corpus_builder_test.py | 37 +- .../retrieval/chunks_retriever_test.py | 201 ------ .../retrieval/conversation_history_test.py | 154 ----- ...letion_retriever_context_extension_test.py | 177 ----- .../graph_completion_retriever_cot_test.py | 170 ----- .../graph_completion_retriever_test.py | 223 ------- .../rag_completion_retriever_test.py | 205 ------ .../retrieval/summaries_retriever_test.py | 159 ----- .../retrieval/temporal_retriever_test.py | 224 ------- .../test_brute_force_triplet_search.py | 608 ------------------ .../retrieval/triplet_retriever_test.py | 83 --- 23 files changed, 1888 insertions(+), 2303 deletions(-) create mode 100644 cognee/tests/integration/retrieval/test_chunks_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py create mode 100644 cognee/tests/integration/retrieval/test_rag_completion_retriever.py rename cognee/tests/{unit/modules/retrieval/structured_output_test.py => integration/retrieval/test_structured_output.py} (65%) create mode 100644 cognee/tests/integration/retrieval/test_summaries_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_temporal_retriever.py delete mode 100644 cognee/tests/unit/modules/retrieval/chunks_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/conversation_history_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/summaries_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/temporal_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py delete mode 100644 cognee/tests/unit/modules/retrieval/triplet_retriever_test.py diff --git a/cognee/modules/retrieval/triplet_retriever.py b/cognee/modules/retrieval/triplet_retriever.py index d251d113a..b9d006312 100644 --- a/cognee/modules/retrieval/triplet_retriever.py +++ b/cognee/modules/retrieval/triplet_retriever.py @@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path - self.top_k = top_k if top_k is not None else 1 + self.top_k = top_k if top_k is not None else 5 self.system_prompt = system_prompt async def get_context(self, query: str) -> str: diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index bd412e0ca..a70fa661b 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -16,24 +16,6 @@ logger = get_logger(level=ERROR) def format_triplets(edges): - print("\n\n\n") - - def filter_attributes(obj, attributes): - """Helper function to filter out non-None properties, including nested dicts.""" - result = {} - for attr in attributes: - value = getattr(obj, attr, None) - if value is not None: - # If the value is a dict, extract relevant keys from it - if isinstance(value, dict): - nested_values = { - k: v for k, v in value.items() if k in attributes and v is not None - } - result[attr] = nested_values - else: - result[attr] = value - return result - triplets = [] for edge in edges: node1 = edge.node1 diff --git a/cognee/tests/integration/retrieval/test_chunks_retriever.py b/cognee/tests/integration/retrieval/test_chunks_retriever.py new file mode 100644 index 000000000..d2e5e6149 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_chunks_retriever.py @@ -0,0 +1,252 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import List +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.processing.document_types import Document +from cognee.modules.engine.models import Entity + + +class DocumentChunkWithEntities(DataPoint): + text: str + chunk_size: int + chunk_index: int + cut_type: str + is_part_of: Document + contains: List[Entity] = None + + metadata: dict = {"index_fields": ["text"]} + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_simple(): + """Set up a clean test environment with simple chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_complex(): + """Set up a clean test environment with complex chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple): + """Integration test: verify ChunksRetriever can retrieve multiple chunks.""" + retriever = ChunksRetriever() + + context = await retriever.get_context("Steve") + + assert isinstance(context, list), "Context should be a list" + assert len(context) > 0, "Context should not be empty" + assert any(chunk["text"] == "Steve Rodger" for chunk in context), ( + "Failed to get Steve Rodger chunk" + ) + + +@pytest.mark.asyncio +async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex): + """Integration test: verify ChunksRetriever respects top_k parameter.""" + retriever = ChunksRetriever(top_k=2) + + context = await retriever.get_context("Employee") + + assert isinstance(context, list), "Context should be a list" + assert len(context) <= 2, "Should respect top_k limit" + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex): + """Integration test: verify ChunksRetriever can retrieve chunk context (complex).""" + retriever = ChunksRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify ChunksRetriever handles empty graph correctly.""" + retriever = ChunksRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection( + "DocumentChunk_text", payload_schema=DocumentChunkWithEntities + ) + + context = await retriever.get_context("Christina Mayer") + assert len(context) == 0, "Found chunks when none should exist" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever.py new file mode 100644 index 000000000..7367b353b --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever.py @@ -0,0 +1,268 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + description: str + + class Person(DataPoint): + name: str + description: str + works_for: Company + + company1 = Company(name="Figma", description="Figma is a company") + company2 = Company(name="Canva", description="Canvas is a company") + person1 = Person( + name="Steve Rodger", + description="This is description about Steve Rodger", + works_for=company1, + ) + person2 = Person( + name="Ike Loma", description="This is description about Ike Loma", works_for=company1 + ) + person3 = Person( + name="Jason Statham", + description="This is description about Jason Statham", + works_for=company1, + ) + person4 = Person( + name="Mike Broski", + description="This is description about Mike Broski", + works_for=company2, + ) + person5 = Person( + name="Christina Mayer", + description="This is description about Christina Mayer", + works_for=company2, + ) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionRetriever can retrieve context (simple).""" + retriever = GraphCompletionRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + # Ensure the top-level sections are present + assert "Nodes:" in context, "Missing 'Nodes:' section in context" + assert "Connections:" in context, "Missing 'Connections:' section in context" + + # --- Nodes headers --- + assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" + assert "Node: Figma" in context, "Missing node header for Figma" + assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" + assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" + assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" + assert "Node: Canva" in context, "Missing node header for Canva" + assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" + + # --- Node contents --- + assert ( + "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" + in context + ), "Description block for Steve Rodger altered" + assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( + "Description block for Figma altered" + ) + assert ( + "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" + in context + ), "Description block for Ike Loma altered" + assert ( + "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" + in context + ), "Description block for Jason Statham altered" + assert ( + "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" + in context + ), "Description block for Mike Broski altered" + assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( + "Description block for Canva altered" + ) + assert ( + "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" + in context + ), "Description block for Christina Mayer altered" + + # --- Connections --- + assert "Steve Rodger --[works_for]--> Figma" in context, ( + "Connection Steve Rodger→Figma missing or changed" + ) + assert "Ike Loma --[works_for]--> Figma" in context, ( + "Connection Ike Loma→Figma missing or changed" + ) + assert "Jason Statham --[works_for]--> Figma" in context, ( + "Connection Jason Statham→Figma missing or changed" + ) + assert "Mike Broski --[works_for]--> Canva" in context, ( + "Connection Mike Broski→Canva missing or changed" + ) + assert "Christina Mayer --[works_for]--> Canva" in context, ( + "Connection Christina Mayer→Canva missing or changed" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionRetriever can retrieve context (complex).""" + retriever = GraphCompletionRetriever(top_k=20) + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + +@pytest.mark.asyncio +async def test_get_graph_completion_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionRetriever handles empty graph correctly.""" + retriever = GraphCompletionRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + +@pytest.mark.asyncio +async def test_graph_completion_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py new file mode 100644 index 000000000..c87de16ef --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py @@ -0,0 +1,226 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_extension_context_simple" + ) + data_directory_path = str( + base_dir / ".data_storage/test_graph_completion_extension_context_simple" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_extension_context_complex" + ) + data_directory_path = str( + base_dir / ".data_storage/test_graph_completion_extension_context_complex" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_extension_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple).""" + retriever = GraphCompletionContextExtensionRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_extension_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex).""" + retriever = GraphCompletionContextExtensionRetriever(top_k=20) + + context = await resolve_edges_to_text( + await retriever.get_context("Who works at Figma and drives Tesla?") + ) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly.""" + retriever = GraphCompletionContextExtensionRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_extension_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionContextExtensionRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionContextExtensionRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py new file mode 100644 index 000000000..0db035e03 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py @@ -0,0 +1,218 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_cot_context_simple" + ) + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_cot_context_complex" + ) + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_cot_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionCotRetriever can retrieve context (simple).""" + retriever = GraphCompletionCotRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_cot_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionCotRetriever can retrieve context (complex).""" + retriever = GraphCompletionCotRetriever(top_k=20) + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionCotRetriever handles empty graph correctly.""" + retriever = GraphCompletionCotRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_cot_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionCotRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionCotRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_rag_completion_retriever.py b/cognee/tests/integration/retrieval/test_rag_completion_retriever.py new file mode 100644 index 000000000..b01d58160 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_rag_completion_retriever.py @@ -0,0 +1,254 @@ +import os +from typing import List +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.processing.document_types import Document +from cognee.modules.engine.models import Entity + + +class DocumentChunkWithEntities(DataPoint): + text: str + chunk_size: int + chunk_index: int + cut_type: str + is_part_of: Document + contains: List[Entity] = None + + metadata: dict = {"index_fields": ["text"]} + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_simple(): + """Set up a clean test environment with simple chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_complex(): + """Set up a clean test environment with complex chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_rag_completion_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_rag_completion_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_rag_completion_context_simple(setup_test_environment_with_chunks_simple): + """Integration test: verify CompletionRetriever can retrieve context (simple).""" + retriever = CompletionRetriever() + + context = await retriever.get_context("Mike") + + assert isinstance(context, str), "Context should be a string" + assert "Mike Broski" in context, "Failed to get Mike Broski" + + +@pytest.mark.asyncio +async def test_rag_completion_context_multiple_chunks(setup_test_environment_with_chunks_simple): + """Integration test: verify CompletionRetriever can retrieve context from multiple chunks.""" + retriever = CompletionRetriever() + + context = await retriever.get_context("Steve") + + assert isinstance(context, str), "Context should be a string" + assert "Steve Rodger" in context, "Failed to get Steve Rodger" + + +@pytest.mark.asyncio +async def test_rag_completion_context_complex(setup_test_environment_with_chunks_complex): + """Integration test: verify CompletionRetriever can retrieve context (complex).""" + # TODO: top_k doesn't affect the output, it should be fixed. + retriever = CompletionRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_get_rag_completion_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify CompletionRetriever handles empty graph correctly.""" + retriever = CompletionRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection( + "DocumentChunk_text", payload_schema=DocumentChunkWithEntities + ) + + context = await retriever.get_context("Christina Mayer") + assert context == "", "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/integration/retrieval/test_structured_output.py similarity index 65% rename from cognee/tests/unit/modules/retrieval/structured_output_test.py rename to cognee/tests/integration/retrieval/test_structured_output.py index 4ad3019ff..13ffd8eef 100644 --- a/cognee/tests/unit/modules/retrieval/structured_output_test.py +++ b/cognee/tests/integration/retrieval/test_structured_output.py @@ -1,9 +1,9 @@ import asyncio - -import pytest -import cognee -import pathlib import os +import pytest +import pathlib +import pytest_asyncio +import cognee from pydantic import BaseModel from cognee.low_level import setup, DataPoint @@ -125,80 +125,90 @@ async def _test_get_structured_entity_completion(): _assert_structured_answer(structured_answer) -class TestStructuredOutputCompletion: - @pytest.mark.asyncio - async def test_get_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest_asyncio.fixture +async def setup_test_environment(): + """Set up a clean test environment with graph and document data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_get_structured_completion") + data_directory_path = str(base_dir / ".data_storage/test_get_structured_completion") + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + yield + + try: await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await setup() + except Exception: + pass - class Company(DataPoint): - name: str - class Person(DataPoint): - name: str - works_for: Company - works_since: int - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) - - entities = [company1, person1] - await add_data_points(entities) - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - await add_data_points(entities) - - entity_type = EntityType(name="Person", description="A human individual") - entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") - - entities = [entity] - await add_data_points(entities) - - await _test_get_structured_graph_completion_cot() - await _test_get_structured_graph_completion() - await _test_get_structured_graph_completion_temporal() - await _test_get_structured_graph_completion_rag() - await _test_get_structured_graph_completion_context_extension() - await _test_get_structured_entity_completion() +@pytest.mark.asyncio +async def test_get_structured_completion(setup_test_environment): + """Integration test: verify structured output completion for all retrievers.""" + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() diff --git a/cognee/tests/integration/retrieval/test_summaries_retriever.py b/cognee/tests/integration/retrieval/test_summaries_retriever.py new file mode 100644 index 000000000..a2f4e40b3 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_summaries_retriever.py @@ -0,0 +1,184 @@ +import os +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.tasks.summarization.models import TextSummary +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_with_summaries(): + """Set up a clean test environment with summaries.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context") + data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk1_summary = TextSummary( + text="S.R.", + made_from=chunk1, + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2_summary = TextSummary( + text="M.B.", + made_from=chunk2, + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3_summary = TextSummary( + text="C.M.", + made_from=chunk3, + ) + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk4_summary = TextSummary( + text="R.R.", + made_from=chunk4, + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5_summary = TextSummary( + text="H.Y.", + made_from=chunk5, + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6_summary = TextSummary( + text="C.H.", + made_from=chunk6, + ) + + entities = [ + chunk1_summary, + chunk2_summary, + chunk3_summary, + chunk4_summary, + chunk5_summary, + chunk6_summary, + ] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without summaries.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context_empty") + data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_summaries_retriever_context(setup_test_environment_with_summaries): + """Integration test: verify SummariesRetriever can retrieve summary context.""" + retriever = SummariesRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert isinstance(context, list), "Context should be a list" + assert len(context) > 0, "Context should not be empty" + assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_summaries_retriever_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify SummariesRetriever handles empty graph correctly.""" + retriever = SummariesRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) + + context = await retriever.get_context("Christina Mayer") + assert context == [], "Returned context should be empty on an empty graph" diff --git a/cognee/tests/integration/retrieval/test_temporal_retriever.py b/cognee/tests/integration/retrieval/test_temporal_retriever.py new file mode 100644 index 000000000..8ce3b32f4 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_temporal_retriever.py @@ -0,0 +1,306 @@ +import os +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.engine.models.Event import Event +from cognee.modules.engine.models.Timestamp import Timestamp +from cognee.modules.engine.models.Interval import Interval + + +@pytest_asyncio.fixture +async def setup_test_environment_with_events(): + """Set up a clean test environment with temporal events.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + # Create timestamps for events + timestamp1 = Timestamp( + time_at=1609459200, # 2021-01-01 00:00:00 + year=2021, + month=1, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-01-01T00:00:00", + ) + + timestamp2 = Timestamp( + time_at=1612137600, # 2021-02-01 00:00:00 + year=2021, + month=2, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-02-01T00:00:00", + ) + + timestamp3 = Timestamp( + time_at=1614556800, # 2021-03-01 00:00:00 + year=2021, + month=3, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-03-01T00:00:00", + ) + + timestamp4 = Timestamp( + time_at=1625097600, # 2021-07-01 00:00:00 + year=2021, + month=7, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-07-01T00:00:00", + ) + + timestamp5 = Timestamp( + time_at=1633046400, # 2021-10-01 00:00:00 + year=2021, + month=10, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-10-01T00:00:00", + ) + + # Create interval for event spanning multiple timestamps + interval1 = Interval(time_from=timestamp2, time_to=timestamp3) + + # Create events with timestamps + event1 = Event( + name="Project Alpha Launch", + description="Launched Project Alpha at the beginning of 2021", + at=timestamp1, + location="San Francisco", + ) + + event2 = Event( + name="Team Meeting", + description="Monthly team meeting discussing Q1 goals", + during=interval1, + location="New York", + ) + + event3 = Event( + name="Product Release", + description="Released new product features in July", + at=timestamp4, + location="Remote", + ) + + event4 = Event( + name="Company Retreat", + description="Annual company retreat in October", + at=timestamp5, + location="Lake Tahoe", + ) + + entities = [event1, event2, event3, event4] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_graph_data(): + """Set up a clean test environment with graph data (for fallback to triplets).""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + description: str + + class Person(DataPoint): + name: str + description: str + works_for: Company + + company1 = Company(name="Figma", description="Figma is a company") + person1 = Person( + name="Steve Rodger", + description="This is description about Steve Rodger", + works_for=company1, + ) + + entities = [company1, person1] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve events within time range.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("What happened in January 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Project Alpha" in context or "Launch" in context, ( + "Should retrieve Project Alpha Launch event from January 2021" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve events at specific time.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("What happened in July 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Product Release" in context or "July" in context, ( + "Should retrieve Product Release event from July 2021" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_fallback_to_triplets( + setup_test_environment_with_graph_data, +): + """Integration test: verify TemporalRetriever falls back to triplets when no time extracted.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("Who works at Figma?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Steve" in context or "Figma" in context, ( + "Should retrieve graph data via triplet search fallback" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty): + """Integration test: verify TemporalRetriever handles empty graph correctly.""" + retriever = TemporalRetriever() + + context = await retriever.get_context("What happened?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) >= 0, "Context should be a string (possibly empty)" + + +@pytest.mark.asyncio +async def test_temporal_retriever_get_completion(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can generate completions.""" + retriever = TemporalRetriever() + + completion = await retriever.get_completion("What happened in January 2021?") + + assert isinstance(completion, list), "Completion should be a list" + assert len(completion) > 0, "Completion should not be empty" + assert all(isinstance(item, str) and item.strip() for item in completion), ( + "Completion items should be non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data): + """Integration test: verify TemporalRetriever get_completion works with triplet fallback.""" + retriever = TemporalRetriever() + + completion = await retriever.get_completion("Who works at Figma?") + + assert isinstance(completion, list), "Completion should be a list" + assert len(completion) > 0, "Completion should not be empty" + assert all(isinstance(item, str) and item.strip() for item in completion), ( + "Completion items should be non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever respects top_k parameter.""" + retriever = TemporalRetriever(top_k=2) + + context = await retriever.get_context("What happened in 2021?") + + assert isinstance(context, str), "Context should be a string" + separator_count = context.count("#####################") + assert separator_count <= 1, "Should respect top_k limit of 2 events" + + +@pytest.mark.asyncio +async def test_temporal_retriever_multiple_events(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve multiple events.""" + retriever = TemporalRetriever(top_k=10) + + context = await retriever.get_context("What events occurred in 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + + assert ( + "Project Alpha" in context + or "Team Meeting" in context + or "Product Release" in context + or "Company Retreat" in context + ), "Should retrieve at least one event from 2021" diff --git a/cognee/tests/integration/retrieval/test_triplet_retriever.py b/cognee/tests/integration/retrieval/test_triplet_retriever.py index e547b6cbe..ebe853e08 100644 --- a/cognee/tests/integration/retrieval/test_triplet_retriever.py +++ b/cognee/tests/integration/retrieval/test_triplet_retriever.py @@ -82,3 +82,38 @@ async def test_triplet_retriever_context_simple(setup_test_environment_with_trip context = await retriever.get_context("Alice") assert "Alice knows Bob" in context, "Failed to get Alice triplet" + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + + +@pytest.mark.asyncio +async def test_triplet_retriever_context_multiple_triplets(setup_test_environment_with_triplets): + """Integration test: verify TripletRetriever can retrieve multiple triplets.""" + retriever = TripletRetriever(top_k=5) + + context = await retriever.get_context("Bob") + + assert "Alice knows Bob" in context or "Bob works at Tech Corp" in context, ( + "Failed to get Bob-related triplets" + ) + + +@pytest.mark.asyncio +async def test_triplet_retriever_top_k_limit(setup_test_environment_with_triplets): + """Integration test: verify TripletRetriever respects top_k parameter.""" + retriever = TripletRetriever(top_k=1) + + context = await retriever.get_context("Alice") + + assert isinstance(context, str), "Context should be a string" + + +@pytest.mark.asyncio +async def test_triplet_retriever_context_empty(setup_test_environment_empty): + """Integration test: verify TripletRetriever handles empty graph correctly.""" + await setup() + + retriever = TripletRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Alice") diff --git a/cognee/tests/unit/eval_framework/benchmark_adapters_test.py b/cognee/tests/unit/eval_framework/benchmark_adapters_test.py index 70ec43cf8..b18012594 100644 --- a/cognee/tests/unit/eval_framework/benchmark_adapters_test.py +++ b/cognee/tests/unit/eval_framework/benchmark_adapters_test.py @@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\ {"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]} """ +MOCK_HOTPOT_CORPUS = [ + { + "_id": "1", + "question": "Next to which country is Germany located?", + "answer": "Netherlands", + # HotpotQA uses "level"; TwoWikiMultiHop uses "type". + "level": "easy", + "type": "comparison", + "context": [ + ["Germany", ["Germany is in Europe."]], + ["Netherlands", ["The Netherlands borders Germany."]], + ], + "supporting_facts": [["Netherlands", 0]], + } +] + ADAPTER_CLASSES = [ HotpotQAAdapter, @@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass): adapter = AdapterClass() result = adapter.load_corpus() + elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter): + with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + adapter = AdapterClass() + result = adapter.load_corpus() + else: adapter = AdapterClass() result = adapter.load_corpus() @@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass): ): adapter = AdapterClass() corpus_list, qa_pairs = adapter.load_corpus(limit=limit) + elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter): + with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + adapter = AdapterClass() + corpus_list, qa_pairs = adapter.load_corpus(limit=limit) else: adapter = AdapterClass() corpus_list, qa_pairs = adapter.load_corpus(limit=limit) diff --git a/cognee/tests/unit/eval_framework/corpus_builder_test.py b/cognee/tests/unit/eval_framework/corpus_builder_test.py index 14136bea5..53f886b58 100644 --- a/cognee/tests/unit/eval_framework/corpus_builder_test.py +++ b/cognee/tests/unit/eval_framework/corpus_builder_test.py @@ -2,15 +2,38 @@ import pytest from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor from cognee.infrastructure.databases.graph import get_graph_engine from unittest.mock import AsyncMock, patch +from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"] +MOCK_HOTPOT_CORPUS = [ + { + "_id": "1", + "question": "Next to which country is Germany located?", + "answer": "Netherlands", + # HotpotQA uses "level"; TwoWikiMultiHop uses "type". + "level": "easy", + "type": "comparison", + "context": [ + ["Germany", ["Germany is in Europe."]], + ["Netherlands", ["The Netherlands borders Germany."]], + ], + "supporting_facts": [["Netherlands", 0]], + } +] + @pytest.mark.parametrize("benchmark", benchmark_options) def test_corpus_builder_load_corpus(benchmark): limit = 2 - corpus_builder = CorpusBuilderExecutor(benchmark, "Default") - raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + if benchmark in ("HotPotQA", "TwoWikiMultiHop"): + with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + else: + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}" assert len(questions) <= 2, ( f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" @@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark): @patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock) async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark): limit = 2 - corpus_builder = CorpusBuilderExecutor(benchmark, "Default") - questions = await corpus_builder.build_corpus(limit=limit) + if benchmark in ("HotPotQA", "TwoWikiMultiHop"): + with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + questions = await corpus_builder.build_corpus(limit=limit) + else: + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + questions = await corpus_builder.build_corpus(limit=limit) + assert len(questions) <= 2, ( f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" ) diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py deleted file mode 100644 index 44786f79d..000000000 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ /dev/null @@ -1,201 +0,0 @@ -import os -import pytest -import pathlib -from typing import List -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.chunks_retriever import ChunksRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity - - -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} - - -class TestChunksRetriever: - @pytest.mark.asyncio - async def test_chunk_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - - await add_data_points(entities) - - retriever = ChunksRetriever() - - context = await retriever.get_context("Mike") - - assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski" - - @pytest.mark.asyncio - async def test_chunk_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] - - await add_data_points(entities) - - retriever = ChunksRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = ChunksRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) - - context = await retriever.get_context("Christina Mayer") - assert len(context) == 0, "Found chunks when none should exist" diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py deleted file mode 100644 index d464a99d8..000000000 --- a/cognee/tests/unit/modules/retrieval/conversation_history_test.py +++ /dev/null @@ -1,154 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from cognee.context_global_variables import session_user -import importlib - - -def create_mock_cache_engine(qa_history=None): - mock_cache = AsyncMock() - if qa_history is None: - qa_history = [] - mock_cache.get_latest_qa = AsyncMock(return_value=qa_history) - mock_cache.add_qa = AsyncMock(return_value=None) - return mock_cache - - -def create_mock_user(): - mock_user = MagicMock() - mock_user.id = "test-user-id-123" - return mock_user - - -class TestConversationHistoryUtils: - @pytest.mark.asyncio - async def test_get_conversation_history_returns_empty_when_no_history(self): - user = create_mock_user() - session_user.set(user) - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - from cognee.modules.retrieval.utils.session_cache import get_conversation_history - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_formats_history_correctly(self): - """Test get_conversation_history formats Q&A history with correct structure.""" - user = create_mock_user() - session_user.set(user) - - mock_history = [ - { - "time": "2024-01-15 10:30:45", - "question": "What is AI?", - "context": "AI is artificial intelligence", - "answer": "AI stands for Artificial Intelligence", - } - ] - mock_cache = create_mock_cache_engine(mock_history) - - # Import the real module to patch safely - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert "Previous conversation:" in result - assert "[2024-01-15 10:30:45]" in result - assert "QUESTION: What is AI?" in result - assert "CONTEXT: AI is artificial intelligence" in result - assert "ANSWER: AI stands for Artificial Intelligence" in result - - @pytest.mark.asyncio - async def test_save_to_session_cache_saves_correctly(self): - """Test save_conversation_history calls add_qa with correct parameters.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="What is Python?", - context_summary="Python is a programming language", - answer="Python is a high-level programming language", - session_id="my_session", - ) - - assert result is True - mock_cache.add_qa.assert_called_once() - - call_kwargs = mock_cache.add_qa.call_args.kwargs - assert call_kwargs["question"] == "What is Python?" - assert call_kwargs["context"] == "Python is a programming language" - assert call_kwargs["answer"] == "Python is a high-level programming language" - assert call_kwargs["session_id"] == "my_session" - - @pytest.mark.asyncio - async def test_save_to_session_cache_uses_default_session_when_none(self): - """Test save_conversation_history uses 'default_session' when session_id is None.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - session_id=None, - ) - - assert result is True - call_kwargs = mock_cache.add_qa.call_args.kwargs - assert call_kwargs["session_id"] == "default_session" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py deleted file mode 100644 index 0e21fe351..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ /dev/null @@ -1,177 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.tasks.storage import add_data_points -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( - GraphCompletionContextExtensionRetriever, -) - - -class TestGraphCompletionWithContextExtensionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_extension_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_simple", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_simple", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" - - answer = await retriever.get_completion("Who works at Canva?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_graph_completion_extension_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_complex", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever(top_k=20) - - context = await resolve_edges_to_text( - await retriever.get_context("Who works at Figma and drives Tesla?") - ) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_get_graph_completion_extension_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_extension_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionContextExtensionRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py deleted file mode 100644 index 206cfaf84..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points -from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever - - -class TestGraphCompletionCoTRetriever: - @pytest.mark.asyncio - async def test_graph_completion_cot_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" - - answer = await retriever.get_completion("Who works at Canva?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_graph_completion_cot_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_cot_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever(top_k=20) - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_get_graph_completion_cot_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_cot_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionCotRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py deleted file mode 100644 index f462baced..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ /dev/null @@ -1,223 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points -from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever - - -class TestGraphCompletionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - description: str - - class Person(DataPoint): - name: str - description: str - works_for: Company - - company1 = Company(name="Figma", description="Figma is a company") - company2 = Company(name="Canva", description="Canvas is a company") - person1 = Person( - name="Steve Rodger", - description="This is description about Steve Rodger", - works_for=company1, - ) - person2 = Person( - name="Ike Loma", description="This is description about Ike Loma", works_for=company1 - ) - person3 = Person( - name="Jason Statham", - description="This is description about Jason Statham", - works_for=company1, - ) - person4 = Person( - name="Mike Broski", - description="This is description about Mike Broski", - works_for=company2, - ) - person5 = Person( - name="Christina Mayer", - description="This is description about Christina Mayer", - works_for=company2, - ) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - # Ensure the top-level sections are present - assert "Nodes:" in context, "Missing 'Nodes:' section in context" - assert "Connections:" in context, "Missing 'Connections:' section in context" - - # --- Nodes headers --- - assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" - assert "Node: Figma" in context, "Missing node header for Figma" - assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" - assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" - assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" - assert "Node: Canva" in context, "Missing node header for Canva" - assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" - - # --- Node contents --- - assert ( - "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" - in context - ), "Description block for Steve Rodger altered" - assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( - "Description block for Figma altered" - ) - assert ( - "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" - in context - ), "Description block for Ike Loma altered" - assert ( - "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" - in context - ), "Description block for Jason Statham altered" - assert ( - "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" - in context - ), "Description block for Mike Broski altered" - assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( - "Description block for Canva altered" - ) - assert ( - "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" - in context - ), "Description block for Christina Mayer altered" - - # --- Connections --- - assert "Steve Rodger --[works_for]--> Figma" in context, ( - "Connection Steve Rodger→Figma missing or changed" - ) - assert "Ike Loma --[works_for]--> Figma" in context, ( - "Connection Ike Loma→Figma missing or changed" - ) - assert "Jason Statham --[works_for]--> Figma" in context, ( - "Connection Jason Statham→Figma missing or changed" - ) - assert "Mike Broski --[works_for]--> Canva" in context, ( - "Connection Mike Broski→Canva missing or changed" - ) - assert "Christina Mayer --[works_for]--> Canva" in context, ( - "Connection Christina Mayer→Canva missing or changed" - ) - - @pytest.mark.asyncio - async def test_graph_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionRetriever(top_k=20) - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - @pytest.mark.asyncio - async def test_get_graph_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py deleted file mode 100644 index 9bfed68f3..000000000 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ /dev/null @@ -1,205 +0,0 @@ -import os -from typing import List -import pytest -import pathlib -import cognee - -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.completion_retriever import CompletionRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity - - -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} - - -class TestRAGCompletionRetriever: - @pytest.mark.asyncio - async def test_rag_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - - await add_data_points(entities) - - retriever = CompletionRetriever() - - context = await retriever.get_context("Mike") - - assert context == "Mike Broski", "Failed to get Mike Broski" - - @pytest.mark.asyncio - async def test_rag_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] - - await add_data_points(entities) - - # TODO: top_k doesn't affect the output, it should be fixed. - retriever = CompletionRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_get_rag_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = CompletionRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) - - context = await retriever.get_context("Christina Mayer") - assert context == "", "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py deleted file mode 100644 index 5f4b93425..000000000 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ /dev/null @@ -1,159 +0,0 @@ -import os -import pytest -import pathlib - -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.tasks.summarization.models import TextSummary -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.summaries_retriever import SummariesRetriever - - -class TestSummariesRetriever: - @pytest.mark.asyncio - async def test_chunk_context(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk1_summary = TextSummary( - text="S.R.", - made_from=chunk1, - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2_summary = TextSummary( - text="M.B.", - made_from=chunk2, - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3_summary = TextSummary( - text="C.M.", - made_from=chunk3, - ) - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk4_summary = TextSummary( - text="R.R.", - made_from=chunk4, - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5_summary = TextSummary( - text="H.Y.", - made_from=chunk5, - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6_summary = TextSummary( - text="C.H.", - made_from=chunk6, - ) - - entities = [ - chunk1_summary, - chunk2_summary, - chunk3_summary, - chunk4_summary, - chunk5_summary, - chunk6_summary, - ] - - await add_data_points(entities) - - retriever = SummariesRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = SummariesRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) - - context = await retriever.get_context("Christina Mayer") - assert context == [], "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py deleted file mode 100644 index c3c6a47f6..000000000 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ /dev/null @@ -1,224 +0,0 @@ -from types import SimpleNamespace -import pytest - -from cognee.modules.retrieval.temporal_retriever import TemporalRetriever - - -# Test TemporalRetriever initialization defaults and overrides -def test_init_defaults_and_overrides(): - tr = TemporalRetriever() - assert tr.top_k == 5 - assert tr.user_prompt_path == "graph_context_for_question.txt" - assert tr.system_prompt_path == "answer_simple_question.txt" - assert tr.time_extraction_prompt_path == "extract_query_time.txt" - - tr2 = TemporalRetriever( - top_k=3, - user_prompt_path="u.txt", - system_prompt_path="s.txt", - time_extraction_prompt_path="t.txt", - ) - assert tr2.top_k == 3 - assert tr2.user_prompt_path == "u.txt" - assert tr2.system_prompt_path == "s.txt" - assert tr2.time_extraction_prompt_path == "t.txt" - - -# Test descriptions_to_string with basic and empty results -def test_descriptions_to_string_basic_and_empty(): - tr = TemporalRetriever() - - results = [ - {"description": " First "}, - {"nope": "no description"}, - {"description": "Second"}, - {"description": ""}, - {"description": " Third line "}, - ] - - s = tr.descriptions_to_string(results) - assert s == "First\n#####################\nSecond\n#####################\nThird line" - - assert tr.descriptions_to_string([]) == "" - - -# Test filter_top_k_events sorts and limits correctly -@pytest.mark.asyncio -async def test_filter_top_k_events_sorts_and_limits(): - tr = TemporalRetriever(top_k=2) - - relevant_events = [ - { - "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3 - not in vector results"}, - ] - } - ] - - scored_results = [ - SimpleNamespace(payload={"id": "e2"}, score=0.10), - SimpleNamespace(payload={"id": "e1"}, score=0.20), - ] - - top = await tr.filter_top_k_events(relevant_events, scored_results) - - assert [e["id"] for e in top] == ["e2", "e1"] - assert all("score" in e for e in top) - assert top[0]["score"] == 0.10 - assert top[1]["score"] == 0.20 - - -# Test filter_top_k_events handles unknown ids as infinite scores -@pytest.mark.asyncio -async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k(): - tr = TemporalRetriever(top_k=2) - - relevant_events = [ - { - "events": [ - {"id": "known1", "description": "Known 1"}, - {"id": "unknown", "description": "Unknown"}, - {"id": "known2", "description": "Known 2"}, - ] - } - ] - - scored_results = [ - SimpleNamespace(payload={"id": "known2"}, score=0.05), - SimpleNamespace(payload={"id": "known1"}, score=0.50), - ] - - top = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in top] == ["known2", "known1"] - assert all(e["score"] != float("inf") for e in top) - - -# Test descriptions_to_string with unicode and newlines -def test_descriptions_to_string_unicode_and_newlines(): - tr = TemporalRetriever() - results = [ - {"description": "Line A\nwith newline"}, - {"description": "This is a description"}, - ] - s = tr.descriptions_to_string(results) - assert "Line A\nwith newline" in s - assert "This is a description" in s - assert s.count("#####################") == 1 - - -# Test filter_top_k_events when top_k is larger than available events -@pytest.mark.asyncio -async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): - tr = TemporalRetriever(top_k=10) - relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] - scored_results = [ - SimpleNamespace(payload={"id": "a"}, score=0.1), - SimpleNamespace(payload={"id": "b"}, score=0.2), - ] - out = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in out] == ["a", "b"] - - -# Test filter_top_k_events when scored_results is empty -@pytest.mark.asyncio -async def test_filter_top_k_events_handles_empty_scored_results(): - tr = TemporalRetriever(top_k=2) - relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}] - scored_results = [] - out = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in out] == ["x", "y"] - assert all(e["score"] == float("inf") for e in out) - - -# Test filter_top_k_events error handling for missing structure -@pytest.mark.asyncio -async def test_filter_top_k_events_error_handling(): - tr = TemporalRetriever(top_k=2) - with pytest.raises((KeyError, TypeError)): - await tr.filter_top_k_events([{}], []) - - -class _FakeRetriever(TemporalRetriever): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._calls = [] - - async def extract_time_from_query(self, query: str): - if "both" in query: - return "2024-01-01", "2024-12-31" - if "from_only" in query: - return "2024-01-01", None - if "to_only" in query: - return None, "2024-12-31" - return None, None - - async def get_triplets(self, query: str): - self._calls.append(("get_triplets", query)) - return [{"s": "a", "p": "b", "o": "c"}] - - async def resolve_edges_to_text(self, triplets): - self._calls.append(("resolve_edges_to_text", len(triplets))) - return "edges->text" - - async def _fake_graph_collect_ids(self, **kwargs): - return ["e1", "e2"] - - async def _fake_graph_collect_events(self, ids): - return [ - { - "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3"}, - ] - } - ] - - async def _fake_vector_embed(self, texts): - assert isinstance(texts, list) and texts - return [[0.0, 1.0, 2.0]] - - async def _fake_vector_search(self, **kwargs): - return [ - SimpleNamespace(payload={"id": "e2"}, score=0.05), - SimpleNamespace(payload={"id": "e1"}, score=0.10), - ] - - async def get_context(self, query: str): - time_from, time_to = await self.extract_time_from_query(query) - - if not (time_from or time_to): - triplets = await self.get_triplets(query) - return await self.resolve_edges_to_text(triplets) - - ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to) - relevant_events = await self._fake_graph_collect_events(ids) - - _ = await self._fake_vector_embed([query]) - vector_search_results = await self._fake_vector_search( - collection_name="Event_name", query_vector=[0.0], limit=0 - ) - top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) - return self.descriptions_to_string(top_k_events) - - -# Test get_context fallback to triplets when no time is extracted -@pytest.mark.asyncio -async def test_fake_get_context_falls_back_to_triplets_when_no_time(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("no_time") - assert ctx == "edges->text" - assert tr._calls[0][0] == "get_triplets" - assert tr._calls[1][0] == "resolve_edges_to_text" - - -# Test get_context when time is extracted and vector ranking is applied -@pytest.mark.asyncio -async def test_fake_get_context_with_time_filters_and_vector_ranking(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("both time") - assert ctx.startswith("E2") - assert "#####################" in ctx - assert "E1" in ctx and "E3" not in ctx diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py deleted file mode 100644 index 3dc9f38d9..000000000 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ /dev/null @@ -1,608 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch - -from cognee.modules.retrieval.utils.brute_force_triplet_search import ( - brute_force_triplet_search, - get_memory_fragment, -) -from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph -from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError - - -class MockScoredResult: - """Mock class for vector search results.""" - - def __init__(self, id, score, payload=None): - self.id = id - self.score = score - self.payload = payload or {} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_empty_query(): - """Test that empty query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): - await brute_force_triplet_search(query="") - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_none_query(): - """Test that None query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): - await brute_force_triplet_search(query=None) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_negative_top_k(): - """Test that negative top_k raises ValueError.""" - with pytest.raises(ValueError, match="top_k must be a positive integer."): - await brute_force_triplet_search(query="test query", top_k=-1) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_zero_top_k(): - """Test that zero top_k raises ValueError.""" - with pytest.raises(ValueError, match="top_k must be a positive integer."): - await brute_force_triplet_search(query="test query", top_k=0) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_limit_global_search(): - """Test that wide_search_limit is applied for global search (node_name=None).""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search( - query="test", - node_name=None, # Global search - wide_search_top_k=75, - ) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] == 75 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): - """Test that wide_search_limit is None for filtered search (node_name provided).""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search( - query="test", - node_name=["Node1"], - wide_search_top_k=50, - ) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] is None - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_default(): - """Test that wide_search_top_k defaults to 100.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", node_name=None) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] == 100 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_default_collections(): - """Test that default collections are used when none provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test") - - expected_collections = [ - "Entity_name", - "TextSummary_text", - "EntityType_name", - "DocumentChunk_text", - "EdgeType_relationship_name", - ] - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert call_collections == expected_collections - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_custom_collections(): - """Test that custom collections are used when provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - custom_collections = ["CustomCol1", "CustomCol2"] - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", collections=custom_collections) - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_always_includes_edge_collection(): - """Test that EdgeType_relationship_name is always searched even when not in collections.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - collections_without_edge = ["Entity_name", "TextSummary_text"] - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", collections=collections_without_edge) - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert "EdgeType_relationship_name" in call_collections - assert set(call_collections) == set(collections_without_edge) | { - "EdgeType_relationship_name" - } - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_all_collections_empty(): - """Test that empty list is returned when all collections return no results.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - results = await brute_force_triplet_search(query="test") - assert results == [] - - -# Tests for query embedding - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_embeds_query(): - """Test that query is embedded before searching.""" - query_text = "test query" - expected_vector = [0.1, 0.2, 0.3] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query=query_text) - - mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["query_vector"] == expected_vector - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_extracts_node_ids_global_search(): - """Test that node IDs are extracted from search results for global search.""" - scored_results = [ - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - MockScoredResult("node3", 0.92), - ] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=scored_results) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_reuses_provided_fragment(): - """Test that provided memory fragment is reused instead of creating new one.""" - provided_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" - ) as mock_get_fragment, - ): - await brute_force_triplet_search( - query="test", - memory_fragment=provided_fragment, - node_name=["node"], - ) - - mock_get_fragment.assert_not_called() - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): - """Test that memory fragment is created when not provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment, - ): - await brute_force_triplet_search(query="test", node_name=["node"]) - - mock_get_fragment.assert_called_once() - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): - """Test that custom top_k is passed to importance calculation.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ), - ): - custom_top_k = 15 - await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) - - mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) - - -@pytest.mark.asyncio -async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): - """Test that get_memory_fragment returns empty graph when entity not found.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock( - side_effect=EntityNotFoundError("Entity not found") - ) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, - ): - fragment = await get_memory_fragment() - - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 - - -@pytest.mark.asyncio -async def test_get_memory_fragment_returns_empty_graph_on_error(): - """Test that get_memory_fragment returns empty graph on generic error.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, - ): - fragment = await get_memory_fragment() - - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_deduplicates_node_ids(): - """Test that duplicate node IDs across collections are deduplicated.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [ - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - ] - elif collection_name == "TextSummary_text": - return [ - MockScoredResult("node1", 0.90), - MockScoredResult("node3", 0.92), - ] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} - assert len(call_kwargs["relevant_ids_to_filter"]) == 3 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_excludes_edge_collection(): - """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [MockScoredResult("node1", 0.95)] - elif collection_name == "EdgeType_relationship_name": - return [MockScoredResult("edge1", 0.88)] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search( - query="test", - node_name=None, - collections=["Entity_name", "EdgeType_relationship_name"], - ) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert call_kwargs["relevant_ids_to_filter"] == ["node1"] - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_skips_nodes_without_ids(): - """Test that nodes without ID attribute are skipped.""" - - class ScoredResultNoId: - """Mock result without id attribute.""" - - def __init__(self, score): - self.score = score - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [ - MockScoredResult("node1", 0.95), - ScoredResultNoId(0.90), - MockScoredResult("node2", 0.87), - ] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_handles_tuple_results(): - """Test that both list and tuple results are handled correctly.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return ( - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - ) - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_mixed_empty_collections(): - """Test ID extraction with mixed empty and non-empty collections.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [MockScoredResult("node1", 0.95)] - elif collection_name == "TextSummary_text": - return [] - elif collection_name == "EntityType_name": - return [MockScoredResult("node2", 0.92)] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py deleted file mode 100644 index d79aca428..000000000 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.triplet_retriever import TripletRetriever -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError - - -@pytest.fixture -def mock_vector_engine(): - """Create a mock vector engine.""" - engine = AsyncMock() - engine.has_collection = AsyncMock(return_value=True) - engine.search = AsyncMock() - return engine - - -@pytest.mark.asyncio -async def test_get_context_success(mock_vector_engine): - """Test successful retrieval of triplet context.""" - mock_result1 = MagicMock() - mock_result1.payload = {"text": "Alice knows Bob"} - mock_result2 = MagicMock() - mock_result2.payload = {"text": "Bob works at Tech Corp"} - - mock_vector_engine.search.return_value = [mock_result1, mock_result2] - - retriever = TripletRetriever(top_k=5) - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "Alice knows Bob\nBob works at Tech Corp" - mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5) - - -@pytest.mark.asyncio -async def test_get_context_no_collection(mock_vector_engine): - """Test that NoDataError is raised when Triplet_text collection doesn't exist.""" - mock_vector_engine.has_collection.return_value = False - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="create_triplet_embeddings"): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_empty_results(mock_vector_engine): - """Test that empty string is returned when no triplets are found.""" - mock_vector_engine.search.return_value = [] - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "" - - -@pytest.mark.asyncio -async def test_get_context_collection_not_found_error(mock_vector_engine): - """Test that CollectionNotFoundError is converted to NoDataError.""" - mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found") - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="No data found"): - await retriever.get_context("test query") From b4aaa7faefce804d9ad6fee93d9907b352206f25 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 11:59:33 +0100 Subject: [PATCH 31/31] chore: retriever test reorganization + adding new tests (smoke e2e) (STEP 1.5) (#1888) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR restructures the end-to-end tests for the multi-database search layer to improve maintainability, consistency, and coverage across supported Python versions and database settings. Key Changes -Migrates the existing E2E tests to pytest for a more standard and extensible testing framework. -Introduces pytest fixtures to centralize and reuse test setup logic. -Implements proper event loop management to support multiple asynchronous pytest tests reliably. -Improves SQLAlchemy handling in tests, ensuring clean setup and teardown of database state. -Extends multi-database E2E test coverage across all supported Python versions. Benefits -Cleaner and more modular test structure. -Reduced duplication and clearer test intent through fixtures. -More reliable async test execution. -Better alignment with our supported Python version matrix. ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Tests** * Expanded end-to-end test suite for the search database with comprehensive setup/teardown, new session-scoped fixtures, and multiple tests validating graph/vector consistency, retriever contexts, triplet metadata, search result shapes, side effects, and feedback-weight behavior. * **Chores** * CI updated to run matrixed test jobs across multiple Python versions and standardize test execution for more consistent, parallelized runs. ✏️ Tip: You can customize this high-level summary in your review settings. --- .github/workflows/search_db_tests.yml | 46 ++- cognee/tests/test_search_db.py | 529 +++++++++++++++++--------- 2 files changed, 374 insertions(+), 201 deletions(-) diff --git a/.github/workflows/search_db_tests.yml b/.github/workflows/search_db_tests.yml index 118c1c06c..f0c7817cd 100644 --- a/.github/workflows/search_db_tests.yml +++ b/.github/workflows/search_db_tests.yml @@ -11,12 +11,21 @@ on: type: string default: "all" description: "Which vector databases to test (comma-separated list or 'all')" + python-versions: + required: false + type: string + default: '["3.10", "3.11", "3.12", "3.13"]' + description: "Python versions to test (JSON array)" jobs: run-kuzu-lance-sqlite-search-tests: - name: Search test for Kuzu/LanceDB/Sqlite + name: Search test for Kuzu/LanceDB/Sqlite (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false steps: - name: Check out uses: actions/checkout@v4 @@ -26,7 +35,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} - name: Dependencies already installed run: echo "Dependencies already installed in setup" @@ -45,13 +54,16 @@ jobs: GRAPH_DATABASE_PROVIDER: 'kuzu' VECTOR_DB_PROVIDER: 'lancedb' DB_PROVIDER: 'sqlite' - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-neo4j-lance-sqlite-search-tests: - name: Search test for Neo4j/LanceDB/Sqlite + name: Search test for Neo4j/LanceDB/Sqlite (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }} - + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false steps: - name: Check out uses: actions/checkout@v4 @@ -61,7 +73,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} - name: Setup Neo4j with GDS uses: ./.github/actions/setup_neo4j @@ -88,12 +100,16 @@ jobs: GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-kuzu-pgvector-postgres-search-tests: - name: Search test for Kuzu/PGVector/Postgres + name: Search test for Kuzu/PGVector/Postgres (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false services: postgres: image: pgvector/pgvector:pg17 @@ -117,7 +133,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} extra-dependencies: "postgres" - name: Dependencies already installed @@ -143,12 +159,16 @@ jobs: DB_PORT: 5432 DB_USERNAME: cognee DB_PASSWORD: cognee - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-neo4j-pgvector-postgres-search-tests: - name: Search test for Neo4j/PGVector/Postgres + name: Search test for Neo4j/PGVector/Postgres (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false services: postgres: image: pgvector/pgvector:pg17 @@ -172,7 +192,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} extra-dependencies: "postgres" - name: Setup Neo4j with GDS @@ -205,4 +225,4 @@ jobs: DB_PORT: 5432 DB_USERNAME: cognee DB_PASSWORD: cognee - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index ba150f813..0916be322 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -1,5 +1,10 @@ import pathlib import os +import asyncio +import pytest +import pytest_asyncio +from collections import Counter + import cognee from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine @@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever from cognee.modules.retrieval.triplet_retriever import TripletRetriever from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType from cognee.modules.users.methods import get_default_user -from collections import Counter logger = get_logger() -async def main(): - # This test runs for multiple db settings, to run this locally set the corresponding db envs +async def _reset_engines_and_prune() -> None: + """Reset db engine caches and prune data/system. + + Kept intentionally identical to the inlined setup logic to avoid event loop issues when + using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run. + """ + # Dispose of existing engines and clear caches to ensure fresh instances for each test + try: + from cognee.infrastructure.databases.vector import get_vector_engine + + vector_engine = get_vector_engine() + # Dispose SQLAlchemy engine connection pool if it exists + if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"): + await vector_engine.engine.dispose(close=True) + except Exception: + # Engine might not exist yet + pass + + from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine + from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from cognee.infrastructure.databases.relational.create_relational_engine import ( + create_relational_engine, + ) + + create_graph_engine.cache_clear() + create_vector_engine.cache_clear() + create_relational_engine.cache_clear() + await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - dataset_name = "test_dataset" +async def _seed_default_dataset(dataset_name: str) -> dict: + """Add the shared test dataset contents and run cognify (same steps/order as before).""" text_1 = """Germany is located in europe right next to the Netherlands""" + + logger.info(f"Adding text data to dataset: {dataset_name}") await cognee.add(text_1, dataset_name) explanation_file_path_quantum = os.path.join( pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt" ) + logger.info(f"Adding file data to dataset: {dataset_name}") await cognee.add([explanation_file_path_quantum], dataset_name) + logger.info(f"Running cognify on dataset: {dataset_name}") await cognee.cognify([dataset_name]) + return { + "dataset_name": dataset_name, + "text_1": text_1, + "explanation_file_path_quantum": explanation_file_path_quantum, + } + + +@pytest.fixture(scope="session") +def event_loop(): + """Use a single asyncio event loop for this test module. + + This helps avoid "Future attached to a different loop" when running multiple async + tests that share clients/engines. + """ + loop = asyncio.new_event_loop() + try: + yield loop + finally: + loop.close() + + +async def setup_test_environment(): + """Helper function to set up test environment with data, cognify, and triplet embeddings.""" + # This test runs for multiple db settings, to run this locally set the corresponding db envs + + dataset_name = "test_dataset" + logger.info("Starting test setup: pruning data and system") + await _reset_engines_and_prune() + state = await _seed_default_dataset(dataset_name=dataset_name) + user = await get_default_user() from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings + logger.info("Creating triplet embeddings") await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5) + # Check if Triplet_text collection was created + vector_engine = get_vector_engine() + has_collection = await vector_engine.has_collection(collection_name="Triplet_text") + logger.info(f"Triplet_text collection exists after creation: {has_collection}") + + if has_collection: + collection = await vector_engine.get_collection("Triplet_text") + count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown" + logger.info(f"Triplet_text collection row count: {count}") + + return state + + +async def setup_test_environment_for_feedback(): + """Helper function to set up test environment for feedback weight calculation test.""" + dataset_name = "test_dataset" + await _reset_engines_and_prune() + return await _seed_default_dataset(dataset_name=dataset_name) + + +@pytest_asyncio.fixture(scope="session") +async def e2e_state(): + """Compute E2E artifacts once; tests only assert. + + This avoids repeating expensive setup and LLM calls across multiple tests. + """ + await setup_test_environment() + + # --- Graph/vector engine consistency --- graph_engine = await get_graph_engine() - nodes, edges = await graph_engine.get_graph_data() + _nodes, edges = await graph_engine.get_graph_data() vector_engine = get_vector_engine() collection = await vector_engine.search( - query_text="Test", limit=None, collection_name="Triplet_text" + collection_name="Triplet_text", query_text="Test", limit=None ) - assert len(edges) == len(collection), ( - f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection" - ) + # --- Retriever contexts --- + query = "Next to which country is Germany located?" - context_gk = await GraphCompletionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_cot = await GraphCompletionCotRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_sum = await GraphSummaryCompletionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_triplet = await TripletRetriever().get_context( - query="Next to which country is Germany located?" - ) + contexts = { + "graph_completion": await GraphCompletionRetriever().get_context(query=query), + "graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query), + "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context( + query=query + ), + "graph_summary_completion": await GraphSummaryCompletionRetriever().get_context( + query=query + ), + "chunks": await ChunksRetriever(top_k=5).get_context(query=query), + "summaries": await SummariesRetriever(top_k=5).get_context(query=query), + "rag_completion": await CompletionRetriever(top_k=3).get_context(query=query), + "temporal": await TemporalRetriever(top_k=5).get_context(query=query), + "triplet": await TripletRetriever().get_context(query=query), + } - for name, context in [ - ("GraphCompletionRetriever", context_gk), - ("GraphCompletionCotRetriever", context_gk_cot), - ("GraphCompletionContextExtensionRetriever", context_gk_ext), - ("GraphSummaryCompletionRetriever", context_gk_sum), - ]: - assert isinstance(context, list), f"{name}: Context should be a list" - assert len(context) > 0, f"{name}: Context should not be empty" - - context_text = await resolve_edges_to_text(context) - lower = context_text.lower() - assert "germany" in lower or "netherlands" in lower, ( - f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}" - ) - - assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string" - assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty" - lower_triplet = context_triplet.lower() - assert "germany" in lower_triplet or "netherlands" in lower_triplet, ( - f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}" - ) - - triplets_gk = await GraphCompletionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - - for name, triplets in [ - ("GraphCompletionRetriever", triplets_gk), - ("GraphCompletionCotRetriever", triplets_gk_cot), - ("GraphCompletionContextExtensionRetriever", triplets_gk_ext), - ("GraphSummaryCompletionRetriever", triplets_gk_sum), - ]: - assert isinstance(triplets, list), f"{name}: Triplets should be a list" - assert triplets, f"{name}: Triplets list should not be empty" - for edge in triplets: - assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" - distance = edge.attributes.get("vector_distance") - node1_distance = edge.node1.attributes.get("vector_distance") - node2_distance = edge.node2.attributes.get("vector_distance") - assert isinstance(distance, float), ( - f"{name}: vector_distance should be float, got {type(distance)}" - ) - assert 0 <= distance <= 1, ( - f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen" - ) - assert 0 <= node1_distance <= 1, ( - f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen" - ) - assert 0 <= node2_distance <= 1, ( - f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen" - ) + # --- Retriever triplets + vector distance validation --- + triplets = { + "graph_completion": await GraphCompletionRetriever().get_triplets(query=query), + "graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query), + "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets( + query=query + ), + "graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets( + query=query + ), + } + # --- Search operations + graph side effects --- completion_gk = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Where is germany located, next to which country?", @@ -164,6 +214,26 @@ async def main(): query_text="Next to which country is Germany located?", save_interaction=True, ) + completion_chunks = await cognee.search( + query_type=SearchType.CHUNKS, + query_text="Germany", + save_interaction=False, + ) + completion_summaries = await cognee.search( + query_type=SearchType.SUMMARIES, + query_text="Germany", + save_interaction=False, + ) + completion_rag = await cognee.search( + query_type=SearchType.RAG_COMPLETION, + query_text="Next to which country is Germany located?", + save_interaction=False, + ) + completion_temporal = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="Next to which country is Germany located?", + save_interaction=False, + ) await cognee.search( query_type=SearchType.FEEDBACK, @@ -171,134 +241,217 @@ async def main(): last_k=1, ) - for name, search_results in [ - ("GRAPH_COMPLETION", completion_gk), - ("GRAPH_COMPLETION_COT", completion_cot), - ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), - ("GRAPH_SUMMARY_COMPLETION", completion_sum), - ("TRIPLET_COMPLETION", completion_triplet), - ]: - assert isinstance(search_results, list), f"{name}: should return a list" - assert len(search_results) == 1, ( - f"{name}: expected single-element list, got {len(search_results)}" - ) + # Snapshot after all E2E operations above (used by assertion-only tests). + graph_snapshot = await (await get_graph_engine()).get_graph_data() - from cognee.context_global_variables import backend_access_control_enabled + return { + "graph_edges": edges, + "triplet_collection": collection, + "vector_collection_edges_count": len(collection), + "graph_edges_count": len(edges), + "contexts": contexts, + "triplets": triplets, + "search_results": { + "graph_completion": completion_gk, + "graph_completion_cot": completion_cot, + "graph_completion_context_extension": completion_ext, + "graph_summary_completion": completion_sum, + "triplet_completion": completion_triplet, + "chunks": completion_chunks, + "summaries": completion_summaries, + "rag_completion": completion_rag, + "temporal": completion_temporal, + }, + "graph_snapshot": graph_snapshot, + } - if backend_access_control_enabled(): - text = search_results[0]["search_result"][0] - else: - text = search_results[0] - assert isinstance(text, str), f"{name}: element should be a string" - assert text.strip(), f"{name}: string should not be empty" - assert "netherlands" in text.lower(), ( - f"{name}: expected 'netherlands' in result, got: {text!r}" - ) - graph_engine = await get_graph_engine() - graph = await graph_engine.get_graph_data() - - type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) - - edge_type_counts = Counter(edge_type[2] for edge_type in graph[1]) - - # Assert there are exactly 4 CogneeUserInteraction nodes. - assert type_counts.get("CogneeUserInteraction", 0) == 4, ( - f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}" - ) - - # Assert there is exactly two CogneeUserFeedback nodes. - assert type_counts.get("CogneeUserFeedback", 0) == 2, ( - f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}" - ) - - # Assert there is exactly two NodeSet. - assert type_counts.get("NodeSet", 0) == 2, ( - f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}" - ) - - # Assert that there are at least 10 'used_graph_element_to_answer' edges. - assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, ( - f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}" - ) - - # Assert that there are exactly 2 'gives_feedback_to' edges. - assert edge_type_counts.get("gives_feedback_to", 0) == 2, ( - f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}" - ) - - # Assert that there are at least 6 'belongs_to_set' edges. - assert edge_type_counts.get("belongs_to_set", 0) == 6, ( - f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}" - ) - - nodes = graph[0] - - required_fields_user_interaction = {"question", "answer", "context"} - required_fields_feedback = {"feedback", "sentiment"} - - for node_id, data in nodes: - if data.get("type") == "CogneeUserInteraction": - assert required_fields_user_interaction.issubset(data.keys()), ( - f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}" - ) - - for field in required_fields_user_interaction: - value = data[field] - assert isinstance(value, str) and value.strip(), ( - f"Node {node_id} has invalid value for '{field}': {value!r}" - ) - - if data.get("type") == "CogneeUserFeedback": - assert required_fields_feedback.issubset(data.keys()), ( - f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}" - ) - - for field in required_fields_feedback: - value = data[field] - assert isinstance(value, str) and value.strip(), ( - f"Node {node_id} has invalid value for '{field}': {value!r}" - ) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - await cognee.add(text_1, dataset_name) - - await cognee.add([text], dataset_name) - - await cognee.cognify([dataset_name]) +@pytest_asyncio.fixture(scope="session") +async def feedback_state(): + """Feedback-weight scenario computed once (fresh environment).""" + await setup_test_environment_for_feedback() await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Next to which country is Germany located?", save_interaction=True, ) - await cognee.search( query_type=SearchType.FEEDBACK, query_text="This was the best answer I've ever seen", last_k=1, ) - await cognee.search( query_type=SearchType.FEEDBACK, query_text="Wow the correctness of this answer blows my mind", last_k=1, ) + graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data() + return {"graph_snapshot": graph} - edges = graph[1] - for from_node, to_node, relationship_name, properties in edges: +@pytest.mark.asyncio +async def test_e2e_graph_vector_consistency(e2e_state): + """Graph and vector stores contain the same triplet edges.""" + assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"] + + +@pytest.mark.asyncio +async def test_e2e_retriever_contexts(e2e_state): + """All retrievers return non-empty, well-typed contexts.""" + contexts = e2e_state["contexts"] + + for name in [ + "graph_completion", + "graph_completion_cot", + "graph_completion_context_extension", + "graph_summary_completion", + ]: + ctx = contexts[name] + assert isinstance(ctx, list), f"{name}: Context should be a list" + assert ctx, f"{name}: Context should not be empty" + ctx_text = await resolve_edges_to_text(ctx) + lower = ctx_text.lower() + assert "germany" in lower or "netherlands" in lower, ( + f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}" + ) + + triplet_ctx = contexts["triplet"] + assert isinstance(triplet_ctx, str), "triplet: Context should be a string" + assert triplet_ctx.strip(), "triplet: Context should not be empty" + + chunks_ctx = contexts["chunks"] + assert isinstance(chunks_ctx, list), "chunks: Context should be a list" + assert chunks_ctx, "chunks: Context should not be empty" + chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower() + assert "germany" in chunks_text or "netherlands" in chunks_text + + summaries_ctx = contexts["summaries"] + assert isinstance(summaries_ctx, list), "summaries: Context should be a list" + assert summaries_ctx, "summaries: Context should not be empty" + assert any(str(item.get("text", "")).strip() for item in summaries_ctx) + + rag_ctx = contexts["rag_completion"] + assert isinstance(rag_ctx, str), "rag_completion: Context should be a string" + assert rag_ctx.strip(), "rag_completion: Context should not be empty" + + temporal_ctx = contexts["temporal"] + assert isinstance(temporal_ctx, str), "temporal: Context should be a string" + assert temporal_ctx.strip(), "temporal: Context should not be empty" + + +@pytest.mark.asyncio +async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): + """Graph retriever triplets include sane vector_distance metadata.""" + for name, triplets in e2e_state["triplets"].items(): + assert isinstance(triplets, list), f"{name}: Triplets should be a list" + assert triplets, f"{name}: Triplets list should not be empty" + for edge in triplets: + assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" + distance = edge.attributes.get("vector_distance") + node1_distance = edge.node1.attributes.get("vector_distance") + node2_distance = edge.node2.attributes.get("vector_distance") + assert isinstance(distance, float), f"{name}: vector_distance should be float" + assert 0 <= distance <= 1 + assert 0 <= node1_distance <= 1 + assert 0 <= node2_distance <= 1 + + +@pytest.mark.asyncio +async def test_e2e_search_results_and_wrappers(e2e_state): + """Search returns expected shapes across search types and access modes.""" + from cognee.context_global_variables import backend_access_control_enabled + + sr = e2e_state["search_results"] + + # Completion-like search types: validate wrapper + content + for name in [ + "graph_completion", + "graph_completion_cot", + "graph_completion_context_extension", + "graph_summary_completion", + "triplet_completion", + "rag_completion", + "temporal", + ]: + search_results = sr[name] + assert isinstance(search_results, list), f"{name}: should return a list" + assert len(search_results) == 1, f"{name}: expected single-element list" + + if backend_access_control_enabled(): + wrapper = search_results[0] + assert isinstance(wrapper, dict), ( + f"{name}: expected wrapper dict in access control mode" + ) + assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper" + assert wrapper.get("dataset_name") == "test_dataset" + assert "graphs" in wrapper + text = wrapper["search_result"][0] + else: + text = search_results[0] + + assert isinstance(text, str) and text.strip() + assert "netherlands" in text.lower() + + # Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text + for name in ["chunks", "summaries"]: + search_results = sr[name] + assert isinstance(search_results, list), f"{name}: should return a list" + assert search_results, f"{name}: should not be empty" + + first = search_results[0] + assert isinstance(first, dict), f"{name}: expected dict entries" + + payloads = search_results + if "search_result" in first and "text" not in first: + payloads = (first.get("search_result") or [None])[0] + + assert isinstance(payloads, list) and payloads + assert isinstance(payloads[0], dict) + assert str(payloads[0].get("text", "")).strip() + + +@pytest.mark.asyncio +async def test_e2e_graph_side_effects_and_node_fields(e2e_state): + """Search interactions create expected graph nodes/edges and required fields.""" + graph = e2e_state["graph_snapshot"] + nodes, edges = graph + + type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes) + edge_type_counts = Counter(edge_type[2] for edge_type in edges) + + assert type_counts.get("CogneeUserInteraction", 0) == 4 + assert type_counts.get("CogneeUserFeedback", 0) == 2 + assert type_counts.get("NodeSet", 0) == 2 + assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10 + assert edge_type_counts.get("gives_feedback_to", 0) == 2 + assert edge_type_counts.get("belongs_to_set", 0) >= 6 + + required_fields_user_interaction = {"question", "answer", "context"} + required_fields_feedback = {"feedback", "sentiment"} + + for node_id, data in nodes: + if data.get("type") == "CogneeUserInteraction": + assert required_fields_user_interaction.issubset(data.keys()) + for field in required_fields_user_interaction: + value = data[field] + assert isinstance(value, str) and value.strip() + + if data.get("type") == "CogneeUserFeedback": + assert required_fields_feedback.issubset(data.keys()) + for field in required_fields_feedback: + value = data[field] + assert isinstance(value, str) and value.strip() + + +@pytest.mark.asyncio +async def test_e2e_feedback_weight_calculation(feedback_state): + """Positive feedback increases used_graph_element_to_answer feedback_weight.""" + _nodes, edges = feedback_state["graph_snapshot"] + for _from_node, _to_node, relationship_name, properties in edges: if relationship_name == "used_graph_element_to_answer": assert properties["feedback_weight"] >= 6, ( "Feedback weight calculation is not correct, it should be more then 6." ) - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main())