File size: 6,809 Bytes
3060aa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# PyFundaments: A Secure Python Architecture
# Copyright 2008-2025 - Volkan Kücükbudak
# Apache License V. 2
# Repo: https://github.com/VolkanSah/PyFundaments
# fundaments/postgresql.py
import os
import logging
import asyncpg
import ssl
from urllib.parse import urlparse, urlencode, parse_qs, urlunparse
from typing import Optional

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

_db_pool: Optional[asyncpg.Pool] = None

def enforce_cloud_security(dsn_url: str) -> str:
    """
    Enforces security settings for cloud environments.
    - Ensures SSL mode is at least 'require'
    - Removes unsupported options for cloud providers (e.g. statement_timeout for Neon)
    - Sets connect_timeout and keepalives_idle defaults
    """
    parsed = urlparse(dsn_url)
    query_params = parse_qs(parsed.query)

    # Enforce SSL (at least 'require')
    sslmode = query_params.get('sslmode', ['prefer'])[0].lower()
    if sslmode not in ['require', 'verify-ca', 'verify-full']:
        query_params['sslmode'] = ['require']

    # Set timeouts and keep-alives if not present
    if 'connect_timeout' not in query_params:
        query_params['connect_timeout'] = ['5']
    if 'keepalives_idle' not in query_params:
        query_params['keepalives_idle'] = ['60']

    # Remove statement_timeout option for Neon
    if 'neon.tech' in parsed.netloc:
        if 'options' in query_params:
            options_clean = []
            for opt in query_params['options']:
                if 'statement_timeout' not in opt:
                    options_clean.append(opt)
            if options_clean:
                query_params['options'] = options_clean
            else:
                query_params.pop('options')
            logger.info("Removed unsupported 'statement_timeout' option for Neon.tech.")
        # Optionally, set a supported option for Neon (usually none)
    
    # TODO: Extend here for further providers...

    # Rebuild DSN
    new_query = urlencode(query_params, doseq=True)
    new_url = parsed._replace(query=new_query)
    return urlunparse(new_url)

def mask_dsn(dsn_url: str) -> str:
    """
    Masks username/password from DSN so they are not exposed in logs.
    """
    parsed = urlparse(dsn_url)
    safe_netloc = f"{parsed.hostname}:{parsed.port}" if parsed.port else parsed.hostname
    return parsed._replace(netloc=safe_netloc).geturl()

async def ssl_runtime_check(conn: asyncpg.Connection):
    """
    Performs a cloud-aware SSL runtime check on an active connection.
    For Neon/Supabase (or unknown cloud) only log a warning if pg_stat_ssl is unavailable.
    """
    dsn = os.getenv("DATABASE_URL", "")
    try:
        ssl_status = await conn.fetchval("""
            SELECT CASE WHEN ssl THEN 'active' ELSE 'INACTIVE' END
            FROM pg_stat_ssl WHERE pid = pg_backend_pid()
        """)
        if ssl_status != 'active':
            logger.critical("CRITICAL ERROR: SSL connection is not active!")
            raise RuntimeError("SSL connection failed")
        logger.info("SSL connection is active.")
    except Exception as e:
        # Cloud: If pg_stat_ssl is not available, don't fail hard.
        if "neon.tech" in dsn or "supabase" in dsn:
            logger.warning("SSL check via pg_stat_ssl not possible (cloud restriction). Assuming SSL is active due to sslmode=require.")
        else:
            logger.critical(f"SSL runtime check failed: {e}")
            raise

async def init_db_pool(dsn_url: Optional[str] = None) -> Optional[asyncpg.Pool]:
    """Initializes the asynchronous database connection pool."""
    global _db_pool
    if _db_pool:
        return _db_pool

    if not dsn_url:
        dsn_url = os.getenv("DATABASE_URL") or os.getenv("PG_DSN")
        if not dsn_url:
            logger.warning("No DATABASE_URL or PG_DSN found. Skipping DB pool initialization.")
            return None

    # Enforce cloud security and remove unsupported options
    secured_dsn = enforce_cloud_security(dsn_url)

    # ⚠ WARNING: This logs full credentials — keep only for secure DEV debugging
    logger.debug(f"[DEV ONLY] Full DSN used for DB connection: {secured_dsn}")

    # Always log a masked DSN for production safety
    logger.info(f"DSN used for DB connection (masked): {mask_dsn(secured_dsn)}")

    ssl_context = None
    if 'sslmode=verify-full' in secured_dsn:
        ssl_context = ssl.create_default_context()

    try:
        logger.info("Initializing secure database pool...")
        _db_pool = await asyncpg.create_pool(
            dsn=secured_dsn,
            min_size=1,
            max_size=10,
            timeout=5,
            command_timeout=30,
            ssl=ssl_context
        )
        # Post-init checks
        async with _db_pool.acquire() as conn:
            await ssl_runtime_check(conn)
        logger.info("Secure database pool initialized.")
        return _db_pool
    except Exception as e:
        logger.critical(f"Pool initialization failed: {str(e)}")
        _db_pool = None
        return None  # Fallback: allow app to run without DB

async def close_db_pool():
    """Gracefully closes the database connection pool."""
    global _db_pool
    if _db_pool:
        await _db_pool.close()
        _db_pool = None
        logger.info("Database pool closed successfully.")

async def execute_secured_query(query: str, *params, fetch_method='fetch'):
    """
    Executes a parameterized query with integrated security checks.
    """
    global _db_pool
    if not _db_pool:
        raise RuntimeError("Database pool not initialized")

    try:
        async with _db_pool.acquire() as conn:
            if fetch_method == 'fetch':
                return await conn.fetch(query, *params)
            elif fetch_method == 'fetchrow':
                return await conn.fetchrow(query, *params)
            elif fetch_method == 'execute':
                return await conn.execute(query, *params)
            else:
                raise ValueError("Invalid fetch_method")
    except asyncpg.PostgresError as e:
        error_type = "Security violation" if getattr(e, 'sqlstate', None) == '42501' else "Database error"
        
        if os.getenv('APP_ENV') == 'production':
            logger.error(f"{error_type} [Code: {getattr(e, 'sqlstate', '?')}]")
        else:
            logger.error(f"{error_type}: {e}")
        
        # Neon: Reconnect if connection terminated (optional)
        if getattr(e, 'sqlstate', None) == '08006' and 'neon.tech' in (os.getenv("DATABASE_URL") or ''):
            logger.warning("Neon.tech connection terminated. Restarting pool...")
            await close_db_pool()
            await init_db_pool(os.getenv("DATABASE_URL"))
        
        raise