├── LICENSE ├── README.md ├── client.go ├── client_hub.go ├── code.go ├── example └── ws.go ├── go.mod ├── go.sum ├── log.go ├── node.go ├── response.go └── server.go /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MQEnergy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-websocket 2 | 基于gorilla/websocket封装的websocket库,实现基于系统维度的消息推送,基于群组维度的消息推送,基于单个和多个客户端消息推送。 3 | 4 | [![GoDoc](https://godoc.org/github.com/MQEnergy/go-websocket/?status.svg)](https://pkg.go.dev/github.com/MQEnergy/go-websocket) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/MQEnergy/go-websocket)](https://goreportcard.com/report/github.com/MQEnergy/go-websocket) 6 | [![codebeat badge](https://codebeat.co/badges/063ec0b6-5059-4b1b-92c0-4f750438faa8)](https://codebeat.co/projects/github-com-mqenergy-go-websocket-main) 7 | [![GitHub license](https://img.shields.io/github/license/MQEnergy/go-websocket)](https://github.com/MQEnergy/go-websocket/blob/main/LICENSE) 8 | 9 | ## 一、目录结构 10 | ``` 11 | ├── LICENSE 12 | ├── README.md 13 | ├── client.go // 客户端 14 | ├── client_hub.go // 客户端集线器 15 | ├── code.go // 状态码 16 | ├── example // 案例 17 | │   └── ws.go 18 | ├── go.mod 19 | ├── go.sum 20 | ├── log.go // 日志 21 | ├── node.go // 节点(用于在分布式系统生成基于节点的客户端连接ID) 22 | ├── response.go // 客户端发送消息 23 | └── server.go // 服务 24 | 25 | ``` 26 | ## 二、在项目中安装使用 27 | ```go 28 | go get -u github.com/MQEnergy/go-websocket 29 | ``` 30 | ## 三、运行example 31 | ### 1、开启服务 32 | ```go 33 | go run examples/ws.go 34 | ``` 35 | ``` 36 | 服务器启动成功,端口号 :9991 37 | ``` 38 | 代表启动成功 39 | 40 | ### 2、案例 41 | 具体查看example目录 42 | 43 | #### 1)连接ws并加群组 44 | system_id为系统ID(不必填 不填默认当前节点ip的int值) 45 | group_id为群组ID(不必填 不填连接不加群组 注意:群组id为全局唯一ID 不然可能会出现不同系统的相同群组都推送消息) 46 | 47 | 请求 48 | ``` 49 | ws://127.0.0.1:9991/ws?system_id=123&group_id=test 50 | ``` 51 | 可选多种返回方式 如: Text,Json,Binary(二进制方式) 52 | 返回如下json示例: 53 | ``` 54 | { 55 | "code": 0, 56 | "msg": "客户端连接成功", 57 | "data": { 58 | "client_id": "1589962851152388096", 59 | "group_id": "test", 60 | "system_id": "123" 61 | }, 62 | "params": null 63 | } 64 | ``` 65 | 66 | #### 2)全局广播消息群发 67 | 请求 68 | ``` 69 | http://127.0.0.1:9991/push_to_system?system_id=123&data={"hello":"world"} 70 | ``` 71 | 返回 72 | ``` 73 | { 74 | "msg": "系统消息发送成功", 75 | } 76 | ``` 77 | 78 | #### 3)单个系统消息群发 79 | 请求 80 | ``` 81 | http://127.0.0.1:9991/push_to_system?system_id=123&data={"hello":"world"} 82 | ``` 83 | 返回 84 | ``` 85 | { 86 | "msg": "系统消息发送成功", 87 | } 88 | ``` 89 | 90 | #### 4)推送消息到群组 91 | 请求 92 | ``` 93 | http://127.0.0.1:9991/push_to_group?system_id=123&group_id=test&data={"hello":"world1"} 94 | ``` 95 | 返回 96 | ``` 97 | { 98 | "msg": "群组消息发送成功", 99 | } 100 | ``` 101 | 102 | #### 5)单个客户端消息发送 103 | 请求 104 | ``` 105 | http://127.0.0.1:9991/push_to_client?client_id=123&data={"hello":"world"} 106 | ``` 107 | 返回 108 | ``` 109 | { 110 | "msg": "客户端消息发送成功", 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package go_websocket 2 | 3 | import ( 4 | "github.com/bwmarrin/snowflake" 5 | "github.com/gorilla/websocket" 6 | "strconv" 7 | "time" 8 | ) 9 | 10 | const ( 11 | // Time allowed to write a message to the peer. 12 | writeWait = 10 * time.Second 13 | 14 | // Time allowed to read the next pong message from the peer. 15 | pongWait = 60 * time.Second 16 | 17 | // Send pings to peer with this period. Must be less than pongWait. 18 | pingPeriod = (pongWait * 9) / 10 19 | 20 | // Maximum message size allowed from peer. 21 | maxMessageSize = 5012 22 | 23 | readBufferSize = 1024 // 读缓冲区大小 24 | writeBufferSize = 1024 // 写缓冲区大小 25 | ) 26 | 27 | var ( 28 | newline = []byte{'\n'} 29 | space = []byte{' '} 30 | ) 31 | 32 | type Client struct { 33 | ClientId string `json:"client_id"` // 客户端连接ID 34 | GroupId string `json:"group_id"` // 群组id 35 | SystemId string `json:"system_id"` // 系统ID 为分布式做准备的 36 | Conn *websocket.Conn 37 | send chan []byte 38 | hub *Hub 39 | } 40 | 41 | // GenerateUuid 生成唯一ID 42 | func GenerateUuid(node *snowflake.Node) string { 43 | if node == nil { 44 | var err error 45 | node, err = snowflake.NewNode(1) 46 | if err != nil { 47 | return "" 48 | } 49 | } 50 | id := node.Generate() 51 | return strconv.FormatInt(id.Int64(), 10) 52 | } 53 | -------------------------------------------------------------------------------- /client_hub.go: -------------------------------------------------------------------------------- 1 | package go_websocket 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | ) 7 | 8 | type BroadcastChan struct { 9 | Name string `json:"name"` 10 | Msg []byte `json:"msg"` 11 | } 12 | 13 | type Hub struct { 14 | Clients map[*Client]bool // 全部客户端列表 {*Client1: bool, *Client2: bool...} 15 | SystemClients map[string][]*Client // 全部系统列表 {"systemId1": []*Clients{*Client1, *Client2...}, "systemId2": []*Clients{*Client1, *Client2...}} 16 | GroupClients map[string][]*Client // 全部群组列表 {"groupId": []*Clients{*Client1, *Client2...}} 17 | 18 | ClientRegister chan *Client // 客户端连接处理 19 | ClientUnregister chan *Client // 客户端断开连接处理 20 | ClientLock sync.RWMutex // 客户端列表读写锁 21 | Broadcast chan []byte // 来自广播的入站消息 22 | SystemBroadcast chan *BroadcastChan // 来自系统的入站消息 {Name:"systemId", Msg:"msg"} 23 | GroupBroadcast chan *BroadcastChan // 来自群组的入站消息 {Name:"groupId", Msg:"msg"} 24 | ClientBroadcast chan *BroadcastChan // 来自客户端的入站消息 {Name:"clientId", Msg:"msg"} 25 | } 26 | 27 | // NewHub 实例化 28 | func NewHub() *Hub { 29 | return &Hub{ 30 | Clients: make(map[*Client]bool), 31 | GroupClients: make(map[string][]*Client, 1000), 32 | SystemClients: make(map[string][]*Client, 1000), 33 | ClientRegister: make(chan *Client), 34 | ClientUnregister: make(chan *Client), 35 | Broadcast: make(chan []byte), 36 | SystemBroadcast: make(chan *BroadcastChan, 1000), 37 | GroupBroadcast: make(chan *BroadcastChan, 1000), 38 | ClientBroadcast: make(chan *BroadcastChan, 1000), 39 | } 40 | } 41 | 42 | // Run run chan listener 43 | func (m *Hub) Run() { 44 | for { 45 | select { 46 | case client := <-m.ClientRegister: 47 | m.handleClientRegister(client) 48 | 49 | case client := <-m.ClientUnregister: 50 | m.handleClientUnregister(client) 51 | close(client.send) 52 | 53 | // 全局广播 54 | case message := <-m.Broadcast: 55 | m.AllBroadcastHandle(message) 56 | // 系统广播 57 | case systems := <-m.SystemBroadcast: 58 | m.SystemBroadcastHandle(systems.Name, systems.Msg) 59 | // 群组广播 60 | case groups := <-m.GroupBroadcast: 61 | m.GroupBroadcastHandle(groups.Name, groups.Msg) 62 | // 客户端推送 63 | case clients := <-m.ClientBroadcast: 64 | m.ClientBroadcastHandle(clients.Name, clients.Msg) 65 | } 66 | } 67 | } 68 | 69 | // handleClientRegister 客户端连接处理 70 | func (m *Hub) handleClientRegister(client *Client) { 71 | m.ClientLock.Lock() 72 | m.SystemClients[client.SystemId] = append(m.SystemClients[client.SystemId], client) 73 | if client.GroupId != "" { 74 | m.GroupClients[client.GroupId] = append(m.GroupClients[client.GroupId], client) 75 | } 76 | m.Clients[client] = true 77 | m.ClientLock.Unlock() 78 | } 79 | 80 | // handleClientUnregister 客户端断开连接处理 81 | func (m *Hub) handleClientUnregister(client *Client) { 82 | m.ClientLock.Lock() 83 | if _, ok := m.Clients[client]; ok { 84 | delete(m.Clients, client) 85 | } 86 | for index, _client := range m.SystemClients[client.SystemId] { 87 | if _client.ClientId == client.ClientId { 88 | m.SystemClients[client.SystemId] = append(m.SystemClients[client.SystemId][:index], m.SystemClients[client.SystemId][index+1:]...) 89 | break 90 | } 91 | } 92 | clients, ok := m.GroupClients[client.GroupId] 93 | if ok { 94 | for index, _client := range clients { 95 | if _client.ClientId == client.ClientId { 96 | m.GroupClients[client.GroupId] = append(m.GroupClients[client.GroupId][:index], m.GroupClients[client.GroupId][index+1:]...) 97 | } 98 | } 99 | } 100 | m.ClientLock.Unlock() 101 | } 102 | 103 | // AllBroadcastHandle 全局广播 104 | func (m *Hub) AllBroadcastHandle(msg []byte) { 105 | for client := range m.Clients { 106 | select { 107 | case client.send <- msg: 108 | default: 109 | close(client.send) 110 | m.handleClientUnregister(client) 111 | } 112 | } 113 | } 114 | 115 | // SystemBroadcastHandle 系统广播处理 116 | func (m *Hub) SystemBroadcastHandle(systemId string, msg []byte) { 117 | clients, err := m.GetSystemClients(systemId) 118 | if err != nil { 119 | m.RemoveSystem(systemId) 120 | } 121 | for _, client := range clients { 122 | select { 123 | case client.send <- msg: 124 | default: 125 | close(client.send) 126 | m.handleClientUnregister(client) 127 | } 128 | } 129 | } 130 | 131 | // GroupBroadcastHandle 群组消息通道处理 132 | func (m *Hub) GroupBroadcastHandle(groupId string, msg []byte) { 133 | clients, err := m.GetGroupClients(groupId) 134 | if err != nil { 135 | m.RemoveGroup(groupId) 136 | } 137 | for _, client := range clients { 138 | select { 139 | case client.send <- msg: 140 | default: 141 | close(client.send) 142 | m.handleClientUnregister(client) 143 | } 144 | } 145 | } 146 | 147 | // ClientBroadcastHandle 单客户端通道处理 148 | func (m *Hub) ClientBroadcastHandle(clientId string, msg []byte) { 149 | var _client *Client 150 | for client := range m.Clients { 151 | if client.ClientId == clientId { 152 | _client = client 153 | break 154 | } 155 | } 156 | if _client != nil { 157 | select { 158 | case _client.send <- msg: 159 | break 160 | default: 161 | close(_client.send) 162 | m.handleClientUnregister(_client) 163 | } 164 | } 165 | } 166 | 167 | // SetClientToGroups 添加客户端到分组 168 | func (m *Hub) SetClientToGroups(groupId string, client *Client) bool { 169 | clients, ok := m.GroupClients[groupId] 170 | if !ok { 171 | return false 172 | } 173 | for _, _client := range clients { 174 | if _client.ClientId == client.ClientId { 175 | return false 176 | } 177 | } 178 | m.ClientLock.Lock() 179 | m.GroupClients[groupId] = append(m.GroupClients[groupId], client) 180 | m.ClientLock.Unlock() 181 | return true 182 | } 183 | 184 | // GetSystemClients 获取系统的客户端列表 185 | func (m *Hub) GetSystemClients(name string) ([]*Client, error) { 186 | clients, ok := m.SystemClients[name] 187 | if !ok { 188 | return []*Client{}, errors.New("group does not exist") 189 | } 190 | return clients, nil 191 | } 192 | 193 | // GetGroupClients 获取群组的客户端列表 194 | func (m *Hub) GetGroupClients(name string) ([]*Client, error) { 195 | clients, ok := m.GroupClients[name] 196 | if !ok { 197 | return []*Client{}, errors.New("group does not exist") 198 | } 199 | return clients, nil 200 | } 201 | 202 | // RemoveSystem 删除system和系统中的client 203 | func (m *Hub) RemoveSystem(name string) { 204 | delete(m.SystemClients, name) 205 | } 206 | 207 | // RemoveGroup 删除group和群组中的client 208 | func (m *Hub) RemoveGroup(name string) { 209 | delete(m.GroupClients, name) 210 | } 211 | 212 | // RemoveClientByGroup 从群组删除客户端 213 | func (m *Hub) RemoveClientByGroup(client *Client) error { 214 | m.ClientLock.Lock() 215 | clients, ok := m.GroupClients[client.GroupId] 216 | if !ok { 217 | return errors.New("group does not exist") 218 | } 219 | for index, _client := range clients { 220 | if _client.ClientId == client.ClientId { 221 | m.GroupClients[client.GroupId] = append(m.GroupClients[client.GroupId][:index], m.GroupClients[client.GroupId][index+1:]...) 222 | } 223 | } 224 | m.ClientLock.Unlock() 225 | return nil 226 | } 227 | -------------------------------------------------------------------------------- /code.go: -------------------------------------------------------------------------------- 1 | package go_websocket 2 | 3 | type Code int 4 | 5 | const ( 6 | Success Code = 0 7 | Failed Code = 10001 + iota 8 | ClientFailed 9 | ClientNotExist 10 | ClientCloseSuccess 11 | ClientCloseFailed 12 | ReadMsgErr 13 | ReadMsgSuccess 14 | SendMsgErr 15 | SendMsgSuccess 16 | HeartbeatErr 17 | SystemErr 18 | BindGroupSuccess 19 | BindGroupErr 20 | UnAuthed 21 | InternalErr 22 | RequestMethodErr 23 | RequestParamErr 24 | ) 25 | 26 | var CodeMap = map[Code]string{ 27 | Success: "客户端连接成功", 28 | Failed: "客户端连接失败", 29 | ClientFailed: "客户端主动断连", 30 | ClientNotExist: "客户端不存在", 31 | ClientCloseSuccess: "客户端关闭成功", 32 | ClientCloseFailed: "客户端关闭失败", 33 | ReadMsgErr: "读取消息体失败", 34 | ReadMsgSuccess: "读取消息体成功", 35 | SendMsgErr: "发送消息体失败", 36 | SendMsgSuccess: "发送消息体成功", 37 | HeartbeatErr: "心跳检测失败", 38 | SystemErr: "系统不能为空", 39 | BindGroupSuccess: "绑定群组成功", 40 | BindGroupErr: "绑定群组失败", 41 | UnAuthed: "用户未认证", 42 | InternalErr: "服务器内部错误", 43 | RequestMethodErr: "请求方式错误", 44 | RequestParamErr: "请求参数错误", 45 | } 46 | 47 | // Msg 返回错误码对应的说明 48 | func (c Code) Msg() string { 49 | if v, ok := CodeMap[c]; ok { 50 | return v 51 | } 52 | return `` 53 | } 54 | -------------------------------------------------------------------------------- /example/ws.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/MQEnergy/go-websocket" 5 | "github.com/bwmarrin/snowflake" 6 | "github.com/sirupsen/logrus" 7 | "log" 8 | "net/http" 9 | ) 10 | 11 | var ( 12 | Node *snowflake.Node 13 | ) 14 | 15 | func init() { 16 | // 日志注入 17 | go_websocket.Logger = logrus.New() 18 | go_websocket.Logger.SetFormatter(&logrus.JSONFormatter{ 19 | TimestampFormat: "2006-01-02 15:04:05", 20 | }) 21 | localIp, err := go_websocket.GetLocalIpToInt() 22 | if err != nil { 23 | panic(err) 24 | } 25 | Node, err = snowflake.NewNode(int64(localIp) % 1023) 26 | if err != nil { 27 | panic(err) 28 | } 29 | } 30 | 31 | func main() { 32 | hub := go_websocket.NewHub() 33 | go hub.Run() 34 | 35 | // ws连接 36 | http.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) { 37 | _, err := go_websocket.WsServer(hub, writer, request, go_websocket.Json) 38 | if err != nil { 39 | return 40 | } 41 | }) 42 | 43 | // 推送到所有连接的客户端 44 | http.HandleFunc("/push_to_all", func(writer http.ResponseWriter, request *http.Request) { 45 | data := request.FormValue("data") 46 | writer.Header().Set("Content-Type", "application/json; charset=utf-8") 47 | if data == "" { 48 | writer.Write([]byte("{\"msg\":\"参数错误\"}")) 49 | return 50 | } 51 | hub.Broadcast <- []byte(data) 52 | writer.Write([]byte("{\"msg\":\"全局消息发送成功\"}")) 53 | return 54 | }) 55 | 56 | // 推送到所在系统的客户端 57 | http.HandleFunc("/push_to_system", func(writer http.ResponseWriter, request *http.Request) { 58 | systemId := request.FormValue("system_id") 59 | data := request.FormValue("data") 60 | writer.Header().Set("Content-Type", "application/json; charset=utf-8") 61 | if systemId == "" || data == "" { 62 | writer.Write([]byte("{\"msg\":\"参数错误\"}")) 63 | return 64 | } 65 | hub.SystemBroadcast <- &go_websocket.BroadcastChan{ 66 | Name: systemId, 67 | Msg: []byte(data), 68 | } 69 | writer.Write([]byte("{\"msg\":\"系统消息发送成功\"}")) 70 | return 71 | }) 72 | 73 | // 推送到群组 74 | http.HandleFunc("/push_to_group", func(writer http.ResponseWriter, request *http.Request) { 75 | groupId := request.FormValue("group_id") 76 | data := request.FormValue("data") 77 | writer.Header().Set("Content-Type", "application/json; charset=utf-8") 78 | if groupId == "" || data == "" { 79 | writer.Write([]byte("{\"msg\":\"参数错误\"}")) 80 | return 81 | } 82 | hub.GroupBroadcast <- &go_websocket.BroadcastChan{ 83 | Name: groupId, 84 | Msg: []byte(data), 85 | } 86 | //hub.GroupBroadcastHandle(groupId, []byte(data)) 87 | writer.Write([]byte("{\"msg\":\"群组消息发送成功\"}")) 88 | return 89 | }) 90 | 91 | // 推送到单个客户端 92 | http.HandleFunc("/push_to_client", func(writer http.ResponseWriter, request *http.Request) { 93 | clientId := request.FormValue("client_id") 94 | data := request.FormValue("data") 95 | writer.Header().Set("Content-Type", "application/json; charset=utf-8") 96 | if clientId == "" || data == "" { 97 | writer.Write([]byte("{\"msg\":\"参数错误\"}")) 98 | return 99 | } 100 | hub.ClientBroadcast <- &go_websocket.BroadcastChan{ 101 | Name: clientId, 102 | Msg: []byte(data), 103 | } 104 | writer.Write([]byte("{\"msg\":\"客户端消息发送成功\"}")) 105 | return 106 | }) 107 | 108 | log.Println("服务启动成功。端口号 :9991") 109 | if err := http.ListenAndServe(":9991", nil); err != nil { 110 | log.Println("ListenAndServe: ", err) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/MQEnergy/go-websocket 2 | 3 | go 1.18 4 | 5 | require github.com/gorilla/websocket v1.5.0 6 | 7 | require ( 8 | github.com/bwmarrin/snowflake v0.3.0 9 | github.com/sirupsen/logrus v1.9.0 10 | ) 11 | 12 | require golang.org/x/sys v0.1.0 // indirect 13 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0= 2 | github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= 7 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 8 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 9 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 10 | github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= 11 | github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= 12 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 13 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 14 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 15 | golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 16 | golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= 17 | golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 18 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 19 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 20 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 21 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | package go_websocket 2 | 3 | import ( 4 | "github.com/sirupsen/logrus" 5 | ) 6 | 7 | var ( 8 | Logger *logrus.Logger 9 | ) 10 | 11 | // TraceLog 写日志 12 | func TraceLog(code Code, params, data, err interface{}, level logrus.Level) { 13 | Logger.WithFields(logrus.Fields{ 14 | "code": code, 15 | "err": err, 16 | "params": params, 17 | "data": data, 18 | }).Log(level, code.Msg()) 19 | } 20 | 21 | // TraceHeartbeatErrdLog 心跳检测失败消息 22 | func TraceHeartbeatErrdLog(params, data, err interface{}, level logrus.Level) { 23 | TraceLog(HeartbeatErr, params, data, err, level) 24 | } 25 | 26 | // TraceClientCloseFailedLog 客户端关闭失败消息 27 | func TraceClientCloseFailedLog(params, data, err interface{}, level logrus.Level) { 28 | TraceLog(ClientCloseFailed, params, data, err, level) 29 | } 30 | 31 | // TraceClientCloseSuccessLog 客户端关闭成功消息 32 | func TraceClientCloseSuccessLog(params, data, err interface{}, level logrus.Level) { 33 | TraceLog(ClientCloseSuccess, params, data, err, level) 34 | } 35 | 36 | // TraceSuccessLog 客户端连接成功消息 37 | func TraceSuccessLog(params, data interface{}, level logrus.Level) { 38 | TraceLog(Success, params, data, nil, level) 39 | } 40 | 41 | // TraceReadMsgSuccessLog 读取消息体成功消息 42 | func TraceReadMsgSuccessLog(params, data interface{}, level logrus.Level) { 43 | TraceLog(ReadMsgSuccess, params, data, nil, level) 44 | } 45 | 46 | // TraceSendMsgErrLog 发送消息体失败 47 | func TraceSendMsgErrLog(params, data, err interface{}, level logrus.Level) { 48 | TraceLog(SendMsgErr, params, data, err, level) 49 | } 50 | -------------------------------------------------------------------------------- /node.go: -------------------------------------------------------------------------------- 1 | package go_websocket 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "net" 7 | ) 8 | 9 | // convertToIntIP 转换ip为int 10 | func convertToIntIP(ip net.IP) uint32 { 11 | if len(ip) == 16 { 12 | return binary.BigEndian.Uint32(ip[12:16]) 13 | } 14 | return binary.BigEndian.Uint32(ip) 15 | } 16 | 17 | // GetLocalIpToInt 获取本机IP转成int 18 | func GetLocalIpToInt() (uint32, error) { 19 | addrs, err := net.InterfaceAddrs() 20 | if err != nil { 21 | return 0, err 22 | } 23 | for _, address := range addrs { 24 | // 检查ip地址判断是否回环地址 25 | if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { 26 | if ipnet.IP.To4() != nil { 27 | return convertToIntIP(ipnet.IP), nil 28 | } 29 | } 30 | } 31 | return 0, errors.New("can not find the client ip address") 32 | } 33 | -------------------------------------------------------------------------------- /response.go: -------------------------------------------------------------------------------- 1 | package go_websocket 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/gorilla/websocket" 6 | ) 7 | 8 | // responseData 响应结构体 9 | type responseData struct { 10 | Code Code `json:"code"` 11 | Msg string `json:"msg"` 12 | Data interface{} `json:"data"` 13 | Params interface{} `json:"params"` // 自定义参数 14 | } 15 | type MsgType int 16 | 17 | const ( 18 | Text MsgType = 1 19 | Json = iota + Text 20 | Binary 21 | ) 22 | 23 | // WriteMessage 返回给客户端的信息 24 | func WriteMessage(conn *websocket.Conn, code Code, message string, data, params interface{}, msgType MsgType) error { 25 | r := responseData{ 26 | Code: code, 27 | Msg: message, 28 | Params: params, 29 | Data: data, 30 | } 31 | switch msgType { 32 | case Text: 33 | marshal, _ := json.Marshal(r) 34 | return conn.WriteMessage(1, marshal) 35 | case Binary: 36 | marshal, _ := json.Marshal(r) 37 | return conn.WriteMessage(2, marshal) 38 | case Json: 39 | return conn.WriteJSON(r) 40 | } 41 | return nil 42 | } 43 | 44 | // WriteJson 返回给客户端的信息 45 | func WriteJson(conn *websocket.Conn, code Code, message string, data, params interface{}) error { 46 | return WriteMessage(conn, code, message, data, params, Json) 47 | } 48 | 49 | // WriteSuccessJson 返回客户端连接成功 50 | func WriteSuccessJson(conn *websocket.Conn, data, params interface{}) error { 51 | return WriteJson(conn, Success, Success.Msg(), data, params) 52 | } 53 | 54 | // WriteFailedJson 返回客户端连接失败 55 | func WriteFailedJson(conn *websocket.Conn, data, params interface{}) error { 56 | return WriteJson(conn, Failed, Failed.Msg(), data, params) 57 | } 58 | 59 | // WriteClientFailedJson 返回客户端主动断连 60 | func WriteClientFailedJson(conn *websocket.Conn, data, params interface{}) error { 61 | return WriteJson(conn, ClientFailed, ClientFailed.Msg(), data, params) 62 | } 63 | 64 | // WriteClientNotExistJson 返回客户端不存在 65 | func WriteClientNotExistJson(conn *websocket.Conn, data, params interface{}) error { 66 | return WriteJson(conn, ClientNotExist, ClientNotExist.Msg(), data, params) 67 | } 68 | 69 | // WriteClientCloseSuccessJson 返回客户端关闭成功 70 | func WriteClientCloseSuccessJson(conn *websocket.Conn, data, params interface{}) error { 71 | return WriteJson(conn, ClientCloseSuccess, ClientCloseSuccess.Msg(), data, params) 72 | } 73 | 74 | // WriteClientCloseFailedJson 返回客户端关闭失败 75 | func WriteClientCloseFailedJson(conn *websocket.Conn, data, params interface{}) error { 76 | return WriteJson(conn, ClientCloseFailed, ClientCloseFailed.Msg(), data, params) 77 | } 78 | 79 | // WriteReadMsgErrJson 返回读取消息体失败 80 | func WriteReadMsgErrJson(conn *websocket.Conn, data, params interface{}) error { 81 | return WriteJson(conn, ReadMsgErr, ReadMsgErr.Msg(), data, params) 82 | } 83 | 84 | // WriteReadMsgSuccessJson 返回读取消息体成功 85 | func WriteReadMsgSuccessJson(conn *websocket.Conn, data, params interface{}) error { 86 | return WriteJson(conn, ReadMsgSuccess, ReadMsgSuccess.Msg(), data, params) 87 | } 88 | 89 | // WriteSendMsgErrJson 返回发送消息体失败 90 | func WriteSendMsgErrJson(conn *websocket.Conn, data, params interface{}) error { 91 | return WriteJson(conn, SendMsgErr, SendMsgErr.Msg(), data, params) 92 | } 93 | 94 | // WriteSendMsgSuccessJson 返回发送消息体成功 95 | func WriteSendMsgSuccessJson(conn *websocket.Conn, data, params interface{}) error { 96 | return WriteJson(conn, SendMsgSuccess, SendMsgSuccess.Msg(), data, params) 97 | } 98 | 99 | // WriteHeartbeatErrJson 返回心跳检测失败 100 | func WriteHeartbeatErrJson(conn *websocket.Conn, data, params interface{}) error { 101 | return WriteJson(conn, HeartbeatErr, HeartbeatErr.Msg(), data, params) 102 | } 103 | 104 | // WriteBindGroupSuccessJson 返回绑定群组成功 105 | func WriteBindGroupSuccessJson(conn *websocket.Conn, data, params interface{}) error { 106 | return WriteJson(conn, BindGroupSuccess, BindGroupSuccess.Msg(), data, params) 107 | } 108 | 109 | // WriteRequestParamErrJson 返回请求参数错误 110 | func WriteRequestParamErrJson(conn *websocket.Conn, data, params interface{}) error { 111 | return WriteJson(conn, RequestParamErr, RequestParamErr.Msg(), data, params) 112 | } 113 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package go_websocket 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "github.com/bwmarrin/snowflake" 7 | "github.com/gorilla/websocket" 8 | "net/http" 9 | "strconv" 10 | "time" 11 | ) 12 | 13 | var ( 14 | Node *snowflake.Node 15 | upgrader = websocket.Upgrader{ 16 | ReadBufferSize: readBufferSize, 17 | WriteBufferSize: writeBufferSize, 18 | // 解决跨域问题 19 | CheckOrigin: func(r *http.Request) bool { 20 | return true 21 | }, 22 | } 23 | ) 24 | 25 | func init() { 26 | localIp, err := GetLocalIpToInt() 27 | if err != nil { 28 | panic(err) 29 | } 30 | Node, err = snowflake.NewNode(int64(localIp) % 1023) 31 | if err != nil { 32 | panic(err) 33 | } 34 | } 35 | 36 | // ReadMessageHandler 将来自 websocket 连接的消息推送到集线器。 37 | func (c *Client) ReadMessageHandler() { 38 | if c.Conn != nil { 39 | defer func() { 40 | c.hub.ClientUnregister <- c 41 | c.Conn.Close() 42 | }() 43 | 44 | c.Conn.SetReadLimit(maxMessageSize) 45 | c.Conn.SetReadDeadline(time.Now().Add(pongWait)) 46 | c.Conn.SetPongHandler(func(appData string) error { 47 | c.Conn.SetReadDeadline(time.Now().Add(pongWait)) 48 | return nil 49 | }) 50 | for { 51 | _, message, err := c.Conn.ReadMessage() 52 | if err != nil { 53 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 54 | TraceClientCloseSuccessLog("", "", err.Error(), 4) 55 | } 56 | break 57 | } 58 | message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) 59 | c.hub.ClientBroadcast <- &BroadcastChan{Name: c.ClientId, Msg: message} 60 | } 61 | } 62 | } 63 | 64 | // WriteMessageHandler 将消息从集线器发送到 websocket 连接 65 | func (c *Client) WriteMessageHandler(msgtype MsgType) { 66 | if c.Conn != nil { 67 | ticker := time.NewTicker(pingPeriod) 68 | defer func() { 69 | ticker.Stop() 70 | if c.Conn != nil { 71 | c.Conn.Close() 72 | } 73 | }() 74 | 75 | for { 76 | select { 77 | case message, ok := <-c.send: 78 | c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) 79 | if !ok { 80 | c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) 81 | return 82 | } 83 | data := make(map[string]interface{}, 0) 84 | if err := json.Unmarshal(message, &data); err != nil { 85 | return 86 | } 87 | c.Conn.SetWriteDeadline(time.Time{}) 88 | WriteMessage(c.Conn, SendMsgSuccess, SendMsgSuccess.Msg(), data, data, msgtype) 89 | 90 | case <-ticker.C: 91 | c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) 92 | if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { 93 | return 94 | } 95 | } 96 | } 97 | } 98 | } 99 | 100 | // WsServer 处理websocket请求 101 | func WsServer(hub *Hub, w http.ResponseWriter, r *http.Request, msgtype MsgType) (*Client, error) { 102 | conn, err := upgrader.Upgrade(w, r, nil) 103 | if err != nil { 104 | return nil, err 105 | } 106 | systemId := r.FormValue("system_id") 107 | groupId := r.FormValue("group_id") 108 | if systemId == "" { 109 | sid, err := GetLocalIpToInt() 110 | if err != nil { 111 | return nil, err 112 | } 113 | systemId = strconv.Itoa(int(sid)) 114 | } 115 | client := &Client{ 116 | SystemId: systemId, 117 | GroupId: groupId, 118 | ClientId: GenerateUuid(Node), 119 | hub: hub, 120 | Conn: conn, 121 | send: make(chan []byte, 256), 122 | } 123 | client.hub.ClientRegister <- client 124 | 125 | // 连接成功返回消息 126 | data := map[string]string{"system_id": systemId, "client_id": client.ClientId, "group_id": groupId} 127 | params := map[string]interface{}{"type": "connected"} 128 | if err := WriteMessage(conn, Success, Success.Msg(), data, params, msgtype); err != nil { 129 | return nil, err 130 | } 131 | 132 | // 监听客户端发送的消息 133 | go client.WriteMessageHandler(msgtype) 134 | go client.ReadMessageHandler() 135 | 136 | return client, nil 137 | } 138 | --------------------------------------------------------------------------------