cognee/cognee/tasks/ingestion/ingest_data.py
Boris ebf1f81b35
fix: code cleanup [COG-781] (#667)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin
2025-03-26 18:32:43 +01:00

152 lines
6.3 KiB
Python

from typing import Any, List
import dlt
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name
from cognee.modules.data.models.DatasetData import DatasetData
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import give_permission_on_document
from .get_dlt_destination import get_dlt_destination
from .save_data_item_to_storage import save_data_item_to_storage
from typing import Union, BinaryIO
import inspect
async def ingest_data(data: Any, dataset_name: str, user: User):
destination = get_dlt_destination()
pipeline = dlt.pipeline(
pipeline_name="file_load_from_filesystem",
destination=destination,
)
def get_external_metadata_dict(data_item: Union[BinaryIO, str, Any]) -> dict[str, Any]:
if hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
return {"metadata": data_item.dict(), "origin": str(type(data_item))}
else:
return {}
@dlt.resource(standalone=True, primary_key="id", merge_key="id")
async def data_resources(file_paths: List[str], user: User):
for file_path in file_paths:
with open(file_path.replace("file://", ""), mode="rb") as file:
classified_data = ingestion.classify(file)
data_id = ingestion.identify(classified_data, user)
file_metadata = classified_data.get_metadata()
yield {
"id": data_id,
"name": file_metadata["name"],
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
"content_hash": file_metadata["content_hash"],
"owner_id": str(user.id),
}
async def store_data_to_dataset(data: Any, dataset_name: str, user: User):
if not isinstance(data, list):
# Convert data to a list as we work with lists further down.
data = [data]
file_paths = []
# Process data
for data_item in data:
file_path = await save_data_item_to_storage(data_item, dataset_name)
file_paths.append(file_path)
# Ingest data and add metadata
with open(file_path.replace("file://", ""), mode="rb") as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
file_metadata = classified_data.get_metadata()
from sqlalchemy import select
from cognee.modules.data.models import Data
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
dataset = await create_dataset(dataset_name, user.id, session)
# Check to see if data should be updated
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
if data_point is not None:
data_point.name = file_metadata["name"]
data_point.raw_data_location = file_metadata["file_path"]
data_point.extension = file_metadata["extension"]
data_point.mime_type = file_metadata["mime_type"]
data_point.owner_id = user.id
data_point.content_hash = file_metadata["content_hash"]
data_point.external_metadata = (get_external_metadata_dict(data_item),)
await session.merge(data_point)
else:
data_point = Data(
id=data_id,
name=file_metadata["name"],
raw_data_location=file_metadata["file_path"],
extension=file_metadata["extension"],
mime_type=file_metadata["mime_type"],
owner_id=user.id,
content_hash=file_metadata["content_hash"],
external_metadata=get_external_metadata_dict(data_item),
token_count=-1,
)
# Check if data is already in dataset
dataset_data = (
await session.execute(
select(DatasetData).filter(
DatasetData.data_id == data_id, DatasetData.dataset_id == dataset.id
)
)
).scalar_one_or_none()
# If data is not present in dataset add it
if dataset_data is None:
dataset.data.append(data_point)
await session.commit()
await give_permission_on_document(user, data_id, "read")
await give_permission_on_document(user, data_id, "write")
return file_paths
db_engine = get_relational_engine()
file_paths = await store_data_to_dataset(data, dataset_name, user)
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
# can't be used inside the pipeline
if db_engine.engine.dialect.name == "sqlite":
# To use sqlite with dlt dataset_name must be set to "main".
# Sqlite doesn't support schemas
pipeline.run(
data_resources(file_paths, user),
table_name="file_metadata",
dataset_name="main",
write_disposition="merge",
)
else:
# Data should be stored in the same schema to allow deduplication
pipeline.run(
data_resources(file_paths, user),
table_name="file_metadata",
dataset_name="public",
write_disposition="merge",
)
datasets = await get_datasets_by_name(dataset_name, user.id)
dataset = datasets[0]
data_documents = await get_dataset_data(dataset_id=dataset.id)
return data_documents