Updates to searches
This commit is contained in:
parent
7930586017
commit
797e7baba3
3 changed files with 99 additions and 71 deletions
|
|
@ -61,91 +61,93 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
|
|||
|
||||
out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session)
|
||||
|
||||
if out:
|
||||
|
||||
async with update_status_lock:
|
||||
task_status = get_task_status([dataset_name])
|
||||
|
||||
if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
|
||||
logger.info(f"Dataset {dataset_name} is being processed.")
|
||||
return
|
||||
async with update_status_lock:
|
||||
task_status = get_task_status([dataset_name])
|
||||
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
graph_config = get_graph_config()
|
||||
root_node_id = None
|
||||
if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
|
||||
logger.info(f"Dataset {dataset_name} is being processed.")
|
||||
return
|
||||
|
||||
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
root_node_id = await topology_engine.add_graph_topology(files = files)
|
||||
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
await topology_engine.add_graph_topology(graph_config.topology_file_path)
|
||||
elif not graph_config.graph_topology_task:
|
||||
root_node_id = "ROOT"
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
graph_config = get_graph_config()
|
||||
root_node_id = None
|
||||
|
||||
tasks = [
|
||||
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(
|
||||
save_data_chunks,
|
||||
collection_name = "chunks",
|
||||
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
|
||||
run_tasks_parallel([
|
||||
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
root_node_id = await topology_engine.add_graph_topology(files = files)
|
||||
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
await topology_engine.add_graph_topology(graph_config.topology_file_path)
|
||||
elif not graph_config.graph_topology_task:
|
||||
root_node_id = "ROOT"
|
||||
|
||||
tasks = [
|
||||
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(
|
||||
summarize_text_chunks,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
collection_name = "chunk_summaries",
|
||||
), # Summarize the document chunks
|
||||
Task(
|
||||
classify_text_chunks,
|
||||
classification_model = cognee_config.classification_model,
|
||||
),
|
||||
]),
|
||||
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
|
||||
]
|
||||
save_data_chunks,
|
||||
collection_name = "chunks",
|
||||
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
|
||||
run_tasks_parallel([
|
||||
Task(
|
||||
summarize_text_chunks,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
collection_name = "chunk_summaries",
|
||||
), # Summarize the document chunks
|
||||
Task(
|
||||
classify_text_chunks,
|
||||
classification_model = cognee_config.classification_model,
|
||||
),
|
||||
]),
|
||||
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, [
|
||||
PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
|
||||
AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
|
||||
ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
|
||||
TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
|
||||
for file in files
|
||||
])
|
||||
pipeline = run_tasks(tasks, [
|
||||
PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
|
||||
AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
|
||||
ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
|
||||
TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
|
||||
for file in files
|
||||
])
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
|
||||
except Exception as error:
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
|
||||
raise error
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
|
||||
except Exception as error:
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
|
||||
raise error
|
||||
|
||||
|
||||
existing_datasets = db_engine.get_datasets()
|
||||
existing_datasets = db_engine.get_datasets()
|
||||
|
||||
awaitables = []
|
||||
awaitables = []
|
||||
|
||||
|
||||
# dataset_files = []
|
||||
# dataset_name = datasets.replace(".", "_").replace(" ", "_")
|
||||
# dataset_files = []
|
||||
# dataset_name = datasets.replace(".", "_").replace(" ", "_")
|
||||
|
||||
# for added_dataset in existing_datasets:
|
||||
# if dataset_name in added_dataset:
|
||||
# dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
|
||||
# for added_dataset in existing_datasets:
|
||||
# if dataset_name in added_dataset:
|
||||
# dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
|
||||
|
||||
for dataset in datasets:
|
||||
if dataset in existing_datasets:
|
||||
# for file_metadata in files:
|
||||
# if root_node_id is None:
|
||||
# root_node_id=file_metadata['id']
|
||||
awaitables.append(run_cognify_pipeline(dataset, db_engine.get_files_metadata(dataset)))
|
||||
for dataset in datasets:
|
||||
if dataset in existing_datasets:
|
||||
# for file_metadata in files:
|
||||
# if root_node_id is None:
|
||||
# root_node_id=file_metadata['id']
|
||||
awaitables.append(run_cognify_pipeline(dataset, db_engine.get_files_metadata(dataset)))
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
||||
|
||||
#
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from enum import Enum
|
|||
from typing import Dict, Any, Callable, List
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from cognee.infrastructure.databases.relational.user_authentication.users import fast_api_users_init, \
|
||||
has_permission_document, get_async_session_context, get_document_ids_for_user
|
||||
from cognee.modules.search.graph import search_cypher
|
||||
from cognee.modules.search.graph.search_adjacent import search_adjacent
|
||||
from cognee.modules.search.vector.search_traverse import search_traverse
|
||||
|
|
@ -41,8 +43,23 @@ class SearchParameters(BaseModel):
|
|||
|
||||
|
||||
async def search(search_type: str, params: Dict[str, Any]) -> List:
|
||||
search_params = SearchParameters(search_type = search_type, params = params)
|
||||
return await specific_search([search_params])
|
||||
active_user = await fast_api_users_init()
|
||||
async with get_async_session_context() as session:
|
||||
|
||||
extract_documents = await get_document_ids_for_user(active_user.current_user(active=True), session=session)
|
||||
search_params = SearchParameters(search_type = search_type, params = params)
|
||||
searches = await specific_search([search_params])
|
||||
|
||||
filtered_searches =[]
|
||||
for document in searches:
|
||||
for document_id in extract_documents:
|
||||
if document_id in document:
|
||||
filtered_searches.append(document)
|
||||
|
||||
|
||||
return filtered_searches
|
||||
|
||||
|
||||
|
||||
|
||||
async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from fastapi_users.authentication import (
|
|||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from cognee.infrastructure.databases.relational.user_authentication.authentication_db import User, get_user_db, \
|
||||
|
|
@ -240,4 +241,12 @@ async def give_permission_document(user: Optional[User], document_id: str, permi
|
|||
permission=permission
|
||||
)
|
||||
session.add(acl_entry)
|
||||
await session.commit()
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_document_ids_for_user(user_id: uuid.UUID, session: AsyncSession) -> list[str]:
|
||||
result = await session.execute(
|
||||
select(ACL.document_id).filter_by(user_id=user_id)
|
||||
)
|
||||
document_ids = [row[0] for row in result.fetchall()]
|
||||
return document_ids
|
||||
Loading…
Add table
Reference in a new issue