Merge pull request #391 from topoteretes/COG-475-local-file-endpoint-deletion
Cog 475 local file endpoint deletion
This commit is contained in:
commit
51559f055b
4 changed files with 93 additions and 15 deletions
|
|
@ -1,15 +1,23 @@
|
||||||
|
import os
|
||||||
from os import path
|
from os import path
|
||||||
|
import logging
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import AsyncGenerator, List
|
from typing import AsyncGenerator, List
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from sqlalchemy import text, select, MetaData, Table
|
from sqlalchemy import text, select, MetaData, Table, delete
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
from sqlalchemy.exc import NoResultFound
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||||
|
from cognee.modules.data.models.Data import Data
|
||||||
|
|
||||||
from ..ModelBase import Base
|
from ..ModelBase import Base
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class SQLAlchemyAdapter():
|
class SQLAlchemyAdapter():
|
||||||
def __init__(self, connection_string: str):
|
def __init__(self, connection_string: str):
|
||||||
self.db_path: str = None
|
self.db_path: str = None
|
||||||
|
|
@ -86,9 +94,9 @@ class SQLAlchemyAdapter():
|
||||||
return [schema[0] for schema in result.fetchall()]
|
return [schema[0] for schema in result.fetchall()]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
|
async def delete_entity_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
|
||||||
"""
|
"""
|
||||||
Delete data in given table based on id. Table must have an id Column.
|
Delete entity in given table based on id. Table must have an id Column.
|
||||||
"""
|
"""
|
||||||
if self.engine.dialect.name == "sqlite":
|
if self.engine.dialect.name == "sqlite":
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
|
|
@ -107,6 +115,42 @@ class SQLAlchemyAdapter():
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_data_entity(self, data_id: UUID):
|
||||||
|
"""
|
||||||
|
Delete data and local files related to data if there are no references to it anymore.
|
||||||
|
"""
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
if self.engine.dialect.name == "sqlite":
|
||||||
|
# Foreign key constraints are disabled by default in SQLite (for backwards compatibility),
|
||||||
|
# so must be enabled for each database connection/session separately.
|
||||||
|
await session.execute(text("PRAGMA foreign_keys = ON;"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one()
|
||||||
|
except (ValueError, NoResultFound) as e:
|
||||||
|
raise EntityNotFoundError(message=f"Entity not found: {str(e)}")
|
||||||
|
|
||||||
|
# Check if other data objects point to the same raw data location
|
||||||
|
raw_data_location_entities = (await session.execute(
|
||||||
|
select(Data.raw_data_location).where(Data.raw_data_location == data_entity.raw_data_location))).all()
|
||||||
|
|
||||||
|
# Don't delete local file unless this is the only reference to the file in the database
|
||||||
|
if len(raw_data_location_entities) == 1:
|
||||||
|
|
||||||
|
# delete local file only if it's created by cognee
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
config = get_base_config()
|
||||||
|
|
||||||
|
if config.data_root_directory in raw_data_location_entities[0].raw_data_location:
|
||||||
|
if os.path.exists(raw_data_location_entities[0].raw_data_location):
|
||||||
|
os.remove(raw_data_location_entities[0].raw_data_location)
|
||||||
|
else:
|
||||||
|
# Report bug as file should exist
|
||||||
|
logger.error("Local file which should exist can't be found.")
|
||||||
|
|
||||||
|
await session.execute(delete(Data).where(Data.id == data_id))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table:
|
async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table:
|
||||||
"""
|
"""
|
||||||
Dynamically loads a table using the given table name and schema name.
|
Dynamically loads a table using the given table name and schema name.
|
||||||
|
|
|
||||||
|
|
@ -17,4 +17,4 @@ async def delete_data(data: Data):
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
return await db_engine.delete_data_by_id(data.__tablename__, data.id)
|
return await db_engine.delete_data_entity(data.id)
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,4 @@ from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
async def delete_dataset(dataset: Dataset):
|
async def delete_dataset(dataset: Dataset):
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
return await db_engine.delete_data_by_id(dataset.__tablename__, dataset.id)
|
return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,53 @@ import os
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
|
|
||||||
|
from cognee.modules.data.models import Data
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
async def test_local_file_deletion(data_text, file_location):
|
||||||
|
from sqlalchemy import select
|
||||||
|
import hashlib
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
|
engine = get_relational_engine()
|
||||||
|
|
||||||
|
async with engine.get_async_session() as session:
|
||||||
|
# Get hash of data contents
|
||||||
|
encoded_text = data_text.encode("utf-8")
|
||||||
|
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||||
|
# Get data entry from database based on hash contents
|
||||||
|
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||||
|
assert os.path.isfile(data.raw_data_location), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
|
# Test deletion of data along with local files created by cognee
|
||||||
|
await engine.delete_data_entity(data.id)
|
||||||
|
assert not os.path.exists(
|
||||||
|
data.raw_data_location), f"Data location still exists after deletion: {data.raw_data_location}"
|
||||||
|
|
||||||
|
async with engine.get_async_session() as session:
|
||||||
|
# Get data entry from database based on file path
|
||||||
|
data = (await session.scalars(select(Data).where(Data.raw_data_location == file_location))).one()
|
||||||
|
assert os.path.isfile(data.raw_data_location), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
|
# Test local files not created by cognee won't get deleted
|
||||||
|
await engine.delete_data_entity(data.id)
|
||||||
|
assert os.path.exists(data.raw_data_location), f"Data location doesn't exists: {data.raw_data_location}"
|
||||||
|
|
||||||
|
async def test_getting_of_documents(dataset_name_1):
|
||||||
|
# Test getting of documents for search per dataset
|
||||||
|
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||||
|
user = await get_default_user()
|
||||||
|
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||||
|
assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||||
|
|
||||||
|
# Test getting of documents for search when no dataset is provided
|
||||||
|
user = await get_default_user()
|
||||||
|
document_ids = await get_document_ids_for_user(user.id)
|
||||||
|
assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
cognee.config.set_vector_db_config(
|
cognee.config.set_vector_db_config(
|
||||||
|
|
@ -67,16 +108,7 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
# Test getting of documents for search per dataset
|
await test_getting_of_documents(dataset_name_1)
|
||||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
|
||||||
user = await get_default_user()
|
|
||||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
|
||||||
assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
|
||||||
|
|
||||||
# Test getting of documents for search when no dataset is provided
|
|
||||||
user = await get_default_user()
|
|
||||||
document_ids = await get_document_ids_for_user(user.id)
|
|
||||||
assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
||||||
|
|
@ -106,6 +138,8 @@ async def main():
|
||||||
results = await brute_force_triplet_search('What is a quantum computer?')
|
results = await brute_force_triplet_search('What is a quantum computer?')
|
||||||
assert len(results) > 0
|
assert len(results) > 0
|
||||||
|
|
||||||
|
await test_local_file_deletion(text, explanation_file_path)
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue