Fixes to the sqlalchemy adapter
This commit is contained in:
parent
b5a3b69e49
commit
9a2cde95d0
3 changed files with 103 additions and 95 deletions
|
|
@ -40,88 +40,92 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
|||
if datasets is None or len(datasets) == 0:
|
||||
return await cognify(await db_engine.get_datasets())
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
async def run_cognify_pipeline(dataset_name: str, files: list[dict]):
|
||||
documents = [
|
||||
PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
|
||||
AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
|
||||
ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
|
||||
TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
|
||||
for file in files
|
||||
]
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
|
||||
await check_permissions_on_documents(user, "read", [document.id for document in documents])
|
||||
if user is None:
|
||||
user = await get_default_user(session= session)
|
||||
|
||||
async with update_status_lock:
|
||||
task_status = get_task_status([dataset_name])
|
||||
|
||||
if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
|
||||
logger.info(f"Dataset {dataset_name} is being processed.")
|
||||
return
|
||||
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
graph_config = get_graph_config()
|
||||
root_node_id = None
|
||||
|
||||
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
root_node_id = await topology_engine.add_graph_topology(files = files)
|
||||
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
await topology_engine.add_graph_topology(graph_config.topology_file_path)
|
||||
elif not graph_config.graph_topology_task:
|
||||
root_node_id = "ROOT"
|
||||
|
||||
tasks = [
|
||||
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(
|
||||
save_data_chunks,
|
||||
collection_name = "chunks",
|
||||
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
|
||||
run_tasks_parallel([
|
||||
Task(
|
||||
summarize_text_chunks,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
collection_name = "chunk_summaries",
|
||||
), # Summarize the document chunks
|
||||
Task(
|
||||
classify_text_chunks,
|
||||
classification_model = cognee_config.classification_model,
|
||||
),
|
||||
]),
|
||||
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
|
||||
async def run_cognify_pipeline(dataset_name: str, files: list[dict]):
|
||||
documents = [
|
||||
PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
|
||||
AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
|
||||
ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
|
||||
TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
|
||||
for file in files
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, documents)
|
||||
await check_permissions_on_documents(user, "read", [document.id for document in documents], session=session)
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
async with update_status_lock:
|
||||
task_status = get_task_status([dataset_name])
|
||||
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
|
||||
except Exception as error:
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
|
||||
raise error
|
||||
if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
|
||||
logger.info(f"Dataset {dataset_name} is being processed.")
|
||||
return
|
||||
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
graph_config = get_graph_config()
|
||||
root_node_id = None
|
||||
|
||||
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
root_node_id = await topology_engine.add_graph_topology(files = files)
|
||||
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
await topology_engine.add_graph_topology(graph_config.topology_file_path)
|
||||
elif not graph_config.graph_topology_task:
|
||||
root_node_id = "ROOT"
|
||||
|
||||
tasks = [
|
||||
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(
|
||||
save_data_chunks,
|
||||
collection_name = "chunks",
|
||||
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
|
||||
run_tasks_parallel([
|
||||
Task(
|
||||
summarize_text_chunks,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
collection_name = "chunk_summaries",
|
||||
), # Summarize the document chunks
|
||||
Task(
|
||||
classify_text_chunks,
|
||||
classification_model = cognee_config.classification_model,
|
||||
),
|
||||
]),
|
||||
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, documents)
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
|
||||
except Exception as error:
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
|
||||
raise error
|
||||
|
||||
|
||||
existing_datasets = await db_engine.get_datasets()
|
||||
awaitables = []
|
||||
existing_datasets = await db_engine.get_datasets()
|
||||
awaitables = []
|
||||
|
||||
for dataset in datasets:
|
||||
dataset_name = generate_dataset_name(dataset)
|
||||
for dataset in datasets:
|
||||
dataset_name = generate_dataset_name(dataset)
|
||||
|
||||
if dataset_name in existing_datasets:
|
||||
awaitables.append(run_cognify_pipeline(dataset, await db_engine.get_files_metadata(dataset_name)))
|
||||
if dataset_name in existing_datasets:
|
||||
awaitables.append(run_cognify_pipeline(dataset, await db_engine.get_files_metadata(dataset_name)))
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
||||
def generate_dataset_name(dataset_name: str) -> str:
|
||||
return dataset_name.replace(".", "_").replace(" ", "_")
|
||||
|
|
|
|||
|
|
@ -3,10 +3,9 @@ from cognee.infrastructure.databases.relational import get_relational_engine
|
|||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
async def get_default_user():
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
stmt = select(User).where(User.email == "default_user@example.com")
|
||||
result = await session.execute(stmt)
|
||||
user = result.scalars().first()
|
||||
async def get_default_user(session):
|
||||
|
||||
stmt = select(User).where(User.email == "default_user@example.com")
|
||||
result = await session.execute(stmt)
|
||||
user = result.scalars().first()
|
||||
return user
|
||||
|
|
@ -1,32 +1,37 @@
|
|||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ...models.User import User
|
||||
from ...models.ACL import ACL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def check_permissions_on_documents(
|
||||
user: User,
|
||||
permission_type: str,
|
||||
document_ids: list[str],
|
||||
):
|
||||
class PermissionDeniedException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
async def check_permissions_on_documents(user: User, permission_type: str, document_ids: list[str], session):
|
||||
try:
|
||||
relational_engine = get_relational_engine()
|
||||
user_group_ids = [group.id for group in user.groups]
|
||||
|
||||
async with relational_engine.get_async_session() as session:
|
||||
user_group_ids = [group.id for group in user.groups]
|
||||
result = await session.execute(
|
||||
select(ACL).filter(
|
||||
ACL.principal_id.in_([user.id, *user_group_ids]),
|
||||
ACL.permission.name == permission_type
|
||||
)
|
||||
)
|
||||
acls = result.scalars().all()
|
||||
|
||||
acls = session.query(ACL) \
|
||||
.filter(ACL.principal_id.in_([user.id, *user_group_ids])) \
|
||||
.filter(ACL.permission.name == permission_type) \
|
||||
.all()
|
||||
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
|
||||
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
||||
|
||||
resource_ids = [resource.resource_id for resource in acl.resources for acl in acls]
|
||||
|
||||
has_permissions = all([document_id in resource_ids for document_id in document_ids])
|
||||
|
||||
if not has_permissions:
|
||||
raise Exception(f"User {user.username} does not have {permission_type} permission on documents")
|
||||
if not has_permissions:
|
||||
raise PermissionDeniedException(f"User {user.username} does not have {permission_type} permission on documents")
|
||||
except Exception as error:
|
||||
logger.error("Error checking permissions on documents: %s", str(error))
|
||||
raise error
|
||||
raise
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue