#!/usr/bin/env python3
"""漏洞扫描系统 API 封装脚本 — 下发任务 → 等待完成 → 生成并下载报告"""

import argparse
import base64
import sys
import time
from pathlib import Path

import requests
import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

RETRIES = 3
RETRY_BACKOFF = 2  # seconds, doubles each retry


class VulnerabilityScanner:
    def __init__(self, base_url: str, access_token: str, user_id: int = 1):
        self.base_url = base_url.rstrip("/")
        self.access_token = access_token
        self.user_id = user_id

    # ------------------------------------------------------------------ helpers

    def _request(self, method: str, path: str, **kwargs) -> requests.Response:
        """带重试的 HTTP 请求，处理连接断开等瞬时错误。"""
        url = f"{self.base_url}{path}"
        kwargs.setdefault("timeout", 30)
        kwargs.setdefault("verify", False)
        last_exc = None
        for attempt in range(1, RETRIES + 1):
            try:
                r = requests.request(method, url, **kwargs)
                if not r.ok:
                    print(f"[HTTP {r.status_code}] {r.text[:200]}")
                r.raise_for_status()
                return r
            except requests.exceptions.ConnectionError as e:
                last_exc = e
                if attempt < RETRIES:
                    wait = RETRY_BACKOFF ** attempt
                    print(f"  连接断开，{wait}s 后重试 ({attempt}/{RETRIES})...")
                    time.sleep(wait)
            except requests.exceptions.Timeout as e:
                last_exc = e
                if attempt < RETRIES:
                    wait = RETRY_BACKOFF ** attempt
                    print(f"  请求超时，{wait}s 后重试 ({attempt}/{RETRIES})...")
                    time.sleep(wait)
        raise last_exc  # type: ignore[misc]

    def _post(self, path: str, data: dict | None = None, files: dict | None = None) -> dict:
        r = self._request("POST", path, data=data, files=files)
        return r.json()

    def _get(self, path: str, params: dict | None = None) -> dict:
        r = self._request("GET", path, params=params)
        if r.status_code == 204:
            return {}
        return r.json()

    # ------------------------------------------------------------------ API

    def resolve_user_id(self) -> int:
        """获取 user_id，默认使用 1。"""
        return self.user_id

    def create_task(
        self,
        targets: str,
        port: str,
        *,
        need_poc: int = 1,
        need_cpe: int = 1,
        need_pw: int = 0,
        need_all_port: int = 0,
        scan_speed: int = -2,
    ) -> int:
        """创建扫描任务，返回 task_id。"""
        task_name = f"scan_{targets}_{port}_{int(time.time())}"
        encoded = base64.b64encode(task_name.encode()).decode()

        data = {
            "access_token": self.access_token,
            "user_id": str(self.user_id),
            "task_name": encoded,
            "targets": targets,
            "need_poc": str(need_poc),
            "need_cpe": str(need_cpe),
            "need_pw": str(need_pw),
            "need_all_port": str(need_all_port),
            "custom_scan_port": str(port),
            "scan_speed": str(scan_speed),
        }
        result = self._post("/v2/", data)
        task_id = result.get("task_id")
        print(f"任务创建成功 — task_id: {task_id}  ({task_name})")
        return task_id

    def get_task_status(self, task_id: int) -> dict:
        """返回单条任务状态 dict，含 status / finish_rate / ..."""
        resp = self._get(f"/v1/{task_id}", {"user_id": str(self.user_id), "access_token": self.access_token})
        tasks = resp.get("task", [])
        return tasks[0] if tasks else {}

    def wait_for_scan(self, task_id: int, interval: int = 10) -> dict:
        """轮询直到扫描完成 (status=4)，返回最终状态。"""
        print("等待扫描完成...")
        while True:
            t = self.get_task_status(task_id)
            status = t.get("status")
            rate = t.get("finish_rate", 0)
            print(f"  进度: {rate}%  status={status}  "
                  f"高危={t.get('high_risk',0)} 中危={t.get('middle_risk',0)} 低危={t.get('low_risk',0)}")

            if status == 4:
                print("扫描完成")
                return t
            if status == 3:
                sys.exit("扫描异常，终止")
            time.sleep(interval)

    def get_scan_result(self, task_id: int, host: str | None = None, *, num: int = 10) -> dict:
        """获取扫描结果详情。"""
        if host:
            path = f"/v2/task/{task_id}/{host}"
        else:
            path = f"/v2/task/{task_id}"
        return self._get(path, {
            "access_token": self.access_token,
            "user_id": str(self.user_id),
            "num": str(num),
        })

    def generate_report(self, task_id: int, file_type: int = 1, filename: str = "report") -> int:
        """生成报表，返回报表 task_id。"""
        result = self._post("/v1/report/1", data={
            "access_token": self.access_token,
            "task_id": str(task_id),
            "file_type": str(file_type),
            "filename": filename,
        })
        report_task_id = result.get("task_id")
        print(f"报表生成任务已创建 — task_id: {report_task_id}")
        return report_task_id

    def wait_for_report(self, report_task_id: int, interval: int = 5) -> str:
        """轮询直到报表生成完成，返回 filename（用于下载）。"""
        print("等待报表生成...")
        while True:
            result = self._get(
                f"/v1/report/1/{report_task_id}",
                {"access_token": self.access_token},
            )
            if not result:
                time.sleep(interval)
                continue

            pct = result.get("percentage", 0)
            ts = result.get("task_status", 0)
            filename = result.get("filename", "")
            print(f"  报表进度: {pct}%  status={ts}")

            if ts == 1 and filename:
                print("报表生成完成")
                return filename
            time.sleep(interval)

    def download_report(self, filename: str, output_path: str) -> Path:
        """下载报表并保存为 .zip 文件。"""
        # 直接用 URL 拼接避免 requests 二次编码 f 参数
        path = f"/v1/download?access_token={self.access_token}&f={filename}"
        r = self._request("GET", path, stream=True, timeout=60)

        out = Path(output_path)
        with open(out, "wb") as f:
            for chunk in r.iter_content(8192):
                f.write(chunk)
        print(f"报表已保存: {out}  ({out.stat().st_size / 1024:.0f} KB)")
        return out

    # --------------------------------------------------------------- pipeline

    def run(
        self,
        ip: str,
        port: str,
        *,
        output_dir: str = ".",
        file_type: int = 1,
        need_poc: int = 1,
        need_cpe: int = 1,
        need_pw: int = 0,
        need_all_port: int = 0,
        scan_speed: int = -2,
    ) -> Path:
        """一键执行：创建任务 → 等待扫描 → 生成报表 → 下载。"""
        print("=" * 52)
        print(f"漏洞扫描: {ip}:{port}")
        print(f"接口地址: {self.base_url}")
        print("=" * 52)

        # 1. 确认用户
        uid = self.resolve_user_id()
        print(f"使用 user_id: {uid}\n")

        # 2. 创建扫描任务
        ts = int(time.time())
        scan_task_id = self.create_task(ip, port,
                                        need_poc=need_poc, need_cpe=need_cpe,
                                        need_pw=need_pw, need_all_port=need_all_port,
                                        scan_speed=scan_speed)

        # 3. 等待扫描完成
        self.wait_for_scan(scan_task_id)

        # 4. 生成报表
        report_name = f"scan_{ip}_{port}_{ts}"
        report_task_id = self.generate_report(scan_task_id, file_type=file_type, filename=report_name)

        # 5. 等待报表生成
        filename = self.wait_for_report(report_task_id)

        # 6. 下载
        output_path = Path(output_dir) / f"scan_{ip}_{port}_{ts}.zip"
        return self.download_report(filename, str(output_path))


# -------------------------------------------------------------------- CLI

def main():
    parser = argparse.ArgumentParser(
        description="漏洞扫描系统 CLI — 下发扫描任务并下载报告",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例:
  python vuln_scan.py -H 192.168.5.16 -p 8080 -u https://192.168.1.170:23000 -t xxxxxxxxxx
  python vuln_scan.py -H 192.168.5.16 -p 443  -u https://192.168.1.170:23000 -t xxxxxxxxxx -o ./reports
  python vuln_scan.py -H 192.168.5.16 -p 1-65535 --all-port -u ... -t ...
        """,
    )
    parser.add_argument("-H", "--host", required=True, help="目标 IP 或域名")
    parser.add_argument("-p", "--port", required=True, help="目标端口，多个逗号分隔，如 80,443")
    parser.add_argument("-u", "--base-url", required=True, help="接口地址，如 https://192.168.1.170:23000")
    parser.add_argument("-t", "--token", required=True, help="access_token")
    parser.add_argument("--user-id", type=int, default=1, help="用户 ID（默认: 1）")
    parser.add_argument("-o", "--output-dir", default=".", help="报告输出目录 (默认: 当前目录)")
    parser.add_argument("--all-port", action="store_true", help="全端口扫描 (默认仅扫描 -p 指定端口)")
    parser.add_argument("--need-pw", action="store_true", help="开启弱口令扫描 (默认关闭)")
    parser.add_argument("--no-poc", action="store_true", help="关闭 POC 漏洞扫描")
    parser.add_argument("--no-cpe", action="store_true", help="关闭 CPE 版本漏洞扫描")
    parser.add_argument("--file-type", type=int, default=1, choices=[1, 2, 3],
                        help="报表类型: 1=PDF, 2=Word, 3=Excel (默认: 1)")
    parser.add_argument("--speed", type=int, default=-2, choices=[-1, -2, -3],
                        help="扫描速度: -1=超快, -2=快速, -3=慢速 (默认: -2)")

    args = parser.parse_args()

    scanner = VulnerabilityScanner(
        base_url=args.base_url,
        access_token=args.token,
        user_id=args.user_id,
    )

    # 支持多端口（逗号分隔）
    ports = [p.strip() for p in args.port.split(",")]

    for port in ports:
        scanner.run(
            ip=args.host,
            port=port,
            output_dir=args.output_dir,
            file_type=args.file_type,
            need_poc=0 if args.no_poc else 1,
            need_cpe=0 if args.no_cpe else 1,
            need_pw=1 if args.need_pw else 0,
            need_all_port=1 if args.all_port else 0,
            scan_speed=args.speed,
        )


if __name__ == "__main__":
    main()
