fix: Resolve issue with text classification
This commit is contained in:
parent
7743071c51
commit
e2457ef277
3 changed files with 25 additions and 11 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
import io
|
import io
|
||||||
import os.path
|
import os.path
|
||||||
from typing import BinaryIO, TypedDict
|
from typing import BinaryIO, TypedDict, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
@ -27,7 +27,7 @@ class FileMetadata(TypedDict):
|
||||||
file_size: int
|
file_size: int
|
||||||
|
|
||||||
|
|
||||||
async def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
async def get_file_metadata(file: BinaryIO, name: Optional[str] = None) -> FileMetadata:
|
||||||
"""
|
"""
|
||||||
Retrieve metadata from a file object.
|
Retrieve metadata from a file object.
|
||||||
|
|
||||||
|
|
@ -53,15 +53,15 @@ async def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||||
except io.UnsupportedOperation as error:
|
except io.UnsupportedOperation as error:
|
||||||
logger.error(f"Error retrieving content hash for file: {file.name} \n{str(error)}\n\n")
|
logger.error(f"Error retrieving content hash for file: {file.name} \n{str(error)}\n\n")
|
||||||
|
|
||||||
file_type = guess_file_type(file)
|
file_type = guess_file_type(file, name=name)
|
||||||
|
|
||||||
file_path = getattr(file, "name", None) or getattr(file, "full_name", None)
|
file_path = getattr(file, "name", None) or getattr(file, "full_name", None)
|
||||||
|
|
||||||
if isinstance(file_path, str):
|
if isinstance(file_path, str):
|
||||||
file_name = Path(file_path).stem if file_path else None
|
file_name = Path(file_path).stem if file_path else None
|
||||||
else:
|
else:
|
||||||
# In case file_path does not exist or is a integer return None
|
# In case file_path does not exist try file_name
|
||||||
file_name = None
|
file_name = name
|
||||||
|
|
||||||
# Get file size
|
# Get file size
|
||||||
pos = file.tell() # remember current pointer
|
pos = file.tell() # remember current pointer
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
from typing import BinaryIO
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import BinaryIO, Optional, Any
|
||||||
import filetype
|
import filetype
|
||||||
from .is_text_content import is_text_content
|
from tempfile import SpooledTemporaryFile
|
||||||
|
from filetype.types.base import Type
|
||||||
|
|
||||||
|
|
||||||
class FileTypeException(Exception):
|
class FileTypeException(Exception):
|
||||||
|
|
@ -22,7 +25,7 @@ class FileTypeException(Exception):
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
def guess_file_type(file: BinaryIO) -> filetype.Type:
|
def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type:
|
||||||
"""
|
"""
|
||||||
Guess the file type from the given binary file stream.
|
Guess the file type from the given binary file stream.
|
||||||
|
|
||||||
|
|
@ -39,12 +42,23 @@ def guess_file_type(file: BinaryIO) -> filetype.Type:
|
||||||
|
|
||||||
- filetype.Type: The guessed file type, represented as filetype.Type.
|
- filetype.Type: The guessed file type, represented as filetype.Type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Note: If file has .txt or .text extension, consider it a plain text file as filetype.guess may not detect it properly
|
||||||
|
# as it contains no magic number encoding
|
||||||
|
ext = None
|
||||||
|
if isinstance(file, str):
|
||||||
|
ext = Path(file).suffix
|
||||||
|
elif name is not None:
|
||||||
|
ext = Path(name).suffix
|
||||||
|
|
||||||
|
if ext in [".txt", ".text"]:
|
||||||
|
file_type = Type("text/plain", "txt")
|
||||||
|
return file_type
|
||||||
|
|
||||||
file_type = filetype.guess(file)
|
file_type = filetype.guess(file)
|
||||||
|
|
||||||
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
|
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
|
||||||
if file_type is None:
|
if file_type is None:
|
||||||
from filetype.types.base import Type
|
|
||||||
|
|
||||||
file_type = Type("text/plain", "txt")
|
file_type = Type("text/plain", "txt")
|
||||||
|
|
||||||
if file_type is None:
|
if file_type is None:
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class BinaryData(IngestionData):
|
||||||
|
|
||||||
async def ensure_metadata(self):
|
async def ensure_metadata(self):
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
self.metadata = await get_file_metadata(self.data)
|
self.metadata = await get_file_metadata(self.data, name=self.name)
|
||||||
|
|
||||||
if self.metadata["name"] is None:
|
if self.metadata["name"] is None:
|
||||||
self.metadata["name"] = self.name
|
self.metadata["name"] = self.name
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue