diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index badf88e71..f6baf4242 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -659,3 +659,51 @@ jobs: EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_pipeline_cache.py + + run_usage_logger_test: + name: Usage logger test (API/MCP) + runs-on: ubuntu-latest + defaults: + run: + shell: bash + services: + redis: + image: redis:7 + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 5s + --health-timeout 3s + --health-retries 5 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + extra-dependencies: "redis" + + - name: Install cognee-mcp (local version) + shell: bash + run: | + uv pip install -e ./cognee-mcp + + - name: Run api/tool usage logger + env: + ENV: dev + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + GRAPH_DATABASE_PROVIDER: 'kuzu' + USAGE_LOGGING: true + CACHE_BACKEND: 'redis' + run: uv run pytest cognee/tests/test_usage_logger_e2e.py -v --log-level=INFO diff --git a/Dockerfile b/Dockerfile index 49bb29445..9e4f83d56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,10 +34,6 @@ COPY README.md pyproject.toml uv.lock entrypoint.sh ./ RUN --mount=type=cache,target=/root/.cache/uv \ uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --extra chromadb --frozen --no-install-project --no-dev --no-editable -# Copy Alembic configuration -COPY alembic.ini /app/alembic.ini -COPY alembic/ /app/alembic - # Then, add the rest of the project source code and install it # Installing separately from its dependencies allows optimal layer caching COPY ./cognee /app/cognee diff --git a/cognee-mcp/Dockerfile b/cognee-mcp/Dockerfile index 6608102c8..cf30466df 100644 --- a/cognee-mcp/Dockerfile +++ b/cognee-mcp/Dockerfile @@ -34,8 +34,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv sync --frozen --no-install-project --no-dev --no-editable # Copy Alembic configuration -COPY alembic.ini /app/alembic.ini -COPY alembic/ /app/alembic +COPY cognee/alembic.ini /app/cognee/alembic.ini +COPY cognee/alembic/ /app/cognee/alembic # Then, add the rest of the project source code and install it # Installing separately from its dependencies allows optimal layer caching diff --git a/cognee-mcp/entrypoint.sh b/cognee-mcp/entrypoint.sh index b4df5ba00..60b6ad459 100644 --- a/cognee-mcp/entrypoint.sh +++ b/cognee-mcp/entrypoint.sh @@ -56,24 +56,22 @@ if [ -n "$API_URL" ]; then echo "Skipping database migrations (API server handles its own database)" else echo "Direct mode: Using local cognee instance" - # Run Alembic migrations with proper error handling. - # Note on UserAlreadyExists error handling: - # During database migrations, we attempt to create a default user. If this user - # already exists (e.g., from a previous deployment or migration), it's not a - # critical error and shouldn't prevent the application from starting. This is - # different from other migration errors which could indicate database schema - # inconsistencies and should cause the startup to fail. This check allows for - # smooth redeployments and container restarts while maintaining data integrity. echo "Running database migrations..." - MIGRATION_OUTPUT=$(alembic upgrade head) + set +e # Disable exit on error to handle specific migration errors + MIGRATION_OUTPUT=$(cd cognee && alembic upgrade head) MIGRATION_EXIT_CODE=$? + set -e if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then - if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then - echo "Warning: Default user already exists, continuing startup..." - else - echo "Migration failed with unexpected error." + + echo "Migration failed with unexpected error. Trying to run Cognee without migrations." + echo "Initializing database tables..." + python /app/src/run_cognee_database_setup.py + INIT_EXIT_CODE=$? + + if [[ $INIT_EXIT_CODE -ne 0 ]]; then + echo "Database initialization failed!" exit 1 fi fi diff --git a/cognee-mcp/pyproject.toml b/cognee-mcp/pyproject.toml index 27b80e72e..03b8c448e 100644 --- a/cognee-mcp/pyproject.toml +++ b/cognee-mcp/pyproject.toml @@ -8,7 +8,6 @@ requires-python = ">=3.10" dependencies = [ # For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes. #"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee", - # TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4 "cognee[postgres,docs,neo4j]==0.5.0", "fastmcp>=2.10.0,<3.0.0", "mcp>=1.12.0,<2.0.0", diff --git a/cognee-mcp/src/run_cognee_database_setup.py b/cognee-mcp/src/run_cognee_database_setup.py new file mode 100644 index 000000000..e0ac91ec4 --- /dev/null +++ b/cognee-mcp/src/run_cognee_database_setup.py @@ -0,0 +1,5 @@ +from cognee.modules.engine.operations.setup import setup +import asyncio + +if __name__ == "__main__": + asyncio.run(setup()) diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index c02de06c8..fc745b24b 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Optional from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location +from cognee.shared.usage_logger import log_usage import importlib.util from contextlib import redirect_stdout import mcp.types as types @@ -91,6 +92,7 @@ async def health_check(request): @mcp.tool() +@log_usage(function_name="MCP cognify", log_type="mcp_tool") async def cognify( data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None ) -> list: @@ -257,6 +259,7 @@ async def cognify( @mcp.tool( name="save_interaction", description="Logs user-agent interactions and query-answer pairs" ) +@log_usage(function_name="MCP save_interaction", log_type="mcp_tool") async def save_interaction(data: str) -> list: """ Transform and save a user-agent interaction into structured knowledge. @@ -316,6 +319,7 @@ async def save_interaction(data: str) -> list: @mcp.tool() +@log_usage(function_name="MCP search", log_type="mcp_tool") async def search(search_query: str, search_type: str, top_k: int = 10) -> list: """ Search and query the knowledge graph for insights, information, and connections. @@ -496,6 +500,7 @@ async def search(search_query: str, search_type: str, top_k: int = 10) -> list: @mcp.tool() +@log_usage(function_name="MCP list_data", log_type="mcp_tool") async def list_data(dataset_id: str = None) -> list: """ List all datasets and their data items with IDs for deletion operations. @@ -624,6 +629,7 @@ async def list_data(dataset_id: str = None) -> list: @mcp.tool() +@log_usage(function_name="MCP delete", log_type="mcp_tool") async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list: """ Delete specific data from a dataset in the Cognee knowledge graph. @@ -703,6 +709,7 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list: @mcp.tool() +@log_usage(function_name="MCP prune", log_type="mcp_tool") async def prune(): """ Reset the Cognee knowledge graph by removing all stored information. @@ -739,6 +746,7 @@ async def prune(): @mcp.tool() +@log_usage(function_name="MCP cognify_status", log_type="mcp_tool") async def cognify_status(): """ Get the current status of the cognify pipeline. @@ -884,26 +892,11 @@ async def main(): await setup() - # Run Alembic migrations from the main cognee directory where alembic.ini is located + # Run Cognee migrations logger.info("Running database migrations...") - migration_result = subprocess.run( - ["python", "-m", "alembic", "upgrade", "head"], - capture_output=True, - text=True, - cwd=Path(__file__).resolve().parent.parent.parent, - ) + from cognee.run_migrations import run_migrations - if migration_result.returncode != 0: - migration_output = migration_result.stderr + migration_result.stdout - # Check for the expected UserAlreadyExists error (which is not critical) - if ( - "UserAlreadyExists" in migration_output - or "User default_user@example.com already exists" in migration_output - ): - logger.warning("Warning: Default user already exists, continuing startup...") - else: - logger.error(f"Migration failed with unexpected error: {migration_output}") - sys.exit(1) + await run_migrations() logger.info("Database migrations done.") elif args.api_url: diff --git a/cognee/__init__.py b/cognee/__init__.py index 4d150ce4e..cea55e031 100644 --- a/cognee/__init__.py +++ b/cognee/__init__.py @@ -33,3 +33,5 @@ from .api.v1.ui import start_ui # Pipelines from .modules import pipelines + +from cognee.run_migrations import run_migrations diff --git a/alembic.ini b/cognee/alembic.ini similarity index 100% rename from alembic.ini rename to cognee/alembic.ini diff --git a/alembic/README b/cognee/alembic/README similarity index 100% rename from alembic/README rename to cognee/alembic/README diff --git a/alembic/env.py b/cognee/alembic/env.py similarity index 100% rename from alembic/env.py rename to cognee/alembic/env.py diff --git a/alembic/script.py.mako b/cognee/alembic/script.py.mako similarity index 100% rename from alembic/script.py.mako rename to cognee/alembic/script.py.mako diff --git a/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py b/cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py similarity index 100% rename from alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py rename to cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py diff --git a/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py b/cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py similarity index 100% rename from alembic/versions/1d0bb7fede17_add_pipeline_run_status.py rename to cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py diff --git a/alembic/versions/1daae0df1866_incremental_loading.py b/cognee/alembic/versions/1daae0df1866_incremental_loading.py similarity index 100% rename from alembic/versions/1daae0df1866_incremental_loading.py rename to cognee/alembic/versions/1daae0df1866_incremental_loading.py diff --git a/alembic/versions/211ab850ef3d_add_sync_operations_table.py b/cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py similarity index 100% rename from alembic/versions/211ab850ef3d_add_sync_operations_table.py rename to cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py diff --git a/alembic/versions/45957f0a9849_add_notebook_table.py b/cognee/alembic/versions/45957f0a9849_add_notebook_table.py similarity index 100% rename from alembic/versions/45957f0a9849_add_notebook_table.py rename to cognee/alembic/versions/45957f0a9849_add_notebook_table.py diff --git a/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py b/cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py similarity index 100% rename from alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py rename to cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py diff --git a/alembic/versions/482cd6517ce4_add_default_user.py b/cognee/alembic/versions/482cd6517ce4_add_default_user.py similarity index 100% rename from alembic/versions/482cd6517ce4_add_default_user.py rename to cognee/alembic/versions/482cd6517ce4_add_default_user.py diff --git a/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py b/cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py similarity index 100% rename from alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py rename to cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py diff --git a/alembic/versions/8057ae7329c2_initial_migration.py b/cognee/alembic/versions/8057ae7329c2_initial_migration.py similarity index 100% rename from alembic/versions/8057ae7329c2_initial_migration.py rename to cognee/alembic/versions/8057ae7329c2_initial_migration.py diff --git a/alembic/versions/9e7a3cb85175_loader_separation.py b/cognee/alembic/versions/9e7a3cb85175_loader_separation.py similarity index 100% rename from alembic/versions/9e7a3cb85175_loader_separation.py rename to cognee/alembic/versions/9e7a3cb85175_loader_separation.py diff --git a/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py b/cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py similarity index 100% rename from alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py rename to cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py diff --git a/alembic/versions/ab7e313804ae_permission_system_rework.py b/cognee/alembic/versions/ab7e313804ae_permission_system_rework.py similarity index 100% rename from alembic/versions/ab7e313804ae_permission_system_rework.py rename to cognee/alembic/versions/ab7e313804ae_permission_system_rework.py diff --git a/alembic/versions/b9274c27a25a_kuzu_11_migration.py b/cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py similarity index 100% rename from alembic/versions/b9274c27a25a_kuzu_11_migration.py rename to cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py diff --git a/alembic/versions/c946955da633_multi_tenant_support.py b/cognee/alembic/versions/c946955da633_multi_tenant_support.py similarity index 100% rename from alembic/versions/c946955da633_multi_tenant_support.py rename to cognee/alembic/versions/c946955da633_multi_tenant_support.py diff --git a/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py b/cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py similarity index 100% rename from alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py rename to cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py diff --git a/alembic/versions/e4ebee1091e7_expand_data_model_info.py b/cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py similarity index 100% rename from alembic/versions/e4ebee1091e7_expand_data_model_info.py rename to cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py index 39dc1a3e6..96c716eec 100644 --- a/cognee/api/v1/add/routers/get_add_router.py +++ b/cognee/api/v1/add/routers/get_add_router.py @@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user from cognee.shared.utils import send_telemetry from cognee.modules.pipelines.models import PipelineRunErrored from cognee.shared.logging_utils import get_logger +from cognee.shared.usage_logger import log_usage from cognee import __version__ as cognee_version logger = get_logger() @@ -19,6 +20,7 @@ def get_add_router() -> APIRouter: router = APIRouter() @router.post("", response_model=dict) + @log_usage(function_name="POST /v1/add", log_type="api_endpoint") async def add( data: List[UploadFile] = File(default=None), datasetName: Optional[str] = Form(default=None), diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 0e2bf2bda..863b68f90 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -29,6 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import ( ) from cognee.shared.logging_utils import get_logger from cognee.shared.utils import send_telemetry +from cognee.shared.usage_logger import log_usage from cognee import __version__ as cognee_version logger = get_logger("api.cognify") @@ -57,6 +58,7 @@ def get_cognify_router() -> APIRouter: router = APIRouter() @router.post("", response_model=dict) + @log_usage(function_name="POST /v1/cognify", log_type="api_endpoint") async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)): """ Transform datasets into structured knowledge graphs through cognitive processing. diff --git a/cognee/api/v1/memify/routers/get_memify_router.py b/cognee/api/v1/memify/routers/get_memify_router.py index c63e4a394..0e54d7508 100644 --- a/cognee/api/v1/memify/routers/get_memify_router.py +++ b/cognee/api/v1/memify/routers/get_memify_router.py @@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user from cognee.shared.utils import send_telemetry from cognee.modules.pipelines.models import PipelineRunErrored from cognee.shared.logging_utils import get_logger +from cognee.shared.usage_logger import log_usage from cognee import __version__ as cognee_version logger = get_logger() @@ -35,6 +36,7 @@ def get_memify_router() -> APIRouter: router = APIRouter() @router.post("", response_model=dict) + @log_usage(function_name="POST /v1/memify", log_type="api_endpoint") async def memify(payload: MemifyPayloadDTO, user: User = Depends(get_authenticated_user)): """ Enrichment pipeline in Cognee, can work with already built graphs. If no data is provided existing knowledge graph will be used as data, diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index 26327628e..c4284bb8b 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -13,6 +13,7 @@ from cognee.modules.users.models import User from cognee.modules.search.operations import get_history from cognee.modules.users.methods import get_authenticated_user from cognee.shared.utils import send_telemetry +from cognee.shared.usage_logger import log_usage from cognee import __version__ as cognee_version from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError from cognee.exceptions import CogneeValidationError @@ -75,6 +76,7 @@ def get_search_router() -> APIRouter: return JSONResponse(status_code=500, content={"error": str(error)}) @router.post("", response_model=Union[List[SearchResult], List]) + @log_usage(function_name="POST /v1/search", log_type="api_endpoint") async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): """ Search for nodes in the graph database. diff --git a/cognee/infrastructure/databases/cache/cache_db_interface.py b/cognee/infrastructure/databases/cache/cache_db_interface.py index 801e86188..c93cf652e 100644 --- a/cognee/infrastructure/databases/cache/cache_db_interface.py +++ b/cognee/infrastructure/databases/cache/cache_db_interface.py @@ -8,10 +8,13 @@ class CacheDBInterface(ABC): Provides a common interface for lock acquisition, release, and context-managed locking. """ - def __init__(self, host: str, port: int, lock_key: str): + def __init__( + self, host: str, port: int, lock_key: str = "default_lock", log_key: str = "usage_logs" + ): self.host = host self.port = port self.lock_key = lock_key + self.log_key = log_key self.lock = None @abstractmethod @@ -77,3 +80,37 @@ class CacheDBInterface(ABC): Gracefully close any async connections. """ pass + + @abstractmethod + async def log_usage( + self, + user_id: str, + log_entry: dict, + ttl: int | None = 604800, + ): + """ + Log usage information (API endpoint calls, MCP tool invocations) to cache. + + Args: + user_id: The user ID. + log_entry: Dictionary containing usage log information. + ttl: Optional time-to-live (seconds). If provided, the log list expires after this time. + + Raises: + CacheConnectionError: If cache connection fails or times out. + """ + pass + + @abstractmethod + async def get_usage_logs(self, user_id: str, limit: int = 100): + """ + Retrieve usage logs for a given user. + + Args: + user_id: The user ID. + limit: Maximum number of logs to retrieve (default: 100). + + Returns: + List of usage log entries, most recent first. + """ + pass diff --git a/cognee/infrastructure/databases/cache/config.py b/cognee/infrastructure/databases/cache/config.py index 88ac05885..de316d3fe 100644 --- a/cognee/infrastructure/databases/cache/config.py +++ b/cognee/infrastructure/databases/cache/config.py @@ -13,6 +13,8 @@ class CacheConfig(BaseSettings): - cache_port: Port number for the cache service. - agentic_lock_expire: Automatic lock expiration time (in seconds). - agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release. + - usage_logging: Enable/disable usage logging for API endpoints and MCP tools. + - usage_logging_ttl: Time-to-live for usage logs in seconds (default: 7 days). """ cache_backend: Literal["redis", "fs"] = "fs" @@ -24,6 +26,8 @@ class CacheConfig(BaseSettings): cache_password: Optional[str] = None agentic_lock_expire: int = 240 agentic_lock_timeout: int = 300 + usage_logging: bool = False + usage_logging_ttl: int = 604800 model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -38,6 +42,8 @@ class CacheConfig(BaseSettings): "cache_password": self.cache_password, "agentic_lock_expire": self.agentic_lock_expire, "agentic_lock_timeout": self.agentic_lock_timeout, + "usage_logging": self.usage_logging, + "usage_logging_ttl": self.usage_logging_ttl, } diff --git a/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py b/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py index 497e6afec..6cedcdc8f 100644 --- a/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +++ b/cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py @@ -89,6 +89,27 @@ class FSCacheAdapter(CacheDBInterface): return None return json.loads(value) + async def log_usage( + self, + user_id: str, + log_entry: dict, + ttl: int | None = 604800, + ): + """ + Usage logging is not supported in filesystem cache backend. + This method is a no-op to satisfy the interface. + """ + logger.warning("Usage logging not supported in FSCacheAdapter, skipping") + pass + + async def get_usage_logs(self, user_id: str, limit: int = 100): + """ + Usage logging is not supported in filesystem cache backend. + This method returns an empty list to satisfy the interface. + """ + logger.warning("Usage logging not supported in FSCacheAdapter, returning empty list") + return [] + async def close(self): if self.cache is not None: self.cache.expire() diff --git a/cognee/infrastructure/databases/cache/get_cache_engine.py b/cognee/infrastructure/databases/cache/get_cache_engine.py index f70358607..b7d015278 100644 --- a/cognee/infrastructure/databases/cache/get_cache_engine.py +++ b/cognee/infrastructure/databases/cache/get_cache_engine.py @@ -1,7 +1,6 @@ """Factory to get the appropriate cache coordination engine (e.g., Redis).""" from functools import lru_cache -import os from typing import Optional from cognee.infrastructure.databases.cache.config import get_cache_config from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface @@ -17,6 +16,7 @@ def create_cache_engine( cache_username: str, cache_password: str, lock_key: str, + log_key: str, agentic_lock_expire: int = 240, agentic_lock_timeout: int = 300, ): @@ -30,6 +30,7 @@ def create_cache_engine( - cache_username: Username to authenticate with. - cache_password: Password to authenticate with. - lock_key: Identifier used for the locking resource. + - log_key: Identifier used for usage logging. - agentic_lock_expire: Duration to hold the lock after acquisition. - agentic_lock_timeout: Max time to wait for the lock before failing. @@ -37,7 +38,7 @@ def create_cache_engine( -------- - CacheDBInterface: An instance of the appropriate cache adapter. """ - if config.caching: + if config.caching or config.usage_logging: from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter if config.cache_backend == "redis": @@ -47,6 +48,7 @@ def create_cache_engine( username=cache_username, password=cache_password, lock_name=lock_key, + log_key=log_key, timeout=agentic_lock_expire, blocking_timeout=agentic_lock_timeout, ) @@ -61,7 +63,10 @@ def create_cache_engine( return None -def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface: +def get_cache_engine( + lock_key: Optional[str] = "default_lock", + log_key: Optional[str] = "usage_logs", +) -> Optional[CacheDBInterface]: """ Returns a cache adapter instance using current context configuration. """ @@ -72,6 +77,7 @@ def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface: cache_username=config.cache_username, cache_password=config.cache_password, lock_key=lock_key, + log_key=log_key, agentic_lock_expire=config.agentic_lock_expire, agentic_lock_timeout=config.agentic_lock_timeout, ) diff --git a/cognee/infrastructure/databases/cache/redis/RedisAdapter.py b/cognee/infrastructure/databases/cache/redis/RedisAdapter.py index b0c51d68e..98021e968 100644 --- a/cognee/infrastructure/databases/cache/redis/RedisAdapter.py +++ b/cognee/infrastructure/databases/cache/redis/RedisAdapter.py @@ -17,13 +17,14 @@ class RedisAdapter(CacheDBInterface): host, port, lock_name="default_lock", + log_key="usage_logs", username=None, password=None, timeout=240, blocking_timeout=300, connection_timeout=30, ): - super().__init__(host, port, lock_name) + super().__init__(host, port, lock_name, log_key) self.host = host self.port = port @@ -177,6 +178,64 @@ class RedisAdapter(CacheDBInterface): entries = await self.async_redis.lrange(session_key, 0, -1) return [json.loads(e) for e in entries] + async def log_usage( + self, + user_id: str, + log_entry: dict, + ttl: int | None = 604800, + ): + """ + Log usage information (API endpoint calls, MCP tool invocations) to Redis. + + Args: + user_id: The user ID. + log_entry: Dictionary containing usage log information. + ttl: Optional time-to-live (seconds). If provided, the log list expires after this time. + + Raises: + CacheConnectionError: If Redis connection fails or times out. + """ + try: + usage_logs_key = f"{self.log_key}:{user_id}" + + await self.async_redis.rpush(usage_logs_key, json.dumps(log_entry)) + + if ttl is not None: + await self.async_redis.expire(usage_logs_key, ttl) + + except (redis.ConnectionError, redis.TimeoutError) as e: + error_msg = f"Redis connection error while logging usage: {str(e)}" + logger.error(error_msg) + raise CacheConnectionError(error_msg) from e + except Exception as e: + error_msg = f"Unexpected error while logging usage to Redis: {str(e)}" + logger.error(error_msg) + raise CacheConnectionError(error_msg) from e + + async def get_usage_logs(self, user_id: str, limit: int = 100): + """ + Retrieve usage logs for a given user. + + Args: + user_id: The user ID. + limit: Maximum number of logs to retrieve (default: 100). + + Returns: + List of usage log entries, most recent first. + """ + try: + usage_logs_key = f"{self.log_key}:{user_id}" + entries = await self.async_redis.lrange(usage_logs_key, -limit, -1) + return [json.loads(e) for e in reversed(entries)] if entries else [] + except (redis.ConnectionError, redis.TimeoutError) as e: + error_msg = f"Redis connection error while retrieving usage logs: {str(e)}" + logger.error(error_msg) + raise CacheConnectionError(error_msg) from e + except Exception as e: + error_msg = f"Unexpected error while retrieving usage logs from Redis: {str(e)}" + logger.error(error_msg) + raise CacheConnectionError(error_msg) from e + async def close(self): """ Gracefully close the async Redis connection. diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index c37af2102..bd2a6f68d 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -24,7 +24,6 @@ async def get_graph_engine() -> GraphDBInterface: return graph_client -@lru_cache def create_graph_engine( graph_database_provider, graph_file_path, @@ -35,6 +34,35 @@ def create_graph_engine( graph_database_port="", graph_database_key="", graph_dataset_database_handler="", +): + """ + Wrapper function to call create graph engine with caching. + For a detailed description, see _create_graph_engine. + """ + return _create_graph_engine( + graph_database_provider, + graph_file_path, + graph_database_url, + graph_database_name, + graph_database_username, + graph_database_password, + graph_database_port, + graph_database_key, + graph_dataset_database_handler, + ) + + +@lru_cache +def _create_graph_engine( + graph_database_provider, + graph_file_path, + graph_database_url="", + graph_database_name="", + graph_database_username="", + graph_database_password="", + graph_database_port="", + graph_database_key="", + graph_dataset_database_handler="", ): """ Create a graph engine based on the specified provider type. diff --git a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py index 72a1fac01..9289bb6c8 100644 --- a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +++ b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py @@ -236,6 +236,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): query_vector: Optional[List[float]] = None, limit: Optional[int] = None, with_vector: bool = False, + include_payload: bool = False, # TODO: Add support for this parameter ): """ Perform a search in the specified collection using either a text query or a vector @@ -319,7 +320,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): self._na_exception_handler(e, query_string) async def batch_search( - self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False + self, + collection_name: str, + query_texts: List[str], + limit: int, + with_vectors: bool = False, + include_payload: bool = False, ): """ Perform a batch search using multiple text queries against a collection. @@ -342,7 +348,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): data_vectors = await self.embedding_engine.embed_text(query_texts) return await asyncio.gather( *[ - self.search(collection_name, None, vector, limit, with_vectors) + self.search( + collection_name, + None, + vector, + limit, + with_vectors, + include_payload=include_payload, + ) for vector in data_vectors ] ) diff --git a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py index 3380125ce..bec97ca94 100644 --- a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +++ b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py @@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface): limit: Optional[int] = 15, with_vector: bool = False, normalized: bool = True, + include_payload: bool = False, # TODO: Add support for this parameter when set to False ): """ Search for items in a collection using either a text or a vector query. @@ -441,6 +442,7 @@ class ChromaDBAdapter(VectorDBInterface): query_texts: List[str], limit: int = 5, with_vectors: bool = False, + include_payload: bool = False, ): """ Perform multiple searches in a single request for efficiency, returning results for each diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 8a87f0339..cdf65514f 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -7,7 +7,6 @@ from cognee.infrastructure.databases.graph.config import get_graph_context_confi from functools import lru_cache -@lru_cache def create_vector_engine( vector_db_provider: str, vector_db_url: str, @@ -15,6 +14,29 @@ def create_vector_engine( vector_db_port: str = "", vector_db_key: str = "", vector_dataset_database_handler: str = "", +): + """ + Wrapper function to call create vector engine with caching. + For a detailed description, see _create_vector_engine. + """ + return _create_vector_engine( + vector_db_provider, + vector_db_url, + vector_db_name, + vector_db_port, + vector_db_key, + vector_dataset_database_handler, + ) + + +@lru_cache +def _create_vector_engine( + vector_db_provider: str, + vector_db_url: str, + vector_db_name: str, + vector_db_port: str = "", + vector_db_key: str = "", + vector_dataset_database_handler: str = "", ): """ Create a vector database engine based on the specified provider. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 6d724f9d7..baef75d9e 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface): limit: Optional[int] = 15, with_vector: bool = False, normalized: bool = True, + include_payload: bool = False, ): if query_text is None and query_vector is None: raise MissingQueryParameterError() @@ -247,17 +248,27 @@ class LanceDBAdapter(VectorDBInterface): if limit <= 0: return [] - result_values = await collection.vector_search(query_vector).limit(limit).to_list() + # Note: Exclude payload if not needed to optimize performance + select_columns = ( + ["id", "vector", "payload", "_distance"] + if include_payload + else ["id", "vector", "_distance"] + ) + result_values = ( + await collection.vector_search(query_vector) + .select(select_columns) + .limit(limit) + .to_list() + ) if not result_values: return [] - normalized_values = normalize_distances(result_values) return [ ScoredResult( id=parse_id(result["id"]), - payload=result["payload"], + payload=result["payload"] if include_payload else None, score=normalized_values[value_index], ) for value_index, result in enumerate(result_values) @@ -269,6 +280,7 @@ class LanceDBAdapter(VectorDBInterface): query_texts: List[str], limit: Optional[int] = None, with_vectors: bool = False, + include_payload: bool = False, ): query_vectors = await self.embedding_engine.embed_text(query_texts) @@ -279,6 +291,7 @@ class LanceDBAdapter(VectorDBInterface): query_vector=query_vector, limit=limit, with_vector=with_vectors, + include_payload=include_payload, ) for query_vector in query_vectors ] diff --git a/cognee/infrastructure/databases/vector/models/ScoredResult.py b/cognee/infrastructure/databases/vector/models/ScoredResult.py index 0a8cc9888..b4792ce28 100644 --- a/cognee/infrastructure/databases/vector/models/ScoredResult.py +++ b/cognee/infrastructure/databases/vector/models/ScoredResult.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from uuid import UUID from pydantic import BaseModel @@ -12,10 +12,10 @@ class ScoredResult(BaseModel): - id (UUID): Unique identifier for the scored result. - score (float): The score associated with the result, where a lower score indicates a better outcome. - - payload (Dict[str, Any]): Additional information related to the score, stored as + - payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as key-value pairs in a dictionary. """ id: UUID score: float # Lower score is better - payload: Dict[str, Any] + payload: Optional[Dict[str, Any]] = None diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 1986fae48..5e2c356ee 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query_vector: Optional[List[float]] = None, limit: Optional[int] = 15, with_vector: bool = False, + include_payload: bool = False, ) -> List[ScoredResult]: if query_text is None and query_vector is None: raise MissingQueryParameterError() @@ -324,10 +325,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # NOTE: This needs to be initialized in case search doesn't return a value closest_items = [] + # Note: Exclude payload from returned columns if not needed to optimize performance + select_columns = ( + [PGVectorDataPoint] + if include_payload + else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector] + ) # Use async session to connect to the database async with self.get_async_session() as session: query = select( - PGVectorDataPoint, + *select_columns, PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), ).order_by("similarity") @@ -344,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): vector_list.append( { "id": parse_id(str(vector.id)), - "payload": vector.payload, + "payload": vector.payload if include_payload else None, "_distance": vector.similarity, } ) @@ -359,7 +366,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Create and return ScoredResult objects return [ - ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) + ScoredResult( + id=row.get("id"), + payload=row.get("payload") if include_payload else None, + score=row.get("score"), + ) for row in vector_list ] @@ -369,6 +380,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query_texts: List[str], limit: int = None, with_vectors: bool = False, + include_payload: bool = False, ): query_vectors = await self.embedding_engine.embed_text(query_texts) @@ -379,6 +391,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query_vector=query_vector, limit=limit, with_vector=with_vectors, + include_payload=include_payload, ) for query_vector in query_vectors ] diff --git a/cognee/infrastructure/databases/vector/vector_db_interface.py b/cognee/infrastructure/databases/vector/vector_db_interface.py index 12ace1a6c..4376d8713 100644 --- a/cognee/infrastructure/databases/vector/vector_db_interface.py +++ b/cognee/infrastructure/databases/vector/vector_db_interface.py @@ -87,6 +87,7 @@ class VectorDBInterface(Protocol): query_vector: Optional[List[float]], limit: Optional[int], with_vector: bool = False, + include_payload: bool = False, ): """ Perform a search in the specified collection using either a text query or a vector @@ -103,6 +104,9 @@ class VectorDBInterface(Protocol): - limit (Optional[int]): The maximum number of results to return from the search. - with_vector (bool): Whether to return the vector representations with search results. (default False) + - include_payload (bool): Whether to include the payload data with search. Search is faster when set to False. + Payload contains metadata about the data point, useful for searches that are only based on embedding distances + like the RAG_COMPLETION search type, but not needed when search also contains graph data. """ raise NotImplementedError @@ -113,6 +117,7 @@ class VectorDBInterface(Protocol): query_texts: List[str], limit: Optional[int], with_vectors: bool = False, + include_payload: bool = False, ): """ Perform a batch search using multiple text queries against a collection. @@ -125,6 +130,9 @@ class VectorDBInterface(Protocol): - limit (Optional[int]): The maximum number of results to return for each query. - with_vectors (bool): Whether to include vector representations with search results. (default False) + - include_payload (bool): Whether to include the payload data with search. Search is faster when set to False. + Payload contains metadata about the data point, useful for searches that are only based on embedding distances + like the RAG_COMPLETION search type, but not needed when search also contains graph data. """ raise NotImplementedError diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index f67c026d3..aad6ad858 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,5 +1,6 @@ import time from cognee.shared.logging_utils import get_logger +from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any from cognee.modules.graph.exceptions import ( @@ -44,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph): def add_edge(self, edge: Edge) -> None: self.edges.append(edge) + + edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") + edge.attributes["edge_type_id"] = ( + generate_edge_id(edge_id=edge_text) if edge_text else None + ) # Update edge with generated edge_type_id + edge.node1.add_skeleton_edge(edge) edge.node2.add_skeleton_edge(edge) key = edge.get_distance_key() @@ -284,13 +291,7 @@ class CogneeGraph(CogneeAbstractGraph): for query_index, scored_results in enumerate(per_query_scored_results): for result in scored_results: - payload = getattr(result, "payload", None) - if not isinstance(payload, dict): - continue - text = payload.get("text") - if not text: - continue - matching_edges = self.edges_by_distance_key.get(str(text)) + matching_edges = self.edges_by_distance_key.get(str(result.id)) if not matching_edges: continue for edge in matching_edges: diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index c9226b6a1..e8e06920d 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -141,7 +141,7 @@ class Edge: self.status = np.ones(dimension, dtype=int) def get_distance_key(self) -> Optional[str]: - key = self.attributes.get("edge_text") or self.attributes.get("relationship_type") + key = self.attributes.get("edge_type_id") if key is None: return None return str(key) diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index ce9b8233b..1a31087d6 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -47,7 +47,9 @@ class ChunksRetriever(BaseRetriever): vector_engine = get_vector_engine() try: - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + found_chunks = await vector_engine.search( + "DocumentChunk_text", query, limit=self.top_k, include_payload=True + ) logger.info(f"Found {len(found_chunks)} chunks from vector search") await update_node_access_timestamps(found_chunks) diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 0e9a4167c..f3a7bd505 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever): vector_engine = get_vector_engine() try: - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + found_chunks = await vector_engine.search( + "DocumentChunk_text", query, limit=self.top_k, include_payload=True + ) if len(found_chunks) == 0: return "" diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 13972bb8d..e79bb514d 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -52,7 +52,7 @@ class SummariesRetriever(BaseRetriever): try: summaries_results = await vector_engine.search( - "TextSummary_text", query, limit=self.top_k + "TextSummary_text", query, limit=self.top_k, include_payload=True ) logger.info(f"Found {len(summaries_results)} summaries from vector search") diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 87d2ab009..cebd03a97 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -98,7 +98,7 @@ class TemporalRetriever(GraphCompletionRetriever): async def filter_top_k_events(self, relevant_events, scored_results): # Build a score lookup from vector search results - score_lookup = {res.payload["id"]: res.score for res in scored_results} + score_lookup = {res.id: res.score for res in scored_results} events_with_scores = [] for event in relevant_events[0]["events"]: diff --git a/cognee/modules/retrieval/triplet_retriever.py b/cognee/modules/retrieval/triplet_retriever.py index b9d006312..ece9c6f85 100644 --- a/cognee/modules/retrieval/triplet_retriever.py +++ b/cognee/modules/retrieval/triplet_retriever.py @@ -67,7 +67,9 @@ class TripletRetriever(BaseRetriever): "In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. " ) - found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k) + found_triplets = await vector_engine.search( + "Triplet_text", query, limit=self.top_k, include_payload=True + ) if len(found_triplets) == 0: return "" diff --git a/cognee/run_migrations.py b/cognee/run_migrations.py new file mode 100644 index 000000000..e501456dd --- /dev/null +++ b/cognee/run_migrations.py @@ -0,0 +1,48 @@ +import os +import sys +import subprocess +from pathlib import Path +import importlib.resources as pkg_resources + +# Assuming your package is named 'cognee' and the migrations are under 'cognee/alembic' +# This is a placeholder for the path logic. +MIGRATIONS_PACKAGE = "cognee" +MIGRATIONS_DIR_NAME = "alembic" + + +async def run_migrations(): + """ + Finds the Alembic configuration within the installed package and + programmatically executes 'alembic upgrade head'. + """ + # 1. Locate the base path of the installed package. + # This reliably finds the root directory of the installed 'cognee' package. + # We look for the parent of the 'migrations' directory. + package_root = str(pkg_resources.files(MIGRATIONS_PACKAGE)) + + # 2. Define the paths for config and scripts + alembic_ini_path = os.path.join(package_root, "alembic.ini") + script_location_path = os.path.join(package_root, MIGRATIONS_DIR_NAME) + + if not os.path.exists(alembic_ini_path): + raise FileNotFoundError( + f"Error: alembic.ini not found at expected locations for package '{MIGRATIONS_PACKAGE}'." + ) + if not os.path.exists(script_location_path): + raise FileNotFoundError( + f"Error: Migrations directory not found at expected locations for package '{MIGRATIONS_PACKAGE}'." + ) + + migration_result = subprocess.run( + ["python", "-m", "alembic", "upgrade", "head"], + capture_output=True, + text=True, + cwd=Path(package_root), + ) + + if migration_result.returncode != 0: + migration_output = migration_result.stderr + migration_result.stdout + print(f"Migration failed with unexpected error: {migration_output}") + sys.exit(1) + + print("Migration completed successfully.") diff --git a/cognee/shared/exceptions/__init__.py b/cognee/shared/exceptions/__init__.py index 5e2ae6875..776a803b4 100644 --- a/cognee/shared/exceptions/__init__.py +++ b/cognee/shared/exceptions/__init__.py @@ -4,6 +4,4 @@ Custom exceptions for the Cognee API. This module defines a set of exceptions for handling various shared utility errors """ -from .exceptions import ( - IngestionError, -) +from .exceptions import IngestionError, UsageLoggerError diff --git a/cognee/shared/exceptions/exceptions.py b/cognee/shared/exceptions/exceptions.py index 3740f677d..c00a39b9e 100644 --- a/cognee/shared/exceptions/exceptions.py +++ b/cognee/shared/exceptions/exceptions.py @@ -1,4 +1,4 @@ -from cognee.exceptions import CogneeValidationError +from cognee.exceptions import CogneeConfigurationError, CogneeValidationError from fastapi import status @@ -10,3 +10,13 @@ class IngestionError(CogneeValidationError): status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, ): super().__init__(message, name, status_code) + + +class UsageLoggerError(CogneeConfigurationError): + def __init__( + self, + message: str = "Usage logging configuration is invalid.", + name: str = "UsageLoggerError", + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ): + super().__init__(message, name, status_code) diff --git a/cognee/shared/usage_logger.py b/cognee/shared/usage_logger.py new file mode 100644 index 000000000..5dde5b9ab --- /dev/null +++ b/cognee/shared/usage_logger.py @@ -0,0 +1,332 @@ +import asyncio +import inspect +import os +from datetime import datetime, timezone +from functools import singledispatch, wraps +from typing import Any, Callable, Optional +from uuid import UUID + +from cognee.infrastructure.databases.cache.config import get_cache_config +from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine +from cognee.shared.exceptions import UsageLoggerError +from cognee.shared.logging_utils import get_logger +from cognee import __version__ as cognee_version + +logger = get_logger("usage_logger") + + +@singledispatch +def _sanitize_value(value: Any) -> Any: + """Default handler for JSON serialization - converts to string.""" + try: + str_repr = str(value) + if str_repr.startswith("<") and str_repr.endswith(">"): + return f"" + return str_repr + except Exception: + return f"" + + +@_sanitize_value.register(type(None)) +def _(value: None) -> None: + """Handle None values - returns None as-is.""" + return None + + +@_sanitize_value.register(str) +@_sanitize_value.register(int) +@_sanitize_value.register(float) +@_sanitize_value.register(bool) +def _(value: str | int | float | bool) -> str | int | float | bool: + """Handle primitive types - returns value as-is since they're JSON-serializable.""" + return value + + +@_sanitize_value.register(UUID) +def _(value: UUID) -> str: + """Convert UUID to string representation.""" + return str(value) + + +@_sanitize_value.register(datetime) +def _(value: datetime) -> str: + """Convert datetime to ISO format string.""" + return value.isoformat() + + +@_sanitize_value.register(list) +@_sanitize_value.register(tuple) +def _(value: list | tuple) -> list: + """Recursively sanitize list or tuple elements.""" + return [_sanitize_value(v) for v in value] + + +@_sanitize_value.register(dict) +def _(value: dict) -> dict: + """Recursively sanitize dictionary keys and values.""" + sanitized = {} + for k, v in value.items(): + key_str = k if isinstance(k, str) else _sanitize_dict_key(k) + sanitized[key_str] = _sanitize_value(v) + return sanitized + + +def _sanitize_dict_key(key: Any) -> str: + """Convert a non-string dict key to a string.""" + sanitized_key = _sanitize_value(key) + if isinstance(sanitized_key, str): + if sanitized_key.startswith("" + return sanitized_key + return str(sanitized_key) + + +def _get_param_names(func: Callable) -> list[str]: + """Get parameter names from function signature.""" + try: + return list(inspect.signature(func).parameters.keys()) + except Exception: + return [] + + +def _get_param_defaults(func: Callable) -> dict[str, Any]: + """Get parameter defaults from function signature.""" + try: + sig = inspect.signature(func) + defaults = {} + for param_name, param in sig.parameters.items(): + if param.default != inspect.Parameter.empty: + defaults[param_name] = param.default + return defaults + except Exception: + return {} + + +def _extract_user_id(args: tuple, kwargs: dict, param_names: list[str]) -> Optional[str]: + """Extract user_id from function arguments if available.""" + try: + if "user" in kwargs and kwargs["user"] is not None: + user = kwargs["user"] + if hasattr(user, "id"): + return str(user.id) + + for i, param_name in enumerate(param_names): + if i < len(args) and param_name == "user": + user = args[i] + if user is not None and hasattr(user, "id"): + return str(user.id) + return None + except Exception: + return None + + +def _extract_parameters(args: tuple, kwargs: dict, param_names: list[str], func: Callable) -> dict: + """Extract function parameters - captures all parameters including defaults, sanitizes for JSON.""" + params = {} + + for key, value in kwargs.items(): + if key != "user": + params[key] = _sanitize_value(value) + + if param_names: + for i, param_name in enumerate(param_names): + if i < len(args) and param_name != "user" and param_name not in kwargs: + params[param_name] = _sanitize_value(args[i]) + else: + for i, arg_value in enumerate(args): + params[f"arg_{i}"] = _sanitize_value(arg_value) + + if param_names: + defaults = _get_param_defaults(func) + for param_name in param_names: + if param_name != "user" and param_name not in params and param_name in defaults: + params[param_name] = _sanitize_value(defaults[param_name]) + + return params + + +async def _log_usage_async( + function_name: str, + log_type: str, + user_id: Optional[str], + parameters: dict, + result: Any, + success: bool, + error: Optional[str], + duration_ms: float, + start_time: datetime, + end_time: datetime, +): + """Asynchronously log function usage to Redis. + + Args: + function_name: Name of the function being logged. + log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool", "function"). + user_id: User identifier, or None to use "unknown". + parameters: Dictionary of function parameters (sanitized). + result: Function return value (will be sanitized). + success: Whether the function executed successfully. + error: Error message if function failed, None otherwise. + duration_ms: Execution duration in milliseconds. + start_time: Function start timestamp. + end_time: Function end timestamp. + + Note: + This function silently handles errors to avoid disrupting the original + function execution. Logs are written to Redis with TTL from config. + """ + try: + logger.debug(f"Starting to log usage for {function_name} at {start_time.isoformat()}") + config = get_cache_config() + if not config.usage_logging: + logger.debug("Usage logging disabled, skipping log") + return + + logger.debug(f"Getting cache engine for {function_name}") + cache_engine = get_cache_engine() + if cache_engine is None: + logger.warning( + f"Cache engine not available for usage logging (function: {function_name})" + ) + return + + logger.debug(f"Cache engine obtained for {function_name}") + + if user_id is None: + user_id = "unknown" + logger.debug(f"No user_id provided, using 'unknown' for {function_name}") + + log_entry = { + "timestamp": start_time.isoformat(), + "type": log_type, + "function_name": function_name, + "user_id": user_id, + "parameters": parameters, + "result": _sanitize_value(result), + "success": success, + "error": error, + "duration_ms": round(duration_ms, 2), + "start_time": start_time.isoformat(), + "end_time": end_time.isoformat(), + "metadata": { + "cognee_version": cognee_version, + "environment": os.getenv("ENV", "prod"), + }, + } + + logger.debug(f"Calling log_usage for {function_name}, user_id={user_id}") + await cache_engine.log_usage( + user_id=user_id, + log_entry=log_entry, + ttl=config.usage_logging_ttl, + ) + logger.info(f"Successfully logged usage for {function_name} (user_id={user_id})") + except Exception as e: + logger.error(f"Failed to log usage for {function_name}: {str(e)}", exc_info=True) + + +def log_usage(function_name: Optional[str] = None, log_type: str = "function"): + """ + Decorator to log function usage to Redis. + + This decorator is completely transparent - it doesn't change function behavior. + It logs function name, parameters, result, timing, and user (if available). + + Args: + function_name: Optional name for the function (defaults to func.__name__) + log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool") + + Usage: + @log_usage(function_name="MCP my_mcp_tool", log_type="mcp_tool") + async def my_mcp_tool(...): + # mcp code + + @log_usage(function_name="POST API /v1/add", log_type="api_endpoint") + async def add(...): + # endpoint code + """ + + def decorator(func: Callable) -> Callable: + """Inner decorator that wraps the function with usage logging. + + Args: + func: The async function to wrap with usage logging. + + Returns: + Callable: The wrapped function with usage logging enabled. + + Raises: + UsageLoggerError: If the function is not async. + """ + if not inspect.iscoroutinefunction(func): + raise UsageLoggerError( + f"@log_usage requires an async function. Got {func.__name__} which is not async." + ) + + @wraps(func) + async def async_wrapper(*args, **kwargs): + """Wrapper function that executes the original function and logs usage. + + This wrapper: + - Extracts user ID and parameters from function arguments + - Executes the original function + - Captures result, success status, and any errors + - Logs usage information asynchronously without blocking + + Args: + *args: Positional arguments passed to the original function. + **kwargs: Keyword arguments passed to the original function. + + Returns: + Any: The return value of the original function. + + Raises: + Any exception raised by the original function (re-raised after logging). + """ + config = get_cache_config() + if not config.usage_logging: + return await func(*args, **kwargs) + + start_time = datetime.now(timezone.utc) + + param_names = _get_param_names(func) + user_id = _extract_user_id(args, kwargs, param_names) + parameters = _extract_parameters(args, kwargs, param_names, func) + + result = None + success = True + error = None + + try: + result = await func(*args, **kwargs) + return result + except Exception as e: + success = False + error = str(e) + raise + finally: + end_time = datetime.now(timezone.utc) + duration_ms = (end_time - start_time).total_seconds() * 1000 + + try: + await _log_usage_async( + function_name=function_name or func.__name__, + log_type=log_type, + user_id=user_id, + parameters=parameters, + result=result, + success=success, + error=error, + duration_ms=duration_ms, + start_time=start_time, + end_time=end_time, + ) + except Exception as e: + logger.error( + f"Failed to log usage for {function_name or func.__name__}: {str(e)}", + exc_info=True, + ) + + return async_wrapper + + return decorator diff --git a/cognee/tests/integration/shared/test_usage_logger_integration.py b/cognee/tests/integration/shared/test_usage_logger_integration.py new file mode 100644 index 000000000..6bdd58e3c --- /dev/null +++ b/cognee/tests/integration/shared/test_usage_logger_integration.py @@ -0,0 +1,255 @@ +"""Integration tests for usage logger with real Redis components.""" + +import os +import pytest +import asyncio +from datetime import datetime, timezone +from types import SimpleNamespace +from uuid import UUID +from unittest.mock import patch + +from cognee.shared.usage_logger import log_usage +from cognee.infrastructure.databases.cache.config import get_cache_config +from cognee.infrastructure.databases.cache.get_cache_engine import ( + get_cache_engine, + create_cache_engine, +) + + +@pytest.fixture +def usage_logging_config(): + """Fixture to enable usage logging via environment variables.""" + original_env = os.environ.copy() + os.environ["USAGE_LOGGING"] = "true" + os.environ["CACHE_BACKEND"] = "redis" + os.environ["CACHE_HOST"] = "localhost" + os.environ["CACHE_PORT"] = "6379" + get_cache_config.cache_clear() + create_cache_engine.cache_clear() + yield + os.environ.clear() + os.environ.update(original_env) + get_cache_config.cache_clear() + create_cache_engine.cache_clear() + + +@pytest.fixture +def usage_logging_disabled(): + """Fixture to disable usage logging via environment variables.""" + original_env = os.environ.copy() + os.environ["USAGE_LOGGING"] = "false" + os.environ["CACHE_BACKEND"] = "redis" + get_cache_config.cache_clear() + create_cache_engine.cache_clear() + yield + os.environ.clear() + os.environ.update(original_env) + get_cache_config.cache_clear() + create_cache_engine.cache_clear() + + +@pytest.fixture +def redis_adapter(): + """Real RedisAdapter instance for testing.""" + from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter + + try: + yield RedisAdapter(host="localhost", port=6379, log_key="test_usage_logs") + except Exception as e: + pytest.skip(f"Redis not available: {e}") + + +@pytest.fixture +def test_user(): + """Test user object.""" + return SimpleNamespace(id="test-user-123") + + +class TestDecoratorBehavior: + """Test decorator behavior with real components.""" + + @pytest.mark.asyncio + async def test_decorator_configuration( + self, usage_logging_disabled, usage_logging_config, redis_adapter + ): + """Test decorator skips when disabled and logs when enabled.""" + # Test disabled + call_count = 0 + + @log_usage(function_name="test_func", log_type="test") + async def test_func(): + nonlocal call_count + call_count += 1 + return "result" + + assert await test_func() == "result" + assert call_count == 1 + + # Test enabled with cache engine None + with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get: + mock_get.return_value = None + assert await test_func() == "result" + + @pytest.mark.asyncio + async def test_decorator_logging(self, usage_logging_config, redis_adapter, test_user): + """Test decorator logs to Redis with correct structure.""" + + @log_usage(function_name="test_func", log_type="test") + async def test_func(param1: str, param2: int = 42, user=None): + await asyncio.sleep(0.01) + return {"result": f"{param1}_{param2}"} + + with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get: + mock_get.return_value = redis_adapter + + result = await test_func("value1", user=test_user) + assert result == {"result": "value1_42"} + + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + log = logs[0] + assert log["function_name"] == "test_func" + assert log["type"] == "test" + assert log["user_id"] == "test-user-123" + assert log["parameters"]["param1"] == "value1" + assert log["parameters"]["param2"] == 42 + assert log["success"] is True + assert all( + field in log + for field in [ + "timestamp", + "result", + "error", + "duration_ms", + "start_time", + "end_time", + "metadata", + ] + ) + assert "cognee_version" in log["metadata"] + + @pytest.mark.asyncio + async def test_multiple_calls(self, usage_logging_config, redis_adapter, test_user): + """Test multiple consecutive calls are all logged.""" + + @log_usage(function_name="multi_test", log_type="test") + async def multi_func(call_num: int, user=None): + return {"call": call_num} + + with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get: + mock_get.return_value = redis_adapter + + for i in range(3): + await multi_func(i, user=test_user) + + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + assert len(logs) >= 3 + call_nums = {log["parameters"]["call_num"] for log in logs[:3]} + assert call_nums == {0, 1, 2} + + +class TestRealRedisIntegration: + """Test real Redis integration.""" + + @pytest.mark.asyncio + async def test_redis_storage_retrieval_and_ttl( + self, usage_logging_config, redis_adapter, test_user + ): + """Test logs are stored, retrieved with correct order/limits, and TTL is set.""" + + @log_usage(function_name="redis_test", log_type="test") + async def redis_func(data: str, user=None): + return {"processed": data} + + @log_usage(function_name="order_test", log_type="test") + async def order_func(num: int, user=None): + return {"num": num} + + with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get: + mock_get.return_value = redis_adapter + + # Storage + await redis_func("test_data", user=test_user) + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + assert logs[0]["function_name"] == "redis_test" + assert logs[0]["parameters"]["data"] == "test_data" + + # Order (most recent first) + for i in range(3): + await order_func(i, user=test_user) + await asyncio.sleep(0.01) + + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + assert [log["parameters"]["num"] for log in logs[:3]] == [2, 1, 0] + + # Limit + assert len(await redis_adapter.get_usage_logs("test-user-123", limit=2)) == 2 + + # TTL + ttl = await redis_adapter.async_redis.ttl("test_usage_logs:test-user-123") + assert 0 < ttl <= 604800 + + +class TestEdgeCases: + """Test edge cases in integration tests.""" + + @pytest.mark.asyncio + async def test_edge_cases(self, usage_logging_config, redis_adapter, test_user): + """Test no params, defaults, complex structures, exceptions, None, circular refs.""" + + @log_usage(function_name="no_params", log_type="test") + async def no_params_func(user=None): + return "result" + + @log_usage(function_name="defaults_only", log_type="test") + async def defaults_only_func(param1: str = "default1", param2: int = 42, user=None): + return {"param1": param1, "param2": param2} + + @log_usage(function_name="complex_test", log_type="test") + async def complex_func(user=None): + return { + "nested": { + "list": [1, 2, 3], + "uuid": UUID("123e4567-e89b-12d3-a456-426614174000"), + "datetime": datetime(2024, 1, 15, tzinfo=timezone.utc), + } + } + + @log_usage(function_name="exception_test", log_type="test") + async def exception_func(user=None): + raise RuntimeError("Test exception") + + @log_usage(function_name="none_test", log_type="test") + async def none_func(user=None): + return None + + with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get: + mock_get.return_value = redis_adapter + + # No parameters + await no_params_func(user=test_user) + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + assert logs[0]["parameters"] == {} + + # Default parameters + await defaults_only_func(user=test_user) + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + assert logs[0]["parameters"]["param1"] == "default1" + assert logs[0]["parameters"]["param2"] == 42 + + # Complex nested structures + await complex_func(user=test_user) + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + assert isinstance(logs[0]["result"]["nested"]["uuid"], str) + assert isinstance(logs[0]["result"]["nested"]["datetime"], str) + + # Exception handling + with pytest.raises(RuntimeError): + await exception_func(user=test_user) + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + assert logs[0]["success"] is False + assert "Test exception" in logs[0]["error"] + + # None return value + assert await none_func(user=test_user) is None + logs = await redis_adapter.get_usage_logs("test-user-123", limit=10) + assert logs[0]["result"] is None diff --git a/cognee/tests/test_chromadb.py b/cognee/tests/test_chromadb.py index 767edf3dc..b5d1c4675 100644 --- a/cognee/tests/test_chromadb.py +++ b/cognee/tests/test_chromadb.py @@ -97,7 +97,7 @@ async def test_vector_engine_search_none_limit(): query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0] result = await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=None + collection_name=collection_name, query_vector=query_vector, limit=None, include_payload=True ) # Check that we did not accidentally use any default value for limit diff --git a/cognee/tests/test_kuzu.py b/cognee/tests/test_kuzu.py index fe9da6dcb..63c9a983f 100644 --- a/cognee/tests/test_kuzu.py +++ b/cognee/tests/test_kuzu.py @@ -70,7 +70,9 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_lancedb.py b/cognee/tests/test_lancedb.py index 115ba99fd..29b149217 100644 --- a/cognee/tests/test_lancedb.py +++ b/cognee/tests/test_lancedb.py @@ -149,7 +149,9 @@ async def main(): await test_getting_of_documents(dataset_name_1) vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index 893b836c0..403bb9e29 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -48,7 +48,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 925614e67..6cc2d7fec 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -63,7 +63,9 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_neptune_analytics_vector.py b/cognee/tests/test_neptune_analytics_vector.py index 99c4d94b4..d86dd6a63 100644 --- a/cognee/tests/test_neptune_analytics_vector.py +++ b/cognee/tests/test_neptune_analytics_vector.py @@ -52,7 +52,9 @@ async def main(): await cognee.cognify([dataset_name]) vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_permissions.py b/cognee/tests/test_permissions.py index 10696441e..9d949c92b 100644 --- a/cognee/tests/test_permissions.py +++ b/cognee/tests/test_permissions.py @@ -41,14 +41,14 @@ async def _reset_engines_and_prune() -> None: except Exception: pass - from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine from cognee.infrastructure.databases.relational.create_relational_engine import ( create_relational_engine, ) - from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine + from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine - create_graph_engine.cache_clear() - create_vector_engine.cache_clear() + _create_graph_engine.cache_clear() + _create_vector_engine.cache_clear() create_relational_engine.cache_clear() await cognee.prune.prune_data() diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 240f9e9bb..8e4b3e8f0 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -163,7 +163,9 @@ async def main(): await test_getting_of_documents(dataset_name_1) vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_remote_kuzu.py b/cognee/tests/test_remote_kuzu.py index 1c619719c..cea5be904 100644 --- a/cognee/tests/test_remote_kuzu.py +++ b/cognee/tests/test_remote_kuzu.py @@ -58,7 +58,9 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_s3_file_storage.py b/cognee/tests/test_s3_file_storage.py index c7fc62cf2..eeb372753 100755 --- a/cognee/tests/test_s3_file_storage.py +++ b/cognee/tests/test_s3_file_storage.py @@ -43,7 +43,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 37b8ae45b..cdb2bbc64 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -48,14 +48,14 @@ async def _reset_engines_and_prune() -> None: # Engine might not exist yet pass - from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine - from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine + from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine from cognee.infrastructure.databases.relational.create_relational_engine import ( create_relational_engine, ) - create_graph_engine.cache_clear() - create_vector_engine.cache_clear() + _create_graph_engine.cache_clear() + _create_vector_engine.cache_clear() create_relational_engine.cache_clear() await cognee.prune.prune_data() diff --git a/cognee/tests/test_usage_logger_e2e.py b/cognee/tests/test_usage_logger_e2e.py new file mode 100644 index 000000000..db34df2cb --- /dev/null +++ b/cognee/tests/test_usage_logger_e2e.py @@ -0,0 +1,268 @@ +import os +import pytest +import pytest_asyncio +import asyncio +from fastapi.testclient import TestClient + +import cognee +from cognee.api.client import app +from cognee.modules.users.methods import get_default_user, get_authenticated_user + + +async def _reset_engines_and_prune(): + """Reset db engine caches and prune data/system.""" + try: + from cognee.infrastructure.databases.vector import get_vector_engine + + vector_engine = get_vector_engine() + if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"): + await vector_engine.engine.dispose(close=True) + except Exception: + pass + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + +@pytest.fixture(scope="session") +def event_loop(): + """Use a single asyncio event loop for this test module.""" + loop = asyncio.new_event_loop() + try: + yield loop + finally: + loop.close() + + +@pytest.fixture(scope="session") +def e2e_config(): + """Configure environment for E2E tests.""" + original_env = os.environ.copy() + os.environ["USAGE_LOGGING"] = "true" + os.environ["CACHE_BACKEND"] = "redis" + os.environ["CACHE_HOST"] = "localhost" + os.environ["CACHE_PORT"] = "6379" + yield + os.environ.clear() + os.environ.update(original_env) + + +@pytest.fixture(scope="session") +def authenticated_client(test_client): + """Override authentication to use default user.""" + + async def override_get_authenticated_user(): + return await get_default_user() + + app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user + yield test_client + app.dependency_overrides.pop(get_authenticated_user, None) + + +@pytest_asyncio.fixture(scope="session") +async def test_data_setup(): + """Set up test data: prune first, then add file and cognify.""" + await _reset_engines_and_prune() + + dataset_name = "test_e2e_dataset" + test_text = "Germany is located in Europe right next to the Netherlands." + + await cognee.add(test_text, dataset_name) + await cognee.cognify([dataset_name]) + + yield dataset_name + + await _reset_engines_and_prune() + + +@pytest_asyncio.fixture +async def mcp_data_setup(): + """Set up test data for MCP tests: prune first, then add file and cognify.""" + await _reset_engines_and_prune() + + dataset_name = "test_mcp_dataset" + test_text = "Germany is located in Europe right next to the Netherlands." + + await cognee.add(test_text, dataset_name) + await cognee.cognify([dataset_name]) + + yield dataset_name + + await _reset_engines_and_prune() + + +@pytest.fixture(scope="session") +def test_client(): + """TestClient instance for API calls.""" + with TestClient(app) as client: + yield client + + +@pytest_asyncio.fixture +async def cache_engine(e2e_config): + """Get cache engine for log verification in test's event loop.""" + from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter + from cognee.infrastructure.databases.cache.config import get_cache_config + + config = get_cache_config() + if not config.usage_logging or config.cache_backend != "redis": + pytest.skip("Redis usage logging not configured") + + engine = RedisAdapter( + host=config.cache_host, + port=config.cache_port, + username=config.cache_username, + password=config.cache_password, + log_key="usage_logs", + ) + return engine + + +@pytest.mark.asyncio +async def test_api_endpoint_logging(e2e_config, authenticated_client, cache_engine): + """Test that API endpoints succeed and log to Redis.""" + user = await get_default_user() + dataset_name = "test_e2e_api_dataset" + + add_response = authenticated_client.post( + "/api/v1/add", + data={"datasetName": dataset_name}, + files=[ + ( + "data", + ( + "test.txt", + b"Germany is located in Europe right next to the Netherlands.", + "text/plain", + ), + ) + ], + ) + assert add_response.status_code in [200, 201], f"Add endpoint failed: {add_response.text}" + + cognify_response = authenticated_client.post( + "/api/v1/cognify", + json={"datasets": [dataset_name], "run_in_background": False}, + ) + assert cognify_response.status_code in [200, 201], ( + f"Cognify endpoint failed: {cognify_response.text}" + ) + + search_response = authenticated_client.post( + "/api/v1/search", + json={"query": "Germany", "search_type": "GRAPH_COMPLETION", "datasets": [dataset_name]}, + ) + assert search_response.status_code == 200, f"Search endpoint failed: {search_response.text}" + + logs = await cache_engine.get_usage_logs(str(user.id), limit=20) + + add_logs = [log for log in logs if log.get("function_name") == "POST /v1/add"] + assert len(add_logs) > 0 + assert add_logs[0]["type"] == "api_endpoint" + assert add_logs[0]["user_id"] == str(user.id) + assert add_logs[0]["success"] is True + + cognify_logs = [log for log in logs if log.get("function_name") == "POST /v1/cognify"] + assert len(cognify_logs) > 0 + assert cognify_logs[0]["type"] == "api_endpoint" + assert cognify_logs[0]["user_id"] == str(user.id) + assert cognify_logs[0]["success"] is True + + search_logs = [log for log in logs if log.get("function_name") == "POST /v1/search"] + assert len(search_logs) > 0 + assert search_logs[0]["type"] == "api_endpoint" + assert search_logs[0]["user_id"] == str(user.id) + assert search_logs[0]["success"] is True + + +@pytest.mark.asyncio +async def test_mcp_tool_logging(e2e_config, cache_engine): + """Test that MCP tools succeed and log to Redis.""" + import sys + import importlib.util + from pathlib import Path + + await _reset_engines_and_prune() + + repo_root = Path(__file__).parent.parent.parent + mcp_src_path = repo_root / "cognee-mcp" / "src" + mcp_server_path = mcp_src_path / "server.py" + + if not mcp_server_path.exists(): + pytest.skip(f"MCP server not found at {mcp_server_path}") + + if str(mcp_src_path) not in sys.path: + sys.path.insert(0, str(mcp_src_path)) + + spec = importlib.util.spec_from_file_location("mcp_server_module", mcp_server_path) + mcp_server_module = importlib.util.module_from_spec(spec) + + import os + + original_cwd = os.getcwd() + try: + os.chdir(str(mcp_src_path)) + spec.loader.exec_module(mcp_server_module) + finally: + os.chdir(original_cwd) + + if mcp_server_module.cognee_client is None: + cognee_client_path = mcp_src_path / "cognee_client.py" + if cognee_client_path.exists(): + spec_client = importlib.util.spec_from_file_location( + "cognee_client", cognee_client_path + ) + cognee_client_module = importlib.util.module_from_spec(spec_client) + spec_client.loader.exec_module(cognee_client_module) + CogneeClient = cognee_client_module.CogneeClient + mcp_server_module.cognee_client = CogneeClient() + else: + pytest.skip(f"CogneeClient not found at {cognee_client_path}") + + test_text = "Germany is located in Europe right next to the Netherlands." + await mcp_server_module.cognify(data=test_text) + await asyncio.sleep(30.0) + + list_result = await mcp_server_module.list_data() + assert list_result is not None, "List data should return results" + + search_result = await mcp_server_module.search( + search_query="Germany", search_type="GRAPH_COMPLETION", top_k=5 + ) + assert search_result is not None, "Search should return results" + + interaction_data = "User: What is Germany?\nAgent: Germany is a country in Europe." + await mcp_server_module.save_interaction(data=interaction_data) + await asyncio.sleep(30.0) + + status_result = await mcp_server_module.cognify_status() + assert status_result is not None, "Cognify status should return results" + + await mcp_server_module.prune() + await asyncio.sleep(0.5) + + logs = await cache_engine.get_usage_logs("unknown", limit=50) + mcp_logs = [log for log in logs if log.get("type") == "mcp_tool"] + assert len(mcp_logs) > 0, ( + f"Should have MCP tool logs with user_id='unknown'. Found logs: {[log.get('function_name') for log in logs[:5]]}" + ) + assert len(mcp_logs) == 6 + function_names = [log.get("function_name") for log in mcp_logs] + expected_tools = [ + "MCP cognify", + "MCP list_data", + "MCP search", + "MCP save_interaction", + "MCP cognify_status", + "MCP prune", + ] + + for expected_tool in expected_tools: + assert expected_tool in function_names, ( + f"Should have {expected_tool} log. Found: {function_names}" + ) + + for log in mcp_logs: + assert log["type"] == "mcp_tool" + assert log["user_id"] == "unknown" + assert log["success"] is True diff --git a/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py b/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py index 837a9955c..ace43628a 100644 --- a/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +++ b/cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py @@ -62,6 +62,8 @@ def test_cache_config_to_dict(): "cache_password": None, "agentic_lock_expire": 100, "agentic_lock_timeout": 200, + "usage_logging": False, + "usage_logging_ttl": 604800, } diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index a13031ac5..5e40ce3a6 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import AsyncMock +from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node @@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph): graph.add_edge(edge) edge_distances = [ - MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}), ] await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) @@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph): graph.add_edge(edge1) graph.add_edge(edge2) + edge_1_text = "CONNECTS_TO" edge_distances = [ - MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}), ] await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) @@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr ) graph.add_edge(edge) + edge_text = "KNOWS" edge_distances = [ - MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}), + MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}), ] await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) @@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph): ) graph.add_edge(edge) + edge_text = "SOME_OTHER_EDGE" edge_distances = [ - MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}), + MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}), ] await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) @@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph): graph.add_edge(edge1) graph.add_edge(edge2) + edge_1_text = "A" + edge_2_text = "B" edge_distances = [ - [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0 - [MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1 + [ + MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text}) + ], # query 0 + [ + MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text}) + ], # query 1 ] await graph.map_vector_distances_to_graph_edges( @@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se graph.add_edge(edge1) graph.add_edge(edge2) + edge_1_text = "A" edge_distances = [ - [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped + [ + MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text}) + ], # query 0: only edge1 mapped [], # query 1: no edges mapped ] diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py index 98bfd48fe..feb254155 100644 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine): assert len(context) == 2 assert context[0]["text"] == "Steve Rodger" assert context[1]["text"] == "Mike Broski" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5) + mock_vector_engine.search.assert_awaited_once_with( + "DocumentChunk_text", "test query", limit=5, include_payload=True + ) @pytest.mark.asyncio @@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine): context = await retriever.get_context("test query") assert len(context) == 3 - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3) + mock_vector_engine.search.assert_awaited_once_with( + "DocumentChunk_text", "test query", limit=3, include_payload=True + ) @pytest.mark.asyncio diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index e998d419d..4a73ef380 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -33,7 +33,9 @@ async def test_get_context_success(mock_vector_engine): context = await retriever.get_context("test query") assert context == "Steve Rodger\nMike Broski" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) + mock_vector_engine.search.assert_awaited_once_with( + "DocumentChunk_text", "test query", limit=2, include_payload=True + ) @pytest.mark.asyncio @@ -85,7 +87,9 @@ async def test_get_context_top_k_limit(mock_vector_engine): context = await retriever.get_context("test query") assert context == "Chunk 0\nChunk 1" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) + mock_vector_engine.search.assert_awaited_once_with( + "DocumentChunk_text", "test query", limit=2, include_payload=True + ) @pytest.mark.asyncio diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index e552ac74a..7bec8afdf 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine): assert len(context) == 2 assert context[0]["text"] == "S.R." assert context[1]["text"] == "M.B." - mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5) + mock_vector_engine.search.assert_awaited_once_with( + "TextSummary_text", "test query", limit=5, include_payload=True + ) @pytest.mark.asyncio @@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine): context = await retriever.get_context("test query") assert len(context) == 3 - mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3) + mock_vector_engine.search.assert_awaited_once_with( + "TextSummary_text", "test query", limit=3, include_payload=True + ) @pytest.mark.asyncio diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index 1d2f4c84d..a0459b227 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -63,8 +63,8 @@ async def test_filter_top_k_events_sorts_and_limits(): ] scored_results = [ - SimpleNamespace(payload={"id": "e2"}, score=0.10), - SimpleNamespace(payload={"id": "e1"}, score=0.20), + SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.10), + SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.20), ] top = await tr.filter_top_k_events(relevant_events, scored_results) @@ -91,8 +91,8 @@ async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k ] scored_results = [ - SimpleNamespace(payload={"id": "known2"}, score=0.05), - SimpleNamespace(payload={"id": "known1"}, score=0.50), + SimpleNamespace(id="known2", payload={"id": "known2"}, score=0.05), + SimpleNamespace(id="known1", payload={"id": "known1"}, score=0.50), ] top = await tr.filter_top_k_events(relevant_events, scored_results) @@ -119,8 +119,8 @@ async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): tr = TemporalRetriever(top_k=10) relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] scored_results = [ - SimpleNamespace(payload={"id": "a"}, score=0.1), - SimpleNamespace(payload={"id": "b"}, score=0.2), + SimpleNamespace(id="a", payload={"id": "a"}, score=0.1), + SimpleNamespace(id="b", payload={"id": "b"}, score=0.2), ] out = await tr.filter_top_k_events(relevant_events, scored_results) assert [e["id"] for e in out] == ["a", "b"] @@ -179,8 +179,8 @@ async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine } ] - mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05) - mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10) + mock_result1 = SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.05) + mock_result2 = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.10) mock_vector_engine.search.return_value = [mock_result1, mock_result2] with ( @@ -279,7 +279,7 @@ async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine) } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( @@ -313,7 +313,7 @@ async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine): } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( @@ -347,7 +347,7 @@ async def test_get_completion_without_context(mock_graph_engine, mock_vector_eng } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( @@ -416,7 +416,7 @@ async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] mock_user = MagicMock() @@ -481,7 +481,7 @@ async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_ve } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( @@ -570,7 +570,7 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index fcbfd2434..4f41f9e3d 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -6,6 +6,7 @@ from cognee.modules.retrieval.utils.brute_force_triplet_search import ( get_memory_fragment, format_triplets, ) +from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError @@ -1036,9 +1037,11 @@ async def test_cognee_graph_mapping_batch_shapes(): ] } + edge_1_text = "relates_to" + edge_2_text = "relates_to" edge_distances_batch = [ - [MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})], - [MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})], + [MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text})], + [MockScoredResult(generate_edge_id(edge_2_text), 0.88, payload={"text": edge_2_text})], ] await graph.map_vector_distances_to_graph_nodes( diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py index 83612c7aa..e914b0aa4 100644 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py @@ -34,7 +34,9 @@ async def test_get_context_success(mock_vector_engine): context = await retriever.get_context("test query") assert context == "Alice knows Bob\nBob works at Tech Corp" - mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5) + mock_vector_engine.search.assert_awaited_once_with( + "Triplet_text", "test query", limit=5, include_payload=True + ) @pytest.mark.asyncio diff --git a/cognee/tests/unit/shared/test_usage_logger.py b/cognee/tests/unit/shared/test_usage_logger.py new file mode 100644 index 000000000..fe4ebb15c --- /dev/null +++ b/cognee/tests/unit/shared/test_usage_logger.py @@ -0,0 +1,241 @@ +"""Unit tests for usage logger core functions.""" + +import pytest +from datetime import datetime, timezone +from uuid import UUID +from types import SimpleNamespace + +from cognee.shared.usage_logger import ( + _sanitize_value, + _sanitize_dict_key, + _get_param_names, + _get_param_defaults, + _extract_user_id, + _extract_parameters, + log_usage, +) +from cognee.shared.exceptions import UsageLoggerError + + +class TestSanitizeValue: + """Test _sanitize_value function.""" + + @pytest.mark.parametrize( + "value,expected", + [ + (None, None), + ("string", "string"), + (42, 42), + (3.14, 3.14), + (True, True), + (False, False), + ], + ) + def test_basic_types(self, value, expected): + assert _sanitize_value(value) == expected + + def test_uuid_and_datetime(self): + """Test UUID and datetime serialization.""" + uuid_val = UUID("123e4567-e89b-12d3-a456-426614174000") + dt = datetime(2024, 1, 15, 12, 30, 45, tzinfo=timezone.utc) + + assert _sanitize_value(uuid_val) == "123e4567-e89b-12d3-a456-426614174000" + assert _sanitize_value(dt) == "2024-01-15T12:30:45+00:00" + + def test_collections(self): + """Test list, tuple, and dict serialization.""" + assert _sanitize_value( + [1, "string", UUID("123e4567-e89b-12d3-a456-426614174000"), None] + ) == [1, "string", "123e4567-e89b-12d3-a456-426614174000", None] + assert _sanitize_value((1, "string", True)) == [1, "string", True] + assert _sanitize_value({"key": UUID("123e4567-e89b-12d3-a456-426614174000")}) == { + "key": "123e4567-e89b-12d3-a456-426614174000" + } + assert _sanitize_value([]) == [] + assert _sanitize_value({}) == {} + + def test_nested_and_complex(self): + """Test nested structures and non-serializable types.""" + # Nested structure + nested = {"level1": {"level2": {"level3": [1, 2, {"nested": "value"}]}}} + assert _sanitize_value(nested)["level1"]["level2"]["level3"][2]["nested"] == "value" + + # Non-serializable + class CustomObject: + def __str__(self): + return "" + + result = _sanitize_value(CustomObject()) + assert isinstance(result, str) + assert "" + + result = _sanitize_dict_key(BadKey()) + assert isinstance(result, str) + assert " List[Person]: + system_prompt = ( + "Extract people mentioned in the text. " + "Return as `persons: Person[]` with each Person having `name` and optional `knows` relations. " + "If the text says someone knows someone set `knows` accordingly. " + "Only include facts explicitly stated." + ) + people = await LLMGateway.acreate_structured_output(text, system_prompt, People) + return people.persons + + +async def main(): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + text = "Alice knows Bob." + + tasks = [ + Task(extract_people), # input: text -> output: list[Person] + Task(add_data_points), # input: list[Person] -> output: list[Person] + ] + + async for _ in run_pipeline(tasks=tasks, data=text, datasets=["people_demo"]): + pass + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/guides/graph_visualization.py b/examples/guides/graph_visualization.py new file mode 100644 index 000000000..3ee980081 --- /dev/null +++ b/examples/guides/graph_visualization.py @@ -0,0 +1,14 @@ +import asyncio +import cognee +from cognee.api.v1.visualize.visualize import visualize_graph + + +async def main(): + await cognee.add(["Alice knows Bob.", "NLP is a subfield of CS."]) + await cognee.cognify() + + await visualize_graph("./graph_after_cognify.html") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/guides/low_level_llm.py b/examples/guides/low_level_llm.py new file mode 100644 index 000000000..454f53f44 --- /dev/null +++ b/examples/guides/low_level_llm.py @@ -0,0 +1,31 @@ +import asyncio + +from pydantic import BaseModel +from typing import List +from cognee.infrastructure.llm.LLMGateway import LLMGateway + + +class MiniEntity(BaseModel): + name: str + type: str + + +class MiniGraph(BaseModel): + nodes: List[MiniEntity] + + +async def main(): + system_prompt = ( + "Extract entities as nodes with name and type. " + "Use concise, literal values present in the text." + ) + + text = "Apple develops iPhone; Audi produces the R8." + + result = await LLMGateway.acreate_structured_output(text, system_prompt, MiniGraph) + print(result) + # MiniGraph(nodes=[MiniEntity(name='Apple', type='Organization'), ...]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/guides/memify_quickstart.py b/examples/guides/memify_quickstart.py new file mode 100644 index 000000000..61fdf6991 --- /dev/null +++ b/examples/guides/memify_quickstart.py @@ -0,0 +1,30 @@ +import asyncio +import cognee +from cognee.modules.search.types import SearchType + + +async def main(): + # 1) Add two short chats and build a graph + await cognee.add( + [ + "We follow PEP8. Add type hints and docstrings.", + "Releases should not be on Friday. Susan must review PRs.", + ], + dataset_name="rules_demo", + ) + await cognee.cognify(datasets=["rules_demo"]) # builds graph + + # 2) Enrich the graph (uses default memify tasks) + await cognee.memify(dataset="rules_demo") + + # 3) Query the new coding rules + rules = await cognee.search( + query_type=SearchType.CODING_RULES, + query_text="List coding rules", + node_name=["coding_agent_rules"], + ) + print("Rules:", rules) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/guides/ontology_input_example/basic_ontology.owl b/examples/guides/ontology_input_example/basic_ontology.owl new file mode 100644 index 000000000..81c4182d4 --- /dev/null +++ b/examples/guides/ontology_input_example/basic_ontology.owl @@ -0,0 +1,290 @@ + + + + Created for making cars accessible to everyone. + + + + + + + + + + + + + + + + + Famous for high-performance sports cars. + + + + + + + + + Pioneering social media and virtual reality technology. + + + + + + + + Known for its innovative consumer electronics and software. + + + + + + + + + + + + Known for its modern designs and technology. + + + + + + + + + + + + + + + + Focused on performance and driving pleasure. + + + + + + + + + + + + + + + + + + + + + Started as a search engine and expanded into cloud computing and AI. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Dominant in software, cloud computing, and gaming. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Synonymous with luxury and quality. + + + + + + + + + + + From e-commerce to cloud computing giant with AWS. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/guides/ontology_quickstart.py b/examples/guides/ontology_quickstart.py new file mode 100644 index 000000000..e3b58d5d6 --- /dev/null +++ b/examples/guides/ontology_quickstart.py @@ -0,0 +1,29 @@ +import asyncio +import cognee +import os +from cognee.modules.ontology.ontology_config import Config +from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver + + +async def main(): + texts = ["Audi produces the R8 and e-tron.", "Apple develops iPhone and MacBook."] + + await cognee.add(texts) + # or: await cognee.add("/path/to/folder/of/files") + + ontology_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "ontology_input_example/basic_ontology.owl" + ) + + # Create full config structure manually + config: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_path) + } + } + + await cognee.cognify(config=config) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/guides/s3_storage.py b/examples/guides/s3_storage.py new file mode 100644 index 000000000..e64859db7 --- /dev/null +++ b/examples/guides/s3_storage.py @@ -0,0 +1,25 @@ +import asyncio +import cognee + + +async def main(): + # Single file + await cognee.add("s3://cognee-s3-small-test/Natural_language_processing.txt") + + # Folder/prefix (recursively expands) + await cognee.add("s3://cognee-s3-small-test") + + # Mixed list + await cognee.add( + [ + "s3://cognee-s3-small-test/Natural_language_processing.txt", + "Some inline text to ingest", + ] + ) + + # Process the data + await cognee.cognify() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/guides/search_basics_additional.py b/examples/guides/search_basics_additional.py new file mode 100644 index 000000000..f98dd8db9 --- /dev/null +++ b/examples/guides/search_basics_additional.py @@ -0,0 +1,52 @@ +import asyncio +import cognee + +from cognee.api.v1.search import SearchType + + +async def main(): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + text = """ + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. + First rule of coding: Do not talk about coding. + """ + + text2 = """ + Sandwiches are best served toasted with cheese, ham, mayo, + lettuce, mustard, and salt & pepper. + """ + + await cognee.add(text, dataset_name="NLP_coding") + await cognee.add(text2, dataset_name="Sandwiches") + await cognee.add(text2) + + await cognee.cognify() + + # Make sure you've already run cognee.cognify(...) so the graph has content + answers = await cognee.search(query_text="What are the main themes in my data?") + assert len(answers) > 0 + + answers = await cognee.search( + query_text="List coding guidelines", + query_type=SearchType.CODING_RULES, + ) + assert len(answers) > 0 + + answers = await cognee.search( + query_text="Give me a confident answer: What is NLP?", + system_prompt="Answer succinctly and state confidence at the end.", + ) + assert len(answers) > 0 + + answers = await cognee.search( + query_text="Tell me about NLP", + only_context=True, + ) + assert len(answers) > 0 + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/guides/search_basics_core.py b/examples/guides/search_basics_core.py new file mode 100644 index 000000000..15e2b5670 --- /dev/null +++ b/examples/guides/search_basics_core.py @@ -0,0 +1,27 @@ +import asyncio +import cognee + + +async def main(): + # Start clean (optional in your app) + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + # Prepare knowledge base + await cognee.add( + [ + "Alice moved to Paris in 2010. She works as a software engineer.", + "Bob lives in New York. He is a data scientist.", + "Alice and Bob met at a conference in 2015.", + ] + ) + + await cognee.cognify() + + # Make sure you've already run cognee.cognify(...) so the graph has content + answers = await cognee.search(query_text="What are the main themes in my data?") + for answer in answers: + print(answer) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/guides/temporal_cognify.py b/examples/guides/temporal_cognify.py new file mode 100644 index 000000000..34c1ee33c --- /dev/null +++ b/examples/guides/temporal_cognify.py @@ -0,0 +1,57 @@ +import asyncio +import cognee + + +async def main(): + text = """ + In 1998 the project launched. In 2001 version 1.0 shipped. In 2004 the team merged + with another group. In 2010 support for v1 ended. + """ + + await cognee.add(text, dataset_name="timeline_demo") + + await cognee.cognify(datasets=["timeline_demo"], temporal_cognify=True) + + from cognee.api.v1.search import SearchType + + # Before / after queries + result = await cognee.search( + query_type=SearchType.TEMPORAL, query_text="What happened before 2000?", top_k=10 + ) + + assert result != [] + + result = await cognee.search( + query_type=SearchType.TEMPORAL, query_text="What happened after 2010?", top_k=10 + ) + + assert result != [] + + # Between queries + result = await cognee.search( + query_type=SearchType.TEMPORAL, query_text="Events between 2001 and 2004", top_k=10 + ) + + assert result != [] + + # Scoped descriptions + result = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="Key project milestones between 1998 and 2010", + top_k=10, + ) + + assert result != [] + + result = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="What happened after 2004?", + datasets=["timeline_demo"], + top_k=10, + ) + + assert result != [] + + +if __name__ == "__main__": + asyncio.run(main())