From 7d3e124e4f3149a2175d0004eae37ddc826cd372 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Mon, 5 Aug 2024 10:31:52 +0200 Subject: [PATCH] Fixes to the sqlalchemy adapter --- cognee/api/v1/cognify/cognify.py | 2 +- cognee/api/v1/cognify/cognify_v2.py | 2 +- .../sqlalchemy/SqlAlchemyAdapter.py | 82 ++++++++++--------- .../modules/tasks/create_task_status_table.py | 10 +-- 4 files changed, 49 insertions(+), 47 deletions(-) diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 77f19f1ec..e695be405 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -45,7 +45,7 @@ async def cognify(datasets: Union[str, List[str]] = None): # Has to be loaded in advance, multithreading doesn't work without it. nltk.download("stopwords", quiet=True) stopwords.ensure_loaded() - create_task_status_table() + await create_task_status_table() graph_client = await get_graph_engine() diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 6b2502c46..7c6139b3c 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -35,7 +35,7 @@ class PermissionDeniedException(Exception): async def cognify(datasets: Union[str, list[str]] = None, user: User = None): db_engine = get_relational_engine() - create_task_status_table() + await create_task_status_table() if datasets is None or len(datasets) == 0: return await cognify(await db_engine.get_datasets()) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 7de1ca785..5b5b1fcbe 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -43,39 +43,39 @@ class SQLAlchemyAdapter(): async def get_datasets(self): async with self.engine.connect() as connection: - result = await connection.execute(text("SELECT DISTINCT schema_name FROM information_schema.tables;")) - tables = [row["schema_name"] for row in result] + result = await connection.execute(text("SELECT DISTINCT table_schema FROM information_schema.tables;")) + tables = [row[0] for row in result] return list( filter( - lambda schema_name: not schema_name.endswith("staging") and schema_name != "cognee", + lambda table_schema: not table_schema.endswith("staging") and table_schema != "cognee", tables ) ) - def get_files_metadata(self, dataset_name: str): - with self.engine.connect() as connection: - result = connection.execute(text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;")) - return [dict(row) for row in result] - async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]): fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config] - async with self.engine.connect() as connection: + async with self.engine.begin() as connection: await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};")) await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});")) - def delete_table(self, table_name: str): - with self.engine.connect() as connection: - connection.execute(text(f"DROP TABLE IF EXISTS {table_name};")) + async def get_files_metadata(self, dataset_name: str): + async with self.engine.connect() as connection: + result = await connection.execute(text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;")) + return [dict(row) for row in result] - def insert_data(self, schema_name: str, table_name: str, data: list[dict]): + async def delete_table(self, table_name: str): + async with self.engine.connect() as connection: + await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};")) + + async def insert_data(self, schema_name: str, table_name: str, data: list[dict]): columns = ", ".join(data[0].keys()) values = ", ".join([f"({', '.join([f':{key}' for key in row.keys()])})" for row in data]) insert_query = text(f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};") - with self.engine.connect() as connection: - connection.execute(insert_query, data) + async with self.engine.connect() as connection: + await connection.execute(insert_query, data) - def get_data(self, table_name: str, filters: dict = None): - with self.engine.connect() as connection: + async def get_data(self, table_name: str, filters: dict = None): + async with self.engine.connect() as connection: query = f"SELECT * FROM {table_name}" if filters: filter_conditions = " AND ".join([ @@ -84,19 +84,19 @@ class SQLAlchemyAdapter(): ]) query += f" WHERE {filter_conditions};" query = text(query) - results = connection.execute(query, filters) + results = await connection.execute(query, filters) else: query += ";" query = text(query) - results = connection.execute(query) + results = await connection.execute(query) return {result["data_id"]: result["status"] for result in results} - def execute_query(self, query): - with self.engine.connect() as connection: - result = connection.execute(text(query)) + async def execute_query(self, query): + async with self.engine.connect() as connection: + result = await connection.execute(text(query)) return [dict(row) for row in result] - def load_cognify_data(self, data): + async def load_cognify_data(self, data): metadata = MetaData() cognify_table = Table( @@ -109,21 +109,22 @@ class SQLAlchemyAdapter(): Column("document_id_target", String, nullable=True) ) - metadata.create_all(self.engine) + async with self.engine.begin() as connection: + await connection.run_sync(metadata.create_all) insert_query = cognify_table.insert().values(document_id=text(":document_id")) - with self.engine.connect() as connection: - connection.execute(insert_query, data) + async with self.engine.connect() as connection: + await connection.execute(insert_query, data) - def fetch_cognify_data(self, excluded_document_id: str): - with self.engine.connect() as connection: - connection.execute(text(""" + async def fetch_cognify_data(self, excluded_document_id: str): + async with self.engine.connect() as connection: + await connection.execute(text(""" CREATE TABLE IF NOT EXISTS cognify ( - document_id STRING, + document_id VARCHAR, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT NULL, processed BOOLEAN DEFAULT FALSE, - document_id_target STRING NULL + document_id_target VARCHAR NULL ); """)) query = text(""" @@ -131,27 +132,28 @@ class SQLAlchemyAdapter(): FROM cognify WHERE document_id != :excluded_document_id AND processed = FALSE; """) - records = connection.execute(query, {"excluded_document_id": excluded_document_id}).fetchall() + records = await connection.execute(query, {"excluded_document_id": excluded_document_id}) + records = records.fetchall() if records: document_ids = tuple(record["document_id"] for record in records) update_query = text("UPDATE cognify SET processed = TRUE WHERE document_id IN :document_ids;") - connection.execute(update_query, {"document_ids": document_ids}) + await connection.execute(update_query, {"document_ids": document_ids}) return [dict(record) for record in records] - def delete_cognify_data(self): - with self.engine.connect() as connection: - connection.execute(text(""" + async def delete_cognify_data(self): + async with self.engine.connect() as connection: + await connection.execute(text(""" CREATE TABLE IF NOT EXISTS cognify ( - document_id STRING, + document_id VARCHAR, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT NULL, processed BOOLEAN DEFAULT FALSE, - document_id_target STRING NULL + document_id_target VARCHAR NULL ); """)) - connection.execute(text("DELETE FROM cognify;")) - connection.execute(text("DROP TABLE cognify;")) + await connection.execute(text("DELETE FROM cognify;")) + await connection.execute(text("DROP TABLE cognify;")) async def drop_tables(self, connection): try: diff --git a/cognee/modules/tasks/create_task_status_table.py b/cognee/modules/tasks/create_task_status_table.py index 42b6a0333..48763600d 100644 --- a/cognee/modules/tasks/create_task_status_table.py +++ b/cognee/modules/tasks/create_task_status_table.py @@ -1,10 +1,10 @@ from cognee.infrastructure.databases.relational import get_relational_engine -def create_task_status_table(): +async def create_task_status_table(): db_engine = get_relational_engine() - db_engine.create_table("cognee.cognee", "cognee_task_status", [ - dict(name = "data_id", type = "STRING"), - dict(name = "status", type = "STRING"), - dict(name = "created_at", type = "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), + await db_engine.create_table("cognee", "cognee_task_status", [ + dict(name="data_id", type="VARCHAR"), + dict(name="status", type="VARCHAR"), + dict(name="created_at", type="TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), ])