Fixes to the sqlalchemy adapter
This commit is contained in:
parent
085ca5ece8
commit
7d3e124e4f
4 changed files with 49 additions and 47 deletions
|
|
@ -45,7 +45,7 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
# Has to be loaded in advance, multithreading doesn't work without it.
|
||||
nltk.download("stopwords", quiet=True)
|
||||
stopwords.ensure_loaded()
|
||||
create_task_status_table()
|
||||
await create_task_status_table()
|
||||
|
||||
graph_client = await get_graph_engine()
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class PermissionDeniedException(Exception):
|
|||
|
||||
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
||||
db_engine = get_relational_engine()
|
||||
create_task_status_table()
|
||||
await create_task_status_table()
|
||||
|
||||
if datasets is None or len(datasets) == 0:
|
||||
return await cognify(await db_engine.get_datasets())
|
||||
|
|
|
|||
|
|
@ -43,39 +43,39 @@ class SQLAlchemyAdapter():
|
|||
|
||||
async def get_datasets(self):
|
||||
async with self.engine.connect() as connection:
|
||||
result = await connection.execute(text("SELECT DISTINCT schema_name FROM information_schema.tables;"))
|
||||
tables = [row["schema_name"] for row in result]
|
||||
result = await connection.execute(text("SELECT DISTINCT table_schema FROM information_schema.tables;"))
|
||||
tables = [row[0] for row in result]
|
||||
return list(
|
||||
filter(
|
||||
lambda schema_name: not schema_name.endswith("staging") and schema_name != "cognee",
|
||||
lambda table_schema: not table_schema.endswith("staging") and table_schema != "cognee",
|
||||
tables
|
||||
)
|
||||
)
|
||||
|
||||
def get_files_metadata(self, dataset_name: str):
|
||||
with self.engine.connect() as connection:
|
||||
result = connection.execute(text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;"))
|
||||
return [dict(row) for row in result]
|
||||
|
||||
async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
|
||||
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
|
||||
async with self.engine.connect() as connection:
|
||||
async with self.engine.begin() as connection:
|
||||
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
|
||||
await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"))
|
||||
|
||||
def delete_table(self, table_name: str):
|
||||
with self.engine.connect() as connection:
|
||||
connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
|
||||
async def get_files_metadata(self, dataset_name: str):
|
||||
async with self.engine.connect() as connection:
|
||||
result = await connection.execute(text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;"))
|
||||
return [dict(row) for row in result]
|
||||
|
||||
def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
|
||||
async def delete_table(self, table_name: str):
|
||||
async with self.engine.connect() as connection:
|
||||
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
|
||||
|
||||
async def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
|
||||
columns = ", ".join(data[0].keys())
|
||||
values = ", ".join([f"({', '.join([f':{key}' for key in row.keys()])})" for row in data])
|
||||
insert_query = text(f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};")
|
||||
with self.engine.connect() as connection:
|
||||
connection.execute(insert_query, data)
|
||||
async with self.engine.connect() as connection:
|
||||
await connection.execute(insert_query, data)
|
||||
|
||||
def get_data(self, table_name: str, filters: dict = None):
|
||||
with self.engine.connect() as connection:
|
||||
async def get_data(self, table_name: str, filters: dict = None):
|
||||
async with self.engine.connect() as connection:
|
||||
query = f"SELECT * FROM {table_name}"
|
||||
if filters:
|
||||
filter_conditions = " AND ".join([
|
||||
|
|
@ -84,19 +84,19 @@ class SQLAlchemyAdapter():
|
|||
])
|
||||
query += f" WHERE {filter_conditions};"
|
||||
query = text(query)
|
||||
results = connection.execute(query, filters)
|
||||
results = await connection.execute(query, filters)
|
||||
else:
|
||||
query += ";"
|
||||
query = text(query)
|
||||
results = connection.execute(query)
|
||||
results = await connection.execute(query)
|
||||
return {result["data_id"]: result["status"] for result in results}
|
||||
|
||||
def execute_query(self, query):
|
||||
with self.engine.connect() as connection:
|
||||
result = connection.execute(text(query))
|
||||
async def execute_query(self, query):
|
||||
async with self.engine.connect() as connection:
|
||||
result = await connection.execute(text(query))
|
||||
return [dict(row) for row in result]
|
||||
|
||||
def load_cognify_data(self, data):
|
||||
async def load_cognify_data(self, data):
|
||||
metadata = MetaData()
|
||||
|
||||
cognify_table = Table(
|
||||
|
|
@ -109,21 +109,22 @@ class SQLAlchemyAdapter():
|
|||
Column("document_id_target", String, nullable=True)
|
||||
)
|
||||
|
||||
metadata.create_all(self.engine)
|
||||
async with self.engine.begin() as connection:
|
||||
await connection.run_sync(metadata.create_all)
|
||||
|
||||
insert_query = cognify_table.insert().values(document_id=text(":document_id"))
|
||||
with self.engine.connect() as connection:
|
||||
connection.execute(insert_query, data)
|
||||
async with self.engine.connect() as connection:
|
||||
await connection.execute(insert_query, data)
|
||||
|
||||
def fetch_cognify_data(self, excluded_document_id: str):
|
||||
with self.engine.connect() as connection:
|
||||
connection.execute(text("""
|
||||
async def fetch_cognify_data(self, excluded_document_id: str):
|
||||
async with self.engine.connect() as connection:
|
||||
await connection.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS cognify (
|
||||
document_id STRING,
|
||||
document_id VARCHAR,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT NULL,
|
||||
processed BOOLEAN DEFAULT FALSE,
|
||||
document_id_target STRING NULL
|
||||
document_id_target VARCHAR NULL
|
||||
);
|
||||
"""))
|
||||
query = text("""
|
||||
|
|
@ -131,27 +132,28 @@ class SQLAlchemyAdapter():
|
|||
FROM cognify
|
||||
WHERE document_id != :excluded_document_id AND processed = FALSE;
|
||||
""")
|
||||
records = connection.execute(query, {"excluded_document_id": excluded_document_id}).fetchall()
|
||||
records = await connection.execute(query, {"excluded_document_id": excluded_document_id})
|
||||
records = records.fetchall()
|
||||
|
||||
if records:
|
||||
document_ids = tuple(record["document_id"] for record in records)
|
||||
update_query = text("UPDATE cognify SET processed = TRUE WHERE document_id IN :document_ids;")
|
||||
connection.execute(update_query, {"document_ids": document_ids})
|
||||
await connection.execute(update_query, {"document_ids": document_ids})
|
||||
return [dict(record) for record in records]
|
||||
|
||||
def delete_cognify_data(self):
|
||||
with self.engine.connect() as connection:
|
||||
connection.execute(text("""
|
||||
async def delete_cognify_data(self):
|
||||
async with self.engine.connect() as connection:
|
||||
await connection.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS cognify (
|
||||
document_id STRING,
|
||||
document_id VARCHAR,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT NULL,
|
||||
processed BOOLEAN DEFAULT FALSE,
|
||||
document_id_target STRING NULL
|
||||
document_id_target VARCHAR NULL
|
||||
);
|
||||
"""))
|
||||
connection.execute(text("DELETE FROM cognify;"))
|
||||
connection.execute(text("DROP TABLE cognify;"))
|
||||
await connection.execute(text("DELETE FROM cognify;"))
|
||||
await connection.execute(text("DROP TABLE cognify;"))
|
||||
|
||||
async def drop_tables(self, connection):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
def create_task_status_table():
|
||||
async def create_task_status_table():
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
db_engine.create_table("cognee.cognee", "cognee_task_status", [
|
||||
dict(name = "data_id", type = "STRING"),
|
||||
dict(name = "status", type = "STRING"),
|
||||
dict(name = "created_at", type = "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"),
|
||||
await db_engine.create_table("cognee", "cognee_task_status", [
|
||||
dict(name="data_id", type="VARCHAR"),
|
||||
dict(name="status", type="VARCHAR"),
|
||||
dict(name="created_at", type="TIMESTAMP DEFAULT CURRENT_TIMESTAMP"),
|
||||
])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue