#!/usr/bin/env python3
"""
Phase 3: Katana web crawler for Vibe Pentest.

Supports:
  - Anonymous crawl (no credentials)
  - Authenticated crawl (with cookies and/or Bearer token)
  - Multi-session crawl (one per test account)

Outputs JSON lines for easy parsing by Phase 4.
"""

import argparse
import hashlib
import json
import os
import platform
import shutil
import subprocess
import sys
import tarfile
import urllib.request
import zipfile
from pathlib import Path


# ── Katana 基础参数 ──────────────────────────────────────
KATANA_BASE_ARGS = [
    "-d", "5",           # 爬取深度 5 层
    "-jc",               # JavaScript 爬取
    "-kf", "all",        # 爬取已知文件 (robots.txt, sitemap.xml)
    "-aff",              # 自动填充表单
    "-hl",               # 无头模式 (headless)
    "-nos",              # --no-sandbox
    "-timeout", "15",    # 请求超时 15 秒
    "-retry", "2",       # 失败重试 2 次
    "-rl", "150",        # 限速 150 req/s
    "-j",                # JSONL 输出
]

TOOLS_DIR = Path(__file__).parent.parent / "tools"
KATANA_DOWNLOADS_FILE = TOOLS_DIR / "katana_downloads.json"
KATANA_CACHE_DIR = TOOLS_DIR / ".cache"
KATANA_BIN_DIR = TOOLS_DIR / "bin"


def current_platform_key() -> str:
    """返回 katana_downloads.json 中使用的平台键，例如 windows-amd64。"""
    system = platform.system().lower()
    machine = platform.machine().lower()

    os_map = {
        "windows": "windows",
        "linux": "linux",
        "darwin": "darwin",
    }
    arch_map = {
        "amd64": "amd64",
        "x86_64": "amd64",
        "386": "386",
        "i386": "386",
        "i686": "386",
        "x86": "386",
        "arm64": "arm64",
        "aarch64": "arm64",
    }

    os_name = os_map.get(system)
    arch_name = arch_map.get(machine)
    if not os_name or not arch_name:
        raise RuntimeError(f"不支持的系统架构: {system}-{machine}")
    return f"{os_name}-{arch_name}"


def katana_binary_name(platform_key: str = None) -> str:
    platform_key = platform_key or current_platform_key()
    return "katana.exe" if platform_key.startswith("windows-") else "katana"


def is_executable_candidate(path: Path, platform_key: str = None) -> bool:
    """判断候选 Katana 文件是否适配当前平台。"""
    platform_key = platform_key or current_platform_key()
    if not path.exists() or not path.is_file():
        return False
    if platform_key.startswith("windows-"):
        return path.name.lower() == "katana.exe"
    return path.name == "katana"


def local_katana_candidates(platform_key: str = None) -> list[Path]:
    """按优先级返回当前平台可用的本地 Katana 候选路径。"""
    platform_key = platform_key or current_platform_key()
    binary_name = katana_binary_name(platform_key)
    candidates = [
        TOOLS_DIR / "bin" / platform_key / binary_name,
        Path.cwd() / "tools" / "bin" / platform_key / binary_name,
    ]
    if platform_key.startswith("windows-"):
        candidates.extend([
            TOOLS_DIR / "katana.exe",
            Path.cwd() / "tools" / "katana.exe",
        ])
    else:
        candidates.extend([
            TOOLS_DIR / "katana",
            Path.cwd() / "tools" / "katana",
        ])
    return candidates


def load_download_info(platform_key: str) -> dict:
    """读取当前平台对应的 Katana 下载信息。"""
    if not KATANA_DOWNLOADS_FILE.exists():
        raise FileNotFoundError(f"未找到下载清单: {KATANA_DOWNLOADS_FILE}")

    with open(KATANA_DOWNLOADS_FILE, "r", encoding="utf-8") as f:
        manifest = json.load(f)

    # 兼容几种清单结构：
    # {"katana": {"linux-amd64": {...}}}
    # {"downloads": {"linux-amd64": {...}}}
    # {"linux-amd64": {...}}
    downloads = manifest.get("katana") or manifest.get("downloads") or manifest
    info = downloads.get(platform_key)
    if isinstance(info, str):
        info = {"url": info}
    if not info or not info.get("url"):
        raise KeyError(f"{KATANA_DOWNLOADS_FILE} 中缺少 {platform_key} 的 url")
    return info


def archive_type_for(path: Path, info: dict) -> str:
    archive_type = (info.get("archive") or "").lower()
    if archive_type:
        return archive_type
    name = path.name.lower()
    if name.endswith(".zip"):
        return "zip"
    if name.endswith(".tar.gz") or name.endswith(".tgz"):
        return "tar.gz"
    raise ValueError(f"无法识别压缩包格式: {path.name}")


def verify_sha256(path: Path, expected_sha256: str = None):
    if not expected_sha256:
        return
    digest = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1024 * 1024), b""):
            digest.update(chunk)
    actual = digest.hexdigest().lower()
    if actual != expected_sha256.lower():
        raise ValueError(f"Katana 下载文件 SHA256 不匹配: {actual}")


def download_file(url: str, dest: Path):
    dest.parent.mkdir(parents=True, exist_ok=True)
    print(f"  [*] 下载 Katana: {url}")
    with urllib.request.urlopen(url, timeout=60) as response:
        with open(dest, "wb") as f:
            shutil.copyfileobj(response, f)


def extract_archive(archive_path: Path, extract_dir: Path, info: dict):
    if extract_dir.exists():
        shutil.rmtree(extract_dir)
    extract_dir.mkdir(parents=True, exist_ok=True)

    archive_type = archive_type_for(archive_path, info)
    if archive_type == "zip":
        with zipfile.ZipFile(archive_path) as zf:
            for member in zf.infolist():
                target = (extract_dir / member.filename).resolve()
                if not str(target).startswith(str(extract_dir.resolve())):
                    raise ValueError(f"压缩包包含非法路径: {member.filename}")
            zf.extractall(extract_dir)
    elif archive_type in {"tar.gz", "tgz"}:
        with tarfile.open(archive_path, "r:gz") as tf:
            for member in tf.getmembers():
                target = (extract_dir / member.name).resolve()
                if not str(target).startswith(str(extract_dir.resolve())):
                    raise ValueError(f"压缩包包含非法路径: {member.name}")
            tf.extractall(extract_dir)
    else:
        raise ValueError(f"不支持的压缩包格式: {archive_type}")


def install_katana_from_manifest(platform_key: str) -> Path:
    """按当前平台下载并安装 Katana，返回安装后的可执行文件路径。"""
    info = load_download_info(platform_key)
    url = info["url"]
    binary_name = info.get("binary") or katana_binary_name(platform_key)
    archive_name = info.get("filename") or url.rstrip("/").split("/")[-1]
    if not archive_name:
        archive_name = f"katana-{platform_key}.archive"

    archive_path = KATANA_CACHE_DIR / archive_name
    extract_dir = KATANA_CACHE_DIR / f"extract-{platform_key}"
    target_dir = KATANA_BIN_DIR / platform_key
    target_binary = target_dir / katana_binary_name(platform_key)

    if not archive_path.exists():
        download_file(url, archive_path)
    else:
        print(f"  [*] 使用 Katana 下载缓存: {archive_path}")

    verify_sha256(archive_path, info.get("sha256"))
    extract_archive(archive_path, extract_dir, info)

    extracted_binary = None
    for candidate in extract_dir.rglob(binary_name):
        if candidate.is_file():
            extracted_binary = candidate
            break
    if not extracted_binary:
        raise FileNotFoundError(f"压缩包中未找到 Katana 可执行文件: {binary_name}")

    target_dir.mkdir(parents=True, exist_ok=True)
    shutil.copy2(extracted_binary, target_binary)
    if not platform_key.startswith("windows-"):
        os.chmod(target_binary, target_binary.stat().st_mode | 0o755)

    print(f"  [+] Katana 已安装: {target_binary}")
    return target_binary


def find_katana() -> str:
    """查找或自动安装 katana 可执行文件。"""
    try:
        platform_key = current_platform_key()
        for p in local_katana_candidates(platform_key):
            if is_executable_candidate(p, platform_key):
                return str(p)

        # PATH 中的 katana 作为跨平台兜底。
        path_katana = shutil.which("katana")
        if path_katana:
            return path_katana

        installed = install_katana_from_manifest(platform_key)
        return str(installed)
    except Exception as e:
        print("[ERROR] 找不到或无法自动安装 katana 可执行文件", file=sys.stderr)
        print(f"[INFO] 当前平台: {platform.system().lower()}-{platform.machine().lower()}", file=sys.stderr)
        print(f"[INFO] 请确保 tools/katana_downloads.json 包含当前平台下载链接，或 katana 在 PATH 中", file=sys.stderr)
        print(f"[INFO] 详细错误: {e}", file=sys.stderr)
        sys.exit(1)


def run_katana(target_url: str, output_dir: str,
               cookies: str = None, auth_header: str = None,
               label: str = "anonymous") -> tuple:
    """
    执行 Katana 爬取。

    Returns:
        (url_count, error_or_none)
    """
    katana = find_katana()
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    output_file = output_path / f"crawled_{label}.jsonl"

    cmd = [katana, "-u", target_url] + KATANA_BASE_ARGS.copy()

    # 认证模式：注入 Cookie 和 Authorization
    if cookies:
        cmd.extend(["-H", f"Cookie: {cookies}"])
        print(f"  [*] Cookie 认证: {cookies[:60]}...")
    if auth_header:
        cmd.extend(["-H", auth_header])
        print(f"  [*] Bearer 认证: {auth_header[:60]}...")

    cmd.extend(["-o", str(output_file)])

    print(f"  [*] 开始爬取 ({label}): {target_url}")
    print(f"  [*] 命令: {' '.join(cmd)}")

    try:
        # 使用 encoding='utf-8', errors='replace' 而非 text=True
        # Windows 中文系统默认编码为 GBK，但 Katana 输出的是 UTF-8 字节流
        # 当爬取非 UTF-8 站点（如 GBK 编码的 Discuz）时，页面内容中的
        # 字节在 UTF-8 视角下可能无效，text=True 会直接抛出 UnicodeDecodeError
        # errors='replace' 用  替代无效字节，确保脚本不会因编码问题崩溃
        result = subprocess.run(
            cmd,
            capture_output=True,
            encoding='utf-8',
            errors='replace',
            timeout=600,  # 单次爬取最多 10 分钟
        )

        if result.returncode != 0:
            stderr_preview = result.stderr[-500:] if result.stderr else ""
            print(f"  [!] Katana 退出码 {result.returncode}")
            if stderr_preview:
                print(f"  [!] stderr: {stderr_preview}")

        # 统计爬取到的 URL 数量
        url_count = 0
        if output_file.exists():
            with open(output_file, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line:
                        url_count += 1

        print(f"  [+] 爬取完成 ({label}): {url_count} 个发现 → {output_file}")
        return url_count, None

    except subprocess.TimeoutExpired:
        print(f"  [!] 爬取超时 ({label})")
        return 0, "timeout"
    except Exception as e:
        print(f"  [!] 爬取失败 ({label}): {e}")
        return 0, str(e)


def load_credentials(sessions_dir: str) -> list:
    """
    从 sessions 目录加载凭证。

    Returns:
        list of dicts: [{"label": "admin", "cookies": "...", "auth_header": "..."}, ...]
    """
    sessions_path = Path(sessions_dir)
    if not sessions_path.exists():
        print(f"  [*] 未找到 sessions 目录，将使用匿名爬取")
        return []

    accounts = []
    for f in sorted(sessions_path.glob("*.json")):
        try:
            with open(f, 'r', encoding='utf-8') as fh:
                data = json.load(fh)
            roles = data.get("roles", [])
            for role in roles:
                label = role.get("role_label", f.stem)
                cookies = role.get("auth_cookie_string", "")
                auth_header = None
                for ah in role.get("auth_headers", []):
                    if ah.get("name") == "Authorization":
                        auth_header = f'{ah["name"]}: {ah["value"]}'
                        break
                if cookies or auth_header:
                    accounts.append({
                        "label": label,
                        "cookies": cookies,
                        "auth_header": auth_header,
                    })
                    print(f"  [+] 加载凭证: {label} (cookies={bool(cookies)}, bearer={bool(auth_header)})")
        except (json.JSONDecodeError, IOError) as e:
            print(f"  [!] 读取失败 {f.name}: {e}")

    return accounts


def main():
    parser = argparse.ArgumentParser(description="Vibe Pentest — Katana web crawler")
    parser.add_argument("--url", "-u", required=True, help="目标 URL")
    parser.add_argument("--output", "-o", default="output", help="输出目录")
    parser.add_argument("--sessions", "-s", help="sessions 凭证目录路径")
    parser.add_argument("--anonymous-only", action="store_true", help="仅匿名爬取，忽略 sessions")
    args = parser.parse_args()

    output_dir = args.output
    print(f"[*] 目标: {args.url}")
    print(f"[*] 输出: {output_dir}")

    all_results = []

    # 匿名爬取（始终执行）
    print(f"\n--- 匿名爬取 ---")
    count, err = run_katana(args.url, output_dir, label="anonymous")
    all_results.append({"label": "anonymous", "count": count, "error": err})

    # 认证爬取（如有凭证）
    if not args.anonymous_only:
        accounts = load_credentials(args.sessions or "sessions")
        for acc in accounts:
            print(f"\n--- 认证爬取: {acc['label']} ---")
            count, err = run_katana(
                args.url, output_dir,
                cookies=acc["cookies"],
                auth_header=acc["auth_header"],
                label=acc["label"],
            )
            all_results.append({"label": acc["label"], "count": count, "error": err})

    # 汇总
    print(f"\n{'=' * 50}")
    print(f"  爬取汇总:")
    for r in all_results:
        status = f"{r['count']} URLs" if not r["error"] else f"失败: {r['error']}"
        print(f"    [{r['label']}] {status}")
    total = sum(r["count"] for r in all_results)
    print(f"  总计: {total} URLs")
    print(f"{'=' * 50}")

    # 输出汇总 JSON
    summary_file = Path(output_dir) / "crawl_summary.json"
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump({
            "target": args.url,
            "sessions": all_results,
            "total_urls": total,
        }, f, indent=2, ensure_ascii=False)
    print(f"\n  [+] 汇总已保存: {summary_file}")


if __name__ == "__main__":
    main()
