From df8b4202f36df58fcef2d35b8ae873bd8465957f Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 21 Jul 2025 02:03:06 +0800 Subject: [PATCH] feat: Add SSL support for PostgreSQL database connections - Add SSL configuration options (ssl_mode, ssl_cert, ssl_key, ssl_root_cert, ssl_crl) - Support all PostgreSQL SSL modes (disable, allow, prefer, require, verify-ca, verify-full) - Add SSL context creation with certificate validation - Update initdb() method to handle SSL connection parameters - Add SSL environment variables to env.example - Maintain backward compatibility with existing non-SSL configurations --- env.example | 7 ++ lightrag/kg/postgres_impl.py | 136 ++++++++++++++++++++++++++++++++--- 2 files changed, 133 insertions(+), 10 deletions(-) diff --git a/env.example b/env.example index a7abaef9..42751abc 100644 --- a/env.example +++ b/env.example @@ -189,6 +189,13 @@ POSTGRES_DATABASE=your_database POSTGRES_MAX_CONNECTIONS=12 # POSTGRES_WORKSPACE=forced_workspace_name +### PostgreSQL SSL Configuration (Optional) +# POSTGRES_SSL_MODE=require +# POSTGRES_SSL_CERT=/path/to/client-cert.pem +# POSTGRES_SSL_KEY=/path/to/client-key.pem +# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem +# POSTGRES_SSL_CRL=/path/to/crl.pem + ### Neo4j Configuration NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io NEO4J_USERNAME=neo4j diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 842e1e54..9ac0f96f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from typing import Any, Union, final import numpy as np import configparser +import ssl from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge @@ -58,27 +59,121 @@ class PostgreSQLDB: self.increment = 1 self.pool: Pool | None = None + # SSL configuration + self.ssl_mode = config.get("ssl_mode") + self.ssl_cert = config.get("ssl_cert") + self.ssl_key = config.get("ssl_key") + self.ssl_root_cert = config.get("ssl_root_cert") + self.ssl_crl = config.get("ssl_crl") + if self.user is None or self.password is None or self.database is None: raise ValueError("Missing database user, password, or database") + def _create_ssl_context(self) -> ssl.SSLContext | None: + """Create SSL context based on configuration parameters.""" + if not self.ssl_mode: + return None + + ssl_mode = self.ssl_mode.lower() + + # For simple modes that don't require custom context + if ssl_mode in ["disable", "allow", "prefer", "require"]: + if ssl_mode == "disable": + return None + elif ssl_mode in ["require", "prefer"]: + # Return None for simple SSL requirement, handled in initdb + return None + + # For modes that require certificate verification + if ssl_mode in ["verify-ca", "verify-full"]: + try: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + + # Configure certificate verification + if ssl_mode == "verify-ca": + context.check_hostname = False + elif ssl_mode == "verify-full": + context.check_hostname = True + + # Load root certificate if provided + if self.ssl_root_cert: + if os.path.exists(self.ssl_root_cert): + context.load_verify_locations(cafile=self.ssl_root_cert) + logger.info( + f"PostgreSQL, Loaded SSL root certificate: {self.ssl_root_cert}" + ) + else: + logger.warning( + f"PostgreSQL, SSL root certificate file not found: {self.ssl_root_cert}" + ) + + # Load client certificate and key if provided + if self.ssl_cert and self.ssl_key: + if os.path.exists(self.ssl_cert) and os.path.exists(self.ssl_key): + context.load_cert_chain(self.ssl_cert, self.ssl_key) + logger.info( + f"PostgreSQL, Loaded SSL client certificate: {self.ssl_cert}" + ) + else: + logger.warning( + "PostgreSQL, SSL client certificate or key file not found" + ) + + # Load certificate revocation list if provided + if self.ssl_crl: + if os.path.exists(self.ssl_crl): + context.load_verify_locations(crlfile=self.ssl_crl) + logger.info(f"PostgreSQL, Loaded SSL CRL: {self.ssl_crl}") + else: + logger.warning( + f"PostgreSQL, SSL CRL file not found: {self.ssl_crl}" + ) + + return context + + except Exception as e: + logger.error(f"PostgreSQL, Failed to create SSL context: {e}") + raise ValueError(f"SSL configuration error: {e}") + + # Unknown SSL mode + logger.warning(f"PostgreSQL, Unknown SSL mode: {ssl_mode}, SSL disabled") + return None + async def initdb(self): try: - self.pool = await asyncpg.create_pool( # type: ignore - user=self.user, - password=self.password, - database=self.database, - host=self.host, - port=self.port, - min_size=1, - max_size=self.max, - ) + # Prepare connection parameters + connection_params = { + "user": self.user, + "password": self.password, + "database": self.database, + "host": self.host, + "port": self.port, + "min_size": 1, + "max_size": self.max, + } + + # Add SSL configuration if provided + ssl_context = self._create_ssl_context() + if ssl_context is not None: + connection_params["ssl"] = ssl_context + logger.info("PostgreSQL, SSL configuration applied") + elif self.ssl_mode: + # Handle simple SSL modes without custom context + if self.ssl_mode.lower() in ["require", "prefer"]: + connection_params["ssl"] = True + elif self.ssl_mode.lower() == "disable": + connection_params["ssl"] = False + logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}") + + self.pool = await asyncpg.create_pool(**connection_params) # type: ignore # Ensure VECTOR extension is available async with self.pool.acquire() as connection: await self.configure_vector_extension(connection) + ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL" logger.info( - f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}" + f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database} {ssl_status}" ) except Exception as e: logger.error( @@ -809,6 +904,27 @@ class ClientManager: "POSTGRES_MAX_CONNECTIONS", config.get("postgres", "max_connections", fallback=20), ), + # SSL configuration + "ssl_mode": os.environ.get( + "POSTGRES_SSL_MODE", + config.get("postgres", "ssl_mode", fallback=None), + ), + "ssl_cert": os.environ.get( + "POSTGRES_SSL_CERT", + config.get("postgres", "ssl_cert", fallback=None), + ), + "ssl_key": os.environ.get( + "POSTGRES_SSL_KEY", + config.get("postgres", "ssl_key", fallback=None), + ), + "ssl_root_cert": os.environ.get( + "POSTGRES_SSL_ROOT_CERT", + config.get("postgres", "ssl_root_cert", fallback=None), + ), + "ssl_crl": os.environ.get( + "POSTGRES_SSL_CRL", + config.get("postgres", "ssl_crl", fallback=None), + ), } @classmethod