File size: 4,507 Bytes
bd7bd10
 
 
 
 
 
2e4a225
bd7bd10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e4a225
bd7bd10
2e4a225
bd7bd10
2e4a225
 
 
 
 
 
 
 
 
bd7bd10
04860cd
2e4a225
bd7bd10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import requests
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse
import uvicorn
import os
import logging # 添加 logging 导入
import json # 添加 json 导入

# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

app = FastAPI()

# 允许的目标域名列表(可选,增加安全性)
# ALLOWED_HOSTS = {"google.com", "example.com"}

@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def reverse_proxy(request: Request, path: str):
    """
    反向代理到目标 URL。
    从路径中提取目标 URL,例如 /google.com/search?q=test -> https://google.com/search?q=test
    """
    target_url_str = path

    # 简单的检查,确保路径看起来像一个域名或包含协议
    if not ('.' in target_url_str or target_url_str.startswith(('http://', 'https://'))):
        return Response(content="Invalid target URL format in path.", status_code=400)

    # 如果路径不包含协议,默认添加 https://
    if not target_url_str.startswith(('http://', 'https://')):
        target_url_str = f"https://{target_url_str}"

    # 提取域名进行检查(可选)
    # try:
    #     from urllib.parse import urlparse
    #     parsed_url = urlparse(target_url_str)
    #     if parsed_url.netloc not in ALLOWED_HOSTS:
    #         return Response(content=f"Host not allowed: {parsed_url.netloc}", status_code=403)
    # except Exception:
    #     return Response(content="Could not parse target URL.", status_code=400)

    # 准备目标请求的 headers,复制客户端 headers,但移除 Host
    headers = {key: value for key, value in request.headers.items() if key.lower() != 'host'}
    # 可以选择性地添加或修改 headers
    # headers['X-Forwarded-For'] = request.client.host

    # 获取请求体
    body = await request.body()

    # 如果是 POST 请求,记录请求体内容
    if request.method == "POST":
        logging.info(f"Received POST request to {path}. Body follows:")
        try:
            # 尝试解码为 UTF-8 文本
            body_text = body.decode('utf-8')
            try:
                # 尝试将文本解析为 JSON
                body_json = json.loads(body_text)
                # 如果成功,格式化 JSON 并记录
                formatted_json = json.dumps(body_json, indent=2, ensure_ascii=False) # indent=2 用于缩进,ensure_ascii=False 支持中文
                logging.info(formatted_json)
            except json.JSONDecodeError:
                # 如果不是有效的 JSON,按原样记录文本(已包含换行处理)
                logging.info(body_text)
        except UnicodeDecodeError:
            # 如果解码失败,记录原始字节信息
            logging.info(f"Body (bytes): {body}")

    try:
        # 发送请求到目标服务器,允许重定向,不验证 SSL 证书(在某些情况下可能需要)
        target_response = requests.request(
            method=request.method,
            url=target_url_str,
            headers=headers,
            data=body,
            stream=True,  # 使用流式传输处理大文件或长时间响应
            allow_redirects=True, # 允许目标服务器重定向
            verify=True # 通常应保持 True,除非你知道目标证书有问题
        )

        # 过滤掉不应转发的响应头 (例如 'transfer-encoding', 'content-encoding', 'content-length')
        response_headers = {
            k: v for k, v in target_response.headers.items()
            if k.lower() not in ['transfer-encoding', 'content-encoding', 'content-length']
        }

        # 使用 StreamingResponse 将目标响应流式传输回客户端
        return StreamingResponse(
            target_response.iter_content(chunk_size=8192),
            status_code=target_response.status_code,
            headers=response_headers,
            media_type=target_response.headers.get('content-type')
        )

    except requests.exceptions.RequestException as e:
        return Response(content=f"Error connecting to target server: {e}", status_code=502) # Bad Gateway
    except Exception as e:
        return Response(content=f"An unexpected error occurred: {e}", status_code=500)

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces 通常使用 7860 端口
    uvicorn.run(app, host="0.0.0.0", port=port)