#!/usr/bin/env python3
"""
Phase 4: Data cleaning and target filtering for Vibe Pentest.

Reads raw Katana crawl results, filters out:
- URLs not belonging to the target domain
- External domains and IPs
- Static resources (except .js for API endpoint extraction)
- Duplicate URLs

Outputs:
- targets.txt: Cleaned URL list
- requests.json: Structured request list (method, path, params)
"""

import argparse
import json
import re
import sys
from collections import defaultdict
from pathlib import Path
from urllib.parse import urlparse, parse_qs, urlunparse


# Static resource extensions to exclude
EXCLUDE_EXTENSIONS = {
    '.css', '.png', '.jpg', '.jpeg', '.gif', '.svg', '.ico',
    '.woff', '.woff2', '.ttf', '.eot', '.otf',
    '.mp4', '.webm', '.avi', '.mov', '.mp3', '.wav',
    '.zip', '.tar', '.gz', '.rar', '.7z',
    '.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx',
    '.bmp', '.webp', '.tiff',
}

# Keep .js files for API endpoint extraction
KEEP_JS = True


def extract_domain(url: str) -> str:
    """Extract domain from URL."""
    try:
        parsed = urlparse(url)
        return parsed.hostname or ""
    except Exception:
        return ""


def is_ip(host: str) -> bool:
    """Check if host is an IP address."""
    return bool(re.match(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', host))


def is_static_resource(path: str, keep_js: bool = True) -> bool:
    """Check if path is a static resource to exclude."""
    # Get extension
    clean_path = path.split('?')[0]
    ext = Path(clean_path).suffix.lower()

    if not ext:
        return False

    if ext == '.js' and keep_js:
        return False

    return ext in EXCLUDE_EXTENSIONS


def clean_urls(raw_file: str, target_domain: str, allow_subdomains: bool = True) -> tuple:
    """
    Clean raw crawl results.

    Returns:
        (targets: list of cleaned URLs, requests: list of structured request dicts)
    """
    # Read raw file
    raw_path = Path(raw_file)
    if not raw_path.exists():
        print(f"Error: {raw_file} not found")
        sys.exit(1)

    with open(raw_path, 'r', encoding='utf-8') as f:
        raw_lines = [line.strip() for line in f if line.strip()]

    print(f"[*] 原始爬取结果: {len(raw_lines)} 条记录")

    # Parse lines — support both plain URL format and Katana JSONL format
    urls = []
    for line in raw_lines:
        if line.startswith(('http://', 'https://')):
            urls.append(line)
        elif line.startswith('{'):
            # Katana JSONL format: {"request": {"endpoint": "..."}}
            try:
                data = json.loads(line)
                endpoint = data.get('request', {}).get('endpoint', '')
                if endpoint:
                    urls.append(endpoint)
            except json.JSONDecodeError:
                continue
        else:
            continue

    print(f"[*] 提取到 URL: {len(urls)} 条")

    # Normalize target domain: remove port for comparison
    target_host = target_domain
    if target_host.startswith(('http://', 'https://')):
        target_host = urlparse(target_host).hostname or target_host
    target_base = target_host.split(':')[0]

    # Filter
    seen = set()
    cleaned_urls = []
    structured_requests = []

    for line in urls:
        # Parse URL
        try:
            parsed = urlparse(line)
        except Exception:
            continue

        host = parsed.hostname or ""
        port = parsed.port
        path = parsed.path

        # Domain filter — compare base hostname (without port)
        if allow_subdomains:
            if not (host == target_base or host.endswith('.' + target_base)):
                continue
        else:
            if host != target_base:
                continue

        # IP filter (unless target itself is IP)
        if is_ip(host) and not is_ip(target_base):
            continue

        # Static resource filter
        if is_static_resource(path):
            continue

        # Deduplicate
        # Normalize: strip trailing slash, lowercase path
        normalized = line.rstrip('/').lower()
        if normalized in seen:
            continue
        seen.add(normalized)

        cleaned_urls.append(line)

        # Build structured request
        method = "GET"
        params = parse_qs(parsed.query)
        query_string = parsed.query

        structured_requests.append({
            "method": method,
            "url": line,
            "path": path,
            "query_string": query_string,
            "params": params,
            "fragment": parsed.fragment,
        })

    print(f"[+] 清洗后: {len(cleaned_urls)} 条 URL")
    print(f"[+] 结构化请求: {len(structured_requests)} 条")

    return cleaned_urls, structured_requests


def extract_api_patterns(js_urls: list) -> dict:
    """Extract potential API endpoints from .js file URLs."""
    patterns = {
        "api_paths": [],
        "potential_endpoints": [],
    }

    for url in js_urls:
        if url.endswith('.js') or '.js?' in url:
            patterns["api_paths"].append(url)

    return patterns


def main():
    parser = argparse.ArgumentParser(description="Clean and filter Katana crawl results")
    parser.add_argument("raw_file", help="Path to raw Katana output file")
    parser.add_argument("--domain", required=True, help="Target domain to filter on")
    parser.add_argument("--output", default="targets.txt", help="Output targets file")
    parser.add_argument("--requests-output", default="requests.json", help="Output structured requests file")
    parser.add_argument("--no-subdomains", action="store_true", help="Do not allow subdomains")
    args = parser.parse_args()

    target_domain = args.domain
    # Normalize: remove scheme and port
    if target_domain.startswith(('http://', 'https://')):
        target_domain = urlparse(target_domain).hostname or target_domain
    target_domain = target_domain.split(':')[0]

    cleaned_urls, structured_requests = clean_urls(
        args.raw_file,
        target_domain,
        allow_subdomains=not args.no_subdomains,
    )

    # Write targets.txt
    with open(args.output, 'w', encoding='utf-8') as f:
        for url in cleaned_urls:
            f.write(url + '\n')
    print(f"[+] 目标列表已保存至: {args.output}")

    # Write requests.json
    with open(args.requests_output, 'w', encoding='utf-8') as f:
        json.dump(structured_requests, f, indent=2, ensure_ascii=False)
    print(f"[+] 结构化请求已保存至: {args.requests_output}")

    # Extract JS files for API pattern analysis
    js_urls = [u for u in cleaned_urls if u.endswith('.js') or '.js?' in u]
    if js_urls:
        api_patterns = extract_api_patterns(js_urls)
        with open("api_patterns.json", 'w', encoding='utf-8') as f:
            json.dump(api_patterns, f, indent=2, ensure_ascii=False)
        print(f"[+] 发现 {len(js_urls)} 个 JS 文件，已提取 API 模式")


if __name__ == "__main__":
    main()
