diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 94ead86e7..bcfd43273 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -61,91 +61,93 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session) + if out: - async with update_status_lock: - task_status = get_task_status([dataset_name]) - if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED": - logger.info(f"Dataset {dataset_name} is being processed.") - return + async with update_status_lock: + task_status = get_task_status([dataset_name]) - update_task_status(dataset_name, "DATASET_PROCESSING_STARTED") - try: - cognee_config = get_cognify_config() - graph_config = get_graph_config() - root_node_id = None + if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED": + logger.info(f"Dataset {dataset_name} is being processed.") + return - if graph_config.infer_graph_topology and graph_config.graph_topology_task: - from cognee.modules.topology.topology import TopologyEngine - topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology) - root_node_id = await topology_engine.add_graph_topology(files = files) - elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology: - from cognee.modules.topology.topology import TopologyEngine - topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology) - await topology_engine.add_graph_topology(graph_config.topology_file_path) - elif not graph_config.graph_topology_task: - root_node_id = "ROOT" + update_task_status(dataset_name, "DATASET_PROCESSING_STARTED") + try: + cognee_config = get_cognify_config() + graph_config = get_graph_config() + root_node_id = None - tasks = [ - Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type - Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data - Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes - Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks - Task( - save_data_chunks, - collection_name = "chunks", - ), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other) - run_tasks_parallel([ + if graph_config.infer_graph_topology and graph_config.graph_topology_task: + from cognee.modules.topology.topology import TopologyEngine + topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology) + root_node_id = await topology_engine.add_graph_topology(files = files) + elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology: + from cognee.modules.topology.topology import TopologyEngine + topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology) + await topology_engine.add_graph_topology(graph_config.topology_file_path) + elif not graph_config.graph_topology_task: + root_node_id = "ROOT" + + tasks = [ + Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type + Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data + Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes + Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks Task( - summarize_text_chunks, - summarization_model = cognee_config.summarization_model, - collection_name = "chunk_summaries", - ), # Summarize the document chunks - Task( - classify_text_chunks, - classification_model = cognee_config.classification_model, - ), - ]), - Task(remove_obsolete_chunks), # Remove the obsolete document chunks. - ] + save_data_chunks, + collection_name = "chunks", + ), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other) + run_tasks_parallel([ + Task( + summarize_text_chunks, + summarization_model = cognee_config.summarization_model, + collection_name = "chunk_summaries", + ), # Summarize the document chunks + Task( + classify_text_chunks, + classification_model = cognee_config.classification_model, + ), + ]), + Task(remove_obsolete_chunks), # Remove the obsolete document chunks. + ] - pipeline = run_tasks(tasks, [ - PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else - AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else - ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else - TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) - for file in files - ]) + pipeline = run_tasks(tasks, [ + PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else + AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else + ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else + TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) + for file in files + ]) - async for result in pipeline: - print(result) + async for result in pipeline: + print(result) - update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED") - except Exception as error: - update_task_status(dataset_name, "DATASET_PROCESSING_ERROR") - raise error + update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED") + except Exception as error: + update_task_status(dataset_name, "DATASET_PROCESSING_ERROR") + raise error - existing_datasets = db_engine.get_datasets() + existing_datasets = db_engine.get_datasets() - awaitables = [] + awaitables = [] - # dataset_files = [] - # dataset_name = datasets.replace(".", "_").replace(" ", "_") + # dataset_files = [] + # dataset_name = datasets.replace(".", "_").replace(" ", "_") - # for added_dataset in existing_datasets: - # if dataset_name in added_dataset: - # dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset))) + # for added_dataset in existing_datasets: + # if dataset_name in added_dataset: + # dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset))) - for dataset in datasets: - if dataset in existing_datasets: - # for file_metadata in files: - # if root_node_id is None: - # root_node_id=file_metadata['id'] - awaitables.append(run_cognify_pipeline(dataset, db_engine.get_files_metadata(dataset))) + for dataset in datasets: + if dataset in existing_datasets: + # for file_metadata in files: + # if root_node_id is None: + # root_node_id=file_metadata['id'] + awaitables.append(run_cognify_pipeline(dataset, db_engine.get_files_metadata(dataset))) - return await asyncio.gather(*awaitables) + return await asyncio.gather(*awaitables) # diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 8cec061b3..2d0bdeb75 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -4,6 +4,8 @@ from enum import Enum from typing import Dict, Any, Callable, List from pydantic import BaseModel, field_validator +from cognee.infrastructure.databases.relational.user_authentication.users import fast_api_users_init, \ + has_permission_document, get_async_session_context, get_document_ids_for_user from cognee.modules.search.graph import search_cypher from cognee.modules.search.graph.search_adjacent import search_adjacent from cognee.modules.search.vector.search_traverse import search_traverse @@ -41,8 +43,23 @@ class SearchParameters(BaseModel): async def search(search_type: str, params: Dict[str, Any]) -> List: - search_params = SearchParameters(search_type = search_type, params = params) - return await specific_search([search_params]) + active_user = await fast_api_users_init() + async with get_async_session_context() as session: + + extract_documents = await get_document_ids_for_user(active_user.current_user(active=True), session=session) + search_params = SearchParameters(search_type = search_type, params = params) + searches = await specific_search([search_params]) + + filtered_searches =[] + for document in searches: + for document_id in extract_documents: + if document_id in document: + filtered_searches.append(document) + + + return filtered_searches + + async def specific_search(query_params: List[SearchParameters]) -> List: diff --git a/cognee/infrastructure/databases/relational/user_authentication/users.py b/cognee/infrastructure/databases/relational/user_authentication/users.py index e48a4d9f4..7435eaf5b 100644 --- a/cognee/infrastructure/databases/relational/user_authentication/users.py +++ b/cognee/infrastructure/databases/relational/user_authentication/users.py @@ -12,6 +12,7 @@ from fastapi_users.authentication import ( from fastapi_users.exceptions import UserAlreadyExists from fastapi_users.db import SQLAlchemyUserDatabase from fastapi import Depends, HTTPException, status +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from cognee.infrastructure.databases.relational.user_authentication.authentication_db import User, get_user_db, \ @@ -240,4 +241,12 @@ async def give_permission_document(user: Optional[User], document_id: str, permi permission=permission ) session.add(acl_entry) - await session.commit() \ No newline at end of file + await session.commit() + + +async def get_document_ids_for_user(user_id: uuid.UUID, session: AsyncSession) -> list[str]: + result = await session.execute( + select(ACL.document_id).filter_by(user_id=user_id) + ) + document_ids = [row[0] for row in result.fetchall()] + return document_ids \ No newline at end of file