├── .gitignore ├── Dockerfile_gateway ├── Dockerfile_router ├── Dockerfile_royal ├── Dockerfile_server ├── LICENSE ├── README.md ├── build_images.sh ├── channel.go ├── channels.go ├── channels_test.go ├── container ├── clients.go ├── container.go ├── hash_selector.go ├── metrics.go └── selector.go ├── context.go ├── default_server.go ├── dispatcher.go ├── dispatcher_mock.go ├── docker-compose-kim.yml ├── docker-compose.yml ├── event.go ├── examples ├── benchmark │ └── server_test.go ├── dialer │ ├── client_dialer.go │ └── login.go ├── echo │ └── echo.go ├── kimbench │ ├── cmd.go │ ├── grouptalk.go │ ├── login.go │ └── usertalk.go ├── main.go ├── mock │ ├── client.go │ ├── cmd.go │ └── server.go └── unittest │ ├── chat_test.go │ ├── login_test.go │ └── offline_test.go ├── go.mod ├── go.sum ├── location.go ├── logger ├── logger.go └── setting.go ├── metrics.go ├── middleware └── recover.go ├── mock.sh ├── naming ├── consul │ ├── naming.go │ └── naming_test.go ├── naming.go └── service.go ├── net.go ├── report ├── report.go ├── report_test.go └── template.go ├── router.go ├── server.go ├── server_mock.go ├── services ├── gateway │ ├── conf.yaml │ ├── conf │ │ ├── config.go │ │ └── route.go │ ├── conf2.yaml │ ├── route.json │ ├── serv │ │ ├── dialer.go │ │ ├── handler.go │ │ ├── metrics.go │ │ ├── selector.go │ │ └── selector_test.go │ └── server.go ├── main.go ├── router │ ├── .DS_Store │ ├── api_test.go │ ├── apis │ │ ├── router.go │ │ └── router_test.go │ ├── conf │ │ ├── config.go │ │ └── router.go │ ├── data │ │ ├── ip2region.db │ │ ├── mapping.json │ │ └── regions.json │ ├── ipregion │ │ ├── ipregion.go │ │ └── ipregion_test.go │ └── server.go ├── server │ ├── conf.yaml │ ├── conf │ │ └── config.go │ ├── handler │ │ ├── chat_handler.go │ │ ├── group_handler.go │ │ ├── login_handler.go │ │ ├── login_handler_test.go │ │ ├── offline_handler.go │ │ └── offline_handler_test.go │ ├── serv │ │ └── handler.go │ ├── server.go │ └── service │ │ ├── group.go │ │ ├── group_test.go │ │ ├── message.go │ │ └── message_test.go └── service │ ├── conf.yaml │ ├── conf │ └── config.go │ ├── database │ ├── id_generator.go │ ├── model.go │ ├── mysql.go │ ├── mysql_test.go │ └── redis.go │ ├── handler │ ├── group_handler.go │ ├── message_handler.go │ └── message_handler_test.go │ └── server.go ├── storage.go ├── storage ├── redis_impl.go └── redis_test.go ├── storage_mock.go ├── tcp ├── client.go ├── connection.go └── server.go ├── websocket ├── client.go ├── connection.go └── server.go └── wire ├── build.sh ├── definitions.go ├── endian ├── big_endian_test.go └── helper.go ├── grpc_helper.go ├── pkt ├── basic_pkt.go ├── common.pb.go ├── packet.go ├── packet_test.go ├── protocol.pb.go ├── read_write.go └── read_write_test.go ├── proto ├── common.proto ├── protocol.proto └── rpc.proto ├── rpc └── rpc.pb.go ├── seq.go └── token ├── jwt.go └── jwt_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | services/data/ 2 | .vscode/ 3 | .idea -------------------------------------------------------------------------------- /Dockerfile_gateway: -------------------------------------------------------------------------------- 1 | FROM golang:alpine AS builder 2 | 3 | # 为我们的镜像设置必要的环境变量 4 | ENV GO111MODULE=on \ 5 | GOPROXY=https://goproxy.cn \ 6 | CGO_ENABLED=0 \ 7 | GOOS=linux \ 8 | GOARCH=amd64 9 | 10 | WORKDIR /build 11 | 12 | COPY . . 13 | 14 | RUN go build -o app ./services 15 | 16 | FROM scratch 17 | 18 | # 从builder镜像中把/build/app 拷贝到当前目录 19 | COPY --from=builder /build/app / 20 | COPY --from=builder /build/services/gateway/route.json /route.json 21 | 22 | EXPOSE 8000 23 | 24 | # 需要运行的命令 25 | ENTRYPOINT ["/app", "gateway", "-r", "route.json"] -------------------------------------------------------------------------------- /Dockerfile_router: -------------------------------------------------------------------------------- 1 | FROM golang:alpine AS builder 2 | 3 | # 为我们的镜像设置必要的环境变量 4 | ENV GO111MODULE=on \ 5 | GOPROXY=https://goproxy.cn \ 6 | CGO_ENABLED=0 \ 7 | GOOS=linux \ 8 | GOARCH=amd64 9 | 10 | WORKDIR /build 11 | 12 | COPY . . 13 | 14 | RUN go build -o app ./services 15 | 16 | FROM scratch 17 | 18 | # 从builder镜像中把/build/app 拷贝到当前目录 19 | COPY --from=builder /build/app / 20 | COPY --from=builder /build/services/router/data /data 21 | 22 | EXPOSE 8000 23 | 24 | # 需要运行的命令 25 | ENTRYPOINT ["/app", "router", "-d", "/data"] -------------------------------------------------------------------------------- /Dockerfile_royal: -------------------------------------------------------------------------------- 1 | FROM golang AS builder 2 | 3 | # 为我们的镜像设置必要的环境变量 4 | ENV GO111MODULE=on \ 5 | GOPROXY=https://goproxy.cn \ 6 | CGO_ENABLED=0 \ 7 | GOOS=linux \ 8 | GOARCH=amd64 9 | 10 | WORKDIR /build 11 | 12 | COPY . . 13 | 14 | # RUN go-wrapper install -ldflags "-linkmode external -extldflags -static" 15 | # RUN go build -a -ldflags '-linkmode external -extldflags "-static"' -o app ./services 16 | 17 | RUN go build -o app ./services 18 | 19 | FROM scratch 20 | 21 | # 从builder镜像中把/build/app 拷贝到当前目录 22 | COPY --from=builder /build/app / 23 | 24 | EXPOSE 8080 25 | 26 | # 需要运行的命令 27 | ENTRYPOINT ["/app", "royal"] -------------------------------------------------------------------------------- /Dockerfile_server: -------------------------------------------------------------------------------- 1 | FROM golang:alpine AS builder 2 | 3 | # 为我们的镜像设置必要的环境变量 4 | ENV GO111MODULE=on \ 5 | GOPROXY=https://goproxy.cn \ 6 | CGO_ENABLED=0 \ 7 | GOOS=linux \ 8 | GOARCH=amd64 9 | 10 | WORKDIR /build 11 | 12 | COPY . . 13 | 14 | RUN go build -o app ./services 15 | 16 | FROM scratch 17 | 18 | # 从builder镜像中把/build/app 拷贝到当前目录 19 | COPY --from=builder /build/app / 20 | 21 | EXPOSE 8005 22 | 23 | # 需要运行的命令 24 | ENTRYPOINT ["/app", "server"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KIM 2 | 3 | King IM CLoud 4 | 5 | ## 简介 6 | 7 | **kim 是一个高性能分布式即时通信系统。** 8 | 9 | ![structure.png](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/2633b07fd1a144d685ceed9be5f64911~tplv-k3u1fbpfcp-watermark.image) 10 | 11 | - Web SDK: [Typescript SDK](https://github.com/klintcheng/kim_web_sdk) 12 | - Flutter SDK: [Flutter SDK](https://github.com/szhua/KimSdk) 13 | - 由[@szhua](https://github.com/szhua)小友提供 14 | 15 | ## 环境准备 16 | 17 | ### 中间件安装 18 | 19 | Kim依赖mysql、Consul和Redis。因此,在本地测试时需要准备相应环境。这里提供两种方式: 20 | 21 | 方式一: 通过docker-compose启动 22 | 23 | > docker-compose -f "docker-compose.yml" up -d --build 24 | 25 | 方式二: docker分别启动 26 | 27 | ```cmd 28 | docker run -itd --name kim_mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=123456 mysql 29 | 30 | docker run \ 31 | -d \ 32 | -p 8500:8500 \ 33 | -p 8600:8600/udp \ 34 | --name=kim_consul \ 35 | consul agent -server -ui -node=server-1 -bootstrap-expect=1 -client=0.0.0.0 36 | 37 | docker run -itd --name kim_redis -p 6379:6379 redis 38 | ``` 39 | 40 | ### 数据准备 41 | 42 | 1. 进入Mysql,修改访问权限: 43 | 1. docker exec -it kim_mysql /bin/sh 44 | 2. mysql -uroot -p123456 45 | 3. GRANT ALL ON *.* TO 'root'@'%'; 46 | 4. flush privileges; 47 | 2. 创建数据库 48 | 1. create database kim_base default character set utf8mb4 collate utf8mb4_unicode_ci; 49 | 2. create database kim_message default character set utf8mb4 collate utf8mb4_unicode_ci; 50 | 51 | ## 启动服务 52 | 53 | 首先进入services中,分别启动三个服务: 54 | 55 | ``` 56 | go run main.go gateway 57 | go run main.go server 58 | go run main.go royal 59 | ``` 60 | 61 | 或者,通过docker-compose启动: 62 | 63 | ``` 64 | docker-compose -f "docker-compose-kim.yml" up -d --build 65 | ``` 66 | 67 | 访问Consul,可以查看服务启动状态: 68 | 69 | > http://localhost:8500/ui 70 | -------------------------------------------------------------------------------- /build_images.sh: -------------------------------------------------------------------------------- 1 | 2 | docker build --pull --rm -f "Dockerfile_royal" -t dockerklint/kim_royal:v1.4 "." 3 | 4 | docker build --pull --rm -f "Dockerfile_gateway" -t dockerklint/kim_gateway:v1.4 "." 5 | 6 | docker build --pull --rm -f "Dockerfile_server" -t dockerklint/kim_server:v1.4 "." 7 | 8 | docker build --pull --rm -f "Dockerfile_router" -t dockerklint/kim_router:v1.1 "." 9 | 10 | docker push dockerklint/kim_royal:v1.4 11 | docker push dockerklint/kim_gateway:v1.4 12 | docker push dockerklint/kim_server:v1.4 13 | docker push dockerklint/kim_router:v1.1 -------------------------------------------------------------------------------- /channels.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/klintcheng/kim/logger" 7 | ) 8 | 9 | // ChannelMap ChannelMap 10 | type ChannelMap interface { 11 | Add(channel Channel) 12 | Remove(id string) 13 | Get(id string) (channel Channel, ok bool) 14 | All() []Channel 15 | } 16 | 17 | // ChannelsImpl ChannelMap 18 | type ChannelsImpl struct { 19 | channels *sync.Map 20 | } 21 | 22 | // NewChannels NewChannels 23 | func NewChannels(num int) ChannelMap { 24 | return &ChannelsImpl{ 25 | channels: new(sync.Map), 26 | } 27 | } 28 | 29 | // Add addChannel 30 | func (ch *ChannelsImpl) Add(channel Channel) { 31 | if channel.ID() == "" { 32 | logger.WithFields(logger.Fields{ 33 | "module": "ChannelsImpl", 34 | }).Error("channel id is required") 35 | return 36 | } 37 | 38 | ch.channels.Store(channel.ID(), channel) 39 | } 40 | 41 | // Remove addChannel 42 | func (ch *ChannelsImpl) Remove(id string) { 43 | ch.channels.Delete(id) 44 | } 45 | 46 | // Get Get 47 | func (ch *ChannelsImpl) Get(id string) (Channel, bool) { 48 | if id == "" { 49 | logger.WithFields(logger.Fields{ 50 | "module": "ChannelsImpl", 51 | }).Error("channel id is required") 52 | return nil, false 53 | } 54 | 55 | val, ok := ch.channels.Load(id) 56 | if !ok { 57 | return nil, false 58 | } 59 | return val.(Channel), true 60 | } 61 | 62 | // All return channels 63 | func (ch *ChannelsImpl) All() []Channel { 64 | arr := make([]Channel, 0) 65 | ch.channels.Range(func(key, val interface{}) bool { 66 | arr = append(arr, val.(Channel)) 67 | return true 68 | }) 69 | return arr 70 | } 71 | -------------------------------------------------------------------------------- /channels_test.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/golang/mock/gomock" 7 | "github.com/segmentio/ksuid" 8 | ) 9 | 10 | func Benchmark_ChannelsAdd(b *testing.B) { 11 | ctrl := gomock.NewController(b) 12 | 13 | chs := NewChannels(10) 14 | 15 | b.ReportAllocs() 16 | b.ResetTimer() 17 | b.RunParallel(func(p *testing.PB) { 18 | for p.Next() { 19 | ch := NewMockChannel(ctrl) 20 | id := ksuid.New().String() 21 | ch.EXPECT().ID().AnyTimes().Return(id) 22 | chs.Add(ch) 23 | chs.Get(id) 24 | } 25 | }) 26 | } 27 | -------------------------------------------------------------------------------- /container/clients.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/klintcheng/kim" 7 | "github.com/klintcheng/kim/logger" 8 | ) 9 | 10 | // Clients Clients 11 | type ClientMap interface { 12 | Add(client kim.Client) 13 | Remove(id string) 14 | Get(id string) (client kim.Client, ok bool) 15 | // Find(name string) (client []kim.Client) 16 | Services(kvs ...string) []kim.Service 17 | } 18 | 19 | type ClientsImpl struct { 20 | clients *sync.Map 21 | } 22 | 23 | // NewClients NewClients 24 | func NewClients(num int) ClientMap { 25 | return &ClientsImpl{ 26 | clients: new(sync.Map), 27 | } 28 | } 29 | 30 | // Add addChannel 31 | func (ch *ClientsImpl) Add(client kim.Client) { 32 | if client.ServiceID() == "" { 33 | logger.WithFields(logger.Fields{ 34 | "module": "ClientsImpl", 35 | }).Error("client id is required") 36 | } 37 | ch.clients.Store(client.ServiceID(), client) 38 | } 39 | 40 | // Remove addChannel 41 | func (ch *ClientsImpl) Remove(id string) { 42 | ch.clients.Delete(id) 43 | } 44 | 45 | // Get Get 46 | func (ch *ClientsImpl) Get(id string) (kim.Client, bool) { 47 | if id == "" { 48 | logger.WithFields(logger.Fields{ 49 | "module": "ClientsImpl", 50 | }).Error("client id is required") 51 | } 52 | 53 | val, ok := ch.clients.Load(id) 54 | if !ok { 55 | return nil, false 56 | } 57 | return val.(kim.Client), true 58 | } 59 | 60 | // 返回服务列表,可以传一对 61 | func (ch *ClientsImpl) Services(kvs ...string) []kim.Service { 62 | kvLen := len(kvs) 63 | if kvLen != 0 && kvLen != 2 { 64 | return nil 65 | } 66 | arr := make([]kim.Service, 0) 67 | ch.clients.Range(func(key, val interface{}) bool { 68 | ser := val.(kim.Service) 69 | if kvLen > 0 && ser.GetMeta()[kvs[0]] != kvs[1] { 70 | return true 71 | } 72 | arr = append(arr, ser) 73 | return true 74 | }) 75 | return arr 76 | } 77 | -------------------------------------------------------------------------------- /container/hash_selector.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "github.com/klintcheng/kim" 5 | "github.com/klintcheng/kim/wire/pkt" 6 | ) 7 | 8 | // HashSelector HashSelector 9 | type HashSelector struct { 10 | } 11 | 12 | // Lookup a server 13 | func (s *HashSelector) Lookup(header *pkt.Header, srvs []kim.Service) string { 14 | ll := len(srvs) 15 | code := HashCode(header.ChannelId) 16 | return srvs[code%ll].ServiceID() 17 | } 18 | -------------------------------------------------------------------------------- /container/metrics.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | "github.com/prometheus/client_golang/prometheus/promauto" 6 | ) 7 | 8 | var messageOutFlowBytes = promauto.NewCounterVec(prometheus.CounterOpts{ 9 | Namespace: "kim", 10 | Name: "message_out_flow_bytes", 11 | Help: "网关下发的消息字节数", 12 | }, []string{"command"}) 13 | -------------------------------------------------------------------------------- /container/selector.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "hash/crc32" 5 | 6 | "github.com/klintcheng/kim" 7 | "github.com/klintcheng/kim/wire/pkt" 8 | ) 9 | 10 | // HashCode generated a hash code 11 | func HashCode(key string) int { 12 | hash32 := crc32.NewIEEE() 13 | hash32.Write([]byte(key)) 14 | return int(hash32.Sum32()) 15 | } 16 | 17 | // Selector is used to select a Service 18 | type Selector interface { 19 | Lookup(*pkt.Header, []kim.Service) string 20 | } 21 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/klintcheng/kim/logger" 7 | "github.com/klintcheng/kim/wire" 8 | "github.com/klintcheng/kim/wire/pkt" 9 | "google.golang.org/protobuf/proto" 10 | ) 11 | 12 | // Session read only 13 | type Session interface { 14 | GetChannelId() string 15 | GetGateId() string 16 | GetAccount() string 17 | GetRemoteIP() string 18 | GetApp() string 19 | GetTags() []string 20 | } 21 | 22 | type Context interface { 23 | Dispatcher 24 | SessionStorage 25 | Header() *pkt.Header 26 | ReadBody(val proto.Message) error 27 | Session() Session 28 | RespWithError(status pkt.Status, err error) error 29 | Resp(status pkt.Status, body proto.Message) error 30 | Dispatch(body proto.Message, recvs ...*Location) error 31 | Next() 32 | } 33 | 34 | // HandlerFunc defines the handler used 35 | type HandlerFunc func(Context) 36 | 37 | // HandlersChain HandlersChain 38 | type HandlersChain []HandlerFunc 39 | 40 | // ContextImpl is the most important part of kim 41 | type ContextImpl struct { 42 | sync.Mutex 43 | Dispatcher 44 | SessionStorage 45 | 46 | handlers HandlersChain 47 | index int 48 | request *pkt.LogicPkt 49 | session Session 50 | } 51 | 52 | func BuildContext() Context { 53 | return &ContextImpl{} 54 | } 55 | 56 | // Next execute next handler 57 | func (c *ContextImpl) Next() { 58 | if c.index >= len(c.handlers) { 59 | return 60 | } 61 | f := c.handlers[c.index] 62 | c.index++ 63 | if f == nil { 64 | logger.Warn("arrived unknown HandlerFunc") 65 | return 66 | } 67 | f(c) 68 | } 69 | 70 | // RespWithError response with error 71 | func (c *ContextImpl) RespWithError(status pkt.Status, err error) error { 72 | return c.Resp(status, &pkt.ErrorResp{Message: err.Error()}) 73 | } 74 | 75 | // Resp send a response message to sender, the header of packet copied from request 76 | func (c *ContextImpl) Resp(status pkt.Status, body proto.Message) error { 77 | packet := pkt.NewFrom(&c.request.Header) 78 | packet.Status = status 79 | packet.WriteBody(body) 80 | packet.Flag = pkt.Flag_Response 81 | logger.Debugf("<-- Resp to %s command:%s status: %v body: %s", c.Session().GetAccount(), &c.request.Header, status, body) 82 | 83 | err := c.Push(c.Session().GetGateId(), []string{c.Session().GetChannelId()}, packet) 84 | if err != nil { 85 | logger.Error(err) 86 | } 87 | return err 88 | } 89 | 90 | // Dispatch the packet to the Destination of request, 91 | // the header flag of this packet will be set with FlagDelivery 92 | // exceptMe: exclude self if self is false 93 | func (c *ContextImpl) Dispatch(body proto.Message, recvs ...*Location) error { 94 | if len(recvs) == 0 { 95 | return nil 96 | } 97 | packet := pkt.NewFrom(&c.request.Header) 98 | packet.Flag = pkt.Flag_Push 99 | packet.WriteBody(body) 100 | 101 | logger.Debugf("<-- Dispatch to %d users command:%s", len(recvs), &c.request.Header) 102 | 103 | // the receivers group by the destination of gateway 104 | group := make(map[string][]string) 105 | for _, recv := range recvs { 106 | if recv.ChannelId == c.Session().GetChannelId() { 107 | continue 108 | } 109 | if _, ok := group[recv.GateId]; !ok { 110 | group[recv.GateId] = make([]string, 0) 111 | } 112 | group[recv.GateId] = append(group[recv.GateId], recv.ChannelId) 113 | } 114 | for gateway, ids := range group { 115 | err := c.Push(gateway, ids, packet) 116 | if err != nil { 117 | logger.Error(err) 118 | } 119 | return err 120 | } 121 | return nil 122 | } 123 | 124 | func (c *ContextImpl) reset() { 125 | c.request = nil 126 | c.index = 0 127 | c.handlers = nil 128 | c.session = nil 129 | } 130 | 131 | func (c *ContextImpl) Header() *pkt.Header { 132 | return &c.request.Header 133 | } 134 | 135 | func (c *ContextImpl) ReadBody(val proto.Message) error { 136 | return c.request.ReadBody(val) 137 | } 138 | 139 | func (c *ContextImpl) Session() Session { 140 | if c.session == nil { 141 | server, _ := c.request.GetMeta(wire.MetaDestServer) 142 | c.session = &pkt.Session{ 143 | ChannelId: c.request.ChannelId, 144 | GateId: server.(string), 145 | Tags: []string{"AutoGenerated"}, 146 | } 147 | } 148 | return c.session 149 | } 150 | -------------------------------------------------------------------------------- /dispatcher.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import "github.com/klintcheng/kim/wire/pkt" 4 | 5 | // Dispatcher defined a component how a message be dispatched to gateway 6 | type Dispatcher interface { 7 | Push(gateway string, channels []string, p *pkt.LogicPkt) error 8 | } 9 | -------------------------------------------------------------------------------- /dispatcher_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: dispatcher.go 3 | 4 | // Package kim is a generated GoMock package. 5 | package kim 6 | 7 | import ( 8 | reflect "reflect" 9 | 10 | gomock "github.com/golang/mock/gomock" 11 | pkt "github.com/klintcheng/kim/wire/pkt" 12 | ) 13 | 14 | // MockDispatcher is a mock of Dispatcher interface. 15 | type MockDispatcher struct { 16 | ctrl *gomock.Controller 17 | recorder *MockDispatcherMockRecorder 18 | } 19 | 20 | // MockDispatcherMockRecorder is the mock recorder for MockDispatcher. 21 | type MockDispatcherMockRecorder struct { 22 | mock *MockDispatcher 23 | } 24 | 25 | // NewMockDispatcher creates a new mock instance. 26 | func NewMockDispatcher(ctrl *gomock.Controller) *MockDispatcher { 27 | mock := &MockDispatcher{ctrl: ctrl} 28 | mock.recorder = &MockDispatcherMockRecorder{mock} 29 | return mock 30 | } 31 | 32 | // EXPECT returns an object that allows the caller to indicate expected use. 33 | func (m *MockDispatcher) EXPECT() *MockDispatcherMockRecorder { 34 | return m.recorder 35 | } 36 | 37 | // Push mocks base method. 38 | func (m *MockDispatcher) Push(gateway string, channels []string, p *pkt.LogicPkt) error { 39 | m.ctrl.T.Helper() 40 | ret := m.ctrl.Call(m, "Push", gateway, channels, p) 41 | ret0, _ := ret[0].(error) 42 | return ret0 43 | } 44 | 45 | // Push indicates an expected call of Push. 46 | func (mr *MockDispatcherMockRecorder) Push(gateway, channels, p interface{}) *gomock.Call { 47 | mr.mock.ctrl.T.Helper() 48 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockDispatcher)(nil).Push), gateway, channels, p) 49 | } 50 | -------------------------------------------------------------------------------- /docker-compose-kim.yml: -------------------------------------------------------------------------------- 1 | version: '3.1' 2 | services: 3 | router: 4 | image: docker.io/dockerklint/kim_router:v1.1 5 | container_name: router 6 | restart: always 7 | networks: 8 | - kimnet 9 | ports: 10 | - "8100:8100" 11 | environment: 12 | KIM_CONSULURL: consul:8500 13 | KIM_LOGLEVEL: INFO 14 | royal: 15 | image: docker.io/dockerklint/kim_royal:v1.4 16 | container_name: royal 17 | restart: always 18 | networks: 19 | - kimnet 20 | ports: 21 | - "8080:8080" 22 | environment: 23 | KIM_PUBLICADDRESS: royal 24 | KIM_CONSULURL: consul:8500 25 | KIM_REDISADDRS: redis:6379 26 | KIM_BASEDB: root:123456@tcp(mysql:3306)/kim_base?charset=utf8mb4&parseTime=True&loc=Local 27 | KIM_MESSAGEDB: root:123456@tcp(mysql:3306)/kim_message?charset=utf8mb4&parseTime=True&loc=Local 28 | KIM_LOGLEVEL: DEBUG 29 | gateway: 30 | image: docker.io/dockerklint/kim_gateway:v1.4 31 | container_name: wgateway 32 | restart: always 33 | networks: 34 | - kimnet 35 | ports: 36 | - "8000:8000" 37 | - "8001:8001" 38 | environment: 39 | KIM_PUBLICADDRESS: gateway 40 | KIM_CONSULURL: consul:8500 41 | KIM_LOGLEVEL: DEBUG 42 | KIM_TAGS: IDC:SH_ALI 43 | KIM_DOMAIN: ws://localhost:8000 44 | server: 45 | image: docker.io/dockerklint/kim_server:v1.4 46 | container_name: chat 47 | restart: always 48 | networks: 49 | - kimnet 50 | ports: 51 | - "8005:8005" 52 | - "8006:8006" 53 | environment: 54 | KIM_PUBLICADDRESS: server 55 | KIM_CONSULURL: consul:8500 56 | KIM_REDISADDRS: redis:6379 57 | KIM_LOGLEVEL: DEBUG 58 | KIM_ROYALURL: http://royal:8080 59 | dns: consul 60 | networks: 61 | kimnet: {} -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.1' 2 | 3 | services: 4 | mysql: 5 | image: docker.io/library/mysql:8.0 6 | container_name: kim_mysql 7 | command: --default-authentication-plugin=mysql_native_password 8 | restart: always 9 | networks: 10 | - kimnet 11 | ports: 12 | - "3306:3306" 13 | volumes: 14 | - ~/data/mysql:/var/lib/mysql 15 | environment: 16 | MYSQL_ROOT_PASSWORD: 123456 17 | redis: 18 | image: docker.io/library/redis:6.2 19 | container_name: kim_redis 20 | command: redis-server 21 | networks: 22 | - kimnet 23 | ports: 24 | - "6379:6379" 25 | volumes: 26 | - ~/data/redis:/data 27 | consul: 28 | image: docker.io/library/consul:latest 29 | container_name: kim_consul 30 | networks: 31 | - kimnet 32 | ports: 33 | - '8300:8300' 34 | - '8301:8301' 35 | - '8301:8301/udp' 36 | - '8500:8500' 37 | - '53:53' 38 | - '53:53/udp' 39 | command: agent -dev -dns-port=53 -recursor=8.8.8.8 -ui -client=0.0.0.0 40 | environment: 41 | CONSUL_BIND_INTERFACE: eth0 42 | CONSUL_ALLOW_PRIVILEGED_PORTS: 53 43 | networks: 44 | kimnet: {} -------------------------------------------------------------------------------- /event.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "sync" 5 | "sync/atomic" 6 | ) 7 | 8 | // Event represents a one-time event that may occur in the future. 9 | type Event struct { 10 | fired int32 11 | c chan struct{} 12 | o sync.Once 13 | } 14 | 15 | // Fire causes e to complete. It is safe to call multiple times, and 16 | // concurrently. It returns true iff this call to Fire caused the signaling 17 | // channel returned by Done to close. 18 | func (e *Event) Fire() bool { 19 | ret := false 20 | e.o.Do(func() { 21 | atomic.StoreInt32(&e.fired, 1) 22 | close(e.c) 23 | ret = true 24 | }) 25 | return ret 26 | } 27 | 28 | // Done returns a channel that will be closed when Fire is called. 29 | func (e *Event) Done() <-chan struct{} { 30 | return e.c 31 | } 32 | 33 | // HasFired returns true if Fire has been called. 34 | func (e *Event) HasFired() bool { 35 | return atomic.LoadInt32(&e.fired) == 1 36 | } 37 | 38 | // NewEvent returns a new, ready-to-use Event. 39 | func NewEvent() *Event { 40 | return &Event{c: make(chan struct{})} 41 | } 42 | -------------------------------------------------------------------------------- /examples/benchmark/server_test.go: -------------------------------------------------------------------------------- 1 | package benchmark 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/klintcheng/kim" 11 | "github.com/klintcheng/kim/examples/mock" 12 | "github.com/klintcheng/kim/logger" 13 | "github.com/klintcheng/kim/websocket" 14 | "github.com/panjf2000/ants/v2" 15 | ) 16 | 17 | const wsurl = "ws://localhost:8000" 18 | 19 | func Test_Parallel(t *testing.T) { 20 | const count = 10000 21 | gpool, _ := ants.NewPool(50, ants.WithPreAlloc(true)) 22 | defer gpool.Release() 23 | var wg sync.WaitGroup 24 | wg.Add(count) 25 | 26 | clis := make([]kim.Client, count) 27 | t0 := time.Now() 28 | for i := 0; i < count; i++ { 29 | idx := i 30 | _ = gpool.Submit(func() { 31 | cli := websocket.NewClient(fmt.Sprintf("test_%v", idx), "client", websocket.ClientOptions{ 32 | Heartbeat: kim.DefaultHeartbeat, 33 | }) 34 | // set dialer 35 | cli.SetDialer(&mock.WebsocketDialer{}) 36 | 37 | // step2: 建立连接 38 | err := cli.Connect(wsurl) 39 | if err != nil { 40 | logger.Error(err) 41 | } 42 | clis[idx] = cli 43 | wg.Done() 44 | }) 45 | } 46 | wg.Wait() 47 | t.Logf("logined %d cost %v", count, time.Since(t0)) 48 | t.Logf("done connecting") 49 | time.Sleep(time.Second * 5) 50 | t.Logf("closed") 51 | 52 | for i := 0; i < count; i++ { 53 | clis[i].Close() 54 | } 55 | } 56 | 57 | func Test_Message(t *testing.T) { 58 | const count = 1000 * 100 59 | cli := websocket.NewClient(fmt.Sprintf("test_%v", 1), "client", websocket.ClientOptions{ 60 | Heartbeat: kim.DefaultHeartbeat, 61 | }) 62 | // set dialer 63 | cli.SetDialer(&mock.WebsocketDialer{}) 64 | 65 | // step2: 建立连接 66 | err := cli.Connect(wsurl) 67 | if err != nil { 68 | logger.Error(err) 69 | } 70 | msg := []byte(strings.Repeat("hello", 10)) 71 | t0 := time.Now() 72 | go func() { 73 | for i := 0; i < count; i++ { 74 | _ = cli.Send(msg) 75 | } 76 | }() 77 | recv := 0 78 | for { 79 | frame, err := cli.Read() 80 | if err != nil { 81 | logger.Info("time", time.Now().UnixNano(), err) 82 | break 83 | } 84 | if frame.GetOpCode() != kim.OpBinary { 85 | continue 86 | } 87 | recv++ 88 | if recv == count { // 接收完消息 89 | break 90 | } 91 | } 92 | 93 | t.Logf("message %d cost %v", count, time.Since(t0)) 94 | } 95 | -------------------------------------------------------------------------------- /examples/dialer/client_dialer.go: -------------------------------------------------------------------------------- 1 | package dialer 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "net" 8 | "time" 9 | 10 | "github.com/gobwas/ws" 11 | "github.com/gobwas/ws/wsutil" 12 | "github.com/klintcheng/kim" 13 | "github.com/klintcheng/kim/logger" 14 | "github.com/klintcheng/kim/wire" 15 | "github.com/klintcheng/kim/wire/pkt" 16 | "github.com/klintcheng/kim/wire/token" 17 | ) 18 | 19 | type ClientDialer struct { 20 | AppSecret string 21 | } 22 | 23 | func (d *ClientDialer) DialAndHandshake(ctx kim.DialerContext) (net.Conn, error) { 24 | // 1. 拨号 25 | conn, _, _, err := ws.Dial(context.TODO(), ctx.Address) 26 | if err != nil { 27 | return nil, err 28 | } 29 | if d.AppSecret == "" { 30 | d.AppSecret = token.DefaultSecret 31 | } 32 | // 2. 直接使用封装的JWT包生成一个token 33 | tk, err := token.Generate(d.AppSecret, &token.Token{ 34 | Account: ctx.Id, 35 | App: "kim", 36 | Exp: time.Now().AddDate(0, 0, 1).Unix(), 37 | }) 38 | if err != nil { 39 | return nil, err 40 | } 41 | // 3. 发送一条CommandLoginSignIn消息 42 | loginreq := pkt.New(wire.CommandLoginSignIn).WriteBody(&pkt.LoginReq{ 43 | Token: tk, 44 | }) 45 | err = wsutil.WriteClientBinary(conn, pkt.Marshal(loginreq)) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | // wait resp 51 | _ = conn.SetReadDeadline(time.Now().Add(ctx.Timeout)) 52 | frame, err := ws.ReadFrame(conn) 53 | if err != nil { 54 | return nil, err 55 | } 56 | ack, err := pkt.MustReadLogicPkt(bytes.NewBuffer(frame.Payload)) 57 | if err != nil { 58 | return nil, err 59 | } 60 | // 4. 判断是否登录成功 61 | if ack.Status != pkt.Status_Success { 62 | return nil, fmt.Errorf("login failed: %v", &ack.Header) 63 | } 64 | var resp = new(pkt.LoginResp) 65 | _ = ack.ReadBody(resp) 66 | 67 | logger.Debug("logined ", resp.GetChannelId()) 68 | return conn, nil 69 | } 70 | -------------------------------------------------------------------------------- /examples/dialer/login.go: -------------------------------------------------------------------------------- 1 | package dialer 2 | 3 | import ( 4 | "github.com/klintcheng/kim" 5 | "github.com/klintcheng/kim/websocket" 6 | "github.com/klintcheng/kim/wire/token" 7 | ) 8 | 9 | func Login(wsurl, account string, appSecrets ...string) (kim.Client, error) { 10 | cli := websocket.NewClient(account, "unittest", websocket.ClientOptions{}) 11 | secret := token.DefaultSecret 12 | if len(appSecrets) > 0 { 13 | secret = appSecrets[0] 14 | } 15 | cli.SetDialer(&ClientDialer{ 16 | AppSecret: secret, 17 | }) 18 | err := cli.Connect(wsurl) 19 | if err != nil { 20 | return nil, err 21 | } 22 | return cli, nil 23 | } 24 | -------------------------------------------------------------------------------- /examples/echo/echo.go: -------------------------------------------------------------------------------- 1 | package echo 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "time" 7 | 8 | "github.com/klintcheng/kim" 9 | "github.com/klintcheng/kim/examples/dialer" 10 | "github.com/klintcheng/kim/logger" 11 | "github.com/klintcheng/kim/websocket" 12 | "github.com/klintcheng/kim/wire" 13 | "github.com/klintcheng/kim/wire/pkt" 14 | "github.com/spf13/cobra" 15 | ) 16 | 17 | // StartOptions StartOptions 18 | type StartOptions struct { 19 | } 20 | 21 | // NewCmd NewCmd 22 | func NewCmd(ctx context.Context) *cobra.Command { 23 | opts := &StartOptions{} 24 | 25 | cmd := &cobra.Command{ 26 | Use: "echo", 27 | Short: "Start echo client", 28 | RunE: func(cmd *cobra.Command, args []string) error { 29 | return run(ctx, opts) 30 | }, 31 | } 32 | 33 | return cmd 34 | } 35 | 36 | func run(ctx context.Context, opts *StartOptions) error { 37 | cli := websocket.NewClient("test1", "echo", websocket.ClientOptions{ 38 | Heartbeat: time.Second * 30, 39 | ReadWait: time.Minute * 3, 40 | WriteWait: time.Second * 10, 41 | }) 42 | 43 | cli.SetDialer(&dialer.ClientDialer{}) 44 | 45 | err := cli.Connect("ws://localhost:8000") 46 | if err != nil { 47 | return err 48 | } 49 | count := 5 50 | 51 | go func() { 52 | // step3: 发送消息然后退出 53 | for i := 0; i < count; i++ { 54 | p := pkt.New(wire.CommandChatUserTalk, pkt.WithDest("test1")) 55 | p.WriteBody(&pkt.MessageReq{ 56 | Type: 1, 57 | Body: "hello world", 58 | }) 59 | err := cli.Send(pkt.Marshal(p)) 60 | if err != nil { 61 | logger.Error(err) 62 | return 63 | } 64 | time.Sleep(time.Second) 65 | } 66 | }() 67 | 68 | // step4: 接收Ack消息 69 | recv := 0 70 | for { 71 | frame, err := cli.Read() 72 | if err != nil { 73 | logger.Info(err) 74 | break 75 | } 76 | if frame.GetOpCode() != kim.OpBinary { 77 | continue 78 | } 79 | recv++ 80 | 81 | p, err := pkt.MustReadLogicPkt(bytes.NewBuffer(frame.GetPayload())) 82 | if err != nil { 83 | logger.Info(err) 84 | break 85 | } 86 | if p.Status != pkt.Status_Success { 87 | var errResp pkt.ErrorResp 88 | _ = p.ReadBody(&errResp) 89 | 90 | logger.Warnf("%s error:%s", cli.ServiceID(), errResp.Message) 91 | } else { 92 | if p.Flag == pkt.Flag_Response { 93 | var ack = new(pkt.MessageResp) 94 | _ = p.ReadBody(ack) 95 | 96 | logger.Warnf("%s receive Ack [%d]", cli.ServiceID(), ack.GetMessageId()) 97 | } else if p.Flag == pkt.Flag_Push { 98 | var push = new(pkt.MessagePush) 99 | _ = p.ReadBody(push) 100 | 101 | logger.Warnf("%s receive message [%d] %s", cli.ServiceID(), push.GetMessageId(), push.Body) 102 | } 103 | 104 | } 105 | 106 | if recv == count*2 { // 接收完消息 107 | break 108 | } 109 | } 110 | cli.Close() 111 | 112 | return nil 113 | } 114 | -------------------------------------------------------------------------------- /examples/kimbench/cmd.go: -------------------------------------------------------------------------------- 1 | package kimbench 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/klintcheng/kim/wire/token" 8 | "github.com/spf13/cobra" 9 | ) 10 | 11 | // DefaultOptions DefaultOptions 12 | type Options struct { 13 | Addr string 14 | AppSecret string 15 | Count int 16 | Threads int 17 | } 18 | 19 | // NewCmd NewCmd 20 | func NewBenchmarkCmd(ctx context.Context) *cobra.Command { 21 | cmd := &cobra.Command{ 22 | Use: "benchmark", 23 | Short: "kim benchmark tools", 24 | } 25 | var opts = &Options{} 26 | cmd.PersistentFlags().StringVarP(&opts.Addr, "address", "a", "ws://localhost:8000", "server address") 27 | cmd.PersistentFlags().StringVarP(&opts.AppSecret, "appSecret", "s", token.DefaultSecret, "app secret") 28 | cmd.PersistentFlags().IntVarP(&opts.Count, "count", "c", 100, "request count") 29 | cmd.PersistentFlags().IntVarP(&opts.Threads, "thread", "t", 10, "thread count") 30 | 31 | cmd.AddCommand(NewUserTalkCmd(opts)) 32 | cmd.AddCommand(NewGroupTalkCmd(opts)) 33 | cmd.AddCommand(NewLoginCmd(opts)) 34 | return cmd 35 | } 36 | 37 | type UserOptions struct { 38 | online bool 39 | } 40 | 41 | func NewUserTalkCmd(opts *Options) *cobra.Command { 42 | var options = &UserOptions{} 43 | 44 | cmd := &cobra.Command{ 45 | Use: "user", 46 | RunE: func(cmd *cobra.Command, args []string) error { 47 | err := usertalk(opts.Addr, opts.AppSecret, opts.Threads, opts.Count, options.online) 48 | if err != nil { 49 | return err 50 | } 51 | return nil 52 | }, 53 | } 54 | 55 | cmd.PersistentFlags().BoolVarP(&options.online, "online", "o", false, "set if receiver is online") 56 | return cmd 57 | } 58 | 59 | type LoginOptions struct { 60 | keep time.Duration 61 | } 62 | 63 | func NewLoginCmd(opts *Options) *cobra.Command { 64 | var options = &LoginOptions{} 65 | cmd := &cobra.Command{ 66 | Use: "login", 67 | RunE: func(cmd *cobra.Command, args []string) error { 68 | err := login(opts.Addr, opts.AppSecret, opts.Threads, opts.Count, options.keep) 69 | if err != nil { 70 | return err 71 | } 72 | return nil 73 | }, 74 | } 75 | cmd.PersistentFlags().DurationVarP(&options.keep, "keep", "k", time.Millisecond*10, "the duration of keeping the client connection") 76 | return cmd 77 | } 78 | 79 | type GroupOptions struct { 80 | MemberCount int 81 | OnlinePercent float32 82 | } 83 | 84 | func NewGroupTalkCmd(opts *Options) *cobra.Command { 85 | var options = &GroupOptions{} 86 | 87 | cmd := &cobra.Command{ 88 | Use: "group", 89 | RunE: func(cmd *cobra.Command, args []string) error { 90 | err := grouptalk(opts.Addr, opts.AppSecret, opts.Threads, opts.Count, options.MemberCount, options.OnlinePercent) 91 | if err != nil { 92 | return err 93 | } 94 | return nil 95 | }, 96 | } 97 | 98 | cmd.PersistentFlags().IntVarP(&options.MemberCount, "memcount", "m", 20, "member count") 99 | cmd.PersistentFlags().Float32VarP(&options.OnlinePercent, "percet", "p", 0.5, "online percet") 100 | return cmd 101 | } 102 | -------------------------------------------------------------------------------- /examples/kimbench/grouptalk.go: -------------------------------------------------------------------------------- 1 | package kimbench 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "os" 7 | "sync" 8 | "time" 9 | 10 | "github.com/klintcheng/kim" 11 | "github.com/klintcheng/kim/examples/dialer" 12 | "github.com/klintcheng/kim/report" 13 | "github.com/klintcheng/kim/wire" 14 | "github.com/klintcheng/kim/wire/pkt" 15 | "github.com/panjf2000/ants/v2" 16 | ) 17 | 18 | func grouptalk(wsurl, appSecret string, threads, count int, memberCount int, onlinePercent float32) error { 19 | cli1, err := dialer.Login(wsurl, "test1", appSecret) 20 | if err != nil { 21 | return err 22 | } 23 | var members = make([]string, memberCount) 24 | for i := 0; i < memberCount; i++ { 25 | members[i] = fmt.Sprintf("test_%d", i+1) 26 | } 27 | // 创建群 28 | p := pkt.New(wire.CommandGroupCreate) 29 | p.WriteBody(&pkt.GroupCreateReq{ 30 | Name: "group1", 31 | Owner: "test1", 32 | Members: members, 33 | }) 34 | if err = cli1.Send(pkt.Marshal(p)); err != nil { 35 | return err 36 | } 37 | // 读取返回信息 38 | ack, _ := cli1.Read() 39 | ackp, _ := pkt.MustReadLogicPkt(bytes.NewBuffer(ack.GetPayload())) 40 | if pkt.Status_Success != ackp.GetStatus() { 41 | return fmt.Errorf("create group failed") 42 | } 43 | var createresp pkt.GroupCreateResp 44 | _ = ackp.ReadBody(&createresp) 45 | group := createresp.GetGroupId() 46 | 47 | onlines := int(float32(memberCount) * onlinePercent) 48 | if onlines < 1 { 49 | onlines = 1 50 | } 51 | for i := 1; i < onlines; i++ { 52 | clix, err := dialer.Login(wsurl, fmt.Sprintf("test_%d", i), appSecret) 53 | if err != nil { 54 | return err 55 | } 56 | go func(cli kim.Client) { 57 | for { 58 | _, err := cli.Read() 59 | if err != nil { 60 | return 61 | } 62 | } 63 | }(clix) 64 | } 65 | 66 | clis, err := loginMulti(wsurl, appSecret, 2, threads) 67 | if err != nil { 68 | return err 69 | } 70 | 71 | pool, _ := ants.NewPool(threads, ants.WithPreAlloc(true)) 72 | defer pool.Release() 73 | 74 | r := report.New(os.Stdout, count) 75 | t1 := time.Now() 76 | 77 | var wg sync.WaitGroup 78 | wg.Add(count) 79 | for i := 0; i < count; i++ { 80 | cli := clis[i%threads] 81 | _ = pool.Submit(func() { 82 | defer func() { 83 | wg.Done() 84 | }() 85 | 86 | t0 := time.Now() 87 | p := pkt.New(wire.CommandChatGroupTalk, pkt.WithDest(group)) 88 | p.WriteBody(&pkt.MessageReq{ 89 | Type: 1, 90 | Body: "hello world", 91 | }) 92 | // 发送消息 93 | err := cli.Send(pkt.Marshal(p)) 94 | if err != nil { 95 | r.Add(&report.Result{ 96 | Err: err, 97 | ContentLength: 11, 98 | }) 99 | return 100 | } 101 | // 读取Resp消息 102 | _, err = cli.Read() 103 | if err != nil { 104 | r.Add(&report.Result{ 105 | Err: err, 106 | ContentLength: 11, 107 | }) 108 | return 109 | } 110 | r.Add(&report.Result{ 111 | Duration: time.Since(t0), 112 | Err: err, 113 | StatusCode: 0, 114 | }) 115 | }) 116 | } 117 | 118 | wg.Wait() 119 | r.Finalize(time.Since(t1)) 120 | return nil 121 | } 122 | -------------------------------------------------------------------------------- /examples/kimbench/login.go: -------------------------------------------------------------------------------- 1 | package kimbench 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "sync" 7 | "time" 8 | 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/examples/dialer" 11 | "github.com/klintcheng/kim/logger" 12 | "github.com/klintcheng/kim/report" 13 | "github.com/panjf2000/ants/v2" 14 | ) 15 | 16 | func login(wsurl, appSecret string, threads int, count int, keep time.Duration) error { 17 | p, _ := ants.NewPool(threads, ants.WithPreAlloc(true)) 18 | defer p.Release() 19 | 20 | r := report.New(os.Stdout, count) 21 | t1 := time.Now() 22 | 23 | var wg sync.WaitGroup 24 | wg.Add(count) 25 | clis := make([]kim.Client, count) 26 | for i := 0; i < count; i++ { 27 | idx := i 28 | _ = p.Submit(func() { 29 | t0 := time.Now() 30 | cli, err := dialer.Login(wsurl, fmt.Sprintf("test%d", idx+1), appSecret) 31 | r.Add(&report.Result{ 32 | Duration: time.Since(t0), 33 | Err: err, 34 | StatusCode: 0, 35 | }) 36 | if err != nil { 37 | logger.Error(err) 38 | } else { 39 | clis[idx] = cli 40 | } 41 | wg.Done() 42 | }) 43 | } 44 | wg.Wait() 45 | 46 | r.Finalize(time.Since(t1)) 47 | 48 | logger.Infof("keep login for %v", keep) 49 | time.Sleep(keep) 50 | 51 | for _, cli := range clis { 52 | cli.Close() 53 | } 54 | logger.Infoln("shutdown..") 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /examples/kimbench/usertalk.go: -------------------------------------------------------------------------------- 1 | package kimbench 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "sync" 7 | "time" 8 | 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/examples/dialer" 11 | "github.com/klintcheng/kim/report" 12 | "github.com/klintcheng/kim/wire" 13 | "github.com/klintcheng/kim/wire/pkt" 14 | "github.com/panjf2000/ants/v2" 15 | ) 16 | 17 | func loginMulti(wsurl, appSecret string, start, count int) ([]kim.Client, error) { 18 | clis := make([]kim.Client, count) 19 | for i := 0; i < count; i++ { 20 | account := fmt.Sprintf("test%d", start) 21 | start++ 22 | cli, err := dialer.Login(wsurl, account, appSecret) 23 | if err != nil { 24 | return nil, err 25 | } 26 | clis[i] = cli 27 | } 28 | return clis, nil 29 | } 30 | 31 | func usertalk(wsurl, appSecret string, threads, count int, online bool) error { 32 | p, _ := ants.NewPool(threads, ants.WithPreAlloc(true)) 33 | defer p.Release() 34 | 35 | if online { 36 | cli2, _ := dialer.Login(wsurl, "test1") 37 | 38 | go func() { 39 | for { 40 | _, err := cli2.Read() 41 | if err != nil { 42 | return 43 | } 44 | } 45 | }() 46 | } 47 | 48 | clis, err := loginMulti(wsurl, appSecret, 2, threads) 49 | if err != nil { 50 | return err 51 | } 52 | 53 | r := report.New(os.Stdout, count) 54 | t1 := time.Now() 55 | 56 | var wg sync.WaitGroup 57 | wg.Add(count) 58 | for i := 0; i < count; i++ { 59 | cli := clis[i%threads] 60 | _ = p.Submit(func() { 61 | defer func() { 62 | wg.Done() 63 | }() 64 | 65 | t0 := time.Now() 66 | p := pkt.New(wire.CommandChatUserTalk, pkt.WithDest("test1")) 67 | p.WriteBody(&pkt.MessageReq{ 68 | Type: 1, 69 | Body: "hello world", 70 | }) 71 | // 发送消息 72 | err := cli.Send(pkt.Marshal(p)) 73 | if err != nil { 74 | r.Add(&report.Result{ 75 | Err: err, 76 | ContentLength: 11, 77 | }) 78 | return 79 | } 80 | // 读取Resp消息 81 | _, err = cli.Read() 82 | if err != nil { 83 | r.Add(&report.Result{ 84 | Err: err, 85 | ContentLength: 11, 86 | }) 87 | return 88 | } 89 | r.Add(&report.Result{ 90 | Duration: time.Since(t0), 91 | Err: err, 92 | StatusCode: 0, 93 | }) 94 | }) 95 | } 96 | wg.Wait() 97 | 98 | r.Finalize(time.Since(t1)) 99 | return nil 100 | } 101 | -------------------------------------------------------------------------------- /examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | 7 | "github.com/klintcheng/kim/examples/echo" 8 | "github.com/klintcheng/kim/examples/kimbench" 9 | "github.com/klintcheng/kim/examples/mock" 10 | "github.com/klintcheng/kim/logger" 11 | "github.com/spf13/cobra" 12 | ) 13 | 14 | const version = "v1" 15 | 16 | func main() { 17 | flag.Parse() 18 | 19 | root := &cobra.Command{ 20 | Use: "kim", 21 | Version: version, 22 | Short: "tools", 23 | } 24 | ctx := context.Background() 25 | 26 | // run echo client 27 | root.AddCommand(echo.NewCmd(ctx)) 28 | 29 | // mock 30 | root.AddCommand(mock.NewClientCmd(ctx)) 31 | root.AddCommand(mock.NewServerCmd(ctx)) 32 | root.AddCommand(kimbench.NewBenchmarkCmd(ctx)) 33 | 34 | if err := root.Execute(); err != nil { 35 | logger.WithError(err).Fatal("Could not run command") 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /examples/mock/client.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "time" 8 | 9 | "github.com/gobwas/ws" 10 | "github.com/gobwas/ws/wsutil" 11 | "github.com/klintcheng/kim" 12 | "github.com/klintcheng/kim/logger" 13 | "github.com/klintcheng/kim/tcp" 14 | "github.com/klintcheng/kim/websocket" 15 | ) 16 | 17 | // ClientDemo Client demo 18 | type ClientDemo struct { 19 | } 20 | 21 | func (c *ClientDemo) Start(userID, protocol, addr string) { 22 | var cli kim.Client 23 | 24 | // step1: 初始化客户端 25 | if protocol == "ws" { 26 | cli = websocket.NewClient(userID, "client", websocket.ClientOptions{}) 27 | // set dialer 28 | cli.SetDialer(&WebsocketDialer{}) 29 | } else if protocol == "tcp" { 30 | cli = tcp.NewClient("test1", "client", tcp.ClientOptions{}) 31 | cli.SetDialer(&TCPDialer{}) 32 | } 33 | 34 | // step2: 建立连接 35 | err := cli.Connect(addr) 36 | if err != nil { 37 | logger.Error(err) 38 | return 39 | } 40 | count := 10 41 | go func() { 42 | // step3: 发送消息然后退出 43 | for i := 0; i < count; i++ { 44 | err := cli.Send([]byte(fmt.Sprintf("hello_%d", i))) 45 | if err != nil { 46 | logger.Error(err) 47 | return 48 | } 49 | time.Sleep(time.Millisecond * 10) 50 | } 51 | }() 52 | 53 | // step4: 接收消息 54 | recv := 0 55 | for { 56 | frame, err := cli.Read() 57 | if err != nil { 58 | logger.Info(err) 59 | break 60 | } 61 | if frame.GetOpCode() != kim.OpBinary { 62 | continue 63 | } 64 | recv++ 65 | logger.Infof("%s receive message [%s]", cli.ServiceID(), frame.GetPayload()) 66 | if recv == count { // 接收完消息 67 | break 68 | } 69 | } 70 | //退出 71 | cli.Close() 72 | } 73 | 74 | // WebsocketDialer WebsocketDialer 75 | type WebsocketDialer struct { 76 | } 77 | 78 | // DialAndHandshake DialAndHandshake 79 | func (d *WebsocketDialer) DialAndHandshake(ctx kim.DialerContext) (net.Conn, error) { 80 | logger.Info("start ws dial: ", ctx.Address) 81 | // 1 调用ws.Dial拨号 82 | ctxWithTimeout, cancel := context.WithTimeout(context.TODO(), ctx.Timeout) 83 | defer cancel() 84 | 85 | conn, _, _, err := ws.Dial(ctxWithTimeout, ctx.Address) 86 | if err != nil { 87 | return nil, err 88 | } 89 | // 2. 发送用户认证信息,示例就是userid 90 | err = wsutil.WriteClientBinary(conn, []byte(ctx.Id)) 91 | if err != nil { 92 | return nil, err 93 | } 94 | // 3. return conn 95 | return conn, nil 96 | } 97 | 98 | // TCPDialer TCPDialer 99 | type TCPDialer struct { 100 | } 101 | 102 | // DialAndHandshake DialAndHandshake 103 | func (d *TCPDialer) DialAndHandshake(ctx kim.DialerContext) (net.Conn, error) { 104 | logger.Info("start tcp dial: ", ctx.Address) 105 | // 1 调用net.Dial拨号 106 | conn, err := net.DialTimeout("tcp", ctx.Address, ctx.Timeout) 107 | if err != nil { 108 | return nil, err 109 | } 110 | // 2. 发送用户认证信息,示例就是userid 111 | err = tcp.WriteFrame(conn, kim.OpBinary, []byte(ctx.Id)) 112 | if err != nil { 113 | return nil, err 114 | } 115 | // 3. return conn 116 | return conn, nil 117 | } 118 | -------------------------------------------------------------------------------- /examples/mock/cmd.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/segmentio/ksuid" 9 | "github.com/spf13/cobra" 10 | ) 11 | 12 | // StartOptions StartOptions 13 | type StartOptions struct { 14 | addr string 15 | protocol string 16 | } 17 | 18 | // NewCmd NewCmd 19 | func NewClientCmd(ctx context.Context) *cobra.Command { 20 | opts := &StartOptions{} 21 | 22 | cmd := &cobra.Command{ 23 | Use: "mock_cli", 24 | Short: "start client", 25 | RunE: func(cmd *cobra.Command, args []string) error { 26 | return runcli(ctx, opts) 27 | }, 28 | } 29 | cmd.PersistentFlags().StringVarP(&opts.addr, "address", "a", "localhost:8000", "server address") 30 | cmd.PersistentFlags().StringVarP(&opts.protocol, "protocol", "p", "ws", "protocol ws or tcp") 31 | return cmd 32 | } 33 | 34 | func runcli(ctx context.Context, opts *StartOptions) error { 35 | cli := ClientDemo{} 36 | if opts.protocol == "ws" && !strings.HasPrefix(opts.addr, "ws:") { 37 | opts.addr = fmt.Sprintf("ws://%s", opts.addr) 38 | } 39 | cli.Start(ksuid.New().String(), opts.protocol, opts.addr) 40 | return nil 41 | } 42 | 43 | // NewCmd NewCmd 44 | func NewServerCmd(ctx context.Context) *cobra.Command { 45 | opts := &StartOptions{} 46 | 47 | cmd := &cobra.Command{ 48 | Use: "mock_srv", 49 | Short: "start server", 50 | RunE: func(cmd *cobra.Command, args []string) error { 51 | return runsrv(ctx, opts) 52 | }, 53 | } 54 | cmd.PersistentFlags().StringVarP(&opts.addr, "address", "a", ":8000", "listen address") 55 | cmd.PersistentFlags().StringVarP(&opts.protocol, "protocol", "p", "ws", "protocol ws or tcp") 56 | return cmd 57 | } 58 | 59 | func runsrv(ctx context.Context, opts *StartOptions) error { 60 | srv := ServerDemo{} 61 | srv.Start("srv1", opts.protocol, opts.addr) 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /examples/mock/server.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | _ "net/http/pprof" 7 | "time" 8 | 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/logger" 11 | "github.com/klintcheng/kim/naming" 12 | "github.com/klintcheng/kim/tcp" 13 | "github.com/klintcheng/kim/websocket" 14 | ) 15 | 16 | type ServerDemo struct{} 17 | 18 | func (s *ServerDemo) Start(id, protocol, addr string) { 19 | go func() { 20 | _ = http.ListenAndServe("0.0.0.0:6060", nil) 21 | }() 22 | 23 | var srv kim.Server 24 | service := &naming.DefaultService{ 25 | Id: id, 26 | Protocol: protocol, 27 | } 28 | if protocol == "ws" { 29 | srv = websocket.NewServer(addr, service) 30 | } else if protocol == "tcp" { 31 | srv = tcp.NewServer(addr, service) 32 | } 33 | 34 | handler := &ServerHandler{} 35 | 36 | srv.SetReadWait(time.Minute) 37 | srv.SetAcceptor(handler) 38 | srv.SetMessageListener(handler) 39 | srv.SetStateListener(handler) 40 | 41 | err := srv.Start() 42 | if err != nil { 43 | panic(err) 44 | } 45 | } 46 | 47 | // ServerHandler ServerHandler 48 | type ServerHandler struct { 49 | } 50 | 51 | // Accept this connection 52 | func (h *ServerHandler) Accept(conn kim.Conn, timeout time.Duration) (string, kim.Meta, error) { 53 | // 1. 读取:客户端发送的鉴权数据包 54 | frame, err := conn.ReadFrame() 55 | if err != nil { 56 | return "", nil, err 57 | } 58 | // 2. 解析:数据包内容就是userId 59 | userID := string(frame.GetPayload()) 60 | // 3. 鉴权:这里只是为了示例做一个fake验证,非空 61 | if userID == "" { 62 | return "", nil, errors.New("user id is invalid") 63 | } 64 | logger.Infof("logined %s", userID) 65 | return userID, nil, nil 66 | } 67 | 68 | // Receive default listener 69 | func (h *ServerHandler) Receive(ag kim.Agent, payload []byte) { 70 | logger.Infof("srv received %s", string(payload)) 71 | _ = ag.Push([]byte("ok")) 72 | } 73 | 74 | // Disconnect default listener 75 | func (h *ServerHandler) Disconnect(id string) error { 76 | logger.Infof("disconnect %s", id) 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /examples/unittest/chat_test.go: -------------------------------------------------------------------------------- 1 | package unittest 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "time" 7 | 8 | "github.com/klintcheng/kim" 9 | "github.com/klintcheng/kim/examples/dialer" 10 | "github.com/klintcheng/kim/wire" 11 | "github.com/klintcheng/kim/wire/pkt" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func Test_Usertalk(t *testing.T) { 16 | cli1, err := dialer.Login(wsurl, "test1") 17 | assert.Nil(t, err) 18 | if err != nil { 19 | return 20 | } 21 | cli2, err := dialer.Login(wsurl, "test2") 22 | assert.Nil(t, err) 23 | if err != nil { 24 | return 25 | } 26 | p := pkt.New(wire.CommandChatUserTalk, pkt.WithDest("test2")) 27 | p.WriteBody(&pkt.MessageReq{ 28 | Type: 1, 29 | Body: "hello world", 30 | }) 31 | err = cli1.Send(pkt.Marshal(p)) 32 | assert.Nil(t, err) 33 | 34 | // resp 35 | frame, _ := cli1.Read() 36 | assert.Equal(t, kim.OpBinary, frame.GetOpCode()) 37 | packet, err := pkt.MustReadLogicPkt(bytes.NewBuffer(frame.GetPayload())) 38 | assert.Nil(t, err) 39 | assert.Equal(t, pkt.Status_Success, packet.Header.Status) 40 | var resp pkt.MessageResp 41 | _ = packet.ReadBody(&resp) 42 | assert.Greater(t, resp.MessageId, int64(1000)) 43 | assert.Greater(t, resp.SendTime, int64(1000)) 44 | t.Log(&resp) 45 | 46 | // push message 47 | frame, err = cli2.Read() 48 | assert.Nil(t, err) 49 | packet, err = pkt.MustReadLogicPkt(bytes.NewBuffer(frame.GetPayload())) 50 | assert.Nil(t, err) 51 | var push pkt.MessagePush 52 | _ = packet.ReadBody(&push) 53 | assert.Equal(t, resp.MessageId, push.MessageId) 54 | assert.Equal(t, resp.SendTime, push.SendTime) 55 | assert.Equal(t, "hello world", push.Body) 56 | assert.Equal(t, int32(1), push.Type) 57 | t.Log(&push) 58 | } 59 | 60 | func Test_grouptalk(t *testing.T) { 61 | // 1. test1 登陆 62 | cli1, err := dialer.Login(wsurl, "test1") 63 | assert.Nil(t, err) 64 | 65 | // 2. 创建群 66 | p := pkt.New(wire.CommandGroupCreate) 67 | p.WriteBody(&pkt.GroupCreateReq{ 68 | Name: "group1", 69 | Owner: "test1", 70 | Members: []string{"test1", "test2", "test3", "test4"}, 71 | }) 72 | err = cli1.Send(pkt.Marshal(p)) 73 | assert.Nil(t, err) 74 | 75 | // 3. 读取创建群返回信息 76 | ack, err := cli1.Read() 77 | assert.Nil(t, err) 78 | ackp, _ := pkt.MustReadLogicPkt(bytes.NewBuffer(ack.GetPayload())) 79 | assert.Equal(t, pkt.Status_Success, ackp.GetStatus()) 80 | assert.Equal(t, wire.CommandGroupCreate, ackp.GetCommand()) 81 | // 4. 解包 82 | var createresp pkt.GroupCreateResp 83 | err = ackp.ReadBody(&createresp) 84 | assert.Nil(t, err) 85 | group := createresp.GetGroupId() 86 | assert.NotEmpty(t, group) 87 | if group == "" { 88 | return 89 | } 90 | // 5. 群成员test2、test3 登录 91 | cli2, err := dialer.Login(wsurl, "test2") 92 | assert.Nil(t, err) 93 | cli3, err := dialer.Login(wsurl, "test3") 94 | assert.Nil(t, err) 95 | t1 := time.Now() 96 | 97 | // 6. 发送群消息 CommandChatGroupTalk 98 | gtalk := pkt.New(wire.CommandChatGroupTalk, pkt.WithDest(group)).WriteBody(&pkt.MessageReq{ 99 | Type: 1, 100 | Body: "hellogroup", 101 | }) 102 | err = cli1.Send(pkt.Marshal(gtalk)) 103 | assert.Nil(t, err) 104 | // 7. 读取resp消息,确认消息发送成功 105 | ack, _ = cli1.Read() 106 | ackp, _ = pkt.MustReadLogicPkt(bytes.NewBuffer(ack.GetPayload())) 107 | assert.Equal(t, pkt.Status_Success, ackp.GetStatus()) 108 | 109 | // 7. test2 读取消息 110 | notify1, _ := cli2.Read() 111 | n1, _ := pkt.MustReadLogicPkt(bytes.NewBuffer(notify1.GetPayload())) 112 | assert.Equal(t, wire.CommandChatGroupTalk, n1.GetCommand()) 113 | var notify pkt.MessagePush 114 | _ = n1.ReadBody(¬ify) 115 | // 8. 校验消息内容 116 | assert.Equal(t, "hellogroup", notify.Body) 117 | assert.Equal(t, int32(1), notify.Type) 118 | assert.Empty(t, notify.Extra) 119 | assert.Greater(t, notify.SendTime, t1.UnixNano()) 120 | assert.Greater(t, notify.MessageId, int64(10000)) 121 | 122 | // 9. test3 读取消息 123 | notify2, _ := cli3.Read() 124 | n2, _ := pkt.MustReadLogicPkt(bytes.NewBuffer(notify2.GetPayload())) 125 | _ = n2.ReadBody(¬ify) 126 | assert.Equal(t, "hellogroup", notify.Body) 127 | 128 | t.Logf("cost %v", time.Since(t1)) 129 | } 130 | -------------------------------------------------------------------------------- /examples/unittest/login_test.go: -------------------------------------------------------------------------------- 1 | package unittest 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/klintcheng/kim/examples/dialer" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // const wsurl = "ws://119.3.4.216:8000" 12 | const wsurl = "ws://localhost:8000" 13 | 14 | func Test_login(t *testing.T) { 15 | cli, err := dialer.Login(wsurl, "test1") 16 | assert.Nil(t, err) 17 | time.Sleep(time.Second * 2) 18 | cli.Close() 19 | } 20 | -------------------------------------------------------------------------------- /examples/unittest/offline_test.go: -------------------------------------------------------------------------------- 1 | package unittest 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/examples/dialer" 11 | "github.com/klintcheng/kim/logger" 12 | "github.com/klintcheng/kim/wire" 13 | "github.com/klintcheng/kim/wire/pkt" 14 | "github.com/stretchr/testify/assert" 15 | "google.golang.org/protobuf/proto" 16 | ) 17 | 18 | func Test_offline(t *testing.T) { 19 | src := fmt.Sprintf("u%d", time.Now().Unix()) 20 | cli, err := dialer.Login(wsurl, src) 21 | assert.Nil(t, err) 22 | if err != nil { 23 | return 24 | } 25 | dest := fmt.Sprintf("u%d", time.Now().Unix()+1) 26 | count := 10 27 | for i := 0; i < count; i++ { 28 | p := pkt.New(wire.CommandChatUserTalk, pkt.WithDest(dest)) 29 | p.WriteBody(&pkt.MessageReq{ 30 | Type: 1, 31 | Body: "hello world", 32 | }) 33 | err := cli.Send(pkt.Marshal(p)) 34 | if err != nil { 35 | logger.Error(err) 36 | return 37 | } 38 | // wait ack 39 | _, _ = cli.Read() 40 | } 41 | 42 | destcli, err := dialer.Login(wsurl, dest) 43 | assert.Nil(t, err) 44 | 45 | // request offline message index 46 | p := pkt.New(wire.CommandOfflineIndex) 47 | p.WriteBody(&pkt.MessageIndexReq{}) 48 | _ = destcli.Send(pkt.Marshal(p)) 49 | 50 | var indexResp pkt.MessageIndexResp 51 | err = Read(destcli, &indexResp) 52 | assert.Nil(t, err) 53 | 54 | assert.Equal(t, count, len(indexResp.Indexes)) 55 | assert.Equal(t, src, indexResp.Indexes[0].AccountB) 56 | assert.Equal(t, int32(0), indexResp.Indexes[0].Direction) 57 | t.Log(indexResp.Indexes) 58 | 59 | var ids = make([]int64, count) 60 | for i, idx := range indexResp.Indexes { 61 | ids[i] = idx.MessageId 62 | } 63 | t.Log(ids) 64 | 65 | lastMessageId := ids[count-1] 66 | 67 | // read again 68 | p = pkt.New(wire.CommandOfflineIndex) 69 | p.WriteBody(&pkt.MessageIndexReq{ 70 | MessageId: lastMessageId, 71 | }) 72 | _ = destcli.Send(pkt.Marshal(p)) 73 | 74 | var indexResp2 pkt.MessageIndexResp 75 | err = Read(destcli, &indexResp2) 76 | assert.Nil(t, err) 77 | assert.Equal(t, 0, len(indexResp2.Indexes)) 78 | 79 | // request offline message content 80 | p = pkt.New(wire.CommandOfflineContent) 81 | p.WriteBody(&pkt.MessageContentReq{ 82 | MessageIds: ids, 83 | }) 84 | _ = destcli.Send(pkt.Marshal(p)) 85 | var contentResp pkt.MessageContentResp 86 | err = Read(destcli, &contentResp) 87 | assert.Nil(t, err) 88 | t.Log(contentResp.Contents) 89 | assert.Equal(t, count, len(contentResp.Contents)) 90 | assert.Equal(t, "hello world", contentResp.Contents[0].Body) 91 | assert.Equal(t, int32(1), contentResp.Contents[0].Type) 92 | } 93 | 94 | func Read(cli kim.Client, body proto.Message) error { 95 | frame, err := cli.Read() 96 | if err != nil { 97 | return err 98 | } 99 | packet, _ := pkt.MustReadLogicPkt(bytes.NewBuffer(frame.GetPayload())) 100 | if packet.GetStatus() != pkt.Status_Success { 101 | return fmt.Errorf("received status :%v", packet.GetStatus()) 102 | } 103 | return packet.ReadBody(body) 104 | } 105 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/klintcheng/kim 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/Joker/hpp v1.0.0 // indirect 7 | github.com/bwmarrin/snowflake v0.3.0 8 | github.com/dgrijalva/jwt-go v3.2.0+incompatible 9 | github.com/go-redis/redis/v7 v7.4.0 10 | github.com/go-resty/resty/v2 v2.6.0 11 | github.com/gobwas/pool v0.2.1 12 | github.com/gobwas/ws v1.0.4 13 | github.com/golang/mock v1.6.0 14 | github.com/golang/protobuf v1.4.3 15 | github.com/hashicorp/consul/api v1.8.1 16 | github.com/kataras/iris/v12 v12.2.0-alpha2.0.20210705170737-afb15b860124 17 | github.com/kelseyhightower/envconfig v1.4.0 18 | github.com/kr/text v0.2.0 // indirect 19 | github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible 20 | github.com/lestrrat-go/strftime v1.0.4 // indirect 21 | github.com/lionsoul2014/ip2region v2.2.0-release+incompatible 22 | github.com/mattn/go-colorable v0.1.8 // indirect 23 | github.com/mattn/go-isatty v0.0.13 // indirect 24 | github.com/panjf2000/ants/v2 v2.4.6 25 | github.com/prometheus/client_golang v1.11.0 26 | github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 27 | github.com/segmentio/ksuid v1.0.3 28 | github.com/sirupsen/logrus v1.7.0 29 | github.com/spf13/cobra v0.0.5 30 | github.com/spf13/viper v1.7.1 31 | github.com/stretchr/testify v1.7.0 32 | golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e // indirect 33 | golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d // indirect 34 | golang.org/x/sys v0.0.0-20211020174200-9d6173849985 // indirect 35 | google.golang.org/grpc v1.33.2 36 | google.golang.org/protobuf v1.26.0-rc.1 37 | gorm.io/driver/mysql v1.1.1 38 | gorm.io/gorm v1.21.15 39 | ) 40 | -------------------------------------------------------------------------------- /location.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | 7 | "github.com/klintcheng/kim/wire/endian" 8 | ) 9 | 10 | type Location struct { 11 | ChannelId string 12 | GateId string 13 | } 14 | 15 | func (loc *Location) Bytes() []byte { 16 | if loc == nil { 17 | return []byte{} 18 | } 19 | buf := new(bytes.Buffer) 20 | _ = endian.WriteShortBytes(buf, []byte(loc.ChannelId)) 21 | _ = endian.WriteShortBytes(buf, []byte(loc.GateId)) 22 | return buf.Bytes() 23 | } 24 | 25 | func (loc *Location) Unmarshal(data []byte) (err error) { 26 | if len(data) == 0 { 27 | return errors.New("data is empty") 28 | } 29 | buf := bytes.NewBuffer(data) 30 | loc.ChannelId, err = endian.ReadShortString(buf) 31 | if err != nil { 32 | return 33 | } 34 | loc.GateId, err = endian.ReadShortString(buf) 35 | if err != nil { 36 | return 37 | } 38 | return 39 | } 40 | -------------------------------------------------------------------------------- /logger/setting.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "time" 5 | 6 | rotatelogs "github.com/lestrrat-go/file-rotatelogs" 7 | "github.com/rifflock/lfshook" 8 | "github.com/sirupsen/logrus" 9 | ) 10 | 11 | type Settings struct { 12 | Filename string 13 | Level string 14 | RollingDays uint 15 | Format string 16 | } 17 | 18 | func Init(settings Settings) error { 19 | 20 | if settings.Level == "" { 21 | settings.Level = "debug" 22 | } 23 | ll, err := logrus.ParseLevel(settings.Level) 24 | if err == nil { 25 | std.SetLevel(ll) 26 | } else { 27 | std.Error("Invalid log level") 28 | } 29 | 30 | if settings.Filename == "" { 31 | return nil 32 | } 33 | 34 | if settings.RollingDays == 0 { 35 | settings.RollingDays = 7 36 | } 37 | 38 | writer, err := rotatelogs.New( 39 | settings.Filename+".%Y%m%d", 40 | // WithLinkName为最新的日志建立软连接,以方便随着找到当前日志文件 41 | rotatelogs.WithLinkName(settings.Filename), 42 | 43 | // WithRotationTime设置日志分割的时间 44 | rotatelogs.WithRotationTime(time.Hour*24), 45 | 46 | // WithMaxAge和WithRotationCount二者只能设置一个, 47 | // WithMaxAge设置文件清理前的最长保存时间, 48 | // WithRotationCount设置文件清理前最多保存的个数。 49 | //rotatelogs.WithMaxAge(time.Hour*24), 50 | rotatelogs.WithRotationCount(settings.RollingDays), 51 | ) 52 | if err != nil { 53 | return err 54 | } 55 | 56 | var logfr logrus.Formatter 57 | if settings.Format == "json" { 58 | logfr = &logrus.JSONFormatter{ 59 | DisableTimestamp: false, 60 | } 61 | } else { 62 | logfr = &logrus.TextFormatter{ 63 | DisableColors: true, 64 | } 65 | } 66 | 67 | lfsHook := lfshook.NewHook(lfshook.WriterMap{ 68 | logrus.DebugLevel: writer, 69 | logrus.InfoLevel: writer, 70 | logrus.WarnLevel: writer, 71 | logrus.ErrorLevel: writer, 72 | // logrus.FatalLevel: writer, 73 | // logrus.PanicLevel: writer, 74 | }, logfr) 75 | 76 | std.AddHook(lfsHook) 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /metrics.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | "github.com/prometheus/client_golang/prometheus/promauto" 6 | ) 7 | 8 | var channelTotalGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{ 9 | Namespace: "kim", 10 | Name: "channel_total", 11 | Help: "网关并发数", 12 | }, []string{"serviceId", "serviceName"}) 13 | -------------------------------------------------------------------------------- /middleware/recover.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | "strings" 7 | 8 | "github.com/klintcheng/kim" 9 | "github.com/klintcheng/kim/logger" 10 | "github.com/klintcheng/kim/wire/pkt" 11 | ) 12 | 13 | func Recover() kim.HandlerFunc { 14 | return func(ctx kim.Context) { 15 | defer func() { 16 | if err := recover(); err != nil { 17 | var callers []string 18 | for i := 1; ; i++ { 19 | _, file, line, got := runtime.Caller(i) 20 | if !got { 21 | break 22 | } 23 | callers = append(callers, fmt.Sprintf("%s:%d", file, line)) 24 | } 25 | logger.WithFields(logger.Fields{ 26 | "ChannelId": ctx.Header().ChannelId, 27 | "Command": ctx.Header().Command, 28 | "Seq": ctx.Header().Sequence, 29 | }).Error(err, strings.Join(callers, "\n")) 30 | 31 | _ = ctx.Resp(pkt.Status_SystemException, &pkt.ErrorResp{Message: "SystemException"}) 32 | } 33 | }() 34 | 35 | ctx.Next() 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /mock.sh: -------------------------------------------------------------------------------- 1 | export GOPATH=/Users/klint/go 2 | export PATH=$PATH:$(go env GOPATH)/bin 3 | 4 | go get -u github.com/golang/mock/gomock 5 | go get -u github.com/golang/mock/mockgen 6 | 7 | mockgen --source server.go -package kim -destination server_mock.go 8 | mockgen --source storage.go -package kim -destination storage_mock.go 9 | mockgen --source dispatcher.go -package kim -destination dispatcher_mock.go 10 | -------------------------------------------------------------------------------- /naming/consul/naming_test.go: -------------------------------------------------------------------------------- 1 | package consul 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/klintcheng/kim" 9 | "github.com/klintcheng/kim/naming" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func Test_Naming(t *testing.T) { 14 | ns, err := NewNaming("localhost:8500") 15 | assert.Nil(t, err) 16 | 17 | // 准备工作 18 | _ = ns.Deregister("test_1") 19 | _ = ns.Deregister("test_2") 20 | 21 | serviceName := "for_test" 22 | // 1. 注册 test_1 23 | err = ns.Register(&naming.DefaultService{ 24 | Id: "test_1", 25 | Name: serviceName, 26 | Namespace: "", 27 | Address: "localhost", 28 | Port: 8000, 29 | Protocol: "ws", 30 | Tags: []string{"tab1", "gate"}, 31 | }) 32 | assert.Nil(t, err) 33 | 34 | // 2. 服务发现 35 | servs, err := ns.Find(serviceName) 36 | assert.Nil(t, err) 37 | assert.Equal(t, 1, len(servs)) 38 | t.Log(servs) 39 | 40 | wg := sync.WaitGroup{} 41 | wg.Add(1) 42 | 43 | // 3. 监听服务实时变化(新增) 44 | _ = ns.Subscribe(serviceName, func(services []kim.ServiceRegistration) { 45 | t.Log(len(services)) 46 | 47 | assert.Equal(t, 2, len(services)) 48 | assert.Equal(t, "test_2", services[1].ServiceID()) 49 | wg.Done() 50 | }) 51 | time.Sleep(time.Second) 52 | 53 | // 4. 注册 test_2 用于验证第3步 54 | err = ns.Register(&naming.DefaultService{ 55 | Id: "test_2", 56 | Name: serviceName, 57 | Namespace: "", 58 | Address: "localhost", 59 | Port: 8001, 60 | Protocol: "ws", 61 | Tags: []string{"tab2", "gate"}, 62 | }) 63 | assert.Nil(t, err) 64 | 65 | // 等 Watch 回调中的方法执行完成 66 | wg.Wait() 67 | 68 | _ = ns.Unsubscribe(serviceName) 69 | 70 | // 5. 服务发现 71 | servs, _ = ns.Find(serviceName, "gate") 72 | assert.Equal(t, 2, len(servs)) // <-- 必须有两个 73 | 74 | // 6. 服务发现, 验证tag查询 75 | servs, _ = ns.Find(serviceName, "tab2") 76 | assert.Equal(t, 1, len(servs)) // <-- 必须有1个 77 | assert.Equal(t, "test_2", servs[0].ServiceID()) 78 | 79 | // 7. 注销test_2 80 | err = ns.Deregister("test_2") 81 | assert.Nil(t, err) 82 | 83 | // 8. 服务发现 84 | servs, err = ns.Find(serviceName) 85 | assert.Nil(t, err) 86 | assert.Equal(t, 1, len(servs)) 87 | assert.Equal(t, "test_1", servs[0].ServiceID()) 88 | 89 | // 9. 注销test_1 90 | err = ns.Deregister("test_1") 91 | assert.Nil(t, err) 92 | 93 | } 94 | -------------------------------------------------------------------------------- /naming/naming.go: -------------------------------------------------------------------------------- 1 | package naming 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/klintcheng/kim" 7 | ) 8 | 9 | // errors 10 | var ( 11 | ErrNotFound = errors.New("service no found") 12 | ) 13 | 14 | // Naming defined methods of the naming service 15 | type Naming interface { 16 | Find(serviceName string, tags ...string) ([]kim.ServiceRegistration, error) 17 | Subscribe(serviceName string, callback func(services []kim.ServiceRegistration)) error 18 | Unsubscribe(serviceName string) error 19 | Register(service kim.ServiceRegistration) error 20 | Deregister(serviceID string) error 21 | } 22 | -------------------------------------------------------------------------------- /naming/service.go: -------------------------------------------------------------------------------- 1 | package naming 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/klintcheng/kim" 7 | ) 8 | 9 | // "ID": "qa-dfirst-zfirst-tgateway-172.16.235.145-0-8000", 10 | // "Service": "tgateway", 11 | // "Tags": [ 12 | // "ZONE:qa-dfirst-zfirst", 13 | // "TMC_REGION:SH", 14 | // "TMC_DOMAIN:g002-qa.tutormeetplus.com" 15 | // ], 16 | // "Address": "172.16.235.145", 17 | // "Port": 8000, 18 | 19 | //Service define a Service 20 | 21 | // DefaultService Service Impl 22 | type DefaultService struct { 23 | Id string 24 | Name string 25 | Address string 26 | Port int 27 | Protocol string 28 | Namespace string 29 | Tags []string 30 | Meta map[string]string 31 | } 32 | 33 | // NewEntry NewEntry 34 | func NewEntry(id, name, protocol string, address string, port int) kim.ServiceRegistration { 35 | return &DefaultService{ 36 | Id: id, 37 | Name: name, 38 | Address: address, 39 | Port: port, 40 | Protocol: protocol, 41 | } 42 | } 43 | 44 | // ID returns the ServiceImpl ID 45 | func (e *DefaultService) ServiceID() string { 46 | return e.Id 47 | } 48 | 49 | // Name Name 50 | func (e *DefaultService) ServiceName() string { return e.Name } 51 | 52 | // Namespace Namespace 53 | func (e *DefaultService) GetNamespace() string { return e.Namespace } 54 | 55 | // Address Address 56 | func (e *DefaultService) PublicAddress() string { 57 | return e.Address 58 | } 59 | 60 | func (e *DefaultService) PublicPort() int { return e.Port } 61 | 62 | // Protocol Protocol 63 | func (e *DefaultService) GetProtocol() string { return e.Protocol } 64 | 65 | func (e *DefaultService) DialURL() string { 66 | if e.Protocol == "tcp" { 67 | return fmt.Sprintf("%s:%d", e.Address, e.Port) 68 | } 69 | return fmt.Sprintf("%s://%s:%d", e.Protocol, e.Address, e.Port) 70 | } 71 | 72 | // Tags Tags 73 | func (e *DefaultService) GetTags() []string { return e.Tags } 74 | 75 | // Meta Meta 76 | func (e *DefaultService) GetMeta() map[string]string { return e.Meta } 77 | 78 | func (e *DefaultService) String() string { 79 | return fmt.Sprintf("Id:%s,Name:%s,Address:%s,Port:%d,Ns:%s,Tags:%v,Meta:%v", e.Id, e.Name, e.Address, e.Port, e.Namespace, e.Tags, e.Meta) 80 | } 81 | -------------------------------------------------------------------------------- /net.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "net/http" 7 | "strings" 8 | ) 9 | 10 | // GetLocalIP GetLocalIP 11 | func GetLocalIP() string { 12 | addrs, err := net.InterfaceAddrs() 13 | if err != nil { 14 | return "" 15 | } 16 | for _, address := range addrs { 17 | // check the address type and if it is not a loopback the display it 18 | if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { 19 | if ipnet.IP.To4() != nil { 20 | return ipnet.IP.String() 21 | } 22 | } 23 | } 24 | return "" 25 | } 26 | 27 | var cidrs []*net.IPNet 28 | 29 | func init() { 30 | maxCidrBlocks := []string{ 31 | "127.0.0.1/8", // localhost 32 | "10.0.0.0/8", // 24-bit block 33 | "172.16.0.0/12", // 20-bit block 34 | "192.168.0.0/16", // 16-bit block 35 | "169.254.0.0/16", // link local address 36 | "::1/128", // localhost IPv6 37 | "fc00::/7", // unique local address IPv6 38 | "fe80::/10", // link local address IPv6 39 | } 40 | 41 | cidrs = make([]*net.IPNet, len(maxCidrBlocks)) 42 | for i, maxCidrBlock := range maxCidrBlocks { 43 | _, cidr, _ := net.ParseCIDR(maxCidrBlock) 44 | cidrs[i] = cidr 45 | } 46 | } 47 | 48 | func isPrivateAddress(address string) (bool, error) { 49 | ipAddress := net.ParseIP(address) 50 | if ipAddress == nil { 51 | return false, errors.New("address is not valid") 52 | } 53 | 54 | for i := range cidrs { 55 | if cidrs[i].Contains(ipAddress) { 56 | return true, nil 57 | } 58 | } 59 | 60 | return false, nil 61 | } 62 | 63 | // FromRequest return client's real public IP address from http request headers. 64 | func FromRequest(r *http.Request) string { 65 | // Fetch header value 66 | xRealIP := r.Header.Get("X-Real-Ip") 67 | xForwardedFor := r.Header.Get("X-Forwarded-For") 68 | 69 | // If both empty, return IP from remote address 70 | if xRealIP == "" && xForwardedFor == "" { 71 | var remoteIP string 72 | 73 | // If there are colon in remote address, remove the port number 74 | // otherwise, return remote address as is 75 | if strings.ContainsRune(r.RemoteAddr, ':') { 76 | remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr) 77 | } else { 78 | remoteIP = r.RemoteAddr 79 | } 80 | 81 | return remoteIP 82 | } 83 | 84 | // Check list of IP in X-Forwarded-For and return the first global address 85 | for _, address := range strings.Split(xForwardedFor, ",") { 86 | address = strings.TrimSpace(address) 87 | isPrivate, err := isPrivateAddress(address) 88 | if !isPrivate && err == nil { 89 | return address 90 | } 91 | } 92 | 93 | // If nothing succeed, return X-Real-IP 94 | return xRealIP 95 | } 96 | 97 | // RealIP is depreciated, use FromRequest instead 98 | func RealIP(r *http.Request) string { 99 | return FromRequest(r) 100 | } 101 | -------------------------------------------------------------------------------- /report/report_test.go: -------------------------------------------------------------------------------- 1 | package report 2 | 3 | import ( 4 | "math/rand" 5 | "os" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestReport(t *testing.T) { 11 | r := New(os.Stdout, 100) 12 | t1 := time.Now() 13 | defer func() { 14 | r.Finalize(time.Since(t1)) 15 | }() 16 | 17 | for i := 0; i < 500; i++ { 18 | r.Add(&Result{ 19 | StatusCode: 200, 20 | Duration: time.Millisecond * time.Duration(1+rand.Intn(20)*100), 21 | }) 22 | } 23 | for i := 0; i < 10; i++ { 24 | r.Add(&Result{ 25 | StatusCode: 100, 26 | }) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /report/template.go: -------------------------------------------------------------------------------- 1 | package report 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "text/template" 7 | ) 8 | 9 | var ( 10 | defaultTmpl = ` 11 | Summary: 12 | Total: {{ formatNumber .Total.Seconds }} secs 13 | Slowest: {{ formatNumber .Slowest }} secs 14 | Fastest: {{ formatNumber .Fastest }} secs 15 | Average: {{ formatNumber .Average }} secs 16 | Requests/sec: {{ formatNumber .Rps }} 17 | {{ if gt .SizeTotal 0 }} 18 | Total data: {{ .SizeTotal }} bytes{{ end }} 19 | Response time histogram: 20 | {{ histogram .Histogram }} 21 | Latency distribution:{{ range .LatencyDistribution }} 22 | {{ .Percentage }}%% in {{ formatNumber .Latency }} secs{{ end }} 23 | 24 | Status code distribution:{{ range $code, $num := .StatusCodeDist }} 25 | [{{ $code }}] {{ $num }} responses{{ end }} 26 | {{ if gt (len .ErrorDist) 0 }}Error distribution:{{ range $err, $num := .ErrorDist }} 27 | [{{ $num }}] {{ $err }}{{ end }}{{ end }} 28 | ` 29 | ) 30 | 31 | const ( 32 | barChar = "■" 33 | ) 34 | 35 | func newTemplate() *template.Template { 36 | return template.Must(template.New("tmpl").Funcs(tmplFuncMap).Parse(defaultTmpl)) 37 | } 38 | 39 | var tmplFuncMap = template.FuncMap{ 40 | "formatNumber": formatNumber, 41 | "formatNumberInt": formatNumberInt, 42 | "histogram": histogram, 43 | "jsonify": jsonify, 44 | } 45 | 46 | func jsonify(v interface{}) string { 47 | d, _ := json.Marshal(v) 48 | return string(d) 49 | } 50 | 51 | func formatNumber(duration float64) string { 52 | return fmt.Sprintf("%4.4f", duration) 53 | } 54 | 55 | func formatNumberInt(duration int) string { 56 | return fmt.Sprintf("%d", duration) 57 | } 58 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/klintcheng/kim/wire/pkt" 9 | ) 10 | 11 | var ErrSessionLost = errors.New("err:session lost") 12 | 13 | // Router defines 14 | type Router struct { 15 | middlewares []HandlerFunc 16 | handlers *FuncTree 17 | pool sync.Pool 18 | } 19 | 20 | // NewRouter NewRouter 21 | func NewRouter() *Router { 22 | r := &Router{ 23 | handlers: NewTree(), 24 | middlewares: make([]HandlerFunc, 0), 25 | } 26 | r.pool.New = func() interface{} { 27 | return BuildContext() 28 | } 29 | return r 30 | } 31 | 32 | func (r *Router) Use(handlers ...HandlerFunc) { 33 | r.middlewares = append(r.middlewares, handlers...) 34 | } 35 | 36 | // Handle register a command handler 37 | func (r *Router) Handle(command string, handlers ...HandlerFunc) { 38 | r.handlers.Add(command, r.middlewares...) 39 | r.handlers.Add(command, handlers...) 40 | } 41 | 42 | // Serve a packet from client 43 | func (r *Router) Serve(packet *pkt.LogicPkt, dispatcher Dispatcher, cache SessionStorage, session Session) error { 44 | if dispatcher == nil { 45 | return fmt.Errorf("dispatcher is nil") 46 | } 47 | if cache == nil { 48 | return fmt.Errorf("cache is nil") 49 | } 50 | ctx := r.pool.Get().(*ContextImpl) 51 | ctx.reset() 52 | ctx.request = packet 53 | ctx.Dispatcher = dispatcher 54 | ctx.SessionStorage = cache 55 | ctx.session = session 56 | 57 | r.serveContext(ctx) 58 | // Put Context to Pool 59 | r.pool.Put(ctx) 60 | return nil 61 | } 62 | 63 | func (r *Router) serveContext(ctx *ContextImpl) { 64 | chain, ok := r.handlers.Get(ctx.Header().Command) 65 | if !ok { 66 | ctx.handlers = []HandlerFunc{handleNoFound} 67 | ctx.Next() 68 | return 69 | } 70 | ctx.handlers = chain 71 | ctx.Next() 72 | } 73 | 74 | func handleNoFound(ctx Context) { 75 | _ = ctx.Resp(pkt.Status_NotImplemented, &pkt.ErrorResp{Message: "NotImplemented"}) 76 | } 77 | 78 | // FuncTree is a tree structure 79 | type FuncTree struct { 80 | nodes map[string]HandlersChain 81 | } 82 | 83 | // NewTree NewTree 84 | func NewTree() *FuncTree { 85 | return &FuncTree{nodes: make(map[string]HandlersChain, 10)} 86 | } 87 | 88 | // Add a handler to tree 89 | func (t *FuncTree) Add(path string, handlers ...HandlerFunc) { 90 | if t.nodes[path] == nil { 91 | t.nodes[path] = HandlersChain{} 92 | } 93 | 94 | t.nodes[path] = append(t.nodes[path], handlers...) 95 | } 96 | 97 | // Get a handler from tree 98 | func (t *FuncTree) Get(path string) (HandlersChain, bool) { 99 | f, ok := t.nodes[path] 100 | return f, ok 101 | } 102 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "time" 7 | ) 8 | 9 | const ( 10 | DefaultReadWait = time.Minute * 3 11 | DefaultWriteWait = time.Second * 10 12 | DefaultLoginWait = time.Second * 10 13 | DefaultHeartbeat = time.Second * 55 14 | ) 15 | 16 | const ( 17 | // 定义读取消息的默认goroutine池大小 18 | DefaultMessageReadPool = 5000 19 | DefaultConnectionPool = 5000 20 | ) 21 | 22 | // 定义了基础服务的抽象接口 23 | type Service interface { 24 | ServiceID() string 25 | ServiceName() string 26 | GetMeta() map[string]string 27 | } 28 | 29 | // 定义服务注册的抽象接口 30 | type ServiceRegistration interface { 31 | Service 32 | PublicAddress() string 33 | PublicPort() int 34 | DialURL() string 35 | GetTags() []string 36 | GetProtocol() string 37 | GetNamespace() string 38 | String() string 39 | } 40 | 41 | // Server 定义了一个tcp/websocket不同协议通用的服务端的接口 42 | type Server interface { 43 | ServiceRegistration 44 | // SetAcceptor 设置Acceptor 45 | SetAcceptor(Acceptor) 46 | //SetMessageListener 设置上行消息监听器 47 | SetMessageListener(MessageListener) 48 | //SetStateListener 设置连接状态监听服务 49 | SetStateListener(StateListener) 50 | // SetReadWait 设置读超时 51 | SetReadWait(time.Duration) 52 | // SetChannelMap 设置Channel管理服务 53 | SetChannelMap(ChannelMap) 54 | 55 | // Start 用于在内部实现网络端口的监听和接收连接, 56 | // 并完成一个Channel的初始化过程。 57 | Start() error 58 | // Push 消息到指定的Channel中 59 | // string channelID 60 | // []byte 序列化之后的消息数据 61 | Push(string, []byte) error 62 | // Shutdown 服务下线,关闭连接 63 | Shutdown(context.Context) error 64 | } 65 | 66 | // Acceptor 连接接收器 67 | type Acceptor interface { 68 | // Accept 返回一个握手完成的Channel对象或者一个error。 69 | // 业务层需要处理不同协议和网络环境下的连接握手协议 70 | Accept(Conn, time.Duration) (string, Meta, error) 71 | } 72 | 73 | // MessageListener 监听消息 74 | type MessageListener interface { 75 | // 收到消息回调 76 | Receive(Agent, []byte) 77 | } 78 | 79 | // StateListener 状态监听器 80 | type StateListener interface { 81 | // 连接断开回调 82 | Disconnect(string) error 83 | } 84 | 85 | type Meta map[string]string 86 | 87 | // Agent is interface of client side 88 | type Agent interface { 89 | ID() string 90 | Push([]byte) error 91 | GetMeta() Meta 92 | } 93 | 94 | // Conn Connection 95 | type Conn interface { 96 | net.Conn 97 | ReadFrame() (Frame, error) 98 | WriteFrame(OpCode, []byte) error 99 | Flush() error 100 | } 101 | 102 | // Channel is interface of client side 103 | type Channel interface { 104 | Conn 105 | Agent 106 | // Close 关闭连接 107 | Close() error 108 | Readloop(lst MessageListener) error 109 | // SetWriteWait 设置写超时 110 | SetWriteWait(time.Duration) 111 | SetReadWait(time.Duration) 112 | } 113 | 114 | // Client is interface of client side 115 | type Client interface { 116 | Service 117 | // connect to server 118 | Connect(string) error 119 | // SetDialer 设置拨号处理器 120 | SetDialer(Dialer) 121 | Send([]byte) error 122 | Read() (Frame, error) 123 | // Close 关闭 124 | Close() 125 | } 126 | 127 | // Dialer Dialer 128 | type Dialer interface { 129 | DialAndHandshake(DialerContext) (net.Conn, error) 130 | } 131 | 132 | type DialerContext struct { 133 | Id string 134 | Name string 135 | Address string 136 | Timeout time.Duration 137 | } 138 | 139 | // OpCode OpCode 140 | type OpCode byte 141 | 142 | // Opcode type 143 | const ( 144 | OpContinuation OpCode = 0x0 145 | OpText OpCode = 0x1 146 | OpBinary OpCode = 0x2 147 | OpClose OpCode = 0x8 148 | OpPing OpCode = 0x9 149 | OpPong OpCode = 0xa 150 | ) 151 | 152 | // Frame Frame 153 | type Frame interface { 154 | SetOpCode(OpCode) 155 | GetOpCode() OpCode 156 | SetPayload([]byte) 157 | GetPayload() []byte 158 | } 159 | -------------------------------------------------------------------------------- /services/gateway/conf.yaml: -------------------------------------------------------------------------------- 1 | ServiceID: gate01 2 | ServiceName: wgateway 3 | Listen: ":8000" 4 | MonitorPort: 8001 5 | PublicPort: 8000 6 | Tags: 7 | - IDC:SH_ALI 8 | Domain: ws://kingimcloud.com 9 | ConsulURL: localhost:8500 10 | AppSecret: "" 11 | MessageGPool: 5000 12 | ConnectionGPool: 15000 -------------------------------------------------------------------------------- /services/gateway/conf/config.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/kelseyhightower/envconfig" 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/logger" 11 | "github.com/spf13/viper" 12 | ) 13 | 14 | // Config Config 15 | type Config struct { 16 | ServiceID string 17 | ServiceName string `default:"wgateway"` 18 | Listen string `default:":8000"` 19 | PublicAddress string 20 | PublicPort int `default:"8000"` 21 | Tags []string 22 | Domain string 23 | ConsulURL string 24 | MonitorPort int `default:"8001"` 25 | AppSecret string 26 | LogLevel string `default:"DEBUG"` 27 | MessageGPool int `default:"10000"` 28 | ConnectionGPool int `default:"15000"` 29 | } 30 | 31 | func (c Config) String() string { 32 | bts, _ := json.Marshal(c) 33 | return string(bts) 34 | } 35 | 36 | // Init InitConfig 37 | func Init(file string) (*Config, error) { 38 | viper.SetConfigFile(file) 39 | viper.AddConfigPath(".") 40 | viper.AddConfigPath("/etc/conf") 41 | 42 | var config Config 43 | 44 | err := envconfig.Process("kim", &config) 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | if err := viper.ReadInConfig(); err != nil { 50 | logger.Warn(err) 51 | } else { 52 | if err := viper.Unmarshal(&config); err != nil { 53 | return nil, err 54 | } 55 | } 56 | 57 | if config.ServiceID == "" { 58 | localIP := kim.GetLocalIP() 59 | config.ServiceID = fmt.Sprintf("gate_%s", strings.ReplaceAll(localIP, ".", "")) 60 | } 61 | if config.PublicAddress == "" { 62 | config.PublicAddress = kim.GetLocalIP() 63 | } 64 | logger.Info(config) 65 | return &config, nil 66 | } 67 | -------------------------------------------------------------------------------- /services/gateway/conf/route.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | 7 | "github.com/sirupsen/logrus" 8 | ) 9 | 10 | type Zone struct { 11 | ID string 12 | Weight int 13 | } 14 | 15 | type Route struct { 16 | RouteBy string 17 | Zones []Zone 18 | Whitelist map[string]string 19 | Slots []int 20 | } 21 | 22 | func ReadRoute(path string) (*Route, error) { 23 | var conf struct { 24 | RouteBy string `json:"route_by,omitempty"` 25 | Zones []Zone `json:"zones,omitempty"` 26 | Whitelist []struct { 27 | Key string `json:"key,omitempty"` 28 | Value string `json:"value,omitempty"` 29 | } `json:"whitelist,omitempty"` 30 | } 31 | 32 | bts, err := ioutil.ReadFile(path) 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | err = json.Unmarshal(bts, &conf) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | var rt = Route{ 43 | RouteBy: conf.RouteBy, 44 | Zones: conf.Zones, 45 | Whitelist: make(map[string]string, len(conf.Whitelist)), 46 | Slots: make([]int, 0), 47 | } 48 | // build slots 49 | for i, zone := range conf.Zones { 50 | // 1.通过权重生成分片中的slots 51 | shard := make([]int, zone.Weight) 52 | // 2. 给当前slots设置值,指向索引i 53 | for j := 0; j < zone.Weight; j++ { 54 | shard[j] = i 55 | } 56 | // 2. 追加到Slots中 57 | rt.Slots = append(rt.Slots, shard...) 58 | } 59 | for _, wl := range conf.Whitelist { 60 | rt.Whitelist[wl.Key] = wl.Value 61 | } 62 | logrus.Infoln(rt) 63 | return &rt, nil 64 | } 65 | -------------------------------------------------------------------------------- /services/gateway/conf2.yaml: -------------------------------------------------------------------------------- 1 | ServiceID: gate02 2 | ServiceName: wgateway 3 | Listen: ":8010" 4 | MonitorPort: 8011 5 | PublicPort: 8010 6 | Tags: 7 | - IDC:HZ_ALI 8 | Domain: ws://kingimcloud.com 9 | ConsulURL: localhost:8500 10 | AppSecret: "" 11 | MessageGPool: 5000 12 | ConnectionGPool: 15000 -------------------------------------------------------------------------------- /services/gateway/route.json: -------------------------------------------------------------------------------- 1 | { 2 | "routeBy": "app", 3 | "zones": [ 4 | { 5 | "id": "zone_ali_01", 6 | "weight": 80 7 | }, 8 | { 9 | "id": "zone_ali_02", 10 | "weight": 10 11 | }, 12 | { 13 | "id": "zone_ali_03", 14 | "weight": 10 15 | } 16 | ], 17 | "whitelist": [ 18 | { 19 | "key": "kim", 20 | "value": "zone_ali_03" 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /services/gateway/serv/dialer.go: -------------------------------------------------------------------------------- 1 | package serv 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/klintcheng/kim" 7 | "github.com/klintcheng/kim/logger" 8 | "github.com/klintcheng/kim/tcp" 9 | "github.com/klintcheng/kim/wire/pkt" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | type TcpDialer struct { 14 | ServiceId string 15 | } 16 | 17 | func NewDialer(serviceId string) kim.Dialer { 18 | return &TcpDialer{ 19 | ServiceId: serviceId, 20 | } 21 | } 22 | 23 | // DialAndHandshake(context.Context, string) (net.Conn, error) 24 | func (d *TcpDialer) DialAndHandshake(ctx kim.DialerContext) (net.Conn, error) { 25 | // 1. 拨号建立连接 26 | conn, err := net.DialTimeout("tcp", ctx.Address, ctx.Timeout) 27 | if err != nil { 28 | return nil, err 29 | } 30 | req := &pkt.InnerHandshakeReq{ 31 | ServiceId: d.ServiceId, 32 | } 33 | logger.Infof("send req %v", req) 34 | // 2. 把自己的ServiceId发送给对方 35 | bts, _ := proto.Marshal(req) 36 | err = tcp.WriteFrame(conn, kim.OpBinary, bts) 37 | if err != nil { 38 | return nil, err 39 | } 40 | return conn, nil 41 | } 42 | -------------------------------------------------------------------------------- /services/gateway/serv/handler.go: -------------------------------------------------------------------------------- 1 | package serv 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "regexp" 7 | "time" 8 | 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/container" 11 | "github.com/klintcheng/kim/logger" 12 | "github.com/klintcheng/kim/wire" 13 | "github.com/klintcheng/kim/wire/pkt" 14 | "github.com/klintcheng/kim/wire/token" 15 | ) 16 | 17 | const ( 18 | MetaKeyApp = "app" 19 | MetaKeyAccount = "account" 20 | ) 21 | 22 | var log = logger.WithFields(logger.Fields{ 23 | "service": "gateway", 24 | "pkg": "serv", 25 | }) 26 | 27 | // Handler Handler 28 | type Handler struct { 29 | ServiceID string 30 | AppSecret string 31 | } 32 | 33 | // Accept this connection 34 | func (h *Handler) Accept(conn kim.Conn, timeout time.Duration) (string, kim.Meta, error) { 35 | // 1. 读取登录包 36 | _ = conn.SetReadDeadline(time.Now().Add(timeout)) 37 | frame, err := conn.ReadFrame() 38 | if err != nil { 39 | return "", nil, err 40 | } 41 | 42 | buf := bytes.NewBuffer(frame.GetPayload()) 43 | req, err := pkt.MustReadLogicPkt(buf) 44 | if err != nil { 45 | log.Error(err) 46 | return "", nil, err 47 | } 48 | // 2. 必须是登录包 49 | if req.Command != wire.CommandLoginSignIn { 50 | resp := pkt.NewFrom(&req.Header) 51 | resp.Status = pkt.Status_InvalidCommand 52 | _ = conn.WriteFrame(kim.OpBinary, pkt.Marshal(resp)) 53 | return "", nil, fmt.Errorf("must be a SignIn command") 54 | } 55 | 56 | // 3. 反序列化Body 57 | var login pkt.LoginReq 58 | err = req.ReadBody(&login) 59 | if err != nil { 60 | return "", nil, err 61 | } 62 | secret := h.AppSecret 63 | if secret == "" { 64 | secret = token.DefaultSecret 65 | } 66 | // 4. 使用默认的DefaultSecret 解析token 67 | tk, err := token.Parse(secret, login.Token) 68 | if err != nil { 69 | // 5. 如果token无效,就返回SDK一个Unauthorized消息 70 | resp := pkt.NewFrom(&req.Header) 71 | resp.Status = pkt.Status_Unauthorized 72 | _ = conn.WriteFrame(kim.OpBinary, pkt.Marshal(resp)) 73 | return "", nil, err 74 | } 75 | // 6. 生成一个全局唯一的ChannelID 76 | id := generateChannelID(h.ServiceID, tk.Account) 77 | log.Infof("accept %v channel:%s", tk, id) 78 | 79 | req.ChannelId = id 80 | req.WriteBody(&pkt.Session{ 81 | Account: tk.Account, 82 | ChannelId: id, 83 | GateId: h.ServiceID, 84 | App: tk.App, 85 | RemoteIP: getIP(conn.RemoteAddr().String()), 86 | }) 87 | req.AddStringMeta(MetaKeyApp, tk.App) 88 | req.AddStringMeta(MetaKeyAccount, tk.Account) 89 | 90 | // 7. 把login.转发给Login服务 91 | err = container.Forward(wire.SNLogin, req) 92 | if err != nil { 93 | log.Errorf("container.Forward :%v", err) 94 | return "", nil, err 95 | } 96 | return id, kim.Meta{ 97 | MetaKeyApp: tk.App, 98 | MetaKeyAccount: tk.Account, 99 | }, nil 100 | } 101 | 102 | // Receive default listener 103 | func (h *Handler) Receive(ag kim.Agent, payload []byte) { 104 | buf := bytes.NewBuffer(payload) 105 | packet, err := pkt.Read(buf) 106 | if err != nil { 107 | log.Error(err) 108 | return 109 | } 110 | if basicPkt, ok := packet.(*pkt.BasicPkt); ok { 111 | if basicPkt.Code == pkt.CodePing { 112 | _ = ag.Push(pkt.Marshal(&pkt.BasicPkt{Code: pkt.CodePong})) 113 | } 114 | return 115 | } 116 | if logicPkt, ok := packet.(*pkt.LogicPkt); ok { 117 | logicPkt.ChannelId = ag.ID() 118 | 119 | messageInTotal.WithLabelValues(h.ServiceID, wire.SNTGateway, logicPkt.Command).Inc() 120 | messageInFlowBytes.WithLabelValues(h.ServiceID, wire.SNTGateway, logicPkt.Command).Add(float64(len(payload))) 121 | 122 | // 把meta注入到header中 123 | if ag.GetMeta() != nil { 124 | logicPkt.AddStringMeta(MetaKeyApp, ag.GetMeta()[MetaKeyApp]) 125 | logicPkt.AddStringMeta(MetaKeyAccount, ag.GetMeta()[MetaKeyAccount]) 126 | } 127 | 128 | err = container.Forward(logicPkt.ServiceName(), logicPkt) 129 | if err != nil { 130 | logger.WithFields(logger.Fields{ 131 | "module": "handler", 132 | "id": ag.ID(), 133 | "cmd": logicPkt.Command, 134 | "dest": logicPkt.Dest, 135 | }).Error(err) 136 | } 137 | } 138 | 139 | } 140 | 141 | // Disconnect default listener 142 | func (h *Handler) Disconnect(id string) error { 143 | log.Infof("disconnect %s", id) 144 | 145 | logout := pkt.New(wire.CommandLoginSignOut, pkt.WithChannel(id)) 146 | err := container.Forward(wire.SNLogin, logout) 147 | if err != nil { 148 | logger.WithFields(logger.Fields{ 149 | "module": "handler", 150 | "id": id, 151 | }).Error(err) 152 | } 153 | return nil 154 | } 155 | 156 | var ipExp = regexp.MustCompile(string("\\:[0-9]+$")) 157 | 158 | func getIP(remoteAddr string) string { 159 | if remoteAddr == "" { 160 | return "" 161 | } 162 | return ipExp.ReplaceAllString(remoteAddr, "") 163 | } 164 | 165 | func generateChannelID(serviceID, account string) string { 166 | return fmt.Sprintf("%s_%s_%d", serviceID, account, wire.Seq.Next()) 167 | } 168 | -------------------------------------------------------------------------------- /services/gateway/serv/metrics.go: -------------------------------------------------------------------------------- 1 | package serv 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | "github.com/prometheus/client_golang/prometheus/promauto" 6 | ) 7 | 8 | var messageInTotal = promauto.NewCounterVec(prometheus.CounterOpts{ 9 | Namespace: "kim", 10 | Name: "message_in_total", 11 | Help: "网关接收消息总数", 12 | }, []string{"serviceId", "serviceName", "command"}) 13 | 14 | var messageInFlowBytes = promauto.NewCounterVec(prometheus.CounterOpts{ 15 | Namespace: "kim", 16 | Name: "message_in_flow_bytes", 17 | Help: "网关接收消息字节数", 18 | }, []string{"serviceId", "serviceName", "command"}) 19 | 20 | var noServerFoundErrorTotal = promauto.NewCounterVec(prometheus.CounterOpts{ 21 | Namespace: "kim", 22 | Name: "no_server_found_error_total", 23 | Help: "查找zone分区中服务失败的次数", 24 | }, []string{"zone"}) 25 | -------------------------------------------------------------------------------- /services/gateway/serv/selector.go: -------------------------------------------------------------------------------- 1 | package serv 2 | 3 | import ( 4 | "hash/crc32" 5 | "math/rand" 6 | 7 | "github.com/klintcheng/kim" 8 | "github.com/klintcheng/kim/logger" 9 | "github.com/klintcheng/kim/services/gateway/conf" 10 | "github.com/klintcheng/kim/wire/pkt" 11 | ) 12 | 13 | // RouteSelector RouteSelector 14 | type RouteSelector struct { 15 | route *conf.Route 16 | } 17 | 18 | func NewRouteSelector(configPath string) (*RouteSelector, error) { 19 | route, err := conf.ReadRoute(configPath) 20 | if err != nil { 21 | return nil, err 22 | } 23 | return &RouteSelector{ 24 | route: route, 25 | }, nil 26 | } 27 | 28 | // Lookup a server 29 | func (s *RouteSelector) Lookup(header *pkt.Header, srvs []kim.Service) string { 30 | // 1. 从header中读取Meta信息 31 | app, _ := pkt.FindMeta(header.Meta, MetaKeyApp) 32 | account, _ := pkt.FindMeta(header.Meta, MetaKeyAccount) 33 | if app == nil || account == nil { 34 | ri := rand.Intn(len(srvs)) 35 | return srvs[ri].ServiceID() 36 | } 37 | log := logger.WithFields(logger.Fields{ 38 | "app": app, 39 | "account": account, 40 | }) 41 | 42 | // 2. 判断是否命中白名单 43 | zone, ok := s.route.Whitelist[app.(string)] 44 | if !ok { // 未命中情况 45 | var key string 46 | switch s.route.RouteBy { 47 | case MetaKeyApp: 48 | key = app.(string) 49 | case MetaKeyAccount: 50 | key = account.(string) 51 | default: 52 | key = account.(string) 53 | } 54 | // 3. 通过权重计算出zone 55 | slot := hashcode(key) % len(s.route.Slots) 56 | i := s.route.Slots[slot] 57 | zone = s.route.Zones[i].ID 58 | } else { 59 | log.Infoln("hit a zone in whitelist", zone) 60 | } 61 | // 4. 过滤出当前zone的servers 62 | zoneSrvs := filterSrvs(srvs, zone) 63 | if len(zoneSrvs) == 0 { 64 | noServerFoundErrorTotal.WithLabelValues(zone).Inc() 65 | log.Warnf("select a random service from all due to no service found in zone %s", zone) 66 | ri := rand.Intn(len(srvs)) 67 | return srvs[ri].ServiceID() 68 | } 69 | // 5. 从zoneSrvs中选中一个服务 70 | srv := selectSrvs(zoneSrvs, account.(string)) 71 | return srv.ServiceID() 72 | } 73 | 74 | func filterSrvs(srvs []kim.Service, zone string) []kim.Service { 75 | var res = make([]kim.Service, 0, len(srvs)) 76 | for _, srv := range srvs { 77 | if zone == srv.GetMeta()["zone"] { 78 | res = append(res, srv) 79 | } 80 | } 81 | return res 82 | } 83 | 84 | func selectSrvs(srvs []kim.Service, account string) kim.Service { 85 | slots := make([]int, 0, len(srvs)*10) 86 | for i := range srvs { 87 | for j := 0; j < 10; j++ { 88 | slots = append(slots, i) 89 | } 90 | } 91 | slot := hashcode(account) % len(slots) 92 | return srvs[slots[slot]] 93 | } 94 | 95 | func hashcode(key string) int { 96 | hash32 := crc32.NewIEEE() 97 | hash32.Write([]byte(key)) 98 | return int(hash32.Sum32()) 99 | } 100 | -------------------------------------------------------------------------------- /services/gateway/serv/selector_test.go: -------------------------------------------------------------------------------- 1 | package serv 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/klintcheng/kim" 7 | "github.com/klintcheng/kim/naming" 8 | "github.com/klintcheng/kim/wire" 9 | "github.com/klintcheng/kim/wire/pkt" 10 | "github.com/segmentio/ksuid" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestRouteSelector_Lookup(t *testing.T) { 15 | 16 | srvs := []kim.Service{ 17 | &naming.DefaultService{ 18 | Id: "s1", 19 | Meta: map[string]string{"zone": "zone_ali_01"}, 20 | }, 21 | &naming.DefaultService{ 22 | Id: "s2", 23 | Meta: map[string]string{"zone": "zone_ali_01"}, 24 | }, 25 | &naming.DefaultService{ 26 | Id: "s3", 27 | Meta: map[string]string{"zone": "zone_ali_01"}, 28 | }, 29 | &naming.DefaultService{ 30 | Id: "s4", 31 | Meta: map[string]string{"zone": "zone_ali_02"}, 32 | }, 33 | &naming.DefaultService{ 34 | Id: "s5", 35 | Meta: map[string]string{"zone": "zone_ali_02"}, 36 | }, 37 | &naming.DefaultService{ 38 | Id: "s6", 39 | Meta: map[string]string{"zone": "zone_ali_03"}, 40 | }, 41 | } 42 | 43 | rs, err := NewRouteSelector("../route.json") 44 | assert.Nil(t, err) 45 | 46 | packet := pkt.New(wire.CommandChatUserTalk, pkt.WithChannel(ksuid.New().String())) 47 | packet.AddStringMeta(MetaKeyApp, "kim") 48 | packet.AddStringMeta(MetaKeyAccount, "test1") 49 | hit := rs.Lookup(&packet.Header, srvs) 50 | assert.Equal(t, "s6", hit) 51 | 52 | hits := make(map[string]int) 53 | for i := 0; i < 100; i++ { 54 | header := pkt.Header{ 55 | ChannelId: ksuid.New().String(), 56 | Meta: []*pkt.Meta{ 57 | { 58 | Type: pkt.MetaType_string, 59 | Key: MetaKeyApp, 60 | Value: ksuid.New().String(), 61 | }, 62 | { 63 | Type: pkt.MetaType_string, 64 | Key: MetaKeyAccount, 65 | Value: ksuid.New().String(), 66 | }, 67 | }, 68 | } 69 | hit = rs.Lookup(&header, srvs) 70 | hits[hit]++ 71 | } 72 | t.Log(hits) 73 | } 74 | -------------------------------------------------------------------------------- /services/gateway/server.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | _ "net/http/pprof" 7 | "time" 8 | 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/container" 11 | "github.com/klintcheng/kim/logger" 12 | "github.com/klintcheng/kim/naming" 13 | "github.com/klintcheng/kim/naming/consul" 14 | "github.com/klintcheng/kim/services/gateway/conf" 15 | "github.com/klintcheng/kim/services/gateway/serv" 16 | "github.com/klintcheng/kim/tcp" 17 | "github.com/klintcheng/kim/websocket" 18 | "github.com/klintcheng/kim/wire" 19 | "github.com/spf13/cobra" 20 | ) 21 | 22 | // const logName = "logs/gateway" 23 | 24 | // ServerStartOptions ServerStartOptions 25 | type ServerStartOptions struct { 26 | config string 27 | protocol string 28 | route string 29 | } 30 | 31 | // NewServerStartCmd creates a new http server command 32 | func NewServerStartCmd(ctx context.Context, version string) *cobra.Command { 33 | opts := &ServerStartOptions{} 34 | 35 | cmd := &cobra.Command{ 36 | Use: "gateway", 37 | Short: "Start a gateway", 38 | RunE: func(cmd *cobra.Command, args []string) error { 39 | return RunServerStart(ctx, opts, version) 40 | }, 41 | } 42 | cmd.PersistentFlags().StringVarP(&opts.config, "config", "c", "./gateway/conf.yaml", "Config file") 43 | cmd.PersistentFlags().StringVarP(&opts.route, "route", "r", "./gateway/route.json", "route file") 44 | cmd.PersistentFlags().StringVarP(&opts.protocol, "protocol", "p", "ws", "protocol of ws or tcp") 45 | return cmd 46 | } 47 | 48 | // RunServerStart run http server 49 | func RunServerStart(ctx context.Context, opts *ServerStartOptions, version string) error { 50 | config, err := conf.Init(opts.config) 51 | if err != nil { 52 | return err 53 | } 54 | _ = logger.Init(logger.Settings{ 55 | Level: "trace", 56 | Filename: "./data/gateway.log", 57 | }) 58 | 59 | handler := &serv.Handler{ 60 | ServiceID: config.ServiceID, 61 | AppSecret: config.AppSecret, 62 | } 63 | meta := make(map[string]string) 64 | meta[consul.KeyHealthURL] = fmt.Sprintf("http://%s:%d/health", config.PublicAddress, config.MonitorPort) 65 | meta["domain"] = config.Domain 66 | 67 | var srv kim.Server 68 | service := &naming.DefaultService{ 69 | Id: config.ServiceID, 70 | Name: config.ServiceName, 71 | Address: config.PublicAddress, 72 | Port: config.PublicPort, 73 | Protocol: opts.protocol, 74 | Tags: config.Tags, 75 | Meta: meta, 76 | } 77 | srvOpts := []kim.ServerOption{ 78 | kim.WithConnectionGPool(config.ConnectionGPool), kim.WithMessageGPool(config.MessageGPool), 79 | } 80 | if opts.protocol == "ws" { 81 | srv = websocket.NewServer(config.Listen, service, srvOpts...) 82 | } else if opts.protocol == "tcp" { 83 | srv = tcp.NewServer(config.Listen, service, srvOpts...) 84 | } 85 | 86 | srv.SetReadWait(time.Minute * 2) 87 | srv.SetAcceptor(handler) 88 | srv.SetMessageListener(handler) 89 | srv.SetStateListener(handler) 90 | 91 | _ = container.Init(srv, wire.SNChat, wire.SNLogin) 92 | container.EnableMonitor(fmt.Sprintf(":%d", config.MonitorPort)) 93 | 94 | ns, err := consul.NewNaming(config.ConsulURL) 95 | if err != nil { 96 | return err 97 | } 98 | container.SetServiceNaming(ns) 99 | // set a dialer 100 | container.SetDialer(serv.NewDialer(config.ServiceID)) 101 | // use routeSelector 102 | selector, err := serv.NewRouteSelector(opts.route) 103 | if err != nil { 104 | return err 105 | } 106 | container.SetSelector(selector) 107 | return container.Start() 108 | } 109 | -------------------------------------------------------------------------------- /services/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | 7 | "github.com/klintcheng/kim/logger" 8 | "github.com/klintcheng/kim/services/gateway" 9 | "github.com/klintcheng/kim/services/router" 10 | "github.com/klintcheng/kim/services/server" 11 | "github.com/klintcheng/kim/services/service" 12 | "github.com/spf13/cobra" 13 | ) 14 | 15 | const version = "v1" 16 | 17 | func main() { 18 | flag.Parse() 19 | 20 | root := &cobra.Command{ 21 | Use: "kim", 22 | Version: version, 23 | Short: "King IM Cloud", 24 | } 25 | ctx := context.Background() 26 | 27 | root.AddCommand(gateway.NewServerStartCmd(ctx, version)) 28 | root.AddCommand(server.NewServerStartCmd(ctx, version)) 29 | root.AddCommand(service.NewServerStartCmd(ctx, version)) 30 | root.AddCommand(router.NewServerStartCmd(ctx, version)) 31 | 32 | if err := root.Execute(); err != nil { 33 | logger.WithError(err).Fatal("Could not run command") 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /services/router/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/klintcheng/kim/69995ca15cf6e22dd8e8ed4f858e0106dbd28b84/services/router/.DS_Store -------------------------------------------------------------------------------- /services/router/api_test.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/go-resty/resty/v2" 9 | "github.com/klintcheng/kim/services/router/apis" 10 | "github.com/segmentio/ksuid" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func Test_Lookup(t *testing.T) { 15 | cli := resty.New() 16 | cli.SetHeader("Content-Type", "application/json") 17 | 18 | domains := make(map[string]int) 19 | for i := 0; i < 1000; i++ { 20 | url := fmt.Sprintf("http://localhost:8100/api/lookup/%s", ksuid.New().String()) 21 | 22 | var res apis.LookUpResp 23 | resp, err := cli.R().SetResult(&res).Get(url) 24 | assert.Equal(t, http.StatusOK, resp.StatusCode()) 25 | assert.Nil(t, err) 26 | if len(res.Domains) > 0 { 27 | domain := res.Domains[0] 28 | domains[domain]++ 29 | } 30 | } 31 | for domain, hit := range domains { 32 | fmt.Printf("domain: %s ;hit count: %d\n", domain, hit) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /services/router/apis/router.go: -------------------------------------------------------------------------------- 1 | package apis 2 | 3 | import ( 4 | "fmt" 5 | "hash/crc32" 6 | "time" 7 | 8 | "github.com/kataras/iris/v12" 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/naming" 11 | "github.com/klintcheng/kim/services/router/conf" 12 | "github.com/klintcheng/kim/services/router/ipregion" 13 | "github.com/klintcheng/kim/wire" 14 | "github.com/sirupsen/logrus" 15 | ) 16 | 17 | const DefaultLocation = "中国" 18 | 19 | type RouterApi struct { 20 | Naming naming.Naming 21 | IpRegion ipregion.IpRegion 22 | Config conf.Router 23 | } 24 | 25 | type LookUpResp struct { 26 | UTC int64 `json:"utc"` 27 | Location string `json:"location"` 28 | Domains []string `json:"domains"` 29 | } 30 | 31 | func (r *RouterApi) Lookup(c iris.Context) { 32 | ip := kim.RealIP(c.Request()) 33 | token := c.Params().Get("token") 34 | 35 | // step 1 36 | var location conf.Country 37 | ipinfo, err := r.IpRegion.Search(ip) 38 | if err != nil || ipinfo.Country == "0" { 39 | location = DefaultLocation 40 | } else { 41 | location = conf.Country(ipinfo.Country) 42 | } 43 | 44 | // step 2 45 | regionId, ok := r.Config.Mapping[location] 46 | if !ok { 47 | c.StopWithError(iris.StatusForbidden, err) 48 | return 49 | } 50 | 51 | // step 3 52 | region, ok := r.Config.Regions[regionId] 53 | if !ok { 54 | c.StopWithError(iris.StatusInternalServerError, err) 55 | return 56 | } 57 | 58 | // step 4 59 | idc := selectIdc(token, region) 60 | 61 | // step 5 62 | gateways, err := r.Naming.Find(wire.SNWGateway, fmt.Sprintf("IDC:%s", idc.ID)) 63 | if err != nil { 64 | c.StopWithError(iris.StatusInternalServerError, err) 65 | return 66 | } 67 | 68 | // step 6 69 | hits := selectGateways(token, gateways, 3) 70 | domains := make([]string, len(hits)) 71 | for i, h := range hits { 72 | domains[i] = h.GetMeta()["domain"] 73 | } 74 | 75 | logrus.WithFields(logrus.Fields{ 76 | "country": location, 77 | "regionId": regionId, 78 | "idc": idc.ID, 79 | }).Infof("lookup domain %v", domains) 80 | 81 | _, _ = c.JSON(LookUpResp{ 82 | UTC: time.Now().Unix(), 83 | Location: string(location), 84 | Domains: domains, 85 | }) 86 | } 87 | 88 | func selectIdc(token string, region *conf.Region) *conf.IDC { 89 | slot := hashcode(token) % len(region.Slots) 90 | i := region.Slots[slot] 91 | return ®ion.Idcs[i] 92 | } 93 | 94 | func selectGateways(token string, gateways []kim.ServiceRegistration, num int) []kim.ServiceRegistration { 95 | if len(gateways) <= num { 96 | return gateways 97 | } 98 | slots := make([]int, 0, len(gateways)*10) 99 | for i := range gateways { 100 | for j := 0; j < 10; j++ { 101 | slots = append(slots, i) 102 | } 103 | } 104 | slot := hashcode(token) % len(slots) 105 | i := slots[slot] 106 | res := make([]kim.ServiceRegistration, 0, num) 107 | for len(res) < num { 108 | res = append(res, gateways[i]) 109 | i++ 110 | if i >= len(gateways) { 111 | i = 0 112 | } 113 | } 114 | return res 115 | } 116 | 117 | func hashcode(key string) int { 118 | hash32 := crc32.NewIEEE() 119 | hash32.Write([]byte(key)) 120 | return int(hash32.Sum32()) 121 | } 122 | -------------------------------------------------------------------------------- /services/router/apis/router_test.go: -------------------------------------------------------------------------------- 1 | package apis 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/klintcheng/kim" 7 | "github.com/klintcheng/kim/naming" 8 | "github.com/klintcheng/kim/services/router/conf" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_selectIdc(t *testing.T) { 13 | got := selectIdc("test1", &conf.Region{ 14 | Idcs: []conf.IDC{ 15 | {ID: "SH_ALI"}, 16 | {ID: "HZ_ALI"}, 17 | {ID: "SH_TENCENT"}, 18 | }, 19 | Slots: []byte{0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2}, 20 | }) 21 | assert.NotNil(t, got) 22 | t.Log(got) 23 | } 24 | 25 | func Test_selectGateways(t *testing.T) { 26 | got := selectGateways("test11", []kim.ServiceRegistration{ 27 | &naming.DefaultService{Id: "g1"}, 28 | &naming.DefaultService{Id: "g2"}, 29 | }, 3) 30 | assert.Equal(t, len(got), 2) 31 | 32 | got = selectGateways("test11", []kim.ServiceRegistration{ 33 | &naming.DefaultService{Id: "g1"}, 34 | &naming.DefaultService{Id: "g2"}, 35 | &naming.DefaultService{Id: "g3"}, 36 | }, 3) 37 | assert.Equal(t, len(got), 3) 38 | 39 | got = selectGateways("test11", []kim.ServiceRegistration{ 40 | &naming.DefaultService{Id: "g1"}, 41 | &naming.DefaultService{Id: "g2"}, 42 | &naming.DefaultService{Id: "g3"}, 43 | &naming.DefaultService{Id: "g4"}, 44 | &naming.DefaultService{Id: "g5"}, 45 | &naming.DefaultService{Id: "g6"}, 46 | }, 3) 47 | 48 | t.Log(got) 49 | assert.Equal(t, len(got), 3) 50 | } 51 | -------------------------------------------------------------------------------- /services/router/conf/config.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/kelseyhightower/envconfig" 7 | "github.com/klintcheng/kim/logger" 8 | "github.com/spf13/viper" 9 | ) 10 | 11 | // Config Config 12 | type Config struct { 13 | Listen string `default:":8100"` 14 | ConsulURL string `default:"localhost:8500"` 15 | LogLevel string `default:"INFO"` 16 | } 17 | 18 | func (c Config) String() string { 19 | bts, _ := json.Marshal(c) 20 | return string(bts) 21 | } 22 | 23 | // Init InitConfig 24 | func Init(file string) (*Config, error) { 25 | viper.SetConfigFile(file) 26 | viper.AddConfigPath(".") 27 | viper.AddConfigPath("/etc/conf") 28 | 29 | var config Config 30 | 31 | err := envconfig.Process("kim", &config) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | if err := viper.ReadInConfig(); err != nil { 37 | logger.Warn(err) 38 | } else { 39 | if err := viper.Unmarshal(&config); err != nil { 40 | return nil, err 41 | } 42 | } 43 | 44 | return &config, nil 45 | } 46 | -------------------------------------------------------------------------------- /services/router/conf/router.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | ) 7 | 8 | type IDC struct { 9 | ID string 10 | Weight int 11 | } 12 | 13 | type Region struct { 14 | ID string 15 | Idcs []IDC 16 | Slots []byte 17 | } 18 | 19 | type Country string 20 | 21 | type Mapping struct { 22 | Region string 23 | Locations []string 24 | } 25 | 26 | type Router struct { 27 | Mapping map[Country]string 28 | Regions map[string]*Region 29 | } 30 | 31 | func LoadMapping(path string) (map[Country]string, error) { 32 | bts, err := ioutil.ReadFile(path) 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | var mps []Mapping 38 | err = json.Unmarshal(bts, &mps) 39 | if err != nil { 40 | return nil, err 41 | } 42 | mp := make(map[Country]string) 43 | for _, v := range mps { 44 | region := v.Region 45 | for _, loc := range v.Locations { 46 | mp[Country(loc)] = region 47 | } 48 | } 49 | return mp, nil 50 | } 51 | 52 | func LoadRegions(path string) (map[string]*Region, error) { 53 | bts, err := ioutil.ReadFile(path) 54 | if err != nil { 55 | return nil, err 56 | } 57 | var regions []*Region 58 | err = json.Unmarshal(bts, ®ions) 59 | if err != nil { 60 | return nil, err 61 | } 62 | res := make(map[string]*Region) 63 | for _, region := range regions { 64 | res[region.ID] = region 65 | for i, idc := range region.Idcs { 66 | // 1.通过权重生成分片中的slots 67 | shard := make([]byte, idc.Weight) 68 | // 2. 给当前slots设置值,指向索引i 69 | for j := 0; j < idc.Weight; j++ { 70 | shard[j] = byte(i) 71 | } 72 | // 2. 追加到Slots中 73 | region.Slots = append(region.Slots, shard...) 74 | } 75 | } 76 | return res, nil 77 | } 78 | -------------------------------------------------------------------------------- /services/router/data/ip2region.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/klintcheng/kim/69995ca15cf6e22dd8e8ed4f858e0106dbd28b84/services/router/data/ip2region.db -------------------------------------------------------------------------------- /services/router/data/mapping.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "region": "EC", 4 | "locations": [ 5 | "中国" 6 | ] 7 | }, 8 | { 9 | "region": "HK", 10 | "locations": [ 11 | "香港特别行政区", 12 | "日本", 13 | "菲律宾" 14 | ] 15 | }, 16 | { 17 | "region": "TW", 18 | "locations": [ 19 | "台湾省" 20 | ] 21 | } 22 | ] -------------------------------------------------------------------------------- /services/router/data/regions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "EC", 4 | "idcs": [ 5 | { 6 | "id": "SH_ALI", 7 | "weight": 100 8 | }, 9 | { 10 | "id": "HZ_ALI", 11 | "weight": 0 12 | } 13 | ] 14 | } 15 | ] -------------------------------------------------------------------------------- /services/router/ipregion/ipregion.go: -------------------------------------------------------------------------------- 1 | package ipregion 2 | 3 | import ( 4 | "github.com/lionsoul2014/ip2region/binding/golang/ip2region" 5 | ) 6 | 7 | type IpInfo struct { 8 | Country string 9 | Region string 10 | City string 11 | ISP string 12 | } 13 | 14 | type IpRegion interface { 15 | Search(ip string) (*IpInfo, error) 16 | } 17 | 18 | type Ip2region struct { 19 | region *ip2region.Ip2Region 20 | } 21 | 22 | func NewIp2region(path string) (IpRegion, error) { 23 | if path == "" { 24 | path = "ip2region.db" 25 | } 26 | region, err := ip2region.New(path) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | return &Ip2region{ 32 | region: region, 33 | }, nil 34 | } 35 | 36 | func (r *Ip2region) Search(ip string) (*IpInfo, error) { 37 | info, err := r.region.MemorySearch(ip) 38 | if err != nil { 39 | return nil, err 40 | } 41 | return &IpInfo{ 42 | Country: info.Country, 43 | Region: info.Region, 44 | City: info.City, 45 | ISP: info.ISP, 46 | }, nil 47 | } 48 | -------------------------------------------------------------------------------- /services/router/ipregion/ipregion_test.go: -------------------------------------------------------------------------------- 1 | package ipregion 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestIp2region_Search(t *testing.T) { 10 | region, err := NewIp2region("../ip2region.db") 11 | assert.Nil(t, err) 12 | 13 | got, err := region.Search("3.166.231.6") 14 | assert.Nil(t, err) 15 | t.Log(got) 16 | } 17 | -------------------------------------------------------------------------------- /services/router/server.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "context" 5 | "path" 6 | 7 | "github.com/kataras/iris/v12" 8 | "github.com/klintcheng/kim/logger" 9 | "github.com/klintcheng/kim/naming/consul" 10 | "github.com/klintcheng/kim/services/router/apis" 11 | "github.com/klintcheng/kim/services/router/conf" 12 | "github.com/klintcheng/kim/services/router/ipregion" 13 | "github.com/sirupsen/logrus" 14 | "github.com/spf13/cobra" 15 | ) 16 | 17 | // ServerStartOptions ServerStartOptions 18 | type ServerStartOptions struct { 19 | config string 20 | data string 21 | } 22 | 23 | // NewServerStartCmd creates a new http server command 24 | func NewServerStartCmd(ctx context.Context, version string) *cobra.Command { 25 | opts := &ServerStartOptions{} 26 | 27 | cmd := &cobra.Command{ 28 | Use: "router", 29 | Short: "Start a router", 30 | RunE: func(cmd *cobra.Command, args []string) error { 31 | return RunServerStart(ctx, opts, version) 32 | }, 33 | } 34 | cmd.PersistentFlags().StringVarP(&opts.config, "config", "c", "./router/conf.yaml", "Config file") 35 | cmd.PersistentFlags().StringVarP(&opts.data, "data", "d", "./router/data", "data path") 36 | return cmd 37 | } 38 | 39 | // RunServerStart run http server 40 | func RunServerStart(ctx context.Context, opts *ServerStartOptions, version string) error { 41 | config, err := conf.Init(opts.config) 42 | if err != nil { 43 | return err 44 | } 45 | _ = logger.Init(logger.Settings{ 46 | Level: "info", 47 | Filename: "./data/router.log", 48 | }) 49 | 50 | mappings, err := conf.LoadMapping(path.Join(opts.data, "mapping.json")) 51 | if err != nil { 52 | return err 53 | } 54 | logrus.Infof("load mappings - %v", mappings) 55 | regions, err := conf.LoadRegions(path.Join(opts.data, "regions.json")) 56 | if err != nil { 57 | return err 58 | } 59 | logrus.Infof("load regions - %v", regions) 60 | 61 | region, err := ipregion.NewIp2region(path.Join(opts.data, "ip2region.db")) 62 | if err != nil { 63 | return err 64 | } 65 | 66 | ns, err := consul.NewNaming(config.ConsulURL) 67 | if err != nil { 68 | return err 69 | } 70 | 71 | router := apis.RouterApi{ 72 | Naming: ns, 73 | IpRegion: region, 74 | Config: conf.Router{ 75 | Mapping: mappings, 76 | Regions: regions, 77 | }, 78 | } 79 | 80 | app := iris.Default() 81 | 82 | app.Get("/health", func(ctx iris.Context) { 83 | _, _ = ctx.WriteString("ok") 84 | }) 85 | routerAPI := app.Party("/api/lookup") 86 | { 87 | routerAPI.Get("/:token", router.Lookup) 88 | } 89 | 90 | // Start server 91 | return app.Listen(config.Listen, iris.WithOptimizations) 92 | } 93 | -------------------------------------------------------------------------------- /services/server/conf.yaml: -------------------------------------------------------------------------------- 1 | ServiceID: chat01 2 | Listen: ":8005" 3 | MonitorPort: 8006 4 | PublicPort: 8005 5 | Tags: 6 | - server 7 | Zone: zone_ali_03 8 | ConsulURL: localhost:8500 9 | RedisAddrs: localhost:6379 10 | RoyalURL: http://localhost:8080 11 | MessageGPool: 5000 12 | ConnectionGPool: 500 -------------------------------------------------------------------------------- /services/server/conf/config.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "log" 7 | "strings" 8 | "time" 9 | 10 | "github.com/go-redis/redis/v7" 11 | "github.com/kelseyhightower/envconfig" 12 | "github.com/klintcheng/kim" 13 | "github.com/klintcheng/kim/logger" 14 | "github.com/sirupsen/logrus" 15 | "github.com/spf13/viper" 16 | ) 17 | 18 | type Server struct { 19 | } 20 | 21 | // Config Config 22 | type Config struct { 23 | ServiceID string 24 | Listen string `default:":8005"` 25 | MonitorPort int `default:"8006"` 26 | PublicAddress string 27 | PublicPort int `default:"8005"` 28 | Tags []string 29 | Zone string `default:"zone_ali_03"` 30 | ConsulURL string 31 | RedisAddrs string 32 | RoyalURL string 33 | LogLevel string `default:"DEBUG"` 34 | MessageGPool int `default:"5000"` 35 | ConnectionGPool int `default:"500"` 36 | } 37 | 38 | func (c Config) String() string { 39 | bts, _ := json.Marshal(c) 40 | return string(bts) 41 | } 42 | 43 | // Init InitConfig 44 | func Init(file string) (*Config, error) { 45 | viper.SetConfigFile(file) 46 | viper.AddConfigPath(".") 47 | viper.AddConfigPath("/etc/conf") 48 | 49 | var config Config 50 | err := envconfig.Process("kim", &config) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | if err := viper.ReadInConfig(); err != nil { 56 | logger.Warn(err) 57 | } else { 58 | if err := viper.Unmarshal(&config); err != nil { 59 | return nil, err 60 | } 61 | } 62 | 63 | if config.ServiceID == "" { 64 | localIP := kim.GetLocalIP() 65 | config.ServiceID = fmt.Sprintf("server_%s", strings.ReplaceAll(localIP, ".", "")) 66 | } 67 | if config.PublicAddress == "" { 68 | config.PublicAddress = kim.GetLocalIP() 69 | } 70 | logger.Info(config) 71 | return &config, nil 72 | } 73 | 74 | func InitRedis(addr string, pass string) (*redis.Client, error) { 75 | redisdb := redis.NewClient(&redis.Options{ 76 | Addr: addr, 77 | Password: pass, 78 | DialTimeout: time.Second * 5, 79 | ReadTimeout: time.Second * 5, 80 | WriteTimeout: time.Second * 5, 81 | }) 82 | 83 | _, err := redisdb.Ping().Result() 84 | if err != nil { 85 | log.Println(err) 86 | return nil, err 87 | } 88 | return redisdb, nil 89 | } 90 | 91 | // InitFailoverRedis init redis with sentinels 92 | func InitFailoverRedis(masterName string, sentinelAddrs []string, password string, timeout time.Duration) (*redis.Client, error) { 93 | redisdb := redis.NewFailoverClient(&redis.FailoverOptions{ 94 | MasterName: masterName, 95 | SentinelAddrs: sentinelAddrs, 96 | Password: password, 97 | DialTimeout: time.Second * 5, 98 | ReadTimeout: timeout, 99 | WriteTimeout: timeout, 100 | }) 101 | 102 | _, err := redisdb.Ping().Result() 103 | if err != nil { 104 | logrus.Warn(err) 105 | } 106 | return redisdb, nil 107 | } 108 | -------------------------------------------------------------------------------- /services/server/handler/group_handler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "github.com/klintcheng/kim" 5 | "github.com/klintcheng/kim/services/server/service" 6 | "github.com/klintcheng/kim/wire/pkt" 7 | "github.com/klintcheng/kim/wire/rpc" 8 | ) 9 | 10 | type GroupHandler struct { 11 | groupService service.Group 12 | } 13 | 14 | func NewGroupHandler(groupService service.Group) *GroupHandler { 15 | return &GroupHandler{ 16 | groupService: groupService, 17 | } 18 | } 19 | 20 | func (h *GroupHandler) DoCreate(ctx kim.Context) { 21 | var req pkt.GroupCreateReq 22 | if err := ctx.ReadBody(&req); err != nil { 23 | _ = ctx.RespWithError(pkt.Status_InvalidPacketBody, err) 24 | return 25 | } 26 | resp, err := h.groupService.Create(ctx.Session().GetApp(), &rpc.CreateGroupReq{ 27 | Name: req.GetName(), 28 | Avatar: req.GetAvatar(), 29 | Introduction: req.GetIntroduction(), 30 | Owner: req.GetOwner(), 31 | Members: req.GetMembers(), 32 | }) 33 | if err != nil { 34 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 35 | return 36 | } 37 | 38 | locs, err := ctx.GetLocations(req.GetMembers()...) 39 | if err != nil && err != kim.ErrSessionNil { 40 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 41 | return 42 | } 43 | 44 | // push to receiver 45 | if len(locs) > 0 { 46 | if err = ctx.Dispatch(&pkt.GroupCreateNotify{ 47 | GroupId: resp.GroupId, 48 | Members: req.GetMembers(), 49 | }, locs...); err != nil { 50 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 51 | return 52 | } 53 | } 54 | 55 | _ = ctx.Resp(pkt.Status_Success, &pkt.GroupCreateResp{ 56 | GroupId: resp.GroupId, 57 | }) 58 | } 59 | 60 | func (h *GroupHandler) DoJoin(ctx kim.Context) { 61 | var req pkt.GroupJoinReq 62 | if err := ctx.ReadBody(&req); err != nil { 63 | _ = ctx.RespWithError(pkt.Status_InvalidPacketBody, err) 64 | return 65 | } 66 | err := h.groupService.Join(ctx.Session().GetApp(), &rpc.JoinGroupReq{ 67 | Account: req.Account, 68 | GroupId: req.GetGroupId(), 69 | }) 70 | if err != nil { 71 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 72 | return 73 | } 74 | 75 | _ = ctx.Resp(pkt.Status_Success, nil) 76 | } 77 | 78 | func (h *GroupHandler) DoQuit(ctx kim.Context) { 79 | var req pkt.GroupQuitReq 80 | if err := ctx.ReadBody(&req); err != nil { 81 | _ = ctx.RespWithError(pkt.Status_InvalidPacketBody, err) 82 | return 83 | } 84 | err := h.groupService.Quit(ctx.Session().GetApp(), &rpc.QuitGroupReq{ 85 | Account: req.Account, 86 | GroupId: req.GetGroupId(), 87 | }) 88 | if err != nil { 89 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 90 | return 91 | } 92 | _ = ctx.Resp(pkt.Status_Success, nil) 93 | } 94 | 95 | func (h *GroupHandler) DoDetail(ctx kim.Context) { 96 | var req pkt.GroupGetReq 97 | if err := ctx.ReadBody(&req); err != nil { 98 | _ = ctx.RespWithError(pkt.Status_InvalidPacketBody, err) 99 | return 100 | } 101 | resp, err := h.groupService.Detail(ctx.Session().GetApp(), &rpc.GetGroupReq{ 102 | GroupId: req.GetGroupId(), 103 | }) 104 | if err != nil { 105 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 106 | return 107 | } 108 | membersResp, err := h.groupService.Members(ctx.Session().GetApp(), &rpc.GroupMembersReq{ 109 | GroupId: req.GetGroupId(), 110 | }) 111 | if err != nil { 112 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 113 | return 114 | } 115 | var members = make([]*pkt.Member, len(membersResp.GetUsers())) 116 | for i, m := range membersResp.GetUsers() { 117 | members[i] = &pkt.Member{ 118 | Account: m.Account, 119 | Alias: m.Alias, 120 | JoinTime: m.JoinTime, 121 | Avatar: m.Avatar, 122 | } 123 | } 124 | _ = ctx.Resp(pkt.Status_Success, &pkt.GroupGetResp{ 125 | Id: resp.Id, 126 | Name: resp.Name, 127 | Introduction: resp.Introduction, 128 | Avatar: resp.Avatar, 129 | Owner: resp.Owner, 130 | Members: members, 131 | }) 132 | } 133 | -------------------------------------------------------------------------------- /services/server/handler/login_handler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "github.com/klintcheng/kim" 5 | "github.com/klintcheng/kim/logger" 6 | "github.com/klintcheng/kim/wire/pkt" 7 | ) 8 | 9 | type LoginHandler struct { 10 | } 11 | 12 | func NewLoginHandler() *LoginHandler { 13 | return &LoginHandler{} 14 | } 15 | 16 | func (h *LoginHandler) DoSysLogin(ctx kim.Context) { 17 | log := logger.WithField("func", "DoSysLogin") 18 | // 1. 序列化 19 | var session pkt.Session 20 | if err := ctx.ReadBody(&session); err != nil { 21 | _ = ctx.RespWithError(pkt.Status_InvalidPacketBody, err) 22 | return 23 | } 24 | 25 | log.Infof("do login of %v ", session.String()) 26 | // 2. 检查当前账号是否已经登录在其它地方 27 | old, err := ctx.GetLocation(session.Account, "") 28 | if err != nil && err != kim.ErrSessionNil { 29 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 30 | return 31 | } 32 | 33 | if old != nil { 34 | // 3. 通知这个用户下线 35 | _ = ctx.Dispatch(&pkt.KickoutNotify{ 36 | ChannelId: old.ChannelId, 37 | }, old) 38 | } 39 | 40 | // 4. 添加到会话管理器中 41 | err = ctx.Add(&session) 42 | if err != nil { 43 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 44 | return 45 | } 46 | // 5. 返回一个登录成功的消息 47 | var resp = &pkt.LoginResp{ 48 | ChannelId: session.ChannelId, 49 | Account: session.Account, 50 | } 51 | _ = ctx.Resp(pkt.Status_Success, resp) 52 | } 53 | 54 | func (h *LoginHandler) DoSysLogout(ctx kim.Context) { 55 | logger.WithField("func", "DoSysLogout").Infof("do Logout of %s %s ", ctx.Session().GetChannelId(), ctx.Session().GetAccount()) 56 | 57 | err := ctx.Delete(ctx.Session().GetAccount(), ctx.Session().GetChannelId()) 58 | if err != nil { 59 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 60 | return 61 | } 62 | 63 | _ = ctx.Resp(pkt.Status_Success, nil) 64 | } 65 | -------------------------------------------------------------------------------- /services/server/handler/login_handler_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func Test_handleSysLogin(t *testing.T) { 8 | 9 | } 10 | -------------------------------------------------------------------------------- /services/server/handler/offline_handler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/klintcheng/kim" 7 | "github.com/klintcheng/kim/services/server/service" 8 | "github.com/klintcheng/kim/wire/pkt" 9 | "github.com/klintcheng/kim/wire/rpc" 10 | ) 11 | 12 | type OfflineHandler struct { 13 | msgService service.Message 14 | } 15 | 16 | func NewOfflineHandler(message service.Message) *OfflineHandler { 17 | return &OfflineHandler{ 18 | msgService: message, 19 | } 20 | } 21 | 22 | func (h *OfflineHandler) DoSyncIndex(ctx kim.Context) { 23 | var req pkt.MessageIndexReq 24 | if err := ctx.ReadBody(&req); err != nil { 25 | _ = ctx.RespWithError(pkt.Status_InvalidPacketBody, err) 26 | return 27 | } 28 | resp, err := h.msgService.GetMessageIndex(ctx.Session().GetApp(), &rpc.GetOfflineMessageIndexReq{ 29 | Account: ctx.Session().GetAccount(), 30 | MessageId: req.GetMessageId(), 31 | }) 32 | if err != nil { 33 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 34 | return 35 | } 36 | var list = make([]*pkt.MessageIndex, len(resp.List)) 37 | for i, val := range resp.List { 38 | list[i] = &pkt.MessageIndex{ 39 | MessageId: val.MessageId, 40 | Direction: val.Direction, 41 | SendTime: val.SendTime, 42 | AccountB: val.AccountB, 43 | Group: val.Group, 44 | } 45 | } 46 | _ = ctx.Resp(pkt.Status_Success, &pkt.MessageIndexResp{ 47 | Indexes: list, 48 | }) 49 | } 50 | 51 | func (h *OfflineHandler) DoSyncContent(ctx kim.Context) { 52 | var req pkt.MessageContentReq 53 | if err := ctx.ReadBody(&req); err != nil { 54 | _ = ctx.RespWithError(pkt.Status_InvalidPacketBody, err) 55 | return 56 | } 57 | if len(req.MessageIds) == 0 { 58 | _ = ctx.RespWithError(pkt.Status_InvalidPacketBody, errors.New("empty MessageIds")) 59 | return 60 | } 61 | resp, err := h.msgService.GetMessageContent(ctx.Session().GetApp(), &rpc.GetOfflineMessageContentReq{ 62 | MessageIds: req.MessageIds, 63 | }) 64 | if err != nil { 65 | _ = ctx.RespWithError(pkt.Status_SystemException, err) 66 | return 67 | } 68 | var list = make([]*pkt.MessageContent, len(resp.List)) 69 | for i, val := range resp.List { 70 | list[i] = &pkt.MessageContent{ 71 | MessageId: val.Id, 72 | Type: val.Type, 73 | Body: val.Body, 74 | Extra: val.Extra, 75 | } 76 | } 77 | _ = ctx.Resp(pkt.Status_Success, &pkt.MessageContentResp{ 78 | Contents: list, 79 | }) 80 | } 81 | -------------------------------------------------------------------------------- /services/server/handler/offline_handler_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | -------------------------------------------------------------------------------- /services/server/serv/handler.go: -------------------------------------------------------------------------------- 1 | package serv 2 | 3 | import ( 4 | "bytes" 5 | "strings" 6 | "time" 7 | 8 | "github.com/klintcheng/kim" 9 | "github.com/klintcheng/kim/container" 10 | "github.com/klintcheng/kim/logger" 11 | "github.com/klintcheng/kim/wire" 12 | "github.com/klintcheng/kim/wire/pkt" 13 | "google.golang.org/protobuf/proto" 14 | ) 15 | 16 | var log = logger.WithFields(logger.Fields{ 17 | "service": wire.SNChat, 18 | "pkg": "serv", 19 | }) 20 | 21 | // ServHandler ServHandler 22 | type ServHandler struct { 23 | r *kim.Router 24 | cache kim.SessionStorage 25 | dispatcher *ServerDispatcher 26 | } 27 | 28 | func NewServHandler(r *kim.Router, cache kim.SessionStorage) *ServHandler { 29 | return &ServHandler{ 30 | r: r, 31 | dispatcher: &ServerDispatcher{}, 32 | cache: cache, 33 | } 34 | } 35 | 36 | // Accept this connection 37 | func (h *ServHandler) Accept(conn kim.Conn, timeout time.Duration) (string, kim.Meta, error) { 38 | _ = conn.SetReadDeadline(time.Now().Add(timeout)) 39 | frame, err := conn.ReadFrame() 40 | if err != nil { 41 | return "", nil, err 42 | } 43 | 44 | var req pkt.InnerHandshakeReq 45 | _ = proto.Unmarshal(frame.GetPayload(), &req) 46 | log.Info("Accept -- ", req.ServiceId) 47 | 48 | return req.ServiceId, nil, nil 49 | } 50 | 51 | // Receive default listener 52 | func (h *ServHandler) Receive(ag kim.Agent, payload []byte) { 53 | buf := bytes.NewBuffer(payload) 54 | packet, err := pkt.MustReadLogicPkt(buf) 55 | if err != nil { 56 | log.Error(err) 57 | return 58 | } 59 | var session *pkt.Session 60 | if packet.Command == wire.CommandLoginSignIn { 61 | server, _ := packet.GetMeta(wire.MetaDestServer) 62 | session = &pkt.Session{ 63 | ChannelId: packet.ChannelId, 64 | GateId: server.(string), 65 | Tags: []string{"AutoGenerated"}, 66 | } 67 | } else { 68 | // TODO:优化点 69 | session, err = h.cache.Get(packet.ChannelId) 70 | if err == kim.ErrSessionNil { 71 | _ = RespErr(ag, packet, pkt.Status_SessionNotFound) 72 | return 73 | } else if err != nil { 74 | _ = RespErr(ag, packet, pkt.Status_SystemException) 75 | return 76 | } 77 | } 78 | log.Debugf("recv a message from %s %s", session, &packet.Header) 79 | err = h.r.Serve(packet, h.dispatcher, h.cache, session) 80 | if err != nil { 81 | log.Warn(err) 82 | } 83 | 84 | } 85 | 86 | func RespErr(ag kim.Agent, p *pkt.LogicPkt, status pkt.Status) error { 87 | packet := pkt.NewFrom(&p.Header) 88 | packet.Status = status 89 | packet.Flag = pkt.Flag_Response 90 | 91 | packet.AddStringMeta(wire.MetaDestChannels, p.Header.ChannelId) 92 | return container.Push(ag.ID(), packet) 93 | } 94 | 95 | type ServerDispatcher struct { 96 | } 97 | 98 | func (d *ServerDispatcher) Push(gateway string, channels []string, p *pkt.LogicPkt) error { 99 | p.AddStringMeta(wire.MetaDestChannels, strings.Join(channels, ",")) 100 | return container.Push(gateway, p) 101 | } 102 | 103 | // Disconnect default listener 104 | func (h *ServHandler) Disconnect(id string) error { 105 | logger.Warnf("close event of %s", id) 106 | return nil 107 | } 108 | -------------------------------------------------------------------------------- /services/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/go-resty/resty/v2" 9 | "github.com/klintcheng/kim" 10 | "github.com/klintcheng/kim/container" 11 | "github.com/klintcheng/kim/logger" 12 | "github.com/klintcheng/kim/middleware" 13 | "github.com/klintcheng/kim/naming" 14 | "github.com/klintcheng/kim/naming/consul" 15 | "github.com/klintcheng/kim/services/server/conf" 16 | "github.com/klintcheng/kim/services/server/handler" 17 | "github.com/klintcheng/kim/services/server/serv" 18 | "github.com/klintcheng/kim/services/server/service" 19 | "github.com/klintcheng/kim/storage" 20 | "github.com/klintcheng/kim/tcp" 21 | "github.com/klintcheng/kim/wire" 22 | "github.com/spf13/cobra" 23 | ) 24 | 25 | // ServerStartOptions ServerStartOptions 26 | type ServerStartOptions struct { 27 | config string 28 | serviceName string 29 | } 30 | 31 | // NewServerStartCmd creates a new http server command 32 | func NewServerStartCmd(ctx context.Context, version string) *cobra.Command { 33 | opts := &ServerStartOptions{} 34 | 35 | cmd := &cobra.Command{ 36 | Use: "server", 37 | Short: "Start a server", 38 | RunE: func(cmd *cobra.Command, args []string) error { 39 | return RunServerStart(ctx, opts, version) 40 | }, 41 | } 42 | cmd.PersistentFlags().StringVarP(&opts.config, "config", "c", "./server/conf.yaml", "Config file") 43 | cmd.PersistentFlags().StringVarP(&opts.serviceName, "serviceName", "s", "chat", "defined a service name,option is login or chat") 44 | return cmd 45 | } 46 | 47 | // RunServerStart run http server 48 | func RunServerStart(ctx context.Context, opts *ServerStartOptions, version string) error { 49 | config, err := conf.Init(opts.config) 50 | if err != nil { 51 | return err 52 | } 53 | _ = logger.Init(logger.Settings{ 54 | Level: config.LogLevel, 55 | Filename: "./data/server.log", 56 | }) 57 | 58 | var groupService service.Group 59 | var messageService service.Message 60 | if strings.TrimSpace(config.RoyalURL) != "" { 61 | groupService = service.NewGroupService(config.RoyalURL) 62 | messageService = service.NewMessageService(config.RoyalURL) 63 | } else { 64 | srvRecord := &resty.SRVRecord{ 65 | Domain: "consul", 66 | Service: wire.SNService, 67 | } 68 | groupService = service.NewGroupServiceWithSRV("http", srvRecord) 69 | messageService = service.NewMessageServiceWithSRV("http", srvRecord) 70 | } 71 | 72 | r := kim.NewRouter() 73 | r.Use(middleware.Recover()) 74 | 75 | // login 76 | loginHandler := handler.NewLoginHandler() 77 | r.Handle(wire.CommandLoginSignIn, loginHandler.DoSysLogin) 78 | r.Handle(wire.CommandLoginSignOut, loginHandler.DoSysLogout) 79 | // talk 80 | chatHandler := handler.NewChatHandler(messageService, groupService) 81 | r.Handle(wire.CommandChatUserTalk, chatHandler.DoUserTalk) 82 | r.Handle(wire.CommandChatGroupTalk, chatHandler.DoGroupTalk) 83 | r.Handle(wire.CommandChatTalkAck, chatHandler.DoTalkAck) 84 | // group 85 | groupHandler := handler.NewGroupHandler(groupService) 86 | r.Handle(wire.CommandGroupCreate, groupHandler.DoCreate) 87 | r.Handle(wire.CommandGroupJoin, groupHandler.DoJoin) 88 | r.Handle(wire.CommandGroupQuit, groupHandler.DoQuit) 89 | r.Handle(wire.CommandGroupDetail, groupHandler.DoDetail) 90 | 91 | // offline 92 | offlineHandler := handler.NewOfflineHandler(messageService) 93 | r.Handle(wire.CommandOfflineIndex, offlineHandler.DoSyncIndex) 94 | r.Handle(wire.CommandOfflineContent, offlineHandler.DoSyncContent) 95 | 96 | rdb, err := conf.InitRedis(config.RedisAddrs, "") 97 | if err != nil { 98 | return err 99 | } 100 | cache := storage.NewRedisStorage(rdb) 101 | servhandler := serv.NewServHandler(r, cache) 102 | 103 | meta := make(map[string]string) 104 | meta[consul.KeyHealthURL] = fmt.Sprintf("http://%s:%d/health", config.PublicAddress, config.MonitorPort) 105 | meta["zone"] = config.Zone 106 | 107 | service := &naming.DefaultService{ 108 | Id: config.ServiceID, 109 | Name: opts.serviceName, 110 | Address: config.PublicAddress, 111 | Port: config.PublicPort, 112 | Protocol: string(wire.ProtocolTCP), 113 | Tags: config.Tags, 114 | Meta: meta, 115 | } 116 | srvOpts := []kim.ServerOption{ 117 | kim.WithConnectionGPool(config.ConnectionGPool), kim.WithMessageGPool(config.MessageGPool), 118 | } 119 | srv := tcp.NewServer(config.Listen, service, srvOpts...) 120 | 121 | srv.SetReadWait(kim.DefaultReadWait) 122 | srv.SetAcceptor(servhandler) 123 | srv.SetMessageListener(servhandler) 124 | srv.SetStateListener(servhandler) 125 | 126 | if err := container.Init(srv); err != nil { 127 | return err 128 | } 129 | container.EnableMonitor(fmt.Sprintf(":%d", config.MonitorPort)) 130 | 131 | ns, err := consul.NewNaming(config.ConsulURL) 132 | if err != nil { 133 | return err 134 | } 135 | container.SetServiceNaming(ns) 136 | 137 | return container.Start() 138 | } 139 | -------------------------------------------------------------------------------- /services/server/service/group.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/go-resty/resty/v2" 8 | "github.com/klintcheng/kim/logger" 9 | "github.com/klintcheng/kim/wire/rpc" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | type Group interface { 14 | Create(app string, req *rpc.CreateGroupReq) (*rpc.CreateGroupResp, error) 15 | Members(app string, req *rpc.GroupMembersReq) (*rpc.GroupMembersResp, error) 16 | Join(app string, req *rpc.JoinGroupReq) error 17 | Quit(app string, req *rpc.QuitGroupReq) error 18 | Detail(app string, req *rpc.GetGroupReq) (*rpc.GetGroupResp, error) 19 | } 20 | 21 | type GroupHttp struct { 22 | url string 23 | cli *resty.Client 24 | srv *resty.SRVRecord 25 | } 26 | 27 | func NewGroupService(url string) Group { 28 | cli := resty.New().SetRetryCount(3).SetTimeout(time.Second * 5) 29 | cli.SetHeader("Content-Type", "application/x-protobuf") 30 | cli.SetHeader("Accept", "application/x-protobuf") 31 | cli.SetScheme("http") 32 | return &GroupHttp{ 33 | url: url, 34 | cli: cli, 35 | } 36 | } 37 | 38 | func NewGroupServiceWithSRV(scheme string, srv *resty.SRVRecord) Group { 39 | cli := resty.New().SetRetryCount(3).SetTimeout(time.Second * 5) 40 | cli.SetHeader("Content-Type", "application/x-protobuf") 41 | cli.SetHeader("Accept", "application/x-protobuf") 42 | cli.SetScheme("http") 43 | 44 | return &GroupHttp{ 45 | url: "", 46 | cli: cli, 47 | srv: srv, 48 | } 49 | } 50 | 51 | func (g *GroupHttp) Create(app string, req *rpc.CreateGroupReq) (*rpc.CreateGroupResp, error) { 52 | path := fmt.Sprintf("%s/api/%s/group", g.url, app) 53 | 54 | body, _ := proto.Marshal(req) 55 | response, err := g.Req().SetBody(body).Post(path) 56 | if err != nil { 57 | return nil, err 58 | } 59 | if response.StatusCode() != 200 { 60 | return nil, fmt.Errorf("GroupHttp.Create response.StatusCode() = %d, want 200", response.StatusCode()) 61 | } 62 | var resp rpc.CreateGroupResp 63 | _ = proto.Unmarshal(response.Body(), &resp) 64 | logger.Debugf("GroupHttp.Create resp: %v", &resp) 65 | return &resp, nil 66 | } 67 | 68 | func (g *GroupHttp) Members(app string, req *rpc.GroupMembersReq) (*rpc.GroupMembersResp, error) { 69 | path := fmt.Sprintf("%s/api/%s/group/members/%s", g.url, app, req.GroupId) 70 | 71 | response, err := g.Req().Get(path) 72 | if err != nil { 73 | return nil, err 74 | } 75 | if response.StatusCode() != 200 { 76 | return nil, fmt.Errorf("GroupHttp.Members response.StatusCode() = %d, want 200", response.StatusCode()) 77 | } 78 | var resp rpc.GroupMembersResp 79 | _ = proto.Unmarshal(response.Body(), &resp) 80 | logger.Debugf("GroupHttp.Members resp: %v", &resp) 81 | return &resp, nil 82 | } 83 | 84 | func (g *GroupHttp) Join(app string, req *rpc.JoinGroupReq) error { 85 | path := fmt.Sprintf("%s/api/%s/group/member", g.url, app) 86 | body, _ := proto.Marshal(req) 87 | response, err := g.Req().SetBody(body).Post(path) 88 | if err != nil { 89 | return err 90 | } 91 | if response.StatusCode() != 200 { 92 | return fmt.Errorf("GroupHttp.Join response.StatusCode() = %d, want 200", response.StatusCode()) 93 | } 94 | return nil 95 | } 96 | 97 | func (g *GroupHttp) Quit(app string, req *rpc.QuitGroupReq) error { 98 | path := fmt.Sprintf("%s/api/%s/group/member", g.url, app) 99 | body, _ := proto.Marshal(req) 100 | response, err := g.Req().SetBody(body).Delete(path) 101 | if err != nil { 102 | return err 103 | } 104 | if response.StatusCode() != 200 { 105 | return fmt.Errorf("GroupHttp.Quit response.StatusCode() = %d, want 200", response.StatusCode()) 106 | } 107 | return nil 108 | } 109 | 110 | func (g *GroupHttp) Detail(app string, req *rpc.GetGroupReq) (*rpc.GetGroupResp, error) { 111 | path := fmt.Sprintf("%s/api/%s/group/%s", g.url, app, req.GroupId) 112 | response, err := g.Req().Get(path) 113 | if err != nil { 114 | return nil, err 115 | } 116 | if response.StatusCode() != 200 { 117 | return nil, fmt.Errorf("GroupHttp.Detail response.StatusCode() = %d, want 200", response.StatusCode()) 118 | } 119 | var resp rpc.GetGroupResp 120 | _ = proto.Unmarshal(response.Body(), &resp) 121 | logger.Debugf("GroupHttp.Detail resp: %v", &resp) 122 | return &resp, nil 123 | } 124 | 125 | func (g *GroupHttp) Req() *resty.Request { 126 | if g.srv == nil { 127 | return g.cli.R() 128 | } 129 | return g.cli.R().SetSRV(g.srv) 130 | } 131 | -------------------------------------------------------------------------------- /services/server/service/group_test.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/klintcheng/kim/wire/rpc" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | const app = "kim_t" 11 | 12 | var groupService = NewGroupService("http://localhost:8080") 13 | 14 | func TestGroupService(t *testing.T) { 15 | 16 | resp, err := groupService.Create(app, &rpc.CreateGroupReq{ 17 | Name: "test", 18 | Owner: "test1", 19 | Members: []string{"test1", "test2"}, 20 | }) 21 | assert.Nil(t, err) 22 | assert.NotEmpty(t, resp.GroupId) 23 | t.Log(resp.GroupId) 24 | 25 | mresp, err := groupService.Members(app, &rpc.GroupMembersReq{ 26 | GroupId: resp.GroupId, 27 | }) 28 | assert.Nil(t, err) 29 | 30 | assert.Equal(t, 2, len(mresp.Users)) 31 | assert.Equal(t, "test1", mresp.Users[0].Account) 32 | assert.Equal(t, "test2", mresp.Users[1].Account) 33 | 34 | err = groupService.Join(app, &rpc.JoinGroupReq{ 35 | Account: "test3", 36 | GroupId: resp.GroupId, 37 | }) 38 | assert.Nil(t, err) 39 | 40 | mresp, err = groupService.Members(app, &rpc.GroupMembersReq{ 41 | GroupId: resp.GroupId, 42 | }) 43 | assert.Nil(t, err) 44 | 45 | assert.Equal(t, 3, len(mresp.Users)) 46 | assert.Equal(t, "test3", mresp.Users[2].Account) 47 | assert.Equal(t, "test2", mresp.Users[1].Account) 48 | 49 | err = groupService.Quit(app, &rpc.QuitGroupReq{ 50 | Account: "test2", 51 | GroupId: resp.GroupId, 52 | }) 53 | assert.Nil(t, err) 54 | 55 | mresp, err = groupService.Members(app, &rpc.GroupMembersReq{ 56 | GroupId: resp.GroupId, 57 | }) 58 | assert.Nil(t, err) 59 | 60 | assert.Equal(t, 2, len(mresp.Users)) 61 | assert.Equal(t, "test1", mresp.Users[0].Account) 62 | assert.Equal(t, "test3", mresp.Users[1].Account) 63 | } 64 | -------------------------------------------------------------------------------- /services/server/service/message.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/go-resty/resty/v2" 8 | "google.golang.org/protobuf/proto" 9 | 10 | "github.com/klintcheng/kim/logger" 11 | "github.com/klintcheng/kim/wire/rpc" 12 | ) 13 | 14 | type Message interface { 15 | InsertUser(app string, req *rpc.InsertMessageReq) (*rpc.InsertMessageResp, error) 16 | InsertGroup(app string, req *rpc.InsertMessageReq) (*rpc.InsertMessageResp, error) 17 | SetAck(app string, req *rpc.AckMessageReq) error 18 | GetMessageIndex(app string, req *rpc.GetOfflineMessageIndexReq) (*rpc.GetOfflineMessageIndexResp, error) 19 | GetMessageContent(app string, req *rpc.GetOfflineMessageContentReq) (*rpc.GetOfflineMessageContentResp, error) 20 | } 21 | 22 | type MessageHttp struct { 23 | url string 24 | cli *resty.Client 25 | srv *resty.SRVRecord 26 | } 27 | 28 | func NewMessageService(url string) Message { 29 | cli := resty.New().SetRetryCount(3).SetTimeout(time.Second * 5) 30 | cli.SetHeader("Content-Type", "application/x-protobuf") 31 | cli.SetHeader("Accept", "application/x-protobuf") 32 | return &MessageHttp{ 33 | url: url, 34 | cli: cli, 35 | } 36 | } 37 | 38 | func NewMessageServiceWithSRV(scheme string, srv *resty.SRVRecord) Message { 39 | cli := resty.New().SetRetryCount(3).SetTimeout(time.Second * 5) 40 | cli.SetHeader("Content-Type", "application/x-protobuf") 41 | cli.SetHeader("Accept", "application/x-protobuf") 42 | cli.SetScheme("http") 43 | 44 | return &MessageHttp{ 45 | url: "", 46 | cli: cli, 47 | srv: srv, 48 | } 49 | } 50 | 51 | func (m *MessageHttp) InsertUser(app string, req *rpc.InsertMessageReq) (*rpc.InsertMessageResp, error) { 52 | path := fmt.Sprintf("%s/api/%s/message/user", m.url, app) 53 | t1 := time.Now() 54 | 55 | body, _ := proto.Marshal(req) 56 | response, err := m.Req().SetBody(body).Post(path) 57 | if err != nil { 58 | return nil, err 59 | } 60 | if response.StatusCode() != 200 { 61 | return nil, fmt.Errorf("MessageHttp.InsertUser response.StatusCode() = %d, want 200", response.StatusCode()) 62 | } 63 | var resp rpc.InsertMessageResp 64 | _ = proto.Unmarshal(response.Body(), &resp) 65 | logger.Debugf("MessageHttp.InsertUser cost %v resp: %v", time.Since(t1), &resp) 66 | return &resp, nil 67 | } 68 | 69 | func (m *MessageHttp) InsertGroup(app string, req *rpc.InsertMessageReq) (*rpc.InsertMessageResp, error) { 70 | path := fmt.Sprintf("%s/api/%s/message/group", m.url, app) 71 | t1 := time.Now() 72 | body, _ := proto.Marshal(req) 73 | response, err := m.Req().SetBody(body).Post(path) 74 | if err != nil { 75 | return nil, err 76 | } 77 | if response.StatusCode() != 200 { 78 | return nil, fmt.Errorf("MessageHttp.InsertGroup response.StatusCode() = %d, want 200", response.StatusCode()) 79 | } 80 | var resp rpc.InsertMessageResp 81 | _ = proto.Unmarshal(response.Body(), &resp) 82 | logger.Debugf("MessageHttp.InsertGroup cost %v resp: %v", time.Since(t1), &resp) 83 | return &resp, nil 84 | } 85 | 86 | func (m *MessageHttp) SetAck(app string, req *rpc.AckMessageReq) error { 87 | path := fmt.Sprintf("%s/api/%s/message/ack", m.url, app) 88 | body, _ := proto.Marshal(req) 89 | response, err := m.Req().SetBody(body).Post(path) 90 | if err != nil { 91 | return err 92 | } 93 | if response.StatusCode() != 200 { 94 | return fmt.Errorf("MessageHttp.SetAck response.StatusCode() = %d, want 200", response.StatusCode()) 95 | } 96 | return nil 97 | } 98 | 99 | func (m *MessageHttp) GetMessageIndex(app string, req *rpc.GetOfflineMessageIndexReq) (*rpc.GetOfflineMessageIndexResp, error) { 100 | path := fmt.Sprintf("%s/api/%s/offline/index", m.url, app) 101 | body, _ := proto.Marshal(req) 102 | 103 | response, err := m.Req().SetBody(body).Post(path) 104 | if err != nil { 105 | return nil, err 106 | } 107 | if response.StatusCode() != 200 { 108 | return nil, fmt.Errorf("MessageHttp.GetMessageIndex response.StatusCode() = %d, want 200", response.StatusCode()) 109 | } 110 | var resp rpc.GetOfflineMessageIndexResp 111 | _ = proto.Unmarshal(response.Body(), &resp) 112 | return &resp, nil 113 | } 114 | 115 | func (m *MessageHttp) GetMessageContent(app string, req *rpc.GetOfflineMessageContentReq) (*rpc.GetOfflineMessageContentResp, error) { 116 | path := fmt.Sprintf("%s/api/%s/offline/content", m.url, app) 117 | body, _ := proto.Marshal(req) 118 | response, err := m.Req().SetBody(body).Post(path) 119 | if err != nil { 120 | return nil, err 121 | } 122 | if response.StatusCode() != 200 { 123 | return nil, fmt.Errorf("MessageHttp.GetMessageContent response.StatusCode() = %d, want 200", response.StatusCode()) 124 | } 125 | var resp rpc.GetOfflineMessageContentResp 126 | _ = proto.Unmarshal(response.Body(), &resp) 127 | return &resp, nil 128 | } 129 | 130 | func (m *MessageHttp) Req() *resty.Request { 131 | if m.srv == nil { 132 | return m.cli.R() 133 | } 134 | return m.cli.R().SetSRV(m.srv) 135 | } 136 | -------------------------------------------------------------------------------- /services/server/service/message_test.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/go-resty/resty/v2" 9 | "github.com/klintcheng/kim/wire/rpc" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | var messageService = NewMessageServiceWithSRV("http", &resty.SRVRecord{ 14 | Domain: "consul", 15 | Service: "royal", 16 | }) 17 | 18 | func Test_Message(t *testing.T) { 19 | 20 | m := rpc.Message{ 21 | Type: 1, 22 | Body: "hello world", 23 | } 24 | dest := fmt.Sprintf("u%d", time.Now().Unix()) 25 | _, err := messageService.InsertUser(app, &rpc.InsertMessageReq{ 26 | Sender: "test1", 27 | Dest: dest, 28 | SendTime: time.Now().UnixNano(), 29 | Message: &m, 30 | }) 31 | assert.Nil(t, err) 32 | 33 | resp, err := messageService.GetMessageIndex(app, &rpc.GetOfflineMessageIndexReq{ 34 | Account: dest, 35 | }) 36 | assert.Nil(t, err) 37 | assert.Equal(t, 1, len(resp.List)) 38 | 39 | index := resp.List[0] 40 | assert.Equal(t, "test1", index.AccountB) 41 | 42 | resp2, err := messageService.GetMessageContent(app, &rpc.GetOfflineMessageContentReq{ 43 | MessageIds: []int64{index.MessageId}, 44 | }) 45 | assert.Nil(t, err) 46 | 47 | assert.Equal(t, 1, len(resp2.List)) 48 | content := resp2.List[0] 49 | assert.Equal(t, m.Body, content.Body) 50 | assert.Equal(t, m.Type, content.Type) 51 | assert.Equal(t, index.MessageId, content.Id) 52 | 53 | //again 54 | resp, err = messageService.GetMessageIndex(app, &rpc.GetOfflineMessageIndexReq{ 55 | Account: dest, 56 | MessageId: index.MessageId, 57 | }) 58 | assert.Nil(t, err) 59 | assert.Equal(t, 0, len(resp.List)) 60 | 61 | resp, err = messageService.GetMessageIndex(app, &rpc.GetOfflineMessageIndexReq{ 62 | Account: dest, 63 | }) 64 | assert.Nil(t, err) 65 | assert.Equal(t, 0, len(resp.List)) 66 | } 67 | 68 | func Test_Group_Message(t *testing.T) { 69 | resp, err := groupService.Create(app, &rpc.CreateGroupReq{ 70 | Name: "test", 71 | Owner: "test1", 72 | Members: []string{"test1", "test2", "test3"}, 73 | }) 74 | assert.Nil(t, err) 75 | assert.NotEmpty(t, resp.GroupId) 76 | 77 | m := rpc.Message{ 78 | Type: 1, 79 | Body: "hello world", 80 | } 81 | dest := resp.GroupId 82 | _, err = messageService.InsertGroup(app, &rpc.InsertMessageReq{ 83 | Sender: "test1", 84 | Dest: dest, 85 | SendTime: time.Now().UnixNano(), 86 | Message: &m, 87 | }) 88 | assert.Nil(t, err) 89 | 90 | indexresp, err := messageService.GetMessageIndex(app, &rpc.GetOfflineMessageIndexReq{ 91 | Account: "test1", 92 | }) 93 | assert.Nil(t, err) 94 | assert.Equal(t, 1, len(indexresp.List)) 95 | assert.Equal(t, int32(1), indexresp.List[0].Direction) 96 | 97 | indexresp2, err := messageService.GetMessageIndex(app, &rpc.GetOfflineMessageIndexReq{ 98 | Account: "test2", 99 | }) 100 | assert.Nil(t, err) 101 | assert.Equal(t, 1, len(indexresp2.List)) 102 | assert.Equal(t, int32(0), indexresp2.List[0].Direction) 103 | } 104 | -------------------------------------------------------------------------------- /services/service/conf.yaml: -------------------------------------------------------------------------------- 1 | ServiceID: royal01 2 | Listen: ":8080" 3 | PublicPort: 8080 4 | Tags: 5 | - royal 6 | ConsulURL: localhost:8500 7 | RedisAddrs: localhost:6379 8 | BaseDb: root:123456@tcp(127.0.0.1:3306)/kim_base?charset=utf8mb4&parseTime=True&loc=Local 9 | MessageDb: root:123456@tcp(127.0.0.1:3306)/kim_message?charset=utf8mb4&parseTime=True&loc=Local -------------------------------------------------------------------------------- /services/service/conf/config.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "log" 7 | "os" 8 | "strconv" 9 | "strings" 10 | "time" 11 | 12 | "github.com/go-redis/redis/v7" 13 | "github.com/kataras/iris/v12/middleware/accesslog" 14 | "github.com/kelseyhightower/envconfig" 15 | "github.com/klintcheng/kim" 16 | "github.com/klintcheng/kim/logger" 17 | "github.com/sirupsen/logrus" 18 | "github.com/spf13/viper" 19 | ) 20 | 21 | // Config Config 22 | type Config struct { 23 | ServiceID string 24 | NodeID int64 25 | Listen string `default:":8080"` 26 | PublicAddress string 27 | PublicPort int `default:"8080"` 28 | Tags []string 29 | ConsulURL string 30 | RedisAddrs string 31 | Driver string `default:"mysql"` 32 | BaseDb string 33 | MessageDb string 34 | LogLevel string `default:"INFO"` 35 | } 36 | 37 | func (c Config) String() string { 38 | bts, _ := json.Marshal(c) 39 | return string(bts) 40 | } 41 | 42 | // Init InitConfig 43 | func Init(file string) (*Config, error) { 44 | viper.SetConfigFile(file) 45 | viper.AddConfigPath(".") 46 | viper.AddConfigPath("/etc/conf") 47 | 48 | var config Config 49 | if err := viper.ReadInConfig(); err != nil { 50 | logger.Warn(err) 51 | } else { 52 | if err := viper.Unmarshal(&config); err != nil { 53 | return nil, err 54 | } 55 | } 56 | err := envconfig.Process("kim", &config) 57 | if err != nil { 58 | return nil, err 59 | } 60 | if config.ServiceID == "" { 61 | localIP := kim.GetLocalIP() 62 | config.ServiceID = fmt.Sprintf("royal_%s", strings.ReplaceAll(localIP, ".", "")) 63 | arr := strings.Split(localIP, ".") 64 | if len(arr) == 4 { 65 | suffix, _ := strconv.Atoi(arr[3]) 66 | config.NodeID = int64(suffix) 67 | } 68 | } 69 | if config.PublicAddress == "" { 70 | config.PublicAddress = kim.GetLocalIP() 71 | } 72 | logger.Info(config) 73 | return &config, nil 74 | } 75 | 76 | func InitRedis(addr string, pass string) (*redis.Client, error) { 77 | redisdb := redis.NewClient(&redis.Options{ 78 | Addr: addr, 79 | Password: pass, 80 | DialTimeout: time.Second * 5, 81 | ReadTimeout: time.Second * 5, 82 | WriteTimeout: time.Second * 5, 83 | }) 84 | 85 | _, err := redisdb.Ping().Result() 86 | if err != nil { 87 | log.Println(err) 88 | return nil, err 89 | } 90 | return redisdb, nil 91 | } 92 | 93 | // InitFailoverRedis init redis with sentinels 94 | func InitFailoverRedis(masterName string, sentinelAddrs []string, password string, timeout time.Duration) (*redis.Client, error) { 95 | redisdb := redis.NewFailoverClient(&redis.FailoverOptions{ 96 | MasterName: masterName, 97 | SentinelAddrs: sentinelAddrs, 98 | Password: password, 99 | DialTimeout: time.Second * 5, 100 | ReadTimeout: timeout, 101 | WriteTimeout: timeout, 102 | }) 103 | 104 | _, err := redisdb.Ping().Result() 105 | if err != nil { 106 | logrus.Warn(err) 107 | } 108 | return redisdb, nil 109 | } 110 | 111 | func MakeAccessLog() *accesslog.AccessLog { 112 | // Initialize a new access log middleware. 113 | ac := accesslog.File("./access.log") 114 | // Remove this line to disable logging to console: 115 | ac.AddOutput(os.Stdout) 116 | 117 | // The default configuration: 118 | ac.Delim = '|' 119 | ac.TimeFormat = "2006-01-02 15:04:05" 120 | ac.Async = false 121 | ac.IP = true 122 | ac.BytesReceivedBody = true 123 | ac.BytesSentBody = true 124 | ac.BytesReceived = false 125 | ac.BytesSent = false 126 | ac.BodyMinify = true 127 | ac.RequestBody = true 128 | ac.ResponseBody = false 129 | ac.KeepMultiLineError = true 130 | ac.PanicLog = accesslog.LogHandler 131 | 132 | // Default line format if formatter is missing: 133 | // Time|Latency|Code|Method|Path|IP|Path Params Query Fields|Bytes Received|Bytes Sent|Request|Response| 134 | // 135 | // Set Custom Formatter: 136 | // ac.SetFormatter(&accesslog.JSON{ 137 | // Indent: " ", 138 | // HumanTime: true, 139 | // }) 140 | // ac.SetFormatter(&accesslog.CSV{}) 141 | // ac.SetFormatter(&accesslog.Template{Text: "{{.Code}}"}) 142 | 143 | return ac 144 | } 145 | -------------------------------------------------------------------------------- /services/service/database/id_generator.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import "github.com/bwmarrin/snowflake" 4 | 5 | // IDGenerator generate unique id 6 | type IDGenerator struct { 7 | node *snowflake.Node 8 | } 9 | 10 | // NewIDGenerator NewIDGenerator 11 | func NewIDGenerator(nodeID int64) (*IDGenerator, error) { 12 | node, err := snowflake.NewNode(nodeID) 13 | if err != nil { 14 | return nil, err 15 | } 16 | return &IDGenerator{node: node}, nil 17 | } 18 | 19 | // Next Generate a new id 20 | func (g *IDGenerator) Next() snowflake.ID { 21 | return g.node.Generate() 22 | } 23 | 24 | // ParseBase36 ParseBase36 25 | func (g *IDGenerator) ParseBase36(id string) (snowflake.ID, error) { 26 | return snowflake.ParseBase36(id) 27 | } 28 | 29 | func (g *IDGenerator) Parse(id int64) snowflake.ID { 30 | return snowflake.ParseInt64(id) 31 | } 32 | -------------------------------------------------------------------------------- /services/service/database/model.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // create database kim_base default character set utf8mb4 collate utf8mb4_unicode_ci; 8 | // create database kim_message default character set utf8mb4 collate utf8mb4_unicode_ci; 9 | 10 | type Model struct { 11 | ID int64 `gorm:"primarykey"` 12 | CreatedAt time.Time 13 | UpdatedAt time.Time 14 | } 15 | 16 | type MessageIndex struct { 17 | ID int64 `gorm:"primarykey"` 18 | AccountA string `gorm:"index;size:60;not null;comment:队列唯一标识"` 19 | AccountB string `gorm:"size:60;not null;comment:另一方"` 20 | Direction byte `gorm:"default:0;not null;comment:1表示AccountA为发送者"` 21 | MessageID int64 `gorm:"not null;comment:关联消息内容表中的ID"` 22 | Group string `gorm:"size:30;comment:群ID,单聊情况为空"` 23 | SendTime int64 `gorm:"index;not null;comment:消息发送时间"` 24 | } 25 | 26 | type MessageContent struct { 27 | ID int64 `gorm:"primarykey"` 28 | Type byte `gorm:"default:0"` 29 | Body string `gorm:"size:5000;not null"` 30 | Extra string `gorm:"size:500"` 31 | SendTime int64 `gorm:"index"` 32 | } 33 | 34 | type User struct { 35 | Model 36 | App string `gorm:"size:30"` 37 | Account string `gorm:"uniqueIndex;size:60"` 38 | Password string `gorm:"size:30"` 39 | Avatar string `gorm:"size:200"` 40 | Nickname string `gorm:"size:20"` 41 | } 42 | 43 | type Group struct { 44 | Model 45 | Group string `gorm:"uniqueIndex;size:30"` 46 | App string `gorm:"size:30"` 47 | Name string `gorm:"size:50"` 48 | Owner string `gorm:"size:60"` 49 | Avatar string `gorm:"size:200"` 50 | Introduction string `gorm:"size:300"` 51 | } 52 | 53 | // GroupMember GroupMember 54 | type GroupMember struct { 55 | Model 56 | Account string `gorm:"uniqueIndex:uni_gp_acc;size:60"` 57 | Group string `gorm:"uniqueIndex:uni_gp_acc;index;size:30"` 58 | Alias string `gorm:"size:30"` 59 | } 60 | -------------------------------------------------------------------------------- /services/service/database/mysql.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | 5 | // just init 6 | 7 | "log" 8 | "os" 9 | "strings" 10 | "time" 11 | 12 | "gorm.io/driver/mysql" 13 | // "gorm.io/driver/sqlite" 14 | "gorm.io/gorm" 15 | "gorm.io/gorm/logger" 16 | "gorm.io/gorm/schema" 17 | ) 18 | 19 | // InitMysqlDb init mysql database 20 | func InitDb(driver string, dsn string) (*gorm.DB, error) { 21 | // dsn := "user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local" 22 | 23 | defaultLogger := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ 24 | SlowThreshold: 200 * time.Millisecond, 25 | LogLevel: logger.Warn, 26 | Colorful: true, 27 | }) 28 | 29 | var dialector gorm.Dialector 30 | if driver == "mysql" { 31 | dialector = mysql.Open(dsn) 32 | } 33 | // else if driver == "sqlite" { 34 | // dialector = sqlite.Open(dsn) 35 | // } 36 | 37 | db, err := gorm.Open(dialector, &gorm.Config{ 38 | Logger: defaultLogger, 39 | NamingStrategy: schema.NamingStrategy{ 40 | TablePrefix: "t_", // table name prefix, table for `User` would be `t_users` 41 | SingularTable: true, // use singular table name, table for `User` would be `user` with this option enabled 42 | NameReplacer: strings.NewReplacer("CID", "Cid"), // use name replacer to change struct/field name before convert it to db name 43 | }}) 44 | 45 | return db, err 46 | } 47 | -------------------------------------------------------------------------------- /services/service/database/mysql_test.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "gorm.io/gorm" 9 | ) 10 | 11 | var db *gorm.DB 12 | var idgen *IDGenerator 13 | 14 | func init() { 15 | db, _ = InitDb("sqlite", "msg.db") 16 | 17 | _ = db.AutoMigrate(&MessageIndex{}) 18 | _ = db.AutoMigrate(&MessageContent{}) 19 | 20 | idgen, _ = NewIDGenerator(1) 21 | } 22 | 23 | func Benchmark_insert(b *testing.B) { 24 | sendTime := time.Now().UnixNano() 25 | b.ResetTimer() 26 | b.SetBytes(1024) 27 | b.ReportAllocs() 28 | b.RunParallel(func(pb *testing.PB) { 29 | for pb.Next() { 30 | idxs := make([]MessageIndex, 100) 31 | cid := idgen.Next().Int64() 32 | for i := 0; i < len(idxs); i++ { 33 | idxs[i] = MessageIndex{ 34 | ID: idgen.Next().Int64(), 35 | AccountA: fmt.Sprintf("test_%d", cid), 36 | AccountB: fmt.Sprintf("test_%d", i), 37 | SendTime: sendTime, 38 | MessageID: cid, 39 | } 40 | } 41 | db.Create(&idxs) 42 | } 43 | }) 44 | } 45 | -------------------------------------------------------------------------------- /services/service/database/redis.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "time" 7 | 8 | "github.com/go-redis/redis/v7" 9 | "github.com/sirupsen/logrus" 10 | ) 11 | 12 | // KeyMessageAckIndex return a redis key of the read index 13 | func KeyMessageAckIndex(account string) string { 14 | return fmt.Sprintf("chat:ack:%s", account) 15 | } 16 | 17 | // InitRedis return a redis instance 18 | func InitRedis(addr string, pass string) (*redis.Client, error) { 19 | redisdb := redis.NewClient(&redis.Options{ 20 | Addr: addr, 21 | Password: pass, 22 | DialTimeout: time.Second * 5, 23 | ReadTimeout: time.Second * 5, 24 | WriteTimeout: time.Second * 5, 25 | }) 26 | 27 | _, err := redisdb.Ping().Result() 28 | if err != nil { 29 | log.Println(err) 30 | return nil, err 31 | } 32 | return redisdb, nil 33 | } 34 | 35 | // InitFailoverRedis init redis with sentinels 36 | func InitFailoverRedis(masterName string, sentinelAddrs []string, password string, timeout time.Duration) (*redis.Client, error) { 37 | redisdb := redis.NewFailoverClient(&redis.FailoverOptions{ 38 | MasterName: masterName, 39 | SentinelAddrs: sentinelAddrs, 40 | Password: password, 41 | DialTimeout: time.Second * 5, 42 | ReadTimeout: timeout, 43 | WriteTimeout: timeout, 44 | }) 45 | 46 | _, err := redisdb.Ping().Result() 47 | if err != nil { 48 | logrus.Warn(err) 49 | } 50 | return redisdb, nil 51 | } 52 | -------------------------------------------------------------------------------- /services/service/handler/group_handler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/bwmarrin/snowflake" 7 | "github.com/kataras/iris/v12" 8 | "github.com/klintcheng/kim/services/service/database" 9 | "github.com/klintcheng/kim/wire/rpc" 10 | "gorm.io/gorm" 11 | ) 12 | 13 | // var log = logger.WithField("module", "service.handler") 14 | 15 | func (h *ServiceHandler) GroupCreate(c iris.Context) { 16 | app := c.Params().Get("app") 17 | var req rpc.CreateGroupReq 18 | if err := c.ReadBody(&req); err != nil { 19 | c.StopWithError(iris.StatusBadRequest, err) 20 | return 21 | } 22 | req.App = app 23 | groupId, err := h.groupCreate(&req) 24 | if err != nil { 25 | c.StopWithError(iris.StatusInternalServerError, err) 26 | return 27 | } 28 | _, _ = c.Negotiate(&rpc.CreateGroupResp{ 29 | GroupId: groupId.Base36(), 30 | }) 31 | } 32 | 33 | func (h *ServiceHandler) groupCreate(req *rpc.CreateGroupReq) (snowflake.ID, error) { 34 | groupId := h.Idgen.Next() 35 | g := &database.Group{ 36 | Model: database.Model{ 37 | ID: groupId.Int64(), 38 | }, 39 | App: req.App, 40 | Group: groupId.Base36(), 41 | Name: req.Name, 42 | Avatar: req.Avatar, 43 | Owner: req.Owner, 44 | Introduction: req.Introduction, 45 | } 46 | members := make([]database.GroupMember, len(req.Members)) 47 | for i, user := range req.Members { 48 | members[i] = database.GroupMember{ 49 | Model: database.Model{ 50 | ID: h.Idgen.Next().Int64(), 51 | }, 52 | Account: user, 53 | Group: groupId.Base36(), 54 | } 55 | } 56 | 57 | err := h.BaseDb.Transaction(func(tx *gorm.DB) error { 58 | if err := tx.Create(g).Error; err != nil { 59 | // return anywill rollback 60 | return err 61 | } 62 | if err := tx.Create(&members).Error; err != nil { 63 | return err 64 | } 65 | // return nil will commit the whole transaction 66 | return nil 67 | }) 68 | if err != nil { 69 | return 0, err 70 | } 71 | return groupId, nil 72 | } 73 | 74 | func (h *ServiceHandler) GroupJoin(c iris.Context) { 75 | // app := c.Param("app") 76 | var req rpc.JoinGroupReq 77 | if err := c.ReadBody(&req); err != nil { 78 | c.StopWithError(iris.StatusBadRequest, err) 79 | return 80 | } 81 | gm := &database.GroupMember{ 82 | Model: database.Model{ 83 | ID: h.Idgen.Next().Int64(), 84 | }, 85 | Account: req.Account, 86 | Group: req.GroupId, 87 | } 88 | err := h.BaseDb.Create(gm).Error 89 | if err != nil { 90 | c.StopWithError(iris.StatusInternalServerError, err) 91 | return 92 | } 93 | } 94 | 95 | func (h *ServiceHandler) GroupQuit(c iris.Context) { 96 | // app := c.Param("app") 97 | var req rpc.QuitGroupReq 98 | if err := c.ReadBody(&req); err != nil { 99 | c.StopWithError(iris.StatusBadRequest, err) 100 | return 101 | } 102 | gm := &database.GroupMember{ 103 | Account: req.Account, 104 | Group: req.GroupId, 105 | } 106 | err := h.BaseDb.Delete(&database.GroupMember{}, gm).Error 107 | if err != nil { 108 | c.StopWithError(iris.StatusInternalServerError, err) 109 | return 110 | } 111 | } 112 | 113 | func (h *ServiceHandler) GroupMembers(c iris.Context) { 114 | group := c.Params().Get("id") 115 | if group == "" { 116 | c.StopWithError(iris.StatusBadRequest, errors.New("group is null")) 117 | return 118 | } 119 | var members []database.GroupMember 120 | err := h.BaseDb.Order("Updated_At asc").Find(&members, database.GroupMember{Group: group}).Error 121 | if err != nil { 122 | c.StopWithError(iris.StatusInternalServerError, err) 123 | return 124 | } 125 | var users = make([]*rpc.Member, len(members)) 126 | for i, m := range members { 127 | users[i] = &rpc.Member{ 128 | Account: m.Account, 129 | Alias: m.Alias, 130 | JoinTime: m.CreatedAt.Unix(), 131 | } 132 | } 133 | _, _ = c.Negotiate(&rpc.GroupMembersResp{ 134 | Users: users, 135 | }) 136 | } 137 | 138 | func (h *ServiceHandler) GroupGet(c iris.Context) { 139 | groupId := c.Params().Get("id") 140 | if groupId == "" { 141 | c.StopWithError(iris.StatusBadRequest, errors.New("group is null")) 142 | return 143 | } 144 | id, err := h.Idgen.ParseBase36(groupId) 145 | if err != nil { 146 | c.StopWithError(iris.StatusBadRequest, errors.New("group is invaild:"+groupId)) 147 | return 148 | } 149 | var group database.Group 150 | err = h.BaseDb.First(&group, id.Int64()).Error 151 | if err != nil { 152 | c.StopWithError(iris.StatusInternalServerError, err) 153 | return 154 | } 155 | _, _ = c.Negotiate(&rpc.GetGroupResp{ 156 | Id: groupId, 157 | Name: group.Name, 158 | Avatar: group.Avatar, 159 | Introduction: group.Introduction, 160 | Owner: group.Owner, 161 | CreatedAt: group.CreatedAt.Unix(), 162 | }) 163 | } 164 | -------------------------------------------------------------------------------- /services/service/handler/message_handler_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/klintcheng/kim/services/service/database" 9 | "github.com/klintcheng/kim/wire/rpc" 10 | "github.com/segmentio/ksuid" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var handler ServiceHandler 15 | 16 | func init() { 17 | baseDb, _ := database.InitDb("mysql", "root:123456@tcp(127.0.0.1:3306)/kim_base?charset=utf8mb4&parseTime=True&loc=Local") 18 | messageDb, _ := database.InitDb("mysql", "root:123456@tcp(127.0.0.1:3306)/kim_message?charset=utf8mb4&parseTime=True&loc=Local") 19 | idgen, _ := database.NewIDGenerator(1) 20 | handler = ServiceHandler{ 21 | MessageDb: messageDb, 22 | BaseDb: baseDb, 23 | Idgen: idgen, 24 | } 25 | } 26 | 27 | func Benchmark_InsertUserMessage(b *testing.B) { 28 | 29 | b.ResetTimer() 30 | b.SetBytes(1024) 31 | b.ReportAllocs() 32 | b.RunParallel(func(pb *testing.PB) { 33 | for pb.Next() { 34 | _, _ = handler.insertUserMessage(&rpc.InsertMessageReq{ 35 | Sender: "test1", 36 | Dest: ksuid.New().String(), 37 | SendTime: time.Now().UnixNano(), 38 | Message: &rpc.Message{ 39 | Type: 1, 40 | Body: "hello", 41 | }, 42 | }) 43 | } 44 | }) 45 | } 46 | 47 | func Benchmark_InsertGroup10Message(b *testing.B) { 48 | memberCount := 10 49 | 50 | var members = make([]string, memberCount) 51 | for i := 0; i < memberCount; i++ { 52 | members[i] = fmt.Sprintf("test%d", i+1) 53 | } 54 | 55 | groupId, err := handler.groupCreate(&rpc.CreateGroupReq{ 56 | App: "kim_t", 57 | Name: "testg", 58 | Owner: "test1", 59 | Members: members, 60 | }) 61 | assert.Nil(b, err) 62 | 63 | b.ResetTimer() 64 | b.SetBytes(1024) 65 | b.ReportAllocs() 66 | b.RunParallel(func(pb *testing.PB) { 67 | for pb.Next() { 68 | _, _ = handler.insertGroupMessage(&rpc.InsertMessageReq{ 69 | Sender: "test1", 70 | Dest: groupId.Base36(), 71 | SendTime: time.Now().UnixNano(), 72 | Message: &rpc.Message{ 73 | Type: 1, 74 | Body: "hello", 75 | }, 76 | }) 77 | } 78 | }) 79 | } 80 | 81 | func Benchmark_InsertGroup50Message(b *testing.B) { 82 | memberCount := 50 83 | 84 | var members = make([]string, memberCount) 85 | for i := 0; i < memberCount; i++ { 86 | members[i] = fmt.Sprintf("test%d", i+1) 87 | } 88 | 89 | groupId, err := handler.groupCreate(&rpc.CreateGroupReq{ 90 | App: "kim_t", 91 | Name: "testg", 92 | Owner: "test1", 93 | Members: members, 94 | }) 95 | assert.Nil(b, err) 96 | 97 | b.ResetTimer() 98 | b.SetBytes(1024) 99 | b.ReportAllocs() 100 | b.RunParallel(func(pb *testing.PB) { 101 | for pb.Next() { 102 | _, _ = handler.insertGroupMessage(&rpc.InsertMessageReq{ 103 | Sender: "test1", 104 | Dest: groupId.Base36(), 105 | SendTime: time.Now().UnixNano(), 106 | Message: &rpc.Message{ 107 | Type: 1, 108 | Body: "hello", 109 | }, 110 | }) 111 | } 112 | }) 113 | } 114 | -------------------------------------------------------------------------------- /storage.go: -------------------------------------------------------------------------------- 1 | package kim 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/klintcheng/kim/wire/pkt" 7 | ) 8 | 9 | // ErrNil 10 | var ErrSessionNil = errors.New("err:session nil") 11 | 12 | // SessionStorage defined a session storage which provides based functions as save,delete,find a session 13 | type SessionStorage interface { 14 | // Add a session 15 | Add(session *pkt.Session) error 16 | // Delete a session 17 | Delete(account string, channelId string) error 18 | // Get session by channelId 19 | Get(channelId string) (*pkt.Session, error) 20 | // Get Locations by accounts 21 | GetLocations(account ...string) ([]*Location, error) 22 | // Get Location by account and device 23 | GetLocation(account string, device string) (*Location, error) 24 | } 25 | -------------------------------------------------------------------------------- /storage/redis_impl.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/go-redis/redis/v7" 8 | "github.com/klintcheng/kim" 9 | "github.com/klintcheng/kim/wire/pkt" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | const ( 14 | LocationExpired = time.Hour * 48 15 | ) 16 | 17 | type RedisStorage struct { 18 | cli *redis.Client 19 | } 20 | 21 | func NewRedisStorage(cli *redis.Client) kim.SessionStorage { 22 | return &RedisStorage{ 23 | cli: cli, 24 | } 25 | } 26 | 27 | func (r *RedisStorage) Add(session *pkt.Session) error { 28 | // save kim.Location 29 | loc := kim.Location{ 30 | ChannelId: session.ChannelId, 31 | GateId: session.GateId, 32 | } 33 | locKey := KeyLocation(session.Account, "") 34 | err := r.cli.Set(locKey, loc.Bytes(), LocationExpired).Err() 35 | if err != nil { 36 | return err 37 | } 38 | // save session 39 | snKey := KeySession(session.ChannelId) 40 | buf, _ := proto.Marshal(session) 41 | err = r.cli.Set(snKey, buf, LocationExpired).Err() 42 | if err != nil { 43 | return err 44 | } 45 | return nil 46 | } 47 | 48 | // Delete a session 49 | func (r *RedisStorage) Delete(account string, channelId string) error { 50 | locKey := KeyLocation(account, "") 51 | err := r.cli.Del(locKey).Err() 52 | if err != nil { 53 | return err 54 | } 55 | 56 | snKey := KeySession(channelId) 57 | err = r.cli.Del(snKey).Err() 58 | if err != nil { 59 | return err 60 | } 61 | return nil 62 | } 63 | 64 | // GetByID get session by sessionID 65 | func (r *RedisStorage) Get(channelId string) (*pkt.Session, error) { 66 | snKey := KeySession(channelId) 67 | bts, err := r.cli.Get(snKey).Bytes() 68 | if err != nil { 69 | if err == redis.Nil { 70 | return nil, kim.ErrSessionNil 71 | } 72 | return nil, err 73 | } 74 | var session pkt.Session 75 | _ = proto.Unmarshal(bts, &session) 76 | return &session, nil 77 | } 78 | 79 | func (r *RedisStorage) GetLocations(accounts ...string) ([]*kim.Location, error) { 80 | keys := KeyLocations(accounts...) 81 | list, err := r.cli.MGet(keys...).Result() 82 | if err != nil { 83 | return nil, err 84 | } 85 | var result = make([]*kim.Location, 0) 86 | for _, l := range list { 87 | if l == nil { 88 | continue 89 | } 90 | var loc kim.Location 91 | _ = loc.Unmarshal([]byte(l.(string))) 92 | result = append(result, &loc) 93 | } 94 | if len(result) == 0 { 95 | return nil, kim.ErrSessionNil 96 | } 97 | return result, nil 98 | } 99 | 100 | func (r *RedisStorage) GetLocation(account string, device string) (*kim.Location, error) { 101 | key := KeyLocation(account, device) 102 | bts, err := r.cli.Get(key).Bytes() 103 | if err != nil { 104 | if err == redis.Nil { 105 | return nil, kim.ErrSessionNil 106 | } 107 | return nil, err 108 | } 109 | var loc kim.Location 110 | _ = loc.Unmarshal(bts) 111 | return &loc, nil 112 | } 113 | 114 | func KeySession(channel string) string { 115 | return fmt.Sprintf("login:sn:%s", channel) 116 | } 117 | 118 | func KeyLocation(account, device string) string { 119 | if device == "" { 120 | return fmt.Sprintf("login:loc:%s", account) 121 | } 122 | return fmt.Sprintf("login:loc:%s:%s", account, device) 123 | } 124 | 125 | func KeyLocations(accounts ...string) []string { 126 | arr := make([]string, len(accounts)) 127 | for i, account := range accounts { 128 | arr[i] = KeyLocation(account, "") 129 | } 130 | return arr 131 | } 132 | -------------------------------------------------------------------------------- /storage/redis_test.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math/rand" 7 | "testing" 8 | "time" 9 | 10 | "github.com/go-redis/redis/v7" 11 | "github.com/klintcheng/kim" 12 | "github.com/klintcheng/kim/wire/pkt" 13 | "github.com/segmentio/ksuid" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func Test_crud(t *testing.T) { 18 | cli, err := InitRedis("localhost:6379", "") 19 | assert.Nil(t, err) 20 | cc := NewRedisStorage(cli) 21 | err = cc.Add(&pkt.Session{ 22 | ChannelId: "ch1", 23 | GateId: "gateway1", 24 | Account: "test1", 25 | Device: "Phone", 26 | }) 27 | assert.Nil(t, err) 28 | 29 | _ = cc.Add(&pkt.Session{ 30 | ChannelId: "ch2", 31 | GateId: "gateway1", 32 | Account: "test2", 33 | Device: "Pc", 34 | }) 35 | 36 | session, err := cc.Get("ch1") 37 | assert.Nil(t, err) 38 | t.Log(session) 39 | assert.Equal(t, "ch1", session.ChannelId) 40 | assert.Equal(t, "gateway1", session.GateId) 41 | assert.Equal(t, "test1", session.Account) 42 | 43 | arr, err := cc.GetLocations("test1", "test2") 44 | assert.Nil(t, err) 45 | t.Log(arr) 46 | loc := arr[1] 47 | 48 | arr, err = cc.GetLocations("test6") 49 | assert.Equal(t, kim.ErrSessionNil, err) 50 | assert.Equal(t, 0, len(arr)) 51 | 52 | assert.Equal(t, "ch2", loc.ChannelId) 53 | assert.Equal(t, "gateway1", loc.GateId) 54 | } 55 | 56 | func Benchmark_MGET(b *testing.B) { 57 | cli, err := InitRedis("localhost:6379", "") 58 | assert.Nil(b, err) 59 | cc := NewRedisStorage(cli) 60 | count := 5 61 | accounts := make([]string, count) 62 | for i := 0; i < count; i++ { 63 | accounts[i] = fmt.Sprintf("account_%d", i) 64 | err = cc.Add(&pkt.Session{ 65 | ChannelId: ksuid.New().String(), 66 | GateId: "gateway1", 67 | Account: accounts[i], 68 | }) 69 | assert.Nil(b, err) 70 | } 71 | 72 | b.ResetTimer() 73 | b.ReportAllocs() 74 | 75 | b.RunParallel(func(pb *testing.PB) { 76 | for pb.Next() { 77 | _, err := cc.GetLocations(accounts...) 78 | assert.Nil(b, err) 79 | } 80 | }) 81 | } 82 | 83 | func Benchmark_getLocation(b *testing.B) { 84 | cli, err := InitRedis("localhost:6379", "") 85 | assert.Nil(b, err) 86 | cc := NewRedisStorage(cli) 87 | 88 | accs := make([]string, 100) 89 | for i := 0; i < 100; i++ { 90 | accs[i] = ksuid.New().String() 91 | _ = cc.Add(&pkt.Session{ 92 | ChannelId: ksuid.New().String(), 93 | GateId: "127_0_0_1_gateway1", 94 | Account: accs[i], 95 | Zone: "testtesttesttest", 96 | Isp: "moblie", 97 | RemoteIP: "127.0.0.1", 98 | App: "kim", 99 | Tags: []string{"tag1", "tag2"}, 100 | }) 101 | } 102 | 103 | b.ResetTimer() 104 | b.ReportAllocs() 105 | 106 | b.RunParallel(func(pb *testing.PB) { 107 | for pb.Next() { 108 | account := accs[rand.Intn(100)] 109 | _, err := cc.GetLocation(account, "") 110 | assert.Nil(b, err) 111 | } 112 | }) 113 | } 114 | 115 | func Benchmark_getSession(b *testing.B) { 116 | cli, err := InitRedis("localhost:6379", "") 117 | assert.Nil(b, err) 118 | cc := NewRedisStorage(cli) 119 | 120 | ids := make([]string, 100) 121 | for i := 0; i < 100; i++ { 122 | ids[i] = ksuid.New().String() 123 | _ = cc.Add(&pkt.Session{ 124 | ChannelId: ids[i], 125 | GateId: "127_0_0_1_gateway1", 126 | Account: ksuid.New().String(), 127 | Zone: "testtesttesttest", 128 | Isp: "moblie", 129 | RemoteIP: "127.0.0.1", 130 | App: "kim", 131 | Tags: []string{"tag1", "tag2"}, 132 | }) 133 | } 134 | 135 | b.ResetTimer() 136 | b.ReportAllocs() 137 | 138 | b.RunParallel(func(pb *testing.PB) { 139 | for pb.Next() { 140 | _, err := cc.Get(ids[rand.Intn(100)]) 141 | assert.Nil(b, err) 142 | } 143 | }) 144 | } 145 | 146 | func InitRedis(addr string, pass string) (*redis.Client, error) { 147 | redisdb := redis.NewClient(&redis.Options{ 148 | Addr: addr, 149 | Password: pass, 150 | DialTimeout: time.Second * 5, 151 | ReadTimeout: time.Second * 5, 152 | WriteTimeout: time.Second * 5, 153 | }) 154 | 155 | _, err := redisdb.Ping().Result() 156 | if err != nil { 157 | log.Println(err) 158 | return nil, err 159 | } 160 | return redisdb, nil 161 | } 162 | -------------------------------------------------------------------------------- /storage_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: storage.go 3 | 4 | // Package kim is a generated GoMock package. 5 | package kim 6 | 7 | import ( 8 | reflect "reflect" 9 | 10 | gomock "github.com/golang/mock/gomock" 11 | pkt "github.com/klintcheng/kim/wire/pkt" 12 | ) 13 | 14 | // MockSessionStorage is a mock of SessionStorage interface. 15 | type MockSessionStorage struct { 16 | ctrl *gomock.Controller 17 | recorder *MockSessionStorageMockRecorder 18 | } 19 | 20 | // MockSessionStorageMockRecorder is the mock recorder for MockSessionStorage. 21 | type MockSessionStorageMockRecorder struct { 22 | mock *MockSessionStorage 23 | } 24 | 25 | // NewMockSessionStorage creates a new mock instance. 26 | func NewMockSessionStorage(ctrl *gomock.Controller) *MockSessionStorage { 27 | mock := &MockSessionStorage{ctrl: ctrl} 28 | mock.recorder = &MockSessionStorageMockRecorder{mock} 29 | return mock 30 | } 31 | 32 | // EXPECT returns an object that allows the caller to indicate expected use. 33 | func (m *MockSessionStorage) EXPECT() *MockSessionStorageMockRecorder { 34 | return m.recorder 35 | } 36 | 37 | // Add mocks base method. 38 | func (m *MockSessionStorage) Add(session *pkt.Session) error { 39 | m.ctrl.T.Helper() 40 | ret := m.ctrl.Call(m, "Add", session) 41 | ret0, _ := ret[0].(error) 42 | return ret0 43 | } 44 | 45 | // Add indicates an expected call of Add. 46 | func (mr *MockSessionStorageMockRecorder) Add(session interface{}) *gomock.Call { 47 | mr.mock.ctrl.T.Helper() 48 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionStorage)(nil).Add), session) 49 | } 50 | 51 | // Delete mocks base method. 52 | func (m *MockSessionStorage) Delete(account, channelId string) error { 53 | m.ctrl.T.Helper() 54 | ret := m.ctrl.Call(m, "Delete", account, channelId) 55 | ret0, _ := ret[0].(error) 56 | return ret0 57 | } 58 | 59 | // Delete indicates an expected call of Delete. 60 | func (mr *MockSessionStorageMockRecorder) Delete(account, channelId interface{}) *gomock.Call { 61 | mr.mock.ctrl.T.Helper() 62 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSessionStorage)(nil).Delete), account, channelId) 63 | } 64 | 65 | // Get mocks base method. 66 | func (m *MockSessionStorage) Get(channelId string) (*pkt.Session, error) { 67 | m.ctrl.T.Helper() 68 | ret := m.ctrl.Call(m, "Get", channelId) 69 | ret0, _ := ret[0].(*pkt.Session) 70 | ret1, _ := ret[1].(error) 71 | return ret0, ret1 72 | } 73 | 74 | // Get indicates an expected call of Get. 75 | func (mr *MockSessionStorageMockRecorder) Get(channelId interface{}) *gomock.Call { 76 | mr.mock.ctrl.T.Helper() 77 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionStorage)(nil).Get), channelId) 78 | } 79 | 80 | // GetLocation mocks base method. 81 | func (m *MockSessionStorage) GetLocation(account, device string) (*Location, error) { 82 | m.ctrl.T.Helper() 83 | ret := m.ctrl.Call(m, "GetLocation", account, device) 84 | ret0, _ := ret[0].(*Location) 85 | ret1, _ := ret[1].(error) 86 | return ret0, ret1 87 | } 88 | 89 | // GetLocation indicates an expected call of GetLocation. 90 | func (mr *MockSessionStorageMockRecorder) GetLocation(account, device interface{}) *gomock.Call { 91 | mr.mock.ctrl.T.Helper() 92 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLocation", reflect.TypeOf((*MockSessionStorage)(nil).GetLocation), account, device) 93 | } 94 | 95 | // GetLocations mocks base method. 96 | func (m *MockSessionStorage) GetLocations(account ...string) ([]*Location, error) { 97 | m.ctrl.T.Helper() 98 | varargs := []interface{}{} 99 | for _, a := range account { 100 | varargs = append(varargs, a) 101 | } 102 | ret := m.ctrl.Call(m, "GetLocations", varargs...) 103 | ret0, _ := ret[0].([]*Location) 104 | ret1, _ := ret[1].(error) 105 | return ret0, ret1 106 | } 107 | 108 | // GetLocations indicates an expected call of GetLocations. 109 | func (mr *MockSessionStorageMockRecorder) GetLocations(account ...interface{}) *gomock.Call { 110 | mr.mock.ctrl.T.Helper() 111 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLocations", reflect.TypeOf((*MockSessionStorage)(nil).GetLocations), account...) 112 | } 113 | -------------------------------------------------------------------------------- /tcp/client.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | 10 | "github.com/klintcheng/kim" 11 | "github.com/klintcheng/kim/logger" 12 | ) 13 | 14 | // ClientOptions ClientOptions 15 | type ClientOptions struct { 16 | Heartbeat time.Duration //登录超时 17 | ReadWait time.Duration //读超时 18 | WriteWait time.Duration //写超时 19 | } 20 | 21 | // Client is a websocket implement of the terminal 22 | type Client struct { 23 | sync.Mutex 24 | kim.Dialer 25 | once sync.Once 26 | id string 27 | name string 28 | conn kim.Conn 29 | state int32 30 | options ClientOptions 31 | Meta map[string]string 32 | } 33 | 34 | // NewClient NewClient 35 | func NewClient(id, name string, opts ClientOptions) kim.Client { 36 | return NewClientWithProps(id, name, make(map[string]string), opts) 37 | } 38 | 39 | func NewClientWithProps(id, name string, meta map[string]string, opts ClientOptions) kim.Client { 40 | if opts.WriteWait == 0 { 41 | opts.WriteWait = kim.DefaultWriteWait 42 | } 43 | if opts.ReadWait == 0 { 44 | opts.ReadWait = kim.DefaultReadWait 45 | } 46 | 47 | cli := &Client{ 48 | id: id, 49 | name: name, 50 | options: opts, 51 | Meta: meta, 52 | } 53 | return cli 54 | } 55 | 56 | // Connect to server 57 | func (c *Client) Connect(addr string) error { 58 | // 这里是一个CAS原子操作,对比并设置值,是并发安全的。 59 | if !atomic.CompareAndSwapInt32(&c.state, 0, 1) { 60 | return fmt.Errorf("client has connected") 61 | } 62 | 63 | rawconn, err := c.Dialer.DialAndHandshake(kim.DialerContext{ 64 | Id: c.id, 65 | Name: c.name, 66 | Address: addr, 67 | Timeout: kim.DefaultLoginWait, 68 | }) 69 | if err != nil { 70 | atomic.CompareAndSwapInt32(&c.state, 1, 0) 71 | return err 72 | } 73 | if rawconn == nil { 74 | return fmt.Errorf("conn is nil") 75 | } 76 | c.conn = NewConn(rawconn) 77 | 78 | if c.options.Heartbeat > 0 { 79 | go func() { 80 | err := c.heartbeatloop() 81 | if err != nil { 82 | logger.WithField("module", "tcp.client").Warn("heartbeatloop stopped - ", err) 83 | } 84 | }() 85 | } 86 | return nil 87 | } 88 | 89 | // SetDialer 设置握手逻辑 90 | func (c *Client) SetDialer(dialer kim.Dialer) { 91 | c.Dialer = dialer 92 | } 93 | 94 | //Send data to connection 95 | func (c *Client) Send(payload []byte) error { 96 | if atomic.LoadInt32(&c.state) == 0 { 97 | return fmt.Errorf("connection is nil") 98 | } 99 | c.Lock() 100 | defer c.Unlock() 101 | err := c.conn.WriteFrame(kim.OpBinary, payload) 102 | if err != nil { 103 | return err 104 | } 105 | return c.conn.Flush() 106 | } 107 | 108 | // Close 关闭 109 | func (c *Client) Close() { 110 | c.once.Do(func() { 111 | if c.conn == nil { 112 | return 113 | } 114 | // graceful close connection 115 | _ = c.conn.WriteFrame(kim.OpClose, nil) 116 | c.conn.Flush() 117 | 118 | c.conn.Close() 119 | atomic.CompareAndSwapInt32(&c.state, 1, 0) 120 | }) 121 | } 122 | 123 | func (c *Client) Read() (kim.Frame, error) { 124 | if c.conn == nil { 125 | return nil, errors.New("connection is nil") 126 | } 127 | if c.options.Heartbeat > 0 { 128 | _ = c.conn.SetReadDeadline(time.Now().Add(c.options.ReadWait)) 129 | } 130 | frame, err := c.conn.ReadFrame() 131 | if err != nil { 132 | return nil, err 133 | } 134 | if frame.GetOpCode() == kim.OpClose { 135 | return nil, errors.New("remote side close the channel") 136 | } 137 | return frame, nil 138 | } 139 | 140 | func (c *Client) heartbeatloop() error { 141 | tick := time.NewTicker(c.options.Heartbeat) 142 | for range tick.C { 143 | // 发送一个ping的心跳包给服务端 144 | if err := c.ping(); err != nil { 145 | return err 146 | } 147 | } 148 | return nil 149 | } 150 | 151 | func (c *Client) ping() error { 152 | logger.WithField("module", "tcp.client").Tracef("%s send ping to server", c.id) 153 | 154 | err := c.conn.WriteFrame(kim.OpPing, nil) 155 | if err != nil { 156 | return err 157 | } 158 | return c.conn.Flush() 159 | } 160 | 161 | // ID return id 162 | func (c *Client) ServiceID() string { 163 | return c.id 164 | } 165 | 166 | // Name Name 167 | func (c *Client) ServiceName() string { 168 | return c.name 169 | } 170 | func (c *Client) GetMeta() map[string]string { return c.Meta } 171 | -------------------------------------------------------------------------------- /tcp/connection.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "net" 7 | 8 | "github.com/klintcheng/kim" 9 | "github.com/klintcheng/kim/wire/endian" 10 | ) 11 | 12 | // Frame Frame 13 | type Frame struct { 14 | OpCode kim.OpCode 15 | Payload []byte 16 | } 17 | 18 | // SetOpCode SetOpCode 19 | func (f *Frame) SetOpCode(code kim.OpCode) { 20 | f.OpCode = code 21 | } 22 | 23 | // GetOpCode GetOpCode 24 | func (f *Frame) GetOpCode() kim.OpCode { 25 | return f.OpCode 26 | } 27 | 28 | // SetPayload SetPayload 29 | func (f *Frame) SetPayload(payload []byte) { 30 | f.Payload = payload 31 | } 32 | 33 | // GetPayload GetPayload 34 | func (f *Frame) GetPayload() []byte { 35 | return f.Payload 36 | } 37 | 38 | // TcpConn Conn 39 | type TcpConn struct { 40 | net.Conn 41 | rd *bufio.Reader 42 | wr *bufio.Writer 43 | } 44 | 45 | // NewConn NewConn 46 | 47 | func NewConn(conn net.Conn) kim.Conn { 48 | return &TcpConn{ 49 | Conn: conn, 50 | rd: bufio.NewReaderSize(conn, 4096), 51 | wr: bufio.NewWriterSize(conn, 1024), 52 | } 53 | } 54 | 55 | func NewConnWithRW(conn net.Conn, rd *bufio.Reader, wr *bufio.Writer) *TcpConn { 56 | return &TcpConn{ 57 | Conn: conn, 58 | rd: rd, 59 | wr: wr, 60 | } 61 | } 62 | 63 | // ReadFrame ReadFrame 64 | func (c *TcpConn) ReadFrame() (kim.Frame, error) { 65 | opcode, err := endian.ReadUint8(c.rd) 66 | if err != nil { 67 | return nil, err 68 | } 69 | payload, err := endian.ReadBytes(c.rd) 70 | if err != nil { 71 | return nil, err 72 | } 73 | return &Frame{ 74 | OpCode: kim.OpCode(opcode), 75 | Payload: payload, 76 | }, nil 77 | } 78 | 79 | // WriteFrame WriteFrame 80 | func (c *TcpConn) WriteFrame(code kim.OpCode, payload []byte) error { 81 | return WriteFrame(c.wr, code, payload) 82 | } 83 | 84 | // Flush Flush 85 | func (c *TcpConn) Flush() error { 86 | return c.wr.Flush() 87 | } 88 | 89 | // WriteFrame write a frame to w 90 | func WriteFrame(w io.Writer, code kim.OpCode, payload []byte) error { 91 | if err := endian.WriteUint8(w, uint8(code)); err != nil { 92 | return err 93 | } 94 | if err := endian.WriteBytes(w, payload); err != nil { 95 | return err 96 | } 97 | return nil 98 | } 99 | -------------------------------------------------------------------------------- /tcp/server.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "bufio" 5 | "net" 6 | 7 | "github.com/klintcheng/kim" 8 | ) 9 | 10 | // Server is a websocket implement of the Server 11 | type Upgrader struct { 12 | } 13 | 14 | // NewServer NewServer 15 | func NewServer(listen string, service kim.ServiceRegistration, options ...kim.ServerOption) kim.Server { 16 | return kim.NewServer(listen, service, new(Upgrader), options...) 17 | } 18 | 19 | func (u *Upgrader) Name() string { 20 | return "tcp.Server" 21 | } 22 | 23 | func (u *Upgrader) Upgrade(rawconn net.Conn, rd *bufio.Reader, wr *bufio.Writer) (kim.Conn, error) { 24 | conn := NewConnWithRW(rawconn, rd, wr) 25 | return conn, nil 26 | } 27 | -------------------------------------------------------------------------------- /websocket/client.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "net/url" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/gobwas/ws" 13 | "github.com/gobwas/ws/wsutil" 14 | "github.com/klintcheng/kim" 15 | "github.com/klintcheng/kim/logger" 16 | ) 17 | 18 | // ClientOptions ClientOptions 19 | type ClientOptions struct { 20 | Heartbeat time.Duration //登录超时 21 | ReadWait time.Duration //读超时 22 | WriteWait time.Duration //写超时 23 | } 24 | 25 | // Client is a websocket implement of the terminal 26 | type Client struct { 27 | sync.Mutex 28 | kim.Dialer 29 | once sync.Once 30 | id string 31 | name string 32 | conn net.Conn 33 | state int32 34 | options ClientOptions 35 | Meta map[string]string 36 | } 37 | 38 | // NewClient NewClient 39 | func NewClient(id, name string, opts ClientOptions) kim.Client { 40 | return NewClientWithProps(id, name, make(map[string]string), opts) 41 | } 42 | 43 | func NewClientWithProps(id, name string, meta map[string]string, opts ClientOptions) kim.Client { 44 | if opts.WriteWait == 0 { 45 | opts.WriteWait = kim.DefaultWriteWait 46 | } 47 | if opts.ReadWait == 0 { 48 | opts.ReadWait = kim.DefaultReadWait 49 | } 50 | 51 | cli := &Client{ 52 | id: id, 53 | name: name, 54 | options: opts, 55 | Meta: meta, 56 | } 57 | return cli 58 | } 59 | 60 | // Connect to server 61 | func (c *Client) Connect(addr string) error { 62 | _, err := url.Parse(addr) 63 | if err != nil { 64 | return err 65 | } 66 | if !atomic.CompareAndSwapInt32(&c.state, 0, 1) { 67 | return fmt.Errorf("client has connected") 68 | } 69 | // step 1 拨号及握手 70 | conn, err := c.Dialer.DialAndHandshake(kim.DialerContext{ 71 | Id: c.id, 72 | Name: c.name, 73 | Address: addr, 74 | Timeout: kim.DefaultLoginWait, 75 | }) 76 | if err != nil { 77 | atomic.CompareAndSwapInt32(&c.state, 1, 0) 78 | return err 79 | } 80 | if conn == nil { 81 | return fmt.Errorf("conn is nil") 82 | } 83 | c.conn = conn 84 | 85 | if c.options.Heartbeat > 0 { 86 | go func() { 87 | err := c.heartbeatloop(conn) 88 | if err != nil { 89 | logger.Error("heartbeatloop stopped ", err) 90 | } 91 | }() 92 | } 93 | return nil 94 | } 95 | 96 | // SetDialer 设置握手逻辑 97 | func (c *Client) SetDialer(dialer kim.Dialer) { 98 | c.Dialer = dialer 99 | } 100 | 101 | //Send data to connection 102 | func (c *Client) Send(payload []byte) error { 103 | if atomic.LoadInt32(&c.state) == 0 { 104 | return fmt.Errorf("connection is nil") 105 | } 106 | c.Lock() 107 | defer c.Unlock() 108 | err := c.conn.SetWriteDeadline(time.Now().Add(c.options.WriteWait)) 109 | if err != nil { 110 | return err 111 | } 112 | // 客户端消息需要使用MASK 113 | return wsutil.WriteClientMessage(c.conn, ws.OpBinary, payload) 114 | } 115 | 116 | // Close 关闭 117 | func (c *Client) Close() { 118 | c.once.Do(func() { 119 | if c.conn == nil { 120 | return 121 | } 122 | // graceful close connection 123 | _ = wsutil.WriteClientMessage(c.conn, ws.OpClose, nil) 124 | 125 | c.conn.Close() 126 | atomic.CompareAndSwapInt32(&c.state, 1, 0) 127 | }) 128 | } 129 | 130 | // Read a frame ,this function is not safely for concurrent 131 | func (c *Client) Read() (kim.Frame, error) { 132 | if c.conn == nil { 133 | return nil, errors.New("connection is nil") 134 | } 135 | if c.options.Heartbeat > 0 { 136 | _ = c.conn.SetReadDeadline(time.Now().Add(c.options.ReadWait)) 137 | } 138 | frame, err := ws.ReadFrame(c.conn) 139 | if err != nil { 140 | return nil, err 141 | } 142 | if frame.Header.OpCode == ws.OpClose { 143 | return nil, errors.New("remote side close the channel") 144 | } 145 | return &Frame{ 146 | raw: frame, 147 | }, nil 148 | } 149 | 150 | func (c *Client) heartbeatloop(conn net.Conn) error { 151 | tick := time.NewTicker(c.options.Heartbeat) 152 | for range tick.C { 153 | // 发送一个ping的心跳包给服务端 154 | if err := c.ping(conn); err != nil { 155 | return err 156 | } 157 | } 158 | return nil 159 | } 160 | 161 | func (c *Client) ping(conn net.Conn) error { 162 | c.Lock() 163 | defer c.Unlock() 164 | err := conn.SetWriteDeadline(time.Now().Add(c.options.WriteWait)) 165 | if err != nil { 166 | return err 167 | } 168 | logger.Tracef("%s send ping to server", c.id) 169 | return wsutil.WriteClientMessage(conn, ws.OpPing, nil) 170 | } 171 | 172 | // ID return id 173 | func (c *Client) ServiceID() string { 174 | return c.id 175 | } 176 | 177 | // Name Name 178 | func (c *Client) ServiceName() string { 179 | return c.name 180 | } 181 | 182 | func (c *Client) GetMeta() map[string]string { 183 | return c.Meta 184 | } 185 | -------------------------------------------------------------------------------- /websocket/connection.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bufio" 5 | "net" 6 | 7 | "github.com/gobwas/ws" 8 | "github.com/klintcheng/kim" 9 | ) 10 | 11 | type Frame struct { 12 | raw ws.Frame 13 | } 14 | 15 | func (f *Frame) SetOpCode(code kim.OpCode) { 16 | f.raw.Header.OpCode = ws.OpCode(code) 17 | } 18 | 19 | func (f *Frame) GetOpCode() kim.OpCode { 20 | return kim.OpCode(f.raw.Header.OpCode) 21 | } 22 | 23 | func (f *Frame) SetPayload(payload []byte) { 24 | f.raw.Payload = payload 25 | } 26 | 27 | func (f *Frame) GetPayload() []byte { 28 | if f.raw.Header.Masked { 29 | ws.Cipher(f.raw.Payload, f.raw.Header.Mask, 0) 30 | } 31 | f.raw.Header.Masked = false 32 | return f.raw.Payload 33 | } 34 | 35 | type WsConn struct { 36 | net.Conn 37 | rd *bufio.Reader 38 | wr *bufio.Writer 39 | } 40 | 41 | func NewConn(conn net.Conn) kim.Conn { 42 | return &WsConn{ 43 | Conn: conn, 44 | rd: bufio.NewReaderSize(conn, 4096), 45 | wr: bufio.NewWriterSize(conn, 1024), 46 | } 47 | } 48 | 49 | func NewConnWithRW(conn net.Conn, rd *bufio.Reader, wr *bufio.Writer) *WsConn { 50 | return &WsConn{ 51 | Conn: conn, 52 | rd: rd, 53 | wr: wr, 54 | } 55 | } 56 | 57 | func (c *WsConn) ReadFrame() (kim.Frame, error) { 58 | f, err := ws.ReadFrame(c.rd) 59 | if err != nil { 60 | return nil, err 61 | } 62 | return &Frame{raw: f}, nil 63 | } 64 | 65 | func (c *WsConn) WriteFrame(code kim.OpCode, payload []byte) error { 66 | f := ws.NewFrame(ws.OpCode(code), true, payload) 67 | return ws.WriteFrame(c.wr, f) 68 | } 69 | 70 | func (c *WsConn) Flush() error { 71 | return c.wr.Flush() 72 | } 73 | -------------------------------------------------------------------------------- /websocket/server.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bufio" 5 | "net" 6 | 7 | "github.com/gobwas/ws" 8 | "github.com/klintcheng/kim" 9 | ) 10 | 11 | // Server is a websocket implement of the Server 12 | type Upgrader struct { 13 | } 14 | 15 | // NewServer NewServer 16 | func NewServer(listen string, service kim.ServiceRegistration, options ...kim.ServerOption) kim.Server { 17 | return kim.NewServer(listen, service, new(Upgrader), options...) 18 | } 19 | 20 | func (u *Upgrader) Name() string { 21 | return "websocket.Server" 22 | } 23 | 24 | func (u *Upgrader) Upgrade(rawconn net.Conn, rd *bufio.Reader, wr *bufio.Writer) (kim.Conn, error) { 25 | _, err := ws.Upgrade(rawconn) 26 | if err != nil { 27 | return nil, err 28 | } 29 | conn := NewConnWithRW(rawconn, rd, wr) 30 | return conn, nil 31 | } 32 | -------------------------------------------------------------------------------- /wire/build.sh: -------------------------------------------------------------------------------- 1 | # export PATH="$PATH:$(go env GOPATH)/bin" 2 | protoc -I proto/ --go_out=. proto/*.proto -------------------------------------------------------------------------------- /wire/definitions.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import "time" 4 | 5 | // algorithm in routing 6 | const ( 7 | AlgorithmHashSlots = "hashslots" 8 | ) 9 | 10 | // Command defined data type between client and server 11 | const ( 12 | // login 13 | CommandLoginSignIn = "login.signin" 14 | CommandLoginSignOut = "login.signout" 15 | 16 | // chat 17 | CommandChatUserTalk = "chat.user.talk" 18 | CommandChatGroupTalk = "chat.group.talk" 19 | CommandChatTalkAck = "chat.talk.ack" 20 | 21 | // 离线 22 | CommandOfflineIndex = "chat.offline.index" 23 | CommandOfflineContent = "chat.offline.content" 24 | 25 | // 群管理 26 | CommandGroupCreate = "chat.group.create" 27 | CommandGroupJoin = "chat.group.join" 28 | CommandGroupQuit = "chat.group.quit" 29 | CommandGroupMembers = "chat.group.members" 30 | CommandGroupDetail = "chat.group.detail" 31 | ) 32 | 33 | // Meta Key of a packet 34 | const ( 35 | // 消息将要送达的网关的ServiceName 36 | MetaDestServer = "dest.server" 37 | // 消息将要送达的channels 38 | MetaDestChannels = "dest.channels" 39 | ) 40 | 41 | // Protocol Protocol 42 | type Protocol string 43 | 44 | // Protocol 45 | const ( 46 | ProtocolTCP Protocol = "tcp" 47 | ProtocolWebsocket Protocol = "websocket" 48 | ) 49 | 50 | // Service Name 定义统一的服务名 51 | const ( 52 | SNWGateway = "wgateway" 53 | SNTGateway = "tgateway" 54 | SNLogin = "chat" //login 55 | SNChat = "chat" //chat 56 | SNService = "royal" //rpc service 57 | ) 58 | 59 | // ServiceID ServiceID 60 | type ServiceID string 61 | 62 | // SessionID SessionID 63 | type SessionID string 64 | 65 | type Magic [4]byte 66 | 67 | var ( 68 | MagicLogicPkt = Magic{0xc3, 0x11, 0xa3, 0x65} 69 | MagicBasicPkt = Magic{0xc3, 0x15, 0xa7, 0x65} 70 | ) 71 | 72 | const ( 73 | OfflineReadIndexExpiresIn = time.Hour * 24 * 30 // 读索引在缓存中的过期时间 74 | OfflineSyncIndexCount = 2000 //单次同步消息索引的数量 75 | OfflineMessageExpiresIn = 15 // 离线消息过期时间 76 | MessageMaxCountPerPage = 200 // 同步消息内容时每页的最大数据 77 | ) 78 | 79 | const ( 80 | MessageTypeText = 1 81 | MessageTypeImage = 2 82 | MessageTypeVoice = 3 83 | MessageTypeVideo = 4 84 | ) 85 | -------------------------------------------------------------------------------- /wire/endian/big_endian_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2013-2016 The btcsuite developers 2 | // Use of this source code is governed by an ISC 3 | // license that can be found in the LICENSE file. 4 | 5 | package endian 6 | 7 | import ( 8 | "encoding/binary" 9 | "testing" 10 | ) 11 | 12 | func TestReadUint32(t *testing.T) { 13 | a := uint32(0x01020304) 14 | arr := make([]byte, 4) 15 | binary.BigEndian.PutUint32(arr, a) 16 | t.Log(arr) //[1 2 3 4] 17 | 18 | binary.LittleEndian.PutUint32(arr, a) 19 | t.Log(arr) //[4 3 2 1] 20 | } 21 | 22 | func TestSerial(t *testing.T) { 23 | var pkt = struct { 24 | Source uint16 25 | Destination uint16 26 | Sequence uint32 27 | Acknowledgment uint32 // 28 | Data []byte 29 | }{ 30 | Source: 4000, 31 | Destination: 80, 32 | Sequence: 100, 33 | Acknowledgment: 1, 34 | Data: []byte("hello world"), 35 | } 36 | 37 | // 为了方便观看,使用大端序 38 | endian := binary.BigEndian 39 | 40 | buf := make([]byte, 1024) // buffer 41 | i := 0 42 | endian.PutUint16(buf[i:i+2], pkt.Source) 43 | i += 2 // 移动指针2个字节 44 | endian.PutUint16(buf[i:i+2], pkt.Destination) 45 | i += 2 46 | endian.PutUint32(buf[i:i+4], pkt.Sequence) 47 | i += 4 48 | endian.PutUint32(buf[i:i+4], pkt.Acknowledgment) 49 | i += 4 50 | // 由于data长度不确定,必须先把长度写入buf, 这样在反序列化时就可以正确的解析出data 51 | dataLen := len(pkt.Data) 52 | endian.PutUint32(buf[i:i+4], uint32(dataLen)) 53 | i += 4 54 | // 写入数据data 55 | copy(buf[i:i+dataLen], pkt.Data) 56 | i += dataLen 57 | t.Log(buf[0:i]) 58 | } 59 | 60 | func TestDecode(t *testing.T) { 61 | var pkt struct { 62 | Source uint16 63 | Destination uint16 64 | Sequence uint32 65 | Acknowledgment uint32 // 66 | Data []byte 67 | } 68 | 69 | recv := []byte{15, 160, 0, 80, 0, 0, 0, 100, 0, 0, 0, 1, 0, 0, 0, 11, 104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100} 70 | // 为了方便观看,使用大端序 71 | endian := binary.BigEndian 72 | i := 0 73 | pkt.Source = endian.Uint16(recv[i : i+2]) 74 | i += 2 // 移动指针2个字节 75 | pkt.Destination = endian.Uint16(recv[i : i+2]) 76 | i += 2 77 | pkt.Sequence = endian.Uint32(recv[i : i+4]) 78 | i += 4 79 | pkt.Acknowledgment = endian.Uint32(recv[i : i+4]) 80 | i += 4 81 | dataLen := endian.Uint32(recv[i : i+4]) 82 | i += 4 83 | pkt.Data = make([]byte, dataLen) 84 | copy(pkt.Data, recv[i:i+int(dataLen)]) 85 | t.Logf("Src:%d Dest:%d Seq:%d Ack:%d Data:%s", pkt.Source, pkt.Destination, pkt.Sequence, pkt.Acknowledgment, pkt.Data) 86 | // t.Log(pkt) 87 | } 88 | -------------------------------------------------------------------------------- /wire/endian/helper.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2013-2016 The btcsuite developers 2 | // Use of this source code is governed by an ISC 3 | // license that can be found in the LICENSE file. 4 | 5 | package endian 6 | 7 | import ( 8 | "encoding/binary" 9 | "io" 10 | ) 11 | 12 | var Default = binary.BigEndian 13 | 14 | // ReadUint8 从 reader 中读取一个 uint8 15 | func ReadUint8(r io.Reader) (uint8, error) { 16 | var bytes = make([]byte, 1) 17 | if _, err := io.ReadFull(r, bytes); err != nil { 18 | return 0, err 19 | } 20 | return uint8(bytes[0]), nil 21 | } 22 | 23 | // ReadUint32 从 reader 中读取一个 uint32 24 | func ReadUint32(r io.Reader) (uint32, error) { 25 | var bytes = make([]byte, 4) 26 | if _, err := io.ReadFull(r, bytes); err != nil { 27 | return 0, err 28 | } 29 | return Default.Uint32(bytes), nil 30 | } 31 | 32 | // ReadUint16 从 reader 中读取一个 uint16 33 | func ReadUint16(r io.Reader) (uint16, error) { 34 | var bytes = make([]byte, 2) 35 | if _, err := io.ReadFull(r, bytes); err != nil { 36 | return 0, err 37 | } 38 | return Default.Uint16(bytes), nil 39 | } 40 | 41 | // ReadUint64 从 reader 中读取一个 uint64 42 | func ReadUint64(r io.Reader) (uint64, error) { 43 | var bytes = make([]byte, 8) 44 | if _, err := io.ReadFull(r, bytes); err != nil { 45 | return 0, err 46 | } 47 | return Default.Uint64(bytes), nil 48 | } 49 | 50 | // ReadString 从 reader 中读取一个 string 51 | func ReadString(r io.Reader) (string, error) { 52 | buf, err := ReadBytes(r) 53 | if err != nil { 54 | return "", err 55 | } 56 | return string(buf), nil 57 | } 58 | 59 | // ReadBytes 从 reader 中读取一个 []byte, reader中前4byte 必须是[]byte 的长度 60 | func ReadBytes(r io.Reader) ([]byte, error) { 61 | bufLen, err := ReadUint32(r) 62 | if err != nil { 63 | return nil, err 64 | } 65 | buf := make([]byte, bufLen) 66 | _, err = io.ReadFull(r, buf) 67 | if err != nil { 68 | return nil, err 69 | } 70 | return buf, nil 71 | } 72 | 73 | //ReadFixedBytes 读取固定长度的字节 74 | func ReadFixedBytes(len int, r io.Reader) ([]byte, error) { 75 | buf := make([]byte, len) 76 | _, err := io.ReadFull(r, buf) 77 | if err != nil { 78 | return nil, err 79 | } 80 | return buf, nil 81 | } 82 | 83 | // WriteUint8 写一个 uint8到 writer 中 84 | func WriteUint8(w io.Writer, val uint8) error { 85 | buf := []byte{byte(val)} 86 | if _, err := w.Write(buf); err != nil { 87 | return err 88 | } 89 | return nil 90 | } 91 | 92 | // WriteUint16 写一个 int16到 writer 中 93 | func WriteUint16(w io.Writer, val uint16) error { 94 | buf := make([]byte, 2) 95 | Default.PutUint16(buf, val) 96 | if _, err := w.Write(buf); err != nil { 97 | return err 98 | } 99 | return nil 100 | } 101 | 102 | // WriteUint32 写一个 int32到 writer 中 103 | func WriteUint32(w io.Writer, val uint32) error { 104 | buf := make([]byte, 4) 105 | Default.PutUint32(buf, val) 106 | if _, err := w.Write(buf); err != nil { 107 | return err 108 | } 109 | return nil 110 | } 111 | 112 | // WriteUint64 写一个 int64到 writer 中 113 | func WriteUint64(w io.Writer, val uint64) error { 114 | buf := make([]byte, 8) 115 | Default.PutUint64(buf, val) 116 | if _, err := w.Write(buf); err != nil { 117 | return err 118 | } 119 | return nil 120 | } 121 | 122 | // WriteString 写一个 string 到 writer 中 123 | func WriteString(w io.Writer, str string) error { 124 | if err := WriteBytes(w, []byte(str)); err != nil { 125 | return err 126 | } 127 | return nil 128 | } 129 | 130 | // WriteBytes 写一个 buf []byte 到 writer 中 131 | func WriteBytes(w io.Writer, buf []byte) error { 132 | bufLen := len(buf) 133 | 134 | if err := WriteUint32(w, uint32(bufLen)); err != nil { 135 | return err 136 | } 137 | if _, err := w.Write(buf); err != nil { 138 | return err 139 | } 140 | return nil 141 | } 142 | 143 | func WriteShortBytes(w io.Writer, buf []byte) error { 144 | bufLen := len(buf) 145 | 146 | if err := WriteUint16(w, uint16(bufLen)); err != nil { 147 | return err 148 | } 149 | if _, err := w.Write(buf); err != nil { 150 | return err 151 | } 152 | return nil 153 | } 154 | 155 | func ReadShortBytes(r io.Reader) ([]byte, error) { 156 | bufLen, err := ReadUint16(r) 157 | if err != nil { 158 | return nil, err 159 | } 160 | buf := make([]byte, bufLen) 161 | _, err = io.ReadFull(r, buf) 162 | if err != nil { 163 | return nil, err 164 | } 165 | return buf, nil 166 | } 167 | 168 | func ReadShortString(r io.Reader) (string, error) { 169 | buf, err := ReadShortBytes(r) 170 | if err != nil { 171 | return "", err 172 | } 173 | return string(buf), nil 174 | } 175 | -------------------------------------------------------------------------------- /wire/grpc_helper.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "google.golang.org/grpc/codes" 5 | "google.golang.org/grpc/status" 6 | ) 7 | 8 | // IsGrpcError check err type with a grpc status 9 | func IsGrpcError(err error, code codes.Code) bool { 10 | if err == nil { 11 | return false 12 | } 13 | if st, ok := status.FromError(err); ok { 14 | return st.Code() == code 15 | } 16 | return false 17 | } 18 | -------------------------------------------------------------------------------- /wire/pkt/basic_pkt.go: -------------------------------------------------------------------------------- 1 | package pkt 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/klintcheng/kim/wire/endian" 7 | ) 8 | 9 | // basic pkt code 10 | const ( 11 | CodePing = uint16(1) 12 | CodePong = uint16(2) 13 | ) 14 | 15 | type BasicPkt struct { 16 | Code uint16 17 | Length uint16 18 | Body []byte 19 | } 20 | 21 | func (p *BasicPkt) Decode(r io.Reader) error { 22 | var err error 23 | if p.Code, err = endian.ReadUint16(r); err != nil { 24 | return err 25 | } 26 | if p.Length, err = endian.ReadUint16(r); err != nil { 27 | return err 28 | } 29 | if p.Length > 0 { 30 | if p.Body, err = endian.ReadFixedBytes(int(p.Length), r); err != nil { 31 | return err 32 | } 33 | } 34 | return nil 35 | } 36 | 37 | func (p *BasicPkt) Encode(w io.Writer) error { 38 | if err := endian.WriteUint16(w, p.Code); err != nil { 39 | return err 40 | } 41 | if err := endian.WriteUint16(w, p.Length); err != nil { 42 | return err 43 | } 44 | if p.Length > 0 { 45 | if _, err := w.Write(p.Body); err != nil { 46 | return err 47 | } 48 | } 49 | return nil 50 | } 51 | -------------------------------------------------------------------------------- /wire/pkt/packet.go: -------------------------------------------------------------------------------- 1 | package pkt 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | 8 | "io" 9 | 10 | "github.com/klintcheng/kim/wire" 11 | "github.com/klintcheng/kim/wire/endian" 12 | "google.golang.org/protobuf/proto" 13 | ) 14 | 15 | // LogicPkt 定义了网关对外的client消息结构 16 | type LogicPkt struct { 17 | Header 18 | Body []byte `json:"body,omitempty"` 19 | } 20 | 21 | // HeaderOption HeaderOption 22 | type HeaderOption func(*Header) 23 | 24 | // WithStatus WithStatus 25 | func WithStatus(status Status) HeaderOption { 26 | return func(h *Header) { 27 | h.Status = status 28 | } 29 | } 30 | 31 | // WithSeq WithSeq 32 | func WithSeq(seq uint32) HeaderOption { 33 | return func(h *Header) { 34 | h.Sequence = seq 35 | } 36 | } 37 | 38 | // WithChannel set channelID 39 | func WithChannel(channelID string) HeaderOption { 40 | return func(h *Header) { 41 | h.ChannelId = channelID 42 | } 43 | } 44 | 45 | // WithDest WithDest 46 | func WithDest(dest string) HeaderOption { 47 | return func(h *Header) { 48 | h.Dest = dest 49 | } 50 | } 51 | 52 | // New new a empty payload message 53 | func New(command string, options ...HeaderOption) *LogicPkt { 54 | pkt := &LogicPkt{} 55 | pkt.Command = command 56 | 57 | for _, option := range options { 58 | option(&pkt.Header) 59 | } 60 | if pkt.Sequence == 0 { 61 | pkt.Sequence = wire.Seq.Next() 62 | } 63 | return pkt 64 | } 65 | 66 | // NewFrom new packet from a header 67 | func NewFrom(header *Header) *LogicPkt { 68 | pkt := &LogicPkt{} 69 | pkt.Header = Header{ 70 | Command: header.Command, 71 | Sequence: header.Sequence, 72 | ChannelId: header.ChannelId, 73 | Status: header.Status, 74 | Dest: header.Dest, 75 | } 76 | return pkt 77 | } 78 | 79 | // ReadPkt read bytes to LogicPkt from a reader 80 | func (p *LogicPkt) Decode(r io.Reader) error { 81 | headerBytes, err := endian.ReadBytes(r) 82 | if err != nil { 83 | return err 84 | } 85 | if err := proto.Unmarshal(headerBytes, &p.Header); err != nil { 86 | return err 87 | } 88 | // read body 89 | p.Body, err = endian.ReadBytes(r) 90 | if err != nil { 91 | return err 92 | } 93 | return nil 94 | } 95 | 96 | // Encode Encode Header to writer 97 | func (p *LogicPkt) Encode(w io.Writer) error { 98 | headerBytes, err := proto.Marshal(&p.Header) 99 | if err != nil { 100 | return err 101 | } 102 | if err := endian.WriteBytes(w, headerBytes); err != nil { 103 | return err 104 | } 105 | if err := endian.WriteBytes(w, p.Body); err != nil { 106 | return err 107 | } 108 | return nil 109 | } 110 | 111 | // ReadBody val must be a pointer 112 | func (p *LogicPkt) ReadBody(val proto.Message) error { 113 | return proto.Unmarshal(p.Body, val) 114 | } 115 | 116 | // WritePb WritePb 117 | func (p *LogicPkt) WriteBody(val proto.Message) *LogicPkt { 118 | if val == nil { 119 | return p 120 | } 121 | p.Body, _ = proto.Marshal(val) 122 | return p 123 | } 124 | 125 | // StringBody return string body 126 | func (p *LogicPkt) StringBody() string { 127 | return string(p.Body) 128 | } 129 | 130 | func (p *LogicPkt) String() string { 131 | return fmt.Sprintf("header:%v body:%dbits", &p.Header, len(p.Body)) 132 | } 133 | 134 | func (h *Header) ServiceName() string { 135 | arr := strings.SplitN(h.Command, ".", 2) 136 | if len(arr) <= 1 { 137 | return "default" 138 | } 139 | return arr[0] 140 | } 141 | 142 | // AddMeta AddMeta 143 | func (p *LogicPkt) AddMeta(m ...*Meta) { 144 | p.Meta = append(p.Meta, m...) 145 | } 146 | 147 | // AddStringMeta AddStringMeta 148 | func (p *LogicPkt) AddStringMeta(key, value string) { 149 | p.AddMeta(&Meta{ 150 | Key: key, 151 | Value: value, 152 | Type: MetaType_string, 153 | }) 154 | } 155 | 156 | // GetMeta extra value 157 | func (p *LogicPkt) GetMeta(key string) (interface{}, bool) { 158 | return FindMeta(p.Meta, key) 159 | } 160 | 161 | func FindMeta(meta []*Meta, key string) (interface{}, bool) { 162 | for _, m := range meta { 163 | if m.Key == key { 164 | switch m.Type { 165 | case MetaType_int: 166 | v, _ := strconv.Atoi(m.Value) 167 | return v, true 168 | case MetaType_float: 169 | v, _ := strconv.ParseFloat(m.Value, 64) 170 | return v, true 171 | } 172 | return m.Value, true 173 | } 174 | } 175 | return nil, false 176 | } 177 | 178 | // DelMeta DelMeta 179 | func (p *LogicPkt) DelMeta(key string) { 180 | for i, m := range p.Meta { 181 | if m.Key == key { 182 | length := len(p.Meta) 183 | if i < length-1 { 184 | copy(p.Meta[i:], p.Meta[i+1:]) 185 | } 186 | p.Meta = p.Meta[:length-1] 187 | } 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /wire/pkt/read_write.go: -------------------------------------------------------------------------------- 1 | package pkt 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "reflect" 8 | 9 | "github.com/klintcheng/kim/wire" 10 | ) 11 | 12 | type Packet interface { 13 | Decode(r io.Reader) error 14 | Encode(w io.Writer) error 15 | } 16 | 17 | func MustReadLogicPkt(r io.Reader) (*LogicPkt, error) { 18 | val, err := Read(r) 19 | if err != nil { 20 | return nil, err 21 | } 22 | if lp, ok := val.(*LogicPkt); ok { 23 | return lp, nil 24 | } 25 | return nil, fmt.Errorf("packet is not a logic packet") 26 | } 27 | 28 | func MustReadBasicPkt(r io.Reader) (*BasicPkt, error) { 29 | val, err := Read(r) 30 | if err != nil { 31 | return nil, err 32 | } 33 | if bp, ok := val.(*BasicPkt); ok { 34 | return bp, nil 35 | } 36 | return nil, fmt.Errorf("packet is not a basic packet") 37 | } 38 | 39 | func Read(r io.Reader) (interface{}, error) { 40 | magic := wire.Magic{} 41 | _, err := io.ReadFull(r, magic[:]) 42 | if err != nil { 43 | return nil, err 44 | } 45 | switch magic { 46 | case wire.MagicLogicPkt: 47 | p := new(LogicPkt) 48 | if err := p.Decode(r); err != nil { 49 | return nil, err 50 | } 51 | return p, nil 52 | case wire.MagicBasicPkt: 53 | p := new(BasicPkt) 54 | if err := p.Decode(r); err != nil { 55 | return nil, err 56 | } 57 | return p, nil 58 | default: 59 | return nil, fmt.Errorf("magic code %s is incorrect", magic) 60 | } 61 | } 62 | 63 | func Marshal(p Packet) []byte { 64 | buf := new(bytes.Buffer) 65 | kind := reflect.TypeOf(p).Elem() 66 | 67 | if kind.AssignableTo(reflect.TypeOf(LogicPkt{})) { 68 | _, _ = buf.Write(wire.MagicLogicPkt[:]) 69 | } else if kind.AssignableTo(reflect.TypeOf(BasicPkt{})) { 70 | _, _ = buf.Write(wire.MagicBasicPkt[:]) 71 | } 72 | _ = p.Encode(buf) 73 | return buf.Bytes() 74 | } 75 | -------------------------------------------------------------------------------- /wire/pkt/read_write_test.go: -------------------------------------------------------------------------------- 1 | package pkt 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/klintcheng/kim/wire" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestMarshal(t *testing.T) { 11 | bp := &BasicPkt{ 12 | Code: CodePing, 13 | } 14 | 15 | bts := Marshal(bp) 16 | t.Log(bts) 17 | 18 | assert.Equal(t, wire.MagicBasicPkt[1], bts[1]) 19 | assert.Equal(t, wire.MagicBasicPkt[2], bts[2]) 20 | 21 | lp := New("login.signin") 22 | bts2 := Marshal(lp) 23 | t.Log(bts2) 24 | 25 | assert.Equal(t, wire.MagicLogicPkt[1], bts2[1]) 26 | assert.Equal(t, wire.MagicLogicPkt[2], bts2[2]) 27 | } 28 | -------------------------------------------------------------------------------- /wire/proto/common.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package pkt; 3 | option go_package = "./pkt"; 4 | 5 | // status is a uint16 value 6 | enum Status { 7 | Success = 0; 8 | // client defined 9 | 10 | // client error 100-200 11 | NoDestination = 100; 12 | InvalidPacketBody = 101; 13 | InvalidCommand = 103; 14 | Unauthorized = 105 ; 15 | // server error 300-400 16 | SystemException = 300; 17 | NotImplemented = 301; 18 | //specific error 19 | SessionNotFound = 404; // session lost 20 | } 21 | 22 | enum MetaType { 23 | int = 0; 24 | string = 1; 25 | float = 2; 26 | } 27 | 28 | enum ContentType { 29 | Protobuf = 0; 30 | Json = 1; 31 | } 32 | 33 | enum Flag { 34 | Request = 0; 35 | Response = 1; 36 | Push = 2; 37 | } 38 | 39 | message Meta { 40 | string key = 1; 41 | string value = 2; 42 | MetaType type = 3; 43 | } 44 | 45 | message Header { 46 | string command = 1; 47 | // sender channel id 48 | string channelId = 2; 49 | uint32 sequence = 3; 50 | Flag flag = 4; 51 | Status status = 5; 52 | // destination is defined as a account,group or room 53 | string dest = 6; 54 | repeated Meta meta = 7; 55 | } 56 | 57 | message InnerHandshakeReq{ 58 | string ServiceId = 1; 59 | } 60 | 61 | message InnerHandshakeResponse{ 62 | uint32 Code = 1; 63 | string Error = 2; 64 | } -------------------------------------------------------------------------------- /wire/proto/protocol.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package pkt; 3 | option go_package = "./pkt"; 4 | 5 | message LoginReq { 6 | string token = 1; 7 | string isp = 2; 8 | string zone = 3; // location code 9 | repeated string tags = 4; 10 | } 11 | 12 | message LoginResp { 13 | string channelId = 1; 14 | string account = 2; 15 | } 16 | 17 | message KickoutNotify { 18 | string channelId = 1; 19 | } 20 | 21 | message Session { 22 | string channelId = 1;// session id 23 | string gateId = 2; // gateway ID 24 | string account = 3; 25 | string zone = 4; 26 | string isp = 5; 27 | string remoteIP = 6; 28 | string device = 7; 29 | string app = 8; 30 | repeated string tags = 9; 31 | } 32 | 33 | // chat message 34 | message MessageReq { 35 | int32 type = 1; 36 | string body = 2; 37 | string extra = 3; 38 | } 39 | 40 | message MessageResp { 41 | int64 messageId = 1; 42 | int64 sendTime = 2; 43 | } 44 | 45 | message MessagePush { 46 | int64 messageId = 1; 47 | int32 type = 2; 48 | string body = 3; 49 | string extra = 4; 50 | string sender = 5; 51 | int64 sendTime = 6; 52 | } 53 | 54 | message ErrorResp { 55 | string message= 1; 56 | } 57 | 58 | message MessageAckReq { 59 | int64 messageId = 1; 60 | } 61 | 62 | message GroupCreateReq { 63 | string name = 1; 64 | string avatar = 2; 65 | string introduction = 3; 66 | string owner = 4; 67 | repeated string members = 5; 68 | } 69 | 70 | message GroupCreateResp { 71 | string group_id = 1; 72 | } 73 | 74 | message GroupCreateNotify { 75 | string group_id = 1; 76 | repeated string members = 2; 77 | } 78 | 79 | message GroupJoinReq { 80 | string account = 1; 81 | string group_id = 2; 82 | } 83 | 84 | message GroupQuitReq { 85 | string account = 1; 86 | string group_id = 2; 87 | } 88 | 89 | message GroupGetReq { 90 | string group_id = 1; 91 | } 92 | 93 | message Member { 94 | string account = 1; 95 | string alias = 2; 96 | string avatar = 3; 97 | int64 join_time = 4; 98 | } 99 | 100 | message GroupGetResp { 101 | string id = 1; 102 | string name = 2; 103 | string avatar = 3; 104 | string introduction = 4; 105 | string owner = 5; 106 | repeated Member members = 6; 107 | int64 created_at = 7; 108 | } 109 | 110 | message GroupJoinNotify { 111 | string group_id = 1; 112 | string account = 2; 113 | } 114 | 115 | message GroupQuitNotify { 116 | string group_id = 1; 117 | string account = 2; 118 | } 119 | 120 | message MessageIndexReq { 121 | int64 message_id = 1; 122 | } 123 | 124 | message MessageIndexResp { 125 | repeated MessageIndex indexes = 1; 126 | } 127 | 128 | message MessageIndex { 129 | int64 message_id = 1; 130 | int32 direction = 2; 131 | int64 send_time = 3; 132 | string accountB = 4; 133 | string group = 5; 134 | } 135 | 136 | message MessageContentReq { 137 | repeated int64 message_ids = 1; 138 | } 139 | 140 | message MessageContent { 141 | int64 messageId = 1; 142 | int32 type = 2; 143 | string body = 3; 144 | string extra = 4; 145 | } 146 | 147 | message MessageContentResp { 148 | repeated MessageContent contents = 1; 149 | } 150 | 151 | // message Pkt { 152 | // uint32 Source = 1; 153 | // uint64 Sequence = 3; 154 | // bytes Data = 5; 155 | // } -------------------------------------------------------------------------------- /wire/proto/rpc.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package rpc; 3 | option go_package = "./rpc"; 4 | 5 | 6 | message User { 7 | string account = 1; 8 | string alias = 2; 9 | string avatar = 3; 10 | int64 created_at = 4; 11 | } 12 | 13 | message Message { 14 | int64 id = 1; 15 | int32 type = 2; 16 | string body = 3; 17 | string extra = 4; 18 | } 19 | 20 | message Member { 21 | string account = 1; 22 | string alias = 2; 23 | string avatar = 3; 24 | int64 join_time = 4; 25 | } 26 | 27 | // service 28 | 29 | message InsertMessageReq { 30 | string sender = 1; 31 | string dest = 2; 32 | int64 send_time = 3; 33 | Message message = 4; 34 | } 35 | 36 | message InsertMessageResp { 37 | int64 message_id = 1; 38 | } 39 | 40 | message AckMessageReq { 41 | string account = 1; 42 | int64 message_id = 2; 43 | } 44 | 45 | message CreateGroupReq { 46 | string app = 1; 47 | string name = 2; 48 | string avatar = 3; 49 | string introduction = 4; 50 | string owner = 5; 51 | repeated string members = 6; 52 | } 53 | 54 | message CreateGroupResp { 55 | string group_id = 1; 56 | } 57 | 58 | message JoinGroupReq { 59 | string account = 1; 60 | string group_id = 2; 61 | } 62 | 63 | message QuitGroupReq { 64 | string account = 1; 65 | string group_id = 2; 66 | } 67 | 68 | message GetGroupReq { 69 | string group_id = 1; 70 | } 71 | 72 | message GetGroupResp { 73 | string id = 1; 74 | string name = 2; 75 | string avatar = 3; 76 | string introduction = 4; 77 | string owner = 5; 78 | int64 created_at = 6; 79 | } 80 | 81 | message GroupMembersReq { 82 | string group_id = 1; 83 | } 84 | 85 | message GroupMembersResp { 86 | repeated Member users = 1; 87 | } 88 | 89 | message GetOfflineMessageIndexReq { 90 | string account = 1; 91 | int64 message_id = 2; 92 | } 93 | 94 | message GetOfflineMessageIndexResp { 95 | repeated MessageIndex list = 1; 96 | } 97 | 98 | message MessageIndex { 99 | int64 message_id = 1; 100 | int32 direction = 2; 101 | int64 send_time = 3; 102 | string accountB = 4; 103 | string group = 5; 104 | } 105 | 106 | message GetOfflineMessageContentReq { 107 | repeated int64 message_ids = 1; 108 | } 109 | 110 | message GetOfflineMessageContentResp { 111 | repeated Message list = 1; 112 | } -------------------------------------------------------------------------------- /wire/seq.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "math" 5 | "sync/atomic" 6 | ) 7 | 8 | // Sequence Sequence 9 | type sequence struct { 10 | num uint32 11 | } 12 | 13 | // Next return Next Seq id 14 | func (s *sequence) Next() uint32 { 15 | next := atomic.AddUint32(&s.num, 1) 16 | if next == math.MaxUint32 { 17 | if atomic.CompareAndSwapUint32(&s.num, next, 1) { 18 | return 1 19 | } 20 | return s.Next() 21 | } 22 | return next 23 | } 24 | 25 | // Seq Seq 26 | var Seq = sequence{num: 1} 27 | -------------------------------------------------------------------------------- /wire/token/jwt.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | jwtgo "github.com/dgrijalva/jwt-go" 8 | ) 9 | 10 | // DefaultSecret 测试使用 11 | const ( 12 | DefaultSecret = "jwt-1sNzdiSgnNuxyq2g7xml2JvLArU" 13 | ) 14 | 15 | // Token Token 16 | type Token struct { 17 | Account string `json:"acc,omitempty"` 18 | App string `json:"app,omitempty"` 19 | Exp int64 `json:"exp,omitempty"` 20 | } 21 | 22 | var errExpiredToken = errors.New("expired token") 23 | 24 | // Valid Valid 25 | func (t *Token) Valid() error { 26 | if t.Exp < time.Now().Unix() { 27 | return errExpiredToken 28 | } 29 | return nil 30 | } 31 | 32 | // Parse ParseJwtToken 33 | func Parse(secret, tk string) (*Token, error) { 34 | var token = new(Token) 35 | _, err := jwtgo.ParseWithClaims(tk, token, func(jwttk *jwtgo.Token) (interface{}, error) { 36 | return []byte(secret), nil 37 | }) 38 | if err != nil { 39 | return nil, err 40 | } 41 | return token, nil 42 | } 43 | 44 | // Generate a JWT token 45 | func Generate(secret string, token *Token) (string, error) { 46 | jtk := jwtgo.NewWithClaims(jwtgo.SigningMethodHS256, token) 47 | return jtk.SignedString([]byte(secret)) 48 | } 49 | -------------------------------------------------------------------------------- /wire/token/jwt_test.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestParseJwtToken(t *testing.T) { 11 | tk1 := &Token{ 12 | Account: "test1", 13 | App: "kim", 14 | Exp: time.Now().Add(time.Hour * 24 * 7).Unix(), 15 | } 16 | secret := "123456" 17 | 18 | tokenString, err := Generate(secret, tk1) 19 | assert.Nil(t, err) 20 | t.Log(tokenString) 21 | 22 | tk2, err := Parse(secret, tokenString) 23 | assert.Nil(t, err) 24 | assert.Equal(t, "test1", tk2.Account) 25 | } 26 | --------------------------------------------------------------------------------