fix: search results preview
This commit is contained in:
parent
709a10c50c
commit
cb9bfa27ea
27 changed files with 177 additions and 222 deletions
|
|
@ -48,15 +48,15 @@ export default function Home() {
|
|||
});
|
||||
}, [showNotification])
|
||||
|
||||
const onDatasetCognify = useCallback((dataset: { id: string }) => {
|
||||
showNotification(`Cognification started for dataset "${dataset.id}".`, 5000);
|
||||
const onDatasetCognify = useCallback((dataset: { id: string, name: string }) => {
|
||||
showNotification(`Cognification started for dataset "${dataset.name}".`, 5000);
|
||||
|
||||
return cognifyDataset(dataset)
|
||||
.then(() => {
|
||||
showNotification(`Dataset "${dataset.id}" cognified.`, 5000);
|
||||
showNotification(`Dataset "${dataset.name}" cognified.`, 5000);
|
||||
})
|
||||
.catch(() => {
|
||||
showNotification(`Dataset "${dataset.id}" cognification failed. Please try again.`, 5000);
|
||||
showNotification(`Dataset "${dataset.name}" cognification failed. Please try again.`, 5000);
|
||||
});
|
||||
}, [showNotification]);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
export default function cognifyDataset(dataset: { id: string }) {
|
||||
export default function cognifyDataset(dataset: { id: string, name: string }) {
|
||||
return fetch('http://127.0.0.1:8000/cognify', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
}
|
||||
|
||||
.dataTable {
|
||||
color: white;
|
||||
border-collapse: collapse;
|
||||
}
|
||||
.dataTable td, .dataTable th {
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ import RawDataPreview from './RawDataPreview';
|
|||
export interface Data {
|
||||
id: string;
|
||||
name: string;
|
||||
filePath: string;
|
||||
mimeType: string;
|
||||
extension: string;
|
||||
rawDataLocation: string;
|
||||
}
|
||||
|
||||
interface DatasetLike {
|
||||
|
|
@ -80,7 +81,6 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
|
|||
<th>Name</th>
|
||||
<th>File path</th>
|
||||
<th>MIME type</th>
|
||||
<th>Keywords</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
|
|
@ -104,10 +104,10 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
|
|||
<Text>{dataItem.id}</Text>
|
||||
</td>
|
||||
<td>
|
||||
<Text>{dataItem.name}</Text>
|
||||
<Text>{dataItem.name}.{dataItem.extension}</Text>
|
||||
</td>
|
||||
<td>
|
||||
<Text>{dataItem.filePath}</Text>
|
||||
<Text>{dataItem.rawDataLocation}</Text>
|
||||
</td>
|
||||
<td>
|
||||
<Text>{dataItem.mimeType}</Text>
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ export default function DatasetsView({
|
|||
.finally(() => enableCognifyRun());
|
||||
}
|
||||
|
||||
const [dataset, setExplorationDataset] = useState<{ id: string } | null>(null);
|
||||
const [dataset, setExplorationDataset] = useState<{ id: string, name: string } | null>(null);
|
||||
const {
|
||||
value: isExplorationWindowShown,
|
||||
setTrue: showExplorationWindow,
|
||||
|
|
@ -97,7 +97,7 @@ export default function DatasetsView({
|
|||
</Stack>
|
||||
<Modal onClose={hideExplorationWindow} isOpen={isExplorationWindowShown} className={styles.explorerModal}>
|
||||
<Spacer horizontal="2" vertical="3" wrap>
|
||||
<Text>{dataset?.id}</Text>
|
||||
<Text>{dataset?.name}</Text>
|
||||
</Spacer>
|
||||
<Explorer dataset={dataset!} />
|
||||
</Modal>
|
||||
|
|
|
|||
|
|
@ -67,7 +67,6 @@ function useDatasets() {
|
|||
const fetchDatasets = useCallback(() => {
|
||||
fetch('http://127.0.0.1:8000/datasets')
|
||||
.then((response) => response.json())
|
||||
.then((datasets) => datasets.map((dataset: string) => ({ id: dataset, name: dataset })))
|
||||
.then((datasets) => {
|
||||
setDatasets(datasets);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import styles from './SearchView.module.css';
|
|||
interface Message {
|
||||
id: string;
|
||||
user: 'user' | 'system';
|
||||
text: string;
|
||||
text: any;
|
||||
}
|
||||
|
||||
interface SelectOption {
|
||||
|
|
@ -98,7 +98,9 @@ export default function SearchView() {
|
|||
[styles.userMessage]: message.user === "user",
|
||||
})}
|
||||
>
|
||||
{message.text}
|
||||
{message?.text && (
|
||||
typeof(message.text) == "string" ? message.text : JSON.stringify(message.text)
|
||||
)}
|
||||
</Text>
|
||||
))}
|
||||
</Stack>
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ def health_check():
|
|||
"""
|
||||
return {"status": "OK"}
|
||||
|
||||
@app.get("/datasets", response_model=list)
|
||||
@app.get("/datasets", response_model = list)
|
||||
async def get_datasets():
|
||||
try:
|
||||
from cognee.api.v1.datasets.datasets import datasets
|
||||
|
|
@ -116,18 +116,12 @@ async def get_datasets():
|
|||
|
||||
return JSONResponse(
|
||||
status_code = 200,
|
||||
content = [{
|
||||
"id": str(dataset.id),
|
||||
"name": dataset.name,
|
||||
"created_at": dataset.created_at,
|
||||
"updated_at": dataset.updated_at,
|
||||
"data": dataset.data,
|
||||
} for dataset in datasets],
|
||||
content = [dataset.to_json() for dataset in datasets],
|
||||
)
|
||||
except Exception as error:
|
||||
raise HTTPException(status_code = 500, detail=f"Error retrieving datasets: {str(error)}") from error
|
||||
|
||||
@app.delete("/datasets/{dataset_id}", response_model=dict)
|
||||
@app.delete("/datasets/{dataset_id}", response_model = dict)
|
||||
async def delete_dataset(dataset_id: str):
|
||||
from cognee.api.v1.datasets.datasets import datasets
|
||||
await datasets.delete_dataset(dataset_id)
|
||||
|
|
@ -159,17 +153,14 @@ async def get_dataset_graph(dataset_id: str):
|
|||
@app.get("/datasets/{dataset_id}/data", response_model=list)
|
||||
async def get_dataset_data(dataset_id: str):
|
||||
from cognee.api.v1.datasets.datasets import datasets
|
||||
dataset_data = await datasets.list_data(dataset_id)
|
||||
|
||||
dataset_data = await datasets.list_data(dataset_id = dataset_id)
|
||||
|
||||
if dataset_data is None:
|
||||
raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
|
||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
||||
|
||||
return [
|
||||
dict(
|
||||
id=data["id"],
|
||||
name=f"{data['name']}.{data['extension']}",
|
||||
filePath=data["file_path"],
|
||||
mimeType=data["mime_type"],
|
||||
)
|
||||
for data in dataset_data
|
||||
data.to_json() for data in dataset_data
|
||||
]
|
||||
|
||||
@app.get("/datasets/status", response_model=dict)
|
||||
|
|
@ -193,10 +184,12 @@ async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset
|
|||
async def get_raw_data(dataset_id: str, data_id: str):
|
||||
from cognee.api.v1.datasets.datasets import datasets
|
||||
dataset_data = await datasets.list_data(dataset_id)
|
||||
|
||||
if dataset_data is None:
|
||||
raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
|
||||
data = [data for data in dataset_data if data["id"] == data_id][0]
|
||||
return data["file_path"]
|
||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
||||
|
||||
data = [data for data in dataset_data if str(data.id) == data_id][0]
|
||||
return data.raw_data_location
|
||||
|
||||
class AddPayload(BaseModel):
|
||||
data: Union[str, UploadFile, List[Union[str, UploadFile]]]
|
||||
|
|
@ -276,18 +269,21 @@ async def search(payload: SearchPayload):
|
|||
from cognee.api.v1.search import search as cognee_search
|
||||
try:
|
||||
search_type = payload.query_params["searchType"]
|
||||
|
||||
params = {
|
||||
"query": payload.query_params["query"],
|
||||
}
|
||||
|
||||
results = await cognee_search(search_type, params)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=json.dumps(results)
|
||||
status_code = 200,
|
||||
content = results,
|
||||
)
|
||||
except Exception as error:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"error": str(error)}
|
||||
status_code = 409,
|
||||
content = {"error": str(error)}
|
||||
)
|
||||
|
||||
@app.get("/settings", response_model=dict)
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
|||
await session.merge(data)
|
||||
else:
|
||||
data = Data(
|
||||
id = data_id,
|
||||
name = file_metadata["name"],
|
||||
raw_data_location = file_metadata["file_path"],
|
||||
extension = file_metadata["extension"],
|
||||
|
|
@ -139,6 +140,8 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
|||
|
||||
await session.merge(dataset)
|
||||
|
||||
await session.commit()
|
||||
|
||||
yield {
|
||||
"id": data_id,
|
||||
"name": file_metadata["name"],
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from cognee.shared.utils import send_telemetry
|
|||
from cognee.modules.tasks import create_task_status_table, update_task_status
|
||||
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
||||
from cognee.modules.tasks import get_task_status
|
||||
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
|
||||
from cognee.infrastructure.data.chunking.config import get_chunk_config
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
@ -90,7 +91,7 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
|
||||
for added_dataset in added_datasets:
|
||||
if dataset_name in added_dataset:
|
||||
dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
|
||||
dataset_files.append((added_dataset, await get_dataset_data(dataset_name = added_dataset)))
|
||||
|
||||
chunk_config = get_chunk_config()
|
||||
chunk_engine = get_chunk_engine()
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ from cognee.modules.data.processing.filter_affected_chunks import filter_affecte
|
|||
from cognee.modules.data.processing.remove_obsolete_chunks import remove_obsolete_chunks
|
||||
from cognee.modules.data.extraction.knowledge_graph.expand_knowledge_graph import expand_knowledge_graph
|
||||
from cognee.modules.data.extraction.knowledge_graph.establish_graph_topology import establish_graph_topology
|
||||
from cognee.modules.data.models import Dataset, Data
|
||||
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.operations.retrieve_datasets import retrieve_datasets
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -40,20 +43,25 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
|||
if datasets is None or len(datasets) == 0:
|
||||
return await cognify(await db_engine.get_datasets())
|
||||
|
||||
if type(datasets[0]) == str:
|
||||
datasets = await retrieve_datasets(datasets)
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
async def run_cognify_pipeline(dataset_name: str, files: list[dict]):
|
||||
async def run_cognify_pipeline(dataset: Dataset):
|
||||
data: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
||||
|
||||
documents = [
|
||||
PdfDocument(id = file["id"], title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
|
||||
AudioDocument(id = file["id"], title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
|
||||
ImageDocument(id = file["id"], title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
|
||||
TextDocument(id = file["id"], title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
|
||||
for file in files
|
||||
PdfDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location) if data_item.extension == "pdf" else
|
||||
AudioDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location) if data_item.extension == "audio" else
|
||||
ImageDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location) if data_item.extension == "image" else
|
||||
TextDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location)
|
||||
for data_item in data
|
||||
]
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
document_ids_str = list(map(str, document_ids))
|
||||
|
||||
await check_permissions_on_documents(
|
||||
user,
|
||||
|
|
@ -61,16 +69,19 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
|||
document_ids,
|
||||
)
|
||||
|
||||
async with update_status_lock:
|
||||
task_status = await get_pipeline_status([dataset_name])
|
||||
dataset_id = dataset.id
|
||||
dataset_name = generate_dataset_name(dataset.name)
|
||||
|
||||
if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
|
||||
logger.info(f"Dataset {dataset_name} is being processed.")
|
||||
async with update_status_lock:
|
||||
task_status = await get_pipeline_status([dataset_id])
|
||||
|
||||
if dataset_id in task_status and task_status[dataset_id] == "DATASET_PROCESSING_STARTED":
|
||||
logger.info("Dataset %s is already being processed.", dataset_name)
|
||||
return
|
||||
|
||||
await log_pipeline_status(dataset_name, "DATASET_PROCESSING_STARTED", {
|
||||
await log_pipeline_status(dataset_id, "DATASET_PROCESSING_STARTED", {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
|
|
@ -80,7 +91,7 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
|||
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
root_node_id = await topology_engine.add_graph_topology(files = files)
|
||||
root_node_id = await topology_engine.add_graph_topology(files = data)
|
||||
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
|
|
@ -116,14 +127,14 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
|||
async for result in pipeline:
|
||||
print(result)
|
||||
|
||||
await log_pipeline_status(dataset_name, "DATASET_PROCESSING_FINISHED", {
|
||||
await log_pipeline_status(dataset_id, "DATASET_PROCESSING_FINISHED", {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
except Exception as error:
|
||||
await log_pipeline_status(dataset_name, "DATASET_PROCESSING_ERROR", {
|
||||
await log_pipeline_status(dataset_id, "DATASET_PROCESSING_ERROR", {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
raise error
|
||||
|
||||
|
|
@ -131,31 +142,14 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
|||
existing_datasets = [dataset.name for dataset in list(await db_engine.get_datasets())]
|
||||
awaitables = []
|
||||
|
||||
for dataset_name in datasets:
|
||||
dataset_name = generate_dataset_name(dataset_name)
|
||||
for dataset in datasets:
|
||||
dataset_name = generate_dataset_name(dataset.name)
|
||||
|
||||
if dataset_name in existing_datasets:
|
||||
awaitables.append(run_cognify_pipeline(dataset_name, await db_engine.get_files_metadata(dataset_name)))
|
||||
awaitables.append(run_cognify_pipeline(dataset))
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
||||
|
||||
def generate_dataset_name(dataset_name: str) -> str:
|
||||
return dataset_name.replace(".", "_").replace(" ", "_")
|
||||
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# from cognee.api.v1.add import add
|
||||
# from cognee.api.v1.datasets.datasets import datasets
|
||||
#
|
||||
#
|
||||
# async def aa():
|
||||
# await add("TEXT ABOUT NLP AND MONKEYS")
|
||||
#
|
||||
# print(datasets.discover_datasets())
|
||||
#
|
||||
# return
|
||||
|
||||
|
||||
|
||||
# asyncio.run(cognify())
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from duckdb import CatalogException
|
||||
from cognee.modules.ingestion import discover_directory_datasets
|
||||
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
|
@ -14,10 +15,9 @@ class datasets():
|
|||
return list(discover_directory_datasets(directory_path).keys())
|
||||
|
||||
@staticmethod
|
||||
async def list_data(dataset_name: str):
|
||||
db = get_relational_engine()
|
||||
async def list_data(dataset_id: str, dataset_name: str = None):
|
||||
try:
|
||||
return await db.get_files_metadata(dataset_name)
|
||||
return await get_dataset_data(dataset_id = dataset_id, dataset_name = dataset_name)
|
||||
except CatalogException:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -28,15 +28,15 @@ class SearchType(Enum):
|
|||
def from_str(name: str):
|
||||
try:
|
||||
return SearchType[name.upper()]
|
||||
except KeyError:
|
||||
raise ValueError(f"{name} is not a valid SearchType")
|
||||
except KeyError as error:
|
||||
raise ValueError(f"{name} is not a valid SearchType") from error
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
search_type: SearchType
|
||||
params: Dict[str, Any]
|
||||
|
||||
@field_validator("search_type", mode="before")
|
||||
def convert_string_to_enum(cls, value):
|
||||
def convert_string_to_enum(cls, value): # pylint: disable=no-self-argument
|
||||
if isinstance(value, str):
|
||||
return SearchType.from_str(value)
|
||||
return value
|
||||
|
|
@ -46,20 +46,21 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
extract_documents = await get_document_ids_for_user(user.id)
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
search_params = SearchParameters(search_type = search_type, params = params)
|
||||
searches = await specific_search([search_params])
|
||||
search_results = await specific_search([search_params])
|
||||
|
||||
filtered_searches = []
|
||||
for document in searches:
|
||||
for document_id in extract_documents:
|
||||
if document_id in document:
|
||||
filtered_searches.append(document)
|
||||
from uuid import UUID
|
||||
|
||||
filtered_search_results = []
|
||||
|
||||
return filtered_searches
|
||||
for search_result in search_results:
|
||||
result_document_id = UUID(search_result["document_id"]) if "document_id" in search_result else None
|
||||
|
||||
if result_document_id is None or result_document_id in own_document_ids:
|
||||
filtered_search_results.append(search_result)
|
||||
|
||||
return filtered_search_results
|
||||
|
||||
|
||||
async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||
|
|
@ -71,7 +72,6 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
|||
SearchType.SIMILARITY: search_similarity,
|
||||
}
|
||||
|
||||
results = []
|
||||
search_tasks = []
|
||||
|
||||
for search_param in query_params:
|
||||
|
|
@ -84,38 +84,6 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
|||
# Use asyncio.gather to run all scheduled tasks concurrently
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
|
||||
# Update the results set with the results from all tasks
|
||||
results.extend(search_results)
|
||||
|
||||
send_telemetry("cognee.search")
|
||||
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
# Assuming 'graph' is your graph object, obtained from somewhere
|
||||
search_type = 'CATEGORIES'
|
||||
params = {'query': 'Ministarstvo', 'other_param': {"node_id": "LLM_LAYER_SUMMARY:DOCUMENT:881ecb36-2819-54c3-8147-ed80293084d6"}}
|
||||
|
||||
results = await search(search_type, params)
|
||||
print(results)
|
||||
|
||||
# Run the async main function
|
||||
asyncio.run(main())
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
|
||||
# query_params = {
|
||||
# SearchType.SIMILARITY: {'query': 'your search query here'}
|
||||
# }
|
||||
# async def main():
|
||||
# graph_client = get_graph_engine()
|
||||
|
||||
# await graph_client.load_graph_from_file()
|
||||
# graph = graph_client.graph
|
||||
# results = await search(graph, query_params)
|
||||
# print(results)
|
||||
|
||||
# asyncio.run(main())
|
||||
return search_results[0] if len(search_results) == 1 else search_results
|
||||
|
|
|
|||
|
|
@ -56,15 +56,6 @@ class SQLAlchemyAdapter():
|
|||
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
|
||||
await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"))
|
||||
|
||||
async def get_files_metadata(self, dataset_name: str):
|
||||
async with self.engine.connect() as connection:
|
||||
result = await connection.execute(
|
||||
text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;"))
|
||||
rows = result.fetchall()
|
||||
metadata = [{"id": row.id, "name": row.name, "file_path": row.file_path, "extension": row.extension,
|
||||
"mime_type": row.mime_type} for row in rows]
|
||||
return metadata
|
||||
|
||||
async def delete_table(self, table_name: str):
|
||||
async with self.engine.connect() as connection:
|
||||
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ async def summarize_text_chunks(data_chunks: list[DocumentChunk], summarization_
|
|||
id = str(chunk.chunk_id),
|
||||
payload = dict(
|
||||
chunk_id = str(chunk.chunk_id),
|
||||
document_id = str(chunk.document_id),
|
||||
text = chunk_summaries[chunk_index].summary,
|
||||
),
|
||||
embed_field = "text",
|
||||
|
|
|
|||
|
|
@ -23,3 +23,15 @@ class Data(Base):
|
|||
secondary = DatasetData.__tablename__,
|
||||
back_populates = "data"
|
||||
)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"name": self.name,
|
||||
"extension": self.extension,
|
||||
"mimeType": self.mime_type,
|
||||
"rawDataLocation": self.raw_data_location,
|
||||
"createdAt": self.created_at.isoformat(),
|
||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||
# "datasets": [dataset.to_json() for dataset in self.datasets]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,3 +20,12 @@ class Dataset(Base):
|
|||
secondary = DatasetData.__tablename__,
|
||||
back_populates = "datasets"
|
||||
)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"name": self.name,
|
||||
"createdAt": self.created_at.isoformat(),
|
||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"data": [data.to_json() for data in self.data]
|
||||
}
|
||||
|
|
|
|||
18
cognee/modules/data/operations/get_dataset_data.py
Normal file
18
cognee/modules/data/operations/get_dataset_data.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
async def get_dataset_data(dataset_id: UUID = None, dataset_name: str = None):
|
||||
if dataset_id is None and dataset_name is None:
|
||||
raise ValueError("get_dataset_data: Either dataset_id or dataset_name must be provided.")
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id) | (Dataset.name == dataset_name))
|
||||
)
|
||||
data = result.scalars().all()
|
||||
|
||||
return data
|
||||
13
cognee/modules/data/operations/retrieve_datasets.py
Normal file
13
cognee/modules/data/operations/retrieve_datasets.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from sqlalchemy import select
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ..models import Dataset
|
||||
|
||||
async def retrieve_datasets(dataset_names: list[str]) -> list[Dataset]:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
datasets = (await session.scalars(
|
||||
select(Dataset).filter(Dataset.name.in_(dataset_names))
|
||||
)).all()
|
||||
|
||||
return datasets
|
||||
|
|
@ -10,8 +10,7 @@ class PipelineRun(Base):
|
|||
|
||||
created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc))
|
||||
|
||||
run_name = Column(String, index = True)
|
||||
|
||||
status = Column(String)
|
||||
|
||||
run_id = Column(UUID(as_uuid = True), index = True)
|
||||
run_info = Column(JSON)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import aliased
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ..models import PipelineRun
|
||||
|
||||
async def get_pipeline_status(pipeline_names: [str]):
|
||||
async def get_pipeline_status(pipeline_ids: list[UUID]):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
query = select(
|
||||
PipelineRun,
|
||||
func.row_number().over(
|
||||
partition_by = PipelineRun.run_name,
|
||||
partition_by = PipelineRun.run_id,
|
||||
order_by = PipelineRun.created_at.desc(),
|
||||
).label("rn")
|
||||
).filter(PipelineRun.run_name.in_(pipeline_names)).subquery()
|
||||
).filter(PipelineRun.run_id.in_(pipeline_ids)).subquery()
|
||||
|
||||
aliased_pipeline_run = aliased(PipelineRun, query)
|
||||
|
||||
|
|
@ -24,7 +25,7 @@ async def get_pipeline_status(pipeline_names: [str]):
|
|||
runs = (await session.execute(latest_runs)).scalars().all()
|
||||
|
||||
pipeline_statuses = {
|
||||
run.run_name: run.status for run in runs
|
||||
str(run.run_id): run.status for run in runs
|
||||
}
|
||||
|
||||
return pipeline_statuses
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from uuid import UUID
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ..models.PipelineRun import PipelineRun
|
||||
|
||||
async def log_pipeline_status(run_name: str, status: str, run_info: dict):
|
||||
async def log_pipeline_status(run_id: UUID, status: str, run_info: dict):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
session.add(PipelineRun(
|
||||
run_name = run_name,
|
||||
run_id = run_id,
|
||||
status = status,
|
||||
run_info = run_info,
|
||||
))
|
||||
|
|
|
|||
|
|
@ -12,6 +12,16 @@ async def search_similarity(query: str) -> list[str, str]:
|
|||
|
||||
similar_results = await vector_engine.search("chunks", query, limit = 5)
|
||||
|
||||
results = [result.payload for result in similar_results]
|
||||
results = [
|
||||
parse_payload(result.payload) for result in similar_results
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def parse_payload(payload: dict) -> dict:
|
||||
return {
|
||||
"text": payload["text"],
|
||||
"chunk_id": payload["chunk_id"],
|
||||
"document_id": payload["document_id"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@
|
|||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, Type
|
||||
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -14,16 +12,10 @@ from pydantic import BaseModel
|
|||
from cognee.infrastructure.data.chunking.config import get_chunk_config
|
||||
from cognee.infrastructure.data.chunking.get_chunking_engine import get_chunk_engine
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.files.utils.extract_text_from_file import extract_text_from_file
|
||||
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type, FileTypeException
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.topology.topology_data_models import NodeModel
|
||||
|
||||
cognify_config = get_cognify_config()
|
||||
base_config = get_base_config()
|
||||
|
||||
logger = logging.getLogger("topology")
|
||||
|
||||
class TopologyEngine:
|
||||
|
|
@ -136,51 +128,3 @@ class TopologyEngine:
|
|||
return
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add graph topology from {file_path}: {e}") from e
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
# text = """Conservative PP in the lead in Spain, according to estimate
|
||||
# An estimate has been published for Spain:
|
||||
#
|
||||
# Opposition leader Alberto Núñez Feijóo’s conservative People’s party (PP): 32.4%
|
||||
#
|
||||
# Spanish prime minister Pedro Sánchez’s Socialist party (PSOE): 30.2%
|
||||
#
|
||||
# The far-right Vox party: 10.4%
|
||||
#
|
||||
# In Spain, the right has sought to turn the European election into a referendum on Sánchez.
|
||||
#
|
||||
# Ahead of the vote, public attention has focused on a saga embroiling the prime minister’s wife, Begoña Gómez, who is being investigated over allegations of corruption and influence-peddling, which Sanchez has dismissed as politically-motivated and totally baseless."""
|
||||
# text_two = """The far-right Vox party: 10.4%"""
|
||||
|
||||
from cognee.api.v1.add import add
|
||||
dataset_name = "explanations"
|
||||
print(os.getcwd())
|
||||
data_dir = os.path.abspath("../../.data")
|
||||
print(os.getcwd())
|
||||
|
||||
await add(f"data://{data_dir}", dataset_name="explanations")
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
datasets = await db_engine.get_datasets()
|
||||
dataset_files =[]
|
||||
|
||||
for added_dataset in datasets:
|
||||
if dataset_name in added_dataset:
|
||||
dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
|
||||
|
||||
|
||||
|
||||
print(dataset_files)
|
||||
topology_engine = TopologyEngine(infer=True)
|
||||
file_path = "example_data.json" # or 'example_data.csv'
|
||||
#
|
||||
# # Adding graph topology
|
||||
graph = await topology_engine.add_graph_topology(file_path, files = dataset_files)
|
||||
print(graph)
|
||||
|
||||
# Run the main function
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
# from sqlalchemy.orm import relationship, Mapped
|
||||
# from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, UUID, String
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
|
@ -15,9 +15,3 @@ class Permission(Base):
|
|||
name = Column(String, unique = True, nullable = False, index = True)
|
||||
|
||||
# acls = relationship("ACL", back_populates = "permission")
|
||||
|
||||
# groups: Mapped[list["Group"]] = relationship(
|
||||
# "Group",
|
||||
# secondary = "group_permissions",
|
||||
# back_populates = "permissions",
|
||||
# )
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
@ -14,7 +15,7 @@ class PermissionDeniedException(Exception):
|
|||
super().__init__(self.message)
|
||||
|
||||
|
||||
async def check_permissions_on_documents(user: User, permission_type: str, document_ids: list[str]):
|
||||
async def check_permissions_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
|
||||
try:
|
||||
user_group_ids = [group.id for group in user.groups]
|
||||
|
||||
|
|
@ -29,7 +30,7 @@ async def check_permissions_on_documents(user: User, permission_type: str, docum
|
|||
.where(ACL.permission.has(name = permission_type))
|
||||
)
|
||||
acls = result.unique().scalars().all()
|
||||
resource_ids = [str(resource.resource_id) for acl in acls for resource in acl.resources]
|
||||
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
|
||||
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
||||
|
||||
if not has_permissions:
|
||||
|
|
|
|||
|
|
@ -1,24 +1,21 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ...models import ACL
|
||||
from ...models import ACL, Resource, Permission
|
||||
|
||||
async def get_document_ids_for_user(user_id: UUID) -> list[str]:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
async with session.begin():
|
||||
result = await session.execute(
|
||||
select(ACL.resources.resource_id) \
|
||||
.join(ACL.resources) \
|
||||
.filter_by(
|
||||
ACL.principal_id == user_id,
|
||||
ACL.permission.name == "read",
|
||||
)
|
||||
)
|
||||
document_ids = [row[0] for row in result.scalars().all()]
|
||||
document_ids = (await session.scalars(
|
||||
select(Resource.resource_id)
|
||||
.join(ACL.resources)
|
||||
.join(ACL.permission)
|
||||
.where(
|
||||
ACL.principal_id == user_id,
|
||||
Permission.name == "read",
|
||||
)
|
||||
)).all()
|
||||
|
||||
return document_ids
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue