cognee/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py
2025-01-16 13:28:35 +01:00

308 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from os import path
import logging
from uuid import UUID
from typing import Optional
from typing import AsyncGenerator, List
from contextlib import asynccontextmanager
from sqlalchemy import text, select, MetaData, Table, delete
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.modules.data.models.Data import Data
from ..ModelBase import Base
logger = logging.getLogger(__name__)
class SQLAlchemyAdapter:
def __init__(self, connection_string: str):
self.db_path: str = None
self.db_uri: str = connection_string
self.engine = create_async_engine(connection_string)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
if self.engine.dialect.name == "sqlite":
self.db_path = connection_string.split("///")[1]
@asynccontextmanager
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
async_session_maker = self.sessionmaker
async with async_session_maker() as session:
try:
yield session
finally:
await session.close() # Ensure the session is closed
def get_session(self):
session_maker = self.sessionmaker
with session_maker() as session:
try:
yield session
finally:
session.close() # Ensure the session is closed
async def get_datasets(self):
from cognee.modules.data.models import Dataset
async with self.get_async_session() as session:
result = await session.execute(select(Dataset).options(joinedload(Dataset.data)))
datasets = result.unique().scalars().all()
return datasets
async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
async with self.engine.begin() as connection:
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
await connection.execute(
text(
f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"
)
)
await connection.close()
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
# SQLite doesnt support schema namespaces and the CASCADE keyword.
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
else:
await connection.execute(
text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;")
)
async def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
columns = ", ".join(data[0].keys())
values = ", ".join([f"({', '.join([f':{key}' for key in row.keys()])})" for row in data])
insert_query = text(f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};")
async with self.engine.begin() as connection:
await connection.execute(insert_query, data)
await connection.close()
async def get_schema_list(self) -> List[str]:
"""
Return a list of all schema names in database
"""
if self.engine.dialect.name == "postgresql":
async with self.engine.begin() as connection:
result = await connection.execute(
text(
"""
SELECT schema_name FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
"""
)
)
return [schema[0] for schema in result.fetchall()]
return []
async def delete_entity_by_id(
self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"
):
"""
Delete entity in given table based on id. Table must have an id Column.
"""
if self.engine.dialect.name == "sqlite":
async with self.get_async_session() as session:
TableModel = await self.get_table(table_name, schema_name)
# Foreign key constraints are disabled by default in SQLite (for backwards compatibility),
# so must be enabled for each database connection/session separately.
await session.execute(text("PRAGMA foreign_keys = ON;"))
await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
await session.commit()
else:
async with self.get_async_session() as session:
TableModel = await self.get_table(table_name, schema_name)
await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
await session.commit()
async def delete_data_entity(self, data_id: UUID):
"""
Delete data and local files related to data if there are no references to it anymore.
"""
async with self.get_async_session() as session:
if self.engine.dialect.name == "sqlite":
# Foreign key constraints are disabled by default in SQLite (for backwards compatibility),
# so must be enabled for each database connection/session separately.
await session.execute(text("PRAGMA foreign_keys = ON;"))
try:
data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one()
except (ValueError, NoResultFound) as e:
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
# Check if other data objects point to the same raw data location
raw_data_location_entities = (
await session.execute(
select(Data.raw_data_location).where(
Data.raw_data_location == data_entity.raw_data_location
)
)
).all()
# Don't delete local file unless this is the only reference to the file in the database
if len(raw_data_location_entities) == 1:
# delete local file only if it's created by cognee
from cognee.base_config import get_base_config
config = get_base_config()
if config.data_root_directory in raw_data_location_entities[0].raw_data_location:
if os.path.exists(raw_data_location_entities[0].raw_data_location):
os.remove(raw_data_location_entities[0].raw_data_location)
else:
# Report bug as file should exist
logger.error("Local file which should exist can't be found.")
await session.execute(delete(Data).where(Data.id == data_id))
await session.commit()
async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table:
"""
Dynamically loads a table using the given table name and schema name.
"""
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
if table_name in Base.metadata.tables:
return Base.metadata.tables[table_name]
else:
raise EntityNotFoundError(message=f"Table '{table_name}' not found.")
else:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect, schema=schema_name)
# Define the full table name
full_table_name = f"{schema_name}.{table_name}"
# Check if table is in list of tables for the given schema
if full_table_name in metadata.tables:
return metadata.tables[full_table_name]
raise EntityNotFoundError(message=f"Table '{full_table_name}' not found.")
async def get_table_names(self) -> List[str]:
"""
Return a list of all tables names in database
"""
table_names = []
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
await connection.run_sync(Base.metadata.reflect)
for table in Base.metadata.tables:
table_names.append(str(table))
else:
schema_list = await self.get_schema_list()
# Create a MetaData instance to load table information
metadata = MetaData()
# Drop all tables from all schemas
for schema_name in schema_list:
# Load the schema information into the MetaData object
await connection.run_sync(metadata.reflect, schema=schema_name)
for table in metadata.sorted_tables:
table_names.append(str(table))
metadata.clear()
return table_names
async def get_data(self, table_name: str, filters: dict = None):
async with self.engine.begin() as connection:
query = f"SELECT * FROM {table_name}"
if filters:
filter_conditions = " AND ".join(
[
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})"
if isinstance(value, list)
else f"{key} = :{key}"
for key, value in filters.items()
]
)
query += f" WHERE {filter_conditions};"
query = text(query)
results = await connection.execute(query, filters)
else:
query += ";"
query = text(query)
results = await connection.execute(query)
return {result["data_id"]: result["status"] for result in results}
async def get_all_data_from_table(self, table_name: str, schema: str = "public"):
async with self.get_async_session() as session:
# Validate inputs to prevent SQL injection
if not table_name.isidentifier():
raise ValueError("Invalid table name")
if schema and not schema.isidentifier():
raise ValueError("Invalid schema name")
if self.engine.dialect.name == "sqlite":
table = await self.get_table(table_name)
else:
table = await self.get_table(table_name, schema)
# Query all data from the table
query = select(table)
result = await session.execute(query)
# Fetch all rows as a list of dictionaries
rows = result.mappings().all()
return rows
async def execute_query(self, query):
async with self.engine.begin() as connection:
result = await connection.execute(text(query))
return [dict(row) for row in result]
async def drop_tables(self):
async with self.engine.begin() as connection:
try:
await connection.execute(text("DROP TABLE IF EXISTS group_permission CASCADE"))
await connection.execute(text("DROP TABLE IF EXISTS permissions CASCADE"))
# Add more DROP TABLE statements for other tables as needed
print("Database tables dropped successfully.")
except Exception as e:
print(f"Error dropping database tables: {e}")
async def create_database(self):
if self.engine.dialect.name == "sqlite":
from cognee.infrastructure.files.storage import LocalStorage
db_directory = path.dirname(self.db_path)
LocalStorage.ensure_directory_exists(db_directory)
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(Base.metadata.create_all)
async def delete_database(self):
try:
if self.engine.dialect.name == "sqlite":
from cognee.infrastructure.files.storage import LocalStorage
LocalStorage.remove(self.db_path)
else:
async with self.engine.begin() as connection:
schema_list = await self.get_schema_list()
# Create a MetaData instance to load table information
metadata = MetaData()
# Drop all tables from all schemas
for schema_name in schema_list:
# Load the schema information into the MetaData object
await connection.run_sync(metadata.reflect, schema=schema_name)
for table in metadata.sorted_tables:
drop_table_query = text(
f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE"
)
await connection.execute(drop_table_query)
metadata.clear()
except Exception as e:
print(f"Error deleting database: {e}")
print("Database deleted successfully.")