├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── amqp ├── amqp.go ├── amqp_test.go ├── app.go ├── consumer.go ├── context.go ├── exchange.go ├── message.go ├── mq.go ├── queue.go └── readme.md ├── config ├── config.go ├── config_test.go ├── readme.md └── test │ ├── config.json │ ├── config.local.toml │ ├── config.myenv.toml │ ├── config.toml │ └── config.yml ├── db ├── config.go ├── gorm.go ├── gorm_test.go ├── readme.md ├── resource │ ├── gorm.go │ ├── gorm_test.go │ ├── gorm_util.go │ ├── readme.md │ └── resource.go ├── sqlite │ └── sqlite.go ├── xorm.go └── xorm_test.go ├── email ├── email.go ├── email_test.go └── readme.md ├── emit ├── emit.go ├── emit_test.go └── readme.md ├── gin ├── midleware │ ├── cache-control.go │ ├── error.go │ └── logger.go ├── readme.md └── util │ ├── readme.md │ └── response.go ├── go.mod ├── go.sum ├── http ├── http.go ├── http_test.go ├── readme.md ├── request.go ├── response.go └── v2 │ ├── http.go │ ├── http_test.go │ ├── readme.md │ ├── request.go │ ├── request_api.go │ ├── response.go │ ├── transform.go │ ├── types.go │ └── util.go ├── log ├── log.go ├── log_test.go ├── logger.go └── readme.md ├── readme.md ├── redis ├── cmd.go ├── readme.md ├── redis.go └── redis_test.go ├── tcp ├── client.go ├── client_test.go ├── config.go ├── conn.go ├── logger.go ├── message.go ├── package.go ├── readme.md ├── server.go ├── server_test.go └── tcpsrv │ ├── client.go │ ├── client_test.go │ ├── context.go │ ├── readme.md │ ├── server.go │ └── server_test.go ├── test.sh ├── types ├── json.go ├── json_test.go ├── readme.md ├── time.go └── time_test.go ├── util ├── bytes.go ├── crypto.go ├── crypto_test.go ├── readme.md ├── resp.go ├── util.go └── util_test.go └── websocket ├── conn.go ├── logger.go ├── message.go ├── readme.md ├── websocket.go ├── websocket_test.go └── wsrv ├── context.go ├── readme.md ├── srv.go └── srv_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | runtime 2 | .runtime 3 | .idea 4 | *.log -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "go.formatTool": "goimports" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 eyasliu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /amqp/amqp.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | const ( 8 | ExchangeDirect = "direct" // 直连交换机 9 | ExchangeFanout = "fanout" // 扇形交换机 10 | ExchangeTopic = "topic" // 主题交换机 11 | ExchangeHeaders = "headers" // 头交换机 12 | ) 13 | 14 | // Config 配置项 15 | // ExchangeName 和 Exchange 二选一,用于指定发布和订阅时使用的交换机 16 | type Config struct { 17 | Addr string // rabbitmq 地址 18 | ExchangeName string // 使用该值创建一个直连的交换机 19 | Exchange *Exchange // 自定义默认交换机 20 | Consumer *Consumer // 在定于队列时,作为消费者使用的参数 21 | } 22 | 23 | // Init 初始化 24 | func New(conf *Config) (*MQ, error) { 25 | if conf.Exchange == nil && conf.ExchangeName == "" { 26 | return nil, errors.New("exchange must defined") 27 | } 28 | 29 | if conf.Exchange == nil && conf.ExchangeName != "" { 30 | conf.Exchange = defaultExchange(conf.ExchangeName) 31 | } 32 | 33 | if conf.Consumer == nil { 34 | conf.Consumer = defaultConsumer() 35 | } 36 | 37 | m := &MQ{Addr: conf.Addr, Exchange: conf.Exchange, Consumer: conf.Consumer} 38 | 39 | err := m.Init() 40 | if err != nil { 41 | return nil, err 42 | } 43 | 44 | return m, nil 45 | } 46 | -------------------------------------------------------------------------------- /amqp/amqp_test.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestAmqp(t *testing.T) { 11 | queue := &Queue{Name: "toolkit.queue1.test"} 12 | // exchange := &Exchange{Name: "toolkit.exchange.test"} 13 | 14 | msg := &Message{ 15 | Data: []byte("{\"seqno\":\"1563541319\",\"cmd\":\"44\",\"data\":{\"mid\":1070869}}"), 16 | } 17 | 18 | mq, err := New(&Config{ 19 | Addr: "amqp://guest:guest@mq.aam.test:5672/", 20 | ExchangeName: "toolkit.exchange1.test", 21 | }) 22 | if err != nil { 23 | panic(err) 24 | } 25 | 26 | testCount := 1000 27 | 28 | startTime := time.Now() 29 | 30 | wg := sync.WaitGroup{} 31 | si := 0 32 | for ; si < testCount; si++ { 33 | err := mq.Pub(queue, msg) 34 | if err != nil { 35 | panic(err) 36 | } 37 | } 38 | t.Logf("发送 %d 条数据, 耗时 %d 纳秒 \n", si, time.Since(startTime)) 39 | 40 | startTime1 := time.Now() 41 | wg.Add(testCount) 42 | go func() { 43 | msgs, err := mq.Sub(queue) 44 | if err != nil { 45 | panic(err) 46 | } 47 | for range msgs { 48 | wg.Done() 49 | } 50 | }() 51 | 52 | wg.Wait() 53 | t.Logf("消费 %d 条数据, 耗时 %d 纳秒 \n", testCount, time.Since(startTime1)) 54 | } 55 | 56 | func TestExchangePub(t *testing.T) { 57 | queue := &Queue{Name: "toolkit.queue.test2", Key: "toolkit.queue2.*"} 58 | mq, err := New(&Config{ 59 | Addr: "amqp://guest:guest@10.0.2.252:5672/", 60 | ExchangeName: "toolkit.exchange.test2", // 直连交换机名称 61 | }) 62 | if err != nil { 63 | panic(err) 64 | } 65 | 66 | count := 100 67 | 68 | wg2 := sync.WaitGroup{} 69 | wg2.Add(count) 70 | go func() { 71 | msgs, err := mq.Sub(queue) 72 | if err != nil { 73 | panic(err) 74 | } 75 | for msg := range msgs { 76 | var v interface{} 77 | err := msg.JSON(&v) 78 | if err != nil { 79 | panic(err) 80 | } 81 | wg2.Done() 82 | // fmt.Printf("msg: %s \n", v) 83 | } 84 | }() 85 | 86 | <-time.After(100 * time.Millisecond) 87 | 88 | msg := &Message{ 89 | Data: []byte(`{"seqno":"1563541319","cmd":"44","data":{"uid":1070869}}`), 90 | } 91 | ex := &Exchange{Name: "toolkit.ex.test.fanout", Kind: ExchangeFanout, AutoDelete: true} 92 | 93 | for i := 0; i < count; i++ { 94 | err := mq.Pub(queue, msg, ex) 95 | if err != nil { 96 | panic(err) 97 | } 98 | } 99 | 100 | wg2.Wait() 101 | 102 | } 103 | 104 | func TestAmqpApp(t *testing.T) { 105 | testQueue := &Queue{Name: "toolkit.queue.test3", Key: "toolkit.queue.test3"} 106 | testReplyQueue := &Queue{Name: "ttoolkit.queue.reply.test3", Key: "toolkit.queue.reply.test3"} 107 | mq, err := NewApp(&Config{ 108 | Addr: "amqp://guest:guest@10.0.2.252:5672/", 109 | ExchangeName: "toolkit.exchange.test3", // 直连交换机名称 110 | }) 111 | 112 | if err != nil { 113 | t.Errorf("amqp error: %v", err) 114 | } 115 | var wg sync.WaitGroup 116 | wg.Add(2) 117 | mq.On(testQueue, func(c *MQContext) { 118 | t.Log("mq listener here") 119 | wg.Done() 120 | }) 121 | mq.Route(map[*Queue]MQHandler{ 122 | testQueue: func(c *MQContext) { 123 | body := map[string]interface{}{} 124 | if err := c.BindJSON(&body); err != nil { 125 | t.Errorf("bind error") 126 | return 127 | } 128 | t.Logf("mq context here, data: %+v", body) 129 | c.Pub(testReplyQueue, &Message{Data: []byte(`{"hello":"world"}`)}) 130 | wg.Done() 131 | }, 132 | }) 133 | 134 | mq.Pub(testQueue, &Message{Data: []byte(`{"hello":"world"}`)}) 135 | wg.Wait() 136 | } 137 | 138 | func ExampleSimple() { 139 | queue := &Queue{Name: "toolkit.queue.test", Key: "toolkit.queue.*"} 140 | mq, _ := New(&Config{ 141 | Addr: "amqp://guest:guest@10.0.2.252:5672/", 142 | ExchangeName: "toolkit.exchange.test", // 直连交换机名称 143 | }) 144 | go func() { 145 | msgs, err := mq.Sub(queue) 146 | if err != nil { 147 | panic(err) 148 | } 149 | for msg := range msgs { 150 | var v interface{} 151 | err := msg.JSON(&v) 152 | if err != nil { 153 | panic(err) 154 | } 155 | fmt.Printf("msg: %s", v) 156 | } 157 | }() 158 | 159 | } 160 | 161 | func TestAmqpR(t *testing.T) { 162 | queue := &Queue{ 163 | Name: "debug.desktop.v1.server.rpc.req_mall_card_detail_list", 164 | Durable: true, 165 | } 166 | // exchange := &Exchange{Name: "toolkit.exchange.test"} 167 | 168 | mq, err := New(&Config{ 169 | Addr: "amqp://guest:guest@10.0.2.252:5672/", 170 | ExchangeName: "desktop.exchange.v1", 171 | }) 172 | if err != nil { 173 | panic(err) 174 | } 175 | 176 | testCount := 500000 177 | 178 | startTime := time.Now() 179 | 180 | // wg := sync.WaitGroup{} 181 | si := 0 182 | for ; si < testCount; si++ { 183 | msg := &Message{ 184 | Data: []byte(fmt.Sprintf(`{"cmd":"req_mall_card_detail_list","data":{"card_id":1,"mid":9100251},"seqno":"%d"}`, si)), 185 | } 186 | err := mq.Pub(queue, msg) 187 | if err != nil { 188 | panic(err) 189 | } 190 | } 191 | t.Logf("发送 %d 条数据, 耗时 %d 纳秒 \n", si, time.Since(startTime)) 192 | 193 | // startTime1 := time.Now() 194 | // wg.Add(testCount) 195 | // go func() { 196 | // msgs, err := mq.Sub(queue) 197 | // if err != nil { 198 | // panic(err) 199 | // } 200 | // for range msgs { 201 | // wg.Done() 202 | // } 203 | // }() 204 | 205 | // wg.Wait() 206 | // t.Logf("消费 %d 条数据, 耗时 %d 纳秒 \n", testCount, time.Since(startTime1)) 207 | } 208 | -------------------------------------------------------------------------------- /amqp/app.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import "sync" 4 | 5 | type MQApp struct { 6 | Client *MQ 7 | listenerMu sync.RWMutex 8 | listener map[*Queue][]MQHandler 9 | listenerRecord map[*Queue]bool 10 | } 11 | 12 | func NewApp(config *Config) (*MQApp, error) { 13 | mq, err := New(config) 14 | return &MQApp{ 15 | Client: mq, 16 | listener: make(map[*Queue][]MQHandler), 17 | listenerRecord: make(map[*Queue]bool), 18 | }, err 19 | } 20 | 21 | // 监听队列触发函数 22 | func (mq *MQApp) On(queue *Queue, handler ...MQHandler) { 23 | mq.listenerMu.RLock() 24 | handlers, ok := mq.listener[queue] 25 | mq.listenerMu.RUnlock() 26 | 27 | if !ok { 28 | handlers = handler 29 | } else { 30 | handlers = append(handlers, handler...) 31 | } 32 | 33 | mq.listenerMu.Lock() 34 | mq.listener[queue] = handlers 35 | mq.listenerMu.Unlock() 36 | 37 | mq.startListen(queue) 38 | } 39 | 40 | func (mq *MQApp) Route(routes map[*Queue]MQHandler) { 41 | for q, handler := range routes { 42 | mq.On(q, handler) 43 | } 44 | } 45 | 46 | func (mq *MQApp) startListen(queue *Queue) { 47 | _, ok := mq.listenerRecord[queue] 48 | 49 | // 之前已经开始了 50 | if ok { 51 | return 52 | } 53 | 54 | // 开始监听 55 | go func(queue *Queue) { 56 | ch, err := mq.Client.Sub(queue) 57 | if err != nil { 58 | return 59 | } 60 | mq.listenerMu.Lock() 61 | mq.listenerRecord[queue] = true 62 | mq.listenerMu.Unlock() 63 | for msg := range ch { 64 | 65 | ctx := &MQContext{ 66 | Request: msg, 67 | Client: mq.Client, 68 | App: mq, 69 | } 70 | handlers := mq.listener[queue] 71 | go func() { 72 | // TODO defer error 73 | for _, h := range handlers { 74 | h(ctx) 75 | } 76 | }() 77 | } 78 | }(queue) 79 | } 80 | 81 | func (mq *MQApp) Pub(q *Queue, msg *Message) error { 82 | return mq.Client.Pub(q, msg) 83 | } 84 | -------------------------------------------------------------------------------- /amqp/consumer.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import "github.com/streadway/amqp" 4 | 5 | func defaultConsumer() *Consumer { 6 | return &Consumer{"", true, false, false, false, nil} 7 | } 8 | 9 | // Consumer 定义消费者选项 10 | type Consumer struct { 11 | Name string 12 | AutoAck bool // 自动确认 13 | Exclusive bool 14 | NoLocal bool 15 | NoWait bool 16 | Args amqp.Table 17 | } 18 | -------------------------------------------------------------------------------- /amqp/context.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type MQHandler func(*MQContext) 8 | 9 | type MQContext struct { 10 | Request *Message 11 | Client *MQ 12 | App *MQApp 13 | } 14 | 15 | func (c *MQContext) BindJSON(v interface{}) error { 16 | return json.Unmarshal(c.Request.Data, v) 17 | } 18 | func (c *MQContext) Pub(q *Queue, msg *Message) error { 19 | return c.Client.Pub(q, msg) 20 | } 21 | -------------------------------------------------------------------------------- /amqp/exchange.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "github.com/streadway/amqp" 5 | ) 6 | 7 | func defaultExchange(name string) *Exchange { 8 | return &Exchange{ 9 | Name: name, 10 | Kind: amqp.ExchangeDirect, 11 | Durable: true, // 持久化 12 | } 13 | } 14 | 15 | // Exchange 定义交换机 16 | type Exchange struct { 17 | Name string // 名称 18 | Kind string // 交换机类型,4 种类型之一 19 | Durable bool // 是否持久化 20 | AutoDelete bool // 是否自动删除 21 | Internal bool // 是否内置,如果设置 为true,则表示是内置的交换器,客户端程序无法直接发送消息到这个交换器中,只能通过交换器路由到交换器的方式 22 | NoWait bool // 是否等待通知定义交换机结果 23 | Args amqp.Table 24 | IsDeclare bool // 是否已定义 25 | } 26 | -------------------------------------------------------------------------------- /amqp/message.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | // Message 消息体 8 | type Message struct { 9 | ContentType string // 消息类型 10 | Queue *Queue // 来自于哪个队列 11 | Data []byte // 消息数据 12 | mq *MQ 13 | } 14 | 15 | func (m *Message) contentType() string { 16 | if m.ContentType == "" { 17 | return "text/plain" 18 | } 19 | return m.ContentType 20 | } 21 | 22 | // JSON 以 json 解析消息体的数据为指定结构体 23 | func (m *Message) JSON(v interface{}) error { 24 | return json.Unmarshal(m.Data, v) 25 | } 26 | 27 | // ReplyTo 给回复的队列发送消息 28 | func (m *Message) ReplyTo(msg *Message) error { 29 | return m.mq.Pub(m.Queue.ReplyTo, msg) 30 | } 31 | -------------------------------------------------------------------------------- /amqp/mq.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "fmt" 5 | "github.com/streadway/amqp" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | type MQ struct { 11 | Addr string 12 | Client *amqp.Connection 13 | Channel *amqp.Channel 14 | Exchange *Exchange 15 | Consumer *Consumer 16 | // notifyClose chan *amqp.Error 17 | subQueues []*Queue // 已注册为消费者的通道 18 | retrying bool 19 | } 20 | 21 | // Init 初始化 22 | // 1. 初始化交换机 23 | func (mq *MQ) Init() error { 24 | mq.subQueues = []*Queue{} 25 | 26 | err := mq.connect() 27 | if err != nil { 28 | return err 29 | } 30 | 31 | // 初始化默认交换机 32 | err = mq.ExchangeDeclare(mq.Exchange) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | return nil 38 | } 39 | 40 | func (mq *MQ) connect() error { 41 | conn, err := amqp.Dial(mq.Addr) 42 | if err != nil { 43 | return err 44 | } 45 | mq.Client = conn 46 | channel, err := conn.Channel() 47 | if err != nil { 48 | return err 49 | } 50 | mq.Channel = channel 51 | 52 | // 重连后重新注册消费者 53 | for _, q := range mq.subQueues { 54 | q.IsDeclare = false 55 | q.exchange = nil 56 | q.q = nil 57 | mq.bindMQChan(q) 58 | } 59 | 60 | // 断线重连 61 | if !mq.retrying { 62 | go mq.reconnect() 63 | } 64 | 65 | 66 | return nil 67 | } 68 | 69 | func (mq *MQ) reconnect() { 70 | mq.retrying = true 71 | 72 | if mq.Client != nil && mq.Channel != nil { 73 | // 已经连上了,监听关闭消息 74 | closeCh := make(chan *amqp.Error) 75 | mq.Channel.NotifyClose(closeCh) 76 | err := <- closeCh 77 | fmt.Printf("rabbitmq connection is close: %v, retrying...\n", err) 78 | mq.Client.Close() 79 | mq.Channel.Close() 80 | mq.Client = nil 81 | mq.Channel = nil 82 | } 83 | 84 | err := mq.connect() 85 | if err != nil { 86 | fmt.Printf("rabbitmq connection retry fail: %v next retrying...\n", err) 87 | } else { 88 | fmt.Printf("rabbitmq connection retry ok\n") 89 | } 90 | time.Sleep(2 * time.Second) 91 | mq.reconnect() 92 | // if err != nil { 93 | // 94 | // err := mq.connect() 95 | // if err != nil { 96 | // mq.reconnect() 97 | // } 98 | // } 99 | 100 | // for { 101 | // closeCh := make(chan *amqp.Error) 102 | // mq.Channel.NotifyClose(closeCh) 103 | // 104 | // err, ok := <-closeCh 105 | // if !ok { 106 | // continue 107 | // } 108 | // if err == nil { 109 | // continue 110 | // } 111 | // 112 | // fmt.Printf("rabbitmq connection is close: %v, retrying...\n", err) 113 | // if mq.Client != nil { 114 | // mq.Client.Close() 115 | // mq.Client = nil 116 | // 117 | // } 118 | // <-time.After(2 * time.Second) // 隔 2s 重连一次 119 | // mq.connect() 120 | // 121 | // } 122 | } 123 | 124 | var subMu sync.Mutex 125 | 126 | // Sub 定于队列消息 127 | // q 队列 128 | //return 接收消息的通道 , 错误对象 129 | func (mq *MQ) Sub(q *Queue) (<-chan *Message, error) { 130 | subMu.Lock() 131 | defer subMu.Unlock() 132 | 133 | mq.subQueues = append(mq.subQueues, q) 134 | 135 | // 初始化接收通道 136 | if q.consumerChan == nil { 137 | q.consumerChan = make(chan *Message, 2) 138 | } 139 | 140 | mq.bindMQChan(q) 141 | 142 | return q.consumerChan, nil 143 | } 144 | 145 | var bindMu sync.Mutex 146 | 147 | // 将 mq 通道绑到队列通道中 148 | func (mq *MQ) bindMQChan(q *Queue) error { 149 | bindMu.Lock() 150 | defer bindMu.Unlock() 151 | // 定义队列 152 | if !q.IsDeclare { 153 | err := mq.QueueDeclare(q) 154 | if err != nil { 155 | return err 156 | } 157 | } 158 | 159 | e := mq.Exchange 160 | 161 | // 绑定交换机 162 | if q.exchange != e { 163 | err := mq.QueueBind(q, e) 164 | if err != nil { 165 | return err 166 | } 167 | } 168 | msgChan, err := mq.Channel.Consume( 169 | q.Name, 170 | mq.Consumer.Name, 171 | mq.Consumer.AutoAck, 172 | mq.Consumer.Exclusive, 173 | mq.Consumer.NoLocal, 174 | mq.Consumer.NoWait, 175 | mq.Consumer.Args, 176 | ) 177 | 178 | if err != nil { 179 | return err 180 | } 181 | 182 | go func(ch chan<- *Message) { 183 | for d := range msgChan { 184 | msg := &Message{ 185 | ContentType: d.ContentType, 186 | mq: mq, 187 | Queue: q, 188 | Data: d.Body, 189 | } 190 | ch <- msg 191 | } 192 | }(q.consumerChan) 193 | 194 | return nil 195 | } 196 | 197 | var pubMu sync.Mutex 198 | 199 | // Pub 给队列发送消息, 200 | // q 队列, 201 | // msg 消息, 202 | // exchanges 交换机,可以用多个交换机多次发送,默认使用初始化时指定的交换机 203 | func (mq *MQ) Pub(q *Queue, msg *Message, exchanges ...*Exchange) error { 204 | pubMu.Lock() 205 | defer pubMu.Unlock() 206 | 207 | // 定义队列 208 | if !q.IsDeclare { 209 | err := mq.QueueDeclare(q) 210 | if err != nil { 211 | return err 212 | } 213 | } 214 | 215 | if len(exchanges) == 0 { 216 | exchanges = append(exchanges, mq.Exchange) 217 | // 绑定初始化的交换机 218 | if q.exchange != mq.Exchange { 219 | err := mq.QueueBind(q, mq.Exchange) 220 | if err != nil { 221 | return err 222 | } 223 | } 224 | } else { 225 | for _, e := range exchanges { 226 | if !e.IsDeclare { 227 | err := mq.ExchangeDeclare(e) 228 | if err != nil { 229 | return err 230 | } 231 | } 232 | err := mq.Channel.QueueBind( 233 | q.Name, 234 | q.GetKey(), 235 | e.Name, 236 | false, 237 | nil, 238 | ) 239 | if err != nil { 240 | return err 241 | } 242 | } 243 | } 244 | 245 | for _, e := range exchanges { 246 | // 发消息 247 | err := mq.Channel.Publish( 248 | e.Name, 249 | q.GetKey(), 250 | false, 251 | false, 252 | amqp.Publishing{ 253 | ContentType: msg.ContentType, 254 | ReplyTo: q.ReplyQueue(), 255 | Body: msg.Data, 256 | }, 257 | ) 258 | if err != nil { 259 | return err 260 | } 261 | } 262 | 263 | return nil 264 | 265 | } 266 | 267 | func (mq *MQ) QueueDeclare(q *Queue) error { 268 | queue, err := mq.Channel.QueueDeclare( 269 | q.Name, 270 | q.Durable, 271 | q.AutoDelete, 272 | q.Exclusive, 273 | q.NoWait, 274 | q.Args, 275 | ) 276 | if err != nil { 277 | return err 278 | } 279 | q.q = &queue 280 | q.IsDeclare = true 281 | return nil 282 | } 283 | 284 | func (mq *MQ) ExchangeDeclare(e *Exchange) error { 285 | if e.IsDeclare { 286 | return nil 287 | } 288 | err := mq.Channel.ExchangeDeclare( 289 | e.Name, 290 | e.Kind, 291 | e.Durable, 292 | e.AutoDelete, 293 | e.Internal, 294 | e.NoWait, 295 | e.Args, 296 | ) 297 | if err == nil { 298 | e.IsDeclare = true 299 | } 300 | return err 301 | } 302 | 303 | func (mq *MQ) QueueBind(q *Queue, e *Exchange) error { 304 | if !e.IsDeclare { 305 | mq.ExchangeDeclare(e) 306 | } 307 | err := mq.Channel.QueueBind( 308 | q.Name, 309 | q.GetKey(), 310 | e.Name, 311 | false, 312 | nil, 313 | ) 314 | if err != nil { 315 | return err 316 | } 317 | q.exchange = e 318 | return nil 319 | } 320 | -------------------------------------------------------------------------------- /amqp/queue.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "github.com/streadway/amqp" 5 | ) 6 | 7 | // Queue 队列 8 | type Queue struct { 9 | Name string // 必须包含前缀标识使用类型 msg. | rpc. | reply. | notify. 10 | Key string // 和交换机绑定时用的Key 11 | Durable bool // 消息代理重启后,队列依旧存在 12 | AutoDelete bool // 当最后一个消费者退订后即被删除 13 | Exclusive bool // 只被一个连接(connection)使用,而且当连接关闭后队列即被删除 14 | NoWait bool // 不需要服务器返回 15 | ReplyTo *Queue // rpc 的消息回应道哪个队列 16 | Args amqp.Table // 一些消息代理用他来完成类似与TTL的某些额外功能 17 | IsDeclare bool // 是否已定义 18 | 19 | q *amqp.Queue 20 | exchange *Exchange // 绑定的交换机 21 | consumerChan chan *Message // 接收该队列数据的通道 22 | } 23 | 24 | func (q *Queue) ReplyQueue() string { 25 | if q.ReplyTo == nil { 26 | return "" 27 | } 28 | return q.ReplyTo.Name 29 | } 30 | 31 | func (q *Queue) GetKey() string { 32 | if q.Key == "" { 33 | return q.Name 34 | } 35 | return q.Key 36 | } 37 | -------------------------------------------------------------------------------- /amqp/readme.md: -------------------------------------------------------------------------------- 1 | # Rabbitmq 封装 2 | 3 | 封装 amqp 协议的基本使用方法,让amqp用起来更简单 4 | 5 | ## 使用 6 | 7 | ### 配置项 8 | 9 | ```go 10 | // Config 配置项 11 | // ExchangeName 和 Exchange 二选一,用于指定发布和订阅时使用的交换机 12 | type Config struct { 13 | Addr string // rabbitmq 地址 14 | ExchangeName string // 使用该值创建一个直连的交换机 15 | Exchange *Exchange // 自定义默认交换机 16 | Consumer *Consumer // 在定于队列时,作为消费者使用的参数 17 | } 18 | ``` 19 | 20 | ### 交换机 21 | 22 | ```go 23 | const ( 24 | ExchangeDirect = "direct" 25 | ExchangeFanout = "fanout" 26 | ExchangeTopic = "topic" 27 | ExchangeHeaders = "headers" 28 | ) 29 | 30 | type Exchange struct { 31 | Name string // 名称 32 | Kind string // 交换机类型,4 种类型之一 33 | Durable bool // 是否持久化 34 | AutoDelete bool // 是否自动删除 35 | Internal bool // 是否内置,如果设置 为true,则表示是内置的交换器,客户端程序无法直接发送消息到这个交换器中,只能通过交换器路由到交换器的方式 36 | NoWait bool // 是否等待通知定义交换机结果 37 | Args amqp.Table 38 | } 39 | ``` 40 | 41 | ### 队列 42 | 43 | ```go 44 | type Queue struct { 45 | Name string // 必须包含前缀标识使用类型 msg. | rpc. | reply. | notify. 46 | Key string // 和交换机绑定时用的Key, 如果不设置,默认和 Name 一样 47 | Durable bool // 消息代理重启后,队列依旧存在 48 | AutoDelete bool // 当最后一个消费者退订后即被删除 49 | Exclusive bool // 只被一个连接(connection)使用,而且当连接关闭后队列即被删除 50 | NoWait bool // 不需要服务器返回 51 | ReplyTo *Queue // rpc 的消息回应道哪个队列 52 | Args amqp.Table // 一些消息代理用他来完成类似与TTL的某些额外功能 53 | } 54 | ``` 55 | 56 | ### 消息结构 57 | 58 | ```go 59 | // Message 消息体 60 | type Message struct { 61 | ContentType string // 消息类型 62 | Queue *Queue // 来自于哪个队列 63 | Data []byte // 消息数据 64 | } 65 | ``` 66 | 67 | ### 示例 68 | 69 | ```go 70 | import ( 71 | "github.com/go-eyas/toolkit/amqp" 72 | ) 73 | 74 | func main() { 75 | mq := amqp.New(*amqp.Config{ 76 | Addr: "amqp://guest:guest@127.0.0.1:5672", 77 | ExchangeName: "toolkit.exchange.test", 78 | }) 79 | queue := &amqp.Queue{Name: "toolkit.queue.test"} 80 | err := mq.Pub(queue, &amqp.Message{Data: []byte("{\"hello\":\"world\"}")}) 81 | 82 | msgch, err := mq.Sub(queue) 83 | for msg := range msgch { 84 | fmt.Printf("%s", string(msg.Data)) 85 | } 86 | } 87 | 88 | ``` 89 | 90 | ## godoc 91 | 92 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/amqp) 93 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/jinzhu/configor" 7 | ) 8 | 9 | // 自动搜索配置文件 config.xxx 并自动加载,如果配置文件不存在,使用默认配置 10 | // 支持三种配置文件格式 11 | // 并支持 环境变量覆盖配置文件 12 | 13 | // 获取有效的 14 | func parseFiles(name string) []string { 15 | exts := []string{"toml", "json", "yml"} 16 | env := os.Getenv("CONFIG_ENV") 17 | if env == "" { 18 | env = "local" 19 | } 20 | 21 | filelist := []string{} 22 | 23 | for _, ext := range exts { 24 | filelist = append(filelist, 25 | name+"."+env+"."+ext, 26 | name+"."+ext, 27 | "../"+name+"."+env+"."+ext, 28 | "../"+name+"."+ext, 29 | ) 30 | } 31 | 32 | validFiles := []string{} 33 | for _, f := range filelist { 34 | if _, err := os.Stat(f); !os.IsNotExist(err) { 35 | validFiles = append(validFiles, f) 36 | } 37 | } 38 | 39 | return validFiles 40 | } 41 | 42 | // Init 初始化配置文件 43 | func Init(file string, v interface{}) error { 44 | files := parseFiles(file) 45 | 46 | err := configor.New(&configor.Config{AutoReload: true}).Load(v, files...) 47 | 48 | return err 49 | } 50 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | type ConfigT struct { 8 | IsParseJSON bool `json:"isParseJson"` 9 | IsParseToml bool `toml:"isParseToml"` 10 | IsParseYml bool `yaml:"isParseYml"` 11 | IsParseLocalToml bool `toml:"isParseLocalToml"` 12 | 13 | Ext string 14 | Env string 15 | Obj struct { 16 | Array []int 17 | Boolean bool 18 | } 19 | } 20 | 21 | func TestConfig(t *testing.T) { 22 | files := parseFiles("test/config") 23 | t.Logf("valid files: %#v", files) 24 | 25 | conf := &ConfigT{} 26 | err := Init("test/config", conf) 27 | 28 | if err != nil { 29 | panic(err) 30 | } 31 | t.Logf("config: %+v", conf) 32 | if conf.IsParseJSON && conf.IsParseToml && conf.IsParseYml { 33 | t.Log("parse config success") 34 | } else { 35 | panic("parse config error") 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /config/readme.md: -------------------------------------------------------------------------------- 1 | # 配置文件 2 | 3 | * 支持 YAML, JSON, TOML, 环境变量 设置配置项的值 4 | * 支持多文件覆盖配置 5 | * 支持默认值 6 | * 修改文件后自动重载 7 | 8 | # 使用 9 | 10 | ```go 11 | import "github.com/go-eyas/toolkit/config" 12 | 13 | type Config struct { 14 | IsParseJSON bool `json:"isParseJson"` 15 | IsParseToml bool `toml:"isParseToml"` 16 | IsParseYml bool `yaml:"isParseYml"` 17 | IsParseLocalToml bool `toml:"isParseLocalToml"` 18 | 19 | Ext string 20 | Env string `default:"shell" env:"APP_ENV" required:"true"` 21 | Obj struct { 22 | Array []int 23 | Boolean bool 24 | } 25 | } 26 | 27 | func main() { 28 | conf := &Config{} 29 | err := config.Init("test/config", conf) 30 | } 31 | 32 | ``` 33 | 34 | ## 配置文件加载顺序 35 | 36 | 后面的会覆盖前面的,假设 file 参数传的是 `config`, shell 的环境变量 `CONFIG_ENV=dev` 37 | 38 | ``` 39 | ../config.dev.yml 40 | ../config.yml 41 | config.dev.yml 42 | config.yml 43 | ../config.dev.json 44 | ../config.json 45 | config.dev.json 46 | config.json 47 | ../config.dev.toml 48 | ../config.toml 49 | config.dev.toml 50 | config.toml 51 | 环境变量 52 | ``` 53 | 54 | > 如果环境变量 `CONFIG_ENV` 没有设置,则默认为 `local` 55 | 56 | # godoc 57 | 58 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/config) -------------------------------------------------------------------------------- /config/test/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "ext": "json", 3 | "env": "none", 4 | "isParseJson": true, 5 | "obj": { 6 | "array": [1], 7 | "boolean": true 8 | } 9 | } -------------------------------------------------------------------------------- /config/test/config.local.toml: -------------------------------------------------------------------------------- 1 | ext = "toml" 2 | env = "local" 3 | isParseLocalToml = true 4 | [obj] 5 | # array = [1,2,3,4,5] 6 | boolean = true -------------------------------------------------------------------------------- /config/test/config.myenv.toml: -------------------------------------------------------------------------------- 1 | env = "local" 2 | [obj] 3 | array = [1,2,3,4,5,6,7,8,9,0] -------------------------------------------------------------------------------- /config/test/config.toml: -------------------------------------------------------------------------------- 1 | ext = "toml" 2 | env = "" 3 | isParseToml = true 4 | 5 | [obj] 6 | array = [1,2,3,4,5,6,7] 7 | boolean = true -------------------------------------------------------------------------------- /config/test/config.yml: -------------------------------------------------------------------------------- 1 | ext: "yml" 2 | env: "" 3 | isParseYml: true 4 | obj: 5 | array: 6 | - 1 7 | - 2 8 | - 5 9 | boolean: false -------------------------------------------------------------------------------- /db/config.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | // Config 数据库配置项 4 | type Config struct { 5 | Driver string `yaml:"driver" json:"driver" toml:"driver" env:"DB_DRIVER"` 6 | URI string `yaml:"uri" json:"uri" toml:"uri" env:"DB_URI"` 7 | Debug bool 8 | Logger Logger 9 | } 10 | 11 | // Logger 日志对象 12 | type Logger interface { 13 | Debug(...interface{}) 14 | Debugf(string, ...interface{}) 15 | Error(...interface{}) 16 | Errorf(string, ...interface{}) 17 | } 18 | 19 | type ViewModel interface{ 20 | From() string 21 | } -------------------------------------------------------------------------------- /db/gorm.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | "github.com/go-eyas/toolkit/log" 6 | "github.com/jinzhu/gorm" 7 | "strings" 8 | "time" 9 | "github.com/novalagung/gubrak" 10 | 11 | // load drivers 12 | _ "github.com/jinzhu/gorm/dialects/mssql" 13 | _ "github.com/jinzhu/gorm/dialects/mysql" 14 | _ "github.com/jinzhu/gorm/dialects/postgres" 15 | ) 16 | 17 | type gormLogger struct { 18 | logger Logger 19 | } 20 | 21 | func (l *gormLogger) Print(v ...interface{}) { 22 | var level = v[0] 23 | 24 | if level == "sql" { 25 | tm := v[2].(time.Duration) 26 | msgs := gorm.LogFormatter(v...) 27 | l.logger.Debug("SQL [", v[5], " rows][", tm.String(), "]: ", msgs[3]) 28 | } else { 29 | l.logger.Debug(v...) 30 | } 31 | } 32 | 33 | // Gorm 初始化 gorm,返回 gorm 实例 34 | func Gorm(conf *Config) (*gorm.DB, error) { 35 | db, err := gorm.Open(conf.Driver, conf.URI) 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | var logger *gormLogger 41 | if conf.Logger != nil { 42 | logger = &gormLogger{conf.Logger} 43 | } else { 44 | logger = &gormLogger{log.SugaredLogger} 45 | } 46 | db.SetLogger(logger) 47 | if conf.Debug { 48 | db.LogMode(conf.Debug) 49 | } 50 | 51 | return db, nil 52 | } 53 | 54 | func GormViewMigrate(db *gorm.DB, v ...ViewModel) { 55 | for _, m := range v { 56 | if db.HasTable(m) { 57 | continue 58 | } 59 | var tags []string 60 | scope := db.NewScope(m) 61 | ms := scope.GetModelStruct() 62 | for _, field := range ms.StructFields { 63 | if !field.IsIgnored { 64 | tags = append(tags, scope.Quote(field.DBName)) 65 | } 66 | } 67 | 68 | // 去掉重复字段 69 | _tags, _ := gubrak.Filter(tags, func(s string) bool { 70 | res, _ := gubrak.Find(tags, func(i string) bool { 71 | index := strings.LastIndex(i, "."+s) 72 | return index != -1 73 | }) 74 | return res == nil 75 | }) 76 | tags = _tags.([]string) 77 | 78 | db.Exec(fmt.Sprintf("CREATE VIEW %v AS SELECT %s %s", scope.QuotedTableName(), strings.Join(tags, ","), m.From())) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /db/gorm_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "github.com/go-eyas/toolkit/log" 5 | "github.com/jinzhu/gorm" 6 | "os" 7 | "testing" 8 | ) 9 | 10 | type gormModelTest struct { 11 | gorm.Model 12 | } 13 | 14 | func (gormModelTest) TableName() string { 15 | return "gorm_test" 16 | } 17 | 18 | type gormView struct { 19 | ID int64 20 | } 21 | 22 | func (gormView) From() string { 23 | return `From gorm_test` 24 | } 25 | 26 | func TestGorm(t *testing.T) { 27 | // test init 28 | db, err := Gorm(&Config{ 29 | Debug: true, 30 | Driver: "mysql", 31 | URI: os.Getenv("DB"), 32 | Logger: log.SugaredLogger, 33 | }) 34 | if err != nil { 35 | panic(err) 36 | } 37 | i := 0 38 | // test query 39 | err = db.Raw("SELECT 1 + 1").Row().Scan(&i) 40 | if err != nil { 41 | panic(err) 42 | } 43 | if i == 2 { 44 | t.Log("test gorm success") 45 | } 46 | 47 | // test migrate 48 | db.AutoMigrate(gormModelTest{}) 49 | list := []*gormModelTest{} 50 | err = db.Model(gormModelTest{}).Find(&list).Error 51 | if err != nil { 52 | panic(err) 53 | } 54 | 55 | // test view 56 | v := gormView{} 57 | GormViewMigrate(db, v) 58 | listView := []*gormView{} 59 | err = db.Model(v).Find(&listView).Error 60 | if err != nil { 61 | panic(err) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /db/readme.md: -------------------------------------------------------------------------------- 1 | # 数据库封装 2 | 3 | 封装 orm 和数据库驱动 4 | 5 | ## 初始化 6 | 7 | * 使用 db.Gorm 使用 gorm 初始化 8 | * 使用 db.Xorm 使用 xorm 初始化 9 | 10 | ```go 11 | import ( 12 | "github.com/go-eyas/toolkit/db" 13 | "github.com/go-eyas/toolkit/log" 14 | ) 15 | 16 | func main() { 17 | log.Init(&log.Config{}) 18 | var err error 19 | // gorm 20 | var db *gorm.DB 21 | db, err = db.Gorm(db.Config{ 22 | Driver: "mysql", 23 | URI: "user:password@(127.0.0.1:3306)/mydb", 24 | Logger: log.SugaredLogger, 25 | }) 26 | 27 | // xorm 28 | // var db *xorm.Engine 29 | // db, err = db.Xorm(db.Config{ 30 | // Driver: "mysql", 31 | // URI: "user:password@(127.0.0.1:3306)/mydb", 32 | // Logger: log.SugaredLogger, 33 | // }) 34 | 35 | 36 | if err != nil { 37 | panic(err) 38 | } 39 | 40 | defer db.Close() 41 | } 42 | ``` 43 | 44 | ## 视图 View 45 | 在支持视图的数据库,可使用 `db.GormViewMigrate` 或 `db.XormViewMigrate` 用于创建视图,视图的字段名称映射和 model 一致 46 | 47 | 视图的模型要实现接口 48 | 49 | ```go 50 | type ViewModel interface{ 51 | From() string // From 返回 创建视图时的 FROM 部分语句 52 | } 53 | ``` 54 | 55 | ```go 56 | type User struct { 57 | ID int64 58 | UserName string 59 | Status byte 60 | } 61 | type Company struct { 62 | UID int64 63 | CompanyName string 64 | } 65 | 66 | type UserCompany struct { 67 | *User 68 | *Company 69 | } 70 | 71 | func (UserView) From() string { 72 | return "FROM users JOIN company ON company.uid = users.uid" 73 | } 74 | 75 | db.GormViewMigrate(DB, &UserCompany{}) 76 | 77 | DB.Model(UserCompany{}).Where("id = ?", 1).Find(&userCompany) 78 | ``` 79 | 80 | 81 | ## 驱动 82 | 83 | 初始化的时候,配置项为 84 | 85 | ```go 86 | // Config 数据库配置项 87 | type Config struct { 88 | Driver string `yaml:"driver" json:"driver" toml:"driver" env:"DB_DRIVER"` 89 | URI string `yaml:"uri" json:"uri" toml:"uri" env:"DB_URI"` 90 | Debug bool 91 | Logger Logger 92 | } 93 | ``` 94 | 95 | Driver 的可选项为 96 | 97 | * mysql 98 | * postgres 99 | * mssql: gorm 为 mssql,xorm 为 sqlserver 100 | 101 | 这些驱动都已提前导入,初始化的时候无需再导入驱动 102 | 103 | #### sqlite 104 | 105 | 因为sqlite驱动是CGO的包,所以默认不导入, 如果要是用sqlite数据库,请按照以下指引 106 | 107 | 1. 导入驱动 108 | ```go 109 | import "github.com/go-eyas/toolkit/db/sqlite" 110 | ``` 111 | 2. 安装 Gcc, G++ 编译环境,windows可使用 [TDM-GCC](http://tdm-gcc.tdragon.net/download) ,其他系统的自行解决 112 | 3. 使用环境变量启用CGO: `CGO_ENABLED=1` 113 | 114 | #### 其他数据库 115 | 116 | 如果要是用其他数据库,如 oracle,tidb等等,执行查找资料并引入驱动 117 | 118 | 119 | # godoc 120 | 121 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/db) -------------------------------------------------------------------------------- /db/resource/gorm.go: -------------------------------------------------------------------------------- 1 | package resource 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "github.com/jinzhu/gorm" 7 | "reflect" 8 | ) 9 | 10 | type Field struct { 11 | ColumnName string // 数据库列名 12 | StructKey string // 结构体的名字 13 | isPrimaryKey bool // 是否主键 14 | isIgnore bool // 该字段是否忽略 15 | JsonKey string // json key 16 | Search string // 查询方式 17 | Order string // 排序方式 18 | Update bool // 是否可更新,默认不可以 19 | Create bool // 是否可在创建时指定 20 | } 21 | 22 | type Resource struct { 23 | db *gorm.DB 24 | tableName string 25 | model *gorm.DB 26 | modelTypeName string 27 | scope *gorm.Scope 28 | pk string 29 | sample interface{} 30 | Fields []*Field 31 | fieldsStructKeyMap map[string]*Field 32 | defaultOrder []string 33 | } 34 | 35 | // NewGormResource 实例化资源 36 | // 37 | // example: 38 | // 39 | // type Article struct { 40 | // ID int64 `resource:"pk;search:=;order:desc" json:"id"` 41 | // Title string `resource:"create;update;search:like" json:"title"` 42 | // Content string `resource:"create;update;search:like" json:"text"` 43 | // Status byte `resource:"search:=" json:"-"` 44 | // } 45 | // 46 | // r := NewGormResource(db, &Article{}) 47 | func NewGormResource(db *gorm.DB, v interface{}) *Resource { 48 | scope := db.NewScope(v) 49 | rv := reflect.ValueOf(v) 50 | if rv.Kind() == reflect.Ptr { 51 | rv = rv.Elem() 52 | } 53 | rt := rv.Type() 54 | if rt.Kind() == reflect.Ptr { 55 | rt = rt.Elem() 56 | } 57 | 58 | r := &Resource{ 59 | sample: rv.Interface(), 60 | db: db, 61 | tableName: scope.TableName(), 62 | model: db.Table(scope.TableName()), 63 | scope: scope, 64 | } 65 | r.modelTypeName = rt.PkgPath() + "." + rt.Name() 66 | fields, pk := r.parseFields(scope) 67 | r.Fields = fields 68 | r.pk = pk 69 | r.defaultOrder = []string{} 70 | keyMap := map[string]*Field{} 71 | for _, field := range fields { 72 | keyMap[field.StructKey] = field 73 | if field.Order == "DESC" || field.Order == "ASC" { 74 | r.defaultOrder = append(r.defaultOrder, field.ColumnName+" "+field.Order) 75 | } 76 | } 77 | r.fieldsStructKeyMap = keyMap 78 | return r 79 | } 80 | 81 | // Model 返回绑定了数据表的 gorm 实例 82 | // r.Model().Where("status = ?", status).Count() 83 | func (r *Resource) Model() *gorm.DB { 84 | return r.model 85 | } 86 | 87 | // Row 返回绑定了主键值的 gorm 实例 88 | // 89 | // r.Row().Update("status", 1) 90 | func (r *Resource) Row(pk interface{}) *gorm.DB { 91 | return r.model.Where(r.pk+" = ?", pk) 92 | } 93 | 94 | // Create 创建资源,支持传入 struct、map,创建前会重置 resource tag 未设置 create 的字段为0值,使得创建记录时忽略值 95 | // 96 | // err := r.Create(&Article{Title: "Hello", Status: 2}) 这里 status 会被重置为 0 ,因为 status 的 resource tag 未设置 create 97 | func (r *Resource) Create(v interface{}) error { 98 | model, err := r.toCreateStruct(v, true) 99 | if err != nil { 100 | return err 101 | } 102 | return r.model.Create(model).Error 103 | } 104 | 105 | // CreateX 创建资源,支持传入 struct、map,传入的值均有效 106 | // 107 | // err := r.CreateX(&Article{Title: "Hello", Status: 2}) 这里 status 成功设置值 108 | func (r *Resource) CreateX(v interface{}) error { 109 | model, err := r.toCreateStruct(v, false) 110 | if err != nil { 111 | return err 112 | } 113 | return r.model.Create(model).Error 114 | } 115 | 116 | // Update 更新资源,支持传入 struct、map,只会更新 resource tag 设置了 update 的字段 117 | // 118 | // err := r.Update(1, map[string]string{"title": "after title"}) 119 | func (r *Resource) Update(pk interface{}, v interface{}) error { 120 | updates, err := r.toUpdateMap(v, true) 121 | if err != nil { 122 | return err 123 | } 124 | if len(updates) > 0 { 125 | return r.Row(pk).Updates(updates).Error 126 | } 127 | return nil 128 | } 129 | 130 | // UpdateX 更新资源,支持传入 struct、map,更新传入的所有字段,如果传入的是 struct ,则会忽略 0 值 131 | // 132 | // err := r.UpdateX(1, map[string]byte{"status": 1}) 133 | func (r *Resource) UpdateX(pk interface{}, v interface{}) error { 134 | updates, err := r.toUpdateMap(v, false) 135 | if err != nil { 136 | return err 137 | } 138 | if len(updates) > 0 { 139 | return r.Row(pk).Updates(updates).Error 140 | } 141 | return nil 142 | } 143 | 144 | // Detail 查询指定主键的记录 145 | // 146 | // article := &Article{} 147 | // err := r.Detail(1, article) 148 | func (r *Resource) Detail(pk interface{}, v interface{}) error { 149 | return r.Row(pk).First(v).Error 150 | } 151 | 152 | // List 查询资源列表,提供查询条件,排序规则,查询列表,查询规则会以resource tag 的search 值为准 153 | // 154 | // list := []*Article{} 155 | // total, err := r.List(&list, map[string]byte{"status": 1}) 156 | // 157 | func (r *Resource) List(slice interface{}, args ...interface{}) (int64, error) { 158 | switch len(args) { 159 | case 0: 160 | return r.listQuery(slice, nil, nil, nil) 161 | case 1: 162 | page := &Pagination{} 163 | query := args[0] 164 | raw, err := json.Marshal(query) 165 | if err != nil { 166 | return 0, errors.New("query parse error") 167 | } 168 | err = json.Unmarshal(raw, page) 169 | if err != nil { 170 | return 0, errors.New("query parse error") 171 | } 172 | order := r.getOrderArgs(query) 173 | return r.listQuery(slice, page, query, order) 174 | case 2: 175 | page := &Pagination{} 176 | query := args[0] 177 | raw, err := json.Marshal(query) 178 | if err != nil { 179 | return 0, errors.New("query parse error") 180 | } 181 | err = json.Unmarshal(raw, page) 182 | if err != nil { 183 | return 0, errors.New("query parse error") 184 | } 185 | return r.listQuery(slice, page, args[0], args[1]) 186 | } 187 | return 0, errors.New("list param error") 188 | } 189 | 190 | func (r *Resource) ListPage(slice interface{}, page *Pagination, args ...interface{}) (int64, error) { 191 | switch len(args) { 192 | case 0: 193 | return r.listQuery(slice, nil, nil, nil) 194 | case 1: 195 | return r.listQuery(slice, page, args[0], nil) 196 | case 2: 197 | return r.listQuery(slice, page, args[0], args[1]) 198 | } 199 | return 0, errors.New("list param error") 200 | } 201 | 202 | // Delete 删除指定主键的资源 203 | // 204 | // err := r.Delete(1) 205 | // 206 | func (r *Resource) Delete(pk interface{}) error { 207 | return r.model.Delete(r.sample, r.pk+" = ?", pk).Error 208 | } 209 | 210 | -------------------------------------------------------------------------------- /db/resource/readme.md: -------------------------------------------------------------------------------- 1 | # Resource 资源自动维护、检索 2 | 3 | 基于 Gorm API,根据 Restful API 设计对资源进行维护并检索 4 | 5 | 简单说就是:自动 curd 6 | 7 | # Usage 8 | 9 | 先来看个栗子 10 | 11 | ```go 12 | package main 13 | 14 | import ( 15 | "net/http" 16 | "github.com/gin-gonic/gin" 17 | "github.com/go-eyas/toolkit/db" 18 | "github.com/go-eyas/toolkit/db/resource" 19 | ) 20 | 21 | 22 | type Article struct { 23 | ID int64 `resource:"pk;search:none"` 24 | Title string `resource:"create;update;search:like"` 25 | Content string `resource:"create;update;search:like"` 26 | Status byte `resource:"search:="` 27 | } 28 | 29 | func main() { 30 | DB, err := db.Gorm(&db.Config{URI: "root:123456@(127.0.0.1:3306)/test"}) 31 | r := resource.NewGormResource(DB, &Article{}) 32 | 33 | /******* create resource ********/ 34 | err = r.Create(&Article{Title: "the title", Content: "the content"}) // 使用原本类型结构体 35 | 36 | // 使用临时结构体 37 | err = r.Create(&struct{ 38 | Title string 39 | Content string 40 | }{Title: "the title", Content: "the content"}) 41 | 42 | // 使用 map, 会自动匹配map 为 数据库列名,或者struct Key,或者 json key 43 | err = r.Create(map[string]interface{}{"title": "the title", "content": "the content"}) 44 | 45 | 46 | /******* update resource ********/ 47 | err = r.Update(1, &Article{Title: "the title", Content: "the content"}) 48 | err = r.Update(1, &struct{ 49 | Title string 50 | Content string 51 | }{Title: "the title", Content: "the content"}) 52 | err = r.Update(1, map[string]interface{}{"title": "the title", "content": "the content"}) 53 | 54 | 55 | /******* delete resource ********/ 56 | err = r.Delete(1) 57 | 58 | 59 | /******* list resource ********/ 60 | list := []*Article{} 61 | 62 | // 查询所有记录 63 | total, err := r.List(&list) 64 | 65 | // 使用原始结构体设置查询参数,会忽略 0 值 66 | total, err := r.List(&list, &Article{Title: "the title", Content: "the content"}) 67 | 68 | // 临时结构体,会自动匹配原始结构体的 struct key 69 | total, err := r.List(&list, &struct{ 70 | Title string 71 | Content string 72 | Status byte 73 | }{Status: 1}) 74 | 75 | // map 定义查询条件,key 会自动匹配struct key 和 数据库列名 76 | total, err := r.List(&list, map[string]interface{}{ 77 | "status": 1, 78 | }) 79 | 80 | // 查询,并指定列名排序,会使用数组顺序作为排序权重 81 | total, err := r.List(&list, nil, []string{"id DESC", "status ASC"}) 82 | 83 | // 查询,并使用 map 指定排序,map无序,无法指定多个排序权重 84 | total, err := r.List(&list, nil, map[string]string{ 85 | "id": "DESC", 86 | "status": "ASC", 87 | }) 88 | 89 | } 90 | ``` 91 | 92 | ## API 93 | 94 | ### 资源定义 95 | 96 | 资源定义,继承gorm的所有struct tag,可以与gorm的模型结构体共用同一个 97 | 98 | ```go 99 | type Article struct { 100 | ID int64 `resource:"pk;search:=;order:desc" json:"id"` 101 | Title string `resource:"create;update;search:like" json:"title"` 102 | Content string `resource:"create;update;search:like" json:"text"` 103 | Status byte `resource:"search:=" json:"-"` 104 | } 105 | ``` 106 | 107 | tag 的 key 为 `resource`,每队键值使用 `:` 定义,以 `;` 分隔 108 | 109 | |key|默认值|说明| 110 | | ---: | :--- | :--- | 111 | |pk| false | 是否为主键| 112 | |search|=| 该字段的查询类型,sql语句的 where 匹配关系,为 `-` 时该字段不可作为查询条件| 113 | |order|-|查询时默认的排序规则| 114 | |create|-|该字段是否在调用 Create 时是否有效,即新增记录的时候是否给该字段赋值| 115 | |update|-|该字段是否在调用 Update 时是否有效,即修改记录的时候是否给该字段赋值| 116 | 117 | 118 | ### API 函数 119 | 120 | #### resource.New(conf *db.Config, model interface{}) (*Resource, *gorm.DB, error) 121 | 122 | 使用数据库配置创建资源实例 123 | 124 | * conf 数据库配置 `github.com/go-eyas/toolkit/db.Config` 125 | * model 资源模型 126 | 127 | 返回 128 | 129 | * *Resource 资源实例 130 | * *gorm.DB gorm 实例 131 | * error 错误对象 132 | 133 | #### resource.NewGormResource(db *gorm.DB, model interface{}) error 134 | 135 | 基于gorm 实例,创建一个资源,返回资源实例 136 | 137 | ```go 138 | r := resource.NewGormResource(db, Article{}) 139 | ``` 140 | 141 | #### r.Create(data interface{}) error 142 | 143 | 创建资源,data 可以是结构体、map 144 | 145 | * 为结构体时,匹配 struct Key,为结构体可以获取到创建完成后的其他字段,如id,时间等 146 | * 为 map 时,顺序匹配struct Key、column key、json key 147 | 148 | 返回错误对象 149 | 150 | #### r.Update(pk interface{}, data interface{}) error 151 | 152 | 更新资源, pk 是主键的值,data 是要更改的之后的数据 153 | 154 | * 为结构体时,匹配 struct Key,为结构体可以获取到创建完成后的其他字段,如id,时间等 155 | * 为 map 时,顺序匹配struct Key、column key、json key 156 | 157 | 返回错误对象 158 | 159 | #### r.Detail(pk interface{}, dest interface{}) error 160 | 161 | 查询指定资源,pk是主键的值,查询到资源后赋值给 dest,dest必须为指针类型 162 | 163 | #### r.Delete(pk interface{}) error 164 | 165 | 删除指定主键值的记录 166 | 167 | #### r.List(list interface{}, query interface{}, order interface{}) (int64, error) 168 | 169 | 查询资源 170 | 171 | * list 是数组或切片的指针对象,查询到数据后赋值到该变量 172 | * query 可选,查询条件,查询类型会按照 tag 定义的search值,可为 struct 和 map, 为 map 时,顺序匹配struct Key、column key、json key 173 | * query 有两个内部字段会用上, offset 和 limit,用于定义分页,返回值的总数会忽略这两个值的查询条件 174 | * order 可选,定义排序规则,如果不传,默认会按照 tag 定义的 order 值,可为 []string 和 map,为 []string 可定义排序字段权重,为map时无序 175 | 176 | 返回值 177 | * int64 为在query查询条件能查到的总数,查询条件不会包含 offset 和 limit 178 | * error 错误对象 179 | 180 | #### r.Row(pk interface{}) *gorm.DB 181 | 182 | 返回 gorm 绑定了主键查询条件的实例 183 | 184 | -------------------------------------------------------------------------------- /db/resource/resource.go: -------------------------------------------------------------------------------- 1 | package resource 2 | 3 | import ( 4 | "github.com/go-eyas/toolkit/db" 5 | "github.com/jinzhu/gorm" 6 | ) 7 | 8 | func New(conf *db.Config, model interface{}) (*Resource, *gorm.DB, error) { 9 | db, err := db.Gorm(conf) 10 | if err != nil { 11 | return nil, nil, err 12 | } 13 | r := NewGormResource(db, model) 14 | return r, db, nil 15 | } -------------------------------------------------------------------------------- /db/sqlite/sqlite.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import ( 4 | _ "github.com/mattn/go-sqlite3" 5 | ) -------------------------------------------------------------------------------- /db/xorm.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | // load mysql driver 5 | _ "github.com/go-sql-driver/mysql" 6 | "xorm.io/xorm" 7 | "xorm.io/xorm/log" 8 | 9 | // load postgresql driver 10 | _ "github.com/lib/pq" 11 | 12 | // load mssql 13 | _ "github.com/denisenkom/go-mssqldb" 14 | ) 15 | 16 | type xormLogger struct { 17 | logger Logger 18 | } 19 | 20 | func (xl *xormLogger) Debug(v ...interface{}) { 21 | xl.logger.Debug(v...) 22 | } 23 | 24 | func (xl *xormLogger) Debugf(f string, v ...interface{}) { 25 | xl.logger.Debugf(f, v...) 26 | } 27 | 28 | func (xl *xormLogger) Info(v ...interface{}) { 29 | xl.logger.Debug(v...) 30 | } 31 | 32 | func (xl *xormLogger) Infof(f string, v ...interface{}) { 33 | xl.logger.Debugf(f, v...) 34 | } 35 | 36 | func (xl *xormLogger) Warn(v ...interface{}) { 37 | xl.logger.Debug(v...) 38 | } 39 | 40 | func (xl *xormLogger) Warnf(f string, v ...interface{}) { 41 | xl.logger.Errorf(f, v...) 42 | } 43 | 44 | func (xl *xormLogger) Error(v ...interface{}) { 45 | xl.logger.Debug(v...) 46 | } 47 | 48 | func (xl *xormLogger) Errorf(f string, v ...interface{}) { 49 | xl.logger.Errorf(f, v...) 50 | } 51 | 52 | func (xl *xormLogger) Level() log.LogLevel { 53 | return 0 54 | } 55 | 56 | func (xl *xormLogger) SetLevel(l log.LogLevel) { 57 | } 58 | 59 | func (xl *xormLogger) ShowSQL(b ...bool) { 60 | } 61 | 62 | func (xl *xormLogger) IsShowSQL() bool { 63 | return true 64 | } 65 | 66 | // Xorm 初始化Xorm 67 | func Xorm(conf *Config) (*xorm.Engine, error) { 68 | db, err := xorm.NewEngine(conf.Driver, conf.URI) 69 | 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | if conf.Debug { 75 | db.ShowSQL(conf.Debug) 76 | if conf.Logger != nil { 77 | logger := log.NewLoggerAdapter(&xormLogger{conf.Logger}) 78 | db.SetLogger(logger) 79 | } 80 | } 81 | 82 | return db, nil 83 | } 84 | -------------------------------------------------------------------------------- /db/xorm_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "github.com/go-eyas/toolkit/log" 5 | "os" 6 | "testing" 7 | ) 8 | 9 | func TestXorm(t *testing.T) { 10 | db, err := Xorm(&Config{ 11 | Debug: true, 12 | Driver: "mysql", 13 | URI: os.Getenv("DB"), 14 | Logger: log.SugaredLogger, 15 | }) 16 | if err != nil { 17 | panic(err) 18 | } 19 | i := 0 20 | _, err = db.SQL("SELECT 1 + 1").Get(&i) 21 | if err != nil { 22 | panic(err) 23 | } 24 | 25 | if i == 2 { 26 | t.Log("test xorm success") 27 | } 28 | type XormTest struct { 29 | ID int64 `xorm:"id"` 30 | } 31 | db.Sync2(XormTest{}) 32 | list := []*XormTest{} 33 | err = db.Table(XormTest{}).Find(&list) 34 | if err != nil { 35 | panic(err) 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /email/email.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | import ( 4 | "bytes" 5 | "crypto/tls" 6 | "fmt" 7 | "html/template" 8 | "net/smtp" 9 | "sync" 10 | 11 | "github.com/go-eyas/toolkit/util" 12 | "github.com/jordan-wright/email" 13 | ) 14 | 15 | type TPL struct { 16 | Subject string 17 | From string 18 | To []string 19 | Bcc []string 20 | Cc []string 21 | Text string 22 | HTML string 23 | } 24 | 25 | type Config struct { 26 | Name string 27 | Account string 28 | Password string 29 | Host string 30 | Port string 31 | Secure bool 32 | TPL map[string]*TPL 33 | } 34 | 35 | type Email struct { 36 | mailHostAddr string 37 | SmtpAuth smtp.Auth 38 | conf *Config 39 | TLS *tls.Config 40 | TPLs map[string]*TPL 41 | cacheName string 42 | sendMu sync.Mutex 43 | } 44 | 45 | func New(conf *Config) *Email { 46 | auth := smtp.PlainAuth(conf.Name, conf.Account, conf.Password, conf.Host) 47 | email := &Email{ 48 | mailHostAddr: conf.Host + ":" + conf.Port, 49 | SmtpAuth: auth, 50 | conf: conf, 51 | cacheName: util.RandomStr(6), 52 | } 53 | if conf.Secure { 54 | email.TLS = &tls.Config{ 55 | ServerName: conf.Host, 56 | InsecureSkipVerify: true, 57 | } 58 | } 59 | 60 | if conf.TPL == nil { 61 | email.TPLs = make(map[string]*TPL) 62 | } else { 63 | email.TPLs = conf.TPL 64 | } 65 | 66 | return email 67 | } 68 | 69 | func (e *Email) NewEmail() *email.Email { 70 | return email.NewEmail() 71 | } 72 | 73 | func (e *Email) NewEmailByTpl(tplName string, data interface{}) (*email.Email, error) { 74 | tpl, ok := e.conf.TPL[tplName] 75 | if !ok { 76 | return nil, fmt.Errorf("tpl name %s is not defined", tplName) 77 | } 78 | var err error 79 | mail := email.NewEmail() 80 | cachePrefix := e.cacheName + "." + tplName + "." 81 | subject, err := templateParse(cachePrefix+"subject", tpl.Subject, data) 82 | if err != nil { 83 | return nil, err 84 | } 85 | mail.Subject = string(subject) 86 | 87 | if tpl.Text != "" { 88 | mail.Text, err = templateParse(cachePrefix+"text", tpl.Text, data) 89 | if err != nil { 90 | return nil, err 91 | } 92 | } 93 | if tpl.HTML != "" { 94 | mail.HTML, err = templateParse(cachePrefix+"html", tpl.HTML, data) 95 | if err != nil { 96 | return nil, err 97 | } 98 | } 99 | 100 | mail.From = fmt.Sprintf("%s <%s>", e.conf.Name, e.conf.Account) 101 | if len(tpl.Bcc) > 0 { 102 | mail.Bcc = tpl.Bcc 103 | } 104 | 105 | if len(tpl.Cc) > 0 { 106 | mail.Cc = tpl.Cc 107 | } 108 | 109 | if len(tpl.To) > 0 { 110 | mail.To = tpl.To 111 | } else { 112 | mail.To = []string{} 113 | } 114 | 115 | return mail, nil 116 | } 117 | 118 | func (e *Email) Send(addr string, mail *email.Email) error { 119 | e.sendMu.Lock() 120 | defer e.sendMu.Unlock() 121 | 122 | mail.To = append(mail.To, addr) 123 | if e.TLS == nil { 124 | return mail.Send(e.mailHostAddr, e.SmtpAuth) 125 | } 126 | return mail.SendWithTLS(e.mailHostAddr, e.SmtpAuth, e.TLS) 127 | } 128 | 129 | func (e *Email) SendByTpl(addr string, tplName string, data interface{}) error { 130 | mail, err := e.NewEmailByTpl(tplName, data) 131 | if err != nil { 132 | return err 133 | } 134 | return e.Send(addr, mail) 135 | } 136 | 137 | var templateCache = map[string]*template.Template{} 138 | var cacheMu sync.RWMutex 139 | 140 | func templateParse(name, src string, data interface{}) ([]byte, error) { 141 | var err error 142 | cacheMu.RLock() 143 | parse, ok := templateCache[name] 144 | cacheMu.RUnlock() 145 | if !ok { 146 | parse, err = template.New(name).Parse(src) 147 | if err != nil { 148 | return nil, err 149 | } 150 | cacheMu.Lock() 151 | templateCache[name] = parse 152 | cacheMu.Unlock() 153 | } 154 | content := new(bytes.Buffer) 155 | err = parse.Execute(content, data) 156 | if err != nil { 157 | return nil, err 158 | } 159 | return content.Bytes(), nil 160 | } 161 | -------------------------------------------------------------------------------- /email/email_test.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/BurntSushi/toml" 8 | ) 9 | 10 | var tomlConfig = ` 11 | host = "smtp.qq.com" 12 | port = "465" 13 | account = "893521870@qq.com" 14 | password = "haha, wo cai bu gao su ni ne" 15 | name = "unit test" 16 | secure = true 17 | [tpl.a] 18 | bcc = ["Jeason "] # 抄送 19 | cc = [] # 抄送人 20 | subject = "Welcome, {{.Name}}" # 主题 21 | text = "Hello, I am {{.Name}}" # 文本 22 | html = "

Hello, I am {{.Name}}

" # html 内容 23 | ` 24 | 25 | func TestEmail(t *testing.T) { 26 | conf := &Config{} 27 | toml.Decode(tomlConfig, conf) 28 | conf.Password = os.Getenv("password") 29 | 30 | email := New(conf) 31 | err := email.SendByTpl("Yuesong Liu ", "a", struct{ Name string }{"Batman"}) 32 | 33 | if err != nil { 34 | t.Errorf("mail send fail: %v", err) 35 | } else { 36 | t.Log("mail send success") 37 | } 38 | } 39 | 40 | func ExampleSample() { 41 | tomlConfig := ` 42 | host = "smtp.qq.com" 43 | port = "465" 44 | account = "893521870@qq.com" 45 | password = "haha, wo cai bu gao su ni ne" 46 | name = "unit test" 47 | secure = true 48 | [tpl.a] 49 | bcc = ["Jeason "] # 抄送 50 | cc = [] # 抄送人 51 | subject = "Welcome, {{.Name}}" # 主题 52 | text = "Hello, I am {{.Name}}" # 文本 53 | html = "

Hello, I am {{.Name}}

" # html 内容 54 | ` 55 | conf := &Config{} 56 | toml.Decode(tomlConfig, conf) 57 | email := New(conf) 58 | email.SendByTpl("Yuesong Liu ", "a", struct{ Name string }{"Batman"}) 59 | } 60 | -------------------------------------------------------------------------------- /email/readme.md: -------------------------------------------------------------------------------- 1 | # 发邮件 2 | 3 | 发送邮件,就是这么简单 4 | 5 | ## 使用 6 | 7 | ```go 8 | import ( 9 | "github.com/go-eyas/toolkit/email" 10 | "github.com/BurntSushi/toml" 11 | ) 12 | 13 | func ExampleSample() { 14 | tomlConfig := ` 15 | host = "smtp.qq.com" 16 | port = "465" 17 | account = "893521870@qq.com" 18 | password = "haha, wo cai bu gao su ni ne" 19 | name = "unit test" 20 | secure = true 21 | [tpl.a] 22 | bcc = ["Jeason "] # 抄送 23 | cc = [] # 抄送人 24 | subject = "Welcome, {{.Name}}" # 主题 25 | text = "Hello, I am {{.Name}}" # 文本 26 | html = "

Hello, I am {{.Name}}

" # html 内容 27 | ` 28 | conf := &Config{} 29 | toml.Decode(tomlConfig, conf) 30 | email := New(conf) 31 | email.SendByTpl("Yuesong Liu ", "a", struct{ Name string }{"Batman"}) 32 | } 33 | ``` -------------------------------------------------------------------------------- /emit/emit.go: -------------------------------------------------------------------------------- 1 | package emit 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | ) 7 | 8 | // Handler 监听触发的回调函数 9 | type Handler func(interface{}) 10 | 11 | type Emitter struct { 12 | lisMu sync.RWMutex 13 | listener map[string][]Handler 14 | } 15 | 16 | // New 实例化监听器 17 | func New() *Emitter { 18 | return &Emitter{ 19 | listener: make(map[string][]Handler), 20 | } 21 | } 22 | 23 | // On 增加事件监听 24 | func (e *Emitter) On(name string, handler ...Handler) *Emitter { 25 | e.lisMu.RLock() 26 | handlers, ok := e.listener[name] 27 | e.lisMu.RUnlock() 28 | if !ok { 29 | handlers = handler 30 | } else { 31 | handlers = append(handlers, handler...) 32 | } 33 | e.lisMu.Lock() 34 | e.listener[name] = handlers 35 | e.lisMu.Unlock() 36 | return e 37 | } 38 | 39 | // Off 取消监听 40 | // Off("evt") 取消所有 41 | // Off("evt", handler1, handler2) 取消指定函数 42 | func (e *Emitter) Off(name string, handler ...Handler) *Emitter { 43 | if len(handler) == 0 { 44 | delete(e.listener, name) 45 | return e 46 | } 47 | e.lisMu.RLock() 48 | handlers, ok := e.listener[name] 49 | e.lisMu.RUnlock() 50 | if !ok || len(handlers) == 0 { 51 | return e 52 | } 53 | nextHandlers := []Handler{} 54 | for _, existH := range handlers { 55 | rm := false 56 | for _, offH := range handler { 57 | if reflect.ValueOf(existH).Pointer() == reflect.ValueOf(offH).Pointer() { 58 | rm = true 59 | } 60 | } 61 | if !rm { 62 | nextHandlers = append(nextHandlers, existH) 63 | } 64 | } 65 | 66 | e.lisMu.Lock() 67 | e.listener[name] = nextHandlers 68 | e.lisMu.Unlock() 69 | return e 70 | } 71 | 72 | // Emit 分发事件 73 | func (e *Emitter) Emit(name string, v interface{}) *Emitter { 74 | e.lisMu.RLock() 75 | handlers, ok := e.listener[name] 76 | e.lisMu.RUnlock() 77 | if !ok || len(handlers) == 0 { 78 | return e 79 | } 80 | for _, h := range handlers { 81 | h(v) 82 | } 83 | return e 84 | } 85 | 86 | // 全局默认的监听器 87 | var emit = &Emitter{} 88 | 89 | func On(name string, handler ...Handler) *Emitter { 90 | return emit.On(name, handler...) 91 | } 92 | 93 | func Off(name string, handler ...Handler) *Emitter { 94 | return emit.Off(name, handler...) 95 | } 96 | 97 | func Emit(name string, v interface{}) *Emitter { 98 | return emit.Emit(name, v) 99 | } 100 | -------------------------------------------------------------------------------- /emit/emit_test.go: -------------------------------------------------------------------------------- 1 | package emit 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | ) 7 | 8 | func TestEmitter(t *testing.T) { 9 | e := New() 10 | 11 | var wg sync.WaitGroup 12 | 13 | fn1 := func(data interface{}) { 14 | t.Logf("fn1 receive event data: %v", data) 15 | wg.Done() 16 | } 17 | 18 | fn2 := func(data interface{}) { 19 | t.Logf("fn2 receive event data: %v", data) 20 | wg.Done() 21 | } 22 | 23 | fn3 := func(data interface{}) { 24 | t.Logf("fn3 receive event data: %v", data) 25 | wg.Done() 26 | } 27 | 28 | // test on 29 | e.On("testEvt", fn1) 30 | e.On("testEvt", fn2, fn3) 31 | wg.Add(3) 32 | 33 | // test emit 34 | e.Emit("testEvt", "this is data") 35 | wg.Wait() 36 | 37 | // test off 38 | e.Off("testEvt", fn3) 39 | wg.Add(2) 40 | e.Emit("testEvt", "this is data") 41 | wg.Wait() 42 | 43 | } 44 | -------------------------------------------------------------------------------- /emit/readme.md: -------------------------------------------------------------------------------- 1 | # 事件监听器 2 | 3 | 最简洁的事件监听与分发 4 | 5 | ## API 6 | 7 | 默认使用全局监听器,在全局范围的监听均有效,如果需要局部的事件分发,可使用 New 重新实例化一个,实例化的事件分发与全局的完全隔离 8 | 9 | 支持链式操作 10 | 11 | #### New() 12 | 13 | 重新实例化一个事件分发器 14 | 15 | #### On(name string, handler func(interface{})) 16 | 17 | 增加监听 18 | 19 | * name 事件名称 20 | * handler 事件触发回调函数 21 | 22 | #### Off(name string, handler func(interface{})) 23 | 24 | 取消事件监听,如果 handler 为空,则取消该事件的所有监听,如果不为空,则只取消指定的监听函数 25 | 26 | * name 事件名称 27 | * handler 事件触发回调函数 28 | 29 | ```go 30 | emit.Off("evt") // 取消 evt 时间的所有函数监听 31 | emit.Off("evt", fn1) // 只取消 fn1 函数的监听 32 | 33 | ``` 34 | 35 | #### Emit(name string, data interface{}) 36 | 37 | 触发事件 38 | 39 | * name 事件名称 40 | * data 事件触发携带的数据 41 | 42 | ## 使用 43 | 44 | ```go 45 | import "github.com/go-eyas/toolkit/emit" 46 | 47 | fn1 := func(data interface{}) { 48 | fmt.Printf("fn1 receive data: %v", data) 49 | } 50 | 51 | fn2 := func(data interface{}) { 52 | fmt.Printf("fn2 receive data: %v", data) 53 | } 54 | 55 | fn3 := func(data interface{}) { 56 | fmt.Printf("fn3 receive data: %v", data) 57 | } 58 | 59 | emit. 60 | On("evt", fn1). 61 | On("evt", fn2, fn3). 62 | Emit("evt", "hello emitter") 63 | 64 | emit.Off("evt", fn3) 65 | emit.Emit("evt", "hello emitter again") 66 | 67 | // or 68 | e := emit.New().On(...).Off(...) 69 | 70 | e.Emit(...) 71 | ``` 72 | -------------------------------------------------------------------------------- /gin/midleware/cache-control.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "regexp" 5 | 6 | "github.com/gin-gonic/gin" 7 | ) 8 | 9 | var StaticRegexp = regexp.MustCompile("/*/*.(js|css|png|jpg|woff|tff|oet)") 10 | 11 | func CacheControl(filter *regexp.Regexp) gin.HandlerFunc { 12 | return func(c *gin.Context) { 13 | path := c.Request.RequestURI 14 | if filter != nil && filter.MatchString(path) { 15 | c.Header("Cache-Control", "public, max-age=31536000") 16 | } 17 | 18 | c.Next() 19 | status := c.Writer.Status() 20 | if status > 300 || status < 200 { 21 | c.Header("Cache-Control", "") 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /gin/midleware/error.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/go-eyas/toolkit/gin/util" 5 | 6 | "github.com/gin-gonic/gin" 7 | ) 8 | 9 | type errLogger interface { 10 | Errorf(string, ...interface{}) 11 | } 12 | 13 | // var codeUnknowError = 999999 14 | 15 | // ErrorMiddleware 捕获到在http处理时的错误 16 | // 在 handler 和其他地方如果产生了 error 可直接panic,到这里统一处理,简化 if err != nil 之类的代码 17 | // panic("text") => {msg: "text", code: 0, data: {}} 18 | // panic(gin.H{"code": 0, "msg": "some error"}) => {与传入的数据一致,} code 默认999999,status 默认 400,msg 默认 unknow error 19 | // panic(errors.New("some error")) => {msg: "some error", code: 999999, data: {}} 20 | // panic(Struct{...}) => {msg: "unknow", code: 999999, data: {...struct 数据}} 21 | func ErrorMiddleware(logger errLogger) gin.HandlerFunc { 22 | return func(ctx *gin.Context) { 23 | defer func() { 24 | if err := recover(); err != nil { 25 | logger.Errorf("%v", err) 26 | ctx.Abort() 27 | util.R(ctx).Error(err) 28 | } 29 | }() 30 | ctx.Next() 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /gin/midleware/logger.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "io/ioutil" 7 | "net" 8 | "net/http" 9 | "net/http/httputil" 10 | "os" 11 | "regexp" 12 | "runtime/debug" 13 | "strings" 14 | "time" 15 | 16 | "github.com/gin-gonic/gin" 17 | "go.uber.org/zap" 18 | "go.uber.org/zap/zapcore" 19 | ) 20 | 21 | // Ginzap returns a gin.HandlerFunc (middleware) that logs requests using uber-go/zap. 22 | // 23 | // Requests with errors are logged using zap.Error(). 24 | // Requests without errors are logged using zap.Info(). 25 | // 26 | // It receives: 27 | // 1. A time package format string (e.g. time.RFC3339). 28 | // 2. A boolean stating whether to use UTC time zone or local. 29 | func Ginzap(logger *zap.Logger, printBody bool, filter *regexp.Regexp) gin.HandlerFunc { 30 | return func(c *gin.Context) { 31 | // some evil middlewares modify this values 32 | path := c.Request.RequestURI 33 | if filter != nil && !filter.MatchString(path) { 34 | c.Next() 35 | return 36 | } 37 | start := time.Now() 38 | contentType := c.GetHeader("content-type") 39 | body := "" 40 | if printBody && c.Request.Method != "GET" && !strings.Contains(contentType, "form-data") { 41 | bodyByte, _ := ioutil.ReadAll(c.Request.Body) 42 | c.Request.Body = ioutil.NopCloser(io.Reader(bytes.NewReader(bodyByte))) 43 | body = string(bodyByte) 44 | } 45 | 46 | c.Next() 47 | 48 | end := time.Now() 49 | latency := end.Sub(start) 50 | 51 | if len(c.Errors) > 0 { 52 | // Append error field if this is an erroneous request. 53 | for _, e := range c.Errors.Errors() { 54 | logger.Error(e) 55 | } 56 | } else { 57 | args := []zapcore.Field{ 58 | zap.Int("status", c.Writer.Status()), 59 | zap.String("method", c.Request.Method), 60 | zap.String("host", c.Request.Host), 61 | zap.String("ip", c.ClientIP()), 62 | zap.String("user-agent", c.Request.UserAgent()), 63 | zap.String("latency", latency.String()), 64 | } 65 | origin := c.Request.Header.Get("Origin") 66 | if origin != "" { 67 | args = append(args, zap.String("origin", origin)) 68 | } 69 | if contentType != "" { 70 | args = append(args, zap.String("content-type", contentType)) 71 | } 72 | if printBody && c.Request.Method != "GET" && body != "" { 73 | args = append(args, zap.String("body", body)) 74 | } 75 | logger.Debug(path, args...) 76 | } 77 | } 78 | } 79 | 80 | // RecoveryWithZap returns a gin.HandlerFunc (middleware) 81 | // that recovers from any panics and logs requests using uber-go/zap. 82 | // All errors are logged using zap.Error(). 83 | // stack means whether output the stack info. 84 | // The stack info is easy to find where the error occurs but the stack info is too large. 85 | func RecoveryWithZap(logger *zap.Logger, stack bool) gin.HandlerFunc { 86 | return func(c *gin.Context) { 87 | defer func() { 88 | if err := recover(); err != nil { 89 | // Check for a broken connection, as it is not really a 90 | // condition that warrants a panic stack trace. 91 | var brokenPipe bool 92 | if ne, ok := err.(*net.OpError); ok { 93 | if se, ok := ne.Err.(*os.SyscallError); ok { 94 | if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { 95 | brokenPipe = true 96 | } 97 | } 98 | } 99 | 100 | httpRequest, _ := httputil.DumpRequest(c.Request, false) 101 | if brokenPipe { 102 | logger.Error(c.Request.URL.Path, 103 | zap.String("error", err.(string)), 104 | zap.String("request", string(httpRequest)), 105 | ) 106 | // If the connection is dead, we can't write a status to it. 107 | _ = c.Error(err.(error)) // nolint: errcheck 108 | c.Abort() 109 | return 110 | } 111 | 112 | if stack { 113 | logger.Error("[Recovery from panic]", 114 | zap.Time("time", time.Now()), 115 | zap.String("error", err.(string)), 116 | zap.String("request", string(httpRequest)), 117 | zap.String("stack", string(debug.Stack())), 118 | ) 119 | } else { 120 | logger.Error("[Recovery from panic]", 121 | zap.Time("time", time.Now()), 122 | zap.Any("error", err), 123 | zap.String("request", string(httpRequest)), 124 | ) 125 | } 126 | c.AbortWithStatus(http.StatusInternalServerError) 127 | } 128 | }() 129 | c.Next() 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /gin/readme.md: -------------------------------------------------------------------------------- 1 | # gin 工具箱 2 | 3 | ## 中间件 4 | 5 | #### error 6 | 7 | 捕获到在http处理时的错误 8 | 9 | 在 handler 和其他地方如果产生了 error 可直接panic,到这里统一处理,简化 if err != nil 之类的代码 10 | 11 | ```go 12 | import ( 13 | "github.com/go-eyas/toolkit/gin/middleware" 14 | "github.com/go-eyas/toolkit/log" 15 | ) 16 | 17 | log.Init(&log.Config{}) 18 | 19 | // engine 20 | route.Use(middleware.Error(log.SugaredLogger)) // 如果不需要日志记录,传 nil 21 | 22 | // handler 23 | func HelloHandler(c *gin.Context) { 24 | panic("text") // {msg: "text", code: 0, data: {}} 25 | panic(gin.H{"code": 0, "msg": "some error"}) // {与传入的数据一致,} code 默认999999,status 默认 400,msg 默认 unknow error 26 | panic(errors.New("some error")) // {msg: "some error", code: 999999, data: {}} 27 | panic(Struct{...}) // {msg: "unknow", code: 999999, data: {...struct 数据}} 28 | } 29 | ``` 30 | 31 | #### logger 32 | 33 | 使用 zap 打印日志 34 | 35 | log.Init(&log.Config{}) 36 | 37 | ```go 38 | import ( 39 | "github.com/go-eyas/toolkit/gin/middleware" 40 | "github.com/go-eyas/toolkit/log" 41 | ) 42 | 43 | log.Init(&log.Config{}) 44 | 45 | route.Use(middleware.Ginzap(log.SugaredLogger)) 46 | ``` 47 | 48 | ## [工具函数](./util) 49 | 50 | 51 | # godoc 52 | 53 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/gin) -------------------------------------------------------------------------------- /gin/util/readme.md: -------------------------------------------------------------------------------- 1 | # 给 gin 定制的工具函数 2 | 3 | ## response 4 | 5 | 使接口的回应固定格式为 6 | 7 | ```json 8 | { 9 | "status": 0, 10 | "msg": "ok", 11 | "data": {} 12 | } 13 | ``` 14 | 15 | 使用 16 | 17 | ```go 18 | import ( 19 | "github.com/gin-gonic/gin" 20 | "github.com/go-eyas/toolkit/gin/util" 21 | ) 22 | 23 | func HelloHandler(c *gin.Context) { 24 | util.R(c).OK(gin.H{ 25 | "hello": "world", 26 | }) 27 | // 将会响应为 28 | // { 29 | // "status": 0, 30 | // "msg": "ok", 31 | // "data": { 32 | // "hello": "world", 33 | // } 34 | // } 35 | } 36 | ``` -------------------------------------------------------------------------------- /gin/util/response.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/gin-gonic/gin" 7 | ) 8 | 9 | type RData struct { 10 | Code int `json:"-"` // http 状态码 11 | Status int `json:"status"` 12 | Msg string `json:"msg"` 13 | Data interface{} `json:"data"` 14 | } 15 | 16 | // resp 回应工具 17 | type resp struct { 18 | c *gin.Context 19 | } 20 | 21 | // R 封装响应数据 22 | func R(c *gin.Context) *resp { 23 | return &resp{c} 24 | } 25 | 26 | const CodeSuccess = 0 27 | const CodeUnknowError = 99999999 28 | 29 | // Response 响应数据 30 | func (r resp) Response(data *RData) { 31 | c := r.c 32 | c.JSON(data.Code, data) 33 | } 34 | 35 | // parse 解析响应数据 36 | func (r resp) Parse(v interface{}) *RData { 37 | data := &RData{ 38 | Code: http.StatusOK, 39 | Msg: "ok", 40 | Status: CodeSuccess, 41 | } 42 | switch v.(type) { 43 | case error: 44 | res := v.(error) 45 | data.Code = http.StatusInternalServerError 46 | data.Msg = res.Error() 47 | data.Status = CodeUnknowError 48 | data.Data = gin.H{} 49 | 50 | case string: 51 | data.Data = v.(string) 52 | 53 | case gin.H, *gin.H, map[string]interface{}: 54 | var e gin.H 55 | if b, ok := v.(gin.H); ok { 56 | e = b 57 | } else if b, ok := v.(map[string]interface{}); ok { 58 | e = gin.H(b) 59 | } else if b, ok := v.(*gin.H); ok { 60 | e = *b 61 | } 62 | 63 | resCode := e["code"] 64 | if resCode == nil { 65 | resCode = http.StatusOK 66 | } 67 | 68 | resStatus := e["status"] 69 | if resStatus == nil { 70 | resStatus = CodeSuccess 71 | } 72 | 73 | resMsg := e["msg"] 74 | if resMsg == nil { 75 | resMsg = "ok" 76 | } else if errmsgError, ok := resMsg.(error); ok { 77 | resMsg = errmsgError.Error() 78 | } 79 | 80 | resData := e["data"] 81 | if resData == nil { 82 | resData = gin.H{} 83 | } 84 | 85 | data = &RData{ 86 | Code: resCode.(int), 87 | Status: resStatus.(int), 88 | Msg: resMsg.(string), 89 | Data: resData, 90 | } 91 | 92 | case RData, *RData: 93 | if b, ok := v.(RData); ok { 94 | data = &b 95 | } else { 96 | data = v.(*RData) 97 | } 98 | default: 99 | data.Data = v 100 | } 101 | 102 | return data 103 | } 104 | 105 | // OK 响应成功 106 | func (r resp) OK(v interface{}) { 107 | r.Response(&RData{ 108 | Code: http.StatusOK, 109 | Msg: "ok", 110 | Status: CodeSuccess, 111 | Data: v, 112 | }) 113 | } 114 | 115 | // Res 通用回应 116 | func (r resp) Res(v interface{}) { 117 | r.Response(r.Parse(v)) 118 | } 119 | 120 | // Err 回应错误 121 | func (r resp) Err(v error) { 122 | r.Res(v) 123 | } 124 | 125 | func (r resp) Error(v interface{}) { 126 | data := r.Parse(v) 127 | if data.Code == 0 { 128 | data.Code = 500 129 | } 130 | if data.Code == http.StatusOK { 131 | data.Code = http.StatusInternalServerError 132 | } 133 | if data.Msg == "ok" { 134 | if msg, ok := data.Data.(string); ok { 135 | data.Msg = msg 136 | data.Data = gin.H{} 137 | } else { 138 | data.Msg = "unknow error" 139 | } 140 | } 141 | if data.Status == CodeSuccess { 142 | data.Status = CodeUnknowError 143 | } 144 | r.Response(data) 145 | } 146 | 147 | // Forbidden 回应禁止访问 148 | func (r resp) Forbidden(v error) { 149 | data := r.Parse(v) 150 | data.Code = http.StatusForbidden 151 | r.Response(data) 152 | } 153 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-eyas/toolkit 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/BurntSushi/toml v0.3.1 7 | github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc 8 | github.com/elazarl/goproxy v0.0.0-20200426045556-49ad98f6dac1 // indirect 9 | github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 // indirect 10 | github.com/gin-gonic/gin v1.6.3 11 | github.com/go-playground/validator/v10 v10.2.0 12 | github.com/go-redis/redis v6.15.7+incompatible 13 | github.com/go-sql-driver/mysql v1.5.0 14 | github.com/gorilla/websocket v1.4.2 15 | github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869 // indirect 16 | github.com/jinzhu/configor v1.2.0 17 | github.com/jinzhu/gorm v1.9.12 18 | github.com/jonboulle/clockwork v0.1.0 // indirect 19 | github.com/jordan-wright/email v0.0.0-20200322182553-8eef2508c362 20 | github.com/lestrrat-go/file-rotatelogs v2.3.0+incompatible 21 | github.com/lestrrat-go/strftime v1.0.1 // indirect 22 | github.com/lib/pq v1.5.2 23 | github.com/mattn/go-sqlite3 v2.0.3+incompatible 24 | github.com/novalagung/gubrak v1.0.0 25 | github.com/parnurzeal/gorequest v0.2.16 26 | github.com/rs/xid v1.2.1 27 | github.com/smartystreets/goconvey v1.6.4 // indirect 28 | github.com/streadway/amqp v0.0.0-20200108173154-1c71cc93ed71 29 | github.com/tebeka/strftime v0.1.4 // indirect 30 | go.uber.org/zap v1.15.0 31 | golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 32 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859 33 | moul.io/http2curl v1.0.0 // indirect 34 | xorm.io/xorm v1.0.1 35 | ) 36 | -------------------------------------------------------------------------------- /http/http.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | ) 7 | 8 | var defaultRequest = New() 9 | 10 | // Type 请求提交方式,默认json 11 | func Type(name string) Request { 12 | return defaultRequest.Type(name) 13 | } 14 | 15 | // UserAgent 设置请求 user-agent,默认是 chrome 75.0 16 | func UserAgent(name string) Request { 17 | return defaultRequest.UserAgent(name) 18 | } 19 | 20 | // Cookie 设置请求 Cookie 21 | func Cookie(c *http.Cookie) Request { 22 | return defaultRequest.Cookie(c) 23 | } 24 | 25 | // Header 设置请求 Header 26 | func Header(key, val string) Request { 27 | return defaultRequest.Header(key, val) 28 | } 29 | 30 | // Proxy 设置请求代理 31 | func Proxy(url string) Request { 32 | return defaultRequest.Proxy(url) 33 | } 34 | 35 | // Query 设置请求代理 36 | func Query(query interface{}) Request { 37 | return defaultRequest.Query(query) 38 | } 39 | 40 | // Timeout 设置请求代理 41 | func Timeout(timeout time.Duration) Request { 42 | return defaultRequest.Timeout(timeout) 43 | } 44 | 45 | // UseRequest 增加请求中间件 46 | func UseRequest(mdl requestMiddlewareHandler) Request { 47 | return defaultRequest.UseRequest(mdl) 48 | } 49 | 50 | // UseResponse 增加响应中间件 51 | func UseResponse(mdl responseMidlewareHandler) Request { 52 | return defaultRequest.UseResponse(mdl) 53 | } 54 | 55 | // BaseURL 设置url前缀 56 | func BaseURL(url string) Request { 57 | return defaultRequest.BaseURL(url) 58 | } 59 | 60 | // Head 发起 head 请求 61 | func Head(url string, query interface{}) (*Response, error) { 62 | return defaultRequest.Do("HEAD", url, query, nil, nil) 63 | } 64 | 65 | // Get 发起 get 请求, query 查询参数 66 | func Get(url string, query interface{}) (*Response, error) { 67 | return defaultRequest.Do("GET", url, query, nil, nil) 68 | } 69 | 70 | // Post 发起 post 请求,body 是请求带的参数,可使用json字符串或者结构体 71 | func Post(url string, body interface{}) (*Response, error) { 72 | return defaultRequest.Do("POST", url, nil, body, nil) 73 | } 74 | 75 | // Put 发起 put 请求,body 是请求带的参数,可使用json字符串或者结构体 76 | func Put(url string, body interface{}) (*Response, error) { 77 | return defaultRequest.Do("PUT", url, nil, body, nil) 78 | } 79 | 80 | // Del 发起 delete 请求,body 是请求带的参数,可使用json字符串或者结构体 81 | func Del(url string, body interface{}) (*Response, error) { 82 | return defaultRequest.Do("DELETE", url, nil, body, nil) 83 | } 84 | 85 | // Patch 发起 patch 请求,body 是请求带的参数,可使用json字符串或者结构体 86 | func Patch(url string, body interface{}) (*Response, error) { 87 | return defaultRequest.Do("PATCH", url, nil, body, nil) 88 | } 89 | 90 | // Options 发起 options 请求,query 查询参数 91 | func Options(url string, query interface{}) (*Response, error) { 92 | return defaultRequest.Do("OPTIONS", url, query, nil, nil) 93 | } 94 | 95 | // PostFile 发起 post 请求上传文件,将使用表单提交,file 是文件地址或者文件流, body 是请求带的参数,可使用json字符串或者结构体 96 | func PostFile(url string, file interface{}, body interface{}) (*Response, error) { 97 | return defaultRequest.Do("PUT", url, nil, body, file) 98 | } 99 | 100 | // PutFile 发起 put 请求上传文件,将使用表单提交,file 是文件地址或者文件流, body 是请求带的参数,可使用json字符串或者结构体 101 | func PutFile(url string, file interface{}, body interface{}) (*Response, error) { 102 | return defaultRequest.Do("PUT", url, nil, body, file) 103 | } 104 | -------------------------------------------------------------------------------- /http/http_test.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestGet(t *testing.T) { 10 | http := New().UseRequest(func(req Request) Request { 11 | fmt.Printf("http 发送 %s %s header=%+v data=%+v\n", req.SuperAgent.Method, req.SuperAgent.Url, req.SuperAgent.Header, req.SuperAgent.Data) 12 | return req 13 | }).UseResponse(func(req Request, res *Response) *Response { 14 | fmt.Printf("http 接收 %s %s\n", req.SuperAgent.Method, req.SuperAgent.Url) 15 | return res 16 | }).Timeout(time.Second * 10). 17 | BaseURL("https://api.github.com"). 18 | BaseURL("/repos") 19 | 20 | res, err := http.Get("/eyasliu/blog/issues", map[string]interface{}{ 21 | "per_page": 1, 22 | }) 23 | if err != nil { 24 | panic(err) 25 | } 26 | s := []struct { 27 | URL string `json:"url"` 28 | Title string 29 | }{} 30 | err = res.JSON(&s) 31 | if err != nil { 32 | panic(err) 33 | } 34 | if len(s) > 0 { 35 | t.Logf("get res struct: %+v", s) 36 | } 37 | } 38 | 39 | func TestError(t *testing.T) { 40 | h := New().UseRequest(func(req Request) Request { 41 | fmt.Printf("http 发送 %s %s header=%+v data=%+v\n", req.SuperAgent.Method, req.SuperAgent.Url, req.SuperAgent.Header, req.SuperAgent.Data) 42 | return req 43 | }).UseResponse(func(req Request, res *Response) *Response { 44 | fmt.Printf("http 接收 %s %s\n", req.SuperAgent.Method, req.SuperAgent.Url) 45 | return res 46 | }) 47 | res, err := h.Header("just-test", "1234").Get("https://api.github.com/repos/eyasliu/blog/issuesx", nil) 48 | if err != nil { 49 | t.Logf("get error success: statusCode=%d body=%s error=%s", res.Status(), res.String(), err.Error()) 50 | } else { 51 | t.Fatalf("res: statusCode=%d body=%s", res.Status(), res.String()) 52 | panic("should get 404 error") 53 | } 54 | 55 | res, err = h.Get("", nil) 56 | if err != nil { 57 | t.Logf("success empty url, statusCode=%d body=%s error=%s", res.Status(), res.String(), err.Error()) 58 | } else { 59 | panic("should error") 60 | } 61 | 62 | } 63 | 64 | // func TestProxy(t *testing.T) { 65 | // h := Proxy("http://127.0.0.1:1080") 66 | // res, err := h.Get("https://www.google.com", map[string]string{ 67 | // "hl": "zh-Hans", 68 | // }) 69 | // if err != nil { 70 | // panic(err) 71 | // } 72 | // t.Logf("google html: %s", res.String()) 73 | // } 74 | -------------------------------------------------------------------------------- /http/readme.md: -------------------------------------------------------------------------------- 1 | # HTTP 客户端 2 | 3 | 封装http 客户端 4 | 5 | ## 使用 6 | 7 | ```go 8 | import ( 9 | "fmt" 10 | "github.com/go-eyas/toolkit/http" 11 | ) 12 | 13 | func main() { 14 | h := http.Header("Authorization", "Bearer xxxxxxxxxxxxxxx"). 15 | UserAgent("your custom user-agent"). 16 | Cookie(). 17 | BaseURL("https://api.github.com") 18 | 19 | res, err := h.Get("/repos/eyasliu/blog/issues", map[string]string{ 20 | "per_page": 1, 21 | }) 22 | 23 | // 获取字符串 24 | fmt.Printf("print string: %s\n", res.String()) 25 | // 获取字节 26 | fmt.Printf("print bytes: %v", res.Byte()) 27 | 28 | // 绑定结构体 29 | s := []struct { 30 | URL string `json:"url"` 31 | Title string `json:"title"` 32 | }{} 33 | res.JSON(&s) 34 | fmt.Printf("print Struct: %v", s) 35 | 36 | // 使用代理 37 | res, err := http.Proxy("http://127.0.0.1:1080").Get("https://www.google.com", map[string]string{ 38 | "hl": "zh-Hans", 39 | }) 40 | fmt.Printf("google html: %s", res.String()) 41 | } 42 | ``` 43 | 44 | ## 使用指南 45 | 46 | #### 请求示例 47 | 48 | ```go 49 | // get url 50 | http.Get("https://api.github.com", nil) 51 | 52 | // 带查询参数 53 | http.Get("https://www.google.com", "hl=zh-Hans") // 查询参数可以是字符串 54 | http.Get("https://www.google.com", map[string]string{ 55 | "hl": "zh-Hans", 56 | }) // 可以是map 57 | http.Get("https://www.google.com", struct{ 58 | HL string `json:"hl"` 59 | }{"zh-Hans"}) // 可以是结构体,使用json key作为查询参数的key 60 | 61 | // post 请求 62 | http.Post("https://api.github.com", nil) 63 | 64 | // post 带json参数 65 | http.Post("https://api.github.com", `{"hello": "world"}`) // 可以是字符串 66 | http.Post("https://api.github.com", map[string]interface{}{"hello": "world"}) // 可以是map 67 | http.Post("https://api.github.com", struct{ 68 | Hello string `json:"hello"` 69 | }{"world"}) // 可以是结构体,使用json 序列化字符串 70 | 71 | // post 带 查询参数,带json参数 72 | http.Query("hl=zh-Hans").Post("https://api.github.com", `{"hello": "world"}`) 73 | 74 | // post form表单 75 | http.Type("multipart").Post("https://api.github.com", map[string]interface{}{"hello": "world"}) 76 | // post 上传文件,会以表单提交 77 | http.PostFile("https://api.github.com", "./example_file.txt", map[string]interface{}{"hello": "world"}) 78 | 79 | // post 上传文件,使用file文件流 80 | file, _ := ioutil.ReadFile("./example_file.txt") 81 | file, _ := os.Open("./example_file.txt") 82 | http.PostFile("https://api.github.com", file, map[string]interface{}{"hello": "world"}) 83 | 84 | // put, 和post完全一致 85 | http.Put("https://api.github.com", nil) 86 | 87 | // delete, 和post完全一致 88 | http.Del("https://api.github.com", nil) 89 | 90 | // patch, 和post完全一致 91 | http.Patch("https://api.github.com", nil) 92 | 93 | // head, 和get完全一致 94 | http.Head("https://api.github.com", nil) 95 | 96 | // options, 和get完全一致 97 | http.Options("https://api.github.com", nil) 98 | ``` 99 | 100 | #### 响应示例 101 | 102 | ```go 103 | res, err := http.Options("https://api.github.com", nil) 104 | 105 | // 错误信息 106 | if err != nil { 107 | err.Error() // 错误信息 108 | } 109 | res.Err().Error() // 与上面等价 110 | 111 | // 响应数据 112 | // 将响应数据转为字符串 113 | var str string = res.String() 114 | 115 | // 将响应数据转为字节 116 | var bt []byte = res.Byte() 117 | 118 | // 获取响应状态码 119 | var statusCode = res.Status() 120 | 121 | // 获取响应的 header 122 | var http.Header = res.Header() 123 | 124 | // 获取响应的 cookies 125 | var []*http.Cookie = res.Cookies() 126 | 127 | // 与结构体绑定 128 | type ResTest struct { 129 | Hello string `json:"hello"` 130 | } 131 | rt := &ResTest{} 132 | res.JSON(rt) 133 | ``` 134 | 135 | **注意:** 136 | 137 | * http的响应状态码 >= 400 时会被视为错误,err 值是 `fmt.Errorf("http response status code %d", statusCode)` 138 | 139 | #### 提前设置通用项 140 | 141 | ```go 142 | h := http.Header("Authorization", "Bearer xxxxxxxxxxxxxxx"). // 设置header 143 | UserAgent("your custom user-agent"). // 设置 useragent 144 | Timeout(10 * time.Second). // 设置请求超时时间 145 | Query("lang=zh_ch"). // 设置查询参数 146 | Proxy("http://127.0.0.1:1080") // 设置代理 147 | 148 | h.Get("xxxx", nil) 149 | ``` 150 | 151 | #### 中间件支持 152 | 153 | 可以增加请求中间件和响应中间件,用于在请求或响应中改变内部操作 154 | 155 | ```go 156 | http.UseRequest(func(req *http.Request) *http.Request { 157 | fmt.Printf("http 发送 %s %s\n", req.SuperAgent.Method, req.SuperAgent.Url) 158 | return req 159 | }).UseResponse(func(req *http.Request, res *http.Response) *http.Response { 160 | fmt.Printf("http 接收 %s %s\n", req.SuperAgent.Method, req.SuperAgent.Url) 161 | return res 162 | }) 163 | ``` 164 | 165 | #### 代理设置 166 | 167 | 默认会获取环境变量 `http_proxy` 的值使用代理,但是可以手动指定 168 | 169 | ```go 170 | http.Proxy("http://127.0.0.1:1080").Get("https://www.google.com", map[string]string{ 171 | "hl": "zh-Hans", 172 | }) 173 | 174 | // 临时取消代理 175 | http.Proxy("").Get("https://www.google.com", map[string]string{ 176 | "hl": "zh-Hans", 177 | }) 178 | ``` 179 | 180 | #### 提交方式 181 | 182 | 也就是 `Type(t string)` 函数支持的值 183 | 184 | ``` 185 | "text/html" uses "html" 186 | "application/json" uses "json" 187 | "application/xml" uses "xml" 188 | "text/plain" uses "text" 189 | "application/x-www-form-urlencoded" uses "urlencoded", "form" or "form-data" 190 | ``` 191 | 192 | 如果是文件上传,则应该设置为 `multipart` 193 | 194 | ## godoc 195 | 196 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/http) -------------------------------------------------------------------------------- /http/request.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "strings" 8 | "time" 9 | 10 | "github.com/parnurzeal/gorequest" 11 | ) 12 | 13 | type requestMiddlewareHandler func(Request) Request 14 | type responseMidlewareHandler func(Request, *Response) *Response 15 | 16 | func newRaw() Request { 17 | return Request{ 18 | SuperAgent: gorequest.New(), 19 | reqMdls: []requestMiddlewareHandler{}, 20 | resMdls: []responseMidlewareHandler{}, 21 | } 22 | } 23 | 24 | // New 新建请求对象,默认数据类型 json 25 | func New() Request { 26 | r := newRaw() 27 | r = r.Type("json") 28 | 29 | return r 30 | } 31 | 32 | // Request 请求结构 33 | type Request struct { 34 | SuperAgent *gorequest.SuperAgent 35 | querys []interface{} 36 | headers map[string]interface{} 37 | reqMdls []requestMiddlewareHandler 38 | resMdls []responseMidlewareHandler 39 | cookies []*http.Cookie 40 | baseURL string 41 | contentType string 42 | proxy string 43 | timeout time.Duration 44 | } 45 | 46 | func (r Request) Clone() Request { 47 | req := newRaw() 48 | req.baseURL = r.baseURL 49 | req.proxy = r.proxy 50 | req.contentType = r.contentType 51 | 52 | // query 53 | for _, query := range r.querys { 54 | req.querys = append(req.querys, query) 55 | } 56 | 57 | for _, cookie := range r.SuperAgent.Cookies { 58 | req.SuperAgent.Cookies = append(req.SuperAgent.Cookies, cookie) 59 | } 60 | 61 | // req mdl 62 | for _, mdl := range r.reqMdls { 63 | req.reqMdls = append(req.reqMdls, mdl) 64 | } 65 | 66 | // res mdl 67 | for _, mdl := range r.resMdls { 68 | req.resMdls = append(req.resMdls, mdl) 69 | } 70 | 71 | // headers 72 | req.headers = map[string]interface{}{} 73 | //for k, v := range r.SuperAgent.Header { 74 | // req.SuperAgent.Header[k] = v 75 | //} 76 | 77 | return req 78 | } 79 | 80 | // Type 请求提交方式,默认json 81 | func (r Request) Type(name string) Request { 82 | req := r.Clone() 83 | req.contentType = name 84 | return req 85 | } 86 | 87 | // UserAgent 设置请求 user-agent,默认是 chrome 75.0 88 | func (r Request) UserAgent(name string) Request { 89 | req := r.Clone() 90 | req.SuperAgent = req.SuperAgent.Set("User-Agent", name) 91 | return req 92 | } 93 | 94 | // Cookie 设置请求 Cookie 95 | func (r Request) Cookie(c *http.Cookie) Request { 96 | req := r.Clone() 97 | req.cookies = append(req.cookies, c) 98 | return req 99 | } 100 | 101 | // Header 设置请求 Header 102 | func (r Request) Header(key, val string) Request { 103 | req := r.Clone() 104 | req.SuperAgent = req.SuperAgent.Set(key, val) 105 | return req 106 | } 107 | 108 | // Proxy 设置请求代理 109 | func (r Request) Proxy(url string) Request { 110 | req := r.Clone() 111 | req.proxy = url 112 | return req 113 | } 114 | 115 | // Query 增加查询参数 116 | func (r Request) Query(query interface{}) Request { 117 | req := r.Clone() 118 | req.querys = append(req.querys, query) 119 | return req 120 | } 121 | 122 | // Timeout 请求超时时间 123 | func (r Request) Timeout(timeout time.Duration) Request { 124 | req := r.Clone() 125 | req.timeout = timeout 126 | return req 127 | } 128 | 129 | // UseRequest 增加请求中间件 130 | func (r Request) UseRequest(mdl requestMiddlewareHandler) Request { 131 | req := r.Clone() 132 | req.reqMdls = append(req.reqMdls, mdl) 133 | return req 134 | } 135 | 136 | // UseResponse 增加响应中间件 137 | func (r Request) UseResponse(mdl responseMidlewareHandler) Request { 138 | req := r.Clone() 139 | req.resMdls = append(req.resMdls, mdl) 140 | return req 141 | } 142 | 143 | // BaseURL 设置url前缀 144 | func (r Request) BaseURL(url string) Request { 145 | req := r.Clone() 146 | req.baseURL += url 147 | return req 148 | } 149 | 150 | // Do 发出请求,method 请求方法,url 请求地址, query 查询参数,body 请求数据,file 文件对象/地址 151 | func (r Request) Do(method, url string, args ...interface{}) (*Response, error) { 152 | var query, body, file interface{} 153 | switch len(args) { 154 | case 1: 155 | query = args[0] 156 | case 2: 157 | query = args[0] 158 | body = args[1] 159 | default: 160 | query = args[0] 161 | body = args[1] 162 | file = args[2] 163 | } 164 | 165 | r = r.Clone() 166 | // set mthod url 167 | if method == "" || url == "" { 168 | return &Response{ 169 | Request: &r, 170 | Raw: nil, 171 | Body: []byte{}, 172 | Errs: []error{errors.New("url is empty")}, 173 | }, fmt.Errorf("http url can't empty") 174 | } 175 | // r.SuperAgent = r.SuperAgent.CustomMethod(method, r.baseURL+url) 176 | r.SuperAgent.Method = strings.ToUpper(method) 177 | r.SuperAgent.Url = r.baseURL + url 178 | r.SuperAgent.Errors = nil 179 | 180 | if r.contentType != "" { 181 | r.SuperAgent = r.SuperAgent.Type(r.contentType) 182 | } 183 | if r.timeout > 0 { 184 | r.SuperAgent = r.SuperAgent.Timeout(r.timeout) 185 | } 186 | 187 | if r.proxy != "" { 188 | r.SuperAgent = r.SuperAgent.Proxy(r.proxy) 189 | } 190 | 191 | // set query string 192 | if query != nil { 193 | r.SuperAgent = r.SuperAgent.Query(query) 194 | } 195 | for _, q := range r.querys { 196 | r.SuperAgent = r.SuperAgent.Query(q) 197 | } 198 | 199 | // set body 200 | if body != nil { 201 | r.SuperAgent = r.SuperAgent.Send(body) 202 | } 203 | 204 | if file != nil { 205 | r.Type("multipart") 206 | r.SuperAgent = r.SuperAgent.SendFile(file) 207 | } 208 | 209 | // 执行请求中间件 210 | for _, mdl := range r.reqMdls { 211 | r1 := mdl(r) 212 | r = r1 213 | } 214 | 215 | res, resBody, errs := r.SuperAgent.EndBytes() 216 | 217 | response := &Response{ 218 | Request: &r, 219 | Raw: res, 220 | Body: resBody, 221 | Errs: errs, 222 | } 223 | 224 | // 执行响应中间件 225 | for _, mdl := range r.resMdls { 226 | response = mdl(r, response) 227 | } 228 | 229 | statusCode := response.Status() 230 | if statusCode >= 400 { 231 | response.Errs = response.Errs.Add(fmt.Errorf("http response status code %d", statusCode)) 232 | } 233 | 234 | return response, response.Err() 235 | } 236 | 237 | // Head 发起 head 请求 238 | func (r Request) Head(url string, args ...interface{}) (*Response, error) { 239 | return r.Do("HEAD", url, args...) 240 | } 241 | 242 | // Get 发起 get 请求, query 查询参数 243 | func (r Request) Get(url string, args ...interface{}) (*Response, error) { 244 | return r.Do("GET", url, args...) 245 | } 246 | 247 | // Post 发起 post 请求,body 是请求带的参数,可使用json字符串或者结构体 248 | func (r Request) Post(url string, body ...interface{}) (*Response, error) { 249 | args := append(make([]interface{}, 1), body...) 250 | return r.Do("POST", url, args...) 251 | } 252 | 253 | // Put 发起 put 请求,body 是请求带的参数,可使用json字符串或者结构体 254 | func (r Request) Put(url string, body ...interface{}) (*Response, error) { 255 | args := append(make([]interface{}, 1), body...) 256 | return r.Do("PUT", url, args...) 257 | } 258 | 259 | // Del 发起 delete 请求,body 是请求带的参数,可使用json字符串或者结构体 260 | func (r Request) Del(url string, body ...interface{}) (*Response, error) { 261 | args := append(make([]interface{}, 1), body...) 262 | return r.Do("DELETE", url, args...) 263 | } 264 | 265 | // Patch 发起 patch 请求,body 是请求带的参数,可使用json字符串或者结构体 266 | func (r Request) Patch(url string, body ...interface{}) (*Response, error) { 267 | args := append(make([]interface{}, 1), body...) 268 | return r.Do("PATCH", url, args...) 269 | } 270 | 271 | // Options 发起 options 请求,query 查询参数 272 | func (r Request) Options(url string, args ...interface{}) (*Response, error) { 273 | return r.Do("OPTIONS", url, args...) 274 | } 275 | 276 | // PostFile 发起 post 请求上传文件,将使用表单提交,file 是文件地址或者文件流, body 是请求带的参数,可使用json字符串或者结构体 277 | func (r Request) PostFile(url string, file interface{}, body interface{}) (*Response, error) { 278 | return r.Do("PUT", url, nil, body, file) 279 | } 280 | 281 | // PutFile 发起 put 请求上传文件,将使用表单提交,file 是文件地址或者文件流, body 是请求带的参数,可使用json字符串或者结构体 282 | func (r Request) PutFile(url string, file interface{}, body interface{}) (*Response, error) { 283 | return r.Do("PUT", url, nil, body, file) 284 | } 285 | -------------------------------------------------------------------------------- /http/response.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | // NewResponse 新建回应对象 10 | func NewResponse() *Response { 11 | return &Response{} 12 | } 13 | 14 | // ResponseError 响应错误对象 15 | type ResponseError []error 16 | 17 | // Error 实现 error 接口 18 | func (e ResponseError) Error() string { 19 | errs := []error(e) 20 | s := []string{} 21 | for _, e := range errs { 22 | s = append(s, e.Error()) 23 | } 24 | 25 | return strings.Join(s, "\n") 26 | } 27 | 28 | // HasErr 是否有错误 29 | func (e ResponseError) HasErr() bool { 30 | if len(e) == 0 { 31 | return false 32 | } 33 | return true 34 | } 35 | 36 | // Add 增加错误 37 | func (e ResponseError) Add(err error) ResponseError { 38 | e = append(e, err) 39 | return e 40 | } 41 | 42 | // Response 回应对象 43 | type Response struct { 44 | Request *Request 45 | Raw *http.Response 46 | Body []byte 47 | Errs ResponseError 48 | } 49 | 50 | // Err 获取响应错误 51 | func (r *Response) Err() error { 52 | if r.Errs.HasErr() { 53 | return r.Errs 54 | } 55 | return nil 56 | } 57 | 58 | // JSON 根据json绑定结构体 59 | func (r *Response) JSON(v interface{}) error { 60 | return json.Unmarshal(r.Body, v) 61 | } 62 | 63 | // String 获取响应字符串 64 | func (r *Response) String() string { 65 | return string(r.Body) 66 | } 67 | 68 | // Byte 获取响应字节 69 | func (r *Response) Byte() []byte { 70 | return r.Body 71 | } 72 | 73 | // Status 获取响应状态码 74 | func (r *Response) Status() int { 75 | if r.Raw != nil { 76 | return r.Raw.StatusCode 77 | } 78 | return 0 79 | } 80 | 81 | // Header 获取响应header 82 | func (r *Response) Header() http.Header { 83 | if r.Raw != nil { 84 | return r.Raw.Header 85 | } 86 | return nil 87 | } 88 | 89 | // Cookies 获取响应 cookie 90 | func (r *Response) Cookies() []*http.Cookie { 91 | if r.Raw != nil { 92 | return r.Raw.Cookies() 93 | } 94 | return nil 95 | } 96 | 97 | // IsError 是否响应错误 98 | func (r *Response) IsError() bool { 99 | return r.Raw.StatusCode >= 400 && r.Err() != nil 100 | } 101 | -------------------------------------------------------------------------------- /http/v2/http.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import "time" 4 | 5 | var defClient = New().Type("json") 6 | 7 | 8 | // Type 请求提交方式,默认json 9 | func Type(name string) *Client { 10 | return defClient.Type(name) 11 | } 12 | 13 | // UserAgent 设置请求 user-agent 14 | func UserAgent(name string) *Client { 15 | return defClient.UserAgent(name) 16 | } 17 | 18 | // Cookie 设置请求 Cookie 19 | func Cookie(k string, v string) *Client { 20 | return defClient.Cookie(k, v) 21 | } 22 | 23 | // Header 设置请求 Header 24 | func Header(key, val string) *Client { 25 | return defClient.Header(key, val) 26 | } 27 | 28 | // Proxy 设置请求代理 29 | func Proxy(url string) *Client { 30 | return defClient.Proxy(url) 31 | } 32 | 33 | // Query 设置请求代理 34 | func Query(query interface{}) *Client { 35 | return defClient.Query(query) 36 | } 37 | 38 | // Timeout 设置请求代理 39 | func Timeout(timeout time.Duration) *Client { 40 | return defClient.Timeout(timeout) 41 | } 42 | 43 | // UseRequest 增加请求中间件 44 | func Use(mdl ClientMiddleware) *Client { 45 | return defClient.Use(mdl) 46 | } 47 | 48 | // UseResponse 增加响应中间件 49 | func UseResponse(mdl responseMiddlewareHandler) *Client { 50 | return defClient.TransformResponse(mdl) 51 | } 52 | 53 | // BaseURL 设置url前缀 54 | func BaseURL(url string) *Client { 55 | return defClient.BaseURL(url) 56 | } 57 | 58 | // BaseURL 设置url前缀 59 | func Retry(n int) *Client { 60 | return defClient.Retry(n) 61 | } 62 | 63 | // Head 发起 head 请求 64 | func Head(url string, data ...interface{}) (*Response, error) { 65 | return defClient.Head(url, data...) 66 | } 67 | 68 | // Get 发起 get 请求, query 查询参数 69 | func Get(url string, data ...interface{}) (*Response, error) { 70 | return defClient.Get(url, data...) 71 | } 72 | 73 | // Post 发起 post 请求,body 是请求带的参数,可使用json字符串或者结构体 74 | func Post(url string, data ...interface{}) (*Response, error) { 75 | return defClient.Post(url, data...) 76 | } 77 | 78 | // Put 发起 put 请求,body 是请求带的参数,可使用json字符串或者结构体 79 | func Put(url string, data ...interface{}) (*Response, error) { 80 | return defClient.Put(url, data...) 81 | } 82 | 83 | // Del 发起 delete 请求,body 是请求带的参数,可使用json字符串或者结构体 84 | func Del(url string, data ...interface{}) (*Response, error) { 85 | return defClient.Del(url, data...) 86 | } 87 | 88 | // Patch 发起 patch 请求,body 是请求带的参数,可使用json字符串或者结构体 89 | func Patch(url string, data ...interface{}) (*Response, error) { 90 | return defClient.Patch(url, data...) 91 | } 92 | 93 | // Options 发起 options 请求,query 查询参数 94 | func Options(url string, data ...interface{}) (*Response, error) { 95 | return defClient.Options(url, data...) 96 | } 97 | -------------------------------------------------------------------------------- /http/v2/http_test.go: -------------------------------------------------------------------------------- 1 | package http_test 2 | 3 | import ( 4 | "github.com/go-eyas/toolkit/log" 5 | "github.com/go-eyas/toolkit/http/v2" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestGet(t *testing.T) { 11 | h := http.Use(http.AccessLogger(log.SugaredLogger)).Timeout(time.Second * 10). 12 | BaseURL("https://api.github.com"). 13 | BaseURL("/repos") 14 | 15 | res, err := h.Get("/eyasliu/blog/issues", map[string]interface{}{ 16 | "per_page": 1, 17 | }) 18 | if err != nil { 19 | panic(err) 20 | } 21 | var s []struct { 22 | URL string `json:"url"` 23 | Title string 24 | } 25 | err = res.JSON(&s) 26 | if err != nil { 27 | panic(err) 28 | } 29 | if len(s) > 0 { 30 | t.Logf("get res struct: %+v", s) 31 | } 32 | } 33 | 34 | func TestError(t *testing.T) { 35 | h := http.New().Use(http.AccessLogger(log.SugaredLogger)) 36 | res, err := h.Header("just-test", "1234"). 37 | Header("func-header", func() string {return "val in fn"}). 38 | Get("https://api.github.com/repos/eyasliu/blog/issuesx") 39 | if err == nil { 40 | panic("should get 404 error") 41 | } 42 | 43 | res, err = h.Header("a", "1").Get("", nil) 44 | if err != nil { 45 | t.Logf("success empty url, statusCode=%d body=%s error=%s", res.Status(), res.String(), err.Error()) 46 | } else { 47 | panic("should error") 48 | } 49 | 50 | res, err = h.Header("b", "2").Type("form").Post("https://api.github.com/repos/eyasliu/blog/issuesx", map[string]interface{}{"hello": "test"}) 51 | if err != nil { 52 | t.Logf("success empty url, statusCode=%d body=%s error=%s", res.Status(), res.String(), err.Error()) 53 | } else { 54 | panic("should error") 55 | } 56 | h.Safe(false) 57 | h.Header("c", "3") 58 | res, err = h.Header("d", "4").Post("https://api.github.com/repos/eyasliu/blog/issuesx", map[string]interface{}{"hello": "test"}) 59 | if err != nil { 60 | t.Logf("success empty url, statusCode=%d body=%s error=%s", res.Status(), res.String(), err.Error()) 61 | } else { 62 | panic("should error") 63 | } 64 | h = h.Safe(true).BaseURL("http://notexistdomain.qwer") 65 | h.Type("form").Post("/", `file=@file:./http.go`) 66 | 67 | } 68 | 69 | // func TestProxy(t *testing.T) { 70 | // h := Proxy("http://127.0.0.1:1080") 71 | // res, err := h.Get("https://www.google.com", map[string]string{ 72 | // "hl": "zh-Hans", 73 | // }) 74 | // if err != nil { 75 | // panic(err) 76 | // } 77 | // t.Logf("google html: %s", res.String()) 78 | // } 79 | 80 | -------------------------------------------------------------------------------- /http/v2/readme.md: -------------------------------------------------------------------------------- 1 | # HTTP 客户端 2 | 3 | 封装http 客户端 4 | 5 | ## 使用 6 | 7 | ```go 8 | import ( 9 | "fmt" 10 | "github.com/go-eyas/toolkit/http/v2" 11 | ) 12 | 13 | func main() { 14 | h := http. 15 | TransformRequest(func(client *http.Client, req *nethttp.Request) *http.Client { 16 | fmt.Printf("HTTP SEND %s %s header=%v\n", req.Method, req.URL, req.Header) 17 | return client 18 | }). 19 | TransformResponse(func (c *http.Client, req *nethttp.Request, resp *http.Response) *http.Response { 20 | fmt.Printf("HTTP RECV %s %s %d\n", req.Method, req.URL, resp.StatusCode()) 21 | if resp.StatusCode() >= 400 { 22 | resp.SetBody([]byte("error! error!")) 23 | } 24 | return resp 25 | }). 26 | Type("json"). 27 | Use(http.AccessLogger(logger)) 28 | Header("Authorization", "Bearer xxxxxxxxxxxxxxx"). 29 | Header("x-test", func() string { return "in func string" }) 30 | Header("x-test2", func(cli *http.Client) string { return "in func string2" }) 31 | UserAgent("your custom user-agent"). 32 | Cookie("sid", "sgf2fdas"). 33 | BaseURL("https://api.github.com"). 34 | Config(&http.Config{ 35 | BaseURL: "/api", // 叠加 36 | }) 37 | 38 | res, err := h.Get("/repos/eyasliu/blog/issues", map[string]string{ 39 | "per_page": 1, 40 | }) 41 | 42 | // 获取字符串 43 | fmt.Printf("print string: %s\n", res.String()) 44 | // 获取字节 45 | fmt.Printf("print bytes: %v", res.Byte()) 46 | 47 | // 绑定结构体 48 | s := []struct { 49 | URL string `json:"url"` 50 | Title string `json:"title"` 51 | }{} 52 | res.JSON(&s) 53 | fmt.Printf("print Struct: %v", s) 54 | 55 | // 使用代理 56 | res, err := http.Proxy("http://127.0.0.1:1080").Get("https://www.google.com", map[string]string{ 57 | "hl": "zh-Hans", 58 | }) 59 | fmt.Printf("google html: %s", res.String()) 60 | } 61 | ``` 62 | 63 | ## 使用指南 64 | 65 | #### 请求示例 66 | 67 | ```go 68 | // get url, 第二个参数可忽略 69 | http.Get("https://api.github.com") 70 | 71 | // 带查询参数 72 | http.Get("https://www.google.com", "hl=zh-Hans") // 查询参数可以是字符串 73 | http.Get("https://www.google.com", map[string]string{ 74 | "hl": "zh-Hans", 75 | }) // 可以是map 76 | http.Get("https://www.google.com", struct{ 77 | HL string `json:"hl"` 78 | }{"zh-Hans"}) // 可以是结构体,使用json key作为查询参数的key 79 | 80 | // post 请求,第二个参数可忽略 81 | http.Post("https://api.github.com") 82 | 83 | // post 带json body参数 84 | http.Post("https://api.github.com", `{"hello": "world"}`) // 可以是字符串 85 | http.Post("https://api.github.com", map[string]interface{}{"hello": "world"}) // 可以是map 86 | http.Post("https://api.github.com", struct{ 87 | Hello string `json:"hello"` 88 | }{"world"}) // 可以是结构体,使用json 序列化字符串 89 | 90 | // post 带 查询参数,带json body 参数 91 | http.Query("hl=zh-Hans").Post("https://api.github.com", `{"hello": "world"}`) 92 | 93 | // post form表单 94 | http.Type("multipart").Post("https://api.github.com", map[string]interface{}{"hello": "world"}) 95 | // post 上传文件,会以表单提交 96 | http.Post("https://api.github.com", "name=@file:./example_file.txt&name=@file:./example_file.txt") 97 | 98 | // post 上传多文件,使用file文件流 99 | http.Post("https://api.github.com", "name=@file:./example_file.txt&name=@file:./example_file.txt") 100 | 101 | // put, 和post完全一致 102 | http.Put("https://api.github.com") 103 | 104 | // delete, 和post完全一致 105 | http.Del("https://api.github.com") 106 | 107 | // patch, 和post完全一致 108 | http.Patch("https://api.github.com") 109 | 110 | // head, 和get完全一致 111 | http.Head("https://api.github.com") 112 | 113 | // options, 和get完全一致 114 | http.Options("https://api.github.com") 115 | ``` 116 | 117 | #### 响应示例 118 | 119 | ```go 120 | res, err := http.Options("https://api.github.com", nil) 121 | 122 | // 错误信息 123 | if err != nil { 124 | err.Error() // 错误信息 125 | } 126 | res.Error() // 与上面等价 127 | 128 | // 响应数据 129 | // 将响应数据转为字符串 130 | var str string = res.String() 131 | 132 | // 将响应数据转为字节 133 | var bt []byte = res.Byte() 134 | 135 | // 获取响应状态码 136 | var statusCode = res.Status() 137 | 138 | // 获取响应的 header 139 | var headers http.Header = res.Header() 140 | 141 | // 获取响应的 cookies 142 | var cookies []*http.Cookie = res.Cookies() 143 | 144 | // 与 json 结构体绑定 145 | type ResTest struct { 146 | Hello string `json:"hello"` 147 | } 148 | rt := &ResTest{} 149 | res.JSON(rt) 150 | 151 | // 与 xml 结构体绑定 152 | type ResTest struct { 153 | Hello string `xml:"hello"` 154 | } 155 | rt := &ResTest{} 156 | res.XML(rt) 157 | ``` 158 | 159 | **注意:** 160 | 161 | * http的响应状态码 >= 400 时会被视为错误,err 值是 `fmt.Errorf("http status code %d", statusCode)` 162 | 163 | #### 链式安全调用 164 | 165 | 默认是链式安全的,即在链式调用的时候,返回的 `*http.Client` 是个新的实例,不会影响之前链式阶段的配置,如 166 | 167 | ```go 168 | cli := http.BaseURL("http://xxx.com") 169 | cli2 := cli.BaseURL("/api") // 需要重新给 cli 赋值才会让其生效 170 | 171 | cli.Get("/users") // GET http://xxx.com/users 172 | cli2.Get("/users") // GET http://xxx.com/api/users 173 | ``` 174 | 175 | 如果不希望链式安全调用,可以关闭 176 | 177 | ```go 178 | cli := http.Safe(false) // 关闭后要赋值一次 179 | 180 | // 下面的赋值都将生效,链式安全关闭后,赋值和不赋值没有区别 181 | cli.BaseURL("http://xxx.com") 182 | cli = cli.BaseURL("/api") // 可赋值,可也不赋值,没有区别 183 | 184 | cli.Get("/users") // GET http://xxx.com/api/users 185 | 186 | // 可以后面再进行开启 187 | cli = cli.Safe(true) 188 | cli = cli.BaseURL("/v1") // 开启后如果不赋值将不会生效 189 | 190 | ``` 191 | 192 | #### 提前设置通用项 193 | 194 | ```go 195 | h := http.Header("Authorization", "Bearer xxxxxxxxxxxxxxx"). // 设置header 196 | UserAgent("your custom user-agent"). // 设置 useragent 197 | Timeout(10 * time.Second). // 设置请求超时时间 198 | Query("lang=zh_ch"). // 设置查询参数 199 | Proxy("http://127.0.0.1:1080") // 设置代理 200 | 201 | 202 | h.Get("xxxx", nil) 203 | ``` 204 | 205 | #### 中间件支持 206 | 207 | 可以增加请求中间件和响应中间件,用于在请求或响应中改变内部操作 208 | 209 | ```go 210 | http.TransformRequest(func(client *http.Client, req *nethttp.Request) *http.Client { 211 | fmt.Printf("HTTP SEND %s %s header=%v\n", req.Method, req.URL, req.Header) 212 | return client 213 | }). 214 | TransformResponse(func (c *http.Client, req *nethttp.Request, resp *http.Response) *http.Response { 215 | fmt.Printf("HTTP RECV %s %s %d\n", req.Method, req.URL, resp.StatusCode()) 216 | if resp.StatusCode() >= 400 { 217 | resp.SetBody([]byte("error! error!")) 218 | } 219 | return resp 220 | }) 221 | ``` 222 | 223 | #### 代理设置 224 | 225 | 默认会获取环境变量 `http_proxy` 的值使用代理,但是可以手动指定 226 | 227 | ```go 228 | http.Proxy("http://127.0.0.1:1080").Get("https://www.google.com", map[string]string{ 229 | "hl": "zh-Hans", 230 | }) 231 | 232 | // 临时取消代理 233 | http.Proxy("").Get("https://www.google.com", map[string]string{ 234 | "hl": "zh-Hans", 235 | }) 236 | ``` 237 | 238 | #### 提交方式 239 | 240 | 也就是 `Type(t string)` 函数支持的值 241 | 242 | ``` 243 | "text/html" uses "html" 244 | "application/json" uses "json" 245 | "application/xml" uses "xml" 246 | "text/plain" uses "text" 247 | "application/x-www-form-urlencoded" uses "urlencoded", "form" or "form-data" 248 | ``` 249 | 250 | 如果是文件上传,会自动设置为 `multipart`,无需手动指定 251 | 252 | ## godoc 253 | 254 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/http/v2) 255 | 256 | ## Tranks 257 | 258 | 部分 api 和代码借鉴了以下库 259 | 260 | * [GoFrame](https://goframe.org/) 部分api,思路借鉴自 GoFrame 261 | * axios 中间件思路来源于 axios -------------------------------------------------------------------------------- /http/v2/request_api.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | ) 7 | 8 | // 批量设置 HTTP Client 的配置项 9 | type ClientConfig struct { 10 | TransformRequest requestMiddlewareHandler // 请求中间件 11 | TransformResponse responseMiddlewareHandler // 响应中间件 12 | Headers map[string]string // 预设 Header 13 | Cookies map[string]string // 预设请求 Cookie 14 | Type string // 预设请求数据类型, 该配置为空时忽略 15 | UserAgent string // 预设 User-Agent, 该配置为空时忽略 16 | Proxy string // 预设请求代理,该配置为空时忽略 17 | BaseURL string // 预设 url 前缀,叠加 18 | Query interface{} // 预设请求查询参数 19 | Timeout time.Duration // 请求超时时间,该配置为 0 时忽略 20 | Retry int // 重试次数,该配置为 0 时忽略 21 | } 22 | 23 | // Config 批量设置 HTTP Client 的配置项 24 | func (c *Client) Config(conf *ClientConfig) *Client { 25 | cli := c.getSetting() 26 | 27 | if conf.TransformRequest != nil { 28 | cli = cli.TransformRequest(conf.TransformRequest) 29 | } 30 | 31 | if conf.TransformResponse != nil { 32 | cli = cli.TransformResponse(conf.TransformResponse) 33 | } 34 | 35 | if len(conf.Headers) > 0 { 36 | for k, v := range conf.Headers { 37 | cli = cli.Header(k, v) 38 | } 39 | } 40 | 41 | if len(conf.Cookies) > 0 { 42 | for k, v := range conf.Cookies { 43 | cli = cli.Cookie(k, v) 44 | } 45 | } 46 | 47 | if conf.Type != "" { 48 | cli = cli.Type(conf.Type) 49 | } 50 | if conf.UserAgent != "" { 51 | cli = cli.UserAgent(conf.UserAgent) 52 | } 53 | 54 | if conf.Proxy != "" { 55 | cli = cli.Proxy(conf.Proxy) 56 | } 57 | 58 | if conf.BaseURL != "" { 59 | cli = cli.BaseURL(conf.BaseURL) 60 | } 61 | 62 | if conf.Query != nil { 63 | cli = cli.Query(conf.Query) 64 | } 65 | 66 | if conf.Timeout > 0 { 67 | cli = cli.Timeout(conf.Timeout) 68 | } 69 | if conf.Retry > 0 { 70 | cli = cli.Retry(conf.Retry) 71 | } 72 | 73 | return cli 74 | } 75 | 76 | func (c *Client) SetClient(rawClient http.Client) *Client { 77 | cli := c.getSetting() 78 | cli.Client = rawClient 79 | return cli 80 | } 81 | 82 | // Header 设置请求 Header 83 | func (c *Client) Header(k string, v interface{}) *Client { 84 | cli := c.getSetting() 85 | cli.headers[k] = v 86 | return cli 87 | } 88 | 89 | // TransformRequest 增加请求中间件,可以在请求发起前对整个请求做前置处理,比如修改 body, header, proxy, url, 配置项等等,也可以获取请求的各种数据,如自定义日志,类似于 axios 的 transformRequest 90 | func (c *Client) TransformRequest(h requestMiddlewareHandler) *Client { 91 | cli := c.getSetting() 92 | cli.reqMdls = append(cli.reqMdls, h) 93 | return cli 94 | } 95 | 96 | // TransformResponse 增加响应中间件,可以在收到请求后第一时间对请求做处理,如验证 status code,验证 body 数据,甚至重置 body 数据,更改响应等等任何操作,类似于 axios 的 transformResponse 97 | func (c *Client) TransformResponse(h responseMiddlewareHandler) *Client { 98 | cli := c.getSetting() 99 | cli.resMdls = append(cli.resMdls, h) 100 | return cli 101 | } 102 | 103 | type ClientMiddleware interface { 104 | TransformRequest(*Client, *http.Request) *Client 105 | TransformResponse(*Client, *http.Request, *Response) *Response 106 | } 107 | 108 | // Use 应用中间件,实现了 TransformRequest 和 TransformResponse 接口的中间件,如 http.Use(http.AccessLogger()),通常用于成对的请求响应处理 109 | func (c *Client) Use(mdl ClientMiddleware) *Client { 110 | cli := c.getSetting() 111 | cli = cli.TransformRequest(mdl.TransformRequest) 112 | cli = cli.TransformResponse(mdl.TransformResponse) 113 | return cli 114 | } 115 | 116 | // Type 请求提交方式,默认json 117 | func (c *Client) Type(ty string) *Client { 118 | cli := c.getSetting() 119 | ct, ok := Types[ty] 120 | if !ok { 121 | ct = Types[TypeJSON] 122 | } 123 | cli = cli.Header(headerContentTypeKey, ct) 124 | return cli 125 | } 126 | 127 | // UserAgent 设置请求 user-agent 128 | func (c *Client) UserAgent(name string) *Client { 129 | cli := c.getSetting() 130 | cli = cli.Header("User-Agent", name) 131 | return cli 132 | } 133 | 134 | // Cookie 设置请求 Cookie 135 | func (c *Client) Cookie(k string, v string) *Client { 136 | cli := c.getSetting() 137 | cli.cookies[k] = v 138 | return cli 139 | } 140 | 141 | // Proxy 设置请求代理 142 | func (c *Client) Proxy(url string) *Client { 143 | cli := c.getSetting() 144 | cli.proxy = url 145 | return cli 146 | } 147 | 148 | // Query 增加查询参数, 如果设置过多次,将会叠加拼接 149 | func (c *Client) Query(query interface{}) *Client { 150 | cli := c.getSetting() 151 | cli.queryArgs = append(cli.queryArgs, query) 152 | return cli 153 | } 154 | 155 | // Timeout 请求超时时间 156 | func (c *Client) Timeout(timeout time.Duration) *Client { 157 | cli := c.getSetting() 158 | cli.timeout = timeout 159 | return cli 160 | } 161 | 162 | // BaseURL 设置url前缀 163 | func (c *Client) BaseURL(url string) *Client { 164 | cli := c.getSetting() 165 | cli.baseURL += url 166 | return cli 167 | } 168 | 169 | // BaseURL 设置url前缀 170 | func (c *Client) Retry(n int) *Client { 171 | cli := c.getSetting() 172 | cli.retryCount = n 173 | return cli 174 | } 175 | 176 | // Head 发起 head 请求 177 | func (c *Client) Head(url string, data ...interface{}) (*Response, error) { 178 | cli := c.getSetting() 179 | return cli.DoRequest("HEAD", url, data...) 180 | } 181 | 182 | // Get 发起 get 请求 183 | func (c *Client) Get(url string, args ...interface{}) (*Response, error) { 184 | cli := c.getSetting() 185 | return cli.DoRequest("GET", url, args...) 186 | } 187 | 188 | // Post 发起 post 请求 189 | func (c *Client) Post(url string, args ...interface{}) (*Response, error) { 190 | cli := c.getSetting() 191 | return cli.DoRequest("POST", url, args...) 192 | } 193 | 194 | // Put 发起 put 请求 195 | func (c *Client) Put(url string, args ...interface{}) (*Response, error) { 196 | cli := c.getSetting() 197 | return cli.DoRequest("Put", url, args...) 198 | } 199 | 200 | // Del 发起 del 请求 201 | func (c *Client) Del(url string, args ...interface{}) (*Response, error) { 202 | cli := c.getSetting() 203 | return cli.DoRequest("DELETE", url, args...) 204 | } 205 | 206 | // Patch 发起 patch 请求 207 | func (c *Client) Patch(url string, args ...interface{}) (*Response, error) { 208 | cli := c.getSetting() 209 | return cli.DoRequest("Patch", url, args...) 210 | } 211 | 212 | // Options 发起 get 请求 213 | func (c *Client) Options(url string, args ...interface{}) (*Response, error) { 214 | cli := c.getSetting() 215 | return cli.DoRequest("Options", url, args...) 216 | } 217 | -------------------------------------------------------------------------------- /http/v2/response.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/json" 5 | "encoding/xml" 6 | "fmt" 7 | "io/ioutil" 8 | "net/http" 9 | ) 10 | 11 | type ResponseError struct { 12 | errs []error 13 | } 14 | func (e *ResponseError) Error() string { 15 | msgs := "" 16 | for _, err := range e.errs { 17 | if len(msgs) > 0 { 18 | msgs += "\n" 19 | } 20 | msgs += err.Error() 21 | } 22 | return msgs 23 | } 24 | 25 | func (e *ResponseError) Add(err error) { 26 | if err != nil { 27 | e.errs = append(e.errs, err) 28 | } 29 | } 30 | 31 | type Response struct { 32 | Client *Client 33 | Request *http.Request 34 | Response *http.Response 35 | body []byte 36 | Err *ResponseError 37 | IsRead bool 38 | } 39 | 40 | func newResponse(request *Client, r *http.Request) *Response { 41 | return &Response{ 42 | Client: request, 43 | Request: r, 44 | Err: &ResponseError{errs: make([]error, 0)}, 45 | } 46 | } 47 | 48 | func (rp *Response) ready() { 49 | //if rp.Response == nil { 50 | // return 51 | //} 52 | if code := rp.StatusCode(); code >= 400 { 53 | rp.AddError(fmt.Errorf("http status code %d", code)) 54 | } 55 | } 56 | 57 | // StatusCode 获取 HTTP 响应状态码 58 | func (rp *Response) StatusCode() int { 59 | if rp.Response == nil { 60 | return 0 61 | } 62 | return rp.Response.StatusCode 63 | } 64 | 65 | // Status StatusCode 别名 66 | func (rp *Response) Status() int { 67 | return rp.StatusCode() 68 | } 69 | 70 | // GetError 获取错误 71 | func (rp *Response) GetError() error { 72 | if len(rp.Err.errs) > 0 { 73 | return rp.Err 74 | } 75 | return nil 76 | } 77 | 78 | // AddError 手动增加错误 79 | func (rp *Response) AddError(err error) { 80 | rp.Err.Add(err) 81 | } 82 | 83 | // ReadAllBody 读取响应的 Body 流 84 | func (rp *Response) ReadAllBody() (bt []byte, err error) { 85 | if rp.Response != nil && !rp.IsRead { 86 | bt, err = ioutil.ReadAll(rp.Response.Body) 87 | rp.body = bt 88 | rp.IsRead = true 89 | return 90 | } 91 | bt = rp.body 92 | return 93 | } 94 | 95 | // Body 获取响应的原始数据 96 | func (rp *Response) Body() (bt []byte) { 97 | bt, _ = rp.ReadAllBody() 98 | return 99 | } 100 | 101 | // Byte 同 Body() 102 | func (rp *Response) Byte() (bt []byte) { 103 | bt, _ = rp.ReadAllBody() 104 | return 105 | } 106 | 107 | // SetBody 重置响应的 Body 108 | func (rp *Response) SetBody(bt []byte) { 109 | rp.IsRead = true 110 | rp.body = bt 111 | } 112 | 113 | // String 将响应的 Body 数据转成字符串 114 | func (rp *Response) String() string { 115 | if !rp.IsRead { rp.ReadAllBody() } 116 | return string(rp.Body()) 117 | } 118 | 119 | // Error 实现 error interface 120 | func (rp *Response) Error() string { 121 | return rp.Err.Error() 122 | } 123 | 124 | func (rp *Response) Header() http.Header { 125 | if rp.Response != nil { 126 | return rp.Response.Header 127 | } 128 | return http.Header{} 129 | } 130 | 131 | func (rp *Response) Cookie() []*http.Cookie { 132 | if rp.Response != nil { 133 | return rp.Response.Cookies() 134 | } 135 | return nil 136 | } 137 | 138 | // JSON 使用 JSON 解析响应 Body 数据 139 | func (rp *Response) JSON(v interface{}) error { 140 | if !rp.IsRead { rp.ReadAllBody() } 141 | return json.Unmarshal(rp.body, v) 142 | } 143 | 144 | // XML 使用 XML 解析响应 Body 数据 145 | func (rp *Response) XML(v interface{}) error { 146 | if !rp.IsRead { rp.ReadAllBody() } 147 | return xml.Unmarshal(rp.body, v) 148 | } 149 | 150 | -------------------------------------------------------------------------------- /http/v2/transform.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | ) 9 | 10 | type printLogger interface{ 11 | Debug(...interface{}) 12 | } 13 | 14 | type consoleLogger struct {} 15 | func (consoleLogger) Debug(a ...interface{}) { 16 | fmt.Println(a...) 17 | } 18 | 19 | // 打印 HTTP 请求响应日志 20 | type HttpLogger struct { 21 | Logger printLogger 22 | MaxLoggerBodyLen int64 23 | } 24 | 25 | func AccessLogger(logger printLogger, bodyLimit ...int64) *HttpLogger { 26 | var limit int64 = 2048 27 | if len(bodyLimit) > 0 { 28 | limit = bodyLimit[0] 29 | } 30 | return &HttpLogger{Logger: logger, MaxLoggerBodyLen: limit} 31 | } 32 | 33 | var Logger = AccessLogger(consoleLogger{}) 34 | 35 | // 打印 HTTP 请求日志 36 | func (l *HttpLogger) TransformRequest(c *Client, req *http.Request) *Client { 37 | var body []byte 38 | logtext := fmt.Sprintf("HTTP SEND %s %s header=%v", req.Method, req.URL, req.Header) 39 | if req.Method != "GET" && req.Body != nil { 40 | logtext = fmt.Sprintf("%s size=%d", logtext, req.ContentLength) 41 | // 如果body太长,估计是文件上传,不打印,也不侵入,并且太长的body也会妨碍控制台输出 42 | if req.ContentLength < l.MaxLoggerBodyLen { 43 | body, _ = ioutil.ReadAll(req.Body) 44 | req.Body = ioutil.NopCloser(bytes.NewReader(body)) 45 | logtext = fmt.Sprintf("%s body=%s", logtext, string(body)) 46 | } 47 | } 48 | l.Logger.Debug(logtext) 49 | return c 50 | } 51 | 52 | // 打印 HTTP 响应日志 53 | // warning: 会先读取一遍 Response.Body,如果该中间件导致了 http 下载异常问题,请关闭该中间件 54 | func (l *HttpLogger) TransformResponse(c *Client, req *http.Request, resp *Response) *Response { 55 | logText := fmt.Sprintf("HTTP RECV %s %s %d", req.Method, req.URL, resp.StatusCode()) 56 | if resp.IsRead { 57 | resp.ReadAllBody() 58 | } 59 | logText += " " + resp.String() 60 | l.Logger.Debug(logText) 61 | return resp 62 | } 63 | -------------------------------------------------------------------------------- /http/v2/types.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | // Types we support. 4 | const ( 5 | TypeJSON = "json" 6 | TypeXML = "xml" 7 | TypeUrlencoded = "urlencoded" 8 | TypeForm = "form" 9 | TypeFormData = "form-data" 10 | TypeHTML = "html" 11 | TypeText = "text" 12 | TypeMultipart = "multipart" 13 | ) 14 | var Types = map[string]string{ 15 | TypeJSON: "application/json", 16 | TypeXML: "application/xml", 17 | TypeForm: "application/x-www-form-urlencoded", 18 | TypeFormData: "application/x-www-form-urlencoded", 19 | TypeUrlencoded: "application/x-www-form-urlencoded", 20 | TypeHTML: "text/html", 21 | TypeText: "text/plain", 22 | TypeMultipart: "multipart/form-data", 23 | } 24 | 25 | var headerContentTypeKey = "Content-Type" 26 | -------------------------------------------------------------------------------- /http/v2/util.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/url" 7 | "os" 8 | ) 9 | 10 | func toMap(v interface{}) map[string]interface{} { 11 | m := map[string]interface{}{} 12 | bt, _ := json.Marshal(v) 13 | json.Unmarshal(bt, &m) 14 | return m 15 | } 16 | 17 | func toString(v interface{}) string { 18 | if s, ok := v.(string); ok { 19 | return s 20 | } else if bt, ok := v.([]byte); ok { 21 | return string(bt) 22 | } else { 23 | return fmt.Sprintf("%v", v) 24 | } 25 | } 26 | 27 | func toUrlEncoding(data ...interface{}) string { 28 | if len(data) == 0 { 29 | return "" 30 | } 31 | urlVals := url.Values{} 32 | for _, q := range data { 33 | mp := toMap(q) 34 | for k, v := range mp { 35 | val := "" 36 | switch v.(type) { 37 | case string: 38 | val = v.(string) 39 | case int,int64,[]byte,float32,float64: 40 | val = toString(v) 41 | default: 42 | continue 43 | } 44 | urlVals.Add(k, val) 45 | } 46 | } 47 | return urlVals.Encode() 48 | } 49 | 50 | func fileExist(p string) bool { 51 | if _, err := os.Stat(p); !os.IsNotExist(err) { 52 | return true 53 | } 54 | return false 55 | } -------------------------------------------------------------------------------- /log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "path/filepath" 8 | "time" 9 | 10 | rotatelogs "github.com/lestrrat-go/file-rotatelogs" 11 | "go.uber.org/zap" 12 | "go.uber.org/zap/zapcore" 13 | ) 14 | 15 | var defaultLogger = zap.New(zapcore.NewCore(zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), zapcore.AddSync(os.Stdout), zap.DebugLevel)) 16 | 17 | var SugaredLogger = defaultLogger.Sugar() 18 | var Logger = defaultLogger 19 | 20 | // LogConfig 日志配置 21 | type LogConfig struct { 22 | Level string // 日志级别 23 | Path string // 路径 24 | Name string // 文件名称 25 | Console bool // 是否输出到控制台 26 | MaxAge time.Duration // 保存多久的日志,默认15天 27 | RotationTime time.Duration // 多久分割一次日志 28 | Caller bool // 是否打印文件行号 29 | SplitLevel bool // 是否把不同级别的日志打到不同文件 30 | } 31 | 32 | var printCaller = false 33 | 34 | // Init 初始化日志库 35 | func Init(conf *LogConfig) error { 36 | // 默认保存最近15天日志 37 | if conf.MaxAge == 0 { 38 | conf.MaxAge = time.Hour * 24 * 15 39 | } 40 | if conf.RotationTime == 0 { 41 | conf.RotationTime = time.Hour 42 | } 43 | printCaller = conf.Caller 44 | return newLog(conf) 45 | } 46 | 47 | func newLog(conf *LogConfig) error { 48 | // 建立日志目录 49 | if err := os.MkdirAll(conf.Path+"/", os.ModePerm); err != nil { 50 | fmt.Println("init log path error.") 51 | return err 52 | } 53 | // 设置一些基本日志格式 具体含义还比较好理解,直接看zap源码也不难懂 54 | encoder := zapcore.NewConsoleEncoder(zapcore.EncoderConfig{ 55 | MessageKey: "msg", 56 | LevelKey: "level", 57 | EncodeLevel: zapcore.CapitalLevelEncoder, 58 | TimeKey: "ts", 59 | EncodeTime: func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { 60 | enc.AppendString(t.Format("2006-01-02 15:04:05.000")) 61 | }, 62 | CallerKey: "file", 63 | EncodeCaller: zapcore.ShortCallerEncoder, 64 | EncodeDuration: func(d time.Duration, enc zapcore.PrimitiveArrayEncoder) { 65 | enc.AppendInt64(int64(d) / 1000000) 66 | }, 67 | }) 68 | 69 | level := new(zapcore.Level) 70 | err := level.Set(conf.Level) 71 | if err != nil { 72 | return err 73 | } 74 | 75 | lv := *level 76 | 77 | cores := []zapcore.Core{} 78 | 79 | if conf.SplitLevel { 80 | if lv <= zapcore.DebugLevel { 81 | debugLevel := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 82 | return lvl <= zapcore.DebugLevel 83 | }) 84 | 85 | debugWriter, err := getWriter(conf.Path+"/"+conf.Name+"_debug", conf) 86 | if err != nil { 87 | return err 88 | } 89 | cores = append(cores, zapcore.NewCore(encoder, zapcore.AddSync(debugWriter), debugLevel)) 90 | } 91 | if lv <= zapcore.InfoLevel { 92 | infoLevel := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 93 | return lvl < zapcore.WarnLevel && lvl > zapcore.DebugLevel 94 | }) 95 | 96 | infoWriter, err := getWriter(conf.Path+"/"+conf.Name+"_info", conf) 97 | if err != nil { 98 | return err 99 | } 100 | 101 | cores = append(cores, zapcore.NewCore(encoder, zapcore.AddSync(infoWriter), infoLevel)) 102 | } 103 | 104 | if lv <= zapcore.WarnLevel { 105 | warnLevel := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 106 | return lvl >= zapcore.WarnLevel 107 | }) 108 | warnWriter, err := getWriter(conf.Path+"/"+conf.Name+"_error", conf) 109 | if err != nil { 110 | return err 111 | } 112 | cores = append(cores, zapcore.NewCore(encoder, zapcore.AddSync(warnWriter), warnLevel)) 113 | } else { 114 | // 级别是 error 以上 115 | errorLevel := lv 116 | errorWriter, err := getWriter(conf.Path+"/"+conf.Name+"_error", conf) 117 | if err != nil { 118 | return err 119 | } 120 | cores = append(cores, zapcore.NewCore(encoder, zapcore.AddSync(errorWriter), errorLevel)) 121 | } 122 | } else { 123 | writer, err := getWriter(conf.Path+"/"+conf.Name, conf) 124 | if err != nil { 125 | return err 126 | } 127 | cores = append(cores, zapcore.NewCore(encoder, zapcore.AddSync(writer), lv)) 128 | } 129 | 130 | if conf.Console { 131 | consoleWriter := os.Stdout 132 | cores = append(cores, zapcore.NewCore(encoder, zapcore.AddSync(consoleWriter), lv)) 133 | } 134 | 135 | // 最后创建具体的Logger 136 | core := zapcore.NewTee(cores...) 137 | 138 | Logger = zap.New(core) 139 | SugaredLogger = Logger.Sugar() 140 | 141 | return nil 142 | } 143 | 144 | func getWriter(filename string, conf *LogConfig) (io.Writer, error) { 145 | hook, err := rotatelogs.New( 146 | filename+".%Y-%m-%d/%H.log", 147 | rotatelogs.WithLinkName(filepath.Join(conf.Path, filename+".log")), 148 | rotatelogs.WithMaxAge(conf.MaxAge), 149 | rotatelogs.WithRotationTime(conf.RotationTime), 150 | ) 151 | 152 | if err != nil { 153 | return nil, err 154 | } 155 | return hook, nil 156 | } 157 | -------------------------------------------------------------------------------- /log/log_test.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestLog(t *testing.T) { 8 | err := Init(&LogConfig{ 9 | Level: "debug", 10 | Path: ".logs", 11 | Name: "api", 12 | Console: true, 13 | Caller: true, 14 | SplitLevel: true, 15 | }) 16 | if err != nil { 17 | panic(err) 18 | } 19 | 20 | Debug("is debug log") 21 | Info("is info log") 22 | Warn("is warn log") 23 | Error("is error log") 24 | // Panic("is panic log") 25 | 26 | Debugf("is debug log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 27 | Infof("is info log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 28 | Warnf("is warn log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 29 | Errorf("is error log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 30 | // Fatalf("is fatal log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 31 | // Panicf("is panic log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 32 | } 33 | -------------------------------------------------------------------------------- /log/logger.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | ) 7 | 8 | func getCaller() string { 9 | if !printCaller { 10 | return "" 11 | } 12 | funcName, _, line, ok := runtime.Caller(2) 13 | caller := "" 14 | if ok { 15 | caller = fmt.Sprintf("%s:%d", runtime.FuncForPC(funcName).Name(), line) 16 | } 17 | 18 | return caller 19 | } 20 | 21 | // Debugf 格式化日志 22 | func Debugf(s string, v ...interface{}) { 23 | if printCaller { 24 | SugaredLogger.Debugf("%s %s", getCaller(), fmt.Sprintf(s, v...)) 25 | } else { 26 | SugaredLogger.Debugf(s, v...) 27 | } 28 | } 29 | 30 | // Infof 格式化日志 31 | func Infof(s string, v ...interface{}) { 32 | if printCaller { 33 | SugaredLogger.Infof("%s %s", getCaller(), fmt.Sprintf(s, v...)) 34 | } else { 35 | SugaredLogger.Infof(s, v...) 36 | } 37 | } 38 | 39 | // Warnf 格式化日志 40 | func Warnf(s string, v ...interface{}) { 41 | if printCaller { 42 | SugaredLogger.Warnf("%s %s", getCaller(), fmt.Sprintf(s, v...)) 43 | } else { 44 | SugaredLogger.Warnf(s, v...) 45 | } 46 | } 47 | 48 | // Errorf 格式化日志 49 | func Errorf(s string, v ...interface{}) { 50 | if printCaller { 51 | SugaredLogger.Errorf("%s %s", getCaller(), fmt.Sprintf(s, v...)) 52 | } else { 53 | SugaredLogger.Errorf(s, v...) 54 | } 55 | } 56 | 57 | // Fatalf 格式化日志 58 | func Fatalf(s string, v ...interface{}) { 59 | if printCaller { 60 | SugaredLogger.Fatalf("%s %s", getCaller(), fmt.Sprintf(s, v...)) 61 | } else { 62 | SugaredLogger.Fatalf(s, v...) 63 | } 64 | } 65 | 66 | // Panicf 格式化日志 67 | func Panicf(s string, v ...interface{}) { 68 | if printCaller { 69 | SugaredLogger.Panicf("%s %s", getCaller(), fmt.Sprintf(s, v...)) 70 | } else { 71 | SugaredLogger.Panicf(s, v...) 72 | } 73 | } 74 | 75 | // Debug 打日志 76 | func Debug(v ...interface{}) { 77 | if printCaller { 78 | SugaredLogger.Debugf("%s %s", getCaller(), fmt.Sprint(v...)) 79 | } else { 80 | SugaredLogger.Debug(v...) 81 | } 82 | } 83 | 84 | // Info 打日志 85 | func Info(v ...interface{}) { 86 | if printCaller { 87 | SugaredLogger.Infof("%s %s", getCaller(), fmt.Sprint(v...)) 88 | } else { 89 | SugaredLogger.Info(v...) 90 | } 91 | } 92 | 93 | // Warn 打日志 94 | func Warn(v ...interface{}) { 95 | if printCaller { 96 | SugaredLogger.Warnf("%s %s", getCaller(), fmt.Sprint(v...)) 97 | } else { 98 | SugaredLogger.Warn(v...) 99 | } 100 | } 101 | 102 | // Error 打日志 103 | func Error(v ...interface{}) { 104 | if printCaller { 105 | SugaredLogger.Errorf("%s %s", getCaller(), fmt.Sprint(v...)) 106 | } else { 107 | SugaredLogger.Error(v...) 108 | } 109 | } 110 | 111 | // Panic 打日志 112 | func Panic(v ...interface{}) { 113 | if printCaller { 114 | SugaredLogger.Panicf("%s %s", getCaller(), fmt.Sprint(v...)) 115 | } else { 116 | SugaredLogger.Panic(v...) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /log/readme.md: -------------------------------------------------------------------------------- 1 | # 日志库 2 | 3 | * 日志滚动,保存15天内的日志,每1小时(整点)分割一次日志(可配置) 4 | * 日志文件分级保存 debug, info, error 5 | * 可选是否输出日志到控制台 6 | * 支持输出打日志的文件文件行号 7 | 8 | ## 使用 9 | 10 | 只支持单例log 11 | 12 | ```go 13 | 14 | // 使用前必须先初始化 15 | log.Init(&log.LogConfig{ 16 | Level: "info", // 日志级别 17 | Path: ".runtime/logs", // 日志保存路径 18 | Name: "api", // 日志文件名 19 | Console: true, // 是否把日志输出到控制台 20 | Caller: true, // 是否输出打日志的文件和行号,会影响性能 21 | MaxAge: time.Hour * 24 * 15, // 保存多久的日志,默认15天 22 | RotationTime: time.Hour, // 多久分割一次日志,默认一小时 23 | SplitLevel: true, // 是否把不同级别的日志打到不同文件,如果为false 则所有级别日志打到同一个文件 24 | }) 25 | 26 | log.Debug("is debug log") 27 | log.Info("is info log") 28 | log.Warn("is warn log") 29 | log.Error("is error log") 30 | log.Panic("is panic log") 31 | 32 | log.Debugf("is debug log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 33 | log.Infof("is info log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 34 | log.Warnf("is warn log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 35 | log.Errorf("is error log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 36 | log.Fatalf("is fatal log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 37 | log.Panicf("is panic log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 38 | 39 | ``` 40 | 41 | ## godoc 42 | 43 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/log) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # go 工具箱 2 | 3 | 为了快速使用通用功能,做一次通用封装 4 | 5 | # 使用 6 | 7 | 8 | ``` 9 | go get -u -v github.com/go-eyas/toolkit 10 | ``` 11 | 12 | # [HTTP 客户端 http](./http) 13 | 14 | ```go 15 | import "github.com/go-eyas/toolkit/http" 16 | 17 | github := http.BaseURL("https://api.github.com") 18 | res, err := github.Get("/repos/eyasliu/blog/issues") 19 | var data interface{} 20 | res.JSON(&data) 21 | ``` 22 | 23 | # [日志 log](./log) 24 | 25 | ```go 26 | import "github.com/go-eyas/toolkit/log" 27 | 28 | log.Init(&log.Config{}) 29 | log.Info("log init ok") 30 | log.Infof("is info log %s %d %v", "string", 123, map[string]string{"test": "hello"}) 31 | ``` 32 | 33 | # [Redis](./redis) 34 | 35 | ```go 36 | import "github.com/go-eyas/toolkit/redis" 37 | 38 | r, err := redis.Init(&redis.Config{ 39 | Cluster: false, // 是否集群 40 | Addrs: []string{"127.0.0.1:6379"}, // redis 地址,如果是集群则在数组上写多个元素 41 | Password: "", 42 | DB: 1, 43 | }) 44 | r.Set("tookit:test", `{"hello": "world"}`) 45 | str, err := r.Get("toolkit:test") 46 | 47 | ``` 48 | 49 | # [TCP](./tcp) 50 | 51 | **server** 52 | 53 | ```go 54 | import "github.com/go-eyas/toolkit/tcp" 55 | 56 | func main() { 57 | server, err := tcp.NewServer(&tcp.Config{ 58 | Network: "tcp", // 网络类型,不填默认 tcp 59 | // tcp 监听地址 60 | Addr: "127.0.0.1:6600", 61 | 62 | // 私有协议实现,不传将使用默认的私有协议实现 63 | // Parser: func([]byte) ([]byte, error) {}, 64 | // Packer: func(*Conn, []byte) ([][]byte, error){}, 65 | }) 66 | 67 | // 接收数据 68 | ch := server.Receive() 69 | for data := range ch { 70 | fmt.Printf("server receive: %v", data.Data) 71 | 72 | // 服务器收到数据后,响应发送一条数据到客户端 73 | err := data.Response([]byte("server receive your message")) 74 | } 75 | 76 | // 给所有连接都发送消息 77 | for connID, conn := range server.Sockets { 78 | fmt.Println("connID: ", connID) 79 | server.Send(conn, []byte("broadcast some message")) 80 | // or 81 | // server.SendConnID(connID, []byte("broadcast some message")) 82 | } 83 | } 84 | ``` 85 | 86 | **client** 87 | 88 | ```go 89 | import "github.com/go-eyas/toolkit/tcp" 90 | 91 | func main() { 92 | client, err := tcp.NewClient(&tcp.Config{ 93 | Network: "tcp", // 网络类型,不填默认 tcp 94 | // tcp 服务端地址 95 | Addr: "127.0.0.1:6600", 96 | 97 | // 私有协议实现,不传将使用默认的私有协议实现 98 | // Parser: func([]byte) ([]byte, error) {}, 99 | // Packer: func(*Conn, []byte) ([][]byte, error){}, 100 | }) 101 | 102 | // 接收数据 103 | ch := client.Receive() 104 | go func() { 105 | for msg := range ch { 106 | // msg.Data 经过 Parser 处理过的数据 107 | // msg.Conn tcp 连接实例 108 | fmt.Println("client receive:", string(msg.Data)) 109 | } 110 | }() 111 | 112 | // 发送数据,send 后将立马把数据传给 Packer 处理后,在发送到 tcp 连接 113 | err = client.Send([]byte("hello world1")) 114 | } 115 | ``` 116 | 117 | 还有一个开箱即用的 [tcp 服务](./tcp/tcpsrv) 118 | 119 | # [长连接 Websocket](./websocket) 120 | 121 | ```go 122 | import "github.com/go-eyas/toolkit/websocket" 123 | 124 | ws := websocket.New(&Config{}) 125 | http.HandleFunc("/ws", ws.HTTPHandler) 126 | go func() { 127 | rec := ws.Receive() 128 | for { 129 | req, _ := <-rec 130 | req.Response([]byte("1234556")) 131 | } 132 | }() 133 | http.ListenAndServe("127.0.0.1:8800", nil) 134 | ``` 135 | 136 | 还有一个类似 http api 的[开箱即用服务](./websocket/wsrv) 137 | 138 | # [RabbitMQ amqp](./amqp) 139 | 140 | ```go 141 | import "github.com/go-eyas/toolkit/amqp" 142 | 143 | mq := amqp.New(*amqp.Config{ 144 | Addr: "amqp://guest:guest@127.0.0.1:5672", 145 | ExchangeName: "toolkit.exchange.test", 146 | }) 147 | queue := &amqp.Queue{Name: "toolkit.queue.test"} 148 | err := mq.Pub(queue, &amqp.Message{Data: []byte("{\"hello\":\"world\"}")}) 149 | 150 | msgch, err := mq.Sub(queue) 151 | for msg := range msgch { 152 | fmt.Printf("%s", string(msg.Data)) 153 | } 154 | 155 | ``` 156 | 157 | # [配置项 config](./config) 158 | 159 | ```go 160 | import "github.com/go-eyas/toolkit/config" 161 | 162 | conf := struct { 163 | Host string 164 | Port int 165 | }{} 166 | config.Init("config", &conf) 167 | ``` 168 | 169 | # [数据库 ORM](./db) 170 | 171 | ```go 172 | import "github.com/go-eyas/toolkit/db" 173 | 174 | var db *gorm.DB = db.Gorm(&db.Config{"mysql", "username:password@127.0.0.1:3306/test"}) 175 | var db *xorm.Engine = db.Xorm(&db.Config{"mysql", "username:password@127.0.0.1:3306/test"}) 176 | 177 | defer db.Close() 178 | ``` 179 | 180 | # [资源模型CRUD](./db/resource) 181 | 182 | 资源自动 crud 183 | 184 | ```go 185 | import "github.com/go-eyas/toolkit/db" 186 | import "github.com/go-eyas/toolkit/db/resource" 187 | 188 | type Article struct { 189 | ID int64 `resource:"pk;search:none"` 190 | Title string `resource:"create;update;search:like"` 191 | Status byte `resource:"search:="` 192 | } 193 | 194 | var r, db, err = resource.New(&db.Config{"mysql", "username:password@127.0.0.1:3306/test"}, Article{}) 195 | 196 | r.Create(map[string]string{"title": "hello eyas"}) // 增 197 | r.Delete(1) // 删 198 | r.Update(1, map[string]int{"status": 1}) // 改 199 | 200 | // 查,指定主键 201 | var m = &Article{} 202 | r.Detail(1, m) 203 | 204 | // 查,指定查询条件查列表 205 | var list = []*Article{} 206 | r.List(&list, map[string]interface{}{"title": "he"}, []string{"id DESC"}) 207 | 208 | 209 | ``` 210 | 211 | # [Gin 中间件 & 工具](./gin) 212 | 213 | ```go 214 | import "github.com/go-eyas/toolkit/gin/util" // 工具函数 215 | import "github.com/go-eyas/toolkit/gin/middleware" // 中间件 216 | ``` 217 | 218 | # [事件分发器 Emitter](./emit) 219 | 220 | ```go 221 | import "github.com/go-eyas/toolkit/emit" 222 | fn1 := func(data interface{}) { 223 | fmt.Printf("fn1 receive data: %v", data) 224 | } 225 | 226 | emit.On("evt", fn1).Off("evt", fn1) 227 | emit.Emit("evt", "hello emitter") 228 | 229 | ``` 230 | 231 | # [邮件发送 Email](./email) 232 | 233 | ```go 234 | import ( 235 | "github.com/go-eyas/toolkit/email" 236 | "github.com/BurntSushi/toml" 237 | ) 238 | 239 | func ExampleSample() { 240 | tomlConfig := ` 241 | host = "smtp.qq.com" 242 | port = "465" 243 | account = "893521870@qq.com" 244 | password = "haha, wo cai bu gao su ni ne" 245 | name = "unit test" 246 | secure = true 247 | [tpl.a] 248 | bcc = ["Jeason "] # 抄送 249 | cc = [] # 抄送人 250 | subject = "Welcome, {{.Name}}" # 主题 251 | text = "Hello, I am {{.Name}}" # 文本 252 | html = "

Hello, I am {{.Name}}

" # html 内容 253 | ` 254 | conf := &Config{} 255 | toml.Decode(tomlConfig, conf) 256 | email := New(conf) 257 | email.SendByTpl("Yuesong Liu ", "a", struct{ Name string }{"Batman"}) 258 | } 259 | ``` 260 | 261 | # [工具函数 util](./util) 262 | 263 | ```go 264 | import "github.com/go-eyas/toolkit/util" 265 | ``` 266 | -------------------------------------------------------------------------------- /redis/cmd.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/go-eyas/toolkit/types" 8 | "github.com/go-redis/redis" 9 | ) 10 | 11 | var redisSetMu sync.Mutex 12 | 13 | func (r *RedisClient) Expire(key string, expiration time.Duration) (bool, error) { 14 | c := r.Client.Expire(r.Prefix+key, expiration) 15 | return c.Result() 16 | } 17 | 18 | // Get 获取字符串值 19 | func (r *RedisClient) Get(key string) (string, error) { 20 | v, err := r.Client.Get(r.Prefix + key).Result() 21 | if err == redis.Nil { 22 | err = nil 23 | } 24 | return v, err 25 | } 26 | 27 | // Set 设置字符串值,有效期默认 24 小时 28 | func (r *RedisClient) Set(key string, value interface{}, expiration ...time.Duration) error { 29 | redisSetMu.Lock() 30 | defer redisSetMu.Unlock() 31 | expire := RedisTTL 32 | if len(expiration) > 0 { 33 | expire = expiration[0] 34 | } 35 | s := value 36 | cmd := r.Client.Set(r.Prefix+key, s, expire) 37 | return cmd.Err() 38 | } 39 | 40 | // Del 删除键 41 | func (r *RedisClient) Del(keys ...string) error { 42 | ks := make([]string, len(keys)) 43 | for i, k := range keys { 44 | ks[i] = r.Prefix + k 45 | } 46 | cmd := r.Client.Del(ks...) 47 | return cmd.Err() 48 | } 49 | 50 | // HGet 获取 Hash 的字段值 51 | func (r *RedisClient) HGet(key string, field string) (string, error) { 52 | cmd := r.Client.HGet(r.Prefix+key, field) 53 | // log.Debugf("redis get hash key=%s, field=%s", r.Prefix+key, field) 54 | v, err := cmd.Result() 55 | if err == redis.Nil { 56 | return "", nil 57 | } 58 | return v, err 59 | } 60 | 61 | // HGetAll 获取 Hash 的所有字段 62 | func (r *RedisClient) HGetAll(key string) (map[string]string, error) { 63 | cmd := r.Client.HGetAll(r.Prefix + key) 64 | // log.Debugf("redis get all hash key=%s", r.Prefix+key) 65 | v, err := cmd.Result() 66 | mp := make(map[string]string) 67 | if err == redis.Nil { 68 | return mp, nil 69 | } 70 | for k, sv := range v { 71 | mp[k] = sv 72 | } 73 | return mp, err 74 | } 75 | 76 | // HSet 设置hash值 77 | func (r *RedisClient) HSet(key, field string, val interface{}, expiration ...time.Duration) error { 78 | redisSetMu.Lock() 79 | defer redisSetMu.Unlock() 80 | cmd := r.Client.HSet(r.Prefix+key, field, val) 81 | 82 | expire := RedisTTL 83 | if len(expiration) > 0 { 84 | expire = expiration[0] 85 | } 86 | r.Expire(key, expire) 87 | // log.Debugf("redis set hash key=%s, field=%s", r.Prefix+key, field) 88 | return cmd.Err() 89 | } 90 | 91 | // HDel 删除hash的键 92 | func (r *RedisClient) HDel(key string, field ...string) error { 93 | k := key 94 | cmd := r.Client.HDel(r.Prefix+k, field...) 95 | // log.Debugf("redis set hash key=%s, field=%s", r.Prefix+k, field) 96 | err := cmd.Err() 97 | if err != nil { 98 | return err 99 | } 100 | // 是否键全删完了,如果是就清理掉这个key 101 | length, err := r.Client.HLen(r.Prefix + k).Result() 102 | if err != nil { 103 | return err 104 | } 105 | if length == 0 { 106 | if err = r.Del(k); err != nil { 107 | return err 108 | } 109 | } 110 | 111 | return nil 112 | } 113 | 114 | type Message struct { 115 | Channel string 116 | Pattern string 117 | Payload string 118 | } 119 | 120 | // JSON 绑定json对象 121 | func (msg *Message) JSON(v interface{}) error { 122 | return types.JSONString(msg.Payload).JSON(v) 123 | } 124 | 125 | // Sub 监听通道,有数据时触发回调 handler 126 | // example: 127 | // redis.Sub("chat")(func(msg *redis.Message) { 128 | // fmt.Printf("receive message: %#v", msg) 129 | // }) 130 | func (r *RedisClient) Sub(channel string, handler func(*Message)) { 131 | pb := r.Client.Subscribe(channel) 132 | ch := pb.Channel() 133 | 134 | for msg := range ch { 135 | handler(&Message{msg.Channel, msg.Pattern, msg.Payload}) 136 | } 137 | 138 | defer pb.Close() 139 | } 140 | 141 | // Pub 发布事件 142 | // example: 143 | // Redis.Pub("chat", "this is a test message") 144 | func (r *RedisClient) Pub(channel string, msg string) error { 145 | cmd := r.Client.Publish(channel, msg) 146 | _, err := cmd.Result() 147 | return err 148 | } 149 | -------------------------------------------------------------------------------- /redis/readme.md: -------------------------------------------------------------------------------- 1 | # redis 2 | 3 | 4 | ## 初始化 5 | 只支持单例 redis 6 | 7 | ```go 8 | import "github.com/go-eyas/toolkit/redis" 9 | 10 | func main() { 11 | // 使用前必须先初始化 12 | r, err := redis.Init(&redis.Config{ 13 | Cluster: false, // 是否集群 14 | Addrs: []string{"10.0.3.252:6379"}, // redis 地址,如果是集群则在数组上写多个元素 15 | Password: "", 16 | DB: 1, 17 | }) 18 | if err != nil { 19 | panic(err) 20 | } 21 | 22 | err = r.Set("tookit:test", `{"hello": "world"}`) 23 | 24 | v, err = r.Get("tookit:test") 25 | fmt.Printf("v: %s", v) // v: {"hello": "world"} 26 | 27 | err = redis.Del("tookit:test") 28 | 29 | 30 | redis.Expire("tookit:test", time.Hour * 24) 31 | 32 | redis.Redis // *redis.RedisClient 33 | redis.Client // *github.com/go-redis/redis.Client 34 | } 35 | ``` 36 | 37 | ## godoc 38 | 39 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/redis) -------------------------------------------------------------------------------- /redis/redis.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/go-eyas/toolkit/log" 8 | "github.com/go-redis/redis" 9 | ) 10 | 11 | type Config struct { 12 | Cluster bool 13 | Addrs []string 14 | Password string 15 | DB int 16 | Prefix string 17 | } 18 | 19 | // redisClientInterface redis 实例拥有的功能 20 | // type redisClientInterface interface { 21 | // redis.Cmdable 22 | // Subscribe(...string) *redis.PubSub 23 | // Close() error 24 | // } 25 | 26 | // RedisClient redis client wrapper 27 | type RedisClient struct { 28 | isCluster bool 29 | Namespace string 30 | Client redis.UniversalClient 31 | Prefix string 32 | } 33 | 34 | // RedisTTL 默认有效期 24 小时 35 | var RedisTTL = time.Hour * 24 36 | 37 | // Redis 暴露的redis封装 38 | // var Redis *RedisClient 39 | 40 | // redis 客户端实例 41 | // var Client redis.UniversalClient 42 | 43 | // Init 初始化redis 44 | func New(redisConf *Config) (*RedisClient, error) { 45 | r := &RedisClient{} 46 | r.isCluster = redisConf.Cluster 47 | r.Prefix = redisConf.Prefix 48 | 49 | if len(redisConf.Addrs) == 0 { 50 | return nil, errors.New("empty addrs") 51 | } 52 | 53 | if redisConf.Cluster { 54 | r.Client = redis.NewClusterClient(&redis.ClusterOptions{ 55 | Addrs: redisConf.Addrs, 56 | Password: redisConf.Password, 57 | }) 58 | } else { 59 | r.Client = redis.NewClient(&redis.Options{ 60 | Addr: redisConf.Addrs[0], 61 | Password: redisConf.Password, 62 | DB: redisConf.DB, 63 | }) 64 | } 65 | _, err := r.Client.Ping().Result() 66 | if err != nil { 67 | log.Errorf("redis 连接失败, err=%v", err) 68 | return r, err 69 | } 70 | // Redis = r 71 | // Client = r.Client 72 | return r, nil 73 | } 74 | 75 | // Close 关闭redis连接 76 | func (r *RedisClient) Close() { 77 | if r != nil && r.Client != nil { 78 | r.Client.Close() 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /redis/redis_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/go-eyas/toolkit/types" 8 | ) 9 | 10 | func TestRedisConnError(t *testing.T) { 11 | _, err := New(&Config{}) 12 | if err != nil { 13 | t.Log("redis conn error success") 14 | } else { 15 | panic("empty addrs should error") 16 | } 17 | } 18 | func TestRedis(t *testing.T) { 19 | r, err := New(&Config{ 20 | Cluster: false, 21 | Addrs: []string{"10.0.2.252:6379"}, 22 | DB: 1, 23 | Prefix: "test:prefix:", 24 | }) 25 | if err != nil { 26 | panic("redis connect fail") 27 | } 28 | 29 | key := "tookit:test" 30 | val := `{"hello": "world"}` 31 | data := "world" 32 | 33 | // test Set 34 | err = r.Set(key, val) 35 | if err != nil { 36 | panic("set redis fail") 37 | } 38 | // test Get 39 | v, err := r.Get(key) 40 | if err != nil { 41 | panic("get redis fail") 42 | } 43 | 44 | // test Bind json 45 | res := struct { 46 | Hello string `json:"hello"` 47 | }{} 48 | err = types.JSONString(v).JSON(&res) 49 | 50 | if err == nil && res.Hello == data { 51 | t.Log("get redis success") 52 | } else { 53 | panic(err) 54 | } 55 | 56 | // test Del 57 | err = r.Del(key) 58 | if err != nil { 59 | panic("del redis key error") 60 | } 61 | v, err = r.Get(key) 62 | if err == nil && v == "" { 63 | t.Log("del key success") 64 | } else { 65 | panic("get redis fail") 66 | } 67 | 68 | // test sub/pub 69 | pbChan := make(chan *Message) 70 | go r.Sub("tookit-pub", func(msg *Message) { 71 | t.Logf("sub receive: %v", msg) 72 | pbChan <- msg 73 | }) 74 | <-time.After(time.Second) 75 | err = r.Pub("tookit-pub", val) 76 | if err != nil { 77 | panic(err) 78 | } else { 79 | t.Logf("pub success") 80 | } 81 | 82 | msg := <-pbChan 83 | 84 | res2 := struct { 85 | Hello string `json:"hello"` 86 | }{} 87 | err = msg.JSON(&res2) 88 | if err != nil || res2.Hello != data { 89 | panic("sub receive wrong") 90 | } else { 91 | t.Logf("sub receive success: %v", res2) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /tcp/client.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "time" 7 | ) 8 | 9 | // Client tcp 客户端 10 | type Client struct { 11 | Conn *Conn 12 | Config *Config 13 | recChan chan *Message // 收到的数据,这里是经过 Parser 解析后的数据 14 | socketCount uint64 15 | autoReconnect bool 16 | 17 | createConnHandlers []connHandler // 当有新连接建立时触发函数 18 | closeConnHandlers []connHandler // 当有连接关闭时触发函数 19 | 20 | closeNotify chan *Conn // 连接关闭时通知通道 21 | isConnect bool // 当前是否已连接 22 | } 23 | 24 | // 实例化 tcp 客户端连接,与服务器建立 TCP 连接 25 | func NewClient(conf *Config) (*Client, error) { 26 | var defaultParsePoll map[uint64][]byte 27 | if conf.Packer == nil && conf.Parser == nil { 28 | conf.Packer = Packer 29 | defaultParsePoll, conf.Parser = Parser() 30 | } else if conf.Packer == nil || conf.Parser == nil { 31 | return nil ,errors.New("the Packer and Parser must be specified together") 32 | } 33 | 34 | if conf.Logger == nil { 35 | conf.Logger = EmptyLogger 36 | } 37 | 38 | if conf.Network == "" { 39 | conf.Network = "tcp" 40 | } 41 | 42 | 43 | client := &Client{ 44 | autoReconnect: true, 45 | Config: conf, 46 | recChan: make(chan *Message, 2), 47 | closeNotify: make(chan *Conn, 0), 48 | } 49 | 50 | client.HandleCreate(func(conn *Conn) { 51 | client.isConnect = true 52 | }) 53 | 54 | // 连接关闭了通知一下 55 | client.HandleClose(func(conn *Conn) { 56 | client.isConnect = false 57 | delete(defaultParsePoll, conn.ID) 58 | if client.autoReconnect { 59 | client.closeNotify <- conn 60 | } 61 | }) 62 | 63 | err := client.connect() 64 | 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | go client.reconnect() 70 | 71 | return client, nil 72 | } 73 | 74 | // 连接tcp 75 | func (c *Client) connect() error { 76 | dial, err := net.Dial(c.Config.Network, c.Config.Addr) 77 | if err != nil { 78 | return err 79 | } 80 | c.socketCount++ 81 | conn := &Conn{Conn: dial, ID: c.socketCount} 82 | c.Conn = conn 83 | conn.client = c 84 | 85 | go c.reader() 86 | return nil 87 | } 88 | 89 | // tcp 断开重连机制 90 | func (c *Client) reconnect() { 91 | if !c.autoReconnect { 92 | close(c.closeNotify) 93 | return 94 | } 95 | <- c.closeNotify 96 | // conn := <- c.closeNotify 97 | // fmt.Printf("conn %d is close, retrying...\n", conn.ID) 98 | for { 99 | time.Sleep(1 * time.Second) 100 | err := c.connect() 101 | if err != nil { 102 | // fmt.Printf("reconnect fail: %v\n", err) 103 | } else { 104 | // fmt.Printf("reconnect ok: \n") 105 | go c.reconnect() 106 | break 107 | } 108 | } 109 | 110 | } 111 | 112 | // 读取 tcp 连接数据 113 | func (c *Client) reader() { 114 | go c.Conn.reader() 115 | for _, h := range c.createConnHandlers { 116 | h(c.Conn) 117 | } 118 | } 119 | 120 | // HandleCreate 每当连接建立成功后时,触发函数 121 | func (c *Client) HandleCreate(h connHandler) { 122 | c.createConnHandlers = append(c.createConnHandlers, h) 123 | // 如果注册的时候已经连接上了,手动触发一次,防止首次连接不会触发 124 | if c.isConnect { 125 | h(c.Conn) 126 | } 127 | } 128 | 129 | // HandleClose 每当连接断开后,触发函数 130 | func (c *Client) HandleClose(h connHandler) { 131 | c.closeConnHandlers = append(c.closeConnHandlers, h) 132 | } 133 | 134 | // Receive 接收数据 135 | func (c *Client) Receive() <-chan *Message { 136 | return c.recChan 137 | } 138 | 139 | // Send 往连接发送数据 140 | func (c *Client) Send(msg []byte) error { 141 | return c.Conn.Send(msg) 142 | } 143 | 144 | // Destroy 关闭并销毁连接 145 | func (c *Client) Destroy() error { 146 | c.autoReconnect = false 147 | return c.Conn.Destroy() 148 | } -------------------------------------------------------------------------------- /tcp/client_test.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestClient(t *testing.T) { 10 | client, err := NewClient(&Config{ 11 | Network: "tcp", 12 | Addr: ":6600", 13 | 14 | // 解析私有协议为结构体,如果当前没有解析到,返回nil 15 | // Parser: func(conn *Conn, bt []byte) (interface{}, error) { 16 | // return nil, nil 17 | // }, 18 | // 19 | // // 将数据转换成字节数组,发送时就发该段数据 20 | // Packer: func(data interface{}) ([]byte, error) { 21 | // return nil, nil 22 | // }, 23 | }) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | t.Log("connect ok") 28 | 29 | // 发送数据 30 | t.Log("send 1") 31 | err = client.Send([]byte("hello world1")) 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | t.Log("send 2") 36 | err = client.Send([]byte("hello world2")) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | t.Log("send 3") 41 | err = client.Send([]byte("hello world3")) 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | t.Log("send 4") 46 | err = client.Send([]byte("hello world4")) 47 | if err != nil { 48 | t.Fatal(err) 49 | } 50 | t.Log("send 5") 51 | err = client.Send([]byte("hello world5")) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | t.Log("send 6") 56 | err = client.Send([]byte("hello world6")) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | 61 | go func() { 62 | for i := 0; true; i++ { 63 | time.Sleep(2 * time.Second) 64 | err = client.Send([]byte(fmt.Sprintf("auto send %d", i))) 65 | 66 | if i > 20 { 67 | client.Destroy() 68 | } 69 | if err != nil { 70 | t.Log(err) 71 | } 72 | } 73 | }() 74 | 75 | 76 | ch := client.Receive() 77 | 78 | for data := range ch { 79 | fmt.Println("receive data:", string(data.Data)) 80 | // data.Response(map[string]interface{}{ 81 | // "msg": "I'm fine.", 82 | // }) 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /tcp/config.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | // Config 配置项 4 | type Config struct { 5 | Addr string // tcp 地址,在客户端使用为需要连接的地址,在服务端使用为监听的地址 6 | Network string // tcp 的网络类型,可选值为 "tcp", "tcp4", "tcp6", "unix" or "unixpacket" 7 | Packer func([]byte) ([]byte, error) // tcp 数据包的封装函数,传入的数据是需要发送的业务数据,返回发送给 tcp 的数据 8 | Parser func(*Conn, []byte) ([][]byte, error) // 将收到的数据包,根据私有协议转换成业务数据,在这里处理粘包,半包等数据包问题,返回处理好的数据包 9 | Logger LoggerI 10 | } 11 | -------------------------------------------------------------------------------- /tcp/conn.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "sync" 7 | ) 8 | 9 | type Conn struct { 10 | writeMu sync.Mutex 11 | ID uint64 12 | Conn net.Conn 13 | server *Server 14 | client *Client 15 | } 16 | 17 | // IsServer 当前连接是否为服务器实例 18 | func (conn *Conn) IsServer() bool { 19 | return conn.server != nil 20 | } 21 | 22 | // IsClient 当前连接是否为客户端实例 23 | func (conn *Conn) IsClient() bool { 24 | return conn.client != nil 25 | } 26 | 27 | // 接收连接数据 28 | func (conn *Conn) reader() { 29 | var parser func(*Conn, []byte) ([][]byte, error) 30 | if conn.IsClient() { 31 | parser = conn.client.Config.Parser 32 | } else if conn.IsServer() { 33 | parser = conn.server.Config.Parser 34 | } 35 | for { 36 | _buf := make([]byte, 1024) 37 | buflen, err := conn.Conn.Read(_buf) 38 | if err != nil { 39 | // 数据异常,马上断开连接 40 | conn.Destroy() 41 | break 42 | } 43 | buf := _buf[:buflen] 44 | body, err := parser(conn, buf) 45 | if err != nil { 46 | // 解析异常,断开连接 47 | conn.Destroy() 48 | break 49 | } 50 | for _, body := range body { 51 | msg := &Message{ 52 | Data: body, 53 | Conn: conn, 54 | } 55 | if conn.IsServer() { 56 | conn.server.recChan <- msg 57 | } else if conn.IsClient() { 58 | conn.client.recChan <- msg 59 | } 60 | } 61 | } 62 | } 63 | 64 | // Send 给当前连接发送数据 65 | func (conn *Conn) Send(msg []byte) error { 66 | conn.writeMu.Lock() 67 | defer conn.writeMu.Unlock() 68 | var err error 69 | var pack []byte 70 | 71 | if conn.IsClient() { 72 | pack, err = conn.client.Config.Packer(msg) 73 | } else if conn.IsServer() { 74 | pack, err = conn.server.Config.Packer(msg) 75 | } else { 76 | err = errors.New("the connection is invalid") 77 | } 78 | 79 | if err != nil { 80 | return err 81 | } 82 | _, err = conn.Conn.Write(pack) 83 | return err 84 | } 85 | 86 | // Destroy 关闭并销毁连接 87 | func (conn *Conn) Destroy() error { 88 | if conn.IsServer() { 89 | for _, h := range conn.server.closeConnHandlers { 90 | h(conn) 91 | } 92 | 93 | delete(conn.server.Sockets, conn.ID) 94 | } else if conn.IsClient() { 95 | for _, h := range conn.client.closeConnHandlers { 96 | h(conn) 97 | } 98 | } 99 | return conn.Conn.Close() 100 | } 101 | -------------------------------------------------------------------------------- /tcp/logger.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | type LoggerI interface { 4 | Info(...interface{}) 5 | Infof(string, ...interface{}) 6 | Error(...interface{}) 7 | Errorf(string, ...interface{}) 8 | } 9 | 10 | type l struct{} 11 | 12 | func (l) Info(v ...interface{}) {} 13 | func (l) Infof(s string, v ...interface{}) {} 14 | func (l) Error(v ...interface{}) {} 15 | func (l) Errorf(s string, v ...interface{}) {} 16 | 17 | var EmptyLogger = &l{} 18 | 19 | var logger LoggerI = EmptyLogger 20 | -------------------------------------------------------------------------------- /tcp/message.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | type Message struct { 4 | Data []byte 5 | Conn *Conn 6 | } 7 | 8 | func (m *Message) Response(data []byte) error { 9 | return m.Conn.Send(data) 10 | } 11 | -------------------------------------------------------------------------------- /tcp/package.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "encoding/binary" 5 | "github.com/go-eyas/toolkit/util" 6 | ) 7 | 8 | // 一个默认的私有协议实现 9 | 10 | // 协议组成 11 | // 4bt(自定义数据长度) + 任意bt(json字符串数据) 12 | // json 格式 {"cmd": "test", "data": {}} 13 | 14 | // 打包 15 | func Packer(data []byte) ([]byte, error) { 16 | bodyLen := uint32(len(data)) 17 | header := make([]byte, 4) 18 | binary.BigEndian.PutUint32(header, bodyLen) 19 | 20 | pkg := util.BytesCombine(header, data) 21 | return pkg, nil 22 | } 23 | 24 | func Parser() (map[uint64][]byte, func(conn *Conn, bt []byte) ([][]byte, error)) { 25 | // 解包 26 | var parserBuf = make(map[uint64][]byte) 27 | return parserBuf, func(conn *Conn, bt []byte) ([][]byte, error) { 28 | preBuf, ok := parserBuf[conn.ID] 29 | if !ok { 30 | preBuf = make([]byte, 0) 31 | parserBuf[conn.ID] = preBuf 32 | } 33 | 34 | buf := util.BytesCombine(preBuf, bt) 35 | datas := make([][]byte, 0) 36 | 37 | for { 38 | if len(buf) < 4 { 39 | break 40 | } 41 | header := buf[:4] 42 | bodyLen := binary.BigEndian.Uint32(header) 43 | if uint32(len(buf)) < (4 + bodyLen) { 44 | break 45 | } 46 | pack := buf[4 : 4+bodyLen] 47 | buf = buf[4+bodyLen:] 48 | datas = append(datas, pack) 49 | } 50 | parserBuf[conn.ID] = buf 51 | 52 | return datas, nil 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /tcp/readme.md: -------------------------------------------------------------------------------- 1 | # TCP 2 | 3 | 基于底层的 tcp 封装,有 客户端和服务端,本工具的服务端和客户端封装了通用的操作,均可单独在用于各个项目。 4 | 5 | 只简化了常用的 tcp 操作 6 | 7 | # 配置 8 | 9 | 实例化 tcp 的服务端和客户端的时候要传入配置,服务端和客户端配置共用同个结构体 `*tcp.Config` ,说明如下 10 | 11 | ```go 12 | // Config 配置项 13 | type Config struct { 14 | Addr string // tcp 地址,在客户端使用为需要连接的地址,在服务端使用为监听的地址 15 | Network string // tcp 的网络类型,可选值为 "tcp", "tcp4", "tcp6", "unix" or "unixpacket" 16 | Packer func([]byte) ([]byte, error) // tcp 数据包的封装函数,传入的数据是需要发送的业务数据,返回发送给 tcp 的数据 17 | Parser func(*Conn, []byte) ([][]byte, error) // 将收到的数据包,根据私有协议转换成业务数据,在这里处理粘包,半包等数据包问题,返回处理好的数据包 18 | Logger LoggerI // 打日志的实例 19 | } 20 | ``` 21 | 22 | Packer 和 Parser 是 tcp 数据包的处理,是同时对等出现的。用于实现 tcp 的私有协议。 23 | 24 | 如果 Packer 和 Parser 两个都不传,会使用默认的私有协议实现,默认私有协议的数据包组成是 25 | 26 | ``` 27 | header[4字节标识body字节长度] + body[任意长度] 28 | ``` 29 | 30 | # 客户端 Client 31 | 32 | tcp 的客户端封装,带有自动重连,简易的 api 封装,只关注收发数据即可 33 | 34 | ```go 35 | package main 36 | 37 | import ( 38 | "fmt" 39 | "github.com/go-eyas/toolkit/tcp" 40 | ) 41 | 42 | func main() { 43 | client, err := tcp.NewClient(&tcp.Config{ 44 | Network: "tcp", // 网络类型,不填默认 tcp 45 | // tcp 服务端地址 46 | Addr: "127.0.0.1:6600", 47 | 48 | // 私有协议实现,不传将使用默认的私有协议实现 49 | // Parser: func([]byte) ([]byte, error) {}, 50 | // Packer: func(*Conn, []byte) ([][]byte, error){}, 51 | }) 52 | if err != nil { 53 | panic(err) 54 | } 55 | 56 | // 接收数据 57 | ch := client.Receive() 58 | go func() { 59 | for msg := range ch { 60 | // msg.Data 经过 Parser 处理过的数据 61 | // msg.Conn tcp 连接实例 62 | fmt.Println("client receive:", string(msg.Data)) 63 | } 64 | }() 65 | 66 | // 发送数据,send 后将立马把数据传给 Packer 处理后,在发送到 tcp 连接 67 | err = client.Send([]byte("hello world1")) 68 | } 69 | 70 | 71 | ``` 72 | 73 | # 服务端 Server 74 | 75 | ```go 76 | package main 77 | 78 | import ( 79 | "fmt" 80 | "github.com/go-eyas/toolkit/tcp" 81 | ) 82 | 83 | func main() { 84 | server, err := tcp.NewServer(&tcp.Config{ 85 | Network: "tcp", // 网络类型,不填默认 tcp 86 | // tcp 监听地址 87 | Addr: "127.0.0.1:6600", 88 | 89 | // 私有协议实现,不传将使用默认的私有协议实现 90 | // Parser: func([]byte) ([]byte, error) {}, 91 | // Packer: func(*Conn, []byte) ([][]byte, error){}, 92 | }) 93 | 94 | if err != nil { 95 | panic(err) 96 | } 97 | 98 | // 接收数据 99 | ch := server.Receive() 100 | for data := range ch { 101 | fmt.Printf("server receive: %v", data.Data) 102 | 103 | // 服务器收到数据后,响应发送一条数据到客户端 104 | err := data.Response([]byte("server receive your message")) 105 | if err != nil { 106 | panic(err) 107 | } 108 | } 109 | 110 | // 给所有连接都发送消息 111 | for connID, conn := range server.Sockets { 112 | fmt.Println("connID: ", connID) 113 | server.Send(conn, []byte("broadcast some message")) 114 | // or 115 | // server.SendConnID(connID, []byte("broadcast some message")) 116 | } 117 | } 118 | 119 | ``` 120 | 121 | # API 122 | 123 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/tcp) -------------------------------------------------------------------------------- /tcp/server.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | ) 7 | 8 | type connHandler func(*Conn) 9 | 10 | // Server 服务器 11 | type Server struct { 12 | Listener net.Listener // 服务器监听实例 13 | Sockets map[uint64]*Conn // 当前与客户端的连接实例 14 | Config *Config // 配置项 15 | 16 | recChan chan *Message // 收到的数据,这里是经过 Parser 解析后的数据 17 | socketCount uint64 // id 计数器 18 | createConnHandlers []connHandler // 当有新连接建立时触发函数 19 | closeConnHandlers []connHandler // 当有连接关闭时触发函数 20 | } 21 | 22 | // NewServer 实例化服务器 23 | func NewServer(conf *Config) (*Server, error) { 24 | var defaultParsePoll map[uint64][]byte 25 | if conf.Packer == nil && conf.Parser == nil { 26 | conf.Packer = Packer 27 | defaultParsePoll, conf.Parser = Parser() 28 | } else if conf.Packer == nil || conf.Parser == nil { 29 | return nil, errors.New("the Packer and Parser must be specified together") 30 | } 31 | 32 | if conf.Logger == nil { 33 | conf.Logger = EmptyLogger 34 | } 35 | 36 | if conf.Network == "" { 37 | conf.Network = "tcp" 38 | } 39 | 40 | listener, err := net.Listen("tcp", conf.Addr) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | server := &Server{ 46 | Listener: listener, 47 | Config: conf, 48 | Sockets: make(map[uint64]*Conn), 49 | recChan: make(chan *Message, 2), 50 | // sendChan: make(chan *Message, 2), 51 | createConnHandlers: make([]connHandler, 0), 52 | closeConnHandlers: make([]connHandler, 0), 53 | } 54 | 55 | go server.accept() 56 | 57 | // 清理已关闭的连接解析池 58 | if defaultParsePoll != nil { 59 | server.HandleClose(func(conn *Conn) { 60 | if _, ok := defaultParsePoll[conn.ID]; ok { 61 | } 62 | }) 63 | } 64 | return server, nil 65 | } 66 | 67 | // 接收新连接 68 | func (sv *Server) accept() { 69 | for { 70 | conn, err := sv.Listener.Accept() 71 | if err != nil { 72 | continue 73 | } 74 | sv.newConn(conn) 75 | } 76 | } 77 | 78 | // Receive 接收数据 79 | func (sv *Server) Receive() <-chan *Message { 80 | return sv.recChan 81 | } 82 | 83 | // Send 发送数据到指定连接实例 84 | func (sv *Server) Send(conn *Conn, data []byte) error { 85 | return conn.Send(data) 86 | } 87 | 88 | // SendConnID 发送数据到指定连接实例ID 89 | func (sv *Server) SendConnID(id uint64, data []byte) error { 90 | conn, ok := sv.Sockets[id] 91 | if !ok || conn == nil { 92 | return errors.New("invalid connection") 93 | } 94 | return sv.Send(conn, data) 95 | } 96 | 97 | // HandleCreate 每当有新连接建立时,触发函数 98 | func (sv *Server) HandleCreate(h connHandler) { 99 | sv.createConnHandlers = append(sv.createConnHandlers, h) 100 | } 101 | 102 | // HandleClose 每当有连接关闭时,触发函数 103 | func (sv *Server) HandleClose(h connHandler) { 104 | sv.closeConnHandlers = append(sv.closeConnHandlers, h) 105 | } 106 | 107 | func (sv *Server) newConn(conn net.Conn) *Conn { 108 | sv.socketCount++ 109 | c := &Conn{ 110 | ID: sv.socketCount, 111 | Conn: conn, 112 | server: sv, 113 | } 114 | sv.Sockets[c.ID] = c 115 | 116 | // 触发器 117 | for _, h := range sv.createConnHandlers { 118 | h(c) 119 | } 120 | 121 | go c.reader() 122 | return c 123 | } 124 | -------------------------------------------------------------------------------- /tcp/server_test.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestServer(t *testing.T) { 8 | server, err := NewServer(&Config{ 9 | Network: "tcp", 10 | Addr: ":6600", 11 | 12 | // // 解析私有协议为结构体,如果当前没有解析到,返回 error 13 | // Parser: func(conn *Conn, bt []byte) (interface{}, error) { 14 | // return nil, nil 15 | // }, 16 | // 17 | // // 将数据转换成字节数组,发送时就发该段数据,如果解析错误返回 error 18 | // Packer: func(data interface{}) ([]byte, error) { 19 | // return nil, nil 20 | // }, 21 | }) 22 | 23 | if err != nil { 24 | panic(err) 25 | } 26 | 27 | ch := server.Receive() 28 | 29 | for data := range ch { 30 | t.Logf("receive: %v", string(data.Data)) 31 | err := data.Response([]byte(`server receive: ` + string(data.Data))) 32 | if err != nil { 33 | t.Log(err) 34 | } 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /tcp/tcpsrv/client.go: -------------------------------------------------------------------------------- 1 | package tcpsrv 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "github.com/go-eyas/toolkit/emit" 7 | "github.com/go-eyas/toolkit/tcp" 8 | "github.com/go-eyas/toolkit/util" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // WSRequest 请求数据 14 | type TCPClientRequest struct { 15 | CMD string `json:"cmd"` 16 | Seqno string `json:"seqno"` 17 | Data interface{} `json:"data"` 18 | } 19 | 20 | type ClientSrv struct { 21 | sendTimeout time.Duration 22 | Engine *tcp.Client 23 | Emitter *emit.Emitter 24 | sendMutex sync.Mutex 25 | sendProcess map[string]chan *TCPResponse 26 | } 27 | 28 | const onMessageEvt = "__ON_ALL_MESSAGE__" 29 | 30 | // NewClientSrv 实例化客户端服务 31 | func NewClientSrv(conf *tcp.Config) (*ClientSrv, error) { 32 | engine, err := tcp.NewClient(conf) 33 | if err != nil { 34 | return nil, err 35 | } 36 | srv := &ClientSrv{ 37 | Engine: engine, 38 | Emitter: emit.New(), 39 | sendTimeout: 10 * time.Second, 40 | sendProcess: make(map[string]chan *TCPResponse), 41 | } 42 | 43 | go srv.reader() 44 | go srv.heartbeat() 45 | 46 | return srv, nil 47 | } 48 | 49 | func (cs *ClientSrv) heartbeat() { 50 | for { 51 | time.Sleep(5 * time.Second) 52 | cs.Engine.Send([]byte{}) 53 | } 54 | } 55 | 56 | func (cs *ClientSrv) reader() { 57 | ch := cs.Engine.Receive() 58 | for msg := range ch { 59 | res := &TCPResponse{} 60 | err := json.Unmarshal(msg.Data, res) 61 | if err != nil { 62 | cs.Emitter.Emit("error", err) 63 | continue 64 | } 65 | cs.sendMutex.Lock() 66 | ch, ok := cs.sendProcess[res.Seqno] 67 | cs.sendMutex.Unlock() 68 | if ok { 69 | ch <- res 70 | } 71 | cs.Emitter.Emit(res.CMD, res) 72 | cs.Emitter.Emit(onMessageEvt, res) 73 | } 74 | } 75 | 76 | // On 监听服务器响应数据,每当服务器有数据发送过来,都会以 cmd 为事件名触发监听函数 77 | func (cs *ClientSrv) On(cmd string, h func(*TCPResponse)) { 78 | cs.Emitter.On(cmd, func(_res interface{}) { 79 | res, ok := _res.(*TCPResponse) 80 | if ok { 81 | h(res) 82 | } 83 | }) 84 | } 85 | 86 | func (cs *ClientSrv) OnMessage(h func(*TCPResponse)) { 87 | cs.Emitter.On(onMessageEvt, func(_res interface{}) { 88 | res, ok := _res.(*TCPResponse) 89 | if ok { 90 | h(res) 91 | } 92 | }) 93 | } 94 | 95 | // Pub 给服务器发送消息 96 | func (cs *ClientSrv) Pub(cmd string, data interface{}) error { 97 | _, err := cs.writeSend(cmd, data) 98 | return err 99 | } 100 | 101 | // Send 给服务器发送消息,并等待服务器的响应数据,10秒超时 102 | func (cs *ClientSrv) Send(cmd string, datas ...interface{}) (*TCPResponse, error) { 103 | var data interface{} 104 | if len(datas) > 0 { 105 | data = datas[0] 106 | } 107 | body, err := cs.writeSend(cmd, data) 108 | if err != nil { 109 | return nil, err 110 | } 111 | cs.sendMutex.Lock() 112 | cs.sendProcess[body.Seqno] = make(chan *TCPResponse) 113 | cs.sendMutex.Unlock() 114 | ticker := time.Tick(cs.sendTimeout) 115 | select { 116 | case res := <-cs.sendProcess[body.Seqno]: 117 | cs.sendMutex.Lock() 118 | close(cs.sendProcess[body.Seqno]) 119 | delete(cs.sendProcess, body.Seqno) 120 | cs.sendMutex.Unlock() 121 | return res, nil 122 | case <-ticker: 123 | cs.sendMutex.Lock() 124 | close(cs.sendProcess[body.Seqno]) 125 | delete(cs.sendProcess, body.Seqno) 126 | cs.sendMutex.Unlock() 127 | return nil, errors.New("request timeout") 128 | } 129 | 130 | } 131 | 132 | // 封装数据,并发送数据到服务端 133 | func (cs *ClientSrv) writeSend(cmd string, data interface{}) (*TCPRequest, error) { 134 | body := &TCPRequest{ 135 | CMD: cmd, 136 | Seqno: util.RandomStr(8), 137 | } 138 | err := body.SetJSON(data) 139 | if err != nil { 140 | return nil, err 141 | } 142 | 143 | raw, err := json.Marshal(body) 144 | if err != nil { 145 | return nil, err 146 | } 147 | return body, cs.Engine.Send(raw) 148 | } 149 | -------------------------------------------------------------------------------- /tcp/tcpsrv/client_test.go: -------------------------------------------------------------------------------- 1 | package tcpsrv 2 | 3 | import ( 4 | "fmt" 5 | "github.com/go-eyas/toolkit/log" 6 | "github.com/go-eyas/toolkit/tcp" 7 | "testing" 8 | ) 9 | 10 | func TestClient(t *testing.T) { 11 | client, err := NewClientSrv(&tcp.Config{ 12 | Addr: ":6601", 13 | Logger: log.SugaredLogger, 14 | }) 15 | if err != nil { 16 | panic(err) 17 | } 18 | 19 | client.On("register", func(response *TCPResponse) { 20 | fmt.Println("on receive register msg:", response) 21 | }) 22 | 23 | client.On("userinfo", func(response *TCPResponse) { 24 | fmt.Println("on receive userinfo msg:", response) 25 | }) 26 | 27 | res, err := client.Send("register", map[string]interface{}{ 28 | "uid": 1234, 29 | }) 30 | if err != nil { 31 | panic(err) 32 | } 33 | fmt.Println("send register response: ", res) 34 | 35 | res, err = client.Send("userinfo") 36 | if err != nil { 37 | panic(err) 38 | } 39 | fmt.Println("send userinfo response: ", res) 40 | 41 | // go func() { 42 | // ch := client.Receive() 43 | // for data := range ch { 44 | // fmt.Println("receive data:", string(data.Data)) 45 | // // data.Response(map[string]interface{}{ 46 | // // "msg": "I'm fine.", 47 | // // }) 48 | // } 49 | // }() 50 | 51 | // send register 52 | // data, _ := json.Marshal(map[string]interface{}{ 53 | // "cmd": "register", 54 | // }) 55 | // err = client.Send(data) 56 | // if err != nil { 57 | // panic(err) 58 | // } 59 | // time.Sleep(1 * time.Second) 60 | // 61 | // // send userinfo 62 | // data, _ = json.Marshal(map[string]interface{}{ 63 | // "cmd": "userinfo", 64 | // }) 65 | // err = client.Send(data) 66 | // if err != nil { 67 | // panic(err) 68 | // } 69 | // 70 | c := make(chan bool, 0) 71 | <- c 72 | } 73 | -------------------------------------------------------------------------------- /tcp/tcpsrv/context.go: -------------------------------------------------------------------------------- 1 | package tcpsrv 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/go-eyas/toolkit/gin/util" 6 | "github.com/go-eyas/toolkit/tcp" 7 | "github.com/go-playground/validator/v10" 8 | "reflect" 9 | "sync" 10 | ) 11 | 12 | var validate = validator.New() 13 | 14 | // WSRequest 请求数据 15 | type TCPRequest struct { 16 | CMD string `json:"cmd"` 17 | Seqno string `json:"seqno"` 18 | Data json.RawMessage `json:"data"` 19 | } 20 | 21 | func (r *TCPRequest) SetJSON(v interface{}) (err error) { 22 | r.Data, err = convertDataToJsonByte(v) 23 | return 24 | } 25 | 26 | func (r *TCPRequest) BindJSON(v interface{}) error { 27 | err := json.Unmarshal(r.Data, v) 28 | if err != nil { 29 | return err 30 | } 31 | rt := reflect.TypeOf(v) 32 | if rt.Kind() == reflect.Ptr { 33 | rt = rt.Elem() 34 | } 35 | if rt.Kind() == reflect.Struct { 36 | return validate.Struct(v) 37 | } 38 | return nil 39 | } 40 | 41 | // WSResponse 响应数据 42 | type TCPResponse struct { 43 | CMD string `json:"cmd"` 44 | Seqno string `json:"seqno"` 45 | Status int `json:"status"` 46 | Msg string `json:"msg"` 47 | Data json.RawMessage `json:"data"` 48 | } 49 | 50 | func (r *TCPResponse) SetJSON(v interface{}) (err error) { 51 | r.Data, err = convertDataToJsonByte(v) 52 | return 53 | } 54 | 55 | func (r *TCPResponse) BindJSON(v interface{}) error { 56 | err := json.Unmarshal(r.Data, v) 57 | if err != nil { 58 | return err 59 | } 60 | rt := reflect.TypeOf(v) 61 | if rt.Kind() == reflect.Ptr { 62 | rt = rt.Elem() 63 | } 64 | if rt.Kind() == reflect.Struct { 65 | return validate.Struct(v) 66 | } 67 | return nil 68 | } 69 | 70 | // Context 请求上下文 71 | type Context struct { 72 | Values map[string]interface{} // 该会话注册的值 73 | valMu sync.RWMutex 74 | CMD string // 命令名称 75 | Seqno string // 请求唯一标识符 76 | RawData json.RawMessage // 请求原始数据 data 77 | SessionID uint64 // 会话ID 78 | Socket *tcp.Conn // 长连接对象 79 | RawMessage *tcp.Message // 原始消息对象 80 | Engine *tcp.Server // 引擎 81 | Server *ServerSrv // 服务器对象 82 | Payload []byte // 请求原始消息报文 83 | Request *TCPRequest // 已解析的请求数据 84 | Response *TCPResponse // 响应数据 85 | logger tcp.LoggerI 86 | handlers []TCPHandler // 当前请求上下文的处理器 87 | handlerIndex int // 当前中间件处理 88 | isAbort bool // 是否已停止继续执行中间件和处理函数 89 | sendMu sync.Mutex 90 | } 91 | 92 | // OK 响应成功数据 93 | func (c *Context) OK(args ...interface{}) error { 94 | c.Response.Status = util.CodeSuccess 95 | c.Response.Msg = "ok" 96 | 97 | if len(args) > 0 { 98 | return c.Response.SetJSON(args[0]) 99 | } else { 100 | c.Response.Data = []byte("null") 101 | return nil 102 | } 103 | 104 | } 105 | 106 | // Bind 解析并 JSON 绑定 data 数据到结构体,并验证数据正确性 107 | func (c *Context) Bind(v interface{}) error { 108 | return c.Request.BindJSON(v) 109 | } 110 | 111 | // Get 获取会话的值 112 | func (c *Context) Get(key string) interface{} { 113 | c.valMu.RLock() 114 | defer c.valMu.RUnlock() 115 | return c.Values[key] 116 | } 117 | 118 | // Set 设置会话的上下文的值,注意设置的值在整个会话生效,不仅仅在本次上下文请求而已 119 | func (c *Context) Set(key string, v interface{}) { 120 | c.valMu.Lock() 121 | c.Values[key] = v 122 | c.valMu.Unlock() 123 | } 124 | 125 | // Abort 停止后面的处理函数和中间件执行 126 | func (c *Context) Abort() { 127 | c.isAbort = true 128 | } 129 | 130 | // Push 服务器主动推送消息至该连接的客户端 131 | func (c *Context) Push(data *TCPResponse) error { 132 | return c.Server.Push(c.SessionID, data) 133 | } 134 | 135 | func (c *Context) Next() { 136 | if c.isAbort { 137 | return 138 | } 139 | if c.handlerIndex < len(c.handlers)-1 { 140 | c.handlerIndex++ 141 | c.handlers[c.handlerIndex](c) 142 | } else { 143 | c.Abort() 144 | } 145 | } 146 | 147 | func (c *Context) writeResponse() error { 148 | c.sendMu.Lock() 149 | defer c.sendMu.Unlock() 150 | payload, err := json.Marshal(c.Response) 151 | if err != nil { 152 | return err 153 | } 154 | c.logger.Infof("[TCP] --> SEND CMD=%s data=%s", c.CMD, string(payload)) 155 | return c.RawMessage.Response(payload) 156 | } 157 | 158 | func convertDataToJsonByte(data interface{}) ([]byte, error) { 159 | var bodyData []byte 160 | if data == nil { 161 | bodyData = []byte("null") 162 | } else if _bodyData, ok := data.([]byte); ok { 163 | bodyData = _bodyData 164 | } else if _bodyData, ok := data.(string); ok { 165 | bodyData = []byte(_bodyData) 166 | } else { 167 | _bodyData, err := json.Marshal(data) 168 | if err != nil { 169 | return nil, err 170 | } 171 | bodyData = _bodyData 172 | } 173 | return bodyData, nil 174 | } -------------------------------------------------------------------------------- /tcp/tcpsrv/readme.md: -------------------------------------------------------------------------------- 1 | # TCP 服务 2 | 3 | 开箱即用的 TCP 服务 4 | 5 | ## 协议 6 | 7 | #### 心跳 8 | 9 | 心跳包为长度为 0 的空数据包,最长时间 30s 发一次,否则链接将会被断开 10 | 11 | #### 请求响应数据 12 | 13 | 请求和响应的数据必须按照该协议来 14 | 15 | **请求数据** 16 | 17 | ```json 18 | { 19 | "cmd": "register", 20 | "seqno": "unique string", 21 | "data": {} 22 | } 23 | ``` 24 | 25 | * cmd 命令名称 26 | * seqno 请求标识符 27 | * data 请求数据 28 | 29 | 30 | **响应数据** 31 | 32 | ```json 33 | { 34 | "cmd": "register", 35 | "seqno": "unique string", 36 | "msg": "ok", 37 | "status": 0, 38 | "data": {} 39 | } 40 | ``` 41 | 42 | * cmd 命令名称,原样返回 43 | * seqno 请求标识符,原样返回 44 | * msg 处理后的消息,如果消息是处理成功的,默认都是 ok 45 | * status 错误状态码,0 为成功,非 0 为失败 46 | * data 响应数据 47 | 48 | # 使用 49 | 50 | 示例概览 51 | 52 | ### 服务器 53 | 54 | ```go 55 | package main 56 | 57 | import ( 58 | "github.com/go-eyas/toolkit/tcp" 59 | "github.com/go-eyas/toolkit/tcp/tcpsrv" 60 | "fmt" 61 | ) 62 | 63 | func main() { 64 | server, err := tcpsrv.NewServerSrv(&tcp.Config{ 65 | Network:":6700", 66 | 67 | // 自定义tcp数据包协议,实现下方两个方法即可 68 | // 将业务数据封装成tcp数据包 69 | // Packer: func(data []byte) ([]byte, error) {}, 70 | // 将 tcp 连接收到的数据包解析成业务数据,返回的业务数据必须符合上方定义的 json 数据 71 | // Parser: func(conn *tcp.Conn, pack []byte) ( [][]byte, error) {}, 72 | }) 73 | if err != nil { 74 | panic(err) 75 | } 76 | 77 | // log 中间件 78 | server.Use(func(c *tcpsrv.Context) { 79 | fmt.Printf("TCP 收到 cmd=%s seqno=%s data=%s\n", c.CMD, c.Seqno, string(c.Payload)) 80 | c.Next() 81 | fmt.Printf("TCP 响应 cmd=%s seqno=%s data=%s\n", c.CMD, c.Seqno, string(c.Response.Data)) 82 | }) 83 | 84 | // 验证中间件 85 | server.Use(func(c *tcpsrv.Context) { 86 | if c.CMD != "register" { 87 | _, ok := c.Get("uid").(int64) 88 | if !ok { 89 | c.Response.Msg = "permission defined" 90 | c.Response.Status = 401 91 | c.Abort() // 停止后面的中间件执行 92 | return 93 | } 94 | } 95 | c.Next() // 如后续无操作,可省略 96 | }) 97 | 98 | server.Handle("register", func(c *tcpsrv.Context) { 99 | body := &struct { 100 | UID int64 `json:"uid"` 101 | }{} 102 | err := c.Bind(body) // 绑定json数据 103 | if err != nil { 104 | panic(err) // 在 Handle panic 后不会导致程序异常,会响应错误数据到客户端 105 | } 106 | c.Set("uid", body.UID) // 设置该连接的会话值 107 | c.OK() 108 | }) 109 | 110 | server.Handle("userinfo", func(c *tcpsrv.Context) { 111 | uid := c.Get("uid").(int64) // 获取会话值 112 | c.OK(findUserByUID(uid)) // OK 可设置响应数据,如果不设置 113 | }) 114 | } 115 | ``` 116 | 117 | ### 客户端 118 | 119 | 客户端是该协议的实现,在符合上述协议的服务器都可使用 120 | 121 | ```go 122 | package main 123 | 124 | import ( 125 | "github.com/go-eyas/toolkit/tcp" 126 | "github.com/go-eyas/toolkit/tcp/tcpsrv" 127 | "fmt" 128 | ) 129 | 130 | func main() { 131 | client, err := tcpsrv.NewClientSrv(&tcp.Config{ 132 | Addr: ":6601", 133 | 134 | // 自定义tcp数据包协议,实现下方两个方法即可 135 | // 将业务数据封装成tcp数据包 136 | // Packer: func(data []byte) ([]byte, error) {}, 137 | // 将 tcp 连接收到的数据包解析成业务数据,返回的业务数据必须符合上方定义的 json 数据 138 | // Parser: func(conn *tcp.Conn, pack []byte) ( [][]byte, error) {}, 139 | }) 140 | if err != nil { 141 | panic(err) 142 | } 143 | 144 | // 每当服务器发送了数据过来,都会以 cmd 作为时间名触发事件 145 | client.On("register", func(response *tcpsrv.TCPResponse) { 146 | fmt.Println("on receive register msg:", response) 147 | }) 148 | 149 | client.On("userinfo", func(response *tcpsrv.TCPResponse) { 150 | fmt.Println("on receive userinfo msg:", response) 151 | }) 152 | 153 | // send 发送后,会等待服务器的响应,res 为服务器的响应数据 154 | res, err := client.Send("register", map[string]interface{}{ 155 | "uid": 1234, 156 | }) 157 | if err != nil { 158 | panic(err) 159 | } 160 | fmt.Println("send register response: ", res) 161 | 162 | res, err = client.Send("userinfo") 163 | if err != nil { 164 | panic(err) 165 | } 166 | // 响可直接解析绑定 data 数据 167 | res.BindJSON(&struct { 168 | UID int64 169 | }{}) 170 | fmt.Println("send userinfo response: ", res) 171 | } 172 | ``` 173 | 174 | # API 175 | 176 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/tcp/tcpsrv) -------------------------------------------------------------------------------- /tcp/tcpsrv/server.go: -------------------------------------------------------------------------------- 1 | package tcpsrv 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/go-eyas/toolkit/tcp" 7 | "github.com/go-eyas/toolkit/util" 8 | "runtime/debug" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | type TCPHandler func(*Context) 14 | 15 | type ServerSrv struct { 16 | Engine *tcp.Server 17 | Config *tcp.Config 18 | logger tcp.LoggerI 19 | routes map[string][]TCPHandler // 路由 20 | Session map[uint64]map[string]interface{} // map[sid]SessionData 21 | sessionMu sync.Mutex 22 | heartbeat *sync.Map // 心跳 23 | handlerMiddlewares []TCPHandler // 中间件 24 | } 25 | 26 | func NewServerSrv(conf *tcp.Config) (*ServerSrv, error) { 27 | engine, err := tcp.NewServer(conf) 28 | if err != nil { 29 | return nil, err 30 | } 31 | srv := &ServerSrv{ 32 | Engine: engine, 33 | Config: conf, 34 | logger: engine.Config.Logger, 35 | routes: make(map[string][]TCPHandler), 36 | Session: make(map[uint64]map[string]interface{}), 37 | heartbeat: &sync.Map{}, 38 | handlerMiddlewares: make([]TCPHandler, 0), 39 | } 40 | 41 | srv.Engine.HandleCreate(srv.onCreate) 42 | srv.Engine.HandleClose(srv.onClose) 43 | 44 | go srv.receive() 45 | go srv.checkHeartbeat() 46 | 47 | return srv, nil 48 | } 49 | 50 | func (srv *ServerSrv) onCreate(conn *tcp.Conn) { 51 | sid := conn.ID 52 | srv.heartbeat.Store(sid, time.Now().Unix()) 53 | srv.sessionMu.Lock() 54 | srv.Session[sid] = make(map[string]interface{}) 55 | srv.sessionMu.Unlock() 56 | srv.logger.Infof("[TCP] New conn id=%d", conn.ID) 57 | } 58 | func (srv *ServerSrv) onClose(conn *tcp.Conn) { 59 | sid := conn.ID 60 | srv.Destroy(sid) 61 | srv.logger.Infof("[TCP] CLOSE conn id=%d", conn.ID) 62 | } 63 | 64 | func (srv *ServerSrv) handlerReceive(req *tcp.Message) { 65 | conn := req.Conn 66 | srv.heartbeat.Store(conn.ID, time.Now().Unix()) 67 | 68 | // 心跳包 69 | if len(req.Data) == 0 { 70 | return 71 | } 72 | 73 | 74 | ctx := &Context{ 75 | SessionID: conn.ID, 76 | Socket: conn, 77 | RawMessage: req, 78 | Engine: srv.Engine, 79 | Payload: req.Data, 80 | Request: &TCPRequest{}, 81 | Server: srv, 82 | logger: srv.logger, 83 | } 84 | srv.sessionMu.Lock() 85 | vals, ok := srv.Session[conn.ID] 86 | if !ok { 87 | srv.Session[conn.ID] = make(map[string]interface{}) 88 | vals = srv.Session[conn.ID] 89 | } 90 | srv.sessionMu.Unlock() 91 | ctx.Values = vals 92 | ctx.logger.Infof("[TCP] <-- RECV CMD=%s data=%s", ctx.CMD, string(ctx.Payload)) 93 | 94 | err := json.Unmarshal(ctx.Payload, ctx.Request) 95 | if err != nil { 96 | ctx.logger.Errorf("TCP request json parse error: %v", err) 97 | return 98 | } 99 | ctx.CMD = ctx.Request.CMD 100 | ctx.Seqno = ctx.Request.Seqno 101 | ctx.Response = &TCPResponse{ 102 | CMD: ctx.Request.CMD, 103 | Seqno: ctx.Request.Seqno, 104 | Status: -1, 105 | Msg: "not implement", 106 | // Data: map[string]interface{}{}, 107 | } 108 | 109 | defer func() { 110 | if err := recover(); err != nil { 111 | srv.logger.Errorf("%v", err) 112 | debug.PrintStack() 113 | r := util.ParseError(err) 114 | ctx.Response.Status = r.Status 115 | ctx.Response.Msg = r.Msg 116 | ctx.Response.SetJSON(r.Data) 117 | } 118 | ctx.writeResponse() 119 | }() 120 | 121 | handlers := append([]TCPHandler{}, srv.handlerMiddlewares...) 122 | handler, ok := srv.routes[ctx.CMD] 123 | if ok { 124 | handlers = append(handlers, handler...) 125 | } 126 | ctx.handlers = handlers 127 | ctx.handlerIndex = -1 128 | 129 | for !ctx.isAbort && ctx.handlerIndex < len(ctx.handlers) { 130 | ctx.Next() 131 | } 132 | } 133 | 134 | func (srv *ServerSrv) receive() { 135 | ch := srv.Engine.Receive() 136 | for { 137 | req := <-ch 138 | go srv.handlerReceive(req) 139 | } 140 | } 141 | 142 | // 处理器中间件 143 | func (srv *ServerSrv) Use(h ...TCPHandler) { 144 | srv.handlerMiddlewares = append(srv.handlerMiddlewares, h...) 145 | } 146 | 147 | // Handle 注册 CMD 路由监听器 148 | func (srv *ServerSrv) Handle(cmd string, handlers ...TCPHandler) { 149 | h, ok := srv.routes[cmd] 150 | if !ok { 151 | h = handlers 152 | } else { 153 | h = append(h, handlers...) 154 | } 155 | srv.routes[cmd] = h 156 | } 157 | 158 | // Push 服务器推送消息到客户端 159 | func (srv *ServerSrv) Push(sid uint64, data *TCPResponse) error { 160 | conn, ok := srv.Engine.Sockets[sid] 161 | if !ok { 162 | return fmt.Errorf("sid=%d is invalid", sid) 163 | } 164 | if data.Seqno == "" { 165 | data.Seqno = util.RandomStr(8) 166 | } 167 | if data.Status == 0 && data.Msg == "" { 168 | data.Msg = "ok" 169 | } 170 | 171 | payload, err := json.Marshal(data) 172 | if err != nil { 173 | return err 174 | } 175 | return conn.Send(payload) 176 | } 177 | 178 | func (srv *ServerSrv) checkHeartbeat() { 179 | for { 180 | time.Sleep(time.Second) 181 | now := time.Now().Unix() - 30 182 | srv.heartbeat.Range(func(key, val interface{}) bool { 183 | sid, ok := key.(uint64) 184 | if !ok { 185 | return true 186 | } 187 | hbTime, ok := val.(int64) 188 | if !ok { 189 | return true 190 | } 191 | 192 | if hbTime < now { 193 | go srv.onClose(srv.Engine.Sockets[sid]) 194 | } 195 | return true 196 | }) 197 | } 198 | } 199 | func (srv *ServerSrv) Destroy(sid uint64) { 200 | conn, ok := srv.Engine.Sockets[sid] 201 | if !ok { 202 | return 203 | } 204 | srv.heartbeat.Delete(conn.ID) 205 | srv.sessionMu.Lock() 206 | delete(srv.Session, sid) 207 | srv.sessionMu.Unlock() 208 | } 209 | -------------------------------------------------------------------------------- /tcp/tcpsrv/server_test.go: -------------------------------------------------------------------------------- 1 | package tcpsrv 2 | 3 | import ( 4 | "fmt" 5 | "github.com/go-eyas/toolkit/log" 6 | "github.com/go-eyas/toolkit/tcp" 7 | "testing" 8 | ) 9 | 10 | func TestServerSrv(t *testing.T) { 11 | srv, err := NewServerSrv(&tcp.Config{ 12 | Addr: ":6601", 13 | Logger: log.SugaredLogger, 14 | }) 15 | 16 | if err != nil { 17 | panic(err) 18 | } 19 | 20 | srv.Use(func(c *Context) { 21 | fmt.Printf("TCP 收到 cmd=%s seqno=%s data=%s\n", c.CMD, c.Seqno, string(c.Payload)) 22 | c.Next() 23 | fmt.Printf("TCP 响应 cmd=%s seqno=%s data=%s\n", c.CMD, c.Seqno, string(c.Response.Data)) 24 | }) 25 | srv.Use(func(c *Context) { 26 | if c.CMD != "register" { 27 | _, ok := c.Get("uid").(int64) 28 | if !ok { 29 | c.Response.Msg = "this connection is not register" 30 | c.Response.Status = 401 31 | c.Abort() 32 | return 33 | } 34 | } 35 | c.Next() // 如后续无操作,可省略 36 | }) 37 | 38 | srv.Handle("register", func(c *Context) { 39 | body := &struct { 40 | UID int64 `json:"uid"` 41 | }{} 42 | err := c.Bind(body) 43 | if err != nil { 44 | panic(err) 45 | } 46 | c.Set("uid", body.UID) 47 | c.OK() 48 | c.Next() 49 | }) 50 | 51 | srv.Handle("userinfo", func(c *Context) { 52 | uid := c.Get("uid").(int64) 53 | c.OK(uid) 54 | }) 55 | 56 | c := make(chan bool, 0) 57 | <- c 58 | } -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | set -e 4 | 5 | go test -v -count=1 -timeout=80s ./amqp 6 | go test -v -count=1 -timeout=80s ./config 7 | go test -v -count=1 -timeout=80s ./db 8 | go test -v -count=1 -timeout=80s ./email 9 | go test -v -count=1 -timeout=80s ./emit 10 | go test -v -count=1 -timeout=80s ./http 11 | go test -v -count=1 -timeout=80s ./log 12 | go test -v -count=1 -timeout=80s ./redis 13 | #go test -v -count=1 -timeout=80s ./tcp 14 | # go test -v -count=1 -timeout=80 websocket -------------------------------------------------------------------------------- /types/json.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import "encoding/json" 4 | 5 | // JSONString json 字符串 6 | // 在json序列化的时候,会把json字符串转成 object 7 | type JSONString string 8 | 9 | // MarshalJSON 格式化为json字符串的时候,会格式化成 object 10 | func (s JSONString) MarshalJSON() ([]byte, error) { 11 | var data interface{} 12 | json.Unmarshal([]byte(s), &data) 13 | return json.Marshal(data) 14 | } 15 | 16 | func (s JSONString) JSON(v interface{}) error { 17 | return json.Unmarshal([]byte(s), v) 18 | } 19 | 20 | // JSONObj json 对象, 序列化的时候,变成纯字符串 21 | type JSONObj map[string]interface{} 22 | 23 | func (m JSONObj) String() JSONString { 24 | b, _ := json.Marshal(m) 25 | return JSONString(b) 26 | } 27 | 28 | func (m JSONObj) JSON(v interface{}) error { 29 | return m.String().JSON(v) 30 | } 31 | -------------------------------------------------------------------------------- /types/json_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | ) 7 | 8 | func TestJSONString(t *testing.T) { 9 | str := JSONString(`{"demo": true, "num": 123}`) 10 | data := struct { 11 | S JSONString 12 | }{str} 13 | raw, _ := json.Marshal(data) 14 | t.Logf("JSONString marshal: %s", string(raw)) 15 | 16 | data2 := struct { 17 | Demo bool 18 | Num int 19 | }{} 20 | str.JSON(&data2) 21 | t.Logf("JSONString unmarshal: %#v", data2) 22 | } 23 | 24 | func TestJSONObj(t *testing.T) { 25 | obj := JSONObj{ 26 | "demo": true, 27 | "num": 123, 28 | } 29 | 30 | data1 := struct { 31 | Demo bool 32 | Num int 33 | }{} 34 | obj.JSON(data1) 35 | t.Logf("JSONObj to json: %#v", data1) 36 | } 37 | -------------------------------------------------------------------------------- /types/readme.md: -------------------------------------------------------------------------------- 1 | # 黑魔法类型 2 | 3 | 一些带有特殊用途的类型 4 | 5 | ### JSONString 6 | 7 | json 字符串,该类型在转成json字符串的时候,会自动转成object类型,所以要求该类型的值是一个合法的json字符串 8 | 9 | 使用场景:扩展字段 10 | 11 | ```go 12 | import "github.com/go-eyas/toolkit/types" 13 | 14 | var str = types.JSONString(`{"demo": true, "num": 123}`) 15 | 16 | data := struct { 17 | S JSONString 18 | }{str} 19 | 20 | json.Marshal(data) // {"S":{"demo":true,"num":123}} 21 | 22 | data2 := struct{ 23 | Demo bool 24 | Num int 25 | }{} 26 | str.JSON(&data2) // 也可以直接 json 反序列化 27 | ``` 28 | 29 | ### Time 30 | 31 | 时间类型 `time.Time` 的别名,该类型在转成 json 字符串的时候,会把时间格式化成这种格式 2006-01-02 15:04:05 32 | 33 | 结合 gorm 使用,存在数据库的是时间类型,转到接口的是上述时间格式 -------------------------------------------------------------------------------- /types/time.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "time" 5 | 6 | "encoding/json" 7 | ) 8 | 9 | // Time 时间别名,在json序列化的时候,会格式成 2006-01-02 15:04:05 这种时间格式 10 | type Time time.Time 11 | 12 | func (tm Time) MarshalJSON() ([]byte, error) { 13 | s := time.Time(tm).Format("2006-01-02 15:04:05") 14 | return json.Marshal(s) 15 | } 16 | -------------------------------------------------------------------------------- /types/time_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestTime(t *testing.T) { 10 | data := struct { 11 | TM Time 12 | }{Time(time.Now())} 13 | 14 | b, _ := json.Marshal(data) 15 | t.Logf("test timt format: %s", string(b)) 16 | } 17 | -------------------------------------------------------------------------------- /util/bytes.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "io/ioutil" 7 | ) 8 | 9 | // ByteToReader 将字节转换成读取流 10 | func ByteToReader(b []byte) io.Reader { 11 | return bytes.NewReader(b) 12 | } 13 | 14 | // ByteToReadCloser 将字节转换成一次性的读取流 15 | func ByteToReadCloser(b []byte) io.ReadCloser { 16 | return ioutil.NopCloser(ByteToReader(b)) 17 | } 18 | 19 | func BytesCombine(pBytes ...[]byte) []byte { 20 | return bytes.Join(pBytes, []byte("")) 21 | } 22 | -------------------------------------------------------------------------------- /util/crypto.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | 8 | "golang.org/x/crypto/bcrypt" 9 | ) 10 | 11 | func BcryptHash(src string) string { 12 | bt, _ := bcrypt.GenerateFromPassword([]byte(src), bcrypt.MinCost) 13 | return string(bt) 14 | } 15 | 16 | func BcryptVerify(hash, src string) bool { 17 | return bcrypt.CompareHashAndPassword([]byte(hash), []byte(src)) == nil 18 | } 19 | 20 | func AesEncrypt(src, key []byte) ([]byte, error) { 21 | block, err := aes.NewCipher(key) 22 | if err != nil { 23 | return nil, err 24 | } 25 | src = _padding(src, block.BlockSize()) 26 | blockmode := cipher.NewCBCEncrypter(block, key) 27 | blockmode.CryptBlocks(src, src) 28 | return src, nil 29 | } 30 | 31 | func AesDecrypt(src []byte, key []byte) []byte { 32 | block, _ := aes.NewCipher(key) 33 | blockmode := cipher.NewCBCDecrypter(block, key) 34 | blockmode.CryptBlocks(src, src) 35 | src = _unpadding(src) 36 | return src 37 | } 38 | 39 | func _padding(src []byte, blocksize int) []byte { 40 | padnum := blocksize - len(src)%blocksize 41 | pad := bytes.Repeat([]byte{byte(padnum)}, padnum) 42 | return append(src, pad...) 43 | } 44 | 45 | func _unpadding(src []byte) []byte { 46 | n := len(src) 47 | unpadnum := int(src[n-1]) 48 | return src[:n-unpadnum] 49 | } 50 | -------------------------------------------------------------------------------- /util/crypto_test.go: -------------------------------------------------------------------------------- 1 | package util_test 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/go-eyas/toolkit/util" 9 | ) 10 | 11 | func TestAesEecode(t *testing.T) { 12 | password := []byte("asdfghjkqwertyui") 13 | var min int64 = 0 14 | var max int64 = 100 15 | for i := min; i < max+min; i++ { 16 | result, err := util.AesEncrypt([]byte(fmt.Sprintf("%d", i)), password) 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | // t.Log(hex.EncodeToString(result)) // 16进制 21 | t.Log(base64.StdEncoding.EncodeToString(result)) // 16进制 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /util/readme.md: -------------------------------------------------------------------------------- 1 | # 工具函数 2 | 3 | ## godoc 4 | 5 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/util) -------------------------------------------------------------------------------- /util/resp.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/gin-gonic/gin" 7 | ) 8 | 9 | type ErrorData struct { 10 | Code int `json:"-"` // http 状态码 11 | Status int `json:"status"` 12 | Msg string `json:"msg"` 13 | Data interface{} `json:"data"` 14 | } 15 | 16 | // parse 解析响应数据 17 | func ParseError(v interface{}) *ErrorData { 18 | data := &ErrorData{ 19 | Code: http.StatusOK, 20 | Msg: "ok", 21 | Status: 0, 22 | } 23 | switch v.(type) { 24 | case error: 25 | res := v.(error) 26 | data.Code = http.StatusInternalServerError 27 | data.Msg = res.Error() 28 | data.Status = 999999999 29 | data.Data = gin.H{} 30 | 31 | case string: 32 | data.Data = v.(string) 33 | 34 | case gin.H, *gin.H, map[string]interface{}: 35 | var e gin.H 36 | if b, ok := v.(gin.H); ok { 37 | e = b 38 | } else if b, ok := v.(map[string]interface{}); ok { 39 | e = gin.H(b) 40 | } else if b, ok := v.(*gin.H); ok { 41 | e = *b 42 | } 43 | 44 | resCode := e["code"] 45 | if resCode == nil { 46 | resCode = http.StatusOK 47 | } 48 | 49 | resStatus := e["status"] 50 | if resStatus == nil { 51 | resStatus = 0 52 | } 53 | 54 | resMsg := e["msg"] 55 | if resMsg == nil { 56 | resMsg = "ok" 57 | } else if errmsgError, ok := resMsg.(error); ok { 58 | resMsg = errmsgError.Error() 59 | } 60 | 61 | resData := e["data"] 62 | if resData == nil { 63 | resData = gin.H{} 64 | } 65 | 66 | data = &ErrorData{ 67 | Code: resCode.(int), 68 | Status: resStatus.(int), 69 | Msg: resMsg.(string), 70 | Data: resData, 71 | } 72 | 73 | case ErrorData, *ErrorData: 74 | if b, ok := v.(ErrorData); ok { 75 | data = &b 76 | } else { 77 | data = v.(*ErrorData) 78 | } 79 | default: 80 | data.Data = v 81 | } 82 | 83 | return data 84 | } 85 | -------------------------------------------------------------------------------- /util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "math/rand" 7 | "os" 8 | "reflect" 9 | "runtime" 10 | 11 | "github.com/rs/xid" 12 | ) 13 | 14 | // Assert 断言 err != nil 15 | func Assert(err error, msg interface{}) { 16 | if err != nil { 17 | panic(msg) 18 | } 19 | } 20 | 21 | var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789$@") 22 | 23 | // RandomStr 生成随机字符串 24 | func RandomStr(length int) string { 25 | var lenthLetter = len(letterRunes) 26 | 27 | b := make([]rune, length) 28 | for i := range b { 29 | b[i] = letterRunes[rand.Intn(lenthLetter)] 30 | } 31 | return string(b) 32 | } 33 | 34 | // Base64Encoding base64 编码 35 | func Base64Encoding(str string) string { 36 | encoded := base64.StdEncoding.EncodeToString([]byte(str)) 37 | return encoded 38 | } 39 | 40 | // Base64Decoding base64 解码 41 | func Base64Decoding(enc string) (string, error) { 42 | decoded, err := base64.StdEncoding.DecodeString(enc) 43 | if err != nil { 44 | return "", err 45 | } 46 | return string(decoded), nil 47 | } 48 | 49 | // FuncName 获取函数的名字 50 | func FuncName(f interface{}) string { 51 | return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() 52 | } 53 | 54 | // AssignMap 合并多个map 55 | func AssignMap(maps ...map[interface{}]interface{}) map[interface{}]interface{} { 56 | m := map[interface{}]interface{}{} 57 | 58 | for _, mp := range maps { 59 | for key, val := range mp { 60 | m[key] = val 61 | } 62 | } 63 | 64 | return m 65 | } 66 | 67 | // ToString 把能转成字符串的都转成JSON字符串 68 | func ToString(v interface{}) string { 69 | bt, _ := json.Marshal(v) 70 | return string(bt) 71 | } 72 | 73 | // HasFile 是否存在该文件 74 | func HasFile(f string) bool { 75 | if _, err := os.Stat(f); !os.IsNotExist(err) { 76 | return true 77 | } 78 | return false 79 | } 80 | 81 | // StructToMap 把结构体转成map,key使用json定义的key 82 | func StructToMap(v interface{}) map[string]interface{} { 83 | data := map[string]interface{}{} 84 | bt, _ := json.Marshal(v) 85 | _ = json.Unmarshal(bt, &data) 86 | return data 87 | } 88 | 89 | // ToStruct 把一个结构体转成另一个结构体,以json key作为关联 90 | func ToStruct(raw interface{}, v interface{}) error { 91 | var err error 92 | var bt []byte 93 | if sraw, ok := raw.(string); ok { 94 | bt = []byte(sraw) 95 | } else if braw, ok := raw.([]byte); ok { 96 | bt = braw 97 | } else { 98 | bt, err = json.Marshal(raw) 99 | if err != nil { 100 | return err 101 | } 102 | } 103 | return json.Unmarshal(bt, v) 104 | } 105 | 106 | // XID 生成一个全局唯一的 id 字符串 107 | // 实际上是改造后的 uuid.v4 108 | func XID() string { 109 | return xid.New().String() 110 | } 111 | -------------------------------------------------------------------------------- /util/util_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | -------------------------------------------------------------------------------- /websocket/conn.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "github.com/gorilla/websocket" 5 | ) 6 | 7 | // Conn 连接实例 8 | type Conn struct { 9 | Socket *websocket.Conn // 连接 10 | ws *WS // ws 服务 11 | isClose bool 12 | ID uint64 13 | } 14 | 15 | // Init 初始化该连接 16 | func (c *Conn) Init() { 17 | c.reader() 18 | } 19 | 20 | func (c *Conn) reader() error { 21 | for { 22 | mType, mRaw, err := c.Socket.ReadMessage() 23 | if err != nil { 24 | return err 25 | } 26 | logger.Infof("websocket: receive data=%s", string(mRaw)) 27 | msg := &Message{c.ID, mRaw, c.ws, c, mType} 28 | c.ws.recC <- msg 29 | } 30 | } 31 | 32 | // Send 往该连接发送数据 33 | func (c *Conn) Send(msg *Message) error { 34 | m := &(*msg) 35 | m.ws = c.ws 36 | m.Socket = c 37 | return m.writer() 38 | } 39 | 40 | // Destroy 销毁该连接 41 | func (c *Conn) Destroy() error { 42 | if c.isClose { 43 | return nil 44 | } 45 | err := c.Socket.Close() 46 | if err != nil { 47 | return err 48 | } 49 | c.isClose = true 50 | 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /websocket/logger.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | type LoggerI interface { 4 | Info(...interface{}) 5 | Infof(string, ...interface{}) 6 | Error(...interface{}) 7 | Errorf(string, ...interface{}) 8 | } 9 | 10 | type l struct{} 11 | 12 | func (l) Info(v ...interface{}) {} 13 | func (l) Infof(s string, v ...interface{}) {} 14 | func (l) Error(v ...interface{}) {} 15 | func (l) Errorf(s string, v ...interface{}) {} 16 | 17 | var EmptyLogger = &l{} 18 | 19 | var logger LoggerI = EmptyLogger 20 | -------------------------------------------------------------------------------- /websocket/message.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | // Message ws 接收到的消息 8 | type Message struct { 9 | SID uint64 10 | Payload []byte 11 | ws *WS 12 | Socket *Conn 13 | MsgType int 14 | } 15 | 16 | func (m Message) clone() *Message { 17 | m1 := m 18 | return &m1 19 | } 20 | 21 | func (m *Message) writer() error { 22 | if m.Socket.isClose { 23 | return errors.New("socket is close") 24 | } 25 | err := m.Socket.Socket.WriteMessage(m.MsgType, m.Payload) 26 | if err != nil { 27 | return err 28 | } 29 | return nil 30 | } 31 | 32 | // Response 在发送本消息的当前连接发送数据 33 | func (m *Message) Response(v []byte) error { 34 | resMsg := m.clone() 35 | resMsg.Payload = v 36 | return resMsg.writer() 37 | } 38 | -------------------------------------------------------------------------------- /websocket/readme.md: -------------------------------------------------------------------------------- 1 | # websocket 2 | 3 | # 使用 4 | 5 | ```go 6 | import ( 7 | "net/http" 8 | "github.com/go-eyas/toolkit/websocket" 9 | ) 10 | func main() { 11 | ws := websocket.New(&Config{ 12 | MsgType: websocket.TextMessage, // 消息类型 websocket.TextMessage | websocke.BinaryMessage 13 | }) 14 | 15 | http.HandleFunc("/ws", ws.HTTPHandler) 16 | 17 | go func() { 18 | rec := ws.Receive() 19 | for { 20 | req, _ := <-rec 21 | req.Response([]byte("1234556")) 22 | } 23 | }() 24 | 25 | http.ListenAndServe("127.0.0.1:8800", nil) 26 | } 27 | ``` 28 | 29 | # 服务 30 | 31 | 已经准备了一个开箱即用的服务,该服务按照特定协议工作,[详情请查看](./wsrv) 32 | 33 | 示例概览 34 | 35 | ```go 36 | import ( 37 | "net/http" 38 | "github.com/go-eyas/toolkit/websocket" 39 | "github.com/go-eyas/toolkit/websocket/wsrv" 40 | ) 41 | func main() { 42 | server := wsrv.New(&Config{ 43 | MsgType: websocket.TextMessage, // 消息类型 websocket.TextMessage | websocke.BinaryMessage 44 | }) 45 | server.Use(func(c *wsrv.Context) { 46 | if c.CMD != "register" { 47 | _, ok := c.Get("uid").(int) 48 | if !ok { 49 | c.Abort() 50 | } 51 | } 52 | }) 53 | 54 | server.Handle("register", func(c *wsrv.Context) { 55 | c.Set("uid", 1001) 56 | c.OK() 57 | }) 58 | server,Handle("userinfo", func(c *wsrv.Context) { 59 | uid := c.Get("uid").(int) 60 | c.OK(GetUserInfoByID(uid)) 61 | }) 62 | 63 | http.HandleFunc("/ws", server.Engine.HTTPHandler) 64 | http.ListenAndServe("127.0.0.1:8800", nil) 65 | } 66 | ``` 67 | 68 | ## 协议 69 | 70 | 71 | 72 | ## godoc 73 | 74 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/websocket) -------------------------------------------------------------------------------- /websocket/websocket.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "html/template" 5 | "net/http" 6 | "sync" 7 | 8 | "github.com/gorilla/websocket" 9 | ) 10 | 11 | const ( 12 | // 文本消息 13 | TextMessage = 1 14 | 15 | // 二进制数据消息 16 | BinaryMessage = 2 17 | ) 18 | 19 | // Config 配置项 20 | type Config struct { 21 | MsgType int // 消息类型 TextMessage | BinaryMessage 22 | ReadBufferSize int // 读取缓存大小 23 | WriteBufferSize int // 写入缓存大小 24 | CheckOrigin func(*http.Request) bool // 检查跨域来源是否允许建立连接 25 | Logger LoggerI // 用于打印内部产生日志 26 | } 27 | 28 | // New 新建 websocket 服务 29 | func New(conf *Config) *WS { 30 | if conf.MsgType == 0 { 31 | conf.MsgType = websocket.TextMessage 32 | } 33 | if conf.CheckOrigin == nil { 34 | conf.CheckOrigin = func(r *http.Request) bool { return true } 35 | } 36 | 37 | ws := &WS{ 38 | MsgType: conf.MsgType, 39 | Clients: make(map[uint64]*Conn), 40 | recC: make(chan *Message, 1024), 41 | logger: conf.Logger, 42 | createHandlers: make([]EventHandle, 0), 43 | closeHandlers: make([]EventHandle, 0), 44 | } 45 | 46 | if ws.logger == nil { 47 | ws.logger = EmptyLogger 48 | } 49 | 50 | ws.Upgrader = &websocket.Upgrader{ 51 | ReadBufferSize: conf.ReadBufferSize, 52 | WriteBufferSize: conf.WriteBufferSize, 53 | CheckOrigin: conf.CheckOrigin, 54 | } 55 | ws.logger.Info("websocket: init websocket") 56 | 57 | return ws 58 | } 59 | 60 | type EventHandle func(*Conn) 61 | 62 | // WS ws 连接 63 | type WS struct { 64 | Clients map[uint64]*Conn 65 | Upgrader *websocket.Upgrader 66 | id uint64 67 | MsgType int 68 | recC chan *Message 69 | logger LoggerI 70 | createHandlers []EventHandle 71 | closeHandlers []EventHandle 72 | } 73 | 74 | var connMu sync.RWMutex 75 | 76 | // HTTPHandler 给 http 控制器绑定使用 77 | func (ws *WS) HTTPHandler(w http.ResponseWriter, r *http.Request) { 78 | conn, err := ws.Connect(w, r) 79 | if err != nil { 80 | return 81 | } 82 | conn.Init() 83 | defer ws.DestroyConn(ws.id) 84 | } 85 | 86 | // 从 http 连接获取连接实例 87 | func (ws *WS) Connect(w http.ResponseWriter, r *http.Request) (*Conn, error) { 88 | socket, err := ws.Upgrader.Upgrade(w, r, nil) 89 | if err != nil { 90 | return nil, err 91 | } 92 | ws.id++ 93 | 94 | conn := &Conn{ 95 | Socket: socket, 96 | ws: ws, 97 | ID: ws.id, 98 | } 99 | 100 | ws.logger.Infof("websocket: new websocket connect create: sid=%d", conn.ID) 101 | 102 | connMu.Lock() 103 | ws.Clients[conn.ID] = conn 104 | connMu.Unlock() 105 | 106 | // send init message 107 | for _, createH := range ws.createHandlers { 108 | createH(conn) 109 | } 110 | 111 | return conn, nil 112 | } 113 | 114 | var page = template.Must(template.New("").Parse(` 115 | 116 | 117 | 118 | 119 | 167 | 168 | 169 | 170 |
171 |

Click "Open" to create a connection to the server, 172 | "Send" to send a message to the server and "Close" to close the connection. 173 | You can change the message and send multiple times. 174 |

175 |

176 |

addr: 177 | 178 |

179 | 180 | 181 |

182 | 183 |

184 |
185 |
186 |
187 | 188 | 189 | `)) 190 | 191 | func (ws *WS) Playground(w http.ResponseWriter, r *http.Request) { 192 | w.Header().Add("Content-Type", "text/html") 193 | err := page.Execute(w, "ws://"+r.Host+"/ws") 194 | if err != nil { 195 | panic(err) 196 | } 197 | } 198 | 199 | // Receive 获取接收数据的 chan 200 | func (ws *WS) Receive() <-chan *Message { 201 | return ws.recC 202 | } 203 | 204 | // Send 发送数据 205 | func (ws *WS) Send(msg *Message) error { 206 | m := &(*msg) 207 | m.ws = ws 208 | return m.writer() 209 | } 210 | 211 | func (ws *WS) HandleClose(fn EventHandle) { 212 | ws.closeHandlers = append(ws.closeHandlers, fn) 213 | } 214 | 215 | func (ws *WS) HandleCreate(fn EventHandle) { 216 | ws.createHandlers = append(ws.createHandlers, fn) 217 | } 218 | 219 | // DestroyConn 销毁连接 220 | func (ws *WS) DestroyConn(cid uint64) { 221 | conn, ok := ws.Clients[ws.id] 222 | if !ok { 223 | return 224 | } 225 | for _, closeH := range ws.closeHandlers { 226 | closeH(conn) 227 | } 228 | 229 | conn.Destroy() 230 | 231 | connMu.Lock() 232 | delete(ws.Clients, cid) 233 | connMu.Unlock() 234 | ws.logger.Info("websocket: destroy ws connect") 235 | } 236 | -------------------------------------------------------------------------------- /websocket/websocket_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | ) 7 | 8 | func TestWS(t *testing.T) { 9 | ws := New(&Config{ 10 | Logger: logger, 11 | }) 12 | 13 | http.HandleFunc("/ws", ws.HTTPHandler) 14 | http.HandleFunc("/", ws.Playground) 15 | 16 | go func() { 17 | rec := ws.Receive() 18 | for { 19 | req, _ := <-rec 20 | req.Response([]byte("1234556")) 21 | } 22 | }() 23 | 24 | // 浏览器打开 http://127.0.0.1:8800 测试 25 | t.Fatal(http.ListenAndServe("127.0.0.1:8800", nil)) 26 | } 27 | -------------------------------------------------------------------------------- /websocket/wsrv/context.go: -------------------------------------------------------------------------------- 1 | package wsrv 2 | 3 | import ( 4 | "encoding/json" 5 | "sync" 6 | 7 | "github.com/go-eyas/toolkit/gin/util" 8 | "github.com/go-eyas/toolkit/websocket" 9 | "github.com/go-playground/validator/v10" 10 | ) 11 | 12 | var validate = validator.New() 13 | 14 | // WSRequest 请求数据 15 | type WSRequest struct { 16 | CMD string `json:"cmd"` 17 | Seqno string `json:"seqno"` 18 | Data json.RawMessage `json:"data"` 19 | } 20 | 21 | // WSResponse 响应数据 22 | type WSResponse struct { 23 | CMD string `json:"cmd"` 24 | Seqno string `json:"seqno"` 25 | Status int `json:"status"` 26 | Msg string `json:"msg"` 27 | Data interface{} `json:"data"` 28 | } 29 | 30 | // Context 请求上下文 31 | type Context struct { 32 | Values map[string]interface{} // 该会话注册的值 33 | valMu sync.RWMutex 34 | CMD string // 命令名称 35 | Seqno string // 请求唯一标识符 36 | RawData json.RawMessage // 请求原始数据 data 37 | SessionID uint64 // 会话ID 38 | Socket *websocket.Conn // 长连接对象 39 | RawMessage *websocket.Message // 原始消息对象 40 | Engine *websocket.WS // 引擎 41 | Server *WebsocketServer // 服务器对象 42 | Payload []byte // 请求原始消息报文 43 | Request *WSRequest // 已解析的请求数据 44 | Response *WSResponse // 响应数据 45 | logger websocket.LoggerI 46 | handlers []WSHandler // 当前请求上下文的处理器 47 | handlerIndex int // 当前中间件处理 48 | isAbort bool // 是否已停止继续执行中间件和处理函数 49 | sendMu sync.Mutex 50 | } 51 | 52 | // OK 响应成功数据 53 | func (c *Context) OK(args ...interface{}) { 54 | c.Response.Status = util.CodeSuccess 55 | c.Response.Msg = "ok" 56 | 57 | if len(args) > 0 { 58 | c.Response.Data = args[0] 59 | } 60 | } 61 | 62 | // Bind 解析并 JSON 绑定 data 数据到结构体,并验证数据正确性 63 | func (c *Context) Bind(v interface{}) error { 64 | err := json.Unmarshal(c.Request.Data, v) 65 | if err != nil { 66 | return err 67 | } 68 | return validate.Struct(v) 69 | } 70 | 71 | // Get 获取会话的值 72 | func (c *Context) Get(key string) interface{} { 73 | c.valMu.RLock() 74 | defer c.valMu.RUnlock() 75 | return c.Values[key] 76 | } 77 | 78 | // Set 设置会话的上下文的值,注意设置的值在整个会话生效,不仅仅在本次上下文请求而已 79 | func (c *Context) Set(key string, v interface{}) { 80 | c.valMu.Lock() 81 | c.Values[key] = v 82 | c.valMu.Unlock() 83 | } 84 | 85 | // Abort 停止后面的处理函数和中间件执行 86 | func (c *Context) Abort() { 87 | c.isAbort = true 88 | } 89 | 90 | // Push 服务器主动推送消息至该连接的客户端 91 | func (c *Context) Push(data *WSResponse) error { 92 | return c.Server.Push(c.SessionID, data) 93 | } 94 | 95 | // 调用下个中间件 96 | func (c *Context) Next() { 97 | if c.isAbort { 98 | return 99 | } 100 | if c.handlerIndex < len(c.handlers)-1 { 101 | c.handlerIndex++ 102 | c.handlers[c.handlerIndex](c) 103 | } else { 104 | c.Abort() 105 | } 106 | } 107 | 108 | func (c *Context) writeResponse() error { 109 | c.sendMu.Lock() 110 | defer c.sendMu.Unlock() 111 | payload, err := json.Marshal(c.Response) 112 | if err != nil { 113 | return err 114 | } 115 | c.logger.Infof("[WS] --> SEND CMD=%s data=%s", c.CMD, string(payload)) 116 | return c.RawMessage.Response(payload) 117 | } 118 | -------------------------------------------------------------------------------- /websocket/wsrv/readme.md: -------------------------------------------------------------------------------- 1 | # websocket 服务 2 | 3 | 开箱即用的 websocket 服务 4 | 5 | ## 协议 6 | 7 | #### 心跳 8 | 9 | 心跳包为长度为 0 的空数据包,最长时间 30s 发一次,否则链接将会被断开 10 | 11 | #### 请求响应数据 12 | 13 | 请求和响应的数据必须按照该协议来 14 | 15 | **请求数据** 16 | 17 | ```json 18 | { 19 | "cmd": "register", 20 | "seqno": "unique string", 21 | "data": {} 22 | } 23 | ``` 24 | 25 | * cmd 命令名称 26 | * seqno 请求标识符 27 | * data 请求数据 28 | 29 | 30 | **响应数据** 31 | 32 | ```json 33 | { 34 | "cmd": "register", 35 | "seqno": "unique string", 36 | "msg": "ok", 37 | "status": 0, 38 | "data": {} 39 | } 40 | ``` 41 | 42 | * cmd 命令名称,原样返回 43 | * seqno 请求标识符,原样返回 44 | * msg 处理后的消息,如果消息是处理成功的,默认都是 ok 45 | * status 错误状态码,0 为成功,非 0 为失败 46 | * data 响应数据 47 | 48 | ## 使用 49 | 50 | 示例概览 51 | 52 | ```go 53 | package main 54 | 55 | import ( 56 | "net/http" 57 | "github.com/go-eyas/toolkit/log" 58 | "github.com/go-eyas/toolkit/websocket" 59 | "github.com/go-eyas/toolkit/websocket/wsrv" 60 | ) 61 | func main() { 62 | server := wsrv.New(&websocket.Config{ 63 | MsgType: websocket.TextMessage, // 消息类型 websocket.TextMessage | websocke.BinaryMessage 64 | }) 65 | 66 | server.Use(func(c *wsrv.Context) { 67 | log.Debugf("ws request middleware, sid=%d, cmd=%s, data=%s", c.SessionID, c.CMD, string(c.Request.Data)) 68 | c.Next() 69 | log.Debugf("ws response middleware, sid=%d cmd=%s, data=%v", c.SessionID, c.CMD, c.Response.Data) 70 | }) 71 | server.Use(func(c *wsrv.Context) { 72 | if c.CMD != "register" { 73 | _, ok := c.Get("uid").(int) 74 | if !ok { 75 | 76 | c.Abort() 77 | } 78 | } 79 | }) 80 | 81 | server.Handle("register", func(c *wsrv.Context) { 82 | body := &struct { 83 | UID int64 84 | }{} 85 | err := c.Bind(body) 86 | if err != nil { 87 | panic(err) 88 | } 89 | c.Set("uid", body.UID) 90 | 91 | // server push 92 | for sid, vals := range server.Session { 93 | if uid, ok := vals["uid"]; ok { 94 | server.Push(sid, &wsrv.WSResponse{ 95 | CMD: "have_user_register", 96 | Data: map[string]interface{}{ 97 | "uid": uid, 98 | }, 99 | }) 100 | } 101 | } 102 | 103 | // server push current connection 104 | c.Push(&wsrv.WSResponse{ 105 | CMD: "user_register", 106 | Data: map[string]interface{}{ 107 | "uid": body.UID, 108 | }, 109 | }) 110 | 111 | c.OK() 112 | }) 113 | server.Handle("userinfo", func(c *wsrv.Context) { 114 | uid := c.Get("uid").(int) 115 | c.OK(GetUserInfoByID(uid)) 116 | }) 117 | 118 | http.HandleFunc("/ws", server.Engine.HTTPHandler) 119 | http.ListenAndServe("127.0.0.1:8800", nil) 120 | } 121 | ``` 122 | 123 | ## API 124 | 125 | [API 文档](https://gowalker.org/github.com/go-eyas/toolkit/websocket/wsrv) -------------------------------------------------------------------------------- /websocket/wsrv/srv.go: -------------------------------------------------------------------------------- 1 | package wsrv 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "runtime/debug" 7 | "sync" 8 | "time" 9 | 10 | "github.com/go-eyas/toolkit/util" 11 | "github.com/go-eyas/toolkit/websocket" 12 | ) 13 | 14 | // WebsocketServer 服务器 15 | type WebsocketServer struct { 16 | Engine *websocket.WS 17 | Config *websocket.Config 18 | logger websocket.LoggerI 19 | routes map[string][]WSHandler // 路由 20 | Session map[uint64]map[string]interface{} // map[sid]SessionData 21 | heartbeat *sync.Map // 心跳 22 | // requestMiddlewares []WSHandler // 请求中间件 23 | // responseMiddlewares []WSHandler // 响应中间件 24 | handlerMiddlewares []WSHandler // 中间件 25 | } 26 | 27 | // WSHandler 请求处理器 28 | type WSHandler func(*Context) 29 | 30 | var sessionMu sync.Mutex 31 | 32 | // New 新建服务器实例 33 | func New(conf *websocket.Config) *WebsocketServer { 34 | if conf.Logger == nil { 35 | conf.Logger = websocket.EmptyLogger 36 | } 37 | if conf.MsgType == 0 { 38 | conf.MsgType = websocket.BinaryMessage 39 | } 40 | ws := websocket.New(conf) 41 | server := &WebsocketServer{ 42 | Config: conf, 43 | logger: conf.Logger, 44 | Engine: ws, 45 | routes: make(map[string][]WSHandler), 46 | Session: make(map[uint64]map[string]interface{}), 47 | heartbeat: &sync.Map{}, 48 | handlerMiddlewares: make([]WSHandler, 0), 49 | // requestMiddlewares: make([]WSHandler, 0), 50 | // responseMiddlewares: make([]WSHandler, 0), 51 | } 52 | 53 | ws.HandleClose(server.onClose) 54 | ws.HandleCreate(server.onCreate) 55 | 56 | go server.receive() 57 | go server.checkHeartbeat() 58 | 59 | return server 60 | } 61 | 62 | func (ws *WebsocketServer) receive() { 63 | ch := ws.Engine.Receive() 64 | for { 65 | req := <-ch 66 | go func(req *websocket.Message) { 67 | ws.heartbeat.Store(req.SID, time.Now().Unix()) 68 | 69 | // 心跳包 70 | if len(req.Payload) == 0 { 71 | return 72 | } 73 | 74 | ctx := &Context{ 75 | SessionID: req.SID, 76 | Socket: req.Socket, 77 | RawMessage: req, 78 | Engine: ws.Engine, 79 | Payload: req.Payload, 80 | Request: &WSRequest{}, 81 | Server: ws, 82 | logger: ws.logger, 83 | } 84 | sessionMu.Lock() 85 | vals, ok := ws.Session[req.SID] 86 | if !ok { 87 | ws.Session[req.SID] = make(map[string]interface{}) 88 | vals = ws.Session[req.SID] 89 | } 90 | sessionMu.Unlock() 91 | ctx.Values = vals 92 | ctx.logger.Infof("[WS] <-- RECV CMD=%s data=%s", ctx.CMD, string(ctx.Payload)) 93 | 94 | err := json.Unmarshal(ctx.Payload, ctx.Request) 95 | if err != nil { 96 | ctx.logger.Errorf("WS request json parse error: %v", err) 97 | return 98 | } 99 | ctx.CMD = ctx.Request.CMD 100 | ctx.Seqno = ctx.Request.Seqno 101 | ctx.Response = &WSResponse{ 102 | CMD: ctx.Request.CMD, 103 | Seqno: ctx.Request.Seqno, 104 | Status: -1, 105 | Msg: "not implement", 106 | Data: map[string]interface{}{}, 107 | } 108 | 109 | defer func() { 110 | if err := recover(); err != nil { 111 | ws.logger.Errorf("%v", err) 112 | debug.PrintStack() 113 | r := util.ParseError(err) 114 | ctx.Response.Status = r.Status 115 | ctx.Response.Msg = r.Msg 116 | ctx.Response.Data = r.Data 117 | } 118 | ctx.writeResponse() 119 | }() 120 | 121 | handlers := append([]WSHandler{}, ws.handlerMiddlewares...) 122 | handler, ok := ws.routes[ctx.CMD] 123 | if ok { 124 | handlers = append(handlers, handler...) 125 | } 126 | ctx.handlers = handlers 127 | ctx.handlerIndex = -1 128 | 129 | for !ctx.isAbort && ctx.handlerIndex < len(ctx.handlers) { 130 | ctx.Next() 131 | } 132 | 133 | // for _, mdl := range ws.requestMiddlewares { 134 | // mdl(ctx) 135 | // if ctx.isAbort { 136 | // break 137 | // } 138 | // } 139 | // if !ctx.isAbort { 140 | // handler, ok := ws.routes[ctx.CMD] 141 | // if !ok { 142 | // return 143 | // } else { 144 | // ctx.Response.Status = 1 145 | // ctx.Response.Msg = "empty implement" 146 | // } 147 | // if !ctx.isAbort { 148 | // for _, h := range handler { 149 | // h(ctx) 150 | // if ctx.isAbort { 151 | // break 152 | // } 153 | // } 154 | // } 155 | // } 156 | // 157 | // for _, mdl := range ws.responseMiddlewares { 158 | // mdl(ctx) 159 | // if ctx.isAbort { 160 | // break 161 | // } 162 | // } 163 | }(req) 164 | } 165 | } 166 | 167 | // // UseRequest 请求中间件 168 | // func (ws *WebsocketServer) UseRequest(h WSHandler) { 169 | // ws.requestMiddlewares = append(ws.requestMiddlewares, h) 170 | // } 171 | // 172 | // // UseResponse 响应中间件 173 | // func (ws *WebsocketServer) UseResponse(h WSHandler) { 174 | // ws.responseMiddlewares = append(ws.responseMiddlewares, h) 175 | // } 176 | 177 | // 处理器中间件 178 | func (srv *WebsocketServer) Use(h ...WSHandler) { 179 | srv.handlerMiddlewares = append(srv.handlerMiddlewares, h...) 180 | } 181 | 182 | // Handle 注册 CMD 路由监听器 183 | func (ws *WebsocketServer) Handle(cmd string, handlers ...WSHandler) { 184 | h, ok := ws.routes[cmd] 185 | if !ok { 186 | h = handlers 187 | } else { 188 | h = append(h, handlers...) 189 | } 190 | ws.routes[cmd] = h 191 | } 192 | 193 | // Push 服务器推送消息到客户端 194 | func (ws *WebsocketServer) Push(sid uint64, data *WSResponse) error { 195 | conn, ok := ws.Engine.Clients[sid] 196 | if !ok { 197 | return fmt.Errorf("sid=%d is invalid", sid) 198 | } 199 | if data.Seqno == "" { 200 | data.Seqno = util.RandomStr(8) 201 | } 202 | if data.Status == 0 && data.Msg == "" { 203 | data.Msg = "ok" 204 | } 205 | 206 | payload, err := json.Marshal(data) 207 | if err != nil { 208 | return err 209 | } 210 | return conn.Send(&websocket.Message{ 211 | SID: sid, 212 | Payload: payload, 213 | Socket: conn, 214 | MsgType: ws.Config.MsgType, 215 | }) 216 | } 217 | 218 | // Destroy 销毁清理连接 219 | func (ws *WebsocketServer) Destroy(sid uint64) { 220 | conn, ok := ws.Engine.Clients[sid] 221 | if ok { 222 | conn.Destroy() 223 | } 224 | ws.heartbeat.Delete(sid) 225 | sessionMu.Lock() 226 | delete(ws.Session, sid) 227 | sessionMu.Unlock() 228 | } 229 | 230 | func (ws *WebsocketServer) onCreate(conn *websocket.Conn) { 231 | sid := conn.ID 232 | ws.heartbeat.Store(sid, time.Now().Unix()) 233 | sessionMu.Lock() 234 | ws.Session[sid] = make(map[string]interface{}) 235 | sessionMu.Unlock() 236 | } 237 | 238 | func (ws *WebsocketServer) onClose(conn *websocket.Conn) { 239 | sid := conn.ID 240 | ws.Destroy(sid) 241 | } 242 | 243 | func (ws *WebsocketServer) checkHeartbeat() { 244 | for { 245 | time.Sleep(time.Second) 246 | now := time.Now().Unix() - 30 247 | ws.heartbeat.Range(func(key, val interface{}) bool { 248 | sid, ok := key.(uint64) 249 | if !ok { 250 | return true 251 | } 252 | hbTime, ok := val.(int64) 253 | if !ok { 254 | return true 255 | } 256 | 257 | if hbTime < now { 258 | go ws.Destroy(sid) 259 | } 260 | return true 261 | }) 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /websocket/wsrv/srv_test.go: -------------------------------------------------------------------------------- 1 | package wsrv 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/go-eyas/toolkit/log" 8 | "github.com/go-eyas/toolkit/websocket" 9 | ) 10 | 11 | func TestSrv(t *testing.T) { 12 | server := New(&websocket.Config{ 13 | Logger: log.SugaredLogger, 14 | }) 15 | // server.UseRequest(func(c *Context) { 16 | // log.Debugf("ws request middleware, sid=%d", c.SessionID) 17 | // }) 18 | // server.UseResponse(func(c *Context) { 19 | // log.Debugf("ws response middleware, sid=%d", c.SessionID) 20 | // }) 21 | // server.UseRequest(func(c *Context) { 22 | // // uid, ok := c.Get("uid").(int64) 23 | // // if !ok || uid == 0 { 24 | // // c.Abort() 25 | // // } 26 | // }) 27 | server.Use(func(c *Context) { 28 | log.Debugf("ws request middleware, sid=%d, cmd=%s, data=%s", c.SessionID, c.CMD, string(c.Request.Data)) 29 | c.Next() 30 | log.Debugf("ws response middleware, sid=%d cmd=%s, data=%v", c.SessionID, c.CMD, c.Response.Data) 31 | }) 32 | server.Handle("register") 33 | server.Handle("register", func(c *Context) { 34 | c.Set("uid", int(123)) 35 | for sid, vals := range server.Session { 36 | if uid, ok := vals["uid"]; ok { 37 | server.Push(sid, &WSResponse{ 38 | CMD: "have_user_register", 39 | Data: map[string]interface{}{ 40 | "uid": uid, 41 | }, 42 | }) 43 | } 44 | } 45 | c.OK() 46 | }) 47 | 48 | t.Log("init ws srv ok ") 49 | http.HandleFunc("/ws", server.Engine.HTTPHandler) 50 | http.HandleFunc("/play", server.Engine.Playground) 51 | http.ListenAndServe("127.0.0.1:9000", nil) 52 | 53 | } 54 | --------------------------------------------------------------------------------