fix: s3 file storage (#1095)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
67c006bd2f
commit
c5bd6bed40
24 changed files with 4231 additions and 4056 deletions
36
.github/workflows/test_s3_file_storage.yml
vendored
Normal file
36
.github/workflows/test_s3_file_storage.yml
vendored
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
name: test | s3 file storage
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test-gemini:
|
||||
name: Run S3 File Storage Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run S3 File Storage Test
|
||||
env:
|
||||
STORAGE_BACKEND: s3
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: True
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_s3_file_storage.py
|
||||
6
.github/workflows/test_suites.yml
vendored
6
.github/workflows/test_suites.yml
vendored
|
|
@ -97,6 +97,12 @@ jobs:
|
|||
uses: ./.github/workflows/db_examples_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
s3-file-storage-test:
|
||||
name: S3 File Storage Test
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_s3_file_storage.yml
|
||||
secrets: inherit
|
||||
|
||||
# Additional LLM tests
|
||||
gemini-tests:
|
||||
name: Gemini Tests
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class config:
|
|||
@staticmethod
|
||||
def system_root_directory(system_root_directory: str):
|
||||
base_config = get_base_config()
|
||||
base_config.system_root_directory = os.path.join(system_root_directory, ".cognee_system")
|
||||
base_config.system_root_directory = system_root_directory
|
||||
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
|
||||
|
|
@ -24,13 +24,7 @@ class config:
|
|||
|
||||
graph_config = get_graph_config()
|
||||
graph_file_name = graph_config.graph_filename
|
||||
# For Kuzu v0.11.0+, use single-file database with .kuzu extension
|
||||
if graph_config.graph_database_provider.lower() == "kuzu":
|
||||
graph_config.graph_file_path = os.path.join(
|
||||
databases_directory_path, f"{graph_file_name}.kuzu"
|
||||
)
|
||||
else:
|
||||
graph_config.graph_file_path = os.path.join(databases_directory_path, graph_file_name)
|
||||
graph_config.graph_file_path = os.path.join(databases_directory_path, graph_file_name)
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
|
|
@ -39,7 +33,7 @@ class config:
|
|||
@staticmethod
|
||||
def data_root_directory(data_root_directory: str):
|
||||
base_config = get_base_config()
|
||||
base_config.data_root_directory = os.path.join(data_root_directory, ".data_storage")
|
||||
base_config.data_root_directory = data_root_directory
|
||||
|
||||
@staticmethod
|
||||
def monitoring_tool(monitoring_tool: object):
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from uuid import UUID
|
|||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.modules.users.methods import get_user
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
# Note: ContextVar allows us to use different graph db configurations in Cognee
|
||||
# for different async tasks, threads and processes
|
||||
|
|
@ -46,13 +46,15 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
data_root_directory = os.path.join(
|
||||
base_config.data_root_directory, str(user.tenant_id or user.id)
|
||||
)
|
||||
system_directory_path = os.path.join(
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
vector_config = {
|
||||
"vector_db_url": os.path.join(system_directory_path, dataset_database.vector_database_name),
|
||||
"vector_db_url": os.path.join(
|
||||
databases_directory_path, dataset_database.vector_database_name
|
||||
),
|
||||
"vector_db_key": "",
|
||||
"vector_db_provider": "lancedb",
|
||||
}
|
||||
|
|
@ -60,7 +62,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
graph_config = {
|
||||
"graph_database_provider": "kuzu",
|
||||
"graph_file_path": os.path.join(
|
||||
system_directory_path, f"{dataset_database.graph_database_name}.kuzu"
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,19 +50,18 @@ class GraphConfig(BaseSettings):
|
|||
@pydantic.model_validator(mode="after")
|
||||
def fill_derived(cls, values):
|
||||
provider = values.graph_database_provider.lower()
|
||||
# Set filename based on graph database provider if no filename is provided
|
||||
|
||||
# Set default filename if no filename is provided
|
||||
if not values.graph_filename:
|
||||
values.graph_filename = f"cognee_graph_{provider}"
|
||||
|
||||
# Set file path based on graph database provider if no file path is provided
|
||||
if not values.graph_file_path:
|
||||
base_config = get_base_config()
|
||||
|
||||
base = os.path.join(base_config.system_root_directory, "databases")
|
||||
# For Kuzu v0.11.0+, use single-file database with .kuzu extension
|
||||
if provider == "kuzu":
|
||||
values.graph_file_path = os.path.join(base, f"{values.graph_filename}.kuzu")
|
||||
else:
|
||||
values.graph_file_path = os.path.join(base, values.graph_filename)
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
|
||||
|
||||
return values
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
|
|
|||
|
|
@ -1452,9 +1452,13 @@ class KuzuAdapter(GraphDBInterface):
|
|||
db_name = os.path.basename(self.db_path)
|
||||
file_storage = get_file_storage(db_dir)
|
||||
|
||||
if await file_storage.file_exists(db_name):
|
||||
await file_storage.remove_all()
|
||||
logger.info(f"Deleted Kuzu database files at {self.db_path}")
|
||||
if await file_storage.is_file(db_name):
|
||||
await file_storage.remove(db_name)
|
||||
await file_storage.remove(f"{db_name}.lock")
|
||||
else:
|
||||
await file_storage.remove_all(db_name)
|
||||
|
||||
logger.info(f"Deleted Kuzu database files at {self.db_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete graph data: {e}")
|
||||
|
|
|
|||
|
|
@ -679,11 +679,11 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
) # Assuming self.filename is defined elsewhere and holds the default graph file path
|
||||
try:
|
||||
file_dir_path = os.path.dirname(file_path)
|
||||
file_path = os.path.basename(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
file_storage = get_file_storage(file_dir_path)
|
||||
|
||||
await file_storage.remove(file_path)
|
||||
await file_storage.remove(file_name)
|
||||
|
||||
self.graph = None
|
||||
logger.info("Graph deleted successfully.")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
import pydantic
|
||||
from typing import Union
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
|
||||
class RelationalConfig(BaseSettings):
|
||||
|
|
@ -10,7 +12,7 @@ class RelationalConfig(BaseSettings):
|
|||
Configure database connection settings.
|
||||
"""
|
||||
|
||||
db_path: str = os.path.join(get_absolute_path(".cognee_system"), "databases")
|
||||
db_path: str = ""
|
||||
db_name: str = "cognee_db"
|
||||
db_host: Union[str, None] = None # "localhost"
|
||||
db_port: Union[str, None] = None # "5432"
|
||||
|
|
@ -20,6 +22,16 @@ class RelationalConfig(BaseSettings):
|
|||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@pydantic.model_validator(mode="after")
|
||||
def fill_derived(cls, values):
|
||||
# Set file path based on graph database provider if no file path is provided
|
||||
if not values.db_path:
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.db_path = databases_directory_path
|
||||
|
||||
return values
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Return the database configuration as a dictionary.
|
||||
|
|
|
|||
|
|
@ -496,9 +496,9 @@ class SQLAlchemyAdapter:
|
|||
# Wait for the database connections to close and release the file (Windows)
|
||||
await asyncio.sleep(2)
|
||||
db_directory = path.dirname(self.db_path)
|
||||
file_path = path.basename(self.db_path)
|
||||
file_name = path.basename(self.db_path)
|
||||
file_storage = get_file_storage(db_directory)
|
||||
await file_storage.remove(file_path)
|
||||
await file_storage.remove(file_name)
|
||||
else:
|
||||
async with self.engine.begin() as connection:
|
||||
# Create a MetaData instance to load table information
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
import pydantic
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
|
||||
class VectorConfig(BaseSettings):
|
||||
|
|
@ -20,15 +22,23 @@ class VectorConfig(BaseSettings):
|
|||
- vector_db_provider: The provider for the vector database.
|
||||
"""
|
||||
|
||||
vector_db_url: str = os.path.join(
|
||||
os.path.join(get_absolute_path(".cognee_system"), "databases"), "cognee.lancedb"
|
||||
)
|
||||
vector_db_url: str = ""
|
||||
vector_db_port: int = 1234
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@pydantic.model_validator(mode="after")
|
||||
def fill_derived(cls, values):
|
||||
# Set file path based on graph database provider if no file path is provided
|
||||
if not values.vector_db_url:
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||
|
||||
return values
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convert the configuration settings to a dictionary.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from ..get_vector_engine import get_vector_engine, get_vectordb_context_config
|
||||
from sqlalchemy import text
|
||||
from cognee.context_global_variables import vector_db_config as context_vector_db_config
|
||||
from ..get_vector_engine import get_vector_engine, get_vectordb_context_config
|
||||
|
||||
|
||||
async def create_db_and_tables():
|
||||
|
|
|
|||
|
|
@ -171,7 +171,25 @@ class LocalFileStorage(Storage):
|
|||
|
||||
return os.path.exists(os.path.join(parsed_storage_path, file_path))
|
||||
|
||||
def ensure_directory_exists(self, directory_path: str = None):
|
||||
def is_file(self, file_path: str):
|
||||
"""
|
||||
Check if a specified file is a regular file.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to check.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the file is a regular file, otherwise False.
|
||||
"""
|
||||
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||
|
||||
return os.path.isfile(os.path.join(parsed_storage_path, file_path))
|
||||
|
||||
def ensure_directory_exists(self, directory_path: str = ""):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
|
|
@ -182,11 +200,8 @@ class LocalFileStorage(Storage):
|
|||
|
||||
- directory_path (str): The path of the directory to check or create.
|
||||
"""
|
||||
if directory_path is None:
|
||||
if not directory_path.strip():
|
||||
directory_path = get_parsed_path(self.storage_path)
|
||||
elif not directory_path or directory_path.strip() == "":
|
||||
# Handle empty string case - use current directory or storage path
|
||||
directory_path = get_parsed_path(self.storage_path) or "."
|
||||
|
||||
if not os.path.exists(directory_path):
|
||||
os.makedirs(directory_path, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import s3fs
|
|||
from typing import BinaryIO, Union
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from cognee.api.v1.add.config import get_s3_config
|
||||
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
||||
from cognee.infrastructure.utils.run_async import run_async
|
||||
from cognee.infrastructure.files.storage.FileBufferedReader import FileBufferedReader
|
||||
from .storage import Storage
|
||||
|
|
@ -127,7 +127,25 @@ class S3FileStorage(Storage):
|
|||
self.s3.exists, os.path.join(self.storage_path.replace("s3://", ""), file_path)
|
||||
)
|
||||
|
||||
async def ensure_directory_exists(self, directory_path: str = None):
|
||||
async def is_file(self, file_path: str) -> bool:
|
||||
"""
|
||||
Check if a specified file is a regular file.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to check.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the file is a regular file, otherwise False.
|
||||
"""
|
||||
return await run_async(
|
||||
self.s3.isfile, os.path.join(self.storage_path.replace("s3://", ""), file_path)
|
||||
)
|
||||
|
||||
async def ensure_directory_exists(self, directory_path: str = ""):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
|
|
@ -138,7 +156,7 @@ class S3FileStorage(Storage):
|
|||
|
||||
- directory_path (str): The path of the directory to check or create.
|
||||
"""
|
||||
if directory_path is None:
|
||||
if not directory_path.strip():
|
||||
directory_path = self.storage_path.replace("s3://", "")
|
||||
|
||||
def ensure_directory():
|
||||
|
|
|
|||
|
|
@ -40,6 +40,12 @@ class StorageManager:
|
|||
else:
|
||||
return self.storage.file_exists(file_path)
|
||||
|
||||
async def is_file(self, file_path: str):
|
||||
if inspect.iscoroutinefunction(self.storage.is_file):
|
||||
return await self.storage.is_file(file_path)
|
||||
else:
|
||||
return self.storage.is_file(file_path)
|
||||
|
||||
async def store(self, file_path: str, data: BinaryIO, overwrite: bool = False) -> str:
|
||||
"""
|
||||
Store data at the specified file path.
|
||||
|
|
@ -78,7 +84,7 @@ class StorageManager:
|
|||
"""
|
||||
# Check the actual storage type by class name to determine if open() is async or sync
|
||||
|
||||
if self.storage.__class__.__name__ == "S3FileStorage":
|
||||
if self.storage.__class__.__name__ == "S3FileStorage" and file_path.startswith("s3://"):
|
||||
# S3FileStorage.open() is async
|
||||
async with self.storage.open(file_path, *args, **kwargs) as file:
|
||||
yield file
|
||||
|
|
@ -87,7 +93,7 @@ class StorageManager:
|
|||
with self.storage.open(file_path, *args, **kwargs) as file:
|
||||
yield file
|
||||
|
||||
async def ensure_directory_exists(self, directory_path: str = None):
|
||||
async def ensure_directory_exists(self, directory_path: str = ""):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,22 @@ class Storage(Protocol):
|
|||
"""
|
||||
pass
|
||||
|
||||
def is_file(self, file_path: str) -> bool:
|
||||
"""
|
||||
Check if a specified file is a regular file.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to check.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the file is a regular file, otherwise False.
|
||||
"""
|
||||
pass
|
||||
|
||||
def store(self, file_path: str, data: Union[BinaryIO, str], overwrite: bool):
|
||||
"""
|
||||
Store data at the specified file path.
|
||||
|
|
@ -66,7 +82,7 @@ class Storage(Protocol):
|
|||
"""
|
||||
pass
|
||||
|
||||
def ensure_directory_exists(self, directory_path: str = None):
|
||||
def ensure_directory_exists(self, directory_path: str = ""):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ async def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
|
|||
normalized_path = os.path.normpath(file_obj)
|
||||
|
||||
file_dir_path = path.dirname(normalized_path)
|
||||
file_path = path.basename(normalized_path)
|
||||
file_name = path.basename(normalized_path)
|
||||
|
||||
file_storage = get_file_storage(file_dir_path)
|
||||
|
||||
async with file_storage.open(file_path, "rb") as file:
|
||||
async with file_storage.open(file_name, "rb") as file:
|
||||
while True:
|
||||
# Reading is buffered, so we can read smaller chunks.
|
||||
chunk = file.read(h.block_size)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from os import path
|
|||
from urllib.parse import urlparse
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorage
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -11,7 +12,7 @@ async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None,
|
|||
# Check if this is a file URI BEFORE normalizing (which corrupts URIs)
|
||||
if file_path.startswith("file://"):
|
||||
# Normalize the file URI for Windows - replace backslashes with forward slashes
|
||||
normalized_file_uri = file_path.replace("\\", "/")
|
||||
normalized_file_uri = os.path.normpath(file_path)
|
||||
|
||||
parsed_url = urlparse(normalized_file_uri)
|
||||
|
||||
|
|
@ -29,25 +30,39 @@ async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None,
|
|||
file_dir_path = path.dirname(actual_fs_path)
|
||||
file_name = path.basename(actual_fs_path)
|
||||
|
||||
file_storage = LocalFileStorage(file_dir_path)
|
||||
|
||||
with file_storage.open(file_name, mode=mode, encoding=encoding, **kwargs) as file:
|
||||
yield file
|
||||
|
||||
elif file_path.startswith("s3://"):
|
||||
# Handle S3 URLs without normalization (which corrupts them)
|
||||
parsed_url = urlparse(file_path)
|
||||
|
||||
# For S3, reconstruct the directory path and filename
|
||||
s3_path = parsed_url.path.lstrip("/") # Remove leading slash
|
||||
normalized_url = (
|
||||
f"s3://{parsed_url.netloc}{os.sep}{os.path.normpath(parsed_url.path).lstrip(os.sep)}"
|
||||
)
|
||||
|
||||
if "/" in s3_path:
|
||||
s3_dir = "/".join(s3_path.split("/")[:-1])
|
||||
s3_filename = s3_path.split("/")[-1]
|
||||
else:
|
||||
s3_dir = ""
|
||||
s3_filename = s3_path
|
||||
s3_dir_path = os.path.dirname(normalized_url)
|
||||
s3_filename = os.path.basename(normalized_url)
|
||||
|
||||
# if "/" in s3_path:
|
||||
# s3_dir = "/".join(s3_path.split("/")[:-1])
|
||||
# s3_filename = s3_path.split("/")[-1]
|
||||
# else:
|
||||
# s3_dir = ""
|
||||
# s3_filename = s3_path
|
||||
|
||||
# Extract filesystem path from S3 URL structure
|
||||
file_dir_path = (
|
||||
f"s3://{parsed_url.netloc}/{s3_dir}" if s3_dir else f"s3://{parsed_url.netloc}"
|
||||
)
|
||||
file_name = s3_filename
|
||||
# file_dir_path = (
|
||||
# f"s3://{parsed_url.netloc}/{s3_dir}" if s3_dir else f"s3://{parsed_url.netloc}"
|
||||
# )
|
||||
# file_name = s3_filename
|
||||
|
||||
file_storage = S3FileStorage(s3_dir_path)
|
||||
|
||||
async with file_storage.open(s3_filename, mode=mode, **kwargs) as file:
|
||||
yield file
|
||||
|
||||
else:
|
||||
# Regular file path - normalize separators
|
||||
|
|
@ -55,11 +70,11 @@ async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None,
|
|||
file_dir_path = path.dirname(normalized_path)
|
||||
file_name = path.basename(normalized_path)
|
||||
|
||||
# Validate that we have a proper filename
|
||||
if not file_name or file_name == "." or file_name == "..":
|
||||
raise ValueError(f"Invalid filename extracted: '{file_name}' from path: '{file_path}'")
|
||||
# Validate that we have a proper filename
|
||||
if not file_name or file_name == "." or file_name == "..":
|
||||
raise ValueError(f"Invalid filename extracted: '{file_name}' from path: '{file_path}'")
|
||||
|
||||
file_storage = get_file_storage(file_dir_path)
|
||||
file_storage = LocalFileStorage(file_dir_path)
|
||||
|
||||
async with file_storage.open(file_name, mode=mode, encoding=encoding, **kwargs) as file:
|
||||
yield file
|
||||
with file_storage.open(file_name, mode=mode, encoding=encoding, **kwargs) as file:
|
||||
yield file
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from pypdf import PdfReader
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.data.processing.document_types.exceptions.exceptions import PyPdfInternalError
|
||||
|
||||
from .Document import Document
|
||||
from .exceptions.exceptions import PyPdfInternalError
|
||||
|
||||
logger = get_logger("PDFDocument")
|
||||
|
||||
|
|
@ -14,7 +15,7 @@ class PdfDocument(Document):
|
|||
|
||||
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
async with open_data_file(self.raw_data_location, mode="rb") as stream:
|
||||
logger.info(f"Reading PDF:{self.raw_data_location}")
|
||||
logger.info(f"Reading PDF: {self.raw_data_location}")
|
||||
|
||||
try:
|
||||
file = PdfReader(stream, strict=False)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from typing import List, Union, BinaryIO
|
||||
from urllib.parse import urlparse
|
||||
from cognee.api.v1.add.config import get_s3_config
|
||||
from typing import List, Union, BinaryIO
|
||||
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
||||
|
||||
|
||||
async def resolve_data_directories(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from urllib.parse import urlparse
|
||||
from typing import Union, BinaryIO, Any
|
||||
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
|
|
@ -20,17 +21,27 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
|
|||
# Dynamic import is used because the llama_index module is optional.
|
||||
from .transform_data import get_data_from_llama_index
|
||||
|
||||
file_path = await get_data_from_llama_index(data_item)
|
||||
return await get_data_from_llama_index(data_item)
|
||||
|
||||
# data is a file object coming from upload.
|
||||
elif hasattr(data_item, "file"):
|
||||
file_path = await save_data_to_file(data_item.file, filename=data_item.filename)
|
||||
if hasattr(data_item, "file"):
|
||||
return await save_data_to_file(data_item.file, filename=data_item.filename)
|
||||
|
||||
elif isinstance(data_item, str):
|
||||
# data is s3 file or local file path
|
||||
if data_item.startswith("s3://") or data_item.startswith("file://"):
|
||||
file_path = data_item
|
||||
# data is a file path
|
||||
if isinstance(data_item, str):
|
||||
parsed_url = urlparse(data_item)
|
||||
|
||||
# data is s3 file path
|
||||
if parsed_url.scheme == "s3":
|
||||
return data_item
|
||||
|
||||
# data is local file path
|
||||
elif parsed_url.scheme == "file":
|
||||
if settings.accept_local_file_path:
|
||||
return data_item
|
||||
else:
|
||||
raise IngestionError(message="Local files are not accepted.")
|
||||
|
||||
# data is an absolute file path
|
||||
elif data_item.startswith("/") or (
|
||||
os.name == "nt" and len(data_item) > 1 and data_item[1] == ":"
|
||||
):
|
||||
|
|
@ -41,12 +52,13 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
|
|||
# Use forward slashes in file URLs for consistency
|
||||
url_path = normalized_path.replace(os.sep, "/")
|
||||
file_path = "file://" + url_path
|
||||
|
||||
return file_path
|
||||
else:
|
||||
raise IngestionError(message="Local files are not accepted.")
|
||||
# data is text
|
||||
else:
|
||||
file_path = await save_data_to_file(data_item)
|
||||
else:
|
||||
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||
|
||||
return file_path
|
||||
# data is text, save it to data storage and return the file path
|
||||
return await save_data_to_file(data_item)
|
||||
|
||||
# data is not a supported type
|
||||
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||
|
|
|
|||
116
cognee/tests/test_s3_file_storage.py
Executable file
116
cognee/tests/test_s3_file_storage.py
Executable file
|
|
@ -0,0 +1,116 @@
|
|||
import os
|
||||
import pathlib
|
||||
from uuid import uuid4
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
test_run_id = uuid4()
|
||||
data_directory_path = f"s3://cognee-storage-dev/{test_run_id}/data"
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = f"s3://cognee-storage-dev/{test_run_id}/system"
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
dataset_name = "artificial_intelligence"
|
||||
|
||||
ai_text_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/artificial-intelligence.pdf"
|
||||
)
|
||||
await cognee.add([ai_text_file_path], dataset_name)
|
||||
|
||||
text = """A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification. LLMs acquire these abilities by learning statistical relationships from text documents during a computationally intensive self-supervised and semi-supervised training process. LLMs can be used for text generation, a form of generative AI, by taking an input text and repeatedly predicting the next token or word.
|
||||
LLMs are artificial neural networks. The largest and most capable, as of March 2024, are built with a decoder-only transformer-based architecture while some recent implementations are based on other architectures, such as recurrent neural network variants and Mamba (a state space model).
|
||||
Up to 2020, fine tuning was the only way a model could be adapted to be able to accomplish specific tasks. Larger sized models, such as GPT-3, however, can be prompt-engineered to achieve similar results.[6] They are thought to acquire knowledge about syntax, semantics and "ontology" inherent in human language corpora, but also inaccuracies and biases present in the corpora.
|
||||
Some notable LLMs are OpenAI's GPT series of models (e.g., GPT-3.5 and GPT-4, used in ChatGPT and Microsoft Copilot), Google's PaLM and Gemini (the latter of which is currently used in the chatbot of the same name), xAI's Grok, Meta's LLaMA family of open-source models, Anthropic's Claude models, Mistral AI's open source models, and Databricks' open source DBRX.
|
||||
"""
|
||||
|
||||
await cognee.add([text], dataset_name)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.INSIGHTS, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\nExtracted summaries are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
user = await get_default_user()
|
||||
history = await get_history(user.id)
|
||||
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
# Assert local data files are cleaned properly
|
||||
await cognee.prune.prune_data()
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
# Assert relational, vector and graph databases have been cleaned properly
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
connection = await vector_engine.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
assert len(collection_names) == 0, "LanceDB vector database is not empty"
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
db_path = get_relational_engine().db_path
|
||||
dir_path = os.path.dirname(db_path)
|
||||
file_name = os.path.basename(db_path)
|
||||
file_storage = get_file_storage(dir_path)
|
||||
|
||||
assert not await file_storage.file_exists(file_name), (
|
||||
"SQLite relational database is not deleted"
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
|
||||
graph_config = get_graph_config()
|
||||
# For Kuzu v0.11.0+, check if database file doesn't exist (single-file format with .kuzu extension)
|
||||
# For older versions or other providers, check if directory is empty
|
||||
if graph_config.graph_database_provider.lower() == "kuzu":
|
||||
assert not os.path.exists(graph_config.graph_file_path), (
|
||||
"Kuzu graph database file still exists"
|
||||
)
|
||||
else:
|
||||
assert not os.path.exists(graph_config.graph_file_path) or not os.listdir(
|
||||
graph_config.graph_file_path
|
||||
), "Graph database directory is not empty"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main(), debug=True)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "cognee"
|
||||
version = "0.2.1-dev1"
|
||||
version = "0.2.1-dev3"
|
||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||
authors = [
|
||||
{ name = "Vasilije Markovic" },
|
||||
|
|
@ -156,7 +156,6 @@ exclude = [
|
|||
"/.data",
|
||||
"/.github",
|
||||
"/alembic",
|
||||
"/distributed",
|
||||
"/deployment",
|
||||
"/cognee-mcp",
|
||||
"/cognee-frontend",
|
||||
|
|
@ -170,6 +169,9 @@ exclude = [
|
|||
"/tools",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["cognee", "distributed"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
exclude = [
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue