feat: Make relational databases work as singleton

Moved dlt pipeline to run in it's own fuction so it doesn't use get_relational_database.
Dlt has it's own async event loop and object can't be shared between event loops

Feature COG-678
This commit is contained in:
Igor Ilic 2024-11-28 12:59:04 +01:00
parent dfd30d8e54
commit 9bd3011264
3 changed files with 45 additions and 24 deletions

View file

@ -1,8 +1,10 @@
from functools import lru_cache
from .config import get_relational_config from .config import get_relational_config
from .create_relational_engine import create_relational_engine from .create_relational_engine import create_relational_engine
@lru_cache
def get_relational_engine(): def get_relational_engine():
relational_config = get_relational_config() relational_config = get_relational_config()
return create_relational_engine(**relational_config.to_dict()) return create_relational_engine(**relational_config.to_dict())

View file

@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
from sqlalchemy import text, select, MetaData, Table from sqlalchemy import text, select, MetaData, Table
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from ..ModelBase import Base from ..ModelBase import Base
class SQLAlchemyAdapter(): class SQLAlchemyAdapter():
@ -171,6 +170,24 @@ class SQLAlchemyAdapter():
results = await connection.execute(query) results = await connection.execute(query)
return {result["data_id"]: result["status"] for result in results} return {result["data_id"]: result["status"] for result in results}
async def get_all_data_from_table(self, table_name: str, schema: str = None):
async with self.get_async_session() as session:
# Validate inputs to prevent SQL injection
if not table_name.isidentifier():
raise ValueError("Invalid table name")
if schema and not schema.isidentifier():
raise ValueError("Invalid schema name")
table = await self.get_table(table_name, schema)
# Query all data from the table
query = select(table)
result = await session.execute(query)
# Fetch all rows as a list of dictionaries
rows = result.mappings().all() # Use `.mappings()` to get key-value pairs
return rows
async def execute_query(self, query): async def execute_query(self, query):
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
result = await connection.execute(text(query)) result = await connection.execute(text(query))

View file

@ -17,25 +17,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
) )
@dlt.resource(standalone = True, merge_key = "id") @dlt.resource(standalone = True, merge_key = "id")
async def data_resources(file_paths: str, user: User): async def data_resources(file_paths: str):
for file_path in file_paths: for file_path in file_paths:
with open(file_path.replace("file://", ""), mode = "rb") as file: with open(file_path.replace("file://", ""), mode = "rb") as file:
classified_data = ingestion.classify(file) classified_data = ingestion.classify(file)
data_id = ingestion.identify(classified_data) data_id = ingestion.identify(classified_data)
file_metadata = classified_data.get_metadata() file_metadata = classified_data.get_metadata()
yield {
"id": data_id,
"name": file_metadata["name"],
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
}
from sqlalchemy import select async def data_storing(table_name, dataset_name, user: User):
from cognee.modules.data.models import Data db_engine = get_relational_engine()
db_engine = get_relational_engine() async with db_engine.get_async_session() as session:
# Read metadata stored with dlt
async with db_engine.get_async_session() as session: files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
for file_metadata in files_metadata:
from sqlalchemy import select
from cognee.modules.data.models import Data
dataset = await create_dataset(dataset_name, user.id, session) dataset = await create_dataset(dataset_name, user.id, session)
data = (await session.execute( data = (await session.execute(
select(Data).filter(Data.id == data_id) select(Data).filter(Data.id == file_metadata["id"])
)).scalar_one_or_none() )).scalar_one_or_none()
if data is not None: if data is not None:
@ -48,7 +56,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
await session.commit() await session.commit()
else: else:
data = Data( data = Data(
id = data_id, id = file_metadata["id"],
name = file_metadata["name"], name = file_metadata["name"],
raw_data_location = file_metadata["file_path"], raw_data_location = file_metadata["file_path"],
extension = file_metadata["extension"], extension = file_metadata["extension"],
@ -58,25 +66,19 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
dataset.data.append(data) dataset.data.append(data)
await session.commit() await session.commit()
yield { await give_permission_on_document(user, file_metadata["id"], "read")
"id": data_id, await give_permission_on_document(user, file_metadata["id"], "write")
"name": file_metadata["name"],
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
}
await give_permission_on_document(user, data_id, "read")
await give_permission_on_document(user, data_id, "write")
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id) send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
run_info = pipeline.run( run_info = pipeline.run(
data_resources(file_paths, user), data_resources(file_paths),
table_name = "file_metadata", table_name = "file_metadata",
dataset_name = dataset_name, dataset_name = dataset_name,
write_disposition = "merge", write_disposition = "merge",
) )
await data_storing("file_metadata", dataset_name, user)
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id) send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
return run_info return run_info