Merge branch 'dev' into feature/cog-2950-ontologyresolver-abstraction
This commit is contained in:
commit
f331cf85fb
8 changed files with 40 additions and 11 deletions
|
|
@ -121,6 +121,9 @@ ACCEPT_LOCAL_FILE_PATH=True
|
||||||
# This protects against Server Side Request Forgery when proper infrastructure is not in place.
|
# This protects against Server Side Request Forgery when proper infrastructure is not in place.
|
||||||
ALLOW_HTTP_REQUESTS=True
|
ALLOW_HTTP_REQUESTS=True
|
||||||
|
|
||||||
|
# When set to false don't allow cypher search to be used in Cognee.
|
||||||
|
ALLOW_CYPHER_QUERY=True
|
||||||
|
|
||||||
# When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents
|
# When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents
|
||||||
RAISE_INCREMENTAL_LOADING_ERRORS=True
|
RAISE_INCREMENTAL_LOADING_ERRORS=True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi import APIRouter, WebSocket, Depends, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, Depends, WebSocketDisconnect
|
||||||
from starlette.status import WS_1000_NORMAL_CLOSURE, WS_1008_POLICY_VIOLATION
|
from starlette.status import WS_1000_NORMAL_CLOSURE, WS_1008_POLICY_VIOLATION
|
||||||
|
|
@ -119,7 +120,7 @@ def get_cognify_router() -> APIRouter:
|
||||||
|
|
||||||
# If any cognify run errored return JSONResponse with proper error status code
|
# If any cognify run errored return JSONResponse with proper error status code
|
||||||
if any(isinstance(v, PipelineRunErrored) for v in cognify_run.values()):
|
if any(isinstance(v, PipelineRunErrored) for v in cognify_run.values()):
|
||||||
return JSONResponse(status_code=420, content=cognify_run)
|
return JSONResponse(status_code=420, content=jsonable_encoder(cognify_run))
|
||||||
return cognify_run
|
return cognify_run
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(AsyncAttrs, DeclarativeBase):
|
||||||
"""
|
"""
|
||||||
Represents a base class for declarative models using SQLAlchemy.
|
Represents a base class for declarative models using SQLAlchemy.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,12 @@ async def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||||
file_type = guess_file_type(file)
|
file_type = guess_file_type(file)
|
||||||
|
|
||||||
file_path = getattr(file, "name", None) or getattr(file, "full_name", None)
|
file_path = getattr(file, "name", None) or getattr(file, "full_name", None)
|
||||||
file_name = Path(file_path).stem if file_path else None
|
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
file_name = Path(file_path).stem if file_path else None
|
||||||
|
else:
|
||||||
|
# In case file_path does not exist or is a integer return None
|
||||||
|
file_name = None
|
||||||
|
|
||||||
# Get file size
|
# Get file size
|
||||||
pos = file.tell() # remember current pointer
|
pos = file.tell() # remember current pointer
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, List, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from cognee.modules.data.models.Data import Data
|
||||||
|
|
||||||
|
|
||||||
class PipelineRunInfo(BaseModel):
|
class PipelineRunInfo(BaseModel):
|
||||||
|
|
@ -8,11 +9,15 @@ class PipelineRunInfo(BaseModel):
|
||||||
pipeline_run_id: UUID
|
pipeline_run_id: UUID
|
||||||
dataset_id: UUID
|
dataset_id: UUID
|
||||||
dataset_name: str
|
dataset_name: str
|
||||||
payload: Optional[Any] = None
|
# Data must be mentioned in typing to allow custom encoders for Data to be activated
|
||||||
|
payload: Optional[Union[Any, List[Data]]] = None
|
||||||
data_ingestion_info: Optional[list] = None
|
data_ingestion_info: Optional[list] = None
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"arbitrary_types_allowed": True,
|
"arbitrary_types_allowed": True,
|
||||||
|
"from_attributes": True,
|
||||||
|
# Add custom encoding handler for Data ORM model
|
||||||
|
"json_encoders": {Data: lambda d: d.to_json()},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from typing import Callable, List, Optional, Type
|
from typing import Callable, List, Optional, Type
|
||||||
|
|
||||||
from cognee.modules.engine.models.node_set import NodeSet
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
|
|
@ -160,6 +161,12 @@ async def get_search_type_tools(
|
||||||
if query_type is SearchType.FEELING_LUCKY:
|
if query_type is SearchType.FEELING_LUCKY:
|
||||||
query_type = await select_search_type(query_text)
|
query_type = await select_search_type(query_text)
|
||||||
|
|
||||||
|
if (
|
||||||
|
query_type in [SearchType.CYPHER, SearchType.NATURAL_LANGUAGE]
|
||||||
|
and os.getenv("ALLOW_CYPHER_QUERY", "true").lower() == "false"
|
||||||
|
):
|
||||||
|
raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
|
||||||
|
|
||||||
search_type_tools = search_tasks.get(query_type)
|
search_type_tools = search_tasks.get(query_type)
|
||||||
|
|
||||||
if not search_type_tools:
|
if not search_type_tools:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from ...models.User import User
|
from ...models.User import User
|
||||||
|
|
@ -17,9 +19,14 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
|
||||||
# Get all datasets all tenants have access to
|
# Get all datasets all tenants have access to
|
||||||
tenant = await get_tenant(user.tenant_id)
|
tenant = await get_tenant(user.tenant_id)
|
||||||
datasets.extend(await get_principal_datasets(tenant, permission_type))
|
datasets.extend(await get_principal_datasets(tenant, permission_type))
|
||||||
|
|
||||||
# Get all datasets Users roles have access to
|
# Get all datasets Users roles have access to
|
||||||
for role_name in user.roles:
|
if isinstance(user, SimpleNamespace):
|
||||||
role = await get_role(user.tenant_id, role_name)
|
# If simple namespace use roles defined in user
|
||||||
|
roles = user.roles
|
||||||
|
else:
|
||||||
|
roles = await user.awaitable_attrs.roles
|
||||||
|
for role in roles:
|
||||||
datasets.extend(await get_principal_datasets(role, permission_type))
|
datasets.extend(await get_principal_datasets(role, permission_type))
|
||||||
|
|
||||||
# Deduplicate datasets with same ID
|
# Deduplicate datasets with same ID
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
"""Test that the server is running and can accept connections."""
|
"""Test that the server is running and can accept connections."""
|
||||||
# Test health endpoint
|
# Test health endpoint
|
||||||
health_response = requests.get("http://localhost:8000/health", timeout=15)
|
health_response = requests.get("http://localhost:8000/health", timeout=15)
|
||||||
self.assertIn(health_response.status_code, [200, 503])
|
self.assertIn(health_response.status_code, [200])
|
||||||
|
|
||||||
# Test root endpoint
|
# Test root endpoint
|
||||||
root_response = requests.get("http://localhost:8000/", timeout=15)
|
root_response = requests.get("http://localhost:8000/", timeout=15)
|
||||||
|
|
@ -88,7 +88,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
payload = {"datasets": [dataset_name]}
|
payload = {"datasets": [dataset_name]}
|
||||||
|
|
||||||
add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50)
|
add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50)
|
||||||
if add_response.status_code not in [200, 201, 409]:
|
if add_response.status_code not in [200, 201]:
|
||||||
add_response.raise_for_status()
|
add_response.raise_for_status()
|
||||||
|
|
||||||
# Cognify request
|
# Cognify request
|
||||||
|
|
@ -99,7 +99,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
cognify_response = requests.post(url, headers=headers, json=payload, timeout=150)
|
cognify_response = requests.post(url, headers=headers, json=payload, timeout=150)
|
||||||
if cognify_response.status_code not in [200, 201, 409]:
|
if cognify_response.status_code not in [200, 201]:
|
||||||
cognify_response.raise_for_status()
|
cognify_response.raise_for_status()
|
||||||
|
|
||||||
# TODO: Add test to verify cognify pipeline is complete before testing search
|
# TODO: Add test to verify cognify pipeline is complete before testing search
|
||||||
|
|
@ -115,7 +115,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
payload = {"searchType": "GRAPH_COMPLETION", "query": "What's in the document?"}
|
payload = {"searchType": "GRAPH_COMPLETION", "query": "What's in the document?"}
|
||||||
|
|
||||||
search_response = requests.post(url, headers=headers, json=payload, timeout=50)
|
search_response = requests.post(url, headers=headers, json=payload, timeout=50)
|
||||||
if search_response.status_code not in [200, 201, 409]:
|
if search_response.status_code not in [200, 201]:
|
||||||
search_response.raise_for_status()
|
search_response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue