feat: Add SSL support for PostgreSQL database connections
Browse files- Add SSL configuration options (ssl_mode, ssl_cert, ssl_key, ssl_root_cert, ssl_crl)
- Support all PostgreSQL SSL modes (disable, allow, prefer, require, verify-ca, verify-full)
- Add SSL context creation with certificate validation
- Update initdb() method to handle SSL connection parameters
- Add SSL environment variables to env.example
- Maintain backward compatibility with existing non-SSL configurations
- env.example +7 -0
- lightrag/kg/postgres_impl.py +126 -10
env.example
CHANGED
|
@@ -189,6 +189,13 @@ POSTGRES_DATABASE=your_database
|
|
| 189 |
POSTGRES_MAX_CONNECTIONS=12
|
| 190 |
# POSTGRES_WORKSPACE=forced_workspace_name
|
| 191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
### Neo4j Configuration
|
| 193 |
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
| 194 |
NEO4J_USERNAME=neo4j
|
|
|
|
| 189 |
POSTGRES_MAX_CONNECTIONS=12
|
| 190 |
# POSTGRES_WORKSPACE=forced_workspace_name
|
| 191 |
|
| 192 |
+
### PostgreSQL SSL Configuration (Optional)
|
| 193 |
+
# POSTGRES_SSL_MODE=require
|
| 194 |
+
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
|
| 195 |
+
# POSTGRES_SSL_KEY=/path/to/client-key.pem
|
| 196 |
+
# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem
|
| 197 |
+
# POSTGRES_SSL_CRL=/path/to/crl.pem
|
| 198 |
+
|
| 199 |
### Neo4j Configuration
|
| 200 |
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
| 201 |
NEO4J_USERNAME=neo4j
|
lightrag/kg/postgres_impl.py
CHANGED
|
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|
| 8 |
from typing import Any, Union, final
|
| 9 |
import numpy as np
|
| 10 |
import configparser
|
|
|
|
| 11 |
|
| 12 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
| 13 |
|
|
@@ -58,27 +59,121 @@ class PostgreSQLDB:
|
|
| 58 |
self.increment = 1
|
| 59 |
self.pool: Pool | None = None
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
if self.user is None or self.password is None or self.database is None:
|
| 62 |
raise ValueError("Missing database user, password, or database")
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
async def initdb(self):
|
| 65 |
try:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Ensure VECTOR extension is available
|
| 77 |
async with self.pool.acquire() as connection:
|
| 78 |
await self.configure_vector_extension(connection)
|
| 79 |
|
|
|
|
| 80 |
logger.info(
|
| 81 |
-
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}"
|
| 82 |
)
|
| 83 |
except Exception as e:
|
| 84 |
logger.error(
|
|
@@ -809,6 +904,27 @@ class ClientManager:
|
|
| 809 |
"POSTGRES_MAX_CONNECTIONS",
|
| 810 |
config.get("postgres", "max_connections", fallback=20),
|
| 811 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
}
|
| 813 |
|
| 814 |
@classmethod
|
|
|
|
| 8 |
from typing import Any, Union, final
|
| 9 |
import numpy as np
|
| 10 |
import configparser
|
| 11 |
+
import ssl
|
| 12 |
|
| 13 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
| 14 |
|
|
|
|
| 59 |
self.increment = 1
|
| 60 |
self.pool: Pool | None = None
|
| 61 |
|
| 62 |
+
# SSL configuration
|
| 63 |
+
self.ssl_mode = config.get("ssl_mode")
|
| 64 |
+
self.ssl_cert = config.get("ssl_cert")
|
| 65 |
+
self.ssl_key = config.get("ssl_key")
|
| 66 |
+
self.ssl_root_cert = config.get("ssl_root_cert")
|
| 67 |
+
self.ssl_crl = config.get("ssl_crl")
|
| 68 |
+
|
| 69 |
if self.user is None or self.password is None or self.database is None:
|
| 70 |
raise ValueError("Missing database user, password, or database")
|
| 71 |
|
| 72 |
+
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
| 73 |
+
"""Create SSL context based on configuration parameters."""
|
| 74 |
+
if not self.ssl_mode:
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
ssl_mode = self.ssl_mode.lower()
|
| 78 |
+
|
| 79 |
+
# For simple modes that don't require custom context
|
| 80 |
+
if ssl_mode in ["disable", "allow", "prefer", "require"]:
|
| 81 |
+
if ssl_mode == "disable":
|
| 82 |
+
return None
|
| 83 |
+
elif ssl_mode in ["require", "prefer"]:
|
| 84 |
+
# Return None for simple SSL requirement, handled in initdb
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
# For modes that require certificate verification
|
| 88 |
+
if ssl_mode in ["verify-ca", "verify-full"]:
|
| 89 |
+
try:
|
| 90 |
+
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
| 91 |
+
|
| 92 |
+
# Configure certificate verification
|
| 93 |
+
if ssl_mode == "verify-ca":
|
| 94 |
+
context.check_hostname = False
|
| 95 |
+
elif ssl_mode == "verify-full":
|
| 96 |
+
context.check_hostname = True
|
| 97 |
+
|
| 98 |
+
# Load root certificate if provided
|
| 99 |
+
if self.ssl_root_cert:
|
| 100 |
+
if os.path.exists(self.ssl_root_cert):
|
| 101 |
+
context.load_verify_locations(cafile=self.ssl_root_cert)
|
| 102 |
+
logger.info(
|
| 103 |
+
f"PostgreSQL, Loaded SSL root certificate: {self.ssl_root_cert}"
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
logger.warning(
|
| 107 |
+
f"PostgreSQL, SSL root certificate file not found: {self.ssl_root_cert}"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Load client certificate and key if provided
|
| 111 |
+
if self.ssl_cert and self.ssl_key:
|
| 112 |
+
if os.path.exists(self.ssl_cert) and os.path.exists(self.ssl_key):
|
| 113 |
+
context.load_cert_chain(self.ssl_cert, self.ssl_key)
|
| 114 |
+
logger.info(
|
| 115 |
+
f"PostgreSQL, Loaded SSL client certificate: {self.ssl_cert}"
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
logger.warning(
|
| 119 |
+
"PostgreSQL, SSL client certificate or key file not found"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Load certificate revocation list if provided
|
| 123 |
+
if self.ssl_crl:
|
| 124 |
+
if os.path.exists(self.ssl_crl):
|
| 125 |
+
context.load_verify_locations(crlfile=self.ssl_crl)
|
| 126 |
+
logger.info(f"PostgreSQL, Loaded SSL CRL: {self.ssl_crl}")
|
| 127 |
+
else:
|
| 128 |
+
logger.warning(
|
| 129 |
+
f"PostgreSQL, SSL CRL file not found: {self.ssl_crl}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return context
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"PostgreSQL, Failed to create SSL context: {e}")
|
| 136 |
+
raise ValueError(f"SSL configuration error: {e}")
|
| 137 |
+
|
| 138 |
+
# Unknown SSL mode
|
| 139 |
+
logger.warning(f"PostgreSQL, Unknown SSL mode: {ssl_mode}, SSL disabled")
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
async def initdb(self):
|
| 143 |
try:
|
| 144 |
+
# Prepare connection parameters
|
| 145 |
+
connection_params = {
|
| 146 |
+
"user": self.user,
|
| 147 |
+
"password": self.password,
|
| 148 |
+
"database": self.database,
|
| 149 |
+
"host": self.host,
|
| 150 |
+
"port": self.port,
|
| 151 |
+
"min_size": 1,
|
| 152 |
+
"max_size": self.max,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
# Add SSL configuration if provided
|
| 156 |
+
ssl_context = self._create_ssl_context()
|
| 157 |
+
if ssl_context is not None:
|
| 158 |
+
connection_params["ssl"] = ssl_context
|
| 159 |
+
logger.info("PostgreSQL, SSL configuration applied")
|
| 160 |
+
elif self.ssl_mode:
|
| 161 |
+
# Handle simple SSL modes without custom context
|
| 162 |
+
if self.ssl_mode.lower() in ["require", "prefer"]:
|
| 163 |
+
connection_params["ssl"] = True
|
| 164 |
+
elif self.ssl_mode.lower() == "disable":
|
| 165 |
+
connection_params["ssl"] = False
|
| 166 |
+
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
| 167 |
+
|
| 168 |
+
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
| 169 |
|
| 170 |
# Ensure VECTOR extension is available
|
| 171 |
async with self.pool.acquire() as connection:
|
| 172 |
await self.configure_vector_extension(connection)
|
| 173 |
|
| 174 |
+
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
| 175 |
logger.info(
|
| 176 |
+
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database} {ssl_status}"
|
| 177 |
)
|
| 178 |
except Exception as e:
|
| 179 |
logger.error(
|
|
|
|
| 904 |
"POSTGRES_MAX_CONNECTIONS",
|
| 905 |
config.get("postgres", "max_connections", fallback=20),
|
| 906 |
),
|
| 907 |
+
# SSL configuration
|
| 908 |
+
"ssl_mode": os.environ.get(
|
| 909 |
+
"POSTGRES_SSL_MODE",
|
| 910 |
+
config.get("postgres", "ssl_mode", fallback=None),
|
| 911 |
+
),
|
| 912 |
+
"ssl_cert": os.environ.get(
|
| 913 |
+
"POSTGRES_SSL_CERT",
|
| 914 |
+
config.get("postgres", "ssl_cert", fallback=None),
|
| 915 |
+
),
|
| 916 |
+
"ssl_key": os.environ.get(
|
| 917 |
+
"POSTGRES_SSL_KEY",
|
| 918 |
+
config.get("postgres", "ssl_key", fallback=None),
|
| 919 |
+
),
|
| 920 |
+
"ssl_root_cert": os.environ.get(
|
| 921 |
+
"POSTGRES_SSL_ROOT_CERT",
|
| 922 |
+
config.get("postgres", "ssl_root_cert", fallback=None),
|
| 923 |
+
),
|
| 924 |
+
"ssl_crl": os.environ.get(
|
| 925 |
+
"POSTGRES_SSL_CRL",
|
| 926 |
+
config.get("postgres", "ssl_crl", fallback=None),
|
| 927 |
+
),
|
| 928 |
}
|
| 929 |
|
| 930 |
@classmethod
|