redefine preferred_loaders param to allow for args per loader

This commit is contained in:
Daulet Amirkhanov 2025-10-21 17:10:45 +01:00
parent 7210198f2e
commit 322ef156cb
5 changed files with 11 additions and 9 deletions

View file

@ -23,7 +23,7 @@ async def add(
vector_db_config: dict = None, vector_db_config: dict = None,
graph_db_config: dict = None, graph_db_config: dict = None,
dataset_id: Optional[UUID] = None, dataset_id: Optional[UUID] = None,
preferred_loaders: List[str] = None, preferred_loaders: dict[str, dict[str, Any]] = None,
incremental_loading: bool = True, incremental_loading: bool = True,
data_per_batch: Optional[int] = 20, data_per_batch: Optional[int] = 20,
): ):

View file

@ -1,5 +1,5 @@
from uuid import UUID from uuid import UUID
from typing import Union, BinaryIO, List, Optional from typing import Union, BinaryIO, List, Optional, Any
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.api.v1.delete import delete from cognee.api.v1.delete import delete
@ -15,7 +15,7 @@ async def update(
node_set: Optional[List[str]] = None, node_set: Optional[List[str]] = None,
vector_db_config: dict = None, vector_db_config: dict = None,
graph_db_config: dict = None, graph_db_config: dict = None,
preferred_loaders: List[str] = None, preferred_loaders: dict[str, dict[str, Any]] = None,
incremental_loading: bool = True, incremental_loading: bool = True,
): ):
""" """

View file

@ -64,7 +64,9 @@ class LoaderEngine:
return True return True
def get_loader( def get_loader(
self, file_path: str, preferred_loaders: List[str] = None self,
file_path: str,
preferred_loaders: dict[str, dict[str, Any]],
) -> Optional[LoaderInterface]: ) -> Optional[LoaderInterface]:
""" """
Get appropriate loader for a file. Get appropriate loader for a file.
@ -105,7 +107,7 @@ class LoaderEngine:
async def load_file( async def load_file(
self, self,
file_path: str, file_path: str,
preferred_loaders: Optional[List[str]] = None, preferred_loaders: dict[str, dict[str, Any]] = None,
**kwargs, **kwargs,
): ):
""" """

View file

@ -1,6 +1,6 @@
import os import os
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import List, Tuple from typing import Any, List, Tuple
from pathlib import Path from pathlib import Path
import tempfile import tempfile
@ -35,7 +35,7 @@ async def pull_from_s3(file_path, destination_file) -> None:
async def data_item_to_text_file( async def data_item_to_text_file(
data_item_path: str, data_item_path: str,
preferred_loaders: List[str], preferred_loaders: dict[str, dict[str, Any]] = None,
) -> Tuple[str, LoaderInterface]: ) -> Tuple[str, LoaderInterface]:
if isinstance(data_item_path, str): if isinstance(data_item_path, str):
parsed_url = urlparse(data_item_path) parsed_url = urlparse(data_item_path)

View file

@ -27,7 +27,7 @@ async def ingest_data(
user: User, user: User,
node_set: Optional[List[str]] = None, node_set: Optional[List[str]] = None,
dataset_id: UUID = None, dataset_id: UUID = None,
preferred_loaders: List[str] = None, preferred_loaders: dict[str, dict[str, Any]] = None,
): ):
if not user: if not user:
user = await get_default_user() user = await get_default_user()
@ -44,7 +44,7 @@ async def ingest_data(
user: User, user: User,
node_set: Optional[List[str]] = None, node_set: Optional[List[str]] = None,
dataset_id: UUID = None, dataset_id: UUID = None,
preferred_loaders: List[str] = None, preferred_loaders: dict[str, dict[str, Any]] = None,
): ):
new_datapoints = [] new_datapoints = []
existing_data_points = [] existing_data_points = []