diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index d3f20ba17..d17cab4f2 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -40,88 +40,92 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None): if datasets is None or len(datasets) == 0: return await cognify(await db_engine.get_datasets()) - if user is None: - user = await get_default_user() - async def run_cognify_pipeline(dataset_name: str, files: list[dict]): - documents = [ - PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else - AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else - ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else - TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) - for file in files - ] + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: - await check_permissions_on_documents(user, "read", [document.id for document in documents]) + if user is None: + user = await get_default_user(session= session) - async with update_status_lock: - task_status = get_task_status([dataset_name]) - - if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED": - logger.info(f"Dataset {dataset_name} is being processed.") - return - - update_task_status(dataset_name, "DATASET_PROCESSING_STARTED") - try: - cognee_config = get_cognify_config() - graph_config = get_graph_config() - root_node_id = None - - if graph_config.infer_graph_topology and graph_config.graph_topology_task: - from cognee.modules.topology.topology import TopologyEngine - topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology) - root_node_id = await topology_engine.add_graph_topology(files = files) - elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology: - from cognee.modules.topology.topology import TopologyEngine - topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology) - await topology_engine.add_graph_topology(graph_config.topology_file_path) - elif not graph_config.graph_topology_task: - root_node_id = "ROOT" - - tasks = [ - Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type - Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data - Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes - Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks - Task( - save_data_chunks, - collection_name = "chunks", - ), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other) - run_tasks_parallel([ - Task( - summarize_text_chunks, - summarization_model = cognee_config.summarization_model, - collection_name = "chunk_summaries", - ), # Summarize the document chunks - Task( - classify_text_chunks, - classification_model = cognee_config.classification_model, - ), - ]), - Task(remove_obsolete_chunks), # Remove the obsolete document chunks. + async def run_cognify_pipeline(dataset_name: str, files: list[dict]): + documents = [ + PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else + AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else + ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else + TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) + for file in files ] - pipeline = run_tasks(tasks, documents) + await check_permissions_on_documents(user, "read", [document.id for document in documents], session=session) - async for result in pipeline: - print(result) + async with update_status_lock: + task_status = get_task_status([dataset_name]) - update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED") - except Exception as error: - update_task_status(dataset_name, "DATASET_PROCESSING_ERROR") - raise error + if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED": + logger.info(f"Dataset {dataset_name} is being processed.") + return + + update_task_status(dataset_name, "DATASET_PROCESSING_STARTED") + try: + cognee_config = get_cognify_config() + graph_config = get_graph_config() + root_node_id = None + + if graph_config.infer_graph_topology and graph_config.graph_topology_task: + from cognee.modules.topology.topology import TopologyEngine + topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology) + root_node_id = await topology_engine.add_graph_topology(files = files) + elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology: + from cognee.modules.topology.topology import TopologyEngine + topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology) + await topology_engine.add_graph_topology(graph_config.topology_file_path) + elif not graph_config.graph_topology_task: + root_node_id = "ROOT" + + tasks = [ + Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type + Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data + Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes + Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks + Task( + save_data_chunks, + collection_name = "chunks", + ), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other) + run_tasks_parallel([ + Task( + summarize_text_chunks, + summarization_model = cognee_config.summarization_model, + collection_name = "chunk_summaries", + ), # Summarize the document chunks + Task( + classify_text_chunks, + classification_model = cognee_config.classification_model, + ), + ]), + Task(remove_obsolete_chunks), # Remove the obsolete document chunks. + ] + + pipeline = run_tasks(tasks, documents) + + async for result in pipeline: + print(result) + + update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED") + except Exception as error: + update_task_status(dataset_name, "DATASET_PROCESSING_ERROR") + raise error - existing_datasets = await db_engine.get_datasets() - awaitables = [] + existing_datasets = await db_engine.get_datasets() + awaitables = [] - for dataset in datasets: - dataset_name = generate_dataset_name(dataset) + for dataset in datasets: + dataset_name = generate_dataset_name(dataset) - if dataset_name in existing_datasets: - awaitables.append(run_cognify_pipeline(dataset, await db_engine.get_files_metadata(dataset_name))) + if dataset_name in existing_datasets: + awaitables.append(run_cognify_pipeline(dataset, await db_engine.get_files_metadata(dataset_name))) - return await asyncio.gather(*awaitables) + return await asyncio.gather(*awaitables) def generate_dataset_name(dataset_name: str) -> str: return dataset_name.replace(".", "_").replace(" ", "_") diff --git a/cognee/modules/users/methods/get_default_user.py b/cognee/modules/users/methods/get_default_user.py index 5e0d3bcfe..fd384e50d 100644 --- a/cognee/modules/users/methods/get_default_user.py +++ b/cognee/modules/users/methods/get_default_user.py @@ -3,10 +3,9 @@ from cognee.infrastructure.databases.relational import get_relational_engine from sqlalchemy.future import select -async def get_default_user(): - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - stmt = select(User).where(User.email == "default_user@example.com") - result = await session.execute(stmt) - user = result.scalars().first() +async def get_default_user(session): + + stmt = select(User).where(User.email == "default_user@example.com") + result = await session.execute(stmt) + user = result.scalars().first() return user \ No newline at end of file diff --git a/cognee/modules/users/permissions/methods/check_permissions_on_documents.py b/cognee/modules/users/permissions/methods/check_permissions_on_documents.py index 4bc6f82c7..474855114 100644 --- a/cognee/modules/users/permissions/methods/check_permissions_on_documents.py +++ b/cognee/modules/users/permissions/methods/check_permissions_on_documents.py @@ -1,32 +1,37 @@ import logging + +from sqlalchemy import select + from cognee.infrastructure.databases.relational import get_relational_engine from ...models.User import User from ...models.ACL import ACL logger = logging.getLogger(__name__) -async def check_permissions_on_documents( - user: User, - permission_type: str, - document_ids: list[str], -): +class PermissionDeniedException(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(self.message) + + +async def check_permissions_on_documents(user: User, permission_type: str, document_ids: list[str], session): try: - relational_engine = get_relational_engine() + user_group_ids = [group.id for group in user.groups] - async with relational_engine.get_async_session() as session: - user_group_ids = [group.id for group in user.groups] + result = await session.execute( + select(ACL).filter( + ACL.principal_id.in_([user.id, *user_group_ids]), + ACL.permission.name == permission_type + ) + ) + acls = result.scalars().all() - acls = session.query(ACL) \ - .filter(ACL.principal_id.in_([user.id, *user_group_ids])) \ - .filter(ACL.permission.name == permission_type) \ - .all() + resource_ids = [resource.resource_id for acl in acls for resource in acl.resources] + has_permissions = all(document_id in resource_ids for document_id in document_ids) - resource_ids = [resource.resource_id for resource in acl.resources for acl in acls] - - has_permissions = all([document_id in resource_ids for document_id in document_ids]) - - if not has_permissions: - raise Exception(f"User {user.username} does not have {permission_type} permission on documents") + if not has_permissions: + raise PermissionDeniedException(f"User {user.username} does not have {permission_type} permission on documents") except Exception as error: logger.error("Error checking permissions on documents: %s", str(error)) - raise error + raise +