diff --git a/cognee/infrastructure/databases/relational/get_relational_engine.py b/cognee/infrastructure/databases/relational/get_relational_engine.py index 6024c7bd0..c0a66e28e 100644 --- a/cognee/infrastructure/databases/relational/get_relational_engine.py +++ b/cognee/infrastructure/databases/relational/get_relational_engine.py @@ -1,8 +1,10 @@ +from functools import lru_cache + from .config import get_relational_config from .create_relational_engine import create_relational_engine - +@lru_cache def get_relational_engine(): relational_config = get_relational_config() - return create_relational_engine(**relational_config.to_dict()) + return create_relational_engine(**relational_config.to_dict()) \ No newline at end of file diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index aa2a022d3..a5733967e 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -6,7 +6,6 @@ from contextlib import asynccontextmanager from sqlalchemy import text, select, MetaData, Table from sqlalchemy.orm import joinedload from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker - from ..ModelBase import Base class SQLAlchemyAdapter(): @@ -171,6 +170,27 @@ class SQLAlchemyAdapter(): results = await connection.execute(query) return {result["data_id"]: result["status"] for result in results} + async def get_all_data_from_table(self, table_name: str, schema: str = "public"): + async with self.get_async_session() as session: + # Validate inputs to prevent SQL injection + if not table_name.isidentifier(): + raise ValueError("Invalid table name") + if schema and not schema.isidentifier(): + raise ValueError("Invalid schema name") + + if self.engine.dialect.name == "sqlite": + table = await self.get_table(table_name) + else: + table = await self.get_table(table_name, schema) + + # Query all data from the table + query = select(table) + result = await session.execute(query) + + # Fetch all rows as a list of dictionaries + rows = result.mappings().all() + return rows + async def execute_query(self, query): async with self.engine.begin() as connection: result = await connection.execute(text(query)) @@ -205,7 +225,6 @@ class SQLAlchemyAdapter(): from cognee.infrastructure.files.storage import LocalStorage LocalStorage.remove(self.db_path) - self.db_path = None else: async with self.engine.begin() as connection: schema_list = await self.get_schema_list() diff --git a/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py b/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py index f40299939..2f4c9cf3f 100644 --- a/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +++ b/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py @@ -1,12 +1,12 @@ -from ...relational.ModelBase import Base from ..get_vector_engine import get_vector_engine, get_vectordb_config from sqlalchemy import text + async def create_db_and_tables(): vector_config = get_vectordb_config() vector_engine = get_vector_engine() if vector_config.vector_db_provider == "pgvector": - await vector_engine.create_database() async with vector_engine.engine.begin() as connection: await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + diff --git a/cognee/tasks/ingestion/get_dlt_destination.py b/cognee/tasks/ingestion/get_dlt_destination.py index 97e3d3220..12042c75b 100644 --- a/cognee/tasks/ingestion/get_dlt_destination.py +++ b/cognee/tasks/ingestion/get_dlt_destination.py @@ -1,9 +1,12 @@ import os +from functools import lru_cache + import dlt from typing import Union from cognee.infrastructure.databases.relational import get_relational_config +@lru_cache def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]: """ Handles propagation of the cognee database configuration to the dlt library diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py index cb4b54598..9418d035b 100644 --- a/cognee/tasks/ingestion/ingest_data.py +++ b/cognee/tasks/ingestion/ingest_data.py @@ -1,6 +1,7 @@ import dlt import cognee.modules.ingestion as ingestion +from uuid import UUID from cognee.shared.utils import send_telemetry from cognee.modules.users.models import User from cognee.infrastructure.databases.relational import get_relational_engine @@ -17,25 +18,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User): ) @dlt.resource(standalone = True, merge_key = "id") - async def data_resources(file_paths: str, user: User): + async def data_resources(file_paths: str): for file_path in file_paths: with open(file_path.replace("file://", ""), mode = "rb") as file: classified_data = ingestion.classify(file) - data_id = ingestion.identify(classified_data) - file_metadata = classified_data.get_metadata() + yield { + "id": data_id, + "name": file_metadata["name"], + "file_path": file_metadata["file_path"], + "extension": file_metadata["extension"], + "mime_type": file_metadata["mime_type"], + } - from sqlalchemy import select - from cognee.modules.data.models import Data + async def data_storing(table_name, dataset_name, user: User): + db_engine = get_relational_engine() - db_engine = get_relational_engine() - - async with db_engine.get_async_session() as session: + async with db_engine.get_async_session() as session: + # Read metadata stored with dlt + files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name) + for file_metadata in files_metadata: + from sqlalchemy import select + from cognee.modules.data.models import Data dataset = await create_dataset(dataset_name, user.id, session) data = (await session.execute( - select(Data).filter(Data.id == data_id) + select(Data).filter(Data.id == UUID(file_metadata["id"])) )).scalar_one_or_none() if data is not None: @@ -48,7 +57,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User): await session.commit() else: data = Data( - id = data_id, + id = UUID(file_metadata["id"]), name = file_metadata["name"], raw_data_location = file_metadata["file_path"], extension = file_metadata["extension"], @@ -58,25 +67,34 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User): dataset.data.append(data) await session.commit() - yield { - "id": data_id, - "name": file_metadata["name"], - "file_path": file_metadata["file_path"], - "extension": file_metadata["extension"], - "mime_type": file_metadata["mime_type"], - } - - await give_permission_on_document(user, data_id, "read") - await give_permission_on_document(user, data_id, "write") + await give_permission_on_document(user, UUID(file_metadata["id"]), "read") + await give_permission_on_document(user, UUID(file_metadata["id"]), "write") send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id) - run_info = pipeline.run( - data_resources(file_paths, user), - table_name = "file_metadata", - dataset_name = dataset_name, - write_disposition = "merge", - ) + + db_engine = get_relational_engine() + + # Note: DLT pipeline has its own event loop, therefore objects created in another event loop + # can't be used inside the pipeline + if db_engine.engine.dialect.name == "sqlite": + # To use sqlite with dlt dataset_name must be set to "main". + # Sqlite doesn't support schemas + run_info = pipeline.run( + data_resources(file_paths), + table_name = "file_metadata", + dataset_name = "main", + write_disposition = "merge", + ) + else: + run_info = pipeline.run( + data_resources(file_paths), + table_name="file_metadata", + dataset_name=dataset_name, + write_disposition="merge", + ) + + await data_storing("file_metadata", dataset_name, user) send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id) return run_info