import requests
import yaml
import os
import warnings
import time
import urllib3
import logging
from urllib.parse import urlparse, urljoin
from colorama import init, Fore

init()
os.system("")
warnings.filterwarnings("ignore")
# 禁用不安全请求警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# 配置日志记录
logging.basicConfig(level=logging.INFO, format=Fore.GREEN + '%(asctime)s' + Fore.RESET + ' - %(message)s', datefmt='%H:%M')

DEFAULT_HEADERS = {
    'Accept': '*/*',
    'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6) AppleWebKit/603.3.8 (KHTML, like Gecko) Version/10.1.2 Safari/603.3.8',
    'Referer': 'https://www.baidu.com',
    'Accept-Encoding': 'gzip, deflate',
    'Connection': 'keep-alive',
}

def check_url_status(url):
    try:
        response = requests.get(url, timeout=5, verify=False)
        # return response.status_code == 200
        return response.status_code in (200, 301, 302, 307)
    except requests.RequestException as e:
        # logging.error(f"Error checking URL status: {e}")
        return False

def load_yaml_file(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as stream:
            return yaml.safe_load(stream)
    except (yaml.YAMLError, FileNotFoundError) as e:
        logging.error(f"无法解析 {file_path}: {e}")
        return None

def send_request_and_validate(request_details, response_details, target_url):
    method = request_details.get('method', 'GET').lower()
    path = request_details.get('path', '/default_path')
    path_2 = response_details.get('path', '')
    url = urljoin(target_url, path)
    if path_2:
        url_2 = urljoin(target_url, path_2)
    # url_2 = urljoin(target_url, path_2) if path_2 else url
    logging.info(f"已发送攻击请求: {url}")

    headers = {key: value for key, value in {**DEFAULT_HEADERS, **request_details.get('headers', {})}.items() if value}
    data = request_details.get('body-raw', '')

    logging.info(f"请求方法: {method.upper()}")
    logging.info(f"请求头: {headers}")

    try:
        # 记录请求开始时间
        start_time = time.time()

        response = requests.request(method, url, headers=headers, data=data, verify=False)
        if path_2:
            response = requests.request("GET", url_2, headers=DEFAULT_HEADERS, verify=False)

        # 记录请求结束时间并计算响应时间
        response_time = round(time.time() - start_time, 2)

        logging.info(f"响应时间: {response_time} 秒")

        expected_status_code = response_details.get('status-code', 200)
        logging.info(f"攻击请求响应码: {expected_status_code} ; 实际响应码: {response.status_code}")
        assert response.status_code == expected_status_code, f"验证获得失败响应码: {response.status_code}"

        result_1 = validate_response_body(response.text, response_details.get('body', ''))
        result = result_1

    except AssertionError as e:
        logging.error(f"错误: {e}")
        result = "不存在漏洞"
    except requests.RequestException as e:
        logging.error(f"HTTP请求错误: {e}")
        result = "请求失败"
    except Exception as e:
        logging.error(f"请求过程中发送了错误: {e}")
        result = "请求失败"
    return result, response.status_code, url, response_time

def validate_response_body(response_body, expected_body):
    if expected_body and expected_body not in response_body:
        logging.warning("响应体匹配失败。")
        return "不存在漏洞"
    logging.info("响应体成功匹配!")
    # print("\n\n" + response_body + "\n\n")
    return "存在漏洞"

def validate_main(target_url):
    results = []
    poc_paths = load_yaml_file_paths('./MatchedPOC.txt')
    for poc_path in poc_paths:
        # print("--------------------------------------------------")
        yaml_content = load_yaml_file(poc_path)
        if yaml_content is None:
            continue
        
        name = print_extracted_info(yaml_content)
        result, status_code, url, res_time = send_request_and_validate(yaml_content.get('requests', {}), yaml_content.get('response', {}), target_url)
        results.append((name, result, status_code, url, res_time))
        # 提取描述部分
        description = yaml_content.get("description", "Description not found")
        print("--------------------------------------------------")
    
    return results, description

def load_yaml_file_paths(file_path):
    try:
        with open(file_path, 'r') as file:
            return file.read().splitlines()
    except FileNotFoundError as e:
        logging.error(f"未找到对应 POC 路径文件: {file_path}: {e}")
        return []

def print_extracted_info(yaml_content):
    keyword = yaml_content.get('keyword', '')
    name = yaml_content.get('name', '')
    description = yaml_content.get('description', '')
    impact = yaml_content.get('impact', '')

    logging.info(f"关键词: {keyword}")
    logging.info(f"漏洞名称: {name}")
    logging.info(f"描述: {description}")
    logging.info(f"影响: {impact}")
    return name

def remove_url_suffix(url):
    parsed_url = urlparse(url)
    return f"{parsed_url.scheme}://{parsed_url.netloc}"

def poc_scan(target_url):
    status = check_url_status(target_url)
    if not status:
        logging.error(Fore.RED +"目标无法访问,请检查目标地址是否正确!" + Fore.RESET)
        return
    new_url = remove_url_suffix(target_url)
    validate_main(new_url)

if __name__ == "__main__":
    target_url = input("请输入待测目标,如:https://example.com\n")
    poc_scan(target_url)