有没有 GO 的大佬,能否帮忙修改这个反向代理的代码,是从 redis 里拿出 SSL 证书和反代的目标地址,
目前主要是全局变量的问题,要修改为并发安全(刚入门 GO 还没研究明白),
问 AI 也没能解决他只是建议用上下文(或许是免费的 AI 不行)。
非常感谢!
package main
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/redis/go-redis/v9"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
)
var (
httpAddr = ":80"
httpsAddr = ":443"
redisClient *redis.Client
)
type proxyInfo struct {
targetUrl string
requestPath string
requestRawQuery string
requestHeader map[string]string
}
func init() {
redisClient = redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
Password: "",
DB: 0,
})
}
func main() {
//创建 httpTCP
tcpConn, err := net.Listen("tcp", httpAddr)
if err != nil {
panic(err)
}
defer tcpConn.Close()
//创建 httpsTCP
tcpsConn, err := net.Listen("tcp", httpsAddr)
if err != nil {
panic(err)
}
defer tcpsConn.Close()
pi := &proxyInfo{}
tlsConn := tls.NewListener(tcpsConn, &tls.Config{
GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return pi.getCertificate(clientHello)
},
})
httpServer := &http.Server{
Handler: pi.proxyRequestHandler(),
}
go func() {
httpServer.Serve(tcpConn)
}()
go func() {
httpServer.Serve(tlsConn)
}()
select {}
}
// 反向代理
func (pi *proxyInfo) newProxy() (*httputil.ReverseProxy, error) {
targetUrl, err := url.Parse(pi.targetUrl)
if err != nil {
return nil, err
}
targetUrl.Path = pi.requestPath
targetUrl.RawQuery = pi.requestRawQuery
fmt.Println("反代的地址:", targetUrl.String())
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
//连接配置
proxy.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 60 * time.Second,
KeepAlive: 60 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConnsPerHost: 20,
}
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.URL = targetUrl
req.Host = targetUrl.Host
for k, v := range pi.requestHeader {
//fmt.Println("添加请求头:", k, v)
req.Header.Set(k, v)
}
}
proxy.ModifyResponse = pi.modifyResponse()
proxy.ErrorHandler = pi.errorHandler()
return proxy, nil
}
// 根据客户端 ClientHello 查询 redis 里域名信息
func (pi *proxyInfo) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
hostName := clientHello.ServerName
//判断不符合域名长度的 SSL 请求
if len(hostName) < 4 {
return nil, errors.New(hostName + ",域名长度不符合")
}
//查询 redis 里的域名 SSL 证书
hostConf, err := pi.getHostConf(hostName)
if err != nil {
return nil, err
}
certPublic := []byte(hostConf["certPublic"])
certPrivate := []byte(hostConf["certPrivate"])
certAndKey, err := tls.X509KeyPair(certPublic, certPrivate)
if err != nil {
return nil, err
}
return &certAndKey, nil
}
// 处理代理请求
func (pi *proxyInfo) proxyRequestHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
//不是 https 的请求
if r.TLS == nil {
_, err := pi.getHostConf(getHostName(r.Host))
if err != nil {
w.WriteHeader( http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
}
pi.requestPath = r.URL.Path
pi.requestRawQuery = r.URL.RawQuery
requestHeader := make(map[string]string)
requestHeader["Referer"] = pi.targetUrl
requestHeader["User-Agent"] = r.Header.Get("User-Agent")
requestHeader["Accept"] = r.Header.Get("Accept")
pi.requestHeader = requestHeader
//反代
proxy, err := pi.newProxy()
if err != nil {
panic(err)
}
proxy.ServeHTTP(w, r)
}
}
// 修改 http 响应数据
func (pi *proxyInfo) modifyResponse() func(*http.Response) error {
return func(r *http.Response) error {
typeStr := r.Header.Get("Content-Type")
fmt.Println(typeStr)
return nil
}
}
// 错误处理器
func (pi *proxyInfo) errorHandler() func( http.ResponseWriter, *http.Request, error) {
return func(w http.ResponseWriter, req *http.Request, err error) {
fmt.Printf("Got error while modifying response: %v \n", err)
w.WriteHeader( http.StatusInternalServerError)
w.Write([]byte("server error"))
return
}
}
// 获取域名的配置信息
func (pi *proxyInfo) getHostConf(hostName string) (map[string]string, error) {
hostConf, err := redisClient.HGetAll(context.Background(), hostName).Result()
if err != nil {
return nil, err
}
//模拟返回 SSL 证书
//hostConf["certPublic"] = "-----BEGIN CERTIFICATE-----\n"
//hostConf["certPrivate"] = "-----BEGIN CERTIFICATE-----\n"
//反代的目标网址
//hostConf["targetUrl"] = "https://www.baidu.com"
//反代的目标网址
pi.targetUrl = hostConf["targetUrl"]
return hostConf, nil
}
// 获取不含端口的 host
func getHostName(rawUrl string) string {
if !strings.HasPrefix("http://", rawUrl) || !strings.HasPrefix("https://", rawUrl) {
rawUrl = "http://" + rawUrl
}
u, err := url.Parse(rawUrl)
if err != nil {
return ""
}
return u.Hostname()
}
|