preflight.py 10.1 KB
"""
启动门禁:保证审计日志上传的所有前置条件都成立。
任一失败即阻止应用进入主流程,对用户只显示一句"应用启动失败,请联系 @柴进"。
详细错误脱敏后写入 logs/preflight_error.log。
"""
from __future__ import annotations

import logging
import re
import sys
import traceback
from datetime import datetime
from pathlib import Path
from typing import Tuple

import pymysql

from config_util import load_config_safe
from version import APP_VERSION


logger = logging.getLogger(__name__)


# 版本过旧错误的 detail 前缀 —— 主程序据此分发到 handle_version_too_old 而不是 handle_preflight_failure
VERSION_ERROR_PREFIX = "VERSION_TOO_OLD::"

REQUIRED_DB_FIELDS = ("host", "port", "user", "password", "database")
REQUIRED_TABLES = ("nano_banana_user_use_log", "nano_banana_user_log", "nano_banana_app_config")
REQUIRED_USE_LOG_COLUMNS = (
    "user_name", "device_name", "prompt", "result_path", "status",
    "error_message", "model", "duration_ms", "finish_reason",
)
REQUIRED_LOGIN_LOG_COLUMNS = (
    "user_name", "local_ip", "public_ip", "device_name", "login_time",
)


def preflight_check(config_path: Path, audit_queue_path: Path) -> Tuple[bool, str, dict]:
    """
    返回 (ok, error_detail, config)。
    - ok=True: 一切就绪,调用方可以继续启动
    - ok=False: error_detail 为详细错误描述(未脱敏;handle_preflight_failure 会脱敏后落盘)
    - config: 成功时为可用 config dict;失败时可能为部分加载或 DEFAULT_CONFIG
    """
    # 1. config.json
    try:
        config, load_err = load_config_safe(config_path)
    except Exception as e:
        return False, f"config load crashed:\n{traceback.format_exc()}", {}

    if load_err:
        return False, f"config load error: {load_err}", config

    # 2. db_config 字段完整
    db = config.get("db_config")
    if not db or not isinstance(db, dict):
        return False, "config.json 缺少 db_config 字段或格式错误", config

    missing = [k for k in REQUIRED_DB_FIELDS if not db.get(k)]
    if missing:
        return False, f"db_config 缺少字段: {missing}", config

    # 3. MySQL 连接 + SELECT 1
    conn = None
    try:
        conn = pymysql.connect(
            host=db["host"],
            port=int(db["port"]),
            user=db["user"],
            password=db["password"],
            database=db["database"],
            connect_timeout=5,
            read_timeout=5,
            write_timeout=5,
            charset="utf8mb4",
        )
    except Exception as e:
        return False, f"MySQL connect 失败: {type(e).__name__}: {e}", config

    try:
        with conn.cursor() as cur:
            cur.execute("SELECT 1")
            cur.fetchone()

            # 4. 表存在
            for table in REQUIRED_TABLES:
                try:
                    cur.execute(f"SELECT 1 FROM `{table}` LIMIT 1")
                    cur.fetchone()
                except Exception as e:
                    return False, f"审计表 {table} 不可用: {type(e).__name__}: {e}", config

            # 5. 必要列存在
            ok, col_err = _check_columns(cur, db["database"], "nano_banana_user_use_log",
                                         REQUIRED_USE_LOG_COLUMNS)
            if not ok:
                return False, col_err, config
            ok, col_err = _check_columns(cur, db["database"], "nano_banana_user_log",
                                         REQUIRED_LOGIN_LOG_COLUMNS)
            if not ok:
                return False, col_err, config

            # 5.5. 版本门禁: 本地 APP_VERSION >= MySQL 里的 min_client_version
            ok, ver_err = _check_version(cur)
            if not ok:
                return False, ver_err, config
    finally:
        try:
            conn.close()
        except Exception:
            pass

    # 6. 本地队列目录可写
    try:
        audit_queue_path.parent.mkdir(parents=True, exist_ok=True)
        probe = audit_queue_path.parent / ".preflight_probe"
        probe.write_text("ok", encoding="utf-8")
        probe.unlink()
    except Exception as e:
        return False, f"审计队列目录不可写 ({audit_queue_path.parent}): {e}", config

    return True, "", config


def _parse_version(v: str) -> Tuple[int, int, int]:
    """语义化版本解析。非法值抛 ValueError,调用方 catch 后 fail-safe 放行。"""
    parts = v.strip().split(".")
    out = [int(p) for p in parts[:3]]
    while len(out) < 3:
        out.append(0)
    return (out[0], out[1], out[2])


def _check_version(cur) -> Tuple[bool, str]:
    """
    读 nano_banana_app_config 的 min_client_version 和 download_url,
    对比本地 APP_VERSION。

    fail-safe 策略: 读不到配置 / 解析失败 → 记 WARNING 后放行。
    避免 DBA 一次误删记录让全体用户挂掉。
    """
    try:
        cur.execute(
            "SELECT config_key, config_value FROM nano_banana_app_config "
            "WHERE config_key IN ('min_client_version', 'download_url')"
        )
        rows = {r[0]: r[1] for r in cur.fetchall()}
    except Exception as e:
        logger.warning(f"读取 app_config 失败, 放行: {e}")
        return True, ""

    min_ver = rows.get("min_client_version")
    url = rows.get("download_url", "")
    if not min_ver:
        logger.warning("app_config 缺少 min_client_version 配置, 放行")
        return True, ""

    try:
        local_t = _parse_version(APP_VERSION)
        min_t = _parse_version(min_ver)
    except Exception as e:
        logger.warning(
            f"版本号解析失败 local={APP_VERSION!r} min={min_ver!r}: {e}, 放行"
        )
        return True, ""

    if local_t < min_t:
        # detail 格式: "VERSION_TOO_OLD::<min_ver>|<url>"
        return False, f"{VERSION_ERROR_PREFIX}{min_ver}|{url}"
    return True, ""


def is_version_error(detail: str) -> bool:
    """主程序据此判断 preflight 失败类型,分发到不同的 handle_* 函数。"""
    return detail.startswith(VERSION_ERROR_PREFIX)


def _check_columns(cur, db_name: str, table: str, required: tuple[str, ...]) -> Tuple[bool, str]:
    cur.execute(
        "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS "
        "WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
        (db_name, table),
    )
    existing = {row[0] for row in cur.fetchall()}
    missing = [c for c in required if c not in existing]
    if missing:
        return False, f"表 {table} 缺少列: {missing}(请运行 migrations/2026-04-21_add_audit_log_columns.sql)"
    return True, ""


def handle_preflight_failure(detail: str, logs_dir: Path) -> None:
    """
    写入脱敏详情到 logs/preflight_error.log,显示单行对话框,sys.exit(1)。
    调用此函数前必须已经创建 QApplication。
    """
    from PySide6.QtWidgets import QMessageBox, QApplication

    # 写日志(脱敏)
    try:
        logs_dir.mkdir(parents=True, exist_ok=True)
        err_log = logs_dir / "preflight_error.log"
        with open(err_log, "a", encoding="utf-8") as f:
            f.write(f"\n===== {datetime.now().isoformat(timespec='seconds')} =====\n")
            f.write(_scrub(detail))
            f.write("\n")
    except Exception:
        pass

    # 对用户:一句话
    try:
        app = QApplication.instance()
        if app is None:
            # preflight 失败比 QApplication 创建还早的极端情况(不应发生)
            app = QApplication(sys.argv)
        box = QMessageBox()
        box.setIcon(QMessageBox.Critical)
        box.setWindowTitle("启动失败")
        box.setText("应用启动失败,请联系 @柴进")
        box.setStandardButtons(QMessageBox.Ok)
        box.exec()
    except Exception:
        # 最坏情况:连对话框都弹不出来
        print("应用启动失败,请联系 @柴进", file=sys.stderr)

    sys.exit(1)


def handle_version_too_old(detail: str, logs_dir: Path) -> None:
    """
    版本过旧: 明文弹窗 + "打开下载页"按钮, sys.exit(1)。
    不脱敏 —— min_ver 和 download_url 都是对外公开的,给用户最清晰的升级指引。
    """
    from PySide6.QtWidgets import QMessageBox, QApplication
    from PySide6.QtCore import QUrl
    from PySide6.QtGui import QDesktopServices

    # 解析 detail: "VERSION_TOO_OLD::<min>|<url>"
    payload = detail[len(VERSION_ERROR_PREFIX):]
    min_ver, _, url = payload.partition("|")

    # 记日志 (不脱敏,版本号和 URL 都是公开信息)
    try:
        logs_dir.mkdir(parents=True, exist_ok=True)
        err_log = logs_dir / "preflight_error.log"
        with open(err_log, "a", encoding="utf-8") as f:
            f.write(f"\n===== {datetime.now().isoformat(timespec='seconds')} =====\n")
            f.write(f"版本过旧: local={APP_VERSION}, required>={min_ver}, url={url}\n")
    except Exception:
        pass

    try:
        app = QApplication.instance()
        if app is None:
            app = QApplication(sys.argv)
        box = QMessageBox()
        box.setIcon(QMessageBox.Information)
        box.setWindowTitle("需要升级")
        box.setText(
            f"当前版本 {APP_VERSION} 已不再支持,请升级到 {min_ver} 或更高版本后继续使用。"
        )
        if url:
            open_btn = box.addButton("打开下载页", QMessageBox.AcceptRole)
        else:
            open_btn = None
        quit_btn = box.addButton("退出", QMessageBox.RejectRole)
        if open_btn is not None:
            box.setDefaultButton(open_btn)
        else:
            box.setDefaultButton(quit_btn)
        box.exec()
        if open_btn is not None and box.clickedButton() is open_btn and url:
            QDesktopServices.openUrl(QUrl(url))
    except Exception:
        print(
            f"版本过旧,请升级到 {min_ver},下载: {url}",
            file=sys.stderr,
        )

    sys.exit(1)


_SCRUB_PATTERNS = [
    (re.compile(r'("password"\s*:\s*)"[^"]*"'), r'\1"***"'),
    (re.compile(r'("api_key"\s*:\s*)"[^"]*"'), r'\1"***"'),
    (re.compile(r"(password\s*=\s*)\S+"), r"\1***"),
    (re.compile(r"(api_key\s*=\s*)\S+"), r"\1***"),
]


def _scrub(detail: str) -> str:
    """从详情里擦除 password / api_key。"""
    out = detail
    for pat, repl in _SCRUB_PATTERNS:
        out = pat.sub(repl, out)
    return out