Cog 678 relational database singleton (#38)

1. Set relational database in cognee to be used as singleton and made
necessary changes to enable this
2. Added SQLite support to dlt pipeline in ingest_data
This commit is contained in:
Igor Ilic 2024-12-02 13:02:34 +01:00 committed by GitHub
commit 34971d16cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 74 additions and 32 deletions

View file

@ -1,8 +1,10 @@
from functools import lru_cache
from .config import get_relational_config from .config import get_relational_config
from .create_relational_engine import create_relational_engine from .create_relational_engine import create_relational_engine
@lru_cache
def get_relational_engine(): def get_relational_engine():
relational_config = get_relational_config() relational_config = get_relational_config()
return create_relational_engine(**relational_config.to_dict()) return create_relational_engine(**relational_config.to_dict())

View file

@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
from sqlalchemy import text, select, MetaData, Table from sqlalchemy import text, select, MetaData, Table
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from ..ModelBase import Base from ..ModelBase import Base
class SQLAlchemyAdapter(): class SQLAlchemyAdapter():
@ -171,6 +170,27 @@ class SQLAlchemyAdapter():
results = await connection.execute(query) results = await connection.execute(query)
return {result["data_id"]: result["status"] for result in results} return {result["data_id"]: result["status"] for result in results}
async def get_all_data_from_table(self, table_name: str, schema: str = "public"):
async with self.get_async_session() as session:
# Validate inputs to prevent SQL injection
if not table_name.isidentifier():
raise ValueError("Invalid table name")
if schema and not schema.isidentifier():
raise ValueError("Invalid schema name")
if self.engine.dialect.name == "sqlite":
table = await self.get_table(table_name)
else:
table = await self.get_table(table_name, schema)
# Query all data from the table
query = select(table)
result = await session.execute(query)
# Fetch all rows as a list of dictionaries
rows = result.mappings().all()
return rows
async def execute_query(self, query): async def execute_query(self, query):
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
result = await connection.execute(text(query)) result = await connection.execute(text(query))
@ -205,7 +225,6 @@ class SQLAlchemyAdapter():
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
LocalStorage.remove(self.db_path) LocalStorage.remove(self.db_path)
self.db_path = None
else: else:
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
schema_list = await self.get_schema_list() schema_list = await self.get_schema_list()

View file

@ -1,12 +1,12 @@
from ...relational.ModelBase import Base
from ..get_vector_engine import get_vector_engine, get_vectordb_config from ..get_vector_engine import get_vector_engine, get_vectordb_config
from sqlalchemy import text from sqlalchemy import text
async def create_db_and_tables(): async def create_db_and_tables():
vector_config = get_vectordb_config() vector_config = get_vectordb_config()
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
if vector_config.vector_db_provider == "pgvector": if vector_config.vector_db_provider == "pgvector":
await vector_engine.create_database()
async with vector_engine.engine.begin() as connection: async with vector_engine.engine.begin() as connection:
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))

View file

@ -1,9 +1,12 @@
import os import os
from functools import lru_cache
import dlt import dlt
from typing import Union from typing import Union
from cognee.infrastructure.databases.relational import get_relational_config from cognee.infrastructure.databases.relational import get_relational_config
@lru_cache
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]: def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
""" """
Handles propagation of the cognee database configuration to the dlt library Handles propagation of the cognee database configuration to the dlt library

View file

@ -1,6 +1,7 @@
import dlt import dlt
import cognee.modules.ingestion as ingestion import cognee.modules.ingestion as ingestion
from uuid import UUID
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
@ -17,25 +18,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
) )
@dlt.resource(standalone = True, merge_key = "id") @dlt.resource(standalone = True, merge_key = "id")
async def data_resources(file_paths: str, user: User): async def data_resources(file_paths: str):
for file_path in file_paths: for file_path in file_paths:
with open(file_path.replace("file://", ""), mode = "rb") as file: with open(file_path.replace("file://", ""), mode = "rb") as file:
classified_data = ingestion.classify(file) classified_data = ingestion.classify(file)
data_id = ingestion.identify(classified_data) data_id = ingestion.identify(classified_data)
file_metadata = classified_data.get_metadata() file_metadata = classified_data.get_metadata()
yield {
"id": data_id,
"name": file_metadata["name"],
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
}
from sqlalchemy import select async def data_storing(table_name, dataset_name, user: User):
from cognee.modules.data.models import Data db_engine = get_relational_engine()
db_engine = get_relational_engine() async with db_engine.get_async_session() as session:
# Read metadata stored with dlt
async with db_engine.get_async_session() as session: files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
for file_metadata in files_metadata:
from sqlalchemy import select
from cognee.modules.data.models import Data
dataset = await create_dataset(dataset_name, user.id, session) dataset = await create_dataset(dataset_name, user.id, session)
data = (await session.execute( data = (await session.execute(
select(Data).filter(Data.id == data_id) select(Data).filter(Data.id == UUID(file_metadata["id"]))
)).scalar_one_or_none() )).scalar_one_or_none()
if data is not None: if data is not None:
@ -48,7 +57,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
await session.commit() await session.commit()
else: else:
data = Data( data = Data(
id = data_id, id = UUID(file_metadata["id"]),
name = file_metadata["name"], name = file_metadata["name"],
raw_data_location = file_metadata["file_path"], raw_data_location = file_metadata["file_path"],
extension = file_metadata["extension"], extension = file_metadata["extension"],
@ -58,25 +67,34 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
dataset.data.append(data) dataset.data.append(data)
await session.commit() await session.commit()
yield { await give_permission_on_document(user, UUID(file_metadata["id"]), "read")
"id": data_id, await give_permission_on_document(user, UUID(file_metadata["id"]), "write")
"name": file_metadata["name"],
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
}
await give_permission_on_document(user, data_id, "read")
await give_permission_on_document(user, data_id, "write")
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id) send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
run_info = pipeline.run(
data_resources(file_paths, user), db_engine = get_relational_engine()
table_name = "file_metadata",
dataset_name = dataset_name, # Note: DLT pipeline has its own event loop, therefore objects created in another event loop
write_disposition = "merge", # can't be used inside the pipeline
) if db_engine.engine.dialect.name == "sqlite":
# To use sqlite with dlt dataset_name must be set to "main".
# Sqlite doesn't support schemas
run_info = pipeline.run(
data_resources(file_paths),
table_name = "file_metadata",
dataset_name = "main",
write_disposition = "merge",
)
else:
run_info = pipeline.run(
data_resources(file_paths),
table_name="file_metadata",
dataset_name=dataset_name,
write_disposition="merge",
)
await data_storing("file_metadata", dataset_name, user)
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id) send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
return run_info return run_info