feat: Make relational databases work as singleton
Moved dlt pipeline to run in it's own fuction so it doesn't use get_relational_database. Dlt has it's own async event loop and object can't be shared between event loops Feature COG-678
This commit is contained in:
parent
dfd30d8e54
commit
9bd3011264
3 changed files with 45 additions and 24 deletions
|
|
@ -1,8 +1,10 @@
|
|||
from functools import lru_cache
|
||||
|
||||
from .config import get_relational_config
|
||||
from .create_relational_engine import create_relational_engine
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_relational_engine():
|
||||
relational_config = get_relational_config()
|
||||
|
||||
return create_relational_engine(**relational_config.to_dict())
|
||||
return create_relational_engine(**relational_config.to_dict())
|
||||
|
|
@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
|
|||
from sqlalchemy import text, select, MetaData, Table
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from ..ModelBase import Base
|
||||
|
||||
class SQLAlchemyAdapter():
|
||||
|
|
@ -171,6 +170,24 @@ class SQLAlchemyAdapter():
|
|||
results = await connection.execute(query)
|
||||
return {result["data_id"]: result["status"] for result in results}
|
||||
|
||||
async def get_all_data_from_table(self, table_name: str, schema: str = None):
|
||||
async with self.get_async_session() as session:
|
||||
# Validate inputs to prevent SQL injection
|
||||
if not table_name.isidentifier():
|
||||
raise ValueError("Invalid table name")
|
||||
if schema and not schema.isidentifier():
|
||||
raise ValueError("Invalid schema name")
|
||||
|
||||
table = await self.get_table(table_name, schema)
|
||||
|
||||
# Query all data from the table
|
||||
query = select(table)
|
||||
result = await session.execute(query)
|
||||
|
||||
# Fetch all rows as a list of dictionaries
|
||||
rows = result.mappings().all() # Use `.mappings()` to get key-value pairs
|
||||
return rows
|
||||
|
||||
async def execute_query(self, query):
|
||||
async with self.engine.begin() as connection:
|
||||
result = await connection.execute(text(query))
|
||||
|
|
|
|||
|
|
@ -17,25 +17,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
|||
)
|
||||
|
||||
@dlt.resource(standalone = True, merge_key = "id")
|
||||
async def data_resources(file_paths: str, user: User):
|
||||
async def data_resources(file_paths: str):
|
||||
for file_path in file_paths:
|
||||
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
|
||||
data_id = ingestion.identify(classified_data)
|
||||
|
||||
file_metadata = classified_data.get_metadata()
|
||||
yield {
|
||||
"id": data_id,
|
||||
"name": file_metadata["name"],
|
||||
"file_path": file_metadata["file_path"],
|
||||
"extension": file_metadata["extension"],
|
||||
"mime_type": file_metadata["mime_type"],
|
||||
}
|
||||
|
||||
from sqlalchemy import select
|
||||
from cognee.modules.data.models import Data
|
||||
async def data_storing(table_name, dataset_name, user: User):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Read metadata stored with dlt
|
||||
files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
|
||||
for file_metadata in files_metadata:
|
||||
from sqlalchemy import select
|
||||
from cognee.modules.data.models import Data
|
||||
dataset = await create_dataset(dataset_name, user.id, session)
|
||||
|
||||
data = (await session.execute(
|
||||
select(Data).filter(Data.id == data_id)
|
||||
select(Data).filter(Data.id == file_metadata["id"])
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if data is not None:
|
||||
|
|
@ -48,7 +56,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
|||
await session.commit()
|
||||
else:
|
||||
data = Data(
|
||||
id = data_id,
|
||||
id = file_metadata["id"],
|
||||
name = file_metadata["name"],
|
||||
raw_data_location = file_metadata["file_path"],
|
||||
extension = file_metadata["extension"],
|
||||
|
|
@ -58,25 +66,19 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
|||
dataset.data.append(data)
|
||||
await session.commit()
|
||||
|
||||
yield {
|
||||
"id": data_id,
|
||||
"name": file_metadata["name"],
|
||||
"file_path": file_metadata["file_path"],
|
||||
"extension": file_metadata["extension"],
|
||||
"mime_type": file_metadata["mime_type"],
|
||||
}
|
||||
|
||||
await give_permission_on_document(user, data_id, "read")
|
||||
await give_permission_on_document(user, data_id, "write")
|
||||
await give_permission_on_document(user, file_metadata["id"], "read")
|
||||
await give_permission_on_document(user, file_metadata["id"], "write")
|
||||
|
||||
|
||||
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
||||
run_info = pipeline.run(
|
||||
data_resources(file_paths, user),
|
||||
data_resources(file_paths),
|
||||
table_name = "file_metadata",
|
||||
dataset_name = dataset_name,
|
||||
write_disposition = "merge",
|
||||
)
|
||||
|
||||
await data_storing("file_metadata", dataset_name, user)
|
||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
||||
|
||||
return run_info
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue