├── .gitignore ├── healthcheck.go ├── apiserve ├── comm │ ├── definition.go │ └── report.go ├── resultcode │ └── code.go ├── router │ └── router.go ├── controller │ └── status_req_count.go └── init.go ├── common.go ├── status └── status.go ├── limit_req.go ├── go.mod ├── util ├── path.go ├── log │ ├── logrushook │ │ └── lfshook.go │ └── log.go ├── fixedqueue │ ├── fixedqueue.go │ └── fixedqueue_test.go ├── tcache │ └── tcache.go ├── http.go └── string.go ├── limit_req_test.go ├── main.go ├── defense_cc.go ├── bin └── configs │ └── conf.txt ├── config └── config.go ├── README.md ├── ocsp.go ├── tools └── getocsp │ └── get_ocsp.go ├── go.sum ├── main_test.go ├── loadconfig.go ├── sessiontickets.go └── gateway.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vscode/ 3 | baks/ 4 | pkg/ 5 | bin/logs/ -------------------------------------------------------------------------------- /healthcheck.go: -------------------------------------------------------------------------------- 1 | //对负载均衡后端服务器进行健康检查 2 | 3 | package main 4 | -------------------------------------------------------------------------------- /apiserve/comm/definition.go: -------------------------------------------------------------------------------- 1 | package comm 2 | 3 | type H map[string]interface{} 4 | -------------------------------------------------------------------------------- /apiserve/resultcode/code.go: -------------------------------------------------------------------------------- 1 | package resultcode 2 | 3 | //返回码 4 | const ( 5 | SUCCESS = 0 //成功 6 | ERR_NORMAL = -1 //通用失败 7 | ERR_RELOGIN = -2 //失败需要重新登录 8 | ) 9 | -------------------------------------------------------------------------------- /apiserve/router/router.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "github.com/gorilla/mux" 5 | "minGateway/apiserve/controller" 6 | ) 7 | 8 | func InitRouter(router *mux.Router) { 9 | if router == nil { 10 | panic("mux.Router is nil") 11 | } 12 | 13 | //获取状态 14 | router.HandleFunc("/api_status/reqcount", controller.StatusReqCount) 15 | } 16 | -------------------------------------------------------------------------------- /apiserve/comm/report.go: -------------------------------------------------------------------------------- 1 | package comm 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | ) 7 | 8 | func Report(w http.ResponseWriter, code int, msg string, body interface{}) error { 9 | w.Header().Set("Content-Type", "application/json") 10 | w.WriteHeader(http.StatusOK) 11 | err := json.NewEncoder(w).Encode(map[string]interface{}{"code": code, "msg": msg, "body": body}) 12 | return err 13 | } 14 | -------------------------------------------------------------------------------- /apiserve/controller/status_req_count.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "fmt" 5 | "minGateway/apiserve/comm" 6 | "minGateway/apiserve/resultcode" 7 | "minGateway/status" 8 | "net/http" 9 | ) 10 | 11 | func StatusReqCount(w http.ResponseWriter, r *http.Request) { 12 | err := comm.Report(w, resultcode.SUCCESS, "success", comm.H{"count": status.Instance().GetReqCount()}) 13 | if err != nil { 14 | fmt.Println(err) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /apiserve/init.go: -------------------------------------------------------------------------------- 1 | package apiserve 2 | 3 | import ( 4 | "fmt" 5 | "minGateway/apiserve/router" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/gorilla/mux" 10 | ) 11 | 12 | func Run() *http.Server { 13 | 14 | muxRouter := mux.NewRouter() 15 | router.InitRouter(muxRouter) 16 | 17 | srv := &http.Server{ 18 | Handler: muxRouter, 19 | Addr: ":9900", 20 | WriteTimeout: 15 * time.Second, 21 | ReadTimeout: 15 * time.Second, 22 | } 23 | 24 | go func() { 25 | if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { 26 | panic(err.Error()) 27 | } 28 | }() 29 | 30 | fmt.Println("\nAPI监听端口:9900") 31 | return srv 32 | } 33 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | var ipForwardeds []string 9 | 10 | // 如果消息是通过前端代理服务器转发或者cdn转发,则需要从消息头中获取IP地址(注意确保IP的真实性), 11 | // 如果消息直接来自于用户客户端,则使用req.RemoteAddr获取 12 | func getIpAddr(req *http.Request) []string { 13 | if ipForwardeds == nil { 14 | return []string{strings.Split(req.RemoteAddr, ":")[0]} 15 | } else { 16 | for _, v := range ipForwardeds { 17 | if addr, ok := req.Header[v]; ok && len(addr) > 0 { 18 | return addr 19 | } 20 | } 21 | return []string{strings.Split(req.RemoteAddr, ":")[0]} 22 | } 23 | } 24 | 25 | func assert(err error) { 26 | if err != nil { 27 | panic(err) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /status/status.go: -------------------------------------------------------------------------------- 1 | package status 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | var status *Status 8 | 9 | type Status struct { 10 | // 连接数 11 | ReqCount int 12 | ReqCountMu *sync.RWMutex 13 | } 14 | 15 | func Init() error { 16 | status = &Status{ 17 | ReqCount: 0, 18 | ReqCountMu: new(sync.RWMutex), 19 | } 20 | return nil 21 | } 22 | 23 | func Instance() *Status { 24 | return status 25 | } 26 | 27 | func (sta *Status) AddReqCount() { 28 | sta.ReqCountMu.Lock() 29 | sta.ReqCount++ 30 | sta.ReqCountMu.Unlock() 31 | } 32 | 33 | func (sta *Status) SubReqCount() { 34 | sta.ReqCountMu.Lock() 35 | sta.ReqCount-- 36 | sta.ReqCountMu.Unlock() 37 | } 38 | 39 | func (sta *Status) GetReqCount() int { 40 | sta.ReqCountMu.RLock() 41 | defer sta.ReqCountMu.RUnlock() 42 | return sta.ReqCount 43 | } 44 | -------------------------------------------------------------------------------- /limit_req.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "minGateway/config" 5 | "minGateway/util" 6 | "minGateway/util/tcache" 7 | "net/http" 8 | ) 9 | 10 | var limitReqMap *tcache.TCache 11 | 12 | func getReqCount(key string) int { 13 | hashkey := util.GetMD5(key) 14 | obj, ok := limitReqMap.Get(hashkey) 15 | if ok { 16 | count := obj.(int) 17 | limitReqMap.Set(hashkey, count+1) 18 | return count 19 | } else { 20 | limitReqMap.Set(hashkey, 1) 21 | } 22 | return 0 23 | } 24 | 25 | func ExceededLimitReq(ip string, r *http.Request) bool { 26 | var key string 27 | if config.Get().LimitReq.Mode == 0 { 28 | key = ip + r.Host + r.URL.Path 29 | } else { 30 | key = ip + r.RequestURI 31 | } 32 | count := getReqCount(key) 33 | if count >= config.Get().LimitReq.Count { 34 | return true 35 | } else { 36 | return false 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module minGateway 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/BurntSushi/toml v0.3.1 7 | github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 // indirect 8 | github.com/g4zhuj/hashring v0.0.0-20180426073119-5d542568fdbd 9 | github.com/gorilla/mux v1.7.4 10 | github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869 // indirect 11 | github.com/jonboulle/clockwork v0.2.0 // indirect 12 | github.com/lestrrat-go/file-rotatelogs v2.3.0+incompatible 13 | github.com/lestrrat-go/strftime v1.0.3 // indirect 14 | github.com/pkg/errors v0.9.1 15 | github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 16 | github.com/sirupsen/logrus v1.6.0 17 | github.com/tebeka/strftime v0.1.5 // indirect 18 | golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 19 | golang.org/x/net v0.0.0-20200707034311-ab3426394381 20 | golang.org/x/sys v0.0.0-20200720211630-cb9d2d5c5666 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /util/path.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/exec" 7 | "path/filepath" 8 | "runtime" 9 | "strings" 10 | ) 11 | 12 | var ostype = runtime.GOOS 13 | 14 | func NormalPath(filepath string) string { 15 | if ostype == "windows" { 16 | filepath = strings.Replace(filepath, "/", "\\", -1) 17 | } 18 | return filepath 19 | } 20 | 21 | func NormalPathF(path string, args ...interface{}) string { 22 | str := fmt.Sprintf(path, args...) 23 | return NormalPath(str) 24 | } 25 | 26 | func GetApplicationDir() string { 27 | tmpDir, err := os.Getwd() 28 | if err != nil { 29 | file, _ := exec.LookPath(os.Args[0]) 30 | tfile, _ := filepath.Abs(file) 31 | tmpDir, _ = filepath.Split(tfile) 32 | } 33 | return tmpDir 34 | } 35 | 36 | func GetAbsolutePath(path string, args ...interface{}) string { 37 | appDir := GetApplicationDir() 38 | str := fmt.Sprintf(path, args...) 39 | str = appDir + str 40 | return NormalPath(str) 41 | } 42 | 43 | func PathExists(path string) bool { 44 | _, err := os.Stat(path) 45 | if err == nil { 46 | return true 47 | } 48 | if os.IsNotExist(err) { 49 | return false 50 | } 51 | return false 52 | } 53 | -------------------------------------------------------------------------------- /util/log/logrushook/lfshook.go: -------------------------------------------------------------------------------- 1 | package logrushook 2 | 3 | import ( 4 | "github.com/lestrrat-go/file-rotatelogs" 5 | "github.com/pkg/errors" 6 | "github.com/rifflock/lfshook" 7 | log "github.com/sirupsen/logrus" 8 | "time" 9 | ) 10 | 11 | // WithMaxAge和WithRotationCount二者只能设置一个, 12 | // WithMaxAge设置文件清理前的最长保存时间, 13 | // WithRotationCount设置文件清理前最多保存的个数。 14 | // rotatelogs.WithMaxAge(time.Hour*24), 15 | func NewLfsHook(logPath string, maxAge time.Duration, rotationTime time.Duration) log.Hook { 16 | writer, err := rotatelogs.New( 17 | logPath+".%Y-%m-%d-%H-%M", 18 | rotatelogs.WithLinkName(logPath), // 生成软链,指向最新日志文件 19 | rotatelogs.WithMaxAge(maxAge), // 文件最大保存时间 20 | rotatelogs.WithRotationTime(rotationTime), // 日志切割时间间隔 21 | ) 22 | if err != nil { 23 | log.Errorf("config local file system logger error. %+v", errors.WithStack(err)) 24 | } 25 | lfsHook := lfshook.NewHook(lfshook.WriterMap{ 26 | log.DebugLevel: writer, 27 | log.InfoLevel: writer, 28 | log.WarnLevel: writer, 29 | log.ErrorLevel: writer, 30 | log.FatalLevel: writer, 31 | log.PanicLevel: writer, 32 | }, &log.TextFormatter{DisableColors: true}) 33 | 34 | return lfsHook 35 | } 36 | -------------------------------------------------------------------------------- /limit_req_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "minGateway/util" 7 | "sync" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestExceededLimitReq(t *testing.T) { 13 | quit := make(chan bool) 14 | 15 | succNum := 0 16 | failNum := 0 17 | maxConn := make(chan bool, 5000) 18 | 19 | mu := &sync.Mutex{} 20 | 21 | go func() { 22 | time.Sleep(11 * time.Second) 23 | quit <- true 24 | return 25 | }() 26 | 27 | ticker := time.NewTicker(2 * time.Second) 28 | 29 | for { 30 | select { 31 | case <-quit: 32 | fmt.Println("end") 33 | goto END 34 | case <-ticker.C: 35 | for i := 0; i < 20; i++ { 36 | go reqPing(&succNum, &failNum, mu, maxConn) 37 | time.Sleep(time.Millisecond * 10) 38 | } 39 | } 40 | } 41 | 42 | END: 43 | fmt.Println("succ:", succNum, " fail:", failNum) 44 | 45 | } 46 | 47 | func reqPing(succ, fail *int, mu *sync.Mutex, conn chan bool) { 48 | conn <- true 49 | req, err := util.HttpGet("http://127.0.0.1/api/ping?cid=100") 50 | <-conn 51 | if err != nil { 52 | fmt.Println("error:", err) 53 | } else { 54 | defer req.Body.Close() 55 | 56 | req, errC := ioutil.ReadAll(util.LimitReader(req.Body, 1024*1024)) 57 | if errC != nil { 58 | fmt.Println("error:", err) 59 | } 60 | if string(req) == "pong" { 61 | mu.Lock() 62 | *succ++ 63 | fmt.Println("pong") 64 | mu.Unlock() 65 | return 66 | } 67 | } 68 | mu.Lock() 69 | *fail++ 70 | fmt.Println("fail") 71 | mu.Unlock() 72 | return 73 | } 74 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "fmt" 7 | "minGateway/apiserve" 8 | "minGateway/config" 9 | "minGateway/status" 10 | "minGateway/util" 11 | "minGateway/util/log" 12 | "os" 13 | "os/signal" 14 | "syscall" 15 | "time" 16 | ) 17 | 18 | func main() { 19 | 20 | //加载配置 21 | err := loadConfig() 22 | if err != nil { 23 | panic(err) 24 | } 25 | 26 | //log设置 27 | logSetting() 28 | 29 | //运行服务 30 | srv := new(GateServer) 31 | ss := srv.run() 32 | 33 | //运行API服务 34 | api := apiserve.Run() 35 | ss = append(ss, api) 36 | 37 | //wait exit 38 | quit := make(chan os.Signal) 39 | signal.Notify(quit, syscall.SIGKILL, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) 40 | sig := <-quit 41 | fmt.Println("Start shutdown Server ...") 42 | 43 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 44 | defer cancel() 45 | for _, s := range ss { 46 | if err := s.Shutdown(ctx); err != nil { 47 | fmt.Printf("Server Shutdown error, signal:%v error:%s\n", sig, err) 48 | } 49 | } 50 | 51 | fmt.Printf("Safe exit server, signal:%v\n", sig) 52 | } 53 | 54 | func init() { 55 | HostList = make(map[string]HostInfo) 56 | HostListWc = make(map[string]HostInfoWc) 57 | CertificateSet = make([]tls.Certificate, 0) 58 | 59 | // 状态信息初始化 60 | _ = status.Init() 61 | } 62 | 63 | func logSetting() { 64 | logConf := config.Get().LogConf 65 | 66 | // 设置日志输出级别 67 | // DebugLevel,InfoLevel,WarnLevel,ErrorLevel,FatalLevel,PanicLevel 68 | log.Init(logConf.LogLevel) 69 | 70 | // 是否写日志文件 71 | if logConf.WriteFile { 72 | 73 | logFile := "logging.log" 74 | if logConf.FileDir != "" { 75 | //FileDir不为空使用绝对路径 76 | logFile = logConf.FileDir + logFile 77 | } else { 78 | dir := util.GetAbsolutePath("/bin/logs") 79 | if !util.PathExists(dir) { 80 | _ = os.Mkdir(dir, 0666) 81 | } 82 | logFile = util.GetAbsolutePath("/bin/logs/%s", logFile) 83 | } 84 | 85 | log.WirteLog(logFile) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /defense_cc.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "minGateway/util/fixedqueue" 5 | "minGateway/util/log" 6 | "minGateway/util/tcache" 7 | "time" 8 | ) 9 | 10 | //检查时间,毫秒 11 | var defenseCCCheckTime int 12 | 13 | //检查时间内的次数 14 | var defenseCCCheckCount int 15 | 16 | //加入黑名单时间,秒 17 | var defenseCCBlackTime int 18 | 19 | //IP记录 20 | var defenseCCIPsRecord *tcache.TCache 21 | 22 | //IP黑名单 23 | var defenseCCIPsBlacklist *tcache.TCache 24 | 25 | //配置中的黑白名单 26 | var defenseCCIPWhiteListConf map[string]struct{} 27 | var defenseCCIPBlackListConf map[string]struct{} 28 | 29 | //判断是否存在于黑名单中 30 | func existInCCBlacklist(ip string) bool { 31 | _, exist := defenseCCIPsBlacklist.Get(ip) 32 | return exist 33 | } 34 | 35 | //检查防御,如果返回true则加入黑名单 36 | func defenseCCBlockCheck(ip string) bool { 37 | 38 | //配置中的白名单 39 | if _, exist := defenseCCIPWhiteListConf[ip]; exist { 40 | return false 41 | } 42 | //配置中的黑名单 43 | if _, exist := defenseCCIPBlackListConf[ip]; exist { 44 | return true 45 | } 46 | 47 | if existInCCBlacklist(ip) { 48 | return true 49 | } 50 | 51 | isBlock := false 52 | lt, ok := defenseCCIPsRecord.Get(ip) 53 | if !ok { 54 | record := fixedqueue.NewFixedQueue(defenseCCCheckCount) 55 | record.Push(makeTimestamp()) 56 | defenseCCIPsRecord.Set(ip, record) 57 | return false 58 | } else { 59 | record := lt.(*fixedqueue.FixedQueue) 60 | if record.Len() == defenseCCCheckCount { //如果记录的长度已达到定义数 61 | //取出最老的那个时间 62 | oldest, _ := record.Get() 63 | now := makeTimestamp() 64 | if now-oldest.(int64) < int64(defenseCCCheckTime) { 65 | //当前时间减去最老的时间小于定义的时间限定 66 | //加入黑名单 67 | log.Warn(ip, "被加入黑名单", float64(defenseCCBlackTime)/60, "分钟") 68 | defenseCCIPsBlacklist.Set(ip, true) 69 | isBlock = true 70 | } else { 71 | record.Push(now) 72 | } 73 | } else { 74 | record.Push(makeTimestamp()) 75 | } 76 | } 77 | return isBlock 78 | } 79 | 80 | func makeTimestamp() int64 { 81 | return time.Now().UnixNano() / int64(time.Millisecond) 82 | } 83 | -------------------------------------------------------------------------------- /util/fixedqueue/fixedqueue.go: -------------------------------------------------------------------------------- 1 | // 一个线程安全的固定长度队列,压入超过队列长度的数据时,替换最老的数据 2 | // 使用环形队列的方式存储数据,减少内存的创建开销和碎片产生 3 | // 经对比测试,比直接使用slice增删方式实现的队列性能好一半左右(测试的一组数据是:3560 ns/op,对比数据为:8876 ns/op) 4 | // Email:cgrencn@gmail.com 5 | 6 | package fixedqueue 7 | 8 | import ( 9 | "fmt" 10 | "sync" 11 | ) 12 | 13 | type FixedQueue struct { 14 | mu *sync.Mutex 15 | list []interface{} 16 | size int //队列大小 17 | p1 int //队列头位置 18 | p2 int //队列尾位置 19 | empty bool //是否列表为空 20 | } 21 | 22 | func NewFixedQueue(size int) *FixedQueue { 23 | q := new(FixedQueue) 24 | q.mu = new(sync.Mutex) 25 | q.list = make([]interface{}, size) 26 | q.size = size 27 | q.empty = true 28 | //q.p1 = 0 29 | //q.p2 = 0 30 | return q 31 | } 32 | 33 | //压入队列尾一个数据 34 | func (q *FixedQueue) Push(data interface{}) { 35 | q.mu.Lock() 36 | 37 | if q.empty { 38 | q.empty = false 39 | q.list[q.p2] = data 40 | q.mu.Unlock() 41 | return 42 | } 43 | 44 | if q.p2+1 == q.size { 45 | q.p2 = 0 46 | } else { 47 | q.p2++ 48 | } 49 | 50 | q.list[q.p2] = data 51 | 52 | if q.p1 == q.p2 { 53 | //追上头,吃掉最老的数据 54 | q.p1++ 55 | if q.p1 == q.size { 56 | q.p1 = 0 57 | } 58 | } 59 | 60 | q.mu.Unlock() 61 | return 62 | } 63 | 64 | //从队列头取出一个数据,同时从队列中删除它 65 | func (q *FixedQueue) Pop() (interface{}, bool) { 66 | q.mu.Lock() 67 | 68 | if q.empty { 69 | q.mu.Unlock() 70 | return nil, false 71 | } 72 | 73 | data := q.list[q.p1] 74 | 75 | if q.p1 == q.p2 { 76 | //取出最后一个数据 77 | q.empty = true 78 | q.mu.Unlock() 79 | return data, true 80 | } 81 | 82 | if q.p1+1 == q.size { 83 | q.p1 = 0 84 | } else { 85 | q.p1++ 86 | } 87 | 88 | q.mu.Unlock() 89 | return data, true 90 | } 91 | 92 | //获取队列头的数据,但不会从队列中删除 93 | func (q *FixedQueue) Get() (interface{}, bool) { 94 | q.mu.Lock() 95 | 96 | if q.empty { 97 | q.mu.Unlock() 98 | return nil, false 99 | } 100 | 101 | data := q.list[q.p1] 102 | 103 | q.mu.Unlock() 104 | return data, true 105 | } 106 | 107 | //获取队列的长度 108 | func (q *FixedQueue) Len() int { 109 | q.mu.Lock() 110 | if q.empty { 111 | q.mu.Unlock() 112 | return 0 113 | } 114 | l := q.p2 - q.p1 + 1 115 | if l <= 0 { 116 | l = q.size + l 117 | } 118 | q.mu.Unlock() 119 | return l 120 | } 121 | 122 | //清空队列,使队列长度为零 123 | func (q *FixedQueue) Clear() { 124 | q.mu.Lock() 125 | q.p1 = 0 126 | q.p2 = 0 127 | q.empty = true 128 | q.mu.Unlock() 129 | } 130 | 131 | func (q *FixedQueue) Print() { 132 | q.mu.Lock() 133 | fmt.Println(q.list) 134 | q.mu.Unlock() 135 | } 136 | -------------------------------------------------------------------------------- /util/tcache/tcache.go: -------------------------------------------------------------------------------- 1 | // 一个线程安全的可以设置过期时间的键值对存储 2 | // 基于性能考虑,采用被动式清除过期数据方式,即除非执行其中任一方法(Set,Get,Len,Update),否则不会去移除过期数据 3 | // 一般来说这些过期数据存在不会有什么影响,但如果这些数据带来问题,可在外部增加一个定时器,定时执行Update方法 4 | // Email:cgrencn@gmail.com 5 | 6 | package tcache 7 | 8 | import ( 9 | "container/list" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | type timeKey struct { 15 | inTime time.Time 16 | key interface{} 17 | } 18 | 19 | type TCache struct { 20 | keyQueue *list.List 21 | repeats map[interface{}]int 22 | datas map[interface{}]interface{} 23 | mu *sync.Mutex 24 | destroyTime time.Duration 25 | } 26 | 27 | /** 28 | destroyTime:设置过期时间 29 | */ 30 | func NewTimeCache(destroyTime time.Duration) *TCache { 31 | tc := new(TCache) 32 | tc.datas = make(map[interface{}]interface{}) //数据存储 33 | tc.repeats = make(map[interface{}]int) //key重复出现的次数 34 | tc.keyQueue = list.New() //key存储队列 35 | tc.mu = new(sync.Mutex) 36 | tc.destroyTime = destroyTime 37 | return tc 38 | } 39 | 40 | func (tc *TCache) Set(key, data interface{}) { 41 | tc.mu.Lock() 42 | defer tc.mu.Unlock() 43 | //先移除过期的数据 44 | tc.removeOverdue() 45 | //设置数据 46 | _, exist := tc.datas[key] 47 | tc.datas[key] = data 48 | tkey := timeKey{ 49 | inTime: time.Now(), 50 | key: key, 51 | } 52 | tc.keyQueue.PushBack(tkey) 53 | //是否原来已经存在 54 | if exist { 55 | count, ok := tc.repeats[key] //重复了几次 56 | if ok { 57 | tc.repeats[key] = count + 1 58 | } else { 59 | tc.repeats[key] = 1 60 | } 61 | } 62 | } 63 | 64 | func (tc *TCache) Get(key interface{}) (interface{}, bool) { 65 | tc.mu.Lock() 66 | defer tc.mu.Unlock() 67 | //先移除过期的数据 68 | tc.removeOverdue() 69 | //获取数据 70 | d, ok := tc.datas[key] 71 | return d, ok 72 | } 73 | 74 | func (tc *TCache) Len() int { 75 | tc.mu.Lock() 76 | defer tc.mu.Unlock() 77 | //先移除过期的数据 78 | tc.removeOverdue() 79 | //len 80 | return len(tc.datas) 81 | } 82 | 83 | func (tc *TCache) Update() { 84 | tc.mu.Lock() 85 | //移除过期的数据 86 | tc.removeOverdue() 87 | tc.mu.Unlock() 88 | } 89 | 90 | //no thread safe 91 | func (tc *TCache) removeOverdue() { 92 | if tc.keyQueue.Len() == 0 { 93 | return 94 | } 95 | for { 96 | e := tc.keyQueue.Front() 97 | if e == nil { 98 | return 99 | } 100 | tkey := (e.Value).(timeKey) 101 | if time.Now().Sub(tkey.inTime) > tc.destroyTime { 102 | tc.keyQueue.Remove(e) 103 | if count, ok := tc.repeats[tkey.key]; !ok { 104 | //只有repeats里没有此key,才去删除数据 105 | delete(tc.datas, tkey.key) 106 | } else { 107 | if count <= 1 { 108 | delete(tc.repeats, tkey.key) 109 | } else { 110 | tc.repeats[tkey.key] = count - 1 111 | } 112 | } 113 | continue 114 | } else { 115 | break 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /util/log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "github.com/sirupsen/logrus" 5 | "minGateway/util/log/logrushook" 6 | "time" 7 | ) 8 | 9 | var L *logrus.Logger 10 | 11 | func Init(level string) { 12 | L = logrus.New() 13 | L.SetLevel(getLogLevel(level)) 14 | } 15 | 16 | func WirteLog(logPath string) { 17 | //分割日志,每小时一个文件,最长保存15天 18 | lfsHook := logrushook.NewLfsHook(logPath, time.Hour*24*15, time.Hour) 19 | L.AddHook(lfsHook) 20 | } 21 | 22 | //设置最低loglevel: debug,info,warn,error,fatal,panic 23 | func getLogLevel(levelName string) logrus.Level { 24 | switch levelName { 25 | case "debug": 26 | return logrus.DebugLevel 27 | case "DebugLevel": 28 | return logrus.DebugLevel 29 | case "info": 30 | return logrus.InfoLevel 31 | case "InfoLevel": 32 | return logrus.InfoLevel 33 | case "warn": 34 | return logrus.WarnLevel 35 | case "WarnLevel": 36 | return logrus.WarnLevel 37 | case "error": 38 | return logrus.ErrorLevel 39 | case "ErrorLevel": 40 | return logrus.ErrorLevel 41 | case "fatal": 42 | return logrus.FatalLevel 43 | case "FatalLevel": 44 | return logrus.FatalLevel 45 | case "panic": 46 | return logrus.PanicLevel 47 | case "PanicLevel": 48 | return logrus.PanicLevel 49 | default: 50 | return logrus.WarnLevel 51 | } 52 | } 53 | 54 | func Tracef(format string, args ...interface{}) { 55 | L.Tracef(format, args...) 56 | } 57 | 58 | func Debugf(format string, args ...interface{}) { 59 | L.Debugf(format, args...) 60 | } 61 | 62 | func Infof(format string, args ...interface{}) { 63 | L.Infof(format, args...) 64 | } 65 | 66 | func Printf(format string, args ...interface{}) { 67 | L.Printf(format, args...) 68 | } 69 | 70 | func Warnf(format string, args ...interface{}) { 71 | L.Warnf(format, args...) 72 | } 73 | 74 | func Warningf(format string, args ...interface{}) { 75 | L.Warningf(format, args...) 76 | } 77 | 78 | func Errorf(format string, args ...interface{}) { 79 | L.Errorf(format, args...) 80 | } 81 | 82 | func Fatalf(format string, args ...interface{}) { 83 | L.Fatalf(format, args...) 84 | } 85 | 86 | func Panicf(format string, args ...interface{}) { 87 | L.Panicf(format, args...) 88 | } 89 | 90 | func Trace(args ...interface{}) { 91 | L.Trace(args...) 92 | } 93 | 94 | func Debug(args ...interface{}) { 95 | L.Debug(args...) 96 | } 97 | 98 | func Info(args ...interface{}) { 99 | L.Info(args...) 100 | } 101 | 102 | func Print(args ...interface{}) { 103 | L.Print(args...) 104 | } 105 | 106 | func Warn(args ...interface{}) { 107 | L.Warn(args...) 108 | } 109 | 110 | func Warning(args ...interface{}) { 111 | L.Warning(args...) 112 | } 113 | 114 | func Error(args ...interface{}) { 115 | L.Error(args...) 116 | } 117 | 118 | func Fatal(args ...interface{}) { 119 | L.Fatal(args...) 120 | } 121 | 122 | func Panic(args ...interface{}) { 123 | L.Panic(args...) 124 | } 125 | -------------------------------------------------------------------------------- /bin/configs/conf.txt: -------------------------------------------------------------------------------- 1 | #服务器设置 2 | [core] 3 | #最大连接数,为0则不限流 4 | limitMaxConn = 30000 5 | #读超时,读完消息head和body的全部时间限制,为0则没有超时,单位秒 6 | readTimeout = 5 7 | #写超时,从读完消息开始到消息返回的用时限制,为0则没有超时,单位秒 8 | writeTimeout = 30 9 | #闲置超时,IdleTimeout是启用keep-alives状态后(默认启用)等待下一个请求的最长时间。 10 | #如果IdleTimeout为零,则使用ReadTimeout的值。 如果两者均为零,则没有超时。单位秒 11 | idleTimeout = 60 12 | #最大头字节,为0则使用默认值1024k, 这里设置为131072=128k 13 | maxHeaderBytes = 131072 14 | #设置读取消息头中IP转发字段名称,按照数组里的顺序查找,如果为空则获取的是TCP连接的IP地址 15 | #如果消息是通过前端代理服务器转发或者cdn转发,则需要从消息头中获取IP地址(注意确保IP的真实性) 16 | ipForwardeds = ["Ali-Cdn-Real-Ip","X-Forwarded-For","X-Real-Ip","X-Real-IP"] 17 | 18 | #CC防御(Challenge Collapsar) 19 | #如设置为3000毫秒钟内访问100次,则把此IP加入黑名单3600秒 20 | [ccDefense] 21 | #是否开启 22 | enable = false 23 | #此时间的访问内做检查(单位:毫秒) 24 | timeDuration = 3000 25 | #允许访问的次数 26 | count = 100 27 | #放入黑名单时间(单位:秒) 28 | blackTime = 3600 29 | #IP白名单 30 | whiteList = ["56.127.44.121","56.127.44.122"] 31 | #IP黑名单 32 | IPBlackList = [] 33 | 34 | #限制请求配置 35 | #限制在一个时间范围内同IP访问同一URL请求次数,超过的会被抛弃 36 | [limitReq] 37 | #是否开启 38 | enable = false 39 | #此时间的访问内做检查(单位:毫秒) 40 | timeDuration = 1000 41 | #允许访问的次数 42 | count = 1 43 | #0:不包含参数 1:包含参数(为0时限制范围更广) 44 | mode = 0 45 | 46 | #log设置 47 | [log] 48 | #设置最低loglevel: debug,info,warn,error,fatal,panic 49 | logLevel = "debug" 50 | #开启写文件 51 | writeFile = true 52 | #存储日志文件的目录,不为空时是绝对路径,为空则是相对路径在程序的bin/logs/中 53 | fileDir = "" 54 | 55 | #代理设置 56 | [proxyInfo] 57 | 58 | #API服务器 59 | [proxyInfo.shop] 60 | #结尾有一个点的是泛解析,所有alogin开头的都走这里 61 | host = "alogin." 62 | #这里有2台服务器,用逗号间隔开 63 | target = ["http://172.226.10.17:8080","http://172.226.10.19:8080"] 64 | #服务器选择模式,1:随机方式 2:轮询方式 3:一致性哈希方式;如果未设置则使用随机方式 65 | obtainMode = 3 66 | 67 | #API服务器 68 | [proxyInfo.alogin] 69 | host = "alogin.obtc.com" 70 | target = ["http://172.226.10.17:8080","http://172.226.10.19:8080"] 71 | obtainMode = 3 72 | 73 | #管理后台 74 | [proxyInfo.admin] 75 | host = "admin.obtc.com" 76 | target = ["http://172.226.10.19:9150"] 77 | 78 | #微信公众号 79 | [proxyInfo.wxpay] 80 | host = "wxpay." 81 | target = ["http://172.226.10.17:9101"] 82 | 83 | #开放平台,苹果拉微信 84 | [proxyInfo.wxpaypay] 85 | host = "wxpaypay.obtc.com" 86 | target = ["http://172.226.10.17:9103"] 87 | 88 | #简单网站和隐私协议 89 | [proxyInfo.service] 90 | host = "service.obtc.com" 91 | target = ["http://172.226.10.17:9102"] 92 | 93 | #都没找到走这里 94 | [proxyInfo.default] 95 | host = "default" 96 | target = ["http://172.226.10.17:8080"] 97 | 98 | 99 | [sslBase] 100 | sessionTicket = true 101 | 102 | #SSL证书文件 103 | [sslCert] 104 | 105 | [sslCert.wxpaypay_obtc_com] 106 | ssl_certificate = "wxpaypay.obtc.com.crt" 107 | ssl_certificate_key = "wxpaypay.obtc.com.key" 108 | ocsp_stapling = true 109 | ocsp_stapling_local = true 110 | ocsp_stapling_file = "wxpaypay.obtc.com.ocsp" 111 | 112 | [sslCert.wxpaypay_lidu_com] 113 | ssl_certificate = "wxpaypay.lidu.com.crt" 114 | ssl_certificate_key = "wxpaypay.lidu.com.key" 115 | -------------------------------------------------------------------------------- /util/fixedqueue/fixedqueue_test.go: -------------------------------------------------------------------------------- 1 | package fixedqueue 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestFixedQueue(t *testing.T) { 10 | q := NewFixedQueue(8) 11 | q.Push(0) 12 | q.Push(1) 13 | q.Push(2) 14 | q.Push(3) 15 | q.Push(4) 16 | q.Print() 17 | fmt.Println(q.Len()) 18 | q.Push(5) 19 | q.Push(6) 20 | q.Push(7) 21 | q.Push(8) 22 | q.Push(9) 23 | 24 | q.Print() 25 | fmt.Println(q.Len()) 26 | 27 | fmt.Println(q.Pop()) 28 | fmt.Println(q.Pop()) 29 | fmt.Println(q.Pop()) 30 | fmt.Println(q.Pop()) 31 | 32 | fmt.Println(q.Len()) 33 | 34 | //q.Clear() 35 | //fmt.Println(q.Len()) 36 | 37 | fmt.Println(q.Pop()) 38 | fmt.Println(q.Pop()) 39 | fmt.Println(q.Pop()) 40 | fmt.Println(q.Pop()) 41 | fmt.Println(q.Pop()) 42 | 43 | fmt.Println(q.Len()) 44 | q.Print() 45 | } 46 | 47 | func TestNewFixedQueue(t *testing.T) { 48 | type args struct { 49 | size int 50 | } 51 | tests := []struct { 52 | name string 53 | args args 54 | want *FixedQueue 55 | }{ 56 | // TODO: Add test cases. 57 | { 58 | name: "testFixedQueue1", 59 | args: args{size: 10}, 60 | want: NewFixedQueue(10), 61 | }, 62 | { 63 | name: "testFixedQueue2", 64 | args: args{size: 50}, 65 | want: NewFixedQueue(50), 66 | }, 67 | { 68 | name: "testFixedQueue3", 69 | args: args{size: 1000}, 70 | want: NewFixedQueue(1000), 71 | }, 72 | } 73 | for _, tt := range tests { 74 | t.Run(tt.name, func(t *testing.T) { 75 | got := NewFixedQueue(tt.args.size) 76 | if !reflect.DeepEqual(got, tt.want) { 77 | t.Errorf("NewFixedQueue() = %v, want %v", got, tt.want) 78 | } 79 | for j := 0; j < 100000; j++ { 80 | for i := 0; i < 100; i++ { 81 | got.Push(i) 82 | } 83 | for i := 0; i < 50; i++ { 84 | got.Pop() 85 | } 86 | } 87 | t.Log(got) 88 | }) 89 | } 90 | } 91 | 92 | func BenchmarkNewFixedQueue(b *testing.B) { 93 | type args struct { 94 | size int 95 | } 96 | tests := []struct { 97 | name string 98 | args args 99 | want *FixedQueue 100 | }{ 101 | // TODO: Add test cases. 102 | { 103 | name: "testFixedQueue1", 104 | args: args{size: 10}, 105 | want: NewFixedQueue(10), 106 | }, 107 | { 108 | name: "testFixedQueue2", 109 | args: args{size: 50}, 110 | want: NewFixedQueue(50), 111 | }, 112 | { 113 | name: "testFixedQueue3", 114 | args: args{size: 1000}, 115 | want: NewFixedQueue(1000), 116 | }, 117 | } 118 | for _, tt := range tests { 119 | b.Run(tt.name, func(b *testing.B) { 120 | got := NewFixedQueue(tt.args.size) 121 | if !reflect.DeepEqual(got, tt.want) { 122 | b.Errorf("NewFixedQueue() = %v, want %v", got, tt.want) 123 | } 124 | for j := 0; j < b.N; j++ { 125 | for i := 0; i < 100; i++ { 126 | got.Push(i) 127 | } 128 | for i := 0; i < 50; i++ { 129 | got.Pop() 130 | } 131 | } 132 | b.Log(got) 133 | }) 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "github.com/BurntSushi/toml" 6 | "minGateway/util" 7 | "sync" 8 | ) 9 | 10 | type ProxyConfig struct { 11 | CoreConf Core `toml:"core"` 12 | LimitReq LimitReq `toml:"limitReq"` 13 | CcDefense CcDefense `toml:"ccDefense"` 14 | LogConf Log `toml:"log"` 15 | HttpProxy map[string]ProxyInfo `toml:"proxyInfo"` 16 | SslBase SslBase `toml:"sslBase"` 17 | SslCert map[string]SslCert `toml:"sslCert"` 18 | } 19 | 20 | /* 21 | #最大连接数,为零,则不限流 22 | limitMaxConn = 30000 23 | #读超时,读完消息head和body的全部时间限制,为零,则没有超时,单位秒 24 | readTimeout = 5 25 | #写超时,从读完消息开始到消息返回的用时限制,为零,则没有超时,单位秒 26 | writeTimeout = 10 27 | #闲置超时,IdleTimeout是启用keep-alives状态后(默认启用)等待下一个请求的最长时间。 28 | #如果IdleTimeout为零,则使用ReadTimeout的值。 如果两者均为零,则没有超时。单位秒 29 | idleTimeout = 120 30 | #最大头字节,为0则使用默认 31 | maxHeaderBytes = 0 32 | #主动防御,如"5,30,1800"设置为3秒钟内访问10次,则把此IP加入黑名单1800秒。未定义或为空则不开启 33 | activeDefense = "5,20,1800" 34 | #设置读取消息头中IP转发字段名称,按照数组的顺序查找,如果为空则获取的是TCP连接的IP地址 35 | #如果消息是通过前端代理服务器转发或者cdn转发,则需要从消息头中获取IP地址(注意确保IP的真实性) 36 | ipForwardeds = ["Ali-Cdn-Real-Ip","X-Forwarded-For","X-Real-Ip","X-Real-IP"] 37 | */ 38 | type Core struct { 39 | LimitMaxConn int `toml:"limitMaxConn"` 40 | ReadTimeout int `toml:"readTimeout"` 41 | WriteTimeout int `toml:"writeTimeout"` 42 | IdleTimeout int `toml:"idleTimeout"` 43 | MaxHeaderBytes int `toml:"maxHeaderBytes"` 44 | IpForwardeds []string `toml:"ipForwardeds"` 45 | } 46 | 47 | type LimitReq struct { 48 | Enable bool `toml:"enable"` 49 | TimeDuration int `toml:"timeDuration"` 50 | Count int `toml:"count"` 51 | Mode int `toml:"mode"` 52 | } 53 | 54 | type CcDefense struct { 55 | Enable bool `toml:"enable"` 56 | TimeDuration int `toml:"timeDuration"` 57 | Count int `toml:"count"` 58 | BlackTime int `toml:"blackTime"` 59 | WhiteList []string `toml:"whiteList"` 60 | BlackList []string `toml:"blackList"` 61 | } 62 | 63 | type Log struct { 64 | LogLevel string `toml:"logLevel"` 65 | WriteFile bool `toml:"writeFile"` 66 | FileDir string `toml:"fileDir"` 67 | } 68 | 69 | type ProxyInfo struct { 70 | Host string `toml:"host"` 71 | Target []string `toml:"target"` 72 | ObtainMode int `toml:"obtainMode"` 73 | } 74 | 75 | type SslBase struct { 76 | SessionTicket bool `toml:"sessionTicket"` 77 | } 78 | 79 | type SslCert struct { 80 | SslCertificate string `toml:"ssl_certificate"` 81 | SslCertificateKey string `toml:"ssl_certificate_key"` 82 | //是否开启ocsp stapling,如果开启默认先去ssl证书平台拉取,失败再看本地是否有ocsp文件 83 | OcspStapling bool `toml:"ocsp_stapling"` 84 | OcspStaplingLocal bool `toml:"ocsp_stapling_local"` 85 | OcspStaplingFile string `toml:"ocsp_stapling_file"` 86 | } 87 | 88 | var ( 89 | cfg ProxyConfig 90 | once sync.Once 91 | ) 92 | 93 | func Get() *ProxyConfig { 94 | once.Do(func() { 95 | cfg = ProxyConfig{} 96 | filePath := util.GetAbsolutePath("/bin/configs/conf.txt") 97 | if _, err := toml.DecodeFile(filePath, &cfg); err != nil { 98 | panic(err) 99 | } 100 | fmt.Printf("读取配置文件: %s\n", filePath) 101 | }) 102 | return &cfg 103 | } 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # minGateway 2 | 3 | 基于Go语言开发的一个小巧的http/https网关服务 4 | 5 | ## 优点: 6 | - 部署简单,程序只有一个单独的文件,支持多个操作系统,即拷即用 7 | - 性能优异,轻松实现每秒上万转发 8 | - 可扩展性好,根据自身业务扩展起来方便(主要是因为代码量少) 9 | 10 | ## 功能: 11 | - 负载均衡:支持三种目标路由方式可选,随机、轮询、哈希 12 | - 多种转发规则:支持前逗号和后逗号泛路由转发 13 | - 限流:设置最大连接数,实现限流功能 14 | - CC攻击(Challenge Collapsar)防御:开启后记录频繁访问的IP,将其加入黑名单,具有一定的CC防御攻击能力 15 | - 访问限制:开启后限制在一个时间范围内同IP访问同一URL请求次数,超过的会被抛弃 16 | - HTTP(S)反向代理:由网关处理https加密连接,后端服务器只需提供非加密的http服务 17 | - OCSP:可配置的OCSP Stapling功能 18 | - 管理API:当前只有获取当前正在访问的数量,需要自行扩展 19 | 20 | ## 压测数据: 21 | **内网环境:** 22 | - 网关服务器(4核8g centos 7.x),tomcat服务器(8核16g contos 7.x) 23 | - tomcat提供GET的"/ping"请求,返回字符串"pong" 24 | - 70微秒一次请求,约合每秒14285次请求,持续压测1分钟 25 | 26 | **测试结果:** 27 | - tomcat直压结果:共返回823842次成功结果,0次失败,约合每秒返回13730个成功结果 28 | - 通过minGateway转发:共返回606926次成功结果,0次失败,约合每秒10115个成功结果 29 | 30 | *另外,本次测试是在tomcat和minGateway的控制台都开着打印的情况,如果优化一下测试数据应该能再好一些。* 31 | 32 | ## 使用方法: 33 | 34 | 1. 配置go编译环境,本库在go 1.14中编译通过 35 | 36 | 2. 获取代码: `go get github.com/ccynet/minGateway` 37 | 38 | 3. 下载依赖: `cd minGateway && go mod tidy` 39 | 40 | 4. 编译程序: `go build` 41 | 42 | 5. 修改配置文件: 配置文件在minGateway的bin/configs/目录下 43 | 44 | 6. 如果需要设置TLS证书,证书文件保存在项目的bin/cert/目录下,并在配置文件中做相应设置 45 | 46 | 7. 运行编译出来的程序,请确保可执行文件放置在bin文件夹的同级目录下 47 | 48 | ## 附配置文件示例: 49 | 50 | ```toml 51 | #服务器设置 52 | [core] 53 | #最大连接数,为0则不限流 54 | limitMaxConn = 30000 55 | #读超时,读完消息head和body的全部时间限制,为0则没有超时,单位秒 56 | readTimeout = 5 57 | #写超时,从读完消息开始到消息返回的用时限制,为0则没有超时,单位秒 58 | writeTimeout = 30 59 | #闲置超时,IdleTimeout是启用keep-alives状态后(默认启用)等待下一个请求的最长时间。 60 | #如果IdleTimeout为零,则使用ReadTimeout的值。 如果两者均为零,则没有超时。单位秒 61 | idleTimeout = 60 62 | #最大头字节,为0则使用默认值1024k, 这里设置为131072=128k 63 | maxHeaderBytes = 131072 64 | #设置读取消息头中IP转发字段名称,按照数组里的顺序查找,如果为空则获取的是TCP连接的IP地址 65 | #如果消息是通过前端代理服务器转发或者cdn转发,则需要从消息头中获取IP地址(注意确保IP的真实性) 66 | ipForwardeds = ["Ali-Cdn-Real-Ip","X-Forwarded-For","X-Real-Ip","X-Real-IP"] 67 | 68 | #CC防御(Challenge Collapsar) 69 | #如设置为3000毫秒钟内访问100次,则把此IP加入黑名单3600秒 70 | [ccDefense] 71 | #是否开启 72 | enable = false 73 | #此时间的访问内做检查(单位:毫秒) 74 | timeDuration = 3000 75 | #允许访问的次数 76 | count = 100 77 | #放入黑名单时间(单位:秒) 78 | blackTime = 3600 79 | #IP白名单 80 | whiteList = ["56.127.44.121","56.127.44.122"] 81 | #IP黑名单 82 | IPBlackList = [] 83 | 84 | #限制请求配置 85 | #限制在一个时间范围内同IP访问同一URL请求次数,超过的会被抛弃 86 | [limitReq] 87 | #是否开启 88 | enable = false 89 | #此时间的访问内做检查(单位:毫秒) 90 | timeDuration = 1000 91 | #允许访问的次数 92 | count = 1 93 | #0:不包含参数 1:包含参数(为0时限制范围更广) 94 | mode = 0 95 | 96 | #log设置 97 | [log] 98 | #设置最低loglevel: debug,info,warn,error,fatal,panic 99 | logLevel = "debug" 100 | #开启写文件 101 | writeFile = true 102 | #存储日志文件的目录,不为空时是绝对路径,为空则是相对路径在程序的bin/logs/中 103 | fileDir = "" 104 | 105 | #代理设置 106 | [proxyInfo] 107 | 108 | #API服务器 109 | [proxyInfo.shop] 110 | #结尾有一个点的是泛解析,所有alogin开头的都走这里 111 | host = "alogin." 112 | #这里有2台服务器,用逗号间隔开 113 | target = ["http://172.226.10.17:8080","http://172.226.10.19:8080"] 114 | #服务器选择模式,1:随机方式 2:轮询方式 3:一致性哈希方式;如果未设置则使用随机方式 115 | obtainMode = 3 116 | 117 | #API服务器 118 | [proxyInfo.alogin] 119 | host = "alogin.obtc.com" 120 | target = ["http://172.226.10.17:8080","http://172.226.10.19:8080"] 121 | obtainMode = 3 122 | 123 | #管理后台 124 | [proxyInfo.admin] 125 | host = "admin.obtc.com" 126 | target = ["http://172.226.10.19:9150"] 127 | 128 | #微信公众号 129 | [proxyInfo.wxpay] 130 | host = "wxpay." 131 | target = ["http://172.226.10.17:9101"] 132 | 133 | #开放平台,苹果拉微信 134 | [proxyInfo.wxpaypay] 135 | host = "wxpaypay.obtc.com" 136 | target = ["http://172.226.10.17:9103"] 137 | 138 | #简单网站和隐私协议 139 | [proxyInfo.service] 140 | host = "service.obtc.com" 141 | target = ["http://172.226.10.17:9102"] 142 | 143 | #都没找到走这里 144 | [proxyInfo.default] 145 | host = "default" 146 | target = ["http://172.226.10.17:8080"] 147 | 148 | 149 | [sslBase] 150 | sessionTicket = true 151 | 152 | #SSL证书文件 153 | [sslCert] 154 | 155 | [sslCert.wxpaypay_obtc_com] 156 | ssl_certificate = "wxpaypay.obtc.com.crt" 157 | ssl_certificate_key = "wxpaypay.obtc.com.key" 158 | ocsp_stapling = true 159 | ocsp_stapling_local = true 160 | ocsp_stapling_file = "wxpaypay.obtc.com.ocsp" 161 | 162 | [sslCert.wxpaypay_lidu_com] 163 | ssl_certificate = "wxpaypay.lidu.com.crt" 164 | ssl_certificate_key = "wxpaypay.lidu.com.key" 165 | 166 | ``` 167 | 168 | ## Version 169 | v0.1.0 170 | 171 | ## License 172 | Licensed under the New BSD License. 173 | 174 | ## Author 175 | Tom Chen (cgrencn@gmail.com) 176 | -------------------------------------------------------------------------------- /ocsp.go: -------------------------------------------------------------------------------- 1 | /* 2 | OCSP(Online Certificate Status Protocol,在线证书状态协议)是用来检验证书合法性的在线查询服务,一般由证书所属 CA 提供。某些客户端会在 TLS 握手阶段 3 | 进一步协商时,实时查询 OCSP 接口,并在获得结果前阻塞后续流程。OCSP 查询本质是一次完整的 HTTP 请求 - 响应,这中间 DNS 查询、建立 TCP、服务端处理等环节 4 | 都可能耗费很长时间,导致最终建立 TLS 连接时间变得更长。 5 | 6 | OCSP Stapling(OCSP 封套),是指服务端主动获取 OCSP 查询结果并随着证书一起发送给客户端,从而让客户端跳过自己去验证的过程,提高 TLS 握手效率。 7 | 8 | 下面方法能在线获取OCSP信息,获取后将它设置到tls的OCSPStaple上,能开启go服务的OCSP Stapling 9 | 10 | 运行服务后,通过 openssl s_client -connect cspi.juleu.com:443 -status -tlsextdebug < /dev/null 2>&1 | grep -i "OCSP response" 查询是否开启 11 | 正常应该返回: 12 | OCSP response: 13 | OCSP Response Data: 14 | OCSP Response Status: successful (0x0) 15 | Response Type: Basic OCSP Response 16 | */ 17 | 18 | package main 19 | 20 | import ( 21 | "bytes" 22 | "crypto/x509" 23 | "encoding/pem" 24 | "errors" 25 | "fmt" 26 | "golang.org/x/crypto/ocsp" 27 | "io/ioutil" 28 | "minGateway/util" 29 | ) 30 | 31 | func GetOCSPForCert(cert [][]byte) ([]byte, *ocsp.Response, *x509.Certificate, error) { 32 | 33 | bundle := new(bytes.Buffer) 34 | for _, derBytes := range cert { 35 | err := pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 36 | if err != nil { 37 | fmt.Println(err) 38 | return nil, nil, nil, err 39 | } 40 | } 41 | pemBundle := bundle.Bytes() 42 | 43 | certificates, err := parsePEMBundle(pemBundle) 44 | if err != nil { 45 | return nil, nil, nil, err 46 | } 47 | 48 | // We expect the certificate slice to be ordered downwards the chain. 49 | // SRV CRT -> CA. We need to pull the leaf and issuer certs out of it, 50 | // which should always be the first two certificates. If there's no 51 | // OCSP server listed in the leaf cert, there's nothing to do. And if 52 | // we have only one certificate so far, we need to get the issuer cert. 53 | issuedCert := certificates[0] 54 | if len(issuedCert.OCSPServer) == 0 { 55 | return nil, nil, nil, errors.New("no OCSP server specified in cert") 56 | } 57 | if len(certificates) == 1 { 58 | // TODO: build fallback. If this fails, check the remaining array entries. 59 | if len(issuedCert.IssuingCertificateURL) == 0 { 60 | return nil, nil, issuedCert, errors.New("no issuing certificate URL") 61 | } 62 | 63 | resp, errC := util.HttpGet(issuedCert.IssuingCertificateURL[0]) 64 | if errC != nil { 65 | return nil, nil, issuedCert, errC 66 | } 67 | defer resp.Body.Close() 68 | 69 | issuerBytes, errC := ioutil.ReadAll(util.LimitReader(resp.Body, 1024*1024)) 70 | if errC != nil { 71 | return nil, nil, issuedCert, errC 72 | } 73 | 74 | issuerCert, errC := x509.ParseCertificate(issuerBytes) 75 | if errC != nil { 76 | return nil, nil, issuedCert, errC 77 | } 78 | 79 | // Insert it into the slice on position 0 80 | // We want it ordered right SRV CRT -> CA 81 | certificates = append(certificates, issuerCert) 82 | } 83 | issuerCert := certificates[1] 84 | 85 | // Finally kick off the OCSP request. 86 | ocspReq, err := ocsp.CreateRequest(issuedCert, issuerCert, nil) 87 | if err != nil { 88 | return nil, nil, issuedCert, err 89 | } 90 | 91 | reader := bytes.NewReader(ocspReq) 92 | fmt.Println("ocsp server url:", issuedCert.OCSPServer[0]) 93 | req, err := util.HttpPost(issuedCert.OCSPServer[0], "application/ocsp-request", reader) 94 | if err != nil { 95 | return nil, nil, issuedCert, err 96 | } 97 | defer req.Body.Close() 98 | 99 | ocspResBytes, err := ioutil.ReadAll(util.LimitReader(req.Body, 1024*1024)) 100 | if err != nil { 101 | return nil, nil, issuedCert, err 102 | } 103 | 104 | ocspRes, err := ocsp.ParseResponse(ocspResBytes, issuerCert) 105 | if err != nil { 106 | return nil, nil, issuedCert, err 107 | } 108 | 109 | return ocspResBytes, ocspRes, issuedCert, nil 110 | } 111 | 112 | // parsePEMBundle parses a certificate bundle from top to bottom and returns 113 | // a slice of x509 certificates. This function will error if no certificates are found. 114 | func parsePEMBundle(bundle []byte) ([]*x509.Certificate, error) { 115 | var certificates []*x509.Certificate 116 | var certDERBlock *pem.Block 117 | 118 | for { 119 | certDERBlock, bundle = pem.Decode(bundle) 120 | if certDERBlock == nil { 121 | break 122 | } 123 | 124 | if certDERBlock.Type == "CERTIFICATE" { 125 | cert, err := x509.ParseCertificate(certDERBlock.Bytes) 126 | if err != nil { 127 | return nil, err 128 | } 129 | certificates = append(certificates, cert) 130 | } 131 | } 132 | 133 | if len(certificates) == 0 { 134 | return nil, errors.New("no certificates were found while parsing the bundle") 135 | } 136 | 137 | return certificates, nil 138 | } 139 | -------------------------------------------------------------------------------- /tools/getocsp/get_ocsp.go: -------------------------------------------------------------------------------- 1 | //用来缓存OCSP文件,因为有些时候请求的校验服务器(如https://ocsp.int-x3.letsencrypt.org)被墙了, 2 | //在本机获取OCSP的校验信息保存下来,让服务器直接读取使用 3 | 4 | package main 5 | 6 | import ( 7 | "bytes" 8 | "crypto/tls" 9 | "crypto/x509" 10 | "encoding/pem" 11 | "errors" 12 | "fmt" 13 | "golang.org/x/crypto/ocsp" 14 | "io/ioutil" 15 | "minGateway/util" 16 | "os" 17 | ) 18 | 19 | var ( 20 | fileDir = "/Users/chunyongchen/GoProjects/minGateway/bin/cert/" 21 | pemName = "truth.juleu.com.pem" 22 | keyName = "truth.juleu.com.key" 23 | ) 24 | 25 | func main() { 26 | getOcspCache() 27 | } 28 | 29 | func getOcspCache() { 30 | certFile := fileDir + pemName 31 | keyFile := fileDir + keyName 32 | crt, err := tls.LoadX509KeyPair(certFile, keyFile) 33 | if err != nil { 34 | fmt.Println(err) 35 | return 36 | } 37 | 38 | fmt.Println(certFile) 39 | fmt.Println(keyFile) 40 | 41 | OCSPBuf, _, certx509, err := GetOCSPForCert(crt.Certificate) 42 | if err == nil { 43 | //写文件 44 | for _, v := range certx509.DNSNames { 45 | name := v + ".ocsp" 46 | ocspFile := fileDir + name 47 | if util.PathExists(ocspFile) { 48 | _ = os.Remove(ocspFile) 49 | } 50 | _ = ioutil.WriteFile(ocspFile, OCSPBuf, 0644) 51 | fmt.Println(name + " file write success") 52 | } 53 | 54 | } else { 55 | fmt.Println(err) 56 | } 57 | } 58 | 59 | func GetOCSPForCert(cert [][]byte) ([]byte, *ocsp.Response, *x509.Certificate, error) { 60 | 61 | bundle := new(bytes.Buffer) 62 | for _, derBytes := range cert { 63 | err := pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 64 | if err != nil { 65 | fmt.Println(err) 66 | return nil, nil, nil, err 67 | } 68 | } 69 | pemBundle := bundle.Bytes() 70 | 71 | certificates, err := parsePEMBundle(pemBundle) 72 | if err != nil { 73 | return nil, nil, nil, err 74 | } 75 | 76 | // We expect the certificate slice to be ordered downwards the chain. 77 | // SRV CRT -> CA. We need to pull the leaf and issuer certs out of it, 78 | // which should always be the first two certificates. If there's no 79 | // OCSP server listed in the leaf cert, there's nothing to do. And if 80 | // we have only one certificate so far, we need to get the issuer cert. 81 | issuedCert := certificates[0] 82 | if len(issuedCert.OCSPServer) == 0 { 83 | return nil, nil, nil, errors.New("no OCSP server specified in cert") 84 | } 85 | if len(certificates) == 1 { 86 | // TODO: build fallback. If this fails, check the remaining array entries. 87 | if len(issuedCert.IssuingCertificateURL) == 0 { 88 | return nil, nil, issuedCert, errors.New("no issuing certificate URL") 89 | } 90 | 91 | resp, errC := util.HttpGet(issuedCert.IssuingCertificateURL[0]) 92 | if errC != nil { 93 | return nil, nil, issuedCert, errC 94 | } 95 | defer resp.Body.Close() 96 | 97 | issuerBytes, errC := ioutil.ReadAll(util.LimitReader(resp.Body, 1024*1024)) 98 | if errC != nil { 99 | return nil, nil, issuedCert, errC 100 | } 101 | 102 | issuerCert, errC := x509.ParseCertificate(issuerBytes) 103 | if errC != nil { 104 | return nil, nil, issuedCert, errC 105 | } 106 | 107 | // Insert it into the slice on position 0 108 | // We want it ordered right SRV CRT -> CA 109 | certificates = append(certificates, issuerCert) 110 | } 111 | issuerCert := certificates[1] 112 | 113 | // Finally kick off the OCSP request. 114 | ocspReq, err := ocsp.CreateRequest(issuedCert, issuerCert, nil) 115 | if err != nil { 116 | return nil, nil, issuedCert, err 117 | } 118 | 119 | reader := bytes.NewReader(ocspReq) 120 | fmt.Println("ocsp server url:", issuedCert.OCSPServer[0]) 121 | req, err := util.HttpPost(issuedCert.OCSPServer[0], "application/ocsp-request", reader) 122 | if err != nil { 123 | return nil, nil, issuedCert, err 124 | } 125 | defer req.Body.Close() 126 | 127 | ocspResBytes, err := ioutil.ReadAll(util.LimitReader(req.Body, 1024*1024)) 128 | if err != nil { 129 | return nil, nil, issuedCert, err 130 | } 131 | 132 | ocspRes, err := ocsp.ParseResponse(ocspResBytes, issuerCert) 133 | if err != nil { 134 | return nil, nil, issuedCert, err 135 | } 136 | 137 | return ocspResBytes, ocspRes, issuedCert, nil 138 | } 139 | 140 | // parsePEMBundle parses a certificate bundle from top to bottom and returns 141 | // a slice of x509 certificates. This function will error if no certificates are found. 142 | func parsePEMBundle(bundle []byte) ([]*x509.Certificate, error) { 143 | var certificates []*x509.Certificate 144 | var certDERBlock *pem.Block 145 | 146 | for { 147 | certDERBlock, bundle = pem.Decode(bundle) 148 | if certDERBlock == nil { 149 | break 150 | } 151 | 152 | if certDERBlock.Type == "CERTIFICATE" { 153 | cert, err := x509.ParseCertificate(certDERBlock.Bytes) 154 | if err != nil { 155 | return nil, err 156 | } 157 | certificates = append(certificates, cert) 158 | } 159 | } 160 | 161 | if len(certificates) == 0 { 162 | return nil, errors.New("no certificates were found while parsing the bundle") 163 | } 164 | 165 | return certificates, nil 166 | } 167 | -------------------------------------------------------------------------------- /util/http.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "net" 10 | "net/http" 11 | "os" 12 | "runtime" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | var ( 18 | // UserAgent (if non-empty) will be tacked onto the User-Agent string in requests. 19 | UserAgent string 20 | 21 | // HTTPClient is an HTTP client with a reasonable timeout value and 22 | // potentially a custom *x509.CertPool based on the caCertificatesEnvVar 23 | // environment variable (see the `initCertPool` function) 24 | HTTPClient = http.Client{ 25 | Transport: &http.Transport{ 26 | Proxy: http.ProxyFromEnvironment, 27 | DialContext: (&net.Dialer{ 28 | Timeout: 30 * time.Second, 29 | KeepAlive: 30 * time.Second, 30 | }).DialContext, 31 | TLSHandshakeTimeout: 15 * time.Second, 32 | ResponseHeaderTimeout: 15 * time.Second, 33 | ExpectContinueTimeout: 1 * time.Second, 34 | TLSClientConfig: &tls.Config{ 35 | ServerName: os.Getenv(caServerNameEnvVar), 36 | RootCAs: initCertPool(), 37 | }, 38 | }, 39 | } 40 | ) 41 | 42 | const ( 43 | // defaultGoUserAgent is the Go HTTP package user agent string. Too 44 | // bad it isn't exported. If it changes, we should update it here, too. 45 | defaultGoUserAgent = "Go-http-client/1.1" 46 | 47 | // ourUserAgent is the User-Agent of this underlying library package. 48 | ourUserAgent = "xenolf-acme" 49 | 50 | // caCertificatesEnvVar is the environment variable name that can be used to 51 | // specify the path to PEM encoded CA Certificates that can be used to 52 | // authenticate an ACME server with a HTTPS certificate not issued by a CA in 53 | // the system-wide trusted root list. 54 | caCertificatesEnvVar = "LEGO_CA_CERTIFICATES" 55 | 56 | // caServerNameEnvVar is the environment variable name that can be used to 57 | // specify the CA server name that can be used to 58 | // authenticate an ACME server with a HTTPS certificate not issued by a CA in 59 | // the system-wide trusted root list. 60 | caServerNameEnvVar = "LEGO_CA_SERVER_NAME" 61 | ) 62 | 63 | // initCertPool creates a *x509.CertPool populated with the PEM certificates 64 | // found in the filepath specified in the caCertificatesEnvVar OS environment 65 | // variable. If the caCertificatesEnvVar is not set then initCertPool will 66 | // return nil. If there is an error creating a *x509.CertPool from the provided 67 | // caCertificatesEnvVar value then initCertPool will panic. 68 | func initCertPool() *x509.CertPool { 69 | if customCACertsPath := os.Getenv(caCertificatesEnvVar); customCACertsPath != "" { 70 | customCAs, err := ioutil.ReadFile(customCACertsPath) 71 | if err != nil { 72 | panic(fmt.Sprintf("error reading %s=%q: %v", 73 | caCertificatesEnvVar, customCACertsPath, err)) 74 | } 75 | certPool := x509.NewCertPool() 76 | if ok := certPool.AppendCertsFromPEM(customCAs); !ok { 77 | panic(fmt.Sprintf("error creating x509 cert pool from %s=%q: %v", 78 | caCertificatesEnvVar, customCACertsPath, err)) 79 | } 80 | return certPool 81 | } 82 | return nil 83 | } 84 | 85 | // httpHead performs a HEAD request with a proper User-Agent string. 86 | // The response body (resp.Body) is already closed when this function returns. 87 | func httpHead(url string) (resp *http.Response, err error) { 88 | req, err := http.NewRequest(http.MethodHead, url, nil) 89 | if err != nil { 90 | return nil, fmt.Errorf("failed to head %q: %v", url, err) 91 | } 92 | 93 | req.Header.Set("User-Agent", userAgent()) 94 | 95 | resp, err = HTTPClient.Do(req) 96 | if err != nil { 97 | return resp, fmt.Errorf("failed to do head %q: %v", url, err) 98 | } 99 | resp.Body.Close() 100 | return resp, err 101 | } 102 | 103 | // httpPost performs a POST request with a proper User-Agent string. 104 | // Callers should close resp.Body when done reading from it. 105 | func HttpPost(url string, bodyType string, body io.Reader) (resp *http.Response, err error) { 106 | req, err := http.NewRequest(http.MethodPost, url, body) 107 | if err != nil { 108 | return nil, fmt.Errorf("failed to post %q: %v", url, err) 109 | } 110 | req.Header.Set("Content-Type", bodyType) 111 | req.Header.Set("User-Agent", userAgent()) 112 | 113 | return HTTPClient.Do(req) 114 | } 115 | 116 | // httpGet performs a GET request with a proper User-Agent string. 117 | // Callers should close resp.Body when done reading from it. 118 | func HttpGet(url string) (resp *http.Response, err error) { 119 | req, err := http.NewRequest(http.MethodGet, url, nil) 120 | if err != nil { 121 | return nil, fmt.Errorf("failed to get %q: %v", url, err) 122 | } 123 | req.Header.Set("User-Agent", userAgent()) 124 | 125 | return HTTPClient.Do(req) 126 | } 127 | 128 | // userAgent builds and returns the User-Agent string to use in requests. 129 | func userAgent() string { 130 | ua := fmt.Sprintf("%s %s (%s; %s) %s", UserAgent, ourUserAgent, runtime.GOOS, runtime.GOARCH, defaultGoUserAgent) 131 | return strings.TrimSpace(ua) 132 | } 133 | 134 | func LimitReader(rd io.ReadCloser, numBytes int64) io.ReadCloser { 135 | return http.MaxBytesReader(nil, rd, numBytes) 136 | } 137 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= 2 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 3 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 h1:Ghm4eQYC0nEPnSJdVkTrXpu9KtoVCSo1hg7mtI7G9KU= 8 | github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239/go.mod h1:Gdwt2ce0yfBxPvZrHkprdPPTTS3N5rwmLE8T22KBXlw= 9 | github.com/g4zhuj/hashring v0.0.0-20180426073119-5d542568fdbd h1:FjwbWhR+yxNuQrGWu6/bN76+TlbMxDFUgjj0FcCyFDs= 10 | github.com/g4zhuj/hashring v0.0.0-20180426073119-5d542568fdbd/go.mod h1:9zdNtZS41ZZRfkQBuvhJxk/tKVdJewHxaUNajTPTUAU= 11 | github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc= 12 | github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= 13 | github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869 h1:IPJ3dvxmJ4uczJe5YQdrYB16oTJlGSC/OyZDqUk9xX4= 14 | github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869/go.mod h1:cJ6Cj7dQo+O6GJNiMx+Pa94qKj+TG8ONdKHgMNIyyag= 15 | github.com/jonboulle/clockwork v0.2.0 h1:J2SLSdy7HgElq8ekSl2Mxh6vrRNFxqbXGenYH2I02Vs= 16 | github.com/jonboulle/clockwork v0.2.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= 17 | github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= 18 | github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= 19 | github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc h1:RKf14vYWi2ttpEmkA4aQ3j4u9dStX2t4M8UM6qqNsG8= 20 | github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc/go.mod h1:kopuH9ugFRkIXf3YoqHKyrJ9YfUFsckUU9S7B+XP+is= 21 | github.com/lestrrat-go/file-rotatelogs v2.3.0+incompatible h1:4mNlp+/SvALIPFpbXV3kxNJJno9iKFWGxSDE13Kl66Q= 22 | github.com/lestrrat-go/file-rotatelogs v2.3.0+incompatible/go.mod h1:ZQnN8lSECaebrkQytbHj4xNgtg8CR7RYXnPok8e0EHA= 23 | github.com/lestrrat-go/strftime v1.0.3 h1:qqOPU7y+TM8Y803I8fG9c/DyKG3xH/xkng6keC1015Q= 24 | github.com/lestrrat-go/strftime v1.0.3/go.mod h1:E1nN3pCbtMSu1yjSVeyuRFVm/U0xoR76fd03sz+Qz4g= 25 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 26 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 27 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 28 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 29 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 30 | github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 h1:mZHayPoR0lNmnHyvtYjDeq0zlVHn9K/ZXoy17ylucdo= 31 | github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5/go.mod h1:GEXHk5HgEKCvEIIrSpFI3ozzG5xOKA2DVlEX/gGnewM= 32 | github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= 33 | github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= 34 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 35 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 36 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 37 | github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= 38 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 39 | github.com/tebeka/strftime v0.1.5 h1:1NQKN1NiQgkqd/2moD6ySP/5CoZQsKa1d3ZhJ44Jpmg= 40 | github.com/tebeka/strftime v0.1.5/go.mod h1:29/OidkoWHdEKZqzyDLUyC+LmgDgdHo4WAFCDT7D/Ig= 41 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 42 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= 43 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 44 | golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 h1:DZhuSZLsGlFL4CmhA8BcRA0mnthyA/nZ00AqCUo7vHg= 45 | golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 46 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 47 | golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU= 48 | golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= 49 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 50 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 51 | golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 52 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= 53 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 54 | golang.org/x/sys v0.0.0-20200720211630-cb9d2d5c5666 h1:gVCS+QOncANNPlmlO1AhlU3oxs4V9z+gTtPwIk3p2N8= 55 | golang.org/x/sys v0.0.0-20200720211630-cb9d2d5c5666/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 56 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 57 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/g4zhuj/hashring" 6 | "github.com/sirupsen/logrus" 7 | "io/ioutil" 8 | llog "log" 9 | "minGateway/util" 10 | "sync" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | func TestHashRing(t *testing.T) { 16 | ips := []string{} 17 | for i := 0; i < 10000; i++ { 18 | ips = append(ips, fmt.Sprint("12.13.456.", i)) 19 | 20 | } 21 | 22 | nodeWeight := make(map[string]int) 23 | nodeWeight["node1"] = 1 24 | nodeWeight["node2"] = 1 25 | nodeWeight["node3"] = 2 26 | vitualSpots := 100 27 | hash := hashring.NewHashRing(vitualSpots) 28 | 29 | //add nodes 30 | hash.AddNodes(nodeWeight) 31 | 32 | node1Count := 0 33 | node2Count := 0 34 | node3Count := 0 35 | //get key's node 36 | for _, ip := range ips { 37 | node := hash.GetNode(ip) 38 | if node == "node1" { 39 | node1Count++ 40 | } 41 | if node == "node2" { 42 | node2Count++ 43 | } 44 | if node == "node3" { 45 | node3Count++ 46 | } 47 | } 48 | 49 | fmt.Println("node1=", node1Count, " node2=", node2Count, " node3=", node3Count) 50 | 51 | fmt.Println(hash.GetNode("")) 52 | 53 | } 54 | 55 | func TestTimeD(t *testing.T) { 56 | in := time.Now() 57 | time.Sleep(100 * time.Millisecond) 58 | t.Log("耗时:", time.Now().Sub(in).Seconds(), "秒") 59 | } 60 | 61 | func TestBytes(t *testing.T) { 62 | b := 1 << 20 63 | t.Log(b, b/1024, "k") 64 | b = 1 << 17 65 | t.Log(b, b/1024, "k") 66 | } 67 | 68 | func TestPing(t *testing.T) { 69 | 70 | quit := make(chan bool) 71 | 72 | succNum := 0 73 | failNum := 0 74 | maxConn := make(chan bool, 5000) 75 | 76 | mu := &sync.Mutex{} 77 | 78 | go func() { 79 | time.Sleep(1 * time.Minute) 80 | quit <- true 81 | return 82 | }() 83 | 84 | ticker := time.NewTicker(20 * time.Millisecond) 85 | 86 | for { 87 | select { 88 | case <-quit: 89 | fmt.Println("end") 90 | goto END 91 | case <-ticker.C: 92 | go sendPing(&succNum, &failNum, mu, maxConn) 93 | } 94 | } 95 | 96 | END: 97 | fmt.Println("succ:", succNum, " fail:", failNum) 98 | 99 | } 100 | 101 | func sendPing(succ, fali *int, mu *sync.Mutex, conn chan bool) { 102 | //req, err := util.HttpGet("http://alogin.obtc.com/zgd/api_sys/ping") 103 | conn <- true 104 | req, err := util.HttpGet("http://139.199.75.120:8080/zgd/api_sys/ping") 105 | <-conn 106 | if err != nil { 107 | fmt.Println("error:", err) 108 | } else { 109 | defer req.Body.Close() 110 | 111 | req, errC := ioutil.ReadAll(util.LimitReader(req.Body, 1024*1024)) 112 | if errC != nil { 113 | fmt.Println("error:", err) 114 | } 115 | if string(req) == "pong" { 116 | mu.Lock() 117 | *succ++ 118 | mu.Unlock() 119 | return 120 | } 121 | } 122 | mu.Lock() 123 | *fali++ 124 | mu.Unlock() 125 | return 126 | } 127 | 128 | var printTagCh = make(chan int, 1000) 129 | 130 | func BenchmarkLogrusChan(b *testing.B) { 131 | type args struct { 132 | level int 133 | } 134 | tests := []struct { 135 | name string 136 | args args 137 | }{ 138 | // TODO: Add test cases. 139 | { 140 | name: "test logrus.level 1", 141 | args: args{level: 1}, 142 | }, 143 | //{ 144 | // name: "test logrus.level 2", 145 | // args: args{level: 2}, 146 | //}, 147 | } 148 | for _, tt := range tests { 149 | b.Run(tt.name, func(b *testing.B) { 150 | l := logrus.New() 151 | 152 | //l.SetNoLock() 153 | 154 | if tt.args.level == 1 { 155 | l.SetLevel(logrus.DebugLevel) 156 | } else if tt.args.level == 2 { 157 | l.SetLevel(logrus.InfoLevel) 158 | } 159 | 160 | for j := 0; j < b.N; j++ { 161 | printTagCh <- j 162 | go func() { 163 | l.Debug(<-printTagCh) 164 | }() 165 | } 166 | //b.Log(l) 167 | }) 168 | } 169 | } 170 | 171 | func BenchmarkLogrus(b *testing.B) { 172 | type args struct { 173 | level int 174 | } 175 | tests := []struct { 176 | name string 177 | args args 178 | }{ 179 | // TODO: Add test cases. 180 | { 181 | name: "test logrus.level 1", 182 | args: args{level: 1}, 183 | }, 184 | { 185 | name: "test logrus.level 2", 186 | args: args{level: 2}, 187 | }, 188 | } 189 | for _, tt := range tests { 190 | b.Run(tt.name, func(b *testing.B) { 191 | l := logrus.New() 192 | 193 | if tt.args.level == 1 { 194 | l.SetLevel(logrus.DebugLevel) 195 | } else if tt.args.level == 2 { 196 | l.SetLevel(logrus.InfoLevel) 197 | } 198 | 199 | for j := 0; j < b.N; j++ { 200 | l.Debug("info info info info info info") 201 | } 202 | //b.Log(l) 203 | }) 204 | } 205 | } 206 | 207 | func BenchmarkSysLog(b *testing.B) { 208 | type args struct { 209 | level int 210 | } 211 | tests := []struct { 212 | name string 213 | args args 214 | }{ 215 | // TODO: Add test cases. 216 | { 217 | name: "test logrus.level 1", 218 | args: args{level: 1}, 219 | }, 220 | } 221 | for _, tt := range tests { 222 | b.Run(tt.name, func(b *testing.B) { 223 | 224 | for j := 0; j < b.N; j++ { 225 | llog.Println("info info info info info info") 226 | fmt.Println("info info info info info info") 227 | } 228 | //b.Log(l) 229 | }) 230 | } 231 | } 232 | 233 | func BenchmarkFmtLog(b *testing.B) { 234 | type args struct { 235 | level int 236 | } 237 | tests := []struct { 238 | name string 239 | args args 240 | }{ 241 | // TODO: Add test cases. 242 | { 243 | name: "test logrus.level 1", 244 | args: args{level: 1}, 245 | }, 246 | } 247 | for _, tt := range tests { 248 | b.Run(tt.name, func(b *testing.B) { 249 | 250 | for j := 0; j < b.N; j++ { 251 | fmt.Println("info info info info info info") 252 | } 253 | //b.Log(l) 254 | }) 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /loadconfig.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "github.com/g4zhuj/hashring" 7 | "io/ioutil" 8 | "minGateway/config" 9 | "minGateway/util" 10 | "minGateway/util/tcache" 11 | "strings" 12 | "time" 13 | ) 14 | 15 | func loadConfig() error { 16 | coreConf := config.Get().CoreConf 17 | limitMaxConn = coreConf.LimitMaxConn 18 | readTimeout = coreConf.ReadTimeout 19 | writeTimeout = coreConf.WriteTimeout 20 | idleTimeout = coreConf.IdleTimeout 21 | maxHeaderBytes = coreConf.MaxHeaderBytes 22 | if coreConf.IpForwardeds != nil && len(coreConf.IpForwardeds) > 0 { 23 | ipForwardeds = coreConf.IpForwardeds 24 | } 25 | 26 | //如果定义了activeDefense即开启防御 27 | if config.Get().CcDefense.Enable { 28 | //监测的时间限定,毫秒 29 | defenseCCCheckTime = config.Get().CcDefense.TimeDuration 30 | 31 | //在限定时间内最大次数 32 | defenseCCCheckCount = config.Get().CcDefense.Count 33 | 34 | //放入黑名单中的时间,秒 35 | defenseCCBlackTime = config.Get().CcDefense.BlackTime 36 | 37 | //ip记录保存10小时 38 | defenseCCIPsRecord = tcache.NewTimeCache(10 * time.Hour) 39 | 40 | //ip黑名单保存时间 41 | defenseCCIPsBlacklist = tcache.NewTimeCache(time.Duration(defenseCCBlackTime) * time.Second) 42 | 43 | //配置黑白名单 44 | defenseCCIPWhiteListConf = make(map[string]struct{}) 45 | defenseCCIPBlackListConf = make(map[string]struct{}) 46 | for _, ip := range config.Get().CcDefense.WhiteList { 47 | defenseCCIPWhiteListConf[ip] = struct{}{} 48 | } 49 | for _, ip := range config.Get().CcDefense.BlackList { 50 | defenseCCIPBlackListConf[ip] = struct{}{} 51 | } 52 | } 53 | 54 | //请求限制配置 55 | if config.Get().LimitReq.Enable { 56 | limitReqMap = tcache.NewTimeCache(time.Millisecond * time.Duration(config.Get().LimitReq.TimeDuration)) 57 | } 58 | 59 | httpConfig := config.Get().HttpProxy 60 | for _, v := range httpConfig { 61 | 62 | info := getHostInfo(v) 63 | 64 | if v.Host == "default" { 65 | //如果定义了default, 遇到未知host走这里 66 | DefaultTarget = &info 67 | } else if strings.HasSuffix(v.Host, ".") { 68 | //关键字在"."的前面 69 | wc := HostInfoWc{KeyPos: 0} 70 | wc.IsMultiTarget = info.IsMultiTarget 71 | wc.MultiTarget = info.MultiTarget 72 | wc.MultiTargetMode = info.MultiTargetMode 73 | wc.Target = info.Target 74 | if info.hashRing != nil { 75 | wc.hashRing = info.hashRing 76 | } 77 | HostListWc[v.Host] = wc 78 | } else if strings.HasPrefix(v.Host, ".") { 79 | //关键字在"."的后面 80 | wc := HostInfoWc{KeyPos: 1} 81 | wc.IsMultiTarget = info.IsMultiTarget 82 | wc.MultiTarget = info.MultiTarget 83 | wc.MultiTargetMode = info.MultiTargetMode 84 | wc.Target = info.Target 85 | if info.hashRing != nil { 86 | wc.hashRing = info.hashRing 87 | } 88 | HostListWc[v.Host] = wc 89 | } else { 90 | HostList[v.Host] = info 91 | if strings.HasPrefix(v.Host, "www.") { 92 | if strings.Count(v.Host, ".") == 2 { 93 | //一级域名,考虑没有带"www"的情况 94 | HostList[strings.TrimLeft(v.Host, "www.")] = HostList[v.Host] 95 | } 96 | } else if strings.Count(v.Host, ".") == 1 { 97 | //排除首位和末位的".","."的数量只有一个说明是没有带"www"的一级域名 98 | HostList["www."+v.Host] = HostList[v.Host] 99 | } 100 | } 101 | } 102 | fmt.Println() 103 | fmt.Println("监听的反向代理域名:") 104 | for k, v := range HostList { 105 | if v.IsMultiTarget { 106 | fmt.Println(k, "->", v.MultiTarget, " mode:", v.MultiTargetMode) 107 | } else { 108 | fmt.Println(k, "->", v.Target) 109 | } 110 | } 111 | for k, v := range HostListWc { 112 | if v.IsMultiTarget { 113 | fmt.Println(k, "->", v.MultiTarget, " mode:", v.MultiTargetMode) 114 | } else { 115 | fmt.Println(k, "->", v.Target) 116 | } 117 | } 118 | if DefaultTarget != nil { 119 | if DefaultTarget.IsMultiTarget { 120 | fmt.Println("default", "->", DefaultTarget.MultiTarget, " mode:", DefaultTarget.MultiTargetMode) 121 | } else { 122 | fmt.Println("default", "->", DefaultTarget.Target) 123 | } 124 | } 125 | 126 | fmt.Println() 127 | fmt.Println("SSL证书:") 128 | sslConfig := config.Get().SslCert 129 | for k, v := range sslConfig { 130 | certFile := util.GetAbsolutePath("/bin/cert/%s", v.SslCertificate) 131 | keyFile := util.GetAbsolutePath("/bin/cert/%s", v.SslCertificateKey) 132 | crt, err := tls.LoadX509KeyPair(certFile, keyFile) 133 | if err != nil { 134 | fmt.Println(err) 135 | return err 136 | } 137 | 138 | fmt.Println(k, ":") 139 | fmt.Println(certFile) 140 | fmt.Println(keyFile) 141 | 142 | //OCSP Stapling 143 | //在crt.OCSPStaple上附值则表示开启OCSP封套 144 | if v.OcspStapling { 145 | //是读取本地缓存文件还是在线获取 146 | if v.OcspStaplingLocal { 147 | ocspFile := util.GetAbsolutePath("/bin/cert/%s", v.OcspStaplingFile) 148 | OCSPBuf, err := ioutil.ReadFile(ocspFile) 149 | if err == nil { 150 | crt.OCSPStaple = OCSPBuf 151 | fmt.Println(k, ", local load, OCSP Stapling OK") 152 | } else { 153 | fmt.Println(err) 154 | } 155 | } else { 156 | OCSPBuf, _, _, err := GetOCSPForCert(crt.Certificate) 157 | if err == nil { 158 | crt.OCSPStaple = OCSPBuf 159 | fmt.Println(k, ", online load, OCSP Stapling OK") 160 | } 161 | } 162 | } 163 | 164 | CertificateSet = append(CertificateSet, crt) 165 | 166 | } 167 | 168 | fmt.Println() 169 | 170 | return nil 171 | } 172 | 173 | func getHostInfo(proxyInfo config.ProxyInfo) HostInfo { 174 | var hostInfo HostInfo 175 | for i := 0; i < len(proxyInfo.Target); i++ { 176 | proxyInfo.Target[i] = strings.ReplaceAll(proxyInfo.Target[i], " ", "") 177 | } 178 | if len(proxyInfo.Target) == 0 { 179 | panic(proxyInfo.Host + " : len(proxyInfo.Target) == 0") 180 | } else if len(proxyInfo.Target) == 1 { 181 | hostInfo = HostInfo{IsMultiTarget: false, Target: proxyInfo.Target[0]} 182 | } else { 183 | //定义了多个目标,使用分流 184 | targets := proxyInfo.Target 185 | hostInfo = HostInfo{IsMultiTarget: true, MultiTarget: targets, MultiTargetMode: ObtainMode(proxyInfo.ObtainMode)} 186 | if proxyInfo.ObtainMode == 3 { //哈希模式 187 | //把节点放到hashring中,同时设置权重 188 | hostInfo.hashRing = hashring.NewHashRing(100) //vitualSpots=100 189 | nodeWeight := make(map[string]int) 190 | for _, target := range targets { 191 | nodeWeight[target] = 1 //这里简化了,权重都设置为1 192 | } 193 | hostInfo.hashRing.AddNodes(nodeWeight) 194 | } 195 | } 196 | return hostInfo 197 | } 198 | -------------------------------------------------------------------------------- /sessiontickets.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 Matthew Holt and The Caddy Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package main 16 | 17 | import ( 18 | "crypto/rand" 19 | "crypto/tls" 20 | "fmt" 21 | "io" 22 | "minGateway/util/log" 23 | "runtime/debug" 24 | "sync" 25 | "time" 26 | ) 27 | 28 | // SessionTicketService configures and manages TLS session tickets. 29 | type SessionTicketService struct { 30 | // How often Caddy rotates STEKs. Default: 12h. 31 | RotationInterval time.Duration `json:"rotation_interval,omitempty"` 32 | 33 | // The maximum number of keys to keep in rotation. Default: 4. 34 | MaxKeys int `json:"max_keys,omitempty"` 35 | 36 | // Disables STEK rotation. 37 | DisableRotation bool `json:"disable_rotation,omitempty"` 38 | 39 | // Disables TLS session resumption by tickets. 40 | Disabled bool `json:"disabled,omitempty"` 41 | 42 | ticker *time.Ticker 43 | 44 | configs map[*tls.Config]struct{} 45 | stopChan chan struct{} 46 | currentKeys [][32]byte 47 | mu *sync.Mutex 48 | } 49 | 50 | func (s *SessionTicketService) Run() error { 51 | s.configs = make(map[*tls.Config]struct{}) 52 | s.mu = new(sync.Mutex) 53 | 54 | // establish sane defaults 55 | if s.RotationInterval == 0 { 56 | s.RotationInterval = defaultSTEKRotationInterval 57 | } 58 | 59 | if s.MaxKeys <= 0 { 60 | s.MaxKeys = defaultMaxSTEKs 61 | } 62 | 63 | // start the STEK module; this ensures we have 64 | // a starting key before any config needs one 65 | return s.start() 66 | } 67 | 68 | // start loads the starting STEKs and spawns a goroutine 69 | // which loops to rotate the STEKs, which continues until 70 | // stop() is called. If start() was already called, this 71 | // is a no-op. 72 | func (s *SessionTicketService) start() error { 73 | if s.stopChan != nil { 74 | return nil 75 | } 76 | s.stopChan = make(chan struct{}) 77 | 78 | // initializing the key source gives us our 79 | // initial key(s) to start with; if successful, 80 | // we need to be sure to call Next() so that 81 | // the key source can know when it is done 82 | initialKey, err := s.generateSTEK() 83 | if err != nil { 84 | return fmt.Errorf("setting STEK module configuration: %v", err) 85 | } 86 | 87 | s.mu.Lock() 88 | s.currentKeys = [][32]byte{initialKey} 89 | s.mu.Unlock() 90 | 91 | // keep the keys rotated 92 | go s.stayUpdated() 93 | 94 | return nil 95 | } 96 | 97 | // stayUpdated is a blocking function which rotates 98 | // the keys whenever new ones are sent. It reads 99 | // from keysChan until s.stop() is called. 100 | func (s *SessionTicketService) stayUpdated() { 101 | defer func() { 102 | if err := recover(); err != nil { 103 | log.Printf("[PANIC] session ticket service: %v\n%s", err, debug.Stack()) 104 | } 105 | }() 106 | 107 | s.ticker = time.NewTicker(s.RotationInterval) 108 | defer func() { 109 | _ = s.ticker.Stop 110 | }() 111 | 112 | for { 113 | select { 114 | case <-s.ticker.C: 115 | s.mu.Lock() 116 | 117 | if s.DisableRotation { 118 | s.mu.Unlock() 119 | continue 120 | } 121 | 122 | newKeys, err := s.RotateSTEKs(s.currentKeys) 123 | if err == nil { 124 | s.currentKeys = newKeys 125 | } 126 | fmt.Println(newKeys) 127 | 128 | configs := s.configs 129 | 130 | s.mu.Unlock() 131 | 132 | for cfg := range configs { 133 | cfg.SetSessionTicketKeys(newKeys) 134 | } 135 | case <-s.stopChan: 136 | return 137 | } 138 | } 139 | } 140 | 141 | // stop terminates the key rotation goroutine. 142 | func (s *SessionTicketService) Stop() { 143 | if s.stopChan != nil { 144 | close(s.stopChan) 145 | } 146 | } 147 | 148 | // register sets the session ticket keys on cfg 149 | // and keeps them updated. Any values registered 150 | // must be unregistered, or they will not be 151 | // garbage-collected. s.start() must have been 152 | // called first. If session tickets are disabled 153 | // or if ticket key rotation is disabled, this 154 | // function is a no-op. 155 | func (s *SessionTicketService) Register(cfg *tls.Config) { 156 | s.mu.Lock() 157 | cfg.SetSessionTicketKeys(s.currentKeys) 158 | s.configs[cfg] = struct{}{} 159 | s.mu.Unlock() 160 | } 161 | 162 | // unregister stops session key management on cfg and 163 | // removes the internal stored reference to cfg. If 164 | // session tickets are disabled or if ticket key rotation 165 | // is disabled, this function is a no-op. 166 | func (s *SessionTicketService) Unregister(cfg *tls.Config) { 167 | s.mu.Lock() 168 | delete(s.configs, cfg) 169 | s.mu.Unlock() 170 | } 171 | 172 | // RotateSTEKs rotates the keys in keys by producing a new key and eliding 173 | // the oldest one. The new slice of keys is returned. 174 | func (s SessionTicketService) RotateSTEKs(keys [][32]byte) ([][32]byte, error) { 175 | // produce a new key 176 | newKey, err := s.generateSTEK() 177 | if err != nil { 178 | return nil, fmt.Errorf("generating STEK: %v", err) 179 | } 180 | 181 | // we need to prepend this new key to the list of 182 | // keys so that it is preferred, but we need to be 183 | // careful that we do not grow the slice larger 184 | // than MaxKeys, otherwise we'll be storing one 185 | // more key in memory than we expect; so be sure 186 | // that the slice does not grow beyond the limit 187 | // even for a brief period of time, since there's 188 | // no guarantee when that extra allocation will 189 | // be overwritten; this is why we first trim the 190 | // length to one less the max, THEN prepend the 191 | // new key 192 | if len(keys) >= s.MaxKeys { 193 | keys[len(keys)-1] = [32]byte{} // zero-out memory of oldest key 194 | keys = keys[:s.MaxKeys-1] // trim length of slice 195 | } 196 | keys = append([][32]byte{newKey}, keys...) // prepend new key 197 | 198 | return keys, nil 199 | } 200 | 201 | // generateSTEK generates key material suitable for use as a 202 | // session ticket ephemeral key. 203 | func (s *SessionTicketService) generateSTEK() ([32]byte, error) { 204 | var newTicketKey [32]byte 205 | _, err := io.ReadFull(rand.Reader, newTicketKey[:]) 206 | return newTicketKey, err 207 | } 208 | 209 | const ( 210 | defaultSTEKRotationInterval = 12 * time.Hour 211 | defaultMaxSTEKs = 4 212 | ) 213 | -------------------------------------------------------------------------------- /util/string.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "encoding/xml" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "math/rand" 11 | "reflect" 12 | "strconv" 13 | "strings" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | //RandomStr 随机生成字符串 19 | func RandomStr(length int) string { 20 | str := "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 21 | bytes := []byte(str) 22 | result := []byte{} 23 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 24 | for i := 0; i < length; i++ { 25 | result = append(result, bytes[r.Intn(len(bytes))]) 26 | } 27 | return string(result) 28 | } 29 | 30 | // convert string to specify type 31 | 32 | type StrTo string 33 | 34 | func (f *StrTo) Set(v string) { 35 | if v != "" { 36 | *f = StrTo(v) 37 | } else { 38 | f.Clear() 39 | } 40 | } 41 | 42 | func (f *StrTo) Clear() { 43 | *f = StrTo(0x1E) 44 | } 45 | 46 | func (f StrTo) Exist() bool { 47 | return string(f) != string(0x1E) 48 | } 49 | 50 | func (f StrTo) Bool() (bool, error) { 51 | if f == "on" { 52 | return true, nil 53 | } 54 | return strconv.ParseBool(f.String()) 55 | } 56 | 57 | func (f StrTo) Float32() (float32, error) { 58 | v, err := strconv.ParseFloat(f.String(), 32) 59 | return float32(v), err 60 | } 61 | 62 | func (f StrTo) Float64() (float64, error) { 63 | return strconv.ParseFloat(f.String(), 64) 64 | } 65 | 66 | func (f StrTo) Int() (int, error) { 67 | v, err := strconv.ParseInt(f.String(), 10, 32) 68 | return int(v), err 69 | } 70 | 71 | func (f StrTo) Int8() (int8, error) { 72 | v, err := strconv.ParseInt(f.String(), 10, 8) 73 | return int8(v), err 74 | } 75 | 76 | func (f StrTo) Int16() (int16, error) { 77 | v, err := strconv.ParseInt(f.String(), 10, 16) 78 | return int16(v), err 79 | } 80 | 81 | func (f StrTo) Int32() (int32, error) { 82 | v, err := strconv.ParseInt(f.String(), 10, 32) 83 | return int32(v), err 84 | } 85 | 86 | func (f StrTo) Int64() (int64, error) { 87 | v, err := strconv.ParseInt(f.String(), 10, 64) 88 | return int64(v), err 89 | } 90 | 91 | func (f StrTo) Uint() (uint, error) { 92 | v, err := strconv.ParseUint(f.String(), 10, 32) 93 | return uint(v), err 94 | } 95 | 96 | func (f StrTo) Uint8() (uint8, error) { 97 | v, err := strconv.ParseUint(f.String(), 10, 8) 98 | return uint8(v), err 99 | } 100 | 101 | func (f StrTo) Uint16() (uint16, error) { 102 | v, err := strconv.ParseUint(f.String(), 10, 16) 103 | return uint16(v), err 104 | } 105 | 106 | func (f StrTo) Uint32() (uint32, error) { 107 | v, err := strconv.ParseUint(f.String(), 10, 32) 108 | return uint32(v), err 109 | } 110 | 111 | func (f StrTo) Uint64() (uint64, error) { 112 | v, err := strconv.ParseUint(f.String(), 10, 64) 113 | return uint64(v), err 114 | } 115 | 116 | func (f StrTo) String() string { 117 | if f.Exist() { 118 | return string(f) 119 | } 120 | return "" 121 | } 122 | 123 | // convert any type to string 124 | func ToStr(value interface{}, args ...int) (s string) { 125 | switch v := value.(type) { 126 | case bool: 127 | s = strconv.FormatBool(v) 128 | case float32: 129 | s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) 130 | case float64: 131 | s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) 132 | case int: 133 | s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) 134 | case int8: 135 | s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) 136 | case int16: 137 | s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) 138 | case int32: 139 | s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) 140 | case int64: 141 | s = strconv.FormatInt(v, argInt(args).Get(0, 10)) 142 | case uint: 143 | s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) 144 | case uint8: 145 | s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) 146 | case uint16: 147 | s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) 148 | case uint32: 149 | s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) 150 | case uint64: 151 | s = strconv.FormatUint(v, argInt(args).Get(0, 10)) 152 | case string: 153 | s = v 154 | case []byte: 155 | s = string(v) 156 | default: 157 | s = fmt.Sprintf("%v", v) 158 | } 159 | return s 160 | } 161 | 162 | // convert any numeric value to int64 163 | func ToInt64(value interface{}) (d int64, err error) { 164 | val := reflect.ValueOf(value) 165 | switch value.(type) { 166 | case int, int8, int16, int32, int64: 167 | d = val.Int() 168 | case uint, uint8, uint16, uint32, uint64: 169 | d = int64(val.Uint()) 170 | default: 171 | err = fmt.Errorf("ToInt64 need numeric not `%T`", value) 172 | } 173 | return 174 | } 175 | 176 | type argString []string 177 | 178 | func (a argString) Get(i int, args ...string) (r string) { 179 | if i >= 0 && i < len(a) { 180 | r = a[i] 181 | } else if len(args) > 0 { 182 | r = args[0] 183 | } 184 | return 185 | } 186 | 187 | type argInt []int 188 | 189 | func (a argInt) Get(i int, args ...int) (r int) { 190 | if i >= 0 && i < len(a) { 191 | r = a[i] 192 | } 193 | if len(args) > 0 { 194 | r = args[0] 195 | } 196 | return 197 | } 198 | 199 | type argAny []interface{} 200 | 201 | func (a argAny) Get(i int, args ...interface{}) (r interface{}) { 202 | if i >= 0 && i < len(a) { 203 | r = a[i] 204 | } 205 | if len(args) > 0 { 206 | r = args[0] 207 | } 208 | return 209 | } 210 | 211 | func formatMapToXML(req map[string]string) (buf []byte, err error) { 212 | bodyBuf := textBufferPool.Get().(*bytes.Buffer) 213 | bodyBuf.Reset() 214 | defer textBufferPool.Put(bodyBuf) 215 | 216 | if bodyBuf == nil { 217 | return []byte{}, errors.New("nil xmlWriter") 218 | } 219 | 220 | if _, err = io.WriteString(bodyBuf, ""); err != nil { 221 | return 222 | } 223 | 224 | for k, v := range req { 225 | if _, err = io.WriteString(bodyBuf, "<"+k+">"); err != nil { 226 | return 227 | } 228 | if err = xml.EscapeText(bodyBuf, []byte(v)); err != nil { 229 | return 230 | } 231 | if _, err = io.WriteString(bodyBuf, ""); err != nil { 232 | return 233 | } 234 | } 235 | 236 | if _, err = io.WriteString(bodyBuf, ""); err != nil { 237 | return 238 | } 239 | 240 | return bodyBuf.Bytes(), nil 241 | } 242 | 243 | var textBufferPool = sync.Pool{ 244 | New: func() interface{} { 245 | return bytes.NewBuffer(make([]byte, 0, 16<<10)) // 16KB 246 | }, 247 | } 248 | 249 | func GetMD5(args ...string) string { 250 | var str string 251 | for _, s := range args { 252 | str += s 253 | } 254 | value := md5.Sum([]byte(str)) 255 | rs := []rune(fmt.Sprintf("%x", value)) 256 | return string(rs) 257 | } 258 | 259 | func GetSign(args ...string) string { 260 | salt := "Yexhj8agldf3yaexuda7da" 261 | var str string 262 | for _, s := range args { 263 | str += s 264 | } 265 | str += salt 266 | value := md5.Sum([]byte(str)) 267 | rs := []rune(fmt.Sprintf("%x", value)) 268 | return string(rs) 269 | } 270 | 271 | //res:原字符串,sep替换的字符串,idx开始替换的位置 272 | //如果替换的字符串超出,就加在原字符串后面 273 | //idx从0开始 274 | func ReplaceString(res, sep string, idx int) string { 275 | sepLen := len(sep) 276 | if sepLen == 0 { 277 | return res 278 | } 279 | 280 | resLen := len(res) 281 | if idx > resLen-1 { 282 | return res + sep 283 | } 284 | 285 | allLen := resLen 286 | if sepLen > resLen-idx { 287 | allLen = idx + sepLen 288 | } 289 | 290 | buf := bytes.Buffer{} 291 | sepIdx := 0 292 | for i := 0; i < allLen; i++ { 293 | if i < idx { 294 | buf.WriteByte(res[i]) 295 | } else { 296 | if sepIdx < sepLen { 297 | buf.WriteByte(sep[sepIdx]) 298 | } else { 299 | buf.WriteByte(res[i]) 300 | } 301 | sepIdx++ 302 | } 303 | } 304 | return buf.String() 305 | } 306 | 307 | // 308 | func ConvFormMapToString(mData map[string]string) string { 309 | formBuf := bytes.Buffer{} 310 | l := len(mData) 311 | i := 0 312 | for k, v := range mData { 313 | formBuf.WriteString(k) 314 | formBuf.WriteString("=") 315 | formBuf.WriteString(v) 316 | if i < l { 317 | formBuf.WriteString("&") 318 | i++ 319 | } 320 | } 321 | return string(formBuf.Bytes()) 322 | } 323 | 324 | //检查是否包含,必须全部包含 325 | func CheckContains(s string, subArr ...string) bool { 326 | for _, sub := range subArr { 327 | if !strings.Contains(s, sub) { 328 | return false 329 | } 330 | } 331 | return true 332 | } 333 | 334 | //有任何一个包含就返回true 335 | func CheckContainsAny(s string, subArr ...string) bool { 336 | for _, sub := range subArr { 337 | if strings.Contains(s, sub) { 338 | return true 339 | } 340 | } 341 | return false 342 | } 343 | -------------------------------------------------------------------------------- /gateway.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "github.com/g4zhuj/hashring" 7 | "golang.org/x/net/netutil" 8 | "math/rand" 9 | "minGateway/config" 10 | "minGateway/status" 11 | "minGateway/util/log" 12 | "net" 13 | "net/http" 14 | "net/http/httputil" 15 | "net/url" 16 | "strings" 17 | "time" 18 | ) 19 | 20 | var ( 21 | limitMaxConn int //最大连接数 22 | readTimeout int //读超时 23 | writeTimeout int //写超时 24 | idleTimeout int //闲置超时 25 | maxHeaderBytes int //最大头字节 26 | ) 27 | 28 | type ObtainMode int //多转发目标时的选择模式 29 | 30 | const ( 31 | SelectModeRandom ObtainMode = 1 //随机选择 32 | SelectModePoll ObtainMode = 2 //轮询选择 33 | SelectModeHash ObtainMode = 3 //哈希选择 34 | ) 35 | 36 | type HostInfoInterface interface { 37 | GetTarget(req *http.Request) string 38 | } 39 | 40 | type HostInfo struct { 41 | Target string //转发目标域名 42 | MultiTarget []string //有多转发目标的域名集合 43 | IsMultiTarget bool //是否有多转发目标 44 | MultiTargetMode ObtainMode //多转发目标选择模式 45 | PoolModeIndex int //轮询模式索引 46 | hashRing *hashring.HashRing //一致性哈希 47 | } 48 | 49 | func (hostInfo *HostInfo) GetTarget(req *http.Request) string { 50 | var route string 51 | if hostInfo.IsMultiTarget { 52 | if hostInfo.MultiTargetMode == SelectModeRandom { //随机模式 53 | route = hostInfo.MultiTarget[rand.Int()%len(hostInfo.MultiTarget)] 54 | } else if hostInfo.MultiTargetMode == SelectModePoll { //轮询模式 55 | route = hostInfo.MultiTarget[hostInfo.PoolModeIndex] 56 | hostInfo.PoolModeIndex++ 57 | hostInfo.PoolModeIndex = hostInfo.PoolModeIndex % len(hostInfo.MultiTarget) 58 | } else if hostInfo.MultiTargetMode == SelectModeHash { //哈希模式 59 | ips := getIpAddr(req) 60 | route = hostInfo.hashRing.GetNode(ips[0]) 61 | } else { //未配置或配置错误使用随机模式 62 | route = hostInfo.MultiTarget[rand.Int()%len(hostInfo.MultiTarget)] 63 | } 64 | } else { 65 | route = hostInfo.Target 66 | } 67 | return route 68 | } 69 | 70 | var HostList map[string]HostInfo 71 | 72 | //通配符地址 wildcard character 73 | type HostInfoWc struct { 74 | HostInfo 75 | //关键字的位置 0:前面 1:后面 76 | //关键字书写 如:wxpay. 那么KeyPos为0 77 | KeyPos int 78 | } 79 | 80 | var HostListWc map[string]HostInfoWc 81 | 82 | //缺省转发,如果配置文件上定义了缺省转发,那么有消息进入时没找到已定义的转发地址就转发到缺省定义上 83 | var DefaultTarget *HostInfo 84 | 85 | var CertificateSet []tls.Certificate 86 | 87 | type Proxy struct{} 88 | 89 | func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 90 | log.Debugf("-> %s ", r.Host) 91 | 92 | var ip string 93 | var ipsSize int 94 | if config.Get().CcDefense.Enable || config.Get().LimitReq.Enable { 95 | ips := getIpAddr(r) 96 | ipsSize = len(ips) 97 | if ipsSize > 0 { 98 | ip = ips[0] 99 | } 100 | } 101 | 102 | //如果配置开启CC防御 103 | if config.Get().CcDefense.Enable { 104 | //看是否在黑名单内,如果在黑名单内,直接返回 105 | if ipsSize > 0 { 106 | //检查是否符合阻挡规则 107 | if defenseCCBlockCheck(ip) { 108 | _ = r.Body.Close() 109 | return 110 | } 111 | } 112 | } 113 | 114 | //如果开启了请求数限制 115 | if config.Get().LimitReq.Enable { 116 | if ipsSize > 0 { 117 | // 是否超过限制请求数 118 | if ExceededLimitReq(ip, r) { 119 | return 120 | } 121 | } 122 | } 123 | 124 | in := time.Now() 125 | 126 | //设置状态:连接数,用于后面获取连接数 127 | status.Instance().AddReqCount() 128 | defer status.Instance().SubReqCount() 129 | 130 | //根据配置选择转发到哪 131 | var route string //转发的目标 132 | var existRoute = false 133 | if len(r.Host) == 0 { 134 | if DefaultTarget != nil { 135 | route = DefaultTarget.GetTarget(r) 136 | existRoute = true 137 | } 138 | } else if hostInfo, ok := HostList[r.Host]; ok { 139 | route = hostInfo.GetTarget(r) 140 | existRoute = true 141 | } else if len(HostListWc) > 0 { 142 | //轮询通配符集合,查找有没有符合的域名 143 | for likeHost, hostInfo := range HostListWc { 144 | if (hostInfo.KeyPos == 0 && strings.HasPrefix(r.Host, likeHost)) || 145 | (hostInfo.KeyPos == 1 && strings.HasSuffix(r.Host, likeHost)) { 146 | route = hostInfo.GetTarget(r) 147 | existRoute = true 148 | break 149 | } 150 | } 151 | } 152 | if !existRoute { 153 | if DefaultTarget != nil { 154 | route = DefaultTarget.GetTarget(r) 155 | existRoute = true 156 | } else { 157 | log.Warnf("未配置的代理, %s", r.Host) 158 | return 159 | } 160 | } 161 | 162 | log.Debugf("-> %s", route) 163 | 164 | //找到转发目标,继续 165 | if existRoute { 166 | target, err := url.Parse(route) 167 | if err != nil { 168 | log.Error("url.Parse失败") 169 | return 170 | } 171 | 172 | proxy := newHostReverseProxy(target) 173 | proxy.ServeHTTP(w, r) 174 | } 175 | 176 | log.Debug("耗时:", time.Now().Sub(in).Seconds(), "秒") 177 | } 178 | 179 | func singleJoiningSlash(a, b string) string { 180 | aslash := strings.HasSuffix(a, "/") 181 | bslash := strings.HasPrefix(b, "/") 182 | switch { 183 | case aslash && bslash: 184 | return a + b[1:] 185 | case !aslash && !bslash: 186 | return a + "/" + b 187 | } 188 | return a + b 189 | } 190 | 191 | func newHostReverseProxy(target *url.URL) *httputil.ReverseProxy { 192 | director := func(req *http.Request) { 193 | targetQuery := target.RawQuery 194 | req.URL.Scheme = target.Scheme 195 | req.URL.Host = target.Host 196 | req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) 197 | if targetQuery == "" || req.URL.RawQuery == "" { 198 | req.URL.RawQuery = targetQuery + req.URL.RawQuery 199 | } else { 200 | req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 201 | } 202 | if _, ok := req.Header["User-Agent"]; !ok { 203 | // explicitly disable User-Agent so it's not set to default value 204 | req.Header.Set("User-Agent", "") 205 | } 206 | req.Header["X-Real-Ip"] = getIpAddr(req) 207 | log.Debug("X-Real-Ip=", req.Header["X-Real-Ip"]) 208 | //for k, v := range req.Header { 209 | // log.Debug(k, v) 210 | //} 211 | } 212 | return &httputil.ReverseProxy{Director: director} 213 | } 214 | 215 | type GateServer struct{} 216 | 217 | func (s *GateServer) proxy80() *http.Server { 218 | ln, err := net.Listen("tcp", ":80") 219 | if err != nil { 220 | panic(err) 221 | } 222 | 223 | if limitMaxConn > 0 { 224 | //限流 225 | ln = netutil.LimitListener(ln, limitMaxConn) 226 | } 227 | 228 | p := &Proxy{} 229 | srv := &http.Server{Addr: ":80", Handler: p} 230 | if readTimeout > 0 { 231 | srv.ReadTimeout = time.Duration(readTimeout) * time.Second 232 | } 233 | if writeTimeout > 0 { 234 | srv.WriteTimeout = time.Duration(writeTimeout) * time.Second 235 | } 236 | if idleTimeout > 0 { 237 | srv.IdleTimeout = time.Duration(idleTimeout) * time.Second 238 | } 239 | if maxHeaderBytes > 0 { 240 | srv.MaxHeaderBytes = maxHeaderBytes 241 | } 242 | go func() { 243 | if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed { 244 | panic(err.Error()) 245 | } 246 | }() 247 | fmt.Println("网关监听端口:80") 248 | 249 | return srv 250 | } 251 | 252 | func makeTlsConfig() *tls.Config { 253 | config1 := &tls.Config{Certificates: CertificateSet} 254 | config1.BuildNameToCertificate() //BuildNameToCertificate()使之能嗅探域名,如果没找到信息则使用数组[0] 255 | fmt.Println("\nSSL Set:", config1.NameToCertificate) 256 | return config1 257 | } 258 | 259 | func (s *GateServer) proxy443(tlsConfig *tls.Config) *http.Server { 260 | ln, err := tls.Listen("tcp", ":443", tlsConfig) 261 | if err != nil { 262 | panic(err) 263 | } 264 | 265 | if limitMaxConn > 0 { 266 | //限流 267 | ln = netutil.LimitListener(ln, limitMaxConn) 268 | } 269 | 270 | p := &Proxy{} 271 | srv := &http.Server{Addr: ":443", Handler: p} 272 | if readTimeout > 0 { 273 | srv.ReadTimeout = time.Duration(readTimeout) * time.Second 274 | } 275 | if writeTimeout > 0 { 276 | srv.WriteTimeout = time.Duration(writeTimeout) * time.Second 277 | } 278 | if idleTimeout > 0 { 279 | srv.IdleTimeout = time.Duration(idleTimeout) * time.Second 280 | } 281 | if maxHeaderBytes > 0 { 282 | srv.MaxHeaderBytes = maxHeaderBytes 283 | } 284 | go func() { 285 | if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed { 286 | panic(err.Error()) 287 | } 288 | }() 289 | fmt.Println("网关监听端口:443") 290 | 291 | return srv 292 | } 293 | 294 | func (s *GateServer) run() []*http.Server { 295 | 296 | ss := make([]*http.Server, 0) 297 | 298 | p80 := s.proxy80() 299 | ss = append(ss, p80) 300 | 301 | if len(CertificateSet) > 0 { 302 | 303 | tlsConfig := makeTlsConfig() 304 | 305 | // 是否开启SessionTicket,TLS1.3中即是否开启PSK 306 | // TLS1.3中SessionTicket报文也是加密的,我通过抓包无法看到 New session ticket 的报文,这里可能有问题,SessionTicket可能没起效 307 | if !config.Get().SslBase.SessionTicket { 308 | tlsConfig.SessionTicketsDisabled = true //禁止 309 | } else { 310 | sessiontickets := &SessionTicketService{} 311 | err := sessiontickets.Run() 312 | if err != nil { 313 | log.Error("SessionTicketService error", err) 314 | } else { 315 | sessiontickets.Register(tlsConfig) 316 | } 317 | } 318 | 319 | p443 := s.proxy443(tlsConfig) 320 | ss = append(ss, p443) 321 | } 322 | 323 | return ss 324 | } 325 | --------------------------------------------------------------------------------