Merge branch 'dev' into feature/cog-3532-empower-test_search-db-retrievers-tests-reorg-2

This commit is contained in:
hajdul88 2025-12-16 09:55:27 +01:00 committed by GitHub
commit de525a6324
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 771 additions and 301 deletions

View file

@ -237,6 +237,31 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_dataset_database_handler.py run: uv run python ./cognee/tests/test_dataset_database_handler.py
test-dataset-database-deletion:
name: Test dataset database deletion in Cognee
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run dataset databases deletion test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_dataset_delete.py
test-permissions: test-permissions:
name: Test permissions with different situations in Cognee name: Test permissions with different situations in Cognee
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04

154
.github/workflows/release.yml vendored Normal file
View file

@ -0,0 +1,154 @@
name: release.yml
on:
workflow_dispatch:
inputs:
flavour:
required: true
default: dev
type: choice
options:
- dev
- main
description: Dev or Main release
test_mode:
required: true
type: boolean
description: Aka Dry Run. If true, it won't affect public indices or repositories
jobs:
release-github:
name: Create GitHub Release from ${{ inputs.flavour }}
outputs:
tag: ${{ steps.create_tag.outputs.tag }}
version: ${{ steps.create_tag.outputs.version }}
permissions:
contents: write
runs-on: ubuntu-latest
steps:
- name: Check out ${{ inputs.flavour }}
uses: actions/checkout@v4
with:
ref: ${{ inputs.flavour }}
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Create and push git tag
id: create_tag
env:
TEST_MODE: ${{ inputs.test_mode }}
run: |
VERSION="$(uv version --short)"
TAG="v${VERSION}"
echo "Tag to create: ${TAG}"
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
echo "tag=${TAG}" >> "$GITHUB_OUTPUT"
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
if [ "$TEST_MODE" = "false" ]; then
git tag "${TAG}"
git push origin "${TAG}"
else
echo "Test mode is enabled. Skipping tag creation and push."
fi
- name: Create GitHub Release
uses: softprops/action-gh-release@v2
with:
tag_name: ${{ steps.create_tag.outputs.tag }}
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
release-pypi-package:
needs: release-github
name: Release PyPI Package from ${{ inputs.flavour }}
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- name: Check out ${{ inputs.flavour }}
uses: actions/checkout@v4
with:
ref: ${{ inputs.flavour }}
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install Python
run: uv python install
- name: Install dependencies
run: uv sync --locked --all-extras
- name: Build distributions
run: uv build
- name: Publish ${{ inputs.flavour }} release to TestPyPI
if: ${{ inputs.test_mode }}
env:
UV_PUBLISH_TOKEN: ${{ secrets.TEST_PYPI_TOKEN }}
run: uv publish --publish-url https://test.pypi.org/legacy/
- name: Publish ${{ inputs.flavour }} release to PyPI
if: ${{ !inputs.test_mode }}
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: uv publish
release-docker-image:
needs: release-github
name: Release Docker Image from ${{ inputs.flavour }}
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- name: Check out ${{ inputs.flavour }}
uses: actions/checkout@v4
with:
ref: ${{ inputs.flavour }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Build and push Dev Docker Image
if: ${{ inputs.flavour == 'dev' }}
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
push: ${{ !inputs.test_mode }}
tags: cognee/cognee:${{ needs.release-github.outputs.version }}
labels: |
version=${{ needs.release-github.outputs.version }}
flavour=${{ inputs.flavour }}
cache-from: type=registry,ref=cognee/cognee:buildcache
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
- name: Build and push Main Docker Image
if: ${{ inputs.flavour == 'main' }}
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
push: ${{ !inputs.test_mode }}
tags: |
cognee/cognee:${{ needs.release-github.outputs.version }}
cognee/cognee:latest
labels: |
version=${{ needs.release-github.outputs.version }}
flavour=${{ inputs.flavour }}
cache-from: type=registry,ref=cognee/cognee:buildcache
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max

View file

@ -84,3 +84,93 @@ jobs:
EMBEDDING_DIMENSIONS: "3072" EMBEDDING_DIMENSIONS: "3072"
EMBEDDING_MAX_TOKENS: "8191" EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py run: uv run python ./examples/python/simple_example.py
test-bedrock-api-key:
name: Run Bedrock API Key Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Run Bedrock API Key Simple Example
env:
LLM_PROVIDER: "bedrock"
LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
LLM_MAX_TOKENS: "16384"
AWS_REGION_NAME: "eu-west-1"
EMBEDDING_PROVIDER: "bedrock"
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
EMBEDDING_DIMENSIONS: "1024"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py
test-bedrock-aws-credentials:
name: Run Bedrock AWS Credentials Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Run Bedrock AWS Credentials Simple Example
env:
LLM_PROVIDER: "bedrock"
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
LLM_MAX_TOKENS: "16384"
AWS_REGION_NAME: "eu-west-1"
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
EMBEDDING_PROVIDER: "bedrock"
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
EMBEDDING_DIMENSIONS: "1024"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py
test-bedrock-aws-profile:
name: Run Bedrock AWS Profile Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "aws"
- name: Configure AWS Profile
run: |
mkdir -p ~/.aws
cat > ~/.aws/credentials << EOF
[bedrock-test]
aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }}
aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }}
EOF
- name: Run Bedrock AWS Profile Simple Example
env:
LLM_PROVIDER: "bedrock"
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
LLM_MAX_TOKENS: "16384"
AWS_PROFILE_NAME: "bedrock-test"
AWS_REGION_NAME: "eu-west-1"
EMBEDDING_PROVIDER: "bedrock"
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
EMBEDDING_DIMENSIONS: "1024"
EMBEDDING_MAX_TOKENS: "8191"
run: uv run python ./examples/python/simple_example.py

View file

@ -3,7 +3,7 @@
Test client for Cognee MCP Server functionality. Test client for Cognee MCP Server functionality.
This script tests all the tools and functions available in the Cognee MCP server, This script tests all the tools and functions available in the Cognee MCP server,
including cognify, codify, search, prune, status checks, and utility functions. including cognify, search, prune, status checks, and utility functions.
Usage: Usage:
# Set your OpenAI API key first # Set your OpenAI API key first
@ -23,6 +23,7 @@ import tempfile
import time import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from cognee.shared.logging_utils import setup_logging from cognee.shared.logging_utils import setup_logging
from logging import ERROR, INFO
from mcp import ClientSession, StdioServerParameters from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client from mcp.client.stdio import stdio_client
@ -35,7 +36,7 @@ from src.server import (
load_class, load_class,
) )
# Set timeout for cognify/codify to complete in # Set timeout for cognify to complete in
TIMEOUT = 5 * 60 # 5 min in seconds TIMEOUT = 5 * 60 # 5 min in seconds
@ -151,12 +152,9 @@ DEBUG = True
expected_tools = { expected_tools = {
"cognify", "cognify",
"codify",
"search", "search",
"prune", "prune",
"cognify_status", "cognify_status",
"codify_status",
"cognee_add_developer_rules",
"list_data", "list_data",
"delete", "delete",
} }
@ -247,106 +245,6 @@ DEBUG = True
} }
print(f"{test_name} test failed: {e}") print(f"{test_name} test failed: {e}")
async def test_codify(self):
"""Test the codify functionality using MCP client."""
print("\n🧪 Testing codify functionality...")
try:
async with self.mcp_server_session() as session:
codify_result = await session.call_tool(
"codify", arguments={"repo_path": self.test_repo_dir}
)
start = time.time() # mark the start
while True:
try:
# Wait a moment
await asyncio.sleep(5)
# Check if codify processing is finished
status_result = await session.call_tool("codify_status", arguments={})
if hasattr(status_result, "content") and status_result.content:
status_text = (
status_result.content[0].text
if status_result.content
else str(status_result)
)
else:
status_text = str(status_result)
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
break
elif time.time() - start > TIMEOUT:
raise TimeoutError("Codify did not complete in 5min")
except DatabaseNotCreatedError:
if time.time() - start > TIMEOUT:
raise TimeoutError("Database was not created in 5min")
self.test_results["codify"] = {
"status": "PASS",
"result": codify_result,
"message": "Codify executed successfully",
}
print("✅ Codify test passed")
except Exception as e:
self.test_results["codify"] = {
"status": "FAIL",
"error": str(e),
"message": "Codify test failed",
}
print(f"❌ Codify test failed: {e}")
async def test_cognee_add_developer_rules(self):
"""Test the cognee_add_developer_rules functionality using MCP client."""
print("\n🧪 Testing cognee_add_developer_rules functionality...")
try:
async with self.mcp_server_session() as session:
result = await session.call_tool(
"cognee_add_developer_rules", arguments={"base_path": self.test_data_dir}
)
start = time.time() # mark the start
while True:
try:
# Wait a moment
await asyncio.sleep(5)
# Check if developer rule cognify processing is finished
status_result = await session.call_tool("cognify_status", arguments={})
if hasattr(status_result, "content") and status_result.content:
status_text = (
status_result.content[0].text
if status_result.content
else str(status_result)
)
else:
status_text = str(status_result)
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_text:
break
elif time.time() - start > TIMEOUT:
raise TimeoutError(
"Cognify of developer rules did not complete in 5min"
)
except DatabaseNotCreatedError:
if time.time() - start > TIMEOUT:
raise TimeoutError("Database was not created in 5min")
self.test_results["cognee_add_developer_rules"] = {
"status": "PASS",
"result": result,
"message": "Developer rules addition executed successfully",
}
print("✅ Developer rules test passed")
except Exception as e:
self.test_results["cognee_add_developer_rules"] = {
"status": "FAIL",
"error": str(e),
"message": "Developer rules test failed",
}
print(f"❌ Developer rules test failed: {e}")
async def test_search_functionality(self): async def test_search_functionality(self):
"""Test the search functionality with different search types using MCP client.""" """Test the search functionality with different search types using MCP client."""
print("\n🧪 Testing search functionality...") print("\n🧪 Testing search functionality...")
@ -359,7 +257,11 @@ DEBUG = True
# Go through all Cognee search types # Go through all Cognee search types
for search_type in SearchType: for search_type in SearchType:
# Don't test these search types # Don't test these search types
if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER]: if search_type in [
SearchType.NATURAL_LANGUAGE,
SearchType.CYPHER,
SearchType.TRIPLET_COMPLETION,
]:
break break
try: try:
async with self.mcp_server_session() as session: async with self.mcp_server_session() as session:
@ -681,9 +583,6 @@ class TestModel:
test_name="Cognify2", test_name="Cognify2",
) )
await self.test_codify()
await self.test_cognee_add_developer_rules()
# Test list_data and delete functionality # Test list_data and delete functionality
await self.test_list_data() await self.test_list_data()
await self.test_delete() await self.test_delete()
@ -739,7 +638,5 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
from logging import ERROR
logger = setup_logging(log_level=ERROR) logger = setup_logging(log_level=ERROR)
asyncio.run(main()) asyncio.run(main())

View file

@ -155,7 +155,7 @@ async def add(
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.) - LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
Optional: Optional:
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral" - LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral", "bedrock"
- LLM_MODEL: Model name (default: "gpt-5-mini") - LLM_MODEL: Model name (default: "gpt-5-mini")
- DEFAULT_USER_EMAIL: Custom default user email - DEFAULT_USER_EMAIL: Custom default user email
- DEFAULT_USER_PASSWORD: Custom default user password - DEFAULT_USER_PASSWORD: Custom default user password

View file

@ -53,6 +53,7 @@ async def cognify(
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
temporal_cognify: bool = False, temporal_cognify: bool = False,
data_per_batch: int = 20, data_per_batch: int = 20,
**kwargs,
): ):
""" """
Transform ingested data into a structured knowledge graph. Transform ingested data into a structured knowledge graph.
@ -223,6 +224,7 @@ async def cognify(
config=config, config=config,
custom_prompt=custom_prompt, custom_prompt=custom_prompt,
chunks_per_batch=chunks_per_batch, chunks_per_batch=chunks_per_batch,
**kwargs,
) )
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
@ -251,6 +253,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config: Config = None, config: Config = None,
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
chunks_per_batch: int = 100, chunks_per_batch: int = 100,
**kwargs,
) -> list[Task]: ) -> list[Task]:
if config is None: if config is None:
ontology_config = get_ontology_env_config() ontology_config = get_ontology_env_config()
@ -288,6 +291,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
config=config, config=config,
custom_prompt=custom_prompt, custom_prompt=custom_prompt,
task_config={"batch_size": chunks_per_batch}, task_config={"batch_size": chunks_per_batch},
**kwargs,
), # Generate knowledge graphs from the document chunks. ), # Generate knowledge graphs from the document chunks.
Task( Task(
summarize_text, summarize_text,

View file

@ -42,7 +42,9 @@ class CognifyPayloadDTO(InDTO):
default="", description="Custom prompt for entity extraction and graph generation" default="", description="Custom prompt for entity extraction and graph generation"
) )
ontology_key: Optional[List[str]] = Field( ontology_key: Optional[List[str]] = Field(
default=None, description="Reference to one or more previously uploaded ontologies" default=None,
examples=[[]],
description="Reference to one or more previously uploaded ontologies",
) )

View file

@ -208,14 +208,14 @@ def get_datasets_router() -> APIRouter:
}, },
) )
from cognee.modules.data.methods import get_dataset, delete_dataset from cognee.modules.data.methods import delete_dataset
dataset = await get_dataset(user.id, dataset_id) dataset = await get_authorized_existing_datasets([dataset_id], "delete", user)
if dataset is None: if dataset is None:
raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.") raise DatasetNotFoundError(message=f"Dataset ({str(dataset_id)}) not found.")
await delete_dataset(dataset) await delete_dataset(dataset[0])
@router.delete( @router.delete(
"/{dataset_id}/data/{data_id}", "/{dataset_id}/data/{data_id}",

View file

@ -1,4 +1,4 @@
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException from fastapi import APIRouter, File, Form, UploadFile, Depends, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import Optional, List from typing import Optional, List
@ -15,28 +15,25 @@ def get_ontology_router() -> APIRouter:
@router.post("", response_model=dict) @router.post("", response_model=dict)
async def upload_ontology( async def upload_ontology(
request: Request,
ontology_key: str = Form(...), ontology_key: str = Form(...),
ontology_file: List[UploadFile] = File(...), ontology_file: UploadFile = File(...),
descriptions: Optional[str] = Form(None), description: Optional[str] = Form(None),
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
): ):
""" """
Upload ontology files with their respective keys for later use in cognify operations. Upload a single ontology file for later use in cognify operations.
Supports both single and multiple file uploads:
- Single file: ontology_key=["key"], ontology_file=[file]
- Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
## Request Parameters ## Request Parameters
- **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies - **ontology_key** (str): User-defined identifier for the ontology.
- **ontology_file** (List[UploadFile]): OWL format ontology files - **ontology_file** (UploadFile): Single OWL format ontology file
- **descriptions** (Optional[str]): JSON array string of optional descriptions - **description** (Optional[str]): Optional description for the ontology.
## Response ## Response
Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps. Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp.
## Error Codes ## Error Codes
- **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded - **400 Bad Request**: Invalid file format, duplicate key, multiple files uploaded
- **500 Internal Server Error**: File system or processing errors - **500 Internal Server Error**: File system or processing errors
""" """
send_telemetry( send_telemetry(
@ -49,16 +46,22 @@ def get_ontology_router() -> APIRouter:
) )
try: try:
import json # Enforce: exactly one uploaded file for "ontology_file"
form = await request.form()
uploaded_files = form.getlist("ontology_file")
if len(uploaded_files) != 1:
raise ValueError("Only one ontology_file is allowed")
ontology_keys = json.loads(ontology_key) if ontology_key.strip().startswith(("[", "{")):
description_list = json.loads(descriptions) if descriptions else None raise ValueError("ontology_key must be a string")
if description is not None and description.strip().startswith(("[", "{")):
raise ValueError("description must be a string")
if not isinstance(ontology_keys, list): result = await ontology_service.upload_ontology(
raise ValueError("ontology_key must be a JSON array") ontology_key=ontology_key,
file=ontology_file,
results = await ontology_service.upload_ontologies( user=user,
ontology_keys, ontology_file, user, description_list description=description,
) )
return { return {
@ -70,10 +73,9 @@ def get_ontology_router() -> APIRouter:
"uploaded_at": result.uploaded_at, "uploaded_at": result.uploaded_at,
"description": result.description, "description": result.description,
} }
for result in results
] ]
} }
except (json.JSONDecodeError, ValueError) as e: except ValueError as e:
return JSONResponse(status_code=400, content={"error": str(e)}) return JSONResponse(status_code=400, content={"error": str(e)})
except Exception as e: except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)}) return JSONResponse(status_code=500, content={"error": str(e)})

View file

@ -1,2 +1,4 @@
from .get_or_create_dataset_database import get_or_create_dataset_database from .get_or_create_dataset_database import get_or_create_dataset_database
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
from .get_graph_dataset_database_handler import get_graph_dataset_database_handler
from .get_vector_dataset_database_handler import get_vector_dataset_database_handler

View file

@ -0,0 +1,10 @@
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
def get_graph_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
return handler

View file

@ -0,0 +1,10 @@
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
def get_vector_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
return handler

View file

@ -1,24 +1,12 @@
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
get_graph_dataset_database_handler,
)
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
get_vector_dataset_database_handler,
)
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
async def resolve_dataset_database_connection_info( async def resolve_dataset_database_connection_info(
dataset_database: DatasetDatabase, dataset_database: DatasetDatabase,
) -> DatasetDatabase: ) -> DatasetDatabase:
@ -31,6 +19,12 @@ async def resolve_dataset_database_connection_info(
Returns: Returns:
DatasetDatabase instance with resolved connection info DatasetDatabase instance with resolved connection info
""" """
dataset_database = await _get_vector_db_connection_info(dataset_database) vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
dataset_database = await _get_graph_db_connection_info(dataset_database) graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
dataset_database = await vector_dataset_database_handler[
"handler_instance"
].resolve_dataset_connection_info(dataset_database)
dataset_database = await graph_dataset_database_handler[
"handler_instance"
].resolve_dataset_connection_info(dataset_database)
return dataset_database return dataset_database

View file

@ -9,6 +9,8 @@ class S3Config(BaseSettings):
aws_access_key_id: Optional[str] = None aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None aws_session_token: Optional[str] = None
aws_profile_name: Optional[str] = None
aws_bedrock_runtime_endpoint: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -11,7 +11,7 @@ class LLMGateway:
@staticmethod @staticmethod
def acreate_structured_output( def acreate_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel] text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> Coroutine: ) -> Coroutine:
llm_config = get_llm_config() llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML": if llm_config.structured_output_framework.upper() == "BAML":
@ -31,7 +31,10 @@ class LLMGateway:
llm_client = get_llm_client() llm_client = get_llm_client()
return llm_client.acreate_structured_output( return llm_client.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model text_input=text_input,
system_prompt=system_prompt,
response_model=response_model,
**kwargs,
) )
@staticmethod @staticmethod

View file

@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
async def extract_content_graph( async def extract_content_graph(
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
): ):
if custom_prompt: if custom_prompt:
system_prompt = custom_prompt system_prompt = custom_prompt
@ -30,7 +30,7 @@ async def extract_content_graph(
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory) system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await LLMGateway.acreate_structured_output( content_graph = await LLMGateway.acreate_structured_output(
content, system_prompt, response_model content, system_prompt, response_model, **kwargs
) )
return content_graph return content_graph

View file

@ -52,7 +52,7 @@ class AnthropicAdapter(LLMInterface):
reraise=True, reraise=True,
) )
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a response from a user query. Generate a response from a user query.

View file

@ -0,0 +1,5 @@
"""Bedrock LLM adapter module."""
from .adapter import BedrockAdapter
__all__ = ["BedrockAdapter"]

View file

@ -0,0 +1,153 @@
import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError,
MissingSystemPromptPathError,
)
from cognee.infrastructure.files.storage.s3_config import get_s3_config
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
sleep_and_retry_sync,
)
from cognee.modules.observability.get_observe import get_observe
observe = get_observe()
class BedrockAdapter(LLMInterface):
"""
Adapter for AWS Bedrock API with support for three authentication methods:
1. API Key (Bearer Token)
2. AWS Credentials (access key + secret key)
3. AWS Profile (boto3 credential chain)
"""
name = "Bedrock"
model: str
api_key: str
default_instructor_mode = "json_schema_mode"
MAX_RETRIES = 5
def __init__(
self,
model: str,
api_key: str = None,
max_completion_tokens: int = 16384,
streaming: bool = False,
instructor_mode: str = None,
):
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
self.client = instructor.from_litellm(litellm.completion)
self.model = model
self.api_key = api_key
self.max_completion_tokens = max_completion_tokens
self.streaming = streaming
def _create_bedrock_request(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> dict:
"""Create Bedrock request with authentication."""
request_params = {
"model": self.model,
"custom_llm_provider": "bedrock",
"drop_params": True,
"messages": [
{"role": "user", "content": text_input},
{"role": "system", "content": system_prompt},
],
"response_model": response_model,
"max_retries": self.MAX_RETRIES,
"max_completion_tokens": self.max_completion_tokens,
"stream": self.streaming,
}
s3_config = get_s3_config()
# Add authentication parameters
if self.api_key:
request_params["api_key"] = self.api_key
elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key:
request_params["aws_access_key_id"] = s3_config.aws_access_key_id
request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key
if s3_config.aws_session_token:
request_params["aws_session_token"] = s3_config.aws_session_token
elif s3_config.aws_profile_name:
request_params["aws_profile_name"] = s3_config.aws_profile_name
if s3_config.aws_region:
request_params["aws_region_name"] = s3_config.aws_region
# Add optional parameters
if s3_config.aws_bedrock_runtime_endpoint:
request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint
return request_params
@observe(as_type="generation")
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate structured output from AWS Bedrock API."""
try:
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
return await self.aclient.chat.completions.create(**request_params)
except (
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
@observe
@sleep_and_retry_sync()
@rate_limit_sync
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate structured output from AWS Bedrock API (synchronous)."""
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
return self.client.chat.completions.create(**request_params)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -80,7 +80,7 @@ class GeminiAdapter(LLMInterface):
reraise=True, reraise=True,
) )
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a response from a user query. Generate a response from a user query.

View file

@ -80,7 +80,7 @@ class GenericAPIAdapter(LLMInterface):
reraise=True, reraise=True,
) )
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a response from a user query. Generate a response from a user query.

View file

@ -24,6 +24,7 @@ class LLMProvider(Enum):
- CUSTOM: Represents a custom provider option. - CUSTOM: Represents a custom provider option.
- GEMINI: Represents the Gemini provider. - GEMINI: Represents the Gemini provider.
- MISTRAL: Represents the Mistral AI provider. - MISTRAL: Represents the Mistral AI provider.
- BEDROCK: Represents the AWS Bedrock provider.
""" """
OPENAI = "openai" OPENAI = "openai"
@ -32,6 +33,7 @@ class LLMProvider(Enum):
CUSTOM = "custom" CUSTOM = "custom"
GEMINI = "gemini" GEMINI = "gemini"
MISTRAL = "mistral" MISTRAL = "mistral"
BEDROCK = "bedrock"
def get_llm_client(raise_api_key_error: bool = True): def get_llm_client(raise_api_key_error: bool = True):
@ -154,7 +156,7 @@ def get_llm_client(raise_api_key_error: bool = True):
) )
elif provider == LLMProvider.MISTRAL: elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None and raise_api_key_error:
raise LLMAPIKeyNotSetError() raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
@ -169,5 +171,21 @@ def get_llm_client(raise_api_key_error: bool = True):
instructor_mode=llm_config.llm_instructor_mode.lower(), instructor_mode=llm_config.llm_instructor_mode.lower(),
) )
elif provider == LLMProvider.BEDROCK:
# if llm_config.llm_api_key is None and raise_api_key_error:
# raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import (
BedrockAdapter,
)
return BedrockAdapter(
model=llm_config.llm_model,
api_key=llm_config.llm_api_key,
max_completion_tokens=max_completion_tokens,
streaming=llm_config.llm_streaming,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
else: else:
raise UnsupportedLLMProviderError(provider) raise UnsupportedLLMProviderError(provider)

View file

@ -69,7 +69,7 @@ class MistralAdapter(LLMInterface):
reraise=True, reraise=True,
) )
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a response from the user query. Generate a response from the user query.

View file

@ -76,7 +76,7 @@ class OllamaAPIAdapter(LLMInterface):
reraise=True, reraise=True,
) )
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a structured output from the LLM using the provided text and system prompt. Generate a structured output from the LLM using the provided text and system prompt.
@ -123,7 +123,7 @@ class OllamaAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def create_transcript(self, input_file: str) -> str: async def create_transcript(self, input_file: str, **kwargs) -> str:
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -162,7 +162,7 @@ class OllamaAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def transcribe_image(self, input_file: str) -> str: async def transcribe_image(self, input_file: str, **kwargs) -> str:
""" """
Transcribe content from an image using base64 encoding. Transcribe content from an image using base64 encoding.

View file

@ -112,7 +112,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True, reraise=True,
) )
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a response from a user query. Generate a response from a user query.
@ -154,6 +154,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version, api_version=self.api_version,
response_model=response_model, response_model=response_model,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )
except ( except (
ContentFilterFinishReasonError, ContentFilterFinishReasonError,
@ -180,6 +181,7 @@ class OpenAIAdapter(LLMInterface):
# api_base=self.fallback_endpoint, # api_base=self.fallback_endpoint,
response_model=response_model, response_model=response_model,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )
except ( except (
ContentFilterFinishReasonError, ContentFilterFinishReasonError,
@ -205,7 +207,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True, reraise=True,
) )
def create_structured_output( def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a response from a user query. Generate a response from a user query.
@ -245,6 +247,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version, api_version=self.api_version,
response_model=response_model, response_model=response_model,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )
@retry( @retry(
@ -254,7 +257,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def create_transcript(self, input): async def create_transcript(self, input, **kwargs):
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -281,6 +284,7 @@ class OpenAIAdapter(LLMInterface):
api_base=self.endpoint, api_base=self.endpoint,
api_version=self.api_version, api_version=self.api_version,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )
return transcription return transcription
@ -292,7 +296,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def transcribe_image(self, input) -> BaseModel: async def transcribe_image(self, input, **kwargs) -> BaseModel:
""" """
Generate a transcription of an image from a user query. Generate a transcription of an image from a user query.
@ -337,4 +341,5 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version, api_version=self.api_version,
max_completion_tokens=300, max_completion_tokens=300,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )

View file

@ -5,6 +5,10 @@ from cognee.context_global_variables import backend_access_control_enabled
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.utils import (
get_graph_dataset_database_handler,
get_vector_dataset_database_handler,
)
from cognee.shared.cache import delete_cache from cognee.shared.cache import delete_cache
from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import DatasetDatabase
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
@ -13,22 +17,13 @@ logger = get_logger()
async def prune_graph_databases(): async def prune_graph_databases():
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[
dataset_database.graph_dataset_database_handler
]
return await handler["handler_instance"].delete_dataset(dataset_database)
db_engine = get_relational_engine() db_engine = get_relational_engine()
try: try:
data = await db_engine.get_all_data_from_table("dataset_database") dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the graph database # Go through each dataset database and delete the graph database
for data_item in data: for dataset_database in dataset_databases:
await _prune_graph_db(data_item) handler = get_graph_dataset_database_handler(dataset_database)
await handler["handler_instance"].delete_dataset(dataset_database)
except (OperationalError, EntityNotFoundError) as e: except (OperationalError, EntityNotFoundError) as e:
logger.debug( logger.debug(
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s", "Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
@ -38,22 +33,13 @@ async def prune_graph_databases():
async def prune_vector_databases(): async def prune_vector_databases():
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[
dataset_database.vector_dataset_database_handler
]
return await handler["handler_instance"].delete_dataset(dataset_database)
db_engine = get_relational_engine() db_engine = get_relational_engine()
try: try:
data = await db_engine.get_all_data_from_table("dataset_database") dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the vector database # Go through each dataset database and delete the vector database
for data_item in data: for dataset_database in dataset_databases:
await _prune_vector_db(data_item) handler = get_vector_dataset_database_handler(dataset_database)
await handler["handler_instance"].delete_dataset(dataset_database)
except (OperationalError, EntityNotFoundError) as e: except (OperationalError, EntityNotFoundError) as e:
logger.debug( logger.debug(
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s", "Skipping pruning of vector DB. Error when accessing dataset_database table: %s",

View file

@ -1,8 +1,34 @@
from cognee.modules.users.models import DatasetDatabase
from sqlalchemy import select
from cognee.modules.data.models import Dataset from cognee.modules.data.models import Dataset
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
get_vector_dataset_database_handler,
)
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
get_graph_dataset_database_handler,
)
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
async def delete_dataset(dataset: Dataset): async def delete_dataset(dataset: Dataset):
db_engine = get_relational_engine() db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
stmt = select(DatasetDatabase).where(
DatasetDatabase.dataset_id == dataset.id,
)
dataset_database: DatasetDatabase = await session.scalar(stmt)
if dataset_database:
graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
await graph_dataset_database_handler["handler_instance"].delete_dataset(
dataset_database
)
await vector_dataset_database_handler["handler_instance"].delete_dataset(
dataset_database
)
# TODO: Remove dataset from pipeline_run_status in Data objects related to dataset as well
# This blocks recreation of the dataset with the same name and data after deletion as
# it's marked as completed and will be just skipped even though it's empty.
return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id) return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)

View file

@ -16,6 +16,7 @@ class ModelName(Enum):
anthropic = "anthropic" anthropic = "anthropic"
gemini = "gemini" gemini = "gemini"
mistral = "mistral" mistral = "mistral"
bedrock = "bedrock"
class LLMConfig(BaseModel): class LLMConfig(BaseModel):
@ -77,6 +78,10 @@ def get_settings() -> SettingsDict:
"value": "mistral", "value": "mistral",
"label": "Mistral", "label": "Mistral",
}, },
{
"value": "bedrock",
"label": "Bedrock",
},
] ]
return SettingsDict.model_validate( return SettingsDict.model_validate(
@ -157,6 +162,20 @@ def get_settings() -> SettingsDict:
"label": "Mistral Large 2.1", "label": "Mistral Large 2.1",
}, },
], ],
"bedrock": [
{
"value": "eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
"label": "Claude 4.5 Sonnet",
},
{
"value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0",
"label": "Claude 4.5 Haiku",
},
{
"value": "eu.amazon.nova-lite-v1:0",
"label": "Amazon Nova Lite",
},
],
}, },
}, },
vector_db={ vector_db={

View file

@ -97,6 +97,7 @@ async def extract_graph_from_data(
graph_model: Type[BaseModel], graph_model: Type[BaseModel],
config: Config = None, config: Config = None,
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
**kwargs,
) -> List[DocumentChunk]: ) -> List[DocumentChunk]:
""" """
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model. Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
@ -111,7 +112,7 @@ async def extract_graph_from_data(
chunk_graphs = await asyncio.gather( chunk_graphs = await asyncio.gather(
*[ *[
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt) extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt, **kwargs)
for chunk in data_chunks for chunk in data_chunks
] ]
) )

View file

@ -1,5 +1,6 @@
from typing import AsyncGenerator, Dict, Any, List, Optional from typing import AsyncGenerator, Dict, Any, List, Optional
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.modules.engine.utils import generate_node_id
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -155,7 +156,12 @@ def _process_single_triplet(
embeddable_text = f"{start_node_text}-{relationship_text}-{end_node_text}".strip() embeddable_text = f"{start_node_text}-{relationship_text}-{end_node_text}".strip()
triplet_obj = Triplet(from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text) relationship_name = relationship.get("relationship_name", "")
triplet_id = generate_node_id(str(start_node_id) + str(relationship_name) + str(end_node_id))
triplet_obj = Triplet(
id=triplet_id, from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text
)
return triplet_obj, None return triplet_obj, None

View file

@ -148,8 +148,8 @@ class TestCogneeServerStart(unittest.TestCase):
headers=headers, headers=headers,
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
data={ data={
"ontology_key": json.dumps([ontology_key]), "ontology_key": ontology_key,
"description": json.dumps(["Test ontology"]), "description": "Test ontology",
}, },
) )
self.assertEqual(ontology_response.status_code, 200) self.assertEqual(ontology_response.status_code, 200)

View file

@ -0,0 +1,76 @@
import os
import asyncio
import pathlib
from uuid import UUID
import cognee
from cognee.shared.logging_utils import setup_logging, ERROR
from cognee.modules.data.methods.delete_dataset import delete_dataset
from cognee.modules.data.methods.get_dataset import get_dataset
from cognee.modules.users.methods import get_default_user
async def main():
# Set data and system directory paths
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_dataset_delete")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_dataset_delete")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
# Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# cognee knowledge graph will be created based on this text
text = """
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
"""
# Add the text, and make it available for cognify
await cognee.add(text, "nlp_dataset")
await cognee.add("Quantum computing is the study of quantum computers.", "quantum_dataset")
# Use LLMs and cognee to create knowledge graph
ret_val = await cognee.cognify()
user = await get_default_user()
for val in ret_val:
dataset_id = str(val)
vector_db_path = os.path.join(
cognee_directory_path, "databases", str(user.id), dataset_id + ".lance.db"
)
graph_db_path = os.path.join(
cognee_directory_path, "databases", str(user.id), dataset_id + ".pkl"
)
# Check if databases are properly created and exist before deletion
assert os.path.exists(graph_db_path), "Graph database file not found."
assert os.path.exists(vector_db_path), "Vector database file not found."
dataset = await get_dataset(user_id=user.id, dataset_id=UUID(dataset_id))
await delete_dataset(dataset)
# Confirm databases have been deleted
assert not os.path.exists(graph_db_path), "Graph database file found."
assert not os.path.exists(vector_db_path), "Vector database file found."
if __name__ == "__main__":
logger = setup_logging(log_level=ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -1,17 +1,28 @@
import pytest import pytest
import uuid import uuid
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from unittest.mock import patch, Mock, AsyncMock from unittest.mock import Mock
from types import SimpleNamespace from types import SimpleNamespace
import importlib
from cognee.api.client import app from cognee.api.client import app
from cognee.modules.users.methods import get_authenticated_user
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
@pytest.fixture(scope="session")
def test_client():
# Keep a single TestClient (and event loop) for the whole module.
# Re-creating TestClient repeatedly can break async DB connections (asyncpg loop mismatch).
with TestClient(app) as c:
yield c
@pytest.fixture @pytest.fixture
def client(): def client(test_client, mock_default_user):
return TestClient(app) async def override_get_authenticated_user():
return mock_default_user
app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
yield test_client
app.dependency_overrides.pop(get_authenticated_user, None)
@pytest.fixture @pytest.fixture
@ -32,12 +43,8 @@ def mock_default_user():
) )
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_success(client):
def test_upload_ontology_success(mock_get_default_user, client, mock_default_user):
"""Test successful ontology upload""" """Test successful ontology upload"""
import json
mock_get_default_user.return_value = mock_default_user
ontology_content = ( ontology_content = (
b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>" b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
) )
@ -46,7 +53,7 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use
response = client.post( response = client.post(
"/api/v1/ontologies", "/api/v1/ontologies",
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])}, data={"ontology_key": unique_key, "description": "Test"},
) )
assert response.status_code == 200 assert response.status_code == 200
@ -55,10 +62,8 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use
assert "uploaded_at" in data["uploaded_ontologies"][0] assert "uploaded_at" in data["uploaded_ontologies"][0]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_invalid_file(client):
def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user):
"""Test 400 response for non-.owl files""" """Test 400 response for non-.owl files"""
mock_get_default_user.return_value = mock_default_user
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
response = client.post( response = client.post(
"/api/v1/ontologies", "/api/v1/ontologies",
@ -68,14 +73,10 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul
assert response.status_code == 400 assert response.status_code == 400
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_missing_data(client):
def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user):
"""Test 400 response for missing file or key""" """Test 400 response for missing file or key"""
import json
mock_get_default_user.return_value = mock_default_user
# Missing file # Missing file
response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])}) response = client.post("/api/v1/ontologies", data={"ontology_key": "test"})
assert response.status_code == 400 assert response.status_code == 400
# Missing key # Missing key
@ -85,34 +86,25 @@ def test_upload_ontology_missing_data(mock_get_default_user, client, mock_defaul
assert response.status_code == 400 assert response.status_code == 400
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_without_auth_header(client):
def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): """Test behavior when no explicit authentication header is provided."""
"""Test behavior when default user is provided (no explicit authentication)"""
import json
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
mock_get_default_user.return_value = mock_default_user
response = client.post( response = client.post(
"/api/v1/ontologies", "/api/v1/ontologies",
files=[("ontology_file", ("test.owl", b"<rdf></rdf>", "application/xml"))], files=[("ontology_file", ("test.owl", b"<rdf></rdf>", "application/xml"))],
data={"ontology_key": json.dumps([unique_key])}, data={"ontology_key": unique_key},
) )
# The current system provides a default user when no explicit authentication is given
# This test verifies the system works with conditional authentication
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
assert "uploaded_at" in data["uploaded_ontologies"][0] assert "uploaded_at" in data["uploaded_ontologies"][0]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_multiple_ontologies_in_single_request_is_rejected(client):
def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): """Uploading multiple ontology files in a single request should fail."""
"""Test uploading multiple ontology files in single request"""
import io import io
mock_get_default_user.return_value = mock_default_user
# Create mock files
file1_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>" file1_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
file2_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>" file2_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
@ -120,45 +112,34 @@ def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
] ]
data = { data = {"ontology_key": "vehicles", "description": "Base vehicles"}
"ontology_key": '["vehicles", "manufacturers"]',
"descriptions": '["Base vehicles", "Car manufacturers"]',
}
response = client.post("/api/v1/ontologies", files=files, data=data) response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 200 assert response.status_code == 400
result = response.json() assert "Only one ontology_file is allowed" in response.json()["error"]
assert "uploaded_ontologies" in result
assert len(result["uploaded_ontologies"]) == 2
assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles"
assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers"
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_endpoint_rejects_array_style_fields(client):
def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): """Array-style form values should be rejected (no backwards compatibility)."""
"""Test that upload endpoint accepts array parameters"""
import io import io
import json import json
mock_get_default_user.return_value = mock_default_user
file_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>" file_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))] files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
data = { data = {
"ontology_key": json.dumps(["single_key"]), "ontology_key": json.dumps(["single_key"]),
"descriptions": json.dumps(["Single ontology"]), "description": json.dumps(["Single ontology"]),
} }
response = client.post("/api/v1/ontologies", files=files, data=data) response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 200 assert response.status_code == 400
result = response.json() assert "ontology_key must be a string" in response.json()["error"]
assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key"
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_cognify_with_multiple_ontologies(client):
def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user):
"""Test cognify endpoint accepts multiple ontology keys""" """Test cognify endpoint accepts multiple ontology keys"""
payload = { payload = {
"datasets": ["test_dataset"], "datasets": ["test_dataset"],
@ -172,14 +153,11 @@ def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_de
assert response.status_code in [200, 400, 409] # May fail for other reasons, not type assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_complete_multifile_workflow(client):
def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): """Test workflow: upload ontologies one-by-one → cognify with multiple keys"""
"""Test complete workflow: upload multiple ontologies → cognify with multiple keys"""
import io import io
import json
mock_get_default_user.return_value = mock_default_user # Step 1: Upload two ontologies (one-by-one)
# Step 1: Upload multiple ontologies
file1_content = b"""<?xml version="1.0"?> file1_content = b"""<?xml version="1.0"?>
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" <rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:owl="http://www.w3.org/2002/07/owl#"> xmlns:owl="http://www.w3.org/2002/07/owl#">
@ -192,17 +170,21 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
<owl:Class rdf:ID="Manufacturer"/> <owl:Class rdf:ID="Manufacturer"/>
</rdf:RDF>""" </rdf:RDF>"""
files = [ upload_response_1 = client.post(
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), "/api/v1/ontologies",
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), files=[("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml"))],
] data={"ontology_key": "vehicles", "description": "Vehicle ontology"},
data = { )
"ontology_key": json.dumps(["vehicles", "manufacturers"]), assert upload_response_1.status_code == 200
"descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]),
}
upload_response = client.post("/api/v1/ontologies", files=files, data=data) upload_response_2 = client.post(
assert upload_response.status_code == 200 "/api/v1/ontologies",
files=[
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml"))
],
data={"ontology_key": "manufacturers", "description": "Manufacturer ontology"},
)
assert upload_response_2.status_code == 200
# Step 2: Verify ontologies are listed # Step 2: Verify ontologies are listed
list_response = client.get("/api/v1/ontologies") list_response = client.get("/api/v1/ontologies")
@ -223,44 +205,42 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
assert cognify_response.status_code != 400 # Not a validation error assert cognify_response.status_code != 400 # Not a validation error
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_error_handling(client):
def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): """Test error handling for invalid uploads (single-file endpoint)."""
"""Test error handling for invalid multifile uploads"""
import io import io
import json import json
# Test mismatched array lengths # Array-style key should be rejected
file_content = b"<rdf:RDF></rdf:RDF>" file_content = b"<rdf:RDF></rdf:RDF>"
files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))] files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
data = { data = {
"ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file "ontology_key": json.dumps(["key1", "key2"]),
"descriptions": json.dumps(["desc1"]), "description": "desc1",
} }
response = client.post("/api/v1/ontologies", files=files, data=data) response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 400 assert response.status_code == 400
assert "Number of keys must match number of files" in response.json()["error"] assert "ontology_key must be a string" in response.json()["error"]
# Test duplicate keys # Duplicate key should be rejected
files = [ response_1 = client.post(
("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")), "/api/v1/ontologies",
("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")), files=[("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml"))],
] data={"ontology_key": "duplicate", "description": "desc1"},
data = { )
"ontology_key": json.dumps(["duplicate", "duplicate"]), assert response_1.status_code == 200
"descriptions": json.dumps(["desc1", "desc2"]),
}
response = client.post("/api/v1/ontologies", files=files, data=data) response_2 = client.post(
assert response.status_code == 400 "/api/v1/ontologies",
assert "Duplicate ontology keys not allowed" in response.json()["error"] files=[("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml"))],
data={"ontology_key": "duplicate", "description": "desc2"},
)
assert response_2.status_code == 400
assert "already exists" in response_2.json()["error"]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_cognify_missing_ontology_key(client):
def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user):
"""Test cognify with non-existent ontology key""" """Test cognify with non-existent ontology key"""
mock_get_default_user.return_value = mock_default_user
payload = { payload = {
"datasets": ["test_dataset"], "datasets": ["test_dataset"],
"ontology_key": ["nonexistent_key"], "ontology_key": ["nonexistent_key"],

View file

@ -1,7 +1,7 @@
[project] [project]
name = "cognee" name = "cognee"
version = "0.5.0.dev0" version = "0.5.0.dev1"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [ authors = [
{ name = "Vasilije Markovic" }, { name = "Vasilije Markovic" },

4
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1 version = 1
revision = 2 revision = 3
requires-python = ">=3.10, <3.14" requires-python = ">=3.10, <3.14"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'", "python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
@ -946,7 +946,7 @@ wheels = [
[[package]] [[package]]
name = "cognee" name = "cognee"
version = "0.5.0.dev0" version = "0.5.0.dev1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiofiles" }, { name = "aiofiles" },