diff --git a/cognee/infrastructure/files/storage/S3FileStorage.py b/cognee/infrastructure/files/storage/S3FileStorage.py index 7c5a1033c..a0d611241 100644 --- a/cognee/infrastructure/files/storage/S3FileStorage.py +++ b/cognee/infrastructure/files/storage/S3FileStorage.py @@ -21,10 +21,11 @@ class S3FileStorage(Storage): def __init__(self, storage_path: str): self.storage_path = storage_path s3_config = get_s3_config() - if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None: + if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None and s3_config.aws_session_token is not None: self.s3 = s3fs.S3FileSystem( key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, + token=s3_config.aws_session_token, anon=False, endpoint_url=s3_config.aws_endpoint_url, client_kwargs={"region_name": s3_config.aws_region}, @@ -146,6 +147,11 @@ class S3FileStorage(Storage): self.s3.isfile, os.path.join(self.storage_path.replace("s3://", ""), file_path) ) + async def get_size(self, file_path: str) -> int: + return await run_async( + self.s3.size, os.path.join(self.storage_path.replace("s3://", ""), file_path) + ) + async def ensure_directory_exists(self, directory_path: str = ""): """ Ensure that the specified directory exists, creating it if necessary. diff --git a/cognee/infrastructure/files/storage/s3_config.py b/cognee/infrastructure/files/storage/s3_config.py index 0b9372b7e..3b59bcd57 100644 --- a/cognee/infrastructure/files/storage/s3_config.py +++ b/cognee/infrastructure/files/storage/s3_config.py @@ -8,9 +8,9 @@ class S3Config(BaseSettings): aws_endpoint_url: Optional[str] = None aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") - @lru_cache def get_s3_config(): return S3Config() diff --git a/cognee/tasks/ingestion/resolve_data_directories.py b/cognee/tasks/ingestion/resolve_data_directories.py index 1d3124a0c..cbd979e16 100644 --- a/cognee/tasks/ingestion/resolve_data_directories.py +++ b/cognee/tasks/ingestion/resolve_data_directories.py @@ -32,7 +32,7 @@ async def resolve_data_directories( import s3fs fs = s3fs.S3FileSystem( - key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False + key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key,token=s3_config.aws_session_token, anon=False ) for item in data: