diff --git a/.env.template b/.env.template index e9e9fb571..916a1ef76 100644 --- a/.env.template +++ b/.env.template @@ -121,6 +121,9 @@ ACCEPT_LOCAL_FILE_PATH=True # This protects against Server Side Request Forgery when proper infrastructure is not in place. ALLOW_HTTP_REQUESTS=True +# When set to false don't allow cypher search to be used in Cognee. +ALLOW_CYPHER_QUERY=True + # When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents RAISE_INCREMENTAL_LOADING_ERRORS=True diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index d40345f8e..9e4bdbbfd 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -3,6 +3,7 @@ import asyncio from uuid import UUID from pydantic import Field from typing import List, Optional +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi import APIRouter, WebSocket, Depends, WebSocketDisconnect from starlette.status import WS_1000_NORMAL_CLOSURE, WS_1008_POLICY_VIOLATION @@ -119,7 +120,7 @@ def get_cognify_router() -> APIRouter: # If any cognify run errored return JSONResponse with proper error status code if any(isinstance(v, PipelineRunErrored) for v in cognify_run.values()): - return JSONResponse(status_code=420, content=cognify_run) + return JSONResponse(status_code=420, content=jsonable_encoder(cognify_run)) return cognify_run except Exception as error: return JSONResponse(status_code=409, content={"error": str(error)}) diff --git a/cognee/infrastructure/databases/relational/ModelBase.py b/cognee/infrastructure/databases/relational/ModelBase.py index a4d3a1a19..3a2054207 100644 --- a/cognee/infrastructure/databases/relational/ModelBase.py +++ b/cognee/infrastructure/databases/relational/ModelBase.py @@ -1,7 +1,8 @@ from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.ext.asyncio import AsyncAttrs -class Base(DeclarativeBase): +class Base(AsyncAttrs, DeclarativeBase): """ Represents a base class for declarative models using SQLAlchemy. diff --git a/cognee/infrastructure/files/utils/get_file_metadata.py b/cognee/infrastructure/files/utils/get_file_metadata.py index 1eb7a1f79..23b10a6df 100644 --- a/cognee/infrastructure/files/utils/get_file_metadata.py +++ b/cognee/infrastructure/files/utils/get_file_metadata.py @@ -56,7 +56,12 @@ async def get_file_metadata(file: BinaryIO) -> FileMetadata: file_type = guess_file_type(file) file_path = getattr(file, "name", None) or getattr(file, "full_name", None) - file_name = Path(file_path).stem if file_path else None + + if isinstance(file_path, str): + file_name = Path(file_path).stem if file_path else None + else: + # In case file_path does not exist or is a integer return None + file_name = None # Get file size pos = file.tell() # remember current pointer diff --git a/cognee/modules/pipelines/models/PipelineRunInfo.py b/cognee/modules/pipelines/models/PipelineRunInfo.py index 5f5a91c34..2a1da34cc 100644 --- a/cognee/modules/pipelines/models/PipelineRunInfo.py +++ b/cognee/modules/pipelines/models/PipelineRunInfo.py @@ -1,6 +1,7 @@ -from typing import Any, Optional +from typing import Any, Optional, List, Union from uuid import UUID from pydantic import BaseModel +from cognee.modules.data.models.Data import Data class PipelineRunInfo(BaseModel): @@ -8,11 +9,15 @@ class PipelineRunInfo(BaseModel): pipeline_run_id: UUID dataset_id: UUID dataset_name: str - payload: Optional[Any] = None + # Data must be mentioned in typing to allow custom encoders for Data to be activated + payload: Optional[Union[Any, List[Data]]] = None data_ingestion_info: Optional[list] = None model_config = { "arbitrary_types_allowed": True, + "from_attributes": True, + # Add custom encoding handler for Data ORM model + "json_encoders": {Data: lambda d: d.to_json()}, } diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py index e671a7db3..551f77a16 100644 --- a/cognee/modules/search/methods/get_search_type_tools.py +++ b/cognee/modules/search/methods/get_search_type_tools.py @@ -1,3 +1,4 @@ +import os from typing import Callable, List, Optional, Type from cognee.modules.engine.models.node_set import NodeSet @@ -160,6 +161,12 @@ async def get_search_type_tools( if query_type is SearchType.FEELING_LUCKY: query_type = await select_search_type(query_text) + if ( + query_type in [SearchType.CYPHER, SearchType.NATURAL_LANGUAGE] + and os.getenv("ALLOW_CYPHER_QUERY", "true").lower() == "false" + ): + raise UnsupportedSearchTypeError("Cypher query search types are disabled.") + search_type_tools = search_tasks.get(query_type) if not search_type_tools: diff --git a/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py b/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py index 5b242baa4..a8731a773 100644 --- a/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +++ b/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + from cognee.shared.logging_utils import get_logger from ...models.User import User @@ -17,9 +19,14 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) -> # Get all datasets all tenants have access to tenant = await get_tenant(user.tenant_id) datasets.extend(await get_principal_datasets(tenant, permission_type)) + # Get all datasets Users roles have access to - for role_name in user.roles: - role = await get_role(user.tenant_id, role_name) + if isinstance(user, SimpleNamespace): + # If simple namespace use roles defined in user + roles = user.roles + else: + roles = await user.awaitable_attrs.roles + for role in roles: datasets.extend(await get_principal_datasets(role, permission_type)) # Deduplicate datasets with same ID diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index 1681b7867..40ae96548 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -48,7 +48,7 @@ class TestCogneeServerStart(unittest.TestCase): """Test that the server is running and can accept connections.""" # Test health endpoint health_response = requests.get("http://localhost:8000/health", timeout=15) - self.assertIn(health_response.status_code, [200, 503]) + self.assertIn(health_response.status_code, [200]) # Test root endpoint root_response = requests.get("http://localhost:8000/", timeout=15) @@ -88,7 +88,7 @@ class TestCogneeServerStart(unittest.TestCase): payload = {"datasets": [dataset_name]} add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50) - if add_response.status_code not in [200, 201, 409]: + if add_response.status_code not in [200, 201]: add_response.raise_for_status() # Cognify request @@ -99,7 +99,7 @@ class TestCogneeServerStart(unittest.TestCase): } cognify_response = requests.post(url, headers=headers, json=payload, timeout=150) - if cognify_response.status_code not in [200, 201, 409]: + if cognify_response.status_code not in [200, 201]: cognify_response.raise_for_status() # TODO: Add test to verify cognify pipeline is complete before testing search @@ -115,7 +115,7 @@ class TestCogneeServerStart(unittest.TestCase): payload = {"searchType": "GRAPH_COMPLETION", "query": "What's in the document?"} search_response = requests.post(url, headers=headers, json=payload, timeout=50) - if search_response.status_code not in [200, 201, 409]: + if search_response.status_code not in [200, 201]: search_response.raise_for_status()