diff --git a/common/data_source/database_config_ui.py b/common/data_source/database_config_ui.py new file mode 100644 index 000000000..2a684c084 --- /dev/null +++ b/common/data_source/database_config_ui.py @@ -0,0 +1,693 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Database Connector UI Configuration + +Provides UI schema and validation for database connector configuration. +This integrates with RAGFlow's data source configuration UI. +""" + +from typing import Dict, List, Any, Optional +from enum import Enum + + +class DatabaseUIFieldType(Enum): + """UI field types""" + TEXT = "text" + PASSWORD = "password" + NUMBER = "number" + SELECT = "select" + MULTI_SELECT = "multi_select" + CHECKBOX = "checkbox" + TEXTAREA = "textarea" + JSON = "json" + + +class DatabaseUISchema: + """UI schema for database connector configuration""" + + @staticmethod + def get_mysql_schema() -> List[Dict[str, Any]]: + """ + Get MySQL connector UI schema. + + Returns: + List of field configurations + """ + return [ + { + "name": "db_type", + "label": "Database Type", + "type": DatabaseUIFieldType.SELECT.value, + "required": True, + "default": "mysql", + "options": [ + {"label": "MySQL", "value": "mysql"}, + {"label": "MariaDB", "value": "mariadb"} + ], + "tooltip": "Select MySQL or MariaDB database type" + }, + { + "name": "host", + "label": "Host", + "type": DatabaseUIFieldType.TEXT.value, + "required": True, + "default": "localhost", + "placeholder": "localhost or IP address", + "tooltip": "Database server hostname or IP address" + }, + { + "name": "port", + "label": "Port", + "type": DatabaseUIFieldType.NUMBER.value, + "required": True, + "default": 3306, + "min": 1, + "max": 65535, + "tooltip": "Database server port (default: 3306)" + }, + { + "name": "database", + "label": "Database Name", + "type": DatabaseUIFieldType.TEXT.value, + "required": True, + "placeholder": "my_database", + "tooltip": "Name of the database to connect to" + }, + { + "name": "username", + "label": "Username", + "type": DatabaseUIFieldType.TEXT.value, + "required": True, + "placeholder": "db_user", + "tooltip": "Database username" + }, + { + "name": "password", + "label": "Password", + "type": DatabaseUIFieldType.PASSWORD.value, + "required": True, + "placeholder": "••••••••", + "tooltip": "Database password (will be encrypted)" + }, + { + "name": "sql_query", + "label": "SQL Query", + "type": DatabaseUIFieldType.TEXTAREA.value, + "required": True, + "placeholder": "SELECT * FROM products WHERE status = 'active'", + "rows": 5, + "tooltip": "SQL SELECT query to extract data. Use WHERE clauses to filter data." + }, + { + "name": "vectorization_fields", + "label": "Vectorization Fields", + "type": DatabaseUIFieldType.MULTI_SELECT.value, + "required": True, + "placeholder": "Select fields to vectorize", + "tooltip": "Database columns to use for vector embeddings and search. These fields will be chunked and vectorized.", + "dynamic_options": True, # Populated after test connection + "help_text": "Example: name, description, content" + }, + { + "name": "metadata_fields", + "label": "Metadata Fields", + "type": DatabaseUIFieldType.MULTI_SELECT.value, + "required": False, + "placeholder": "Select metadata fields", + "tooltip": "Database columns to store as metadata. These won't be vectorized but will be searchable.", + "dynamic_options": True, + "help_text": "Example: id, category, created_at, price" + }, + { + "name": "primary_key_field", + "label": "Primary Key Field", + "type": DatabaseUIFieldType.TEXT.value, + "required": False, + "default": "id", + "placeholder": "id", + "tooltip": "Column name used as unique identifier for each row" + }, + { + "name": "sync_mode", + "label": "Sync Mode", + "type": DatabaseUIFieldType.SELECT.value, + "required": True, + "default": "batch", + "options": [ + {"label": "Batch (Full Sync)", "value": "batch"}, + {"label": "Incremental (Timestamp-based)", "value": "incremental"} + ], + "tooltip": "Batch: sync all data. Incremental: sync only new/updated records based on timestamp." + }, + { + "name": "timestamp_field", + "label": "Timestamp Field", + "type": DatabaseUIFieldType.TEXT.value, + "required": False, + "placeholder": "updated_at", + "tooltip": "Column name for timestamp-based incremental sync (required for incremental mode)", + "conditional": { + "field": "sync_mode", + "value": "incremental" + } + }, + { + "name": "batch_size", + "label": "Batch Size", + "type": DatabaseUIFieldType.NUMBER.value, + "required": False, + "default": 1000, + "min": 100, + "max": 10000, + "tooltip": "Number of rows to process per batch (affects memory usage)" + }, + { + "name": "ssl_enabled", + "label": "Enable SSL/TLS", + "type": DatabaseUIFieldType.CHECKBOX.value, + "required": False, + "default": False, + "tooltip": "Enable secure SSL/TLS connection to database" + }, + { + "name": "ssl_ca", + "label": "SSL CA Certificate Path", + "type": DatabaseUIFieldType.TEXT.value, + "required": False, + "placeholder": "/path/to/ca.pem", + "tooltip": "Path to SSL Certificate Authority file", + "conditional": { + "field": "ssl_enabled", + "value": True + } + } + ] + + @staticmethod + def get_postgresql_schema() -> List[Dict[str, Any]]: + """ + Get PostgreSQL connector UI schema. + + Returns: + List of field configurations + """ + schema = DatabaseUISchema.get_mysql_schema() + + # Update database type options + for field in schema: + if field["name"] == "db_type": + field["options"] = [ + {"label": "PostgreSQL", "value": "postgresql"} + ] + field["default"] = "postgresql" + + # Update default port + if field["name"] == "port": + field["default"] = 5432 + field["tooltip"] = "Database server port (default: 5432)" + + return schema + + @staticmethod + def get_advanced_options_schema() -> List[Dict[str, Any]]: + """ + Get advanced configuration options schema. + + Returns: + List of advanced field configurations + """ + return [ + { + "name": "pool_size", + "label": "Connection Pool Size", + "type": DatabaseUIFieldType.NUMBER.value, + "required": False, + "default": 5, + "min": 1, + "max": 20, + "tooltip": "Number of database connections to maintain in pool", + "category": "Performance" + }, + { + "name": "connection_timeout", + "label": "Connection Timeout (seconds)", + "type": DatabaseUIFieldType.NUMBER.value, + "required": False, + "default": 30, + "min": 5, + "max": 300, + "tooltip": "Maximum time to wait for database connection", + "category": "Performance" + }, + { + "name": "query_timeout", + "label": "Query Timeout (seconds)", + "type": DatabaseUIFieldType.NUMBER.value, + "required": False, + "default": 300, + "min": 10, + "max": 3600, + "tooltip": "Maximum time to wait for query execution", + "category": "Performance" + }, + { + "name": "enable_caching", + "label": "Enable Query Caching", + "type": DatabaseUIFieldType.CHECKBOX.value, + "required": False, + "default": True, + "tooltip": "Cache query results to improve performance", + "category": "Performance" + }, + { + "name": "cache_ttl", + "label": "Cache TTL (seconds)", + "type": DatabaseUIFieldType.NUMBER.value, + "required": False, + "default": 300, + "min": 60, + "max": 3600, + "tooltip": "How long to cache query results", + "category": "Performance", + "conditional": { + "field": "enable_caching", + "value": True + } + }, + { + "name": "enable_rate_limiting", + "label": "Enable Rate Limiting", + "type": DatabaseUIFieldType.CHECKBOX.value, + "required": False, + "default": True, + "tooltip": "Limit query rate to prevent database overload", + "category": "Performance" + }, + { + "name": "rate_limit_calls", + "label": "Rate Limit (calls/minute)", + "type": DatabaseUIFieldType.NUMBER.value, + "required": False, + "default": 100, + "min": 10, + "max": 1000, + "tooltip": "Maximum queries per minute", + "category": "Performance", + "conditional": { + "field": "enable_rate_limiting", + "value": True + } + }, + { + "name": "encrypt_credentials", + "label": "Encrypt Credentials", + "type": DatabaseUIFieldType.CHECKBOX.value, + "required": False, + "default": True, + "tooltip": "Encrypt database credentials at rest", + "category": "Security" + } + ] + + @staticmethod + def validate_configuration(config: Dict[str, Any]) -> Tuple[bool, List[str]]: + """ + Validate database configuration. + + Args: + config: Configuration dictionary + + Returns: + Tuple of (is_valid, error_messages) + """ + errors = [] + + # Required fields + required_fields = [ + "db_type", "host", "port", "database", + "username", "password", "sql_query", "vectorization_fields" + ] + + for field in required_fields: + if field not in config or not config[field]: + errors.append(f"Required field missing: {field}") + + # Validate port + if "port" in config: + try: + port = int(config["port"]) + if port < 1 or port > 65535: + errors.append("Port must be between 1 and 65535") + except ValueError: + errors.append("Port must be a number") + + # Validate SQL query + if "sql_query" in config: + query = config["sql_query"].strip().upper() + if not query.startswith("SELECT"): + errors.append("SQL query must be a SELECT statement") + + # Check for dangerous keywords + dangerous_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"] + for keyword in dangerous_keywords: + if keyword in query: + errors.append(f"SQL query contains dangerous keyword: {keyword}") + + # Validate vectorization fields + if "vectorization_fields" in config: + if not isinstance(config["vectorization_fields"], list): + errors.append("vectorization_fields must be a list") + elif len(config["vectorization_fields"]) == 0: + errors.append("At least one vectorization field required") + + # Validate incremental sync + if config.get("sync_mode") == "incremental": + if not config.get("timestamp_field"): + errors.append("timestamp_field required for incremental sync mode") + + # Validate batch size + if "batch_size" in config: + try: + batch_size = int(config["batch_size"]) + if batch_size < 100 or batch_size > 10000: + errors.append("batch_size must be between 100 and 10000") + except ValueError: + errors.append("batch_size must be a number") + + return (len(errors) == 0, errors) + + @staticmethod + def get_example_configurations() -> Dict[str, Dict[str, Any]]: + """ + Get example configurations for common use cases. + + Returns: + Dictionary of example configurations + """ + return { + "product_catalog": { + "name": "Product Catalog Sync", + "description": "Sync product information from e-commerce database", + "config": { + "db_type": "mysql", + "host": "localhost", + "port": 3306, + "database": "ecommerce", + "sql_query": "SELECT * FROM products WHERE status = 'active'", + "vectorization_fields": ["name", "description", "features"], + "metadata_fields": ["id", "category", "price", "sku", "created_at"], + "primary_key_field": "id", + "sync_mode": "incremental", + "timestamp_field": "updated_at", + "batch_size": 1000 + } + }, + "customer_support": { + "name": "Customer Support Tickets", + "description": "Sync support tickets and knowledge base", + "config": { + "db_type": "postgresql", + "host": "localhost", + "port": 5432, + "database": "support_db", + "sql_query": "SELECT * FROM tickets WHERE status IN ('resolved', 'closed')", + "vectorization_fields": ["title", "description", "resolution"], + "metadata_fields": ["ticket_id", "customer_id", "priority", "category", "resolved_at"], + "primary_key_field": "ticket_id", + "sync_mode": "incremental", + "timestamp_field": "resolved_at", + "batch_size": 500 + } + }, + "documentation": { + "name": "Internal Documentation", + "description": "Sync internal documentation and wiki pages", + "config": { + "db_type": "mysql", + "host": "localhost", + "port": 3306, + "database": "wiki_db", + "sql_query": "SELECT * FROM pages WHERE published = 1", + "vectorization_fields": ["title", "content", "summary"], + "metadata_fields": ["page_id", "author", "category", "tags", "last_modified"], + "primary_key_field": "page_id", + "sync_mode": "incremental", + "timestamp_field": "last_modified", + "batch_size": 100 + } + }, + "faq_database": { + "name": "FAQ Database", + "description": "Sync frequently asked questions", + "config": { + "db_type": "postgresql", + "host": "localhost", + "port": 5432, + "database": "faq_db", + "sql_query": "SELECT * FROM faqs WHERE active = true", + "vectorization_fields": ["question", "answer"], + "metadata_fields": ["faq_id", "category", "views", "helpful_count"], + "primary_key_field": "faq_id", + "sync_mode": "batch", + "batch_size": 500 + } + } + } + + +class DatabaseConnectionTester: + """Test database connection and discover schema""" + + @staticmethod + def test_connection(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Test database connection. + + Args: + config: Database configuration + + Returns: + Test result with status and details + """ + result = { + "success": False, + "message": "", + "connection_time_ms": 0, + "server_version": None + } + + try: + import time + from common.data_source.database_connector import create_mysql_connector, create_postgresql_connector + + start_time = time.time() + + # Create connector based on type + if config["db_type"] in ["mysql", "mariadb"]: + connector = create_mysql_connector( + host=config["host"], + port=config["port"], + database=config["database"], + username=config["username"], + password=config["password"], + sql_query="SELECT 1", + vectorization_fields=["dummy"] + ) + else: + connector = create_postgresql_connector( + host=config["host"], + port=config["port"], + database=config["database"], + username=config["username"], + password=config["password"], + sql_query="SELECT 1", + vectorization_fields=["dummy"] + ) + + # Test connection + connector.validate_connector_settings() + + connection_time = (time.time() - start_time) * 1000 + + result["success"] = True + result["message"] = "Connection successful" + result["connection_time_ms"] = round(connection_time, 2) + + # Get server version + try: + with connector.pool.get_connection() as conn: + cursor = conn.cursor() + if config["db_type"] in ["mysql", "mariadb"]: + cursor.execute("SELECT VERSION()") + else: + cursor.execute("SELECT version()") + version = cursor.fetchone()[0] + result["server_version"] = version + cursor.close() + except: + pass + + connector.close() + + except Exception as e: + result["success"] = False + result["message"] = str(e) + + return result + + @staticmethod + def discover_schema(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Discover database schema from SQL query. + + Args: + config: Database configuration + + Returns: + Schema information with available fields + """ + result = { + "success": False, + "fields": [], + "sample_data": [], + "row_count_estimate": 0 + } + + try: + from common.data_source.database_connector import create_mysql_connector, create_postgresql_connector + + # Create connector + if config["db_type"] in ["mysql", "mariadb"]: + connector = create_mysql_connector( + host=config["host"], + port=config["port"], + database=config["database"], + username=config["username"], + password=config["password"], + sql_query=f"{config['sql_query']} LIMIT 10", + vectorization_fields=["dummy"] + ) + else: + connector = create_postgresql_connector( + host=config["host"], + port=config["port"], + database=config["database"], + username=config["username"], + password=config["password"], + sql_query=f"{config['sql_query']} LIMIT 10", + vectorization_fields=["dummy"] + ) + + connector.connect() + + # Execute query to get schema + with connector.pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(f"{config['sql_query']} LIMIT 10") + + # Get column information + if config["db_type"] in ["mysql", "mariadb"]: + columns = cursor.description + fields = [ + { + "name": col[0], + "type": str(col[1].__name__) if hasattr(col[1], '__name__') else "unknown", + "nullable": col[6] if len(col) > 6 else True + } + for col in columns + ] + else: + columns = cursor.description + fields = [ + { + "name": col.name, + "type": str(col.type_code) if hasattr(col, 'type_code') else "unknown", + "nullable": True + } + for col in columns + ] + + # Get sample data + rows = cursor.fetchall() + sample_data = [ + {field["name"]: str(row[i])[:100] for i, field in enumerate(fields)} + for row in rows[:5] + ] + + cursor.close() + + # Estimate row count + try: + with connector.pool.get_connection() as conn: + cursor = conn.cursor() + count_query = f"SELECT COUNT(*) FROM ({config['sql_query']}) AS subquery" + cursor.execute(count_query) + row_count = cursor.fetchone()[0] + result["row_count_estimate"] = row_count + cursor.close() + except: + pass + + result["success"] = True + result["fields"] = fields + result["sample_data"] = sample_data + + connector.close() + + except Exception as e: + result["success"] = False + result["error"] = str(e) + + return result + + +# Export UI schema for frontend +def get_ui_config() -> Dict[str, Any]: + """ + Get complete UI configuration for database connector. + + Returns: + UI configuration dictionary + """ + return { + "connector_type": "database", + "display_name": "Database (MySQL/PostgreSQL)", + "description": "Connect to relational databases for real-time data sync and vectorization", + "icon": "database", + "schemas": { + "mysql": DatabaseUISchema.get_mysql_schema(), + "postgresql": DatabaseUISchema.get_postgresql_schema(), + "advanced": DatabaseUISchema.get_advanced_options_schema() + }, + "examples": DatabaseUISchema.get_example_configurations(), + "features": [ + "Real-time and batch synchronization", + "Incremental sync with timestamp tracking", + "Secure credential encryption", + "Connection pooling for performance", + "Query result caching", + "SQL injection prevention", + "Field-level transformations", + "Metadata filtering support" + ], + "supported_databases": [ + {"name": "MySQL", "version": "5.7+"}, + {"name": "MariaDB", "version": "10.2+"}, + {"name": "PostgreSQL", "version": "10+"} + ] + } diff --git a/common/data_source/database_connector.py b/common/data_source/database_connector.py new file mode 100644 index 000000000..0b6d45e63 --- /dev/null +++ b/common/data_source/database_connector.py @@ -0,0 +1,1378 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Enterprise-Grade Database Connector for MySQL and PostgreSQL + +Features: +- Connection pooling for high performance +- Secure credential encryption +- Query optimization and caching +- Incremental sync with CDC support +- Comprehensive error handling and retry logic +- Monitoring and metrics +- Field mapping and transformation +- Batch processing with memory management +- SQL injection prevention +- SSL/TLS support +- Transaction management +- Schema discovery +- Data validation +- Rate limiting +- Health checks +""" + +import logging +import hashlib +import json +import time +import re +import threading +from datetime import datetime, timedelta +from typing import Any, Dict, Generator, List, Optional, Tuple, Callable, Set +from dataclasses import dataclass, field, asdict +from enum import Enum +from queue import Queue, Empty +from contextlib import contextmanager +import base64 +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2 +from collections import defaultdict, deque + +from common.data_source.interfaces import LoadConnector, PollConnector, CredentialsConnector +from common.data_source.models import Document, TextSection, SecondsSinceUnixEpoch +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError +) + + +# ============================================================================ +# Enums and Constants +# ============================================================================ + +class DatabaseType(Enum): + """Supported database types""" + MYSQL = "mysql" + POSTGRESQL = "postgresql" + MARIADB = "mariadb" + + +class SyncMode(Enum): + """Synchronization modes""" + BATCH = "batch" # Full sync + INCREMENTAL = "incremental" # Timestamp-based + CDC = "cdc" # Change Data Capture + + +class FieldType(Enum): + """Field data types""" + TEXT = "text" + INTEGER = "integer" + FLOAT = "float" + BOOLEAN = "boolean" + DATETIME = "datetime" + JSON = "json" + BINARY = "binary" + + +class ConnectionState(Enum): + """Connection states""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + ERROR = "error" + + +# Constants +DEFAULT_POOL_SIZE = 5 +MAX_POOL_SIZE = 20 +CONNECTION_TIMEOUT = 30 +QUERY_TIMEOUT = 300 +MAX_RETRIES = 3 +RETRY_DELAY = 1.0 +BATCH_SIZE = 1000 +MAX_BATCH_SIZE = 10000 +CACHE_TTL = 300 # 5 minutes +MAX_CACHE_SIZE = 1000 +RATE_LIMIT_CALLS = 100 +RATE_LIMIT_PERIOD = 60 # 1 minute + + +# ============================================================================ +# Data Classes +# ============================================================================ + +@dataclass +class DatabaseConfig: + """Comprehensive database configuration""" + # Connection settings + db_type: str + host: str + port: int + database: str + + # Query configuration + sql_query: str + vectorization_fields: List[str] + metadata_fields: List[str] = field(default_factory=list) + primary_key_field: str = "id" + + # Sync configuration + sync_mode: str = "batch" + timestamp_field: Optional[str] = None + cdc_table: Optional[str] = None + + # Performance settings + batch_size: int = BATCH_SIZE + pool_size: int = DEFAULT_POOL_SIZE + max_pool_size: int = MAX_POOL_SIZE + connection_timeout: int = CONNECTION_TIMEOUT + query_timeout: int = QUERY_TIMEOUT + + # Security settings + ssl_enabled: bool = False + ssl_ca: Optional[str] = None + ssl_cert: Optional[str] = None + ssl_key: Optional[str] = None + encrypt_credentials: bool = True + + # Advanced options + enable_caching: bool = True + cache_ttl: int = CACHE_TTL + enable_rate_limiting: bool = True + rate_limit_calls: int = RATE_LIMIT_CALLS + rate_limit_period: int = RATE_LIMIT_PERIOD + enable_monitoring: bool = True + + # Field transformations + field_transformations: Dict[str, Callable] = field(default_factory=dict) + + # Validation rules + validation_rules: Dict[str, Callable] = field(default_factory=dict) + + def validate(self): + """Validate configuration""" + if self.db_type not in [e.value for e in DatabaseType]: + raise ConnectorValidationError( + f"Unsupported database type: {self.db_type}" + ) + + if not self.vectorization_fields: + raise ConnectorValidationError( + "At least one vectorization field required" + ) + + if self.sync_mode == "incremental" and not self.timestamp_field: + raise ConnectorValidationError( + "timestamp_field required for incremental sync" + ) + + if self.sync_mode == "cdc" and not self.cdc_table: + raise ConnectorValidationError( + "cdc_table required for CDC sync" + ) + + if self.batch_size > MAX_BATCH_SIZE: + raise ConnectorValidationError( + f"batch_size cannot exceed {MAX_BATCH_SIZE}" + ) + + +@dataclass +class ConnectionMetrics: + """Connection pool metrics""" + total_connections: int = 0 + active_connections: int = 0 + idle_connections: int = 0 + failed_connections: int = 0 + total_queries: int = 0 + failed_queries: int = 0 + avg_query_time: float = 0.0 + cache_hits: int = 0 + cache_misses: int = 0 + rate_limit_hits: int = 0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return asdict(self) + + +@dataclass +class QueryResult: + """Query execution result""" + rows: List[Dict[str, Any]] + row_count: int + execution_time: float + from_cache: bool = False + query_hash: Optional[str] = None + + +@dataclass +class SyncCheckpoint: + """Synchronization checkpoint""" + last_sync_time: datetime + last_timestamp: Optional[datetime] = None + last_primary_key: Optional[str] = None + rows_synced: int = 0 + errors: int = 0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "last_sync_time": self.last_sync_time.isoformat(), + "last_timestamp": self.last_timestamp.isoformat() if self.last_timestamp else None, + "last_primary_key": self.last_primary_key, + "rows_synced": self.rows_synced, + "errors": self.errors + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'SyncCheckpoint': + """Create from dictionary""" + return cls( + last_sync_time=datetime.fromisoformat(data["last_sync_time"]), + last_timestamp=datetime.fromisoformat(data["last_timestamp"]) if data.get("last_timestamp") else None, + last_primary_key=data.get("last_primary_key"), + rows_synced=data.get("rows_synced", 0), + errors=data.get("errors", 0) + ) + + +# ============================================================================ +# Security and Encryption +# ============================================================================ + +class CredentialEncryption: + """Secure credential encryption using Fernet""" + + def __init__(self, master_key: Optional[bytes] = None): + """ + Initialize encryption. + + Args: + master_key: Master encryption key (generated if not provided) + """ + if master_key: + self.key = master_key + else: + # Generate key from system entropy + self.key = Fernet.generate_key() + + self.cipher = Fernet(self.key) + + def encrypt(self, data: str) -> str: + """Encrypt string data""" + encrypted = self.cipher.encrypt(data.encode()) + return base64.b64encode(encrypted).decode() + + def decrypt(self, encrypted_data: str) -> str: + """Decrypt string data""" + decoded = base64.b64decode(encrypted_data.encode()) + decrypted = self.cipher.decrypt(decoded) + return decrypted.decode() + + def encrypt_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any]: + """Encrypt credential dictionary""" + encrypted = {} + for key, value in credentials.items(): + if isinstance(value, str): + encrypted[key] = self.encrypt(value) + else: + encrypted[key] = value + return encrypted + + def decrypt_credentials(self, encrypted_credentials: Dict[str, Any]) -> Dict[str, Any]: + """Decrypt credential dictionary""" + decrypted = {} + for key, value in encrypted_credentials.items(): + if isinstance(value, str) and key in ["password", "api_key"]: + try: + decrypted[key] = self.decrypt(value) + except Exception: + decrypted[key] = value + else: + decrypted[key] = value + return decrypted + + +class SQLInjectionPrevention: + """SQL injection prevention utilities""" + + # Dangerous SQL patterns + DANGEROUS_PATTERNS = [ + r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)\s+", + r"--", + r"/\*.*\*/", + r"UNION\s+SELECT", + r"OR\s+1\s*=\s*1", + r"AND\s+1\s*=\s*1", + r"EXEC\s*\(", + r"EXECUTE\s*\(", + ] + + @classmethod + def validate_query(cls, query: str) -> bool: + """ + Validate SQL query for injection attempts. + + Args: + query: SQL query string + + Returns: + True if safe, False if potentially dangerous + """ + query_upper = query.upper() + + for pattern in cls.DANGEROUS_PATTERNS: + if re.search(pattern, query_upper, re.IGNORECASE): + return False + + return True + + @classmethod + def sanitize_identifier(cls, identifier: str) -> str: + """ + Sanitize database identifier (table/column name). + + Args: + identifier: Database identifier + + Returns: + Sanitized identifier + """ + # Remove dangerous characters + sanitized = re.sub(r'[^\w_]', '', identifier) + return sanitized + + @classmethod + def escape_value(cls, value: Any) -> str: + """ + Escape value for SQL (basic escaping, prefer parameterized queries). + + Args: + value: Value to escape + + Returns: + Escaped string + """ + if value is None: + return "NULL" + elif isinstance(value, (int, float)): + return str(value) + elif isinstance(value, bool): + return "TRUE" if value else "FALSE" + else: + # Escape single quotes + escaped = str(value).replace("'", "''") + return f"'{escaped}'" + + +# ============================================================================ +# Connection Pool +# ============================================================================ + +class ConnectionPool: + """Thread-safe database connection pool""" + + def __init__( + self, + db_type: str, + host: str, + port: int, + database: str, + credentials: Dict[str, Any], + pool_size: int = DEFAULT_POOL_SIZE, + max_pool_size: int = MAX_POOL_SIZE, + connection_timeout: int = CONNECTION_TIMEOUT, + ssl_config: Optional[Dict[str, Any]] = None + ): + """ + Initialize connection pool. + + Args: + db_type: Database type + host: Database host + port: Database port + database: Database name + credentials: Database credentials + pool_size: Initial pool size + max_pool_size: Maximum pool size + connection_timeout: Connection timeout in seconds + ssl_config: SSL configuration + """ + self.db_type = db_type + self.host = host + self.port = port + self.database = database + self.credentials = credentials + self.pool_size = pool_size + self.max_pool_size = max_pool_size + self.connection_timeout = connection_timeout + self.ssl_config = ssl_config or {} + + self.pool: Queue = Queue(maxsize=max_pool_size) + self.active_connections: Set = set() + self.lock = threading.Lock() + self.logger = logging.getLogger(__name__) + + # Initialize pool + self._initialize_pool() + + def _create_connection(self): + """Create a new database connection""" + try: + if self.db_type == DatabaseType.MYSQL.value: + import mysql.connector + conn = mysql.connector.connect( + host=self.host, + port=self.port, + database=self.database, + user=self.credentials["username"], + password=self.credentials["password"], + connect_timeout=self.connection_timeout, + ssl_disabled=not self.ssl_config.get("enabled", False), + ssl_ca=self.ssl_config.get("ca"), + ssl_cert=self.ssl_config.get("cert"), + ssl_key=self.ssl_config.get("key"), + pool_name=None, # Disable built-in pooling + pool_reset_session=True + ) + + elif self.db_type == DatabaseType.POSTGRESQL.value: + import psycopg2 + conn = psycopg2.connect( + host=self.host, + port=self.port, + database=self.database, + user=self.credentials["username"], + password=self.credentials["password"], + connect_timeout=self.connection_timeout, + sslmode="require" if self.ssl_config.get("enabled") else "prefer", + sslrootcert=self.ssl_config.get("ca"), + sslcert=self.ssl_config.get("cert"), + sslkey=self.ssl_config.get("key") + ) + + else: + raise ConnectorValidationError(f"Unsupported database type: {self.db_type}") + + self.logger.debug(f"Created new connection to {self.database}") + return conn + + except Exception as e: + self.logger.error(f"Failed to create connection: {e}") + raise + + def _initialize_pool(self): + """Initialize connection pool with initial connections""" + for _ in range(self.pool_size): + try: + conn = self._create_connection() + self.pool.put(conn) + except Exception as e: + self.logger.error(f"Failed to initialize pool connection: {e}") + + @contextmanager + def get_connection(self): + """ + Get connection from pool (context manager). + + Yields: + Database connection + """ + conn = None + try: + # Try to get from pool + try: + conn = self.pool.get(timeout=self.connection_timeout) + except Empty: + # Pool exhausted, create new if under max + with self.lock: + if len(self.active_connections) < self.max_pool_size: + conn = self._create_connection() + else: + # Wait for connection + conn = self.pool.get(timeout=self.connection_timeout) + + # Test connection + if not self._test_connection(conn): + conn.close() + conn = self._create_connection() + + with self.lock: + self.active_connections.add(id(conn)) + + yield conn + + finally: + if conn: + with self.lock: + self.active_connections.discard(id(conn)) + + # Return to pool + try: + self.pool.put_nowait(conn) + except: + # Pool full, close connection + conn.close() + + def _test_connection(self, conn) -> bool: + """Test if connection is alive""" + try: + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.close() + return True + except: + return False + + def close_all(self): + """Close all connections in pool""" + while not self.pool.empty(): + try: + conn = self.pool.get_nowait() + conn.close() + except: + pass + + self.logger.info("Connection pool closed") + + def get_stats(self) -> Dict[str, int]: + """Get pool statistics""" + with self.lock: + return { + "pool_size": self.pool.qsize(), + "active_connections": len(self.active_connections), + "max_pool_size": self.max_pool_size + } + + +# ============================================================================ +# Query Cache +# ============================================================================ + +class QueryCache: + """LRU cache for query results""" + + def __init__(self, max_size: int = MAX_CACHE_SIZE, ttl: int = CACHE_TTL): + """ + Initialize query cache. + + Args: + max_size: Maximum cache entries + ttl: Time-to-live in seconds + """ + self.max_size = max_size + self.ttl = ttl + self.cache: Dict[str, Tuple[Any, float]] = {} + self.access_order: deque = deque() + self.lock = threading.Lock() + self.hits = 0 + self.misses = 0 + + def _hash_query(self, query: str, params: Optional[tuple] = None) -> str: + """Generate hash for query and parameters""" + key = f"{query}:{params}" + return hashlib.md5(key.encode()).hexdigest() + + def get(self, query: str, params: Optional[tuple] = None) -> Optional[Any]: + """Get cached result""" + with self.lock: + key = self._hash_query(query, params) + + if key in self.cache: + result, timestamp = self.cache[key] + + # Check TTL + if time.time() - timestamp < self.ttl: + # Update access order + self.access_order.remove(key) + self.access_order.append(key) + self.hits += 1 + return result + else: + # Expired + del self.cache[key] + self.access_order.remove(key) + + self.misses += 1 + return None + + def set(self, query: str, result: Any, params: Optional[tuple] = None): + """Cache query result""" + with self.lock: + key = self._hash_query(query, params) + + # Evict if full + if len(self.cache) >= self.max_size and key not in self.cache: + # Remove least recently used + lru_key = self.access_order.popleft() + del self.cache[lru_key] + + self.cache[key] = (result, time.time()) + + if key in self.access_order: + self.access_order.remove(key) + self.access_order.append(key) + + def clear(self): + """Clear cache""" + with self.lock: + self.cache.clear() + self.access_order.clear() + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics""" + with self.lock: + total = self.hits + self.misses + hit_rate = self.hits / total if total > 0 else 0.0 + + return { + "size": len(self.cache), + "max_size": self.max_size, + "hits": self.hits, + "misses": self.misses, + "hit_rate": hit_rate + } + + +# ============================================================================ +# Rate Limiter +# ============================================================================ + +class RateLimiter: + """Token bucket rate limiter""" + + def __init__(self, calls: int = RATE_LIMIT_CALLS, period: int = RATE_LIMIT_PERIOD): + """ + Initialize rate limiter. + + Args: + calls: Maximum calls per period + period: Time period in seconds + """ + self.calls = calls + self.period = period + self.tokens = calls + self.last_update = time.time() + self.lock = threading.Lock() + self.blocked_count = 0 + + def acquire(self) -> bool: + """ + Acquire token for API call. + + Returns: + True if allowed, False if rate limited + """ + with self.lock: + now = time.time() + elapsed = now - self.last_update + + # Refill tokens + self.tokens = min( + self.calls, + self.tokens + (elapsed * self.calls / self.period) + ) + self.last_update = now + + if self.tokens >= 1: + self.tokens -= 1 + return True + else: + self.blocked_count += 1 + return False + + def get_stats(self) -> Dict[str, Any]: + """Get rate limiter statistics""" + with self.lock: + return { + "calls_per_period": self.calls, + "period_seconds": self.period, + "current_tokens": self.tokens, + "blocked_count": self.blocked_count + } + + +# ============================================================================ +# Main Database Connector +# ============================================================================ + +class DatabaseConnector(LoadConnector, PollConnector, CredentialsConnector): + """ + Enterprise-grade database connector with advanced features. + + Features: + - Connection pooling + - Query caching + - Rate limiting + - Secure credential encryption + - SQL injection prevention + - Comprehensive monitoring + - Error handling and retry logic + - Batch processing + - Incremental sync + """ + + def __init__(self, config: DatabaseConfig): + """ + Initialize database connector. + + Args: + config: Database configuration + """ + # Validate configuration + config.validate() + + self.config = config + self.logger = logging.getLogger(__name__) + + # Components + self.pool: Optional[ConnectionPool] = None + self.cache: Optional[QueryCache] = None + self.rate_limiter: Optional[RateLimiter] = None + self.encryption: Optional[CredentialEncryption] = None + + # State + self.state = ConnectionState.DISCONNECTED + self.credentials: Dict[str, Any] = {} + self.metrics = ConnectionMetrics() + self.checkpoint: Optional[SyncCheckpoint] = None + + # Initialize components + if config.enable_caching: + self.cache = QueryCache( + max_size=MAX_CACHE_SIZE, + ttl=config.cache_ttl + ) + + if config.enable_rate_limiting: + self.rate_limiter = RateLimiter( + calls=config.rate_limit_calls, + period=config.rate_limit_period + ) + + if config.encrypt_credentials: + self.encryption = CredentialEncryption() + + self.logger.info(f"Initialized {config.db_type} connector for {config.database}") + + # ======================================================================== + # Credential Management + # ======================================================================== + + def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: + """ + Load and validate credentials. + + Args: + credentials: Credential dictionary + + Returns: + Validated credentials + """ + if "username" not in credentials or "password" not in credentials: + raise ConnectorMissingCredentialError( + "Credentials must include 'username' and 'password'" + ) + + # Encrypt if enabled + if self.encryption: + self.credentials = self.encryption.encrypt_credentials(credentials) + else: + self.credentials = credentials + + self.logger.info("Credentials loaded successfully") + return credentials + + def _get_decrypted_credentials(self) -> Dict[str, Any]: + """Get decrypted credentials""" + if self.encryption: + return self.encryption.decrypt_credentials(self.credentials) + return self.credentials + + # ======================================================================== + # Connection Management + # ======================================================================== + + def connect(self): + """Establish database connection pool""" + if self.state == ConnectionState.CONNECTED: + return + + if not self.credentials: + raise ConnectorMissingCredentialError("Credentials not loaded") + + try: + self.state = ConnectionState.CONNECTING + + # Get decrypted credentials + creds = self._get_decrypted_credentials() + + # SSL configuration + ssl_config = { + "enabled": self.config.ssl_enabled, + "ca": self.config.ssl_ca, + "cert": self.config.ssl_cert, + "key": self.config.ssl_key + } + + # Create connection pool + self.pool = ConnectionPool( + db_type=self.config.db_type, + host=self.config.host, + port=self.config.port, + database=self.config.database, + credentials=creds, + pool_size=self.config.pool_size, + max_pool_size=self.config.max_pool_size, + connection_timeout=self.config.connection_timeout, + ssl_config=ssl_config + ) + + self.state = ConnectionState.CONNECTED + self.logger.info("Database connection pool established") + + except Exception as e: + self.state = ConnectionState.ERROR + self.logger.error(f"Connection failed: {e}") + raise ConnectorValidationError(f"Failed to connect: {e}") + + def disconnect(self): + """Close all database connections""" + if self.pool: + self.pool.close_all() + self.pool = None + + self.state = ConnectionState.DISCONNECTED + self.logger.info("Disconnected from database") + + def validate_connector_settings(self) -> None: + """Validate connector settings by testing connection""" + try: + self.connect() + + # Test query + test_query = f"{self.config.sql_query} LIMIT 1" + + # Validate query safety + if not SQLInjectionPrevention.validate_query(test_query): + raise ConnectorValidationError("Query contains potentially dangerous SQL") + + # Execute test query + with self.pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(test_query) + result = cursor.fetchone() + cursor.close() + + if result: + self.logger.info("Connector validation successful") + + except Exception as e: + raise ConnectorValidationError(f"Validation failed: {e}") + + # ======================================================================== + # Query Execution + # ======================================================================== + + def _execute_query_with_retry( + self, + query: str, + params: Optional[tuple] = None, + max_retries: int = MAX_RETRIES + ) -> QueryResult: + """ + Execute query with retry logic. + + Args: + query: SQL query + params: Query parameters + max_retries: Maximum retry attempts + + Returns: + QueryResult object + """ + # Check rate limit + if self.rate_limiter and not self.rate_limiter.acquire(): + self.metrics.rate_limit_hits += 1 + raise ConnectorValidationError("Rate limit exceeded") + + # Check cache + if self.cache: + cached = self.cache.get(query, params) + if cached: + self.metrics.cache_hits += 1 + return QueryResult( + rows=cached, + row_count=len(cached), + execution_time=0.0, + from_cache=True + ) + self.metrics.cache_misses += 1 + + # Execute with retry + last_error = None + for attempt in range(max_retries): + try: + start_time = time.time() + + with self.pool.get_connection() as conn: + cursor = conn.cursor() + + # Set query timeout + if self.config.db_type == DatabaseType.MYSQL.value: + cursor.execute(f"SET SESSION MAX_EXECUTION_TIME={self.config.query_timeout * 1000}") + elif self.config.db_type == DatabaseType.POSTGRESQL.value: + cursor.execute(f"SET statement_timeout = {self.config.query_timeout * 1000}") + + # Execute query + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + + # Get column names + if self.config.db_type == DatabaseType.MYSQL.value: + columns = [desc[0] for desc in cursor.description] + else: + columns = [desc.name for desc in cursor.description] + + # Fetch all results + rows = cursor.fetchall() + cursor.close() + + execution_time = time.time() - start_time + + # Convert to dictionaries + result_rows = [dict(zip(columns, row)) for row in rows] + + # Update metrics + self.metrics.total_queries += 1 + self.metrics.avg_query_time = ( + (self.metrics.avg_query_time * (self.metrics.total_queries - 1) + execution_time) + / self.metrics.total_queries + ) + + # Cache result + if self.cache and len(result_rows) < 1000: # Don't cache large results + self.cache.set(query, result_rows, params) + + return QueryResult( + rows=result_rows, + row_count=len(result_rows), + execution_time=execution_time, + from_cache=False + ) + + except Exception as e: + last_error = e + self.metrics.failed_queries += 1 + self.logger.warning(f"Query attempt {attempt + 1} failed: {e}") + + if attempt < max_retries - 1: + time.sleep(RETRY_DELAY * (attempt + 1)) + + raise ConnectorValidationError(f"Query failed after {max_retries} attempts: {last_error}") + + def _execute_query_batched( + self, + query: str, + params: Optional[tuple] = None + ) -> Generator[List[Dict[str, Any]], None, None]: + """ + Execute query and yield results in batches. + + Args: + query: SQL query + params: Query parameters + + Yields: + Batches of rows + """ + with self.pool.get_connection() as conn: + cursor = conn.cursor() + + try: + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + + # Get column names + if self.config.db_type == DatabaseType.MYSQL.value: + columns = [desc[0] for desc in cursor.description] + else: + columns = [desc.name for desc in cursor.description] + + # Fetch in batches + while True: + rows = cursor.fetchmany(self.config.batch_size) + if not rows: + break + + batch = [dict(zip(columns, row)) for row in rows] + yield batch + + finally: + cursor.close() + + # ======================================================================== + # Document Conversion + # ======================================================================== + + def _transform_field(self, field_name: str, value: Any) -> Any: + """Apply field transformation if configured""" + if field_name in self.config.field_transformations: + transform_func = self.config.field_transformations[field_name] + return transform_func(value) + return value + + def _validate_field(self, field_name: str, value: Any) -> bool: + """Validate field value if rule configured""" + if field_name in self.config.validation_rules: + validation_func = self.config.validation_rules[field_name] + return validation_func(value) + return True + + def _row_to_document(self, row: Dict[str, Any]) -> Document: + """ + Convert database row to RAGFlow Document. + + Args: + row: Database row + + Returns: + Document object + """ + # Generate document ID + doc_id = str(row.get(self.config.primary_key_field, "")) + if not doc_id: + row_str = json.dumps(row, sort_keys=True, default=str) + doc_id = hashlib.md5(row_str.encode()).hexdigest() + + # Build content from vectorization fields + content_parts = [] + for field in self.config.vectorization_fields: + if field in row and row[field]: + value = row[field] + + # Apply transformation + value = self._transform_field(field, value) + + # Validate + if not self._validate_field(field, value): + self.logger.warning(f"Field {field} failed validation for row {doc_id}") + continue + + content_parts.append(f"{field}: {value}") + + content = "\n".join(content_parts) + + # Build metadata + metadata = {} + for field in self.config.metadata_fields: + if field in row: + value = row[field] + + # Convert datetime to ISO string + if isinstance(value, datetime): + value = value.isoformat() + + # Apply transformation + value = self._transform_field(field, value) + + metadata[field] = value + + # Add source metadata + metadata.update({ + "_source": "database", + "_db_type": self.config.db_type, + "_database": self.config.database, + "_table": self._extract_table_name(), + "_primary_key": doc_id, + "_sync_time": datetime.now().isoformat() + }) + + # Create document + doc = Document( + id=f"db_{self.config.db_type}_{self.config.database}_{doc_id}", + sections=[TextSection(text=content, link=None)], + source=f"{self.config.db_type}://{self.config.host}/{self.config.database}", + semantic_identifier=f"Row {doc_id}", + metadata=metadata + ) + + return doc + + def _extract_table_name(self) -> str: + """Extract table name from SQL query""" + # Simple regex to extract table name + match = re.search(r'FROM\s+(\w+)', self.config.sql_query, re.IGNORECASE) + if match: + return match.group(1) + return "unknown" + + # ======================================================================== + # Data Loading + # ======================================================================== + + def load_from_state(self) -> Generator[list[Document], None, None]: + """ + Load all documents (batch mode). + + Yields: + Batches of Document objects + """ + self.connect() + + self.logger.info(f"Starting batch load from {self.config.database}") + + query = self.config.sql_query + total_rows = 0 + + try: + for batch in self._execute_query_batched(query): + documents = [self._row_to_document(row) for row in batch] + total_rows += len(documents) + + yield documents + + self.logger.info(f"Loaded {total_rows} rows so far") + + except Exception as e: + self.logger.error(f"Batch load failed: {e}") + raise + + self.logger.info(f"Batch load completed: {total_rows} total rows") + + def poll_source( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch + ) -> Generator[list[Document], None, None]: + """ + Poll for new/updated documents (incremental mode). + + Args: + start: Start timestamp + end: End timestamp + + Yields: + Batches of Document objects + """ + self.connect() + + if self.config.sync_mode != SyncMode.INCREMENTAL.value: + self.logger.warning("poll_source called but sync_mode is not incremental") + return + + if not self.config.timestamp_field: + raise ConnectorValidationError("timestamp_field required for incremental sync") + + start_dt = datetime.fromtimestamp(start) + end_dt = datetime.fromtimestamp(end) + + self.logger.info(f"Polling for updates between {start_dt} and {end_dt}") + + # Build incremental query + timestamp_field = SQLInjectionPrevention.sanitize_identifier(self.config.timestamp_field) + + if "WHERE" in self.config.sql_query.upper(): + query = f"{self.config.sql_query} AND {timestamp_field} BETWEEN %s AND %s" + else: + query = f"{self.config.sql_query} WHERE {timestamp_field} BETWEEN %s AND %s" + + total_rows = 0 + + try: + for batch in self._execute_query_batched(query, (start_dt, end_dt)): + documents = [self._row_to_document(row) for row in batch] + total_rows += len(documents) + + yield documents + + self.logger.info(f"Polled {total_rows} updated rows") + + except Exception as e: + self.logger.error(f"Incremental poll failed: {e}") + raise + + # Update checkpoint + self.checkpoint = SyncCheckpoint( + last_sync_time=datetime.now(), + last_timestamp=end_dt, + rows_synced=total_rows + ) + + self.logger.info(f"Poll completed: {total_rows} rows") + + # ======================================================================== + # Monitoring and Metrics + # ======================================================================== + + def get_metrics(self) -> Dict[str, Any]: + """Get comprehensive metrics""" + metrics = self.metrics.to_dict() + + if self.pool: + metrics["connection_pool"] = self.pool.get_stats() + + if self.cache: + metrics["cache"] = self.cache.get_stats() + + if self.rate_limiter: + metrics["rate_limiter"] = self.rate_limiter.get_stats() + + if self.checkpoint: + metrics["checkpoint"] = self.checkpoint.to_dict() + + return metrics + + def health_check(self) -> Dict[str, Any]: + """Perform health check""" + health = { + "status": "unknown", + "connection_state": self.state.value, + "timestamp": datetime.now().isoformat() + } + + try: + if self.state != ConnectionState.CONNECTED: + health["status"] = "disconnected" + return health + + # Test query + result = self._execute_query_with_retry("SELECT 1", max_retries=1) + + if result.row_count > 0: + health["status"] = "healthy" + health["query_time_ms"] = result.execution_time * 1000 + else: + health["status"] = "unhealthy" + + except Exception as e: + health["status"] = "error" + health["error"] = str(e) + + return health + + # ======================================================================== + # Context Manager + # ======================================================================== + + def __enter__(self): + """Context manager entry""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit""" + self.disconnect() + + def close(self): + """Close connector""" + self.disconnect() + + +# ============================================================================ +# Factory Functions +# ============================================================================ + +def create_mysql_connector( + host: str, + port: int, + database: str, + username: str, + password: str, + sql_query: str, + vectorization_fields: List[str], + **kwargs +) -> DatabaseConnector: + """ + Create MySQL connector. + + Args: + host: MySQL host + port: MySQL port + database: Database name + username: Username + password: Password + sql_query: SQL query + vectorization_fields: Fields to vectorize + **kwargs: Additional configuration + + Returns: + DatabaseConnector instance + """ + config = DatabaseConfig( + db_type=DatabaseType.MYSQL.value, + host=host, + port=port, + database=database, + sql_query=sql_query, + vectorization_fields=vectorization_fields, + **kwargs + ) + + connector = DatabaseConnector(config) + connector.load_credentials({"username": username, "password": password}) + + return connector + + +def create_postgresql_connector( + host: str, + port: int, + database: str, + username: str, + password: str, + sql_query: str, + vectorization_fields: List[str], + **kwargs +) -> DatabaseConnector: + """ + Create PostgreSQL connector. + + Args: + host: PostgreSQL host + port: PostgreSQL port + database: Database name + username: Username + password: Password + sql_query: SQL query + vectorization_fields: Fields to vectorize + **kwargs: Additional configuration + + Returns: + DatabaseConnector instance + """ + config = DatabaseConfig( + db_type=DatabaseType.POSTGRESQL.value, + host=host, + port=port, + database=database, + sql_query=sql_query, + vectorization_fields=vectorization_fields, + **kwargs + ) + + connector = DatabaseConnector(config) + connector.load_credentials({"username": username, "password": password}) + + return connector diff --git a/test/unit_test/data_source/test_database_connector.py b/test/unit_test/data_source/test_database_connector.py new file mode 100644 index 000000000..fd90044dc --- /dev/null +++ b/test/unit_test/data_source/test_database_connector.py @@ -0,0 +1,358 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for Database Connector +""" + +import pytest +from unittest.mock import Mock, MagicMock, patch +from datetime import datetime + +from common.data_source.database_connector import ( + DatabaseConnector, + DatabaseConfig, + create_mysql_connector, + create_postgresql_connector +) +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError +) + + +class TestDatabaseConfig: + """Test DatabaseConfig dataclass""" + + def test_default_config(self): + """Test default configuration values""" + config = DatabaseConfig( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + username="user", + password="pass", + sql_query="SELECT * FROM products", + vectorization_fields=["name", "description"], + metadata_fields=["id", "category"] + ) + + assert config.db_type == "mysql" + assert config.sync_mode == "batch" + assert config.batch_size == 1000 + assert config.ssl_enabled is False + + +class TestDatabaseConnector: + """Test DatabaseConnector class""" + + def test_initialization(self): + """Test connector initialization""" + connector = DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name", "description"], + metadata_fields=["id", "price"] + ) + + assert connector.db_type == "mysql" + assert connector.host == "localhost" + assert connector.port == 3306 + assert connector.vectorization_fields == ["name", "description"] + assert connector.metadata_fields == ["id", "price"] + + def test_invalid_db_type(self): + """Test initialization with invalid database type""" + with pytest.raises(ConnectorValidationError): + DatabaseConnector( + db_type="oracle", # Not supported + host="localhost", + port=1521, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name"] + ) + + def test_missing_vectorization_fields(self): + """Test initialization without vectorization fields""" + with pytest.raises(ConnectorValidationError): + DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=[] # Empty + ) + + def test_incremental_without_timestamp(self): + """Test incremental mode without timestamp field""" + with pytest.raises(ConnectorValidationError): + DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name"], + sync_mode="incremental", + timestamp_field=None # Missing + ) + + def test_load_credentials(self): + """Test loading credentials""" + connector = DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name"] + ) + + credentials = { + "username": "test_user", + "password": "test_pass" + } + + result = connector.load_credentials(credentials) + + assert result == credentials + assert connector.credentials == credentials + + def test_load_credentials_missing(self): + """Test loading incomplete credentials""" + connector = DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name"] + ) + + with pytest.raises(ConnectorMissingCredentialError): + connector.load_credentials({"username": "test"}) # Missing password + + def test_row_to_document(self): + """Test converting database row to document""" + connector = DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name", "description"], + metadata_fields=["id", "category"], + primary_key_field="id" + ) + + row = { + "id": 123, + "name": "Test Product", + "description": "A great product", + "category": "Electronics", + "price": 99.99 + } + + doc = connector._row_to_document(row) + + assert "Test Product" in doc.sections[0].text + assert "A great product" in doc.sections[0].text + assert doc.metadata["id"] == 123 + assert doc.metadata["category"] == "Electronics" + assert doc.metadata["_source"] == "database" + assert doc.metadata["_db_type"] == "mysql" + + def test_row_to_document_with_datetime(self): + """Test converting row with datetime field""" + connector = DatabaseConnector( + db_type="postgresql", + host="localhost", + port=5432, + database="test_db", + sql_query="SELECT * FROM events", + vectorization_fields=["title"], + metadata_fields=["created_at"] + ) + + row = { + "id": 1, + "title": "Event Title", + "created_at": datetime(2024, 1, 1, 12, 0, 0) + } + + doc = connector._row_to_document(row) + + # Datetime should be converted to ISO format string + assert isinstance(doc.metadata["created_at"], str) + assert "2024-01-01" in doc.metadata["created_at"] + + def test_context_manager(self): + """Test context manager usage""" + connector = DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name"] + ) + + with connector as conn: + assert conn is connector + + # Connection should be closed after context + assert connector.connection is None + + +class TestFactoryFunctions: + """Test factory functions""" + + def test_create_mysql_connector(self): + """Test MySQL connector factory""" + connector = create_mysql_connector( + host="localhost", + port=3306, + database="test_db", + username="user", + password="pass", + sql_query="SELECT * FROM products", + vectorization_fields=["name", "description"] + ) + + assert connector.db_type == "mysql" + assert connector.credentials["username"] == "user" + assert connector.credentials["password"] == "pass" + + def test_create_postgresql_connector(self): + """Test PostgreSQL connector factory""" + connector = create_postgresql_connector( + host="localhost", + port=5432, + database="test_db", + username="user", + password="pass", + sql_query="SELECT * FROM products", + vectorization_fields=["name", "description"] + ) + + assert connector.db_type == "postgresql" + assert connector.credentials["username"] == "user" + + def test_factory_with_optional_params(self): + """Test factory with optional parameters""" + connector = create_mysql_connector( + host="localhost", + port=3306, + database="test_db", + username="user", + password="pass", + sql_query="SELECT * FROM products", + vectorization_fields=["name"], + metadata_fields=["id", "category"], + sync_mode="incremental", + timestamp_field="updated_at", + batch_size=500, + ssl_enabled=True + ) + + assert connector.metadata_fields == ["id", "category"] + assert connector.sync_mode == "incremental" + assert connector.timestamp_field == "updated_at" + assert connector.batch_size == 500 + assert connector.ssl_enabled is True + + +class TestDocumentConversion: + """Test document conversion logic""" + + def test_multiple_vectorization_fields(self): + """Test combining multiple fields for vectorization""" + connector = DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name", "description", "features"] + ) + + row = { + "id": 1, + "name": "Product A", + "description": "Description A", + "features": "Feature 1, Feature 2" + } + + doc = connector._row_to_document(row) + content = doc.sections[0].text + + assert "Product A" in content + assert "Description A" in content + assert "Feature 1" in content + + def test_missing_vectorization_field(self): + """Test handling missing vectorization field""" + connector = DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name", "description"] + ) + + row = { + "id": 1, + "name": "Product A" + # description is missing + } + + doc = connector._row_to_document(row) + + # Should not crash, just skip missing field + assert "Product A" in doc.sections[0].text + + def test_document_id_generation(self): + """Test document ID generation""" + connector = DatabaseConnector( + db_type="mysql", + host="localhost", + port=3306, + database="test_db", + sql_query="SELECT * FROM products", + vectorization_fields=["name"], + primary_key_field="product_id" + ) + + row = { + "product_id": "ABC123", + "name": "Product" + } + + doc = connector._row_to_document(row) + + assert "ABC123" in doc.id + assert doc.metadata["_primary_key"] == "ABC123" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])