Merge pull request #391 from topoteretes/COG-475-local-file-endpoint-deletion
Cog 475 local file endpoint deletion
This commit is contained in:
commit
51559f055b
4 changed files with 93 additions and 15 deletions
|
|
@ -1,15 +1,23 @@
|
|||
import os
|
||||
from os import path
|
||||
import logging
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, List
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy import text, select, MetaData, Table
|
||||
from sqlalchemy import text, select, MetaData, Table, delete
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.modules.data.models.Data import Data
|
||||
|
||||
from ..ModelBase import Base
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SQLAlchemyAdapter():
|
||||
def __init__(self, connection_string: str):
|
||||
self.db_path: str = None
|
||||
|
|
@ -86,9 +94,9 @@ class SQLAlchemyAdapter():
|
|||
return [schema[0] for schema in result.fetchall()]
|
||||
return []
|
||||
|
||||
async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
|
||||
async def delete_entity_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
|
||||
"""
|
||||
Delete data in given table based on id. Table must have an id Column.
|
||||
Delete entity in given table based on id. Table must have an id Column.
|
||||
"""
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
async with self.get_async_session() as session:
|
||||
|
|
@ -107,6 +115,42 @@ class SQLAlchemyAdapter():
|
|||
await session.commit()
|
||||
|
||||
|
||||
async def delete_data_entity(self, data_id: UUID):
|
||||
"""
|
||||
Delete data and local files related to data if there are no references to it anymore.
|
||||
"""
|
||||
async with self.get_async_session() as session:
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
# Foreign key constraints are disabled by default in SQLite (for backwards compatibility),
|
||||
# so must be enabled for each database connection/session separately.
|
||||
await session.execute(text("PRAGMA foreign_keys = ON;"))
|
||||
|
||||
try:
|
||||
data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one()
|
||||
except (ValueError, NoResultFound) as e:
|
||||
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
|
||||
|
||||
# Check if other data objects point to the same raw data location
|
||||
raw_data_location_entities = (await session.execute(
|
||||
select(Data.raw_data_location).where(Data.raw_data_location == data_entity.raw_data_location))).all()
|
||||
|
||||
# Don't delete local file unless this is the only reference to the file in the database
|
||||
if len(raw_data_location_entities) == 1:
|
||||
|
||||
# delete local file only if it's created by cognee
|
||||
from cognee.base_config import get_base_config
|
||||
config = get_base_config()
|
||||
|
||||
if config.data_root_directory in raw_data_location_entities[0].raw_data_location:
|
||||
if os.path.exists(raw_data_location_entities[0].raw_data_location):
|
||||
os.remove(raw_data_location_entities[0].raw_data_location)
|
||||
else:
|
||||
# Report bug as file should exist
|
||||
logger.error("Local file which should exist can't be found.")
|
||||
|
||||
await session.execute(delete(Data).where(Data.id == data_id))
|
||||
await session.commit()
|
||||
|
||||
async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table:
|
||||
"""
|
||||
Dynamically loads a table using the given table name and schema name.
|
||||
|
|
|
|||
|
|
@ -17,4 +17,4 @@ async def delete_data(data: Data):
|
|||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
return await db_engine.delete_data_by_id(data.__tablename__, data.id)
|
||||
return await db_engine.delete_data_entity(data.id)
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@ from cognee.infrastructure.databases.relational import get_relational_engine
|
|||
async def delete_dataset(dataset: Dataset):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
return await db_engine.delete_data_by_id(dataset.__tablename__, dataset.id)
|
||||
return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,53 @@ import os
|
|||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
async def test_local_file_deletion(data_text, file_location):
|
||||
from sqlalchemy import select
|
||||
import hashlib
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
engine = get_relational_engine()
|
||||
|
||||
async with engine.get_async_session() as session:
|
||||
# Get hash of data contents
|
||||
encoded_text = data_text.encode("utf-8")
|
||||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||
# Get data entry from database based on hash contents
|
||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||
assert os.path.isfile(data.raw_data_location), f"Data location doesn't exist: {data.raw_data_location}"
|
||||
# Test deletion of data along with local files created by cognee
|
||||
await engine.delete_data_entity(data.id)
|
||||
assert not os.path.exists(
|
||||
data.raw_data_location), f"Data location still exists after deletion: {data.raw_data_location}"
|
||||
|
||||
async with engine.get_async_session() as session:
|
||||
# Get data entry from database based on file path
|
||||
data = (await session.scalars(select(Data).where(Data.raw_data_location == file_location))).one()
|
||||
assert os.path.isfile(data.raw_data_location), f"Data location doesn't exist: {data.raw_data_location}"
|
||||
# Test local files not created by cognee won't get deleted
|
||||
await engine.delete_data_entity(data.id)
|
||||
assert os.path.exists(data.raw_data_location), f"Data location doesn't exists: {data.raw_data_location}"
|
||||
|
||||
async def test_getting_of_documents(dataset_name_1):
|
||||
# Test getting of documents for search per dataset
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
user = await get_default_user()
|
||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||
assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||
|
||||
# Test getting of documents for search when no dataset is provided
|
||||
user = await get_default_user()
|
||||
document_ids = await get_document_ids_for_user(user.id)
|
||||
assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||
|
||||
|
||||
async def main():
|
||||
cognee.config.set_vector_db_config(
|
||||
|
|
@ -67,16 +108,7 @@ async def main():
|
|||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
# Test getting of documents for search per dataset
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
user = await get_default_user()
|
||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||
assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||
|
||||
# Test getting of documents for search when no dataset is provided
|
||||
user = await get_default_user()
|
||||
document_ids = await get_document_ids_for_user(user.id)
|
||||
assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||
await test_getting_of_documents(dataset_name_1)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
||||
|
|
@ -106,6 +138,8 @@ async def main():
|
|||
results = await brute_force_triplet_search('What is a quantum computer?')
|
||||
assert len(results) > 0
|
||||
|
||||
await test_local_file_deletion(text, explanation_file_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue