Cog 678 relational database singleton (#38)
1. Set relational database in cognee to be used as singleton and made necessary changes to enable this 2. Added SQLite support to dlt pipeline in ingest_data
This commit is contained in:
commit
34971d16cc
5 changed files with 74 additions and 32 deletions
|
|
@ -1,8 +1,10 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from .config import get_relational_config
|
from .config import get_relational_config
|
||||||
from .create_relational_engine import create_relational_engine
|
from .create_relational_engine import create_relational_engine
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def get_relational_engine():
|
def get_relational_engine():
|
||||||
relational_config = get_relational_config()
|
relational_config = get_relational_config()
|
||||||
|
|
||||||
return create_relational_engine(**relational_config.to_dict())
|
return create_relational_engine(**relational_config.to_dict())
|
||||||
|
|
@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
|
||||||
from sqlalchemy import text, select, MetaData, Table
|
from sqlalchemy import text, select, MetaData, Table
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
from ..ModelBase import Base
|
from ..ModelBase import Base
|
||||||
|
|
||||||
class SQLAlchemyAdapter():
|
class SQLAlchemyAdapter():
|
||||||
|
|
@ -171,6 +170,27 @@ class SQLAlchemyAdapter():
|
||||||
results = await connection.execute(query)
|
results = await connection.execute(query)
|
||||||
return {result["data_id"]: result["status"] for result in results}
|
return {result["data_id"]: result["status"] for result in results}
|
||||||
|
|
||||||
|
async def get_all_data_from_table(self, table_name: str, schema: str = "public"):
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
# Validate inputs to prevent SQL injection
|
||||||
|
if not table_name.isidentifier():
|
||||||
|
raise ValueError("Invalid table name")
|
||||||
|
if schema and not schema.isidentifier():
|
||||||
|
raise ValueError("Invalid schema name")
|
||||||
|
|
||||||
|
if self.engine.dialect.name == "sqlite":
|
||||||
|
table = await self.get_table(table_name)
|
||||||
|
else:
|
||||||
|
table = await self.get_table(table_name, schema)
|
||||||
|
|
||||||
|
# Query all data from the table
|
||||||
|
query = select(table)
|
||||||
|
result = await session.execute(query)
|
||||||
|
|
||||||
|
# Fetch all rows as a list of dictionaries
|
||||||
|
rows = result.mappings().all()
|
||||||
|
return rows
|
||||||
|
|
||||||
async def execute_query(self, query):
|
async def execute_query(self, query):
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
result = await connection.execute(text(query))
|
result = await connection.execute(text(query))
|
||||||
|
|
@ -205,7 +225,6 @@ class SQLAlchemyAdapter():
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
|
||||||
LocalStorage.remove(self.db_path)
|
LocalStorage.remove(self.db_path)
|
||||||
self.db_path = None
|
|
||||||
else:
|
else:
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
schema_list = await self.get_schema_list()
|
schema_list = await self.get_schema_list()
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from ...relational.ModelBase import Base
|
|
||||||
from ..get_vector_engine import get_vector_engine, get_vectordb_config
|
from ..get_vector_engine import get_vector_engine, get_vectordb_config
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
async def create_db_and_tables():
|
async def create_db_and_tables():
|
||||||
vector_config = get_vectordb_config()
|
vector_config = get_vectordb_config()
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
if vector_config.vector_db_provider == "pgvector":
|
if vector_config.vector_db_provider == "pgvector":
|
||||||
await vector_engine.create_database()
|
|
||||||
async with vector_engine.engine.begin() as connection:
|
async with vector_engine.engine.begin() as connection:
|
||||||
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
import dlt
|
import dlt
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_config
|
from cognee.infrastructure.databases.relational import get_relational_config
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
|
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
|
||||||
"""
|
"""
|
||||||
Handles propagation of the cognee database configuration to the dlt library
|
Handles propagation of the cognee database configuration to the dlt library
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import dlt
|
import dlt
|
||||||
import cognee.modules.ingestion as ingestion
|
import cognee.modules.ingestion as ingestion
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
@ -17,25 +18,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
)
|
)
|
||||||
|
|
||||||
@dlt.resource(standalone = True, merge_key = "id")
|
@dlt.resource(standalone = True, merge_key = "id")
|
||||||
async def data_resources(file_paths: str, user: User):
|
async def data_resources(file_paths: str):
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
||||||
classified_data = ingestion.classify(file)
|
classified_data = ingestion.classify(file)
|
||||||
|
|
||||||
data_id = ingestion.identify(classified_data)
|
data_id = ingestion.identify(classified_data)
|
||||||
|
|
||||||
file_metadata = classified_data.get_metadata()
|
file_metadata = classified_data.get_metadata()
|
||||||
|
yield {
|
||||||
|
"id": data_id,
|
||||||
|
"name": file_metadata["name"],
|
||||||
|
"file_path": file_metadata["file_path"],
|
||||||
|
"extension": file_metadata["extension"],
|
||||||
|
"mime_type": file_metadata["mime_type"],
|
||||||
|
}
|
||||||
|
|
||||||
from sqlalchemy import select
|
async def data_storing(table_name, dataset_name, user: User):
|
||||||
from cognee.modules.data.models import Data
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
async with db_engine.get_async_session() as session:
|
||||||
|
# Read metadata stored with dlt
|
||||||
async with db_engine.get_async_session() as session:
|
files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
|
||||||
|
for file_metadata in files_metadata:
|
||||||
|
from sqlalchemy import select
|
||||||
|
from cognee.modules.data.models import Data
|
||||||
dataset = await create_dataset(dataset_name, user.id, session)
|
dataset = await create_dataset(dataset_name, user.id, session)
|
||||||
|
|
||||||
data = (await session.execute(
|
data = (await session.execute(
|
||||||
select(Data).filter(Data.id == data_id)
|
select(Data).filter(Data.id == UUID(file_metadata["id"]))
|
||||||
)).scalar_one_or_none()
|
)).scalar_one_or_none()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
|
|
@ -48,7 +57,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
await session.commit()
|
await session.commit()
|
||||||
else:
|
else:
|
||||||
data = Data(
|
data = Data(
|
||||||
id = data_id,
|
id = UUID(file_metadata["id"]),
|
||||||
name = file_metadata["name"],
|
name = file_metadata["name"],
|
||||||
raw_data_location = file_metadata["file_path"],
|
raw_data_location = file_metadata["file_path"],
|
||||||
extension = file_metadata["extension"],
|
extension = file_metadata["extension"],
|
||||||
|
|
@ -58,25 +67,34 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
dataset.data.append(data)
|
dataset.data.append(data)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
yield {
|
await give_permission_on_document(user, UUID(file_metadata["id"]), "read")
|
||||||
"id": data_id,
|
await give_permission_on_document(user, UUID(file_metadata["id"]), "write")
|
||||||
"name": file_metadata["name"],
|
|
||||||
"file_path": file_metadata["file_path"],
|
|
||||||
"extension": file_metadata["extension"],
|
|
||||||
"mime_type": file_metadata["mime_type"],
|
|
||||||
}
|
|
||||||
|
|
||||||
await give_permission_on_document(user, data_id, "read")
|
|
||||||
await give_permission_on_document(user, data_id, "write")
|
|
||||||
|
|
||||||
|
|
||||||
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
||||||
run_info = pipeline.run(
|
|
||||||
data_resources(file_paths, user),
|
db_engine = get_relational_engine()
|
||||||
table_name = "file_metadata",
|
|
||||||
dataset_name = dataset_name,
|
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||||
write_disposition = "merge",
|
# can't be used inside the pipeline
|
||||||
)
|
if db_engine.engine.dialect.name == "sqlite":
|
||||||
|
# To use sqlite with dlt dataset_name must be set to "main".
|
||||||
|
# Sqlite doesn't support schemas
|
||||||
|
run_info = pipeline.run(
|
||||||
|
data_resources(file_paths),
|
||||||
|
table_name = "file_metadata",
|
||||||
|
dataset_name = "main",
|
||||||
|
write_disposition = "merge",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
run_info = pipeline.run(
|
||||||
|
data_resources(file_paths),
|
||||||
|
table_name="file_metadata",
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
write_disposition="merge",
|
||||||
|
)
|
||||||
|
|
||||||
|
await data_storing("file_metadata", dataset_name, user)
|
||||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
||||||
|
|
||||||
return run_info
|
return run_info
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue