Fixes to the sqlalchemy adapter

This commit is contained in:
Vasilije 2024-08-05 10:31:52 +02:00
parent 085ca5ece8
commit 7d3e124e4f
4 changed files with 49 additions and 47 deletions

View file

@ -45,7 +45,7 @@ async def cognify(datasets: Union[str, List[str]] = None):
# Has to be loaded in advance, multithreading doesn't work without it.
nltk.download("stopwords", quiet=True)
stopwords.ensure_loaded()
create_task_status_table()
await create_task_status_table()
graph_client = await get_graph_engine()

View file

@ -35,7 +35,7 @@ class PermissionDeniedException(Exception):
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
db_engine = get_relational_engine()
create_task_status_table()
await create_task_status_table()
if datasets is None or len(datasets) == 0:
return await cognify(await db_engine.get_datasets())

View file

@ -43,39 +43,39 @@ class SQLAlchemyAdapter():
async def get_datasets(self):
async with self.engine.connect() as connection:
result = await connection.execute(text("SELECT DISTINCT schema_name FROM information_schema.tables;"))
tables = [row["schema_name"] for row in result]
result = await connection.execute(text("SELECT DISTINCT table_schema FROM information_schema.tables;"))
tables = [row[0] for row in result]
return list(
filter(
lambda schema_name: not schema_name.endswith("staging") and schema_name != "cognee",
lambda table_schema: not table_schema.endswith("staging") and table_schema != "cognee",
tables
)
)
def get_files_metadata(self, dataset_name: str):
with self.engine.connect() as connection:
result = connection.execute(text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;"))
return [dict(row) for row in result]
async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
async with self.engine.connect() as connection:
async with self.engine.begin() as connection:
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"))
def delete_table(self, table_name: str):
with self.engine.connect() as connection:
connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
async def get_files_metadata(self, dataset_name: str):
async with self.engine.connect() as connection:
result = await connection.execute(text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;"))
return [dict(row) for row in result]
def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
async def delete_table(self, table_name: str):
async with self.engine.connect() as connection:
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
async def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
columns = ", ".join(data[0].keys())
values = ", ".join([f"({', '.join([f':{key}' for key in row.keys()])})" for row in data])
insert_query = text(f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};")
with self.engine.connect() as connection:
connection.execute(insert_query, data)
async with self.engine.connect() as connection:
await connection.execute(insert_query, data)
def get_data(self, table_name: str, filters: dict = None):
with self.engine.connect() as connection:
async def get_data(self, table_name: str, filters: dict = None):
async with self.engine.connect() as connection:
query = f"SELECT * FROM {table_name}"
if filters:
filter_conditions = " AND ".join([
@ -84,19 +84,19 @@ class SQLAlchemyAdapter():
])
query += f" WHERE {filter_conditions};"
query = text(query)
results = connection.execute(query, filters)
results = await connection.execute(query, filters)
else:
query += ";"
query = text(query)
results = connection.execute(query)
results = await connection.execute(query)
return {result["data_id"]: result["status"] for result in results}
def execute_query(self, query):
with self.engine.connect() as connection:
result = connection.execute(text(query))
async def execute_query(self, query):
async with self.engine.connect() as connection:
result = await connection.execute(text(query))
return [dict(row) for row in result]
def load_cognify_data(self, data):
async def load_cognify_data(self, data):
metadata = MetaData()
cognify_table = Table(
@ -109,21 +109,22 @@ class SQLAlchemyAdapter():
Column("document_id_target", String, nullable=True)
)
metadata.create_all(self.engine)
async with self.engine.begin() as connection:
await connection.run_sync(metadata.create_all)
insert_query = cognify_table.insert().values(document_id=text(":document_id"))
with self.engine.connect() as connection:
connection.execute(insert_query, data)
async with self.engine.connect() as connection:
await connection.execute(insert_query, data)
def fetch_cognify_data(self, excluded_document_id: str):
with self.engine.connect() as connection:
connection.execute(text("""
async def fetch_cognify_data(self, excluded_document_id: str):
async with self.engine.connect() as connection:
await connection.execute(text("""
CREATE TABLE IF NOT EXISTS cognify (
document_id STRING,
document_id VARCHAR,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT NULL,
processed BOOLEAN DEFAULT FALSE,
document_id_target STRING NULL
document_id_target VARCHAR NULL
);
"""))
query = text("""
@ -131,27 +132,28 @@ class SQLAlchemyAdapter():
FROM cognify
WHERE document_id != :excluded_document_id AND processed = FALSE;
""")
records = connection.execute(query, {"excluded_document_id": excluded_document_id}).fetchall()
records = await connection.execute(query, {"excluded_document_id": excluded_document_id})
records = records.fetchall()
if records:
document_ids = tuple(record["document_id"] for record in records)
update_query = text("UPDATE cognify SET processed = TRUE WHERE document_id IN :document_ids;")
connection.execute(update_query, {"document_ids": document_ids})
await connection.execute(update_query, {"document_ids": document_ids})
return [dict(record) for record in records]
def delete_cognify_data(self):
with self.engine.connect() as connection:
connection.execute(text("""
async def delete_cognify_data(self):
async with self.engine.connect() as connection:
await connection.execute(text("""
CREATE TABLE IF NOT EXISTS cognify (
document_id STRING,
document_id VARCHAR,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT NULL,
processed BOOLEAN DEFAULT FALSE,
document_id_target STRING NULL
document_id_target VARCHAR NULL
);
"""))
connection.execute(text("DELETE FROM cognify;"))
connection.execute(text("DROP TABLE cognify;"))
await connection.execute(text("DELETE FROM cognify;"))
await connection.execute(text("DROP TABLE cognify;"))
async def drop_tables(self, connection):
try:

View file

@ -1,10 +1,10 @@
from cognee.infrastructure.databases.relational import get_relational_engine
def create_task_status_table():
async def create_task_status_table():
db_engine = get_relational_engine()
db_engine.create_table("cognee.cognee", "cognee_task_status", [
dict(name = "data_id", type = "STRING"),
dict(name = "status", type = "STRING"),
dict(name = "created_at", type = "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"),
await db_engine.create_table("cognee", "cognee_task_status", [
dict(name="data_id", type="VARCHAR"),
dict(name="status", type="VARCHAR"),
dict(name="created_at", type="TIMESTAMP DEFAULT CURRENT_TIMESTAMP"),
])