Compare commits
2 commits
main
...
fix-get-st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
613eb2cc4e | ||
|
|
9c8c274574 |
3 changed files with 25 additions and 7 deletions
|
|
@ -237,8 +237,12 @@ async def run_cognify_as_background_process(
|
||||||
graph_db_config: dict = None,
|
graph_db_config: dict = None,
|
||||||
vector_db_config: dict = False,
|
vector_db_config: dict = False,
|
||||||
):
|
):
|
||||||
|
# Convert dataset to list if it's a string
|
||||||
|
if isinstance(datasets, str):
|
||||||
|
datasets = [datasets]
|
||||||
|
|
||||||
# Store pipeline status for all pipelines
|
# Store pipeline status for all pipelines
|
||||||
pipeline_run_started_info = []
|
pipeline_run_started_info = {}
|
||||||
|
|
||||||
async def handle_rest_of_the_run(pipeline_list):
|
async def handle_rest_of_the_run(pipeline_list):
|
||||||
# Execute all provided pipelines one by one to avoid database write conflicts
|
# Execute all provided pipelines one by one to avoid database write conflicts
|
||||||
|
|
@ -269,7 +273,14 @@ async def run_cognify_as_background_process(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save dataset Pipeline run started info
|
# Save dataset Pipeline run started info
|
||||||
pipeline_run_started_info.append(await anext(pipeline_run))
|
run_info = await anext(pipeline_run)
|
||||||
|
pipeline_run_started_info[run_info.dataset_id] = run_info
|
||||||
|
|
||||||
|
if pipeline_run_started_info[run_info.dataset_id].payload:
|
||||||
|
# Remove payload info to avoid serialization
|
||||||
|
# TODO: Handle payload serialization
|
||||||
|
pipeline_run_started_info[run_info.dataset_id].payload = []
|
||||||
|
|
||||||
pipeline_list.append(pipeline_run)
|
pipeline_list.append(pipeline_run)
|
||||||
|
|
||||||
# Send all started pipelines to execute one by one in background
|
# Send all started pipelines to execute one by one in background
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ logger = get_logger("api.cognify")
|
||||||
class CognifyPayloadDTO(InDTO):
|
class CognifyPayloadDTO(InDTO):
|
||||||
datasets: Optional[List[str]] = None
|
datasets: Optional[List[str]] = None
|
||||||
dataset_ids: Optional[List[UUID]] = None
|
dataset_ids: Optional[List[UUID]] = None
|
||||||
graph_model: Optional[BaseModel] = KnowledgeGraph
|
|
||||||
run_in_background: Optional[bool] = False
|
run_in_background: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -96,7 +95,7 @@ def get_cognify_router() -> APIRouter:
|
||||||
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
||||||
|
|
||||||
cognify_run = await cognee_cognify(
|
cognify_run = await cognee_cognify(
|
||||||
datasets, user, payload.graph_model, run_in_background=payload.run_in_background
|
datasets, user, run_in_background=payload.run_in_background
|
||||||
)
|
)
|
||||||
|
|
||||||
return cognify_run
|
return cognify_run
|
||||||
|
|
|
||||||
|
|
@ -331,11 +331,19 @@ def get_datasets_router() -> APIRouter:
|
||||||
## Error Codes
|
## Error Codes
|
||||||
- **500 Internal Server Error**: Error retrieving status information
|
- **500 Internal Server Error**: Error retrieving status information
|
||||||
"""
|
"""
|
||||||
from cognee.modules.data.methods import get_dataset_status
|
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
|
||||||
|
|
||||||
dataset_status = await get_dataset_status(datasets, user.id)
|
try:
|
||||||
|
# Verify user has permission to read dataset
|
||||||
|
authorized_datasets = await get_authorized_existing_datasets(datasets, "read", user)
|
||||||
|
|
||||||
return dataset_status
|
datasets_statuses = await cognee_datasets.get_status(
|
||||||
|
[dataset.id for dataset in authorized_datasets]
|
||||||
|
)
|
||||||
|
|
||||||
|
return datasets_statuses
|
||||||
|
except Exception as error:
|
||||||
|
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||||
|
|
||||||
@router.get("/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
@router.get("/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
||||||
async def get_raw_data(
|
async def get_raw_data(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue