From b60f2603f4632c1d678d5677a07ba8c8eab4ddc7 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 20 Nov 2024 17:11:23 +0100 Subject: [PATCH] test: Add test for pgvector to confirm database deletion is working Added assert to verify all tables in database have been cleared. Added method to SqlAlchemyAdapter to get all table names in database. Test COG-488 --- .../sqlalchemy/SqlAlchemyAdapter.py | 23 +++++++++++++++++++ cognee/tests/test_pgvector.py | 6 ++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index febfe1931..aa2a022d3 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -130,6 +130,29 @@ class SQLAlchemyAdapter(): return metadata.tables[full_table_name] raise ValueError(f"Table '{full_table_name}' not found.") + async def get_table_names(self) -> List[str]: + """ + Return a list of all tables names in database + """ + table_names = [] + async with self.engine.begin() as connection: + if self.engine.dialect.name == "sqlite": + await connection.run_sync(Base.metadata.reflect) + for table in Base.metadata.tables: + table_names.append(str(table)) + else: + schema_list = await self.get_schema_list() + # Create a MetaData instance to load table information + metadata = MetaData() + # Drop all tables from all schemas + for schema_name in schema_list: + # Load the schema information into the MetaData object + await connection.run_sync(metadata.reflect, schema=schema_name) + for table in metadata.sorted_tables: + table_names.append(str(table)) + metadata.clear() + return table_names + async def get_data(self, table_name: str, filters: dict = None): async with self.engine.begin() as connection: diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 1466e195f..18d9e2c91 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -87,9 +87,13 @@ async def main(): print(f"{result}\n") history = await cognee.get_search_history() - assert len(history) == 6, "Search history is not correct." + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + tables_in_database = await vector_engine.get_table_names() + assert len(tables_in_database) == 0, "The database is not empty" + if __name__ == "__main__": import asyncio