Compare commits
4 commits
main
...
feat/add-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
277b5972d0 | ||
|
|
5fc4f10988 | ||
|
|
b5852a2da7 | ||
|
|
de39d5c49a |
8 changed files with 4847 additions and 4102 deletions
|
|
@ -1,8 +1,6 @@
|
|||
from uuid import UUID
|
||||
import os
|
||||
from typing import Union, BinaryIO, List, Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from urllib.parse import urlparse
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines import Task, run_pipeline
|
||||
from cognee.modules.pipelines.layers.resolve_authorized_user_dataset import (
|
||||
|
|
@ -14,17 +12,16 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from .preprocessors import get_preprocessor_registry, PreprocessorContext
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
try:
|
||||
from cognee.tasks.web_scraper.config import TavilyConfig, SoupCrawlerConfig
|
||||
from cognee.context_global_variables import (
|
||||
tavily_config as tavily,
|
||||
soup_crawler_config as soup_crawler,
|
||||
)
|
||||
from .preprocessors.web_preprocessor import register_web_preprocessor
|
||||
|
||||
register_web_preprocessor()
|
||||
except ImportError:
|
||||
logger.debug(f"Unable to import {str(ImportError)}")
|
||||
logger.debug("Web preprocessor not available")
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -38,9 +35,8 @@ async def add(
|
|||
dataset_id: Optional[UUID] = None,
|
||||
preferred_loaders: List[str] = None,
|
||||
incremental_loading: bool = True,
|
||||
extraction_rules: Optional[Dict[str, Any]] = None,
|
||||
tavily_config: Optional[BaseModel] = None,
|
||||
soup_crawler_config: Optional[BaseModel] = None,
|
||||
preprocessors: Optional[List[str]] = None,
|
||||
**preprocessor_params,
|
||||
):
|
||||
"""
|
||||
Add data to Cognee for knowledge graph processing.
|
||||
|
|
@ -97,9 +93,9 @@ async def add(
|
|||
vector_db_config: Optional configuration for vector database (for custom setups).
|
||||
graph_db_config: Optional configuration for graph database (for custom setups).
|
||||
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
|
||||
extraction_rules: Optional dictionary of rules (e.g., CSS selectors, XPath) for extracting specific content from web pages using BeautifulSoup
|
||||
tavily_config: Optional configuration for Tavily API, including API key and extraction settings
|
||||
soup_crawler_config: Optional configuration for BeautifulSoup crawler, specifying concurrency, crawl delay, and extraction rules.
|
||||
preprocessors: Optional list of preprocessor names to run. If None, no preprocessors will run.
|
||||
Available preprocessors: ["web_preprocessor"] for handling URLs.
|
||||
**preprocessor_params: Additional parameters passed to preprocessors (e.g., extraction_rules, tavily_config, soup_crawler_config for web preprocessor).
|
||||
|
||||
Returns:
|
||||
PipelineRunInfo: Information about the ingestion pipeline execution including:
|
||||
|
|
@ -155,14 +151,14 @@ async def add(
|
|||
"description": "p",
|
||||
"more_info": "a[href*='more-info']"
|
||||
}
|
||||
await cognee.add("https://example.com",extraction_rules=extraction_rules)
|
||||
await cognee.add("https://example.com", preprocessors=["web_preprocessor"], extraction_rules=extraction_rules)
|
||||
|
||||
# Add a single url and tavily extract ingestion method
|
||||
Make sure to set TAVILY_API_KEY = YOUR_TAVILY_API_KEY as a environment variable
|
||||
await cognee.add("https://example.com")
|
||||
await cognee.add("https://example.com", preprocessors=["web_preprocessor"], tavily_config=your_config)
|
||||
|
||||
# Add multiple urls
|
||||
await cognee.add(["https://example.com","https://books.toscrape.com"])
|
||||
await cognee.add(["https://example.com","https://books.toscrape.com"], preprocessors=["web_preprocessor"])
|
||||
```
|
||||
|
||||
Environment Variables:
|
||||
|
|
@ -180,27 +176,37 @@ async def add(
|
|||
|
||||
"""
|
||||
|
||||
registry = get_preprocessor_registry()
|
||||
preprocessor_context = PreprocessorContext(
|
||||
data=data,
|
||||
dataset_name=dataset_name,
|
||||
user=user,
|
||||
node_set=node_set,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
dataset_id=dataset_id,
|
||||
preferred_loaders=preferred_loaders,
|
||||
incremental_loading=incremental_loading,
|
||||
extra_params={
|
||||
**preprocessor_params,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
if not soup_crawler_config and extraction_rules:
|
||||
soup_crawler_config = SoupCrawlerConfig(extraction_rules=extraction_rules)
|
||||
if not tavily_config and os.getenv("TAVILY_API_KEY"):
|
||||
tavily_config = TavilyConfig(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
|
||||
soup_crawler.set(soup_crawler_config)
|
||||
tavily.set(tavily_config)
|
||||
|
||||
http_schemes = {"http", "https"}
|
||||
|
||||
def _is_http_url(item: Union[str, BinaryIO]) -> bool:
|
||||
return isinstance(item, str) and urlparse(item).scheme in http_schemes
|
||||
|
||||
if _is_http_url(data):
|
||||
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
|
||||
elif isinstance(data, list) and any(_is_http_url(item) for item in data):
|
||||
node_set = ["web_content"] if not node_set else node_set + ["web_content"]
|
||||
except NameError:
|
||||
logger.debug(f"Unable to import {str(ImportError)}")
|
||||
pass
|
||||
processed_context = await registry.process_with_selected_preprocessors(
|
||||
preprocessor_context, preprocessors or []
|
||||
)
|
||||
data = processed_context.data
|
||||
dataset_name = processed_context.dataset_name
|
||||
user = processed_context.user
|
||||
node_set = processed_context.node_set
|
||||
vector_db_config = processed_context.vector_db_config
|
||||
graph_db_config = processed_context.graph_db_config
|
||||
dataset_id = processed_context.dataset_id
|
||||
preferred_loaders = processed_context.preferred_loaders
|
||||
incremental_loading = processed_context.incremental_loading
|
||||
except Exception as e:
|
||||
logger.error(f"Preprocessor processing failed: {str(e)}")
|
||||
|
||||
tasks = [
|
||||
Task(resolve_data_directories, include_subdirectories=True),
|
||||
|
|
|
|||
16
cognee/api/v1/add/preprocessors/__init__.py
Normal file
16
cognee/api/v1/add/preprocessors/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""
|
||||
Preprocessor system for the cognee add function.
|
||||
|
||||
This module provides a plugin architecture that allows preprocessors to be easily
|
||||
plugged into the add() function without modifying core code.
|
||||
"""
|
||||
|
||||
from .base import Preprocessor, PreprocessorRegistry, PreprocessorContext
|
||||
from .registry import get_preprocessor_registry
|
||||
|
||||
__all__ = [
|
||||
"Preprocessor",
|
||||
"PreprocessorRegistry",
|
||||
"get_preprocessor_registry",
|
||||
"PreprocessorContext",
|
||||
]
|
||||
171
cognee/api/v1/add/preprocessors/base.py
Normal file
171
cognee/api/v1/add/preprocessors/base.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""
|
||||
Base classes for the cognee add preprocessor system.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union, BinaryIO
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class PreprocessorContext(BaseModel):
|
||||
"""Context passed to preprocessors during processing."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
data: Union[BinaryIO, List[BinaryIO], str, List[str]]
|
||||
dataset_name: str
|
||||
user: Optional[User] = None
|
||||
node_set: Optional[List[str]] = None
|
||||
vector_db_config: Optional[Dict] = None
|
||||
graph_db_config: Optional[Dict] = None
|
||||
dataset_id: Optional[UUID] = None
|
||||
preferred_loaders: Optional[List[str]] = None
|
||||
incremental_loading: bool = True
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class PreprocessorResult(BaseModel):
|
||||
"""Result returned by preprocessors."""
|
||||
|
||||
modified_context: Optional[PreprocessorContext] = None
|
||||
stop_processing: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class Preprocessor(ABC):
|
||||
"""
|
||||
Base class for all cognee add preprocessors.
|
||||
|
||||
Preprocessors can modify the processing context, add custom logic,
|
||||
or handle specific data types before main pipeline processing.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Unique name for this preprocessor."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_handle(self, context: PreprocessorContext) -> bool:
|
||||
"""
|
||||
Check if this preprocessor can handle the given context.
|
||||
|
||||
Args:
|
||||
context: The current processing context
|
||||
|
||||
Returns:
|
||||
True if this preprocessor should process this context
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def process(self, context: PreprocessorContext) -> PreprocessorResult:
|
||||
"""
|
||||
Process the given context.
|
||||
|
||||
Args:
|
||||
context: The current processing context
|
||||
|
||||
Returns:
|
||||
PreprocessorResult with any modifications or errors
|
||||
"""
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name}"
|
||||
|
||||
|
||||
class PreprocessorRegistry:
|
||||
"""Registry for managing and executing preprocessors."""
|
||||
|
||||
def __init__(self):
|
||||
self._preprocessors: List[Preprocessor] = []
|
||||
|
||||
def register(self, preprocessor: Preprocessor) -> None:
|
||||
"""Register a preprocessor."""
|
||||
if not isinstance(preprocessor, Preprocessor):
|
||||
raise TypeError(
|
||||
f"Preprocessor must inherit from Preprocessor, got {type(preprocessor)}"
|
||||
)
|
||||
|
||||
if any(prep.name == preprocessor.name for prep in self._preprocessors):
|
||||
raise ValueError(f"Preprocessor with name '{preprocessor.name}' already registered")
|
||||
|
||||
self._preprocessors.append(preprocessor)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Unregister a preprocessor by name."""
|
||||
for i, prep in enumerate(self._preprocessors):
|
||||
if prep.name == name:
|
||||
del self._preprocessors[i]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_preprocessors(self) -> List[Preprocessor]:
|
||||
"""Get all registered preprocessors ordered by priority."""
|
||||
return self._preprocessors.copy()
|
||||
|
||||
def get_applicable_preprocessors(self, context: PreprocessorContext) -> List[Preprocessor]:
|
||||
"""Get preprocessors that can handle the given context."""
|
||||
return [prep for prep in self._preprocessors if prep.can_handle(context)]
|
||||
|
||||
async def process_with_selected_preprocessors(
|
||||
self, context: PreprocessorContext, preprocessor_names: List[str]
|
||||
) -> PreprocessorContext:
|
||||
"""
|
||||
Process context through only the specified preprocessors.
|
||||
|
||||
Args:
|
||||
context: The initial context
|
||||
preprocessor_names: List of preprocessor names to run
|
||||
|
||||
Returns:
|
||||
The final processed context
|
||||
|
||||
Raises:
|
||||
Exception: If any preprocessor encounters an error or preprocessor name not found
|
||||
"""
|
||||
current_context = context
|
||||
|
||||
selected_preprocessors: List[Preprocessor] = []
|
||||
for name in preprocessor_names:
|
||||
preprocessor = next((prep for prep in self._preprocessors if prep.name == name), None)
|
||||
if preprocessor is None:
|
||||
available_names = [prep.name for prep in self._preprocessors]
|
||||
raise ValueError(
|
||||
f"Preprocessor '{name}' not found. Available preprocessors: {available_names}"
|
||||
)
|
||||
selected_preprocessors.append(preprocessor)
|
||||
|
||||
for preprocessor in selected_preprocessors:
|
||||
if not preprocessor.can_handle(current_context):
|
||||
logger.warning(
|
||||
f"Preprocessor '{preprocessor.name}' cannot handle current context, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await preprocessor.process(current_context)
|
||||
|
||||
if result.error:
|
||||
raise Exception(f"Preprocessor '{preprocessor.name}' failed: {result.error}")
|
||||
|
||||
if result.modified_context:
|
||||
current_context = result.modified_context
|
||||
|
||||
if result.stop_processing:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Preprocessor '{preprocessor.name}' encountered an error: {str(e)}"
|
||||
) from e
|
||||
|
||||
return current_context
|
||||
12
cognee/api/v1/add/preprocessors/registry.py
Normal file
12
cognee/api/v1/add/preprocessors/registry.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""
|
||||
Global preprocessor registry for cognee add function.
|
||||
"""
|
||||
|
||||
from .base import PreprocessorRegistry
|
||||
|
||||
_registry = PreprocessorRegistry()
|
||||
|
||||
|
||||
def get_preprocessor_registry() -> PreprocessorRegistry:
|
||||
"""Get the global preprocessor registry."""
|
||||
return _registry
|
||||
98
cognee/api/v1/add/preprocessors/web_preprocessor.py
Normal file
98
cognee/api/v1/add/preprocessors/web_preprocessor.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""
|
||||
Web preprocessor for handling URL inputs in the cognee add function.
|
||||
|
||||
This preprocessor handles web URLs by setting up appropriate crawling configurations
|
||||
and modifying the processing context for web content.
|
||||
"""
|
||||
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from typing import Union, BinaryIO
|
||||
|
||||
from .base import Preprocessor, PreprocessorContext, PreprocessorResult
|
||||
|
||||
try:
|
||||
from cognee.tasks.web_scraper.config import TavilyConfig, SoupCrawlerConfig
|
||||
from cognee.context_global_variables import (
|
||||
tavily_config as tavily,
|
||||
soup_crawler_config as soup_crawler,
|
||||
)
|
||||
|
||||
WEB_SCRAPER_AVAILABLE = True
|
||||
except ImportError:
|
||||
WEB_SCRAPER_AVAILABLE = False
|
||||
|
||||
|
||||
class WebPreprocessor(Preprocessor):
|
||||
"""Preprocessor for handling web URL inputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "web_preprocessor"
|
||||
|
||||
def _is_http_url(self, item: Union[str, BinaryIO]) -> bool:
|
||||
"""Check if an item is an HTTP/HTTPS URL."""
|
||||
http_schemes = {"http", "https"}
|
||||
return isinstance(item, str) and urlparse(item).scheme in http_schemes
|
||||
|
||||
def can_handle(self, context: PreprocessorContext) -> bool:
|
||||
"""Check if this preprocessor can handle the given context."""
|
||||
if not WEB_SCRAPER_AVAILABLE:
|
||||
return False
|
||||
|
||||
if self._is_http_url(context.data):
|
||||
return True
|
||||
|
||||
if isinstance(context.data, list):
|
||||
return any(self._is_http_url(item) for item in context.data)
|
||||
|
||||
return False
|
||||
|
||||
async def process(self, context: PreprocessorContext) -> PreprocessorResult:
|
||||
"""Process web URLs by setting up crawling configurations."""
|
||||
try:
|
||||
extraction_rules = context.extra_params.get("extraction_rules")
|
||||
tavily_config_param = context.extra_params.get("tavily_config")
|
||||
soup_crawler_config_param = context.extra_params.get("soup_crawler_config")
|
||||
|
||||
if not soup_crawler_config_param and extraction_rules:
|
||||
soup_crawler_config_param = SoupCrawlerConfig(extraction_rules=extraction_rules)
|
||||
|
||||
if not tavily_config_param and os.getenv("TAVILY_API_KEY"):
|
||||
tavily_config_param = TavilyConfig(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
|
||||
if soup_crawler_config_param:
|
||||
soup_crawler.set(soup_crawler_config_param)
|
||||
|
||||
tavily.set(tavily_config_param)
|
||||
|
||||
modified_context = context.model_copy()
|
||||
|
||||
if self._is_http_url(context.data):
|
||||
modified_context.node_set = (
|
||||
["web_content"] if not context.node_set else context.node_set + ["web_content"]
|
||||
)
|
||||
elif isinstance(context.data, list) and any(
|
||||
self._is_http_url(item) for item in context.data
|
||||
):
|
||||
modified_context.node_set = (
|
||||
["web_content"] if not context.node_set else context.node_set + ["web_content"]
|
||||
)
|
||||
|
||||
return PreprocessorResult(modified_context=modified_context)
|
||||
|
||||
except Exception as e:
|
||||
return PreprocessorResult(error=f"Failed to configure web scraping: {str(e)}")
|
||||
|
||||
|
||||
def register_web_preprocessor() -> None:
|
||||
"""Register the web preprocessor with the global registry."""
|
||||
from .registry import get_preprocessor_registry
|
||||
|
||||
registry = get_preprocessor_registry()
|
||||
|
||||
if WEB_SCRAPER_AVAILABLE:
|
||||
try:
|
||||
registry.register(WebPreprocessor())
|
||||
except ValueError:
|
||||
pass
|
||||
2
mypy.ini
2
mypy.ini
|
|
@ -1,5 +1,5 @@
|
|||
[mypy]
|
||||
python_version=3.8
|
||||
python_version=3.10
|
||||
ignore_missing_imports=false
|
||||
strict_optional=false
|
||||
warn_redundant_casts=true
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ dependencies = [
|
|||
"pympler>=1.1,<2.0.0",
|
||||
"onnxruntime<=1.22.1",
|
||||
"pylance>=0.22.0,<=0.36.0",
|
||||
"kuzu (==0.11.2)",
|
||||
"kuzu (==0.11.3)",
|
||||
"python-magic-bin<0.5 ; platform_system == 'Windows'", # Only needed for Windows
|
||||
"fastembed<=0.6.0",
|
||||
"networkx>=3.4.2,<4",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue