audit_logger.py 11.7 KB
"""
审计日志本地队列 + 后台上传 worker。

核心保证:事件一旦 log_use / log_login 返回,就已经 fsync 到本地 NDJSON 文件。
后台 worker 负责把本地队列异步上传到 MySQL;失败指数退避重试,成功后 compaction
重写队列文件删除已送达行。应用退出时 flush 一次尽量送达。

公开接口:
- init_audit_logger(db_config, queue_path, logs_dir): 启动单例
- get_audit_logger(): 获取单例(未初始化返回 None)
- AuditLogger.log_use(...)
- AuditLogger.log_login(...)
- AuditLogger.shutdown(timeout=5.0)
"""
from __future__ import annotations

import json
import logging
import os
import threading
import time
from datetime import datetime
from pathlib import Path
from typing import Optional

import pymysql
from PySide6.QtCore import QThread


logger = logging.getLogger(__name__)

_instance: Optional["AuditLogger"] = None
_instance_lock = threading.Lock()


def init_audit_logger(db_config: dict, queue_path: Path, logs_dir: Path) -> "AuditLogger":
    """在 preflight 通过后调用;幂等。"""
    global _instance
    with _instance_lock:
        if _instance is None:
            _instance = AuditLogger(db_config, queue_path, logs_dir)
            _instance.start()
        return _instance


def get_audit_logger() -> Optional["AuditLogger"]:
    return _instance


class AuditLogger:
    """
    对外门面。只负责:
      1. 落盘(log_use / log_login, fsync 后返回)
      2. 拉起/关闭 worker
    真正上传逻辑在 _UploadWorker。
    """

    def __init__(self, db_config: dict, queue_path: Path, logs_dir: Path):
        self._db_config = db_config
        self._queue_path = Path(queue_path)
        self._logs_dir = Path(logs_dir)
        self._file_lock = threading.Lock()
        self._worker = _UploadWorker(
            db_config=db_config,
            queue_path=self._queue_path,
            file_lock=self._file_lock,
        )

    def start(self) -> None:
        self._queue_path.parent.mkdir(parents=True, exist_ok=True)
        self._worker.start()

    def log_use(
        self,
        user_name: str,
        device_name: str,
        prompt: str,
        result_path: Optional[str],
        status: str,
        error_message: Optional[str],
        model: Optional[str],
        duration_ms: Optional[int],
        finish_reason: Optional[str],
    ) -> None:
        record = {
            "kind": "use_log",
            "ts": datetime.now().isoformat(timespec="seconds"),
            "user_name": user_name or "未知用户",
            "device_name": device_name or "未知设备",
            "prompt": prompt or "",
            "result_path": result_path,
            "status": status,
            "error_message": error_message,
            "model": model,
            "duration_ms": duration_ms,
            "finish_reason": finish_reason,
        }
        self._append(record)

    def log_login(
        self,
        user_name: str,
        local_ip: Optional[str],
        public_ip: Optional[str],
        device_name: Optional[str],
    ) -> None:
        record = {
            "kind": "login_log",
            "ts": datetime.now().isoformat(timespec="seconds"),
            "user_name": user_name,
            "local_ip": local_ip,
            "public_ip": public_ip,
            "device_name": device_name,
            "login_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        }
        self._append(record)

    def shutdown(self, timeout: float = 5.0) -> None:
        """应用退出前调用,尽量 flush。"""
        self._worker.stop(timeout)

    def _append(self, record: dict) -> None:
        """落盘 + fsync。发生任何异常都不向上抛,但会落 error 日志(而不是 pass 吞掉)。"""
        try:
            line = json.dumps(record, ensure_ascii=False, default=str)
        except Exception as e:
            logger.error(f"审计事件序列化失败,已丢弃: {e}; record keys={list(record.keys())}")
            return

        try:
            with self._file_lock:
                with open(self._queue_path, "a", encoding="utf-8") as f:
                    f.write(line + "\n")
                    f.flush()
                    os.fsync(f.fileno())
            self._worker.wake()
        except Exception as e:
            # 本地磁盘都写不进去,是真·严重故障。降级到日志文件,不再 raise
            logger.error(f"审计事件落盘失败: {e}; 事件内容已写 error 日志兜底: {line[:200]}")


class _UploadWorker(QThread):
    """后台线程:循环 drain 队列文件 → 批量 INSERT → compaction。"""

    def __init__(self, db_config: dict, queue_path: Path, file_lock: threading.Lock):
        super().__init__()
        self._db_config = db_config
        self._queue_path = Path(queue_path)
        self._file_lock = file_lock
        self._stop_event = threading.Event()
        self._wake_event = threading.Event()
        self._backoff = 1.0

    # --- 外部控制 ---

    def wake(self) -> None:
        self._wake_event.set()

    def stop(self, timeout: float = 5.0) -> None:
        self._stop_event.set()
        self._wake_event.set()
        self.wait(int(timeout * 1000))

    # --- 主循环 ---

    def run(self) -> None:
        logger.info("audit UploadWorker started")
        while not self._stop_event.is_set():
            try:
                sent, unsent = self._drain_once()
            except Exception as e:
                logger.error(f"audit drain 抛出未预期异常: {e}", exc_info=True)
                sent, unsent = 0, 1  # 当做失败处理

            if unsent > 0:
                self._backoff = min(self._backoff * 2, 300.0)
                logger.debug(f"audit: unsent={unsent}, backoff={self._backoff}s")
            else:
                self._backoff = 1.0

            # 退出前再尝试一次 drain(worker stop 时)
            if self._stop_event.is_set():
                break

            wait_s = self._backoff if unsent > 0 else 60.0
            self._wake_event.wait(wait_s)
            self._wake_event.clear()

        # 退出前最后一次 drain
        try:
            self._drain_once()
        except Exception:
            pass
        logger.info("audit UploadWorker stopped")

    # --- 核心 drain ---

    def _drain_once(self) -> tuple[int, int]:
        """
        读快照 -> 批量上传 -> compaction。
        返回 (sent_count, unsent_count)。
        """
        # 1. 快照读
        with self._file_lock:
            if not self._queue_path.exists():
                return 0, 0
            eof_at_read = self._queue_path.stat().st_size
            if eof_at_read == 0:
                return 0, 0
            with open(self._queue_path, "rb") as f:
                head_bytes = f.read(eof_at_read)

        try:
            head_text = head_bytes.decode("utf-8")
        except UnicodeDecodeError as e:
            logger.error(f"audit 队列文件不是合法 UTF-8,跳过本轮: {e}")
            return 0, 1

        lines = [ln for ln in head_text.split("\n") if ln.strip()]
        if not lines:
            return 0, 0

        # 2. 连 DB + 批量 INSERT
        try:
            conn = pymysql.connect(
                host=self._db_config["host"],
                port=int(self._db_config.get("port", 3306)),
                user=self._db_config["user"],
                password=self._db_config["password"],
                database=self._db_config["database"],
                connect_timeout=5,
                read_timeout=10,
                write_timeout=10,
                charset="utf8mb4",
            )
        except Exception as e:
            logger.warning(f"audit connect 失败,稍后重试: {e}")
            return 0, len(lines)

        sent = 0
        unsent_lines: list[str] = []
        try:
            with conn.cursor() as cursor:
                for i, line in enumerate(lines):
                    try:
                        record = json.loads(line)
                    except json.JSONDecodeError as e:
                        logger.error(f"audit 队列出现坏行,已跳过: {e}; line={line[:120]!r}")
                        # 不保留到 unsent(避免无限重试坏行)
                        continue

                    try:
                        self._insert_one(cursor, record)
                        sent += 1
                    except Exception as e:
                        logger.warning(
                            f"audit INSERT 失败(后续全部留队列): {type(e).__name__}: {e}"
                        )
                        unsent_lines = lines[i:]
                        break
            conn.commit()
        except Exception as e:
            logger.warning(f"audit commit 失败: {e}")
            unsent_lines = lines
            sent = 0
        finally:
            try:
                conn.close()
            except Exception:
                pass

        # 3. Compaction:重写队列文件 = unsent_lines + 期间新增的 tail
        with self._file_lock:
            try:
                # 读快照之后新增的尾部
                current_size = self._queue_path.stat().st_size
                tail = b""
                if current_size > eof_at_read:
                    with open(self._queue_path, "rb") as f:
                        f.seek(eof_at_read)
                        tail = f.read()

                with open(self._queue_path, "wb") as f:
                    for ln in unsent_lines:
                        f.write((ln + "\n").encode("utf-8"))
                    if tail:
                        f.write(tail)
                    f.flush()
                    os.fsync(f.fileno())
            except Exception as e:
                # compaction 失败:不致命,已发送的会在下次 drain 被重发(幂等性由
                # MySQL auto-increment id 保障,不会真复制业务数据,仅审计可能重复一次)
                logger.error(f"audit compaction 失败: {e}", exc_info=True)

        if sent > 0:
            logger.info(f"audit drained: sent={sent}, unsent={len(unsent_lines)}")
        return sent, len(unsent_lines)

    # --- 具体插入 ---

    def _insert_one(self, cursor, record: dict) -> None:
        kind = record.get("kind")
        if kind == "use_log":
            sql = """
                INSERT INTO `nano_banana_user_use_log`
                (`user_name`, `device_name`, `prompt`, `result_path`, `status`,
                 `error_message`, `model`, `duration_ms`, `finish_reason`)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
            """
            cursor.execute(
                sql,
                (
                    record.get("user_name", "未知用户"),
                    record.get("device_name", "未知设备"),
                    record.get("prompt", ""),
                    record.get("result_path"),
                    record.get("status", "unknown"),
                    record.get("error_message"),
                    record.get("model"),
                    record.get("duration_ms"),
                    record.get("finish_reason"),
                ),
            )
        elif kind == "login_log":
            sql = """
                INSERT INTO `nano_banana_user_log`
                (`user_name`, `local_ip`, `public_ip`, `device_name`, `login_time`)
                VALUES (%s, %s, %s, %s, %s)
            """
            login_time_val = record.get("login_time") or record.get("ts")
            cursor.execute(
                sql,
                (
                    record.get("user_name"),
                    record.get("local_ip"),
                    record.get("public_ip"),
                    record.get("device_name"),
                    login_time_val,
                ),
            )
        else:
            raise ValueError(f"未知审计事件 kind={kind!r}")