├── .gitignore ├── example ├── websocket_service │ ├── macBuildLiunx.sh │ ├── fireTower.toml │ └── main.go └── web │ ├── firetower_v0.1.js │ └── index.html ├── config ├── topicmanage.toml ├── fireTower.toml └── config.go ├── store ├── single │ ├── cluster.go │ ├── cluster.conn.go │ ├── cluster.topic.go │ └── provider.go ├── redis │ ├── cluster.go │ ├── provider.go │ ├── redis_test.go │ ├── cluster.topic.go │ └── cluster.conn.go └── store.go ├── protocol ├── errorinfo.go ├── coder.go ├── const.go ├── pusher.go └── protocol.go ├── service └── tower │ ├── error_info.go │ ├── brazier_test.go │ ├── serverside.go │ ├── brazier.go │ ├── bucket.go │ └── tower.go ├── pkg └── nats │ ├── nats_test.go │ └── nats.go ├── utils └── utils.go ├── LICENSE ├── go.mod ├── README.md ├── e2e └── tower_test.go └── go.sum /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | experiment -------------------------------------------------------------------------------- /example/websocket_service/macBuildLiunx.sh: -------------------------------------------------------------------------------- 1 | CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build 2 | -------------------------------------------------------------------------------- /config/topicmanage.toml: -------------------------------------------------------------------------------- 1 | # 内部socket通信配置 2 | 3 | heartbeat = 5 # 秒(s) 发送心跳包的间隔时间 4 | servicetimeout = 8 # 服务端超时时间 超过该时间没有收到client的心跳则断开连接 5 | 6 | [socket] 7 | port = 6666 -------------------------------------------------------------------------------- /store/single/cluster.go: -------------------------------------------------------------------------------- 1 | package single 2 | 3 | type ClusterStore struct { 4 | } 5 | 6 | func (s *ClusterStore) ClusterNumber() (int64, error) { 7 | return 1, nil 8 | } 9 | -------------------------------------------------------------------------------- /protocol/errorinfo.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import "errors" 4 | 5 | var ( 6 | // ErrorClose 连接关闭的错误信息 7 | ErrorClose = errors.New("firetower is collapsed") 8 | // ErrorBlock block错误信息 9 | ErrorBlock = errors.New("network congestion") 10 | ) 11 | -------------------------------------------------------------------------------- /example/websocket_service/fireTower.toml: -------------------------------------------------------------------------------- 1 | # 连接配置 2 | 3 | chanLens = 1000 # channal 缓冲区大小 4 | heartbeat = 30 # 心跳间隔 单位秒(s) 5 | 6 | [grpc] 7 | address = "localhost:6667" 8 | 9 | [bucket] 10 | Num = 4 # 启动多少个Bucket 11 | CentralChanCount = 100000 # 整台服务器的消息中心处理通道容量 12 | BuffChanCount = 10000 # 每个bucket的消息通道容量 13 | ConsumerNum = 2 # 每个bucket有多少个消费者同时向socket中推送消息 -------------------------------------------------------------------------------- /config/fireTower.toml: -------------------------------------------------------------------------------- 1 | # 连接配置 2 | 3 | chanLens = 1000 # channal 缓冲区大小 4 | heartbeat = 600 # 心跳间隔 单位秒(s) 5 | 6 | topicServiceAddr = "0.0.0.0:6666" 7 | 8 | clusterMode = "single" 9 | 10 | [bucket] 11 | num = 4 # 启动多少个Bucket 12 | centralChanCount = 100000 # 整台服务器的消息中心处理通道容量 13 | buffChanCount = 1000 # 每个bucket的消息通道容量 14 | consumerNum = 32 # 每个bucket有多少个消费者同时向socket中推送消息 -------------------------------------------------------------------------------- /protocol/coder.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import "github.com/vmihailenco/msgpack/v5" 4 | 5 | type DefaultCoder[T any] struct{} 6 | 7 | func (c *DefaultCoder[T]) Decode(data []byte, fire *FireInfo[T]) error { 8 | return msgpack.Unmarshal(data, fire) 9 | } 10 | 11 | func (c *DefaultCoder[T]) Encode(msg *FireInfo[T]) []byte { 12 | b, _ := msgpack.Marshal(msg) 13 | return b 14 | } 15 | -------------------------------------------------------------------------------- /service/tower/error_info.go: -------------------------------------------------------------------------------- 1 | package tower 2 | 3 | import "errors" 4 | 5 | var ( 6 | // ErrorClosed gateway连接已经关闭的错误信息 7 | ErrorClosed = errors.New("firetower is collapsed") 8 | // ErrorTopicEmpty topic不存在的错误信息 9 | ErrorTopicEmpty = errors.New("topic is empty") 10 | // Server Side Mode Can not send to self 11 | ErrorServerSideMode = errors.New("server side tower can not send to self") 12 | ) 13 | -------------------------------------------------------------------------------- /store/redis/cluster.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import "context" 4 | 5 | type ClusterStore struct { 6 | provider *RedisProvider 7 | } 8 | 9 | func newClusterStore(provider *RedisProvider) *ClusterStore { 10 | return &ClusterStore{ 11 | provider: provider, 12 | } 13 | } 14 | 15 | const ( 16 | ClusterKey = "firetower_cluster_number" 17 | ) 18 | 19 | func (s *ClusterStore) ClusterNumber() (int64, error) { 20 | res := s.provider.dbconn.Incr(context.TODO(), s.provider.keyPrefix+ClusterKey) 21 | return res.Val(), res.Err() 22 | } 23 | -------------------------------------------------------------------------------- /store/store.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | type ClusterConnStore interface { 4 | OneClientAtomicAddBy(clientIP string, num int64) error 5 | GetAllConnNum() (uint64, error) 6 | RemoveClient(clientIP string) error 7 | ClusterMembers() ([]string, error) 8 | } 9 | 10 | type ClusterTopicStore interface { 11 | TopicConnAtomicAddBy(topic string, num int64) error 12 | GetTopicConnNum(topic string) (uint64, error) 13 | RemoveTopic(topic string) error 14 | Topics() (map[string]uint64, error) 15 | ClusterTopics() (map[string]uint64, error) 16 | } 17 | 18 | type ClusterStore interface { 19 | ClusterNumber() (int64, error) 20 | } 21 | -------------------------------------------------------------------------------- /pkg/nats/nats_test.go: -------------------------------------------------------------------------------- 1 | package nats 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "testing" 7 | "time" 8 | 9 | "github.com/nats-io/nats.go" 10 | ) 11 | 12 | func TestNats(t *testing.T) { 13 | nc, err := nats.Connect("nats://localhost:4222", nats.Name("FireTower"), nats.UserInfo("firetower", "firetower")) 14 | if err != nil { 15 | log.Fatal(err) 16 | } 17 | 18 | topic := "chat.world." 19 | 20 | go func() { 21 | nc.Subscribe(topic+">", func(msg *nats.Msg) { 22 | fmt.Println("received", string(msg.Data)) 23 | }) 24 | }() 25 | 26 | time.Sleep(time.Second) 27 | if err := nc.Publish(topic+"1", []byte(`{"message":"hello"}`)); err != nil { 28 | t.Fatal(err) 29 | } 30 | 31 | time.Sleep(time.Second * 10) 32 | } 33 | -------------------------------------------------------------------------------- /utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/holdno/snowFlakeByGo" 7 | ) 8 | 9 | var ( 10 | // IdWorker 全局唯一id生成器实例 11 | idWorker *snowFlakeByGo.Worker 12 | ) 13 | 14 | func SetupIDWorker(clusterID int64) { 15 | idWorker, _ = snowFlakeByGo.NewWorker(clusterID) 16 | } 17 | 18 | func IDWorker() *snowFlakeByGo.Worker { 19 | return idWorker 20 | } 21 | 22 | // GetIP 获取当前服务器ip 23 | func GetIP() (string, error) { 24 | addrs, err := net.InterfaceAddrs() 25 | if err != nil { 26 | return "", nil 27 | } 28 | for _, a := range addrs { 29 | if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { 30 | if ipnet.IP.To4() != nil { 31 | return ipnet.IP.String(), nil 32 | } 33 | } 34 | } 35 | return "127.0.0.1", nil 36 | } 37 | -------------------------------------------------------------------------------- /service/tower/brazier_test.go: -------------------------------------------------------------------------------- 1 | package tower 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | ) 10 | 11 | func TestConnID(t *testing.T) { 12 | connId = math.MaxUint64 13 | 14 | atomic.AddUint64(&connId, 1) 15 | if connId != 0 { 16 | t.Fatal("wrong connid") 17 | } 18 | atomic.AddUint64(&connId, 1) 19 | if connId != 1 { 20 | t.Fatal("wrong connid") 21 | } 22 | } 23 | 24 | func TestSyncPool(t *testing.T) { 25 | var pool = sync.Pool{ 26 | New: func() interface{} { 27 | fmt.Println("new") 28 | return &struct{}{} 29 | }, 30 | } 31 | n := &struct{ Name string }{ 32 | Name: "hhhh", 33 | } 34 | 35 | pool.Put(n) 36 | pool.Put(n) 37 | pool.Put(n) 38 | pool.Put(n) 39 | 40 | a := pool.Get() 41 | b := pool.Get() 42 | 43 | a.(*struct{ Name string }).Name = "123" 44 | fmt.Println(b.(*struct{ Name string }).Name) 45 | 46 | pool.Get() 47 | pool.Get() 48 | pool.Get() 49 | } 50 | -------------------------------------------------------------------------------- /store/single/cluster.conn.go: -------------------------------------------------------------------------------- 1 | package single 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type ClusterConnStore struct { 8 | storage map[string]int64 9 | 10 | sync.RWMutex 11 | } 12 | 13 | func (s *ClusterConnStore) ClusterMembers() ([]string, error) { 14 | var list []string 15 | for k := range s.storage { 16 | list = append(list, k) 17 | } 18 | return list, nil 19 | } 20 | 21 | func (s *ClusterConnStore) OneClientAtomicAddBy(clientIP string, num int64) error { 22 | s.Lock() 23 | defer s.Unlock() 24 | s.storage[clientIP] += num 25 | return nil 26 | } 27 | 28 | func (s *ClusterConnStore) GetAllConnNum() (uint64, error) { 29 | s.RLock() 30 | defer s.RUnlock() 31 | 32 | var result int64 33 | for _, v := range s.storage { 34 | result += v 35 | } 36 | return uint64(result), nil 37 | } 38 | 39 | func (s *ClusterConnStore) RemoveClient(clientIP string) error { 40 | s.Lock() 41 | defer s.Unlock() 42 | 43 | delete(s.storage, clientIP) 44 | return nil 45 | } 46 | -------------------------------------------------------------------------------- /store/single/cluster.topic.go: -------------------------------------------------------------------------------- 1 | package single 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type ClusterTopicStore struct { 8 | storage map[string]int64 9 | 10 | sync.RWMutex 11 | } 12 | 13 | func (s *ClusterTopicStore) TopicConnAtomicAddBy(topic string, num int64) error { 14 | s.Lock() 15 | defer s.Unlock() 16 | 17 | s.storage[topic] += num 18 | return nil 19 | } 20 | 21 | func (s *ClusterTopicStore) RemoveTopic(topic string) error { 22 | s.Lock() 23 | defer s.Unlock() 24 | delete(s.storage, topic) 25 | return nil 26 | } 27 | 28 | func (s *ClusterTopicStore) GetTopicConnNum(topic string) (uint64, error) { 29 | s.RLock() 30 | defer s.RUnlock() 31 | return uint64(s.storage[topic]), nil 32 | } 33 | 34 | func (s *ClusterTopicStore) Topics() (map[string]uint64, error) { 35 | result := make(map[string]uint64) 36 | for k, v := range result { 37 | result[k] = uint64(v) 38 | } 39 | return result, nil 40 | } 41 | 42 | func (s *ClusterTopicStore) ClusterTopics() (map[string]uint64, error) { 43 | return s.Topics() 44 | } 45 | -------------------------------------------------------------------------------- /store/single/provider.go: -------------------------------------------------------------------------------- 1 | package single 2 | 3 | import "github.com/holdno/firetower/store" 4 | 5 | var provider *SingleProvider 6 | 7 | type SingleProvider struct { 8 | clusterConnStore store.ClusterConnStore 9 | clusterTopicStore store.ClusterTopicStore 10 | clusterStore store.ClusterStore 11 | } 12 | 13 | func Setup() (*SingleProvider, error) { 14 | provider = &SingleProvider{ 15 | clusterConnStore: &ClusterConnStore{ 16 | storage: make(map[string]int64), 17 | }, 18 | clusterTopicStore: &ClusterTopicStore{ 19 | storage: make(map[string]int64), 20 | }, 21 | clusterStore: &ClusterStore{}, 22 | } 23 | return provider, nil 24 | } 25 | 26 | func Provider() *SingleProvider { 27 | return provider 28 | } 29 | 30 | func (s *SingleProvider) ClusterConnStore() store.ClusterConnStore { 31 | return s.clusterConnStore 32 | } 33 | 34 | func (s *SingleProvider) ClusterTopicStore() store.ClusterTopicStore { 35 | return s.clusterTopicStore 36 | } 37 | 38 | func (s *SingleProvider) ClusterStore() store.ClusterStore { 39 | return s.clusterStore 40 | } 41 | -------------------------------------------------------------------------------- /service/tower/serverside.go: -------------------------------------------------------------------------------- 1 | package tower 2 | 3 | import ( 4 | "github.com/holdno/firetower/protocol" 5 | ) 6 | 7 | type ServerSideTower[T any] interface { 8 | SetOnConnectHandler(fn func() bool) 9 | SetOnOfflineHandler(fn func()) 10 | SetReceivedHandler(fn func(protocol.ReadOnlyFire[T]) bool) 11 | SetSubscribeHandler(fn func(context protocol.FireLife, topic []string)) 12 | SetUnSubscribeHandler(fn func(context protocol.FireLife, topic []string)) 13 | SetBeforeSubscribeHandler(fn func(context protocol.FireLife, topic []string) bool) 14 | SetOnSystemRemove(fn func(topic string)) 15 | GetConnectNum(topic string) (uint64, error) 16 | Publish(fire *protocol.FireInfo[T]) error 17 | Subscribe(context protocol.FireLife, topics []string) error 18 | UnSubscribe(context protocol.FireLife, topics []string) error 19 | Logger() protocol.Logger 20 | TopicList() []string 21 | Run() 22 | Close() 23 | OnClose() chan struct{} 24 | } 25 | 26 | // BuildTower 实例化一个websocket客户端 27 | func (t *TowerManager[T]) BuildServerSideTower(clientId string) ServerSideTower[T] { 28 | return buildNewTower(t, nil, clientId) 29 | } 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 wby 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/holdno/firetower 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/gorilla/websocket v1.5.0 7 | github.com/holdno/snowFlakeByGo v1.0.0 8 | github.com/json-iterator/go v1.1.12 9 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 10 | github.com/modern-go/reflect2 v1.0.2 // indirect 11 | github.com/pelletier/go-toml v1.9.4 12 | ) 13 | 14 | require ( 15 | github.com/avast/retry-go/v4 v4.6.0 16 | github.com/go-redis/redis/v9 v9.0.0-beta.2 17 | github.com/nats-io/nats.go v1.16.0 18 | github.com/orcaman/concurrent-map/v2 v2.0.1 19 | github.com/vmihailenco/msgpack/v5 v5.3.5 20 | ) 21 | 22 | require ( 23 | github.com/cespare/xxhash/v2 v2.1.2 // indirect 24 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 25 | github.com/golang/protobuf v1.5.2 // indirect 26 | github.com/nats-io/nats-server/v2 v2.8.4 // indirect 27 | github.com/nats-io/nkeys v0.3.0 // indirect 28 | github.com/nats-io/nuid v1.0.1 // indirect 29 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 30 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect 31 | google.golang.org/protobuf v1.27.1 // indirect 32 | ) 33 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pelletier/go-toml" 7 | ) 8 | 9 | // FireTowerConfig 每个连接的配置信息 10 | type FireTowerConfig struct { 11 | ReadChanLens int 12 | WriteChanLens int 13 | Heartbeat int 14 | ServiceMode string 15 | Bucket BucketConfig 16 | Cluster Cluster 17 | } 18 | 19 | type Cluster struct { 20 | RedisOption Redis 21 | NatsOption Nats 22 | } 23 | 24 | type Redis struct { 25 | KeyPrefix string 26 | Addr string 27 | Password string 28 | DB int 29 | } 30 | 31 | type Nats struct { 32 | Addr string 33 | ServerName string 34 | UserName string 35 | Password string 36 | SubjectPrefix string 37 | } 38 | 39 | type BucketConfig struct { 40 | Num int 41 | CentralChanCount int64 42 | BuffChanCount int64 43 | ConsumerNum int 44 | } 45 | 46 | const ( 47 | SingleMode = "single" 48 | ClusterMode = "cluster" 49 | DefaultServiceMode = SingleMode 50 | ) 51 | 52 | var ( 53 | // ConfigTree 保存配置 54 | ConfigTree *toml.Tree 55 | ) 56 | 57 | func loadConfig(path string) { 58 | var ( 59 | err error 60 | ) 61 | if ConfigTree, err = toml.LoadFile(path); err != nil { 62 | fmt.Println("config load failed:", err) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /service/tower/brazier.go: -------------------------------------------------------------------------------- 1 | package tower 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/holdno/firetower/protocol" 7 | ) 8 | 9 | func newBrazier[T any]() *brazier[T] { 10 | return &brazier[T]{ 11 | pool: &sync.Pool{ 12 | New: func() interface{} { 13 | return &protocol.FireInfo[T]{ 14 | Context: protocol.FireLife{}, 15 | Message: protocol.TopicMessage[T]{}, 16 | } 17 | }, 18 | }, 19 | } 20 | } 21 | 22 | type brazier[T any] struct { 23 | len int 24 | pool *sync.Pool 25 | } 26 | 27 | func (b *brazier[T]) Extinguished(fire *protocol.FireInfo[T]) { 28 | if fire == nil || b.len > 100000 { 29 | return 30 | } 31 | var empty T 32 | b.len++ 33 | fire.MessageType = 0 34 | fire.Message.Data = empty 35 | fire.Message.Topic = "" 36 | fire.Message.Type = 0 37 | b.pool.Put(fire) 38 | } 39 | 40 | func (b *brazier[T]) LightAFire() *protocol.FireInfo[T] { 41 | b.len-- 42 | return b.pool.Get().(*protocol.FireInfo[T]) 43 | } 44 | 45 | type PusherInfo interface { 46 | ClientID() string 47 | UserID() string 48 | } 49 | 50 | func (t *TowerManager[T]) NewFire(source protocol.FireSource, tower PusherInfo) *protocol.FireInfo[T] { 51 | f := t.brazier.LightAFire() 52 | f.Message.Type = protocol.PublishOperation 53 | f.Context.Reset(source, tower.ClientID(), tower.UserID()) 54 | return f 55 | } 56 | -------------------------------------------------------------------------------- /protocol/const.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | type FireSource uint8 4 | 5 | const ( 6 | SourceClient FireSource = 1 7 | SourceLogic FireSource = 2 8 | SourceSystem FireSource = 3 9 | ) 10 | 11 | func (s FireSource) String() string { 12 | switch s { 13 | case SourceClient: 14 | return "client" 15 | case SourceLogic: 16 | return "logic" 17 | case SourceSystem: 18 | return "system" 19 | default: 20 | return "unknown" 21 | } 22 | } 23 | 24 | type FireOperation uint8 25 | 26 | const ( 27 | // PublishKey 与前端(客户端约定的推送关键字) 28 | PublishOperation FireOperation = 1 29 | SubscribeOperation FireOperation = 2 30 | UnSubscribeOperation FireOperation = 3 31 | // OfflineTopicByUserIdKey 踢除,将用户某个topic踢下线 32 | OfflineTopicByUserIdOperation FireOperation = 4 33 | // OfflineTopicKey 针对某个topic进行踢除 34 | OfflineTopicOperation FireOperation = 5 35 | // OfflineUserKey 将某个用户踢下线 36 | OfflineUserOperation FireOperation = 6 37 | ) 38 | 39 | func (o FireOperation) String() string { 40 | switch o { 41 | case PublishOperation: 42 | return "publish" 43 | case SubscribeOperation: 44 | return "subscribe" 45 | case UnSubscribeOperation: 46 | return "unSubscribe" 47 | case OfflineTopicByUserIdOperation: 48 | return "offlineTopicByUserid" 49 | case OfflineTopicOperation: 50 | return "offlineTopic" 51 | case OfflineUserOperation: 52 | return "offlineUser" 53 | default: 54 | return "unknown" 55 | } 56 | } 57 | 58 | const ( 59 | SYSTEM_CMD_REMOVE_USER = "remove_user" 60 | ) 61 | -------------------------------------------------------------------------------- /protocol/pusher.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "log/slog" 5 | "sync" 6 | ) 7 | 8 | type Pusher[T any] interface { 9 | Publish(fire *FireInfo[T]) error 10 | Receive() chan *FireInfo[T] 11 | } 12 | 13 | type Logger interface { 14 | Info(string, ...any) 15 | Debug(string, ...any) 16 | Error(string, ...any) 17 | } 18 | 19 | type SinglePusher[T any] struct { 20 | msg chan []byte 21 | b Brazier[T] 22 | once sync.Once 23 | coder Coder[T] 24 | fireChan chan *FireInfo[T] 25 | logger Logger 26 | } 27 | 28 | func (s *SinglePusher[T]) Publish(fire *FireInfo[T]) error { 29 | s.msg <- s.coder.Encode(fire) 30 | return nil 31 | } 32 | 33 | func (s *SinglePusher[T]) Receive() chan *FireInfo[T] { 34 | s.once.Do(func() { 35 | go func() { 36 | for { 37 | select { 38 | case m := <-s.msg: 39 | fire := new(FireInfo[T]) 40 | err := s.coder.Decode(m, fire) 41 | if err != nil { 42 | s.logger.Error("failed to decode message", slog.String("data", string(m)), slog.String("error", err.Error())) 43 | continue 44 | } 45 | s.fireChan <- fire 46 | } 47 | } 48 | }() 49 | }) 50 | return s.fireChan 51 | } 52 | 53 | type Brazier[T any] interface { 54 | Extinguished(fire *FireInfo[T]) 55 | LightAFire() *FireInfo[T] 56 | } 57 | 58 | func DefaultPusher[T any](b Brazier[T], coder Coder[T], logger Logger) *SinglePusher[T] { 59 | return &SinglePusher[T]{ 60 | msg: make(chan []byte, 100), 61 | b: b, 62 | coder: coder, 63 | fireChan: make(chan *FireInfo[T], 10000), 64 | logger: logger, 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /store/redis/provider.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/go-redis/redis/v9" 7 | "github.com/holdno/firetower/store" 8 | ) 9 | 10 | var provider *RedisProvider 11 | 12 | type RedisProvider struct { 13 | dbconn *redis.Client 14 | clientIP string 15 | keyPrefix string 16 | 17 | clusterConnStore store.ClusterConnStore 18 | clusterTopicStore store.ClusterTopicStore 19 | clusterStore store.ClusterStore 20 | } 21 | 22 | func Setup(addr, password string, db int, clientIP, keyPrefix string) (*RedisProvider, error) { 23 | if clientIP == "" { 24 | panic("setup redis store: required client IP") 25 | } 26 | provider = &RedisProvider{ 27 | keyPrefix: keyPrefix, 28 | dbconn: redis.NewClient(&redis.Options{ 29 | Addr: addr, 30 | Password: password, // no password set 31 | DB: db, // use default DB 32 | }), 33 | clientIP: clientIP, 34 | } 35 | 36 | provider.clusterConnStore = newClusterConnStore(provider) 37 | provider.clusterTopicStore = newClusterTopicStore(provider) 38 | provider.clusterStore = newClusterStore(provider) 39 | 40 | res := provider.dbconn.Ping(context.TODO()) 41 | if res.Err() != nil { 42 | return nil, res.Err() 43 | } 44 | 45 | return provider, nil 46 | } 47 | 48 | func (s *RedisProvider) ClusterConnStore() store.ClusterConnStore { 49 | return s.clusterConnStore 50 | } 51 | 52 | func (s *RedisProvider) ClusterTopicStore() store.ClusterTopicStore { 53 | return s.clusterTopicStore 54 | } 55 | 56 | func (s *RedisProvider) ClusterStore() store.ClusterStore { 57 | return s.clusterStore 58 | } 59 | -------------------------------------------------------------------------------- /pkg/nats/nats.go: -------------------------------------------------------------------------------- 1 | package nats 2 | 3 | import ( 4 | "log/slog" 5 | "sync" 6 | 7 | "github.com/holdno/firetower/config" 8 | "github.com/holdno/firetower/protocol" 9 | 10 | "github.com/nats-io/nats.go" 11 | ) 12 | 13 | var _ protocol.Pusher[any] = (*pusher[any])(nil) 14 | 15 | func MustSetupNatsPusher[T any](cfg config.Nats, coder protocol.Coder[T], logger protocol.Logger, topicFunc func() map[string]uint64) protocol.Pusher[T] { 16 | if cfg.SubjectPrefix == "" { 17 | cfg.SubjectPrefix = "firetower.topic." 18 | } 19 | p := &pusher[T]{ 20 | subjectPrefix: cfg.SubjectPrefix, 21 | coder: coder, 22 | currentTopic: topicFunc, 23 | msg: make(chan *protocol.FireInfo[T], 10000), 24 | logger: logger, 25 | } 26 | var err error 27 | p.nats, err = nats.Connect(cfg.Addr, nats.Name(cfg.ServerName), nats.UserInfo(cfg.UserName, cfg.Password)) 28 | if err != nil { 29 | panic(err) 30 | } 31 | return p 32 | } 33 | 34 | type pusher[T any] struct { 35 | subjectPrefix string 36 | msg chan *protocol.FireInfo[T] 37 | nats *nats.Conn 38 | once sync.Once 39 | b protocol.Brazier[T] 40 | coder protocol.Coder[T] 41 | currentTopic func() map[string]uint64 42 | logger protocol.Logger 43 | } 44 | 45 | func (p *pusher[T]) Publish(fire *protocol.FireInfo[T]) error { 46 | msg := nats.NewMsg(p.subjectPrefix + fire.Message.Topic) 47 | msg.Header.Set("topic", fire.Message.Topic) 48 | msg.Data = p.coder.Encode(fire) 49 | return p.nats.PublishMsg(msg) 50 | } 51 | 52 | func (p *pusher[T]) Receive() chan *protocol.FireInfo[T] { 53 | p.once.Do(func() { 54 | p.nats.Subscribe(p.subjectPrefix+">", func(msg *nats.Msg) { 55 | topic := msg.Header.Get("topic") 56 | if _, exist := p.currentTopic()[topic]; exist { 57 | fire := new(protocol.FireInfo[T]) 58 | if err := p.coder.Decode(msg.Data, fire); err != nil { 59 | p.logger.Error("failed to decode message", slog.String("data", string(msg.Data)), slog.Any("error", err), slog.String("topic", topic)) 60 | return 61 | } 62 | p.msg <- fire 63 | } 64 | msg.Ack() 65 | }) 66 | }) 67 | 68 | return p.msg 69 | } 70 | -------------------------------------------------------------------------------- /protocol/protocol.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "strconv" 5 | "time" 6 | 7 | json "github.com/json-iterator/go" 8 | 9 | "github.com/holdno/firetower/utils" 10 | ) 11 | 12 | // type Message struct { 13 | // Ctx MessageContext `json:"c"` 14 | // Data []byte `json:"d"` 15 | // Topic string `json:"t"` 16 | // } 17 | 18 | // type MessageContext struct { 19 | // ID string `json:"i"` 20 | // MsgTime int64 `json:"m"` 21 | // Source string `json:"s"` 22 | // PushType int `json:"p"` 23 | // Type string `json:"t"` 24 | // } 25 | 26 | type Coder[T any] interface { 27 | Decode([]byte, *FireInfo[T]) error 28 | Encode(msg *FireInfo[T]) []byte 29 | } 30 | 31 | type WebSocketMessage struct { 32 | MessageType int 33 | Data []byte 34 | } 35 | 36 | // FireInfo 接收的消息结构体 37 | type FireInfo[T any] struct { 38 | Context FireLife `json:"c"` 39 | MessageType int `json:"t"` 40 | Message TopicMessage[T] `json:"m"` 41 | } 42 | 43 | func (f *FireInfo[T]) Copy() FireInfo[T] { 44 | return *f 45 | } 46 | 47 | func (f *FireInfo[T]) GetContext() FireLife { 48 | return f.Context 49 | } 50 | 51 | func (f *FireInfo[T]) GetMessage() TopicMessage[T] { 52 | return f.Message 53 | } 54 | 55 | type ReadOnlyFire[T any] interface { 56 | GetContext() FireLife 57 | GetMessage() TopicMessage[T] 58 | Copy() FireInfo[T] 59 | } 60 | 61 | // TopicMessage 话题信息结构体 62 | type TopicMessage[T any] struct { 63 | Topic string `json:"topic"` 64 | Data T `json:"data,omitempty"` 65 | Type FireOperation `json:"type"` 66 | } 67 | 68 | func (s *TopicMessage[T]) Json() []byte { 69 | raw, _ := json.Marshal(s) 70 | return raw 71 | } 72 | 73 | type TowerInfo interface { 74 | UserID() string 75 | ClientID() string 76 | } 77 | 78 | // FireLife 客户端推送消息的结构体 79 | type FireLife struct { 80 | ID string `json:"i"` 81 | StartTime time.Time `json:"t"` 82 | ClientID string `json:"c"` 83 | UserID string `json:"u"` 84 | Source FireSource `json:"s"` 85 | ExtMeta map[string]string `json:"e,omitempty"` 86 | } 87 | 88 | func (f *FireLife) Reset(source FireSource, clientID, userID string) { 89 | f.StartTime = time.Now() 90 | f.ID = strconv.FormatInt(utils.IDWorker().GetId(), 10) 91 | f.ClientID = clientID 92 | f.UserID = userID 93 | f.Source = source 94 | } 95 | -------------------------------------------------------------------------------- /example/web/firetower_v0.1.js: -------------------------------------------------------------------------------- 1 | function firetower(addr, onopen) { 2 | var ws = new WebSocket(addr); 3 | ws.onopen = onopen 4 | 5 | var _this = this 6 | 7 | this.publishKey = 1 8 | this.subscribeKey = 2 9 | this.unSubscribeKey = 3 10 | 11 | this.logopen = true // 开启log 12 | 13 | var logInfo = function(data){ 14 | console.log('[firetower] INFO', data) 15 | } 16 | 17 | this.publish = function(topic, data){ 18 | if (topic == '' || data == '') { 19 | return errorMessage('topic或data参数不能为空') 20 | } 21 | 22 | if (_this.logopen) { 23 | logInfo('publish topic:"' + topic + '", data:' + JSON.stringify(data)) 24 | } 25 | 26 | ws.send(JSON.stringify({ 27 | type: _this.publishKey, 28 | topic: topic, 29 | data: data 30 | })) 31 | return successMessage('发送成功') 32 | } 33 | 34 | this.onmessage = false 35 | ws.onmessage = function(event){ 36 | if (_this.logopen) { 37 | logInfo('new message:' + JSON.stringify(event.data)) 38 | } 39 | 40 | if (event.data == 'heartbeat') { 41 | return 42 | } 43 | 44 | if (_this.onmessage) { 45 | _this.onmessage(event) 46 | } 47 | } 48 | 49 | this.onclose = false 50 | ws.onclose = function(){ 51 | if (_this.onclose) { 52 | _this.onclose() 53 | } 54 | } 55 | 56 | this.subscribe = function(topicArr){ 57 | if (!Array.isArray(topicArr)) { 58 | topicArr = [topicArr] 59 | } 60 | 61 | if (_this.logopen) { 62 | logInfo('subscribe:"' + topicArr.join(',') + '"') 63 | } 64 | 65 | ws.send(JSON.stringify({ 66 | type: _this.subscribeKey, 67 | topic: topicArr.join(','), 68 | data: '' 69 | })) 70 | 71 | 72 | } 73 | 74 | this.unsubscribe = function(topicArr){ 75 | if (!Array.isArray(topicArr)) { 76 | topicArr = [topicArr] 77 | } 78 | 79 | if (_this.logopen) { 80 | logInfo('unSubscribe:"' + topicArr.join(',') + '"') 81 | } 82 | 83 | ws.send(JSON.stringify({ 84 | type: _this.unSubscribeKey, 85 | topic: topicArr.join(','), 86 | data: '' 87 | })) 88 | } 89 | 90 | function errorMessage(info){ 91 | return { 92 | type: 'error', 93 | info: info 94 | } 95 | } 96 | 97 | function successMessage(info){ 98 | return { 99 | type: 'success', 100 | info: info 101 | } 102 | } 103 | } 104 | 105 | -------------------------------------------------------------------------------- /store/redis/redis_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "fmt" 7 | "testing" 8 | "time" 9 | 10 | "github.com/go-redis/redis/v9" 11 | "github.com/holdno/firetower/utils" 12 | ) 13 | 14 | func TestRedis_HSet(t *testing.T) { 15 | rdb := redis.NewClient(&redis.Options{ 16 | Addr: "localhost:6379", 17 | Password: "", // no password set 18 | DB: 0, // use default DB 19 | }) 20 | 21 | ip, err := utils.GetIP() 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | res := rdb.HSet(context.Background(), ClusterConnKey, ip, fmt.Sprintf("%d|%d", 2315, time.Now().Unix())) 26 | if res.Err() != nil { 27 | t.Fatal(res.Err()) 28 | } 29 | 30 | gRes := rdb.HGet(context.Background(), ClusterConnKey, ip) 31 | if gRes.Err() != nil { 32 | t.Fatal(gRes.Err()) 33 | } 34 | 35 | t.Log(gRes.Val()) 36 | 37 | if gRes.Val() != fmt.Sprintf("%d|%d", 2315, time.Now().Unix()) { 38 | t.Fatal(gRes.Val()) 39 | } 40 | t.Log("successful") 41 | } 42 | 43 | func TestLua(t *testing.T) { 44 | rdb := redis.NewClient(&redis.Options{ 45 | Addr: "localhost:6379", 46 | Password: "", // no password set 47 | DB: 0, // use default DB 48 | }) 49 | res1 := rdb.Del(context.TODO(), ClusterConnKey) 50 | if res1.Err() != nil { 51 | t.Fatal(res1.Err()) 52 | } 53 | 54 | b := make([]byte, 16) 55 | binary.LittleEndian.PutUint64(b, 2315) 56 | binary.LittleEndian.PutUint64(b[8:], uint64(time.Now().Unix()-10)) 57 | res := rdb.HSet(context.TODO(), ClusterConnKey, "localhost", string(b)) 58 | if res.Err() != nil { 59 | t.Fatal(res.Err()) 60 | } 61 | 62 | result := rdb.Eval(context.TODO(), clusterShutdownCheckerScript, []string{ClusterConnKey, fmt.Sprintf("%d", time.Now().Unix())}, 2) 63 | if result.Err() != nil { 64 | t.Fatal(result.Err()) 65 | } 66 | t.Log(result.Val()) 67 | } 68 | 69 | func TestPack(t *testing.T) { 70 | var number uint64 = 3512 71 | r := packClientConnNumberNow(number) 72 | n, err := unpackClientConnNumberNow(r + "1") 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | 77 | if n != number { 78 | t.Fatal(n, number) 79 | } 80 | t.Log("successful") 81 | } 82 | 83 | func TestLock(t *testing.T) { 84 | lockKey := "ft_cluster_master" 85 | rdb := redis.NewClient(&redis.Options{ 86 | Addr: "localhost:6379", 87 | Password: "", // no password set 88 | DB: 0, // use default DB 89 | }) 90 | locker := func(clusterID string) { 91 | for { 92 | res := rdb.SetNX(context.Background(), lockKey, clusterID, time.Second*3) 93 | if res.Val() { 94 | for { 95 | ticker := time.NewTicker(time.Second * 1) 96 | select { 97 | case <-ticker.C: 98 | res := rdb.Eval(context.Background(), keepMasterScript, []string{lockKey, clusterID}, 2) 99 | if res.Val() != "success" || res.Err() != nil { 100 | ticker.Stop() 101 | break 102 | } 103 | fmt.Println(clusterID, "expire") 104 | } 105 | } 106 | } 107 | time.Sleep(time.Second) 108 | } 109 | } 110 | 111 | go locker("1") 112 | go locker("2") 113 | go locker("3") 114 | go locker("4") 115 | 116 | time.Sleep(time.Second * 20) 117 | } 118 | -------------------------------------------------------------------------------- /store/redis/cluster.topic.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "strconv" 6 | "sync" 7 | 8 | "github.com/holdno/firetower/store" 9 | ) 10 | 11 | type ClusterTopicStore struct { 12 | storage map[string]int64 13 | provider *RedisProvider 14 | 15 | sync.RWMutex 16 | } 17 | 18 | func newClusterTopicStore(provider *RedisProvider) store.ClusterTopicStore { 19 | s := &ClusterTopicStore{ 20 | storage: make(map[string]int64), 21 | provider: provider, 22 | } 23 | if err := s.init(); err != nil { 24 | panic(err) 25 | } 26 | return s 27 | } 28 | 29 | const ( 30 | ClusterTopicKeyPrefix = "ft_topics_conn_" 31 | ) 32 | 33 | func (s *ClusterTopicStore) getTopicKey(clientIP string) string { 34 | return s.provider.keyPrefix + ClusterTopicKeyPrefix + clientIP 35 | } 36 | 37 | func (s *ClusterTopicStore) init() error { 38 | res := s.provider.dbconn.Del(context.TODO(), s.getTopicKey(s.provider.clientIP)) 39 | return res.Err() 40 | } 41 | 42 | func (s *ClusterTopicStore) TopicConnAtomicAddBy(topic string, num int64) error { 43 | s.Lock() 44 | defer s.Unlock() 45 | s.storage[topic] += num 46 | res := s.provider.dbconn.HSet(context.TODO(), s.getTopicKey(s.provider.clientIP), topic, s.storage[topic]) 47 | if res.Err() != nil { 48 | s.storage[topic] -= num 49 | return res.Err() 50 | } 51 | 52 | if s.storage[topic] == 0 { 53 | res := s.provider.dbconn.HDel(context.TODO(), s.getTopicKey(s.provider.clientIP), topic) 54 | if res.Err() != nil { 55 | return res.Err() 56 | } 57 | } 58 | 59 | return nil 60 | } 61 | 62 | func (s *ClusterTopicStore) RemoveTopic(topic string) error { 63 | res := s.provider.dbconn.HDel(context.TODO(), s.getTopicKey(s.provider.clientIP), topic) 64 | if res.Err() != nil { 65 | return res.Err() 66 | } 67 | s.Lock() 68 | defer s.Unlock() 69 | delete(s.storage, topic) 70 | return nil 71 | } 72 | 73 | func (s *ClusterTopicStore) GetTopicConnNum(topic string) (uint64, error) { 74 | members, err := s.provider.ClusterConnStore().ClusterMembers() 75 | if err != nil { 76 | return 0, err 77 | } 78 | 79 | var total uint64 80 | for _, v := range members { 81 | res := s.provider.dbconn.HGet(context.TODO(), s.getTopicKey(v), topic) 82 | if res.Err() != nil { 83 | return 0, res.Err() 84 | } 85 | num, _ := strconv.ParseUint(res.Val(), 10, 64) 86 | total += num 87 | } 88 | 89 | return total, nil 90 | } 91 | 92 | func (s *ClusterTopicStore) Topics() (map[string]uint64, error) { 93 | res := s.provider.dbconn.HGetAll(context.TODO(), s.getTopicKey(s.provider.clientIP)) 94 | if res.Err() != nil { 95 | return nil, res.Err() 96 | } 97 | 98 | result := make(map[string]uint64) 99 | for k, v := range res.Val() { 100 | vInt, err := strconv.ParseUint(v, 10, 64) 101 | if err != nil { 102 | return nil, err 103 | } 104 | result[k] = vInt 105 | } 106 | 107 | return result, nil 108 | } 109 | 110 | func (s *ClusterTopicStore) ClusterTopics() (map[string]uint64, error) { 111 | members, err := s.provider.ClusterConnStore().ClusterMembers() 112 | if err != nil { 113 | return nil, err 114 | } 115 | 116 | result := make(map[string]uint64) 117 | for _, v := range members { 118 | res := s.provider.dbconn.HGetAll(context.TODO(), s.getTopicKey(v)) 119 | if res.Err() != nil { 120 | return nil, res.Err() 121 | } 122 | for k, v := range res.Val() { 123 | vInt, err := strconv.ParseUint(v, 10, 64) 124 | if err != nil { 125 | return nil, err 126 | } 127 | result[k] += vInt 128 | } 129 | } 130 | 131 | return result, nil 132 | } 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

firetower logo

2 | 3 |

4 | Downloads 5 | 6 | Build Status 7 | Package Utilities 8 | Version 9 | license 10 |

11 |

Firetower

12 | firetower是一个用golang开发的分布式推送(IM)服务 13 | 14 | 完全基于websocket封装,围绕topic进行sub/pub 15 | 自身实现订阅管理服务,无需依赖redis 16 | 聊天室demo体验地址: http://chat.ojbk.io 17 | ### 可用版本 18 | go get github.com/holdno/firetower@v0.4.0 19 | ### 构成 20 | 21 | 基本服务由两点构成 22 | - topic管理服务 23 | > 详见示例 example/topicService 24 | 25 | 该服务主要作为集群环境下唯一的topic管理节点 26 | firetower一定要依赖这个管理节点才能正常工作 27 | 大型项目可以将该服务单独部署在一台独立的服务器上,小项目可以同连接层服务一起部署在一台机器上 28 | - 连接层服务(websocket服务) 29 | > 详见示例 example/websocketService 30 | 31 | websocket服务是用户基于firetower自定义开发的业务逻辑 32 | 可以通过firetower提供的回调方法来实现自己的业务逻辑 33 | (web client 在 example/web 下) 34 | ### 架构图 35 | ![beacontower](http://img.holdno.com/github/holdno/firetower_process.png) 36 | ### 接入姿势 37 | ``` golang 38 | package main 39 | 40 | import ( 41 | "fmt" 42 | "github.com/gorilla/websocket" 43 | "github.com/holdno/firetower/gateway" 44 | "github.com/holdno/snowFlakeByGo" // 这是一个分布式全局唯一id生成器 45 | "net/http" 46 | "strconv" 47 | ) 48 | 49 | var upgrader = websocket.Upgrader{ 50 | CheckOrigin: func(r *http.Request) bool { 51 | return true 52 | }, 53 | } 54 | 55 | var GlobalIdWorker *snowFlakeByGo.Worker 56 | 57 | func main() { 58 | GlobalIdWorker, _ = snowFlakeByGo.NewWorker(1) 59 | // 如果是集群环境 一定一定要给每个服务设置唯一的id 60 | // 取值范围 1-1024 61 | gateway.ClusterId = 1 62 | http.HandleFunc("/ws", Websocket) 63 | fmt.Println("websocket service start: 0.0.0.0:9999") 64 | http.ListenAndServe("0.0.0.0:9999", nil) 65 | } 66 | 67 | func Websocket(w http.ResponseWriter, r *http.Request) { 68 | // 做用户身份验证 69 | ... 70 | // 验证成功才升级连接 71 | ws, _ := upgrader.Upgrade(w, r, nil) 72 | // 生成一个全局唯一的clientid 正常业务下这个clientid应该由前端传入 73 | id := GlobalIdWorker.GetId() 74 | tower := gateway.BuildTower(ws, strconv.FormatInt(id, 10)) // 生成一个烽火台 75 | tower.Run() 76 | } 77 | ``` 78 | ### 目前支持的回调方法 79 | - ReadHandler 收到客户端发送的消息时触发 80 | ``` golang 81 | tower := gateway.BuildTower(ws, strconv.FormatInt(id, 10)) // 创建beacontower实例 82 | tower.SetReadHandler(func(fire *gateway.FireInfo) bool { // 绑定ReadHandler回调方法 83 | // message.Data 为客户端传来的信息 84 | // message.Topic 为消息传递的topic 85 | // 用户可在此做发送验证 86 | // 判断发送方是否有权限向到达方发送内容 87 | // 通过 Publish 方法将内容推送到所有订阅 message.Topic 的连接 88 | tower.Publish(message) 89 | return true 90 | }) 91 | ``` 92 | 93 | - ReadTimeoutHandler 客户端websocket请求超时处理(生产速度高于消费速度) 94 | ``` golang 95 | tower.SetReadTimeoutHandler(func(fire *gateway.FireInfo) { 96 | fmt.Println("read timeout:", fire.Message.Type, fire.Message.Topic, fire.Message.Data) 97 | }) 98 | ``` 99 | 100 | - BeforeSubscribeHandler 客户端订阅某些topic时触发(这个时候topic还没有订阅,是before subscribe) 101 | ``` golang 102 | tower.SetBeforeSubscribeHandler(func(context *gateway.FireLife, topic []string) bool { 103 | // 这里用来判断当前用户是否允许订阅该topic 104 | return true 105 | }) 106 | ``` 107 | 108 | - SubscribeHandler 客户端完成某些topic的订阅时触发(topic已经被topicService收录并管理) 109 | ``` golang 110 | tower.SetSubscribeHandler(func(context *gateway.FireLife, topic []string) bool { 111 | // 我们给出的聊天室示例是需要用到这个回调方法 112 | // 当某个聊天室(topic)有新的订阅者,则需要通知其他已经在聊天室内的成员当前在线人数+1 113 | for _, v := range topic { 114 | num := tower.GetConnectNum(v) 115 | // 继承订阅消息的context 116 | var pushmsg = gateway.NewFireInfo(tower, context) 117 | pushmsg.Message.Topic = v 118 | pushmsg.Message.Data = []byte(fmt.Sprintf("{\"type\":\"onSubscribe\",\"data\":%d}", num)) 119 | tower.Publish(pushmsg) 120 | } 121 | return true 122 | }) 123 | ``` 124 | 125 | - UnSubscribeHandler 客户端取消订阅某些topic完成时触发 (这个回调方法没有设置before方法,目前没有想到什么场景会使用到before unsubscribe,如果有请issue联系) 126 | ``` golang 127 | tower.SetUnSubscribeHandler(func(context *gateway.FireLife, topic []string) bool { 128 | for _, v := range topic { 129 | num := tower.GetConnectNum(v) 130 | var pushmsg = gateway.NewFireInfo(tower, context) 131 | pushmsg.Message.Topic = v 132 | pushmsg.Message.Data = []byte(fmt.Sprintf("{\"type\":\"onUnsubscribe\",\"data\":%d}", num)) 133 | tower.Publish(pushmsg) 134 | } 135 | return true 136 | }) 137 | ``` 138 | 注意:当客户端断开websocket连接时firetower会将其在线时订阅的所有topic进行退订 会触发UnSubscirbeHandler 139 | 140 | ## TODO 141 | - 运行时web看板 142 | - 提供推送相关http及grpc接口 143 | 144 | ## License 145 | [MIT](https://opensource.org/licenses/MIT) 146 | -------------------------------------------------------------------------------- /store/redis/cluster.conn.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "fmt" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | type ClusterConnStore struct { 12 | storage map[string]int64 13 | provider *RedisProvider 14 | isMaster bool 15 | 16 | keepMasterScriptSHA string 17 | clusterShutdownCheckerScriptSHA string 18 | 19 | sync.RWMutex 20 | } 21 | 22 | const ( 23 | ClusterConnKey = "ft_cluster_conn" 24 | ) 25 | 26 | func newClusterConnStore(provider *RedisProvider) *ClusterConnStore { 27 | s := &ClusterConnStore{ 28 | storage: make(map[string]int64), 29 | provider: provider, 30 | } 31 | 32 | res := s.provider.dbconn.ScriptLoad(context.TODO(), clusterShutdownCheckerScript) 33 | if res.Err() != nil { 34 | panic(res.Err()) 35 | } 36 | s.clusterShutdownCheckerScriptSHA = res.Val() 37 | 38 | res = s.provider.dbconn.ScriptLoad(context.TODO(), keepMasterScript) 39 | if res.Err() != nil { 40 | panic(res.Err()) 41 | } 42 | s.keepMasterScriptSHA = res.Val() 43 | 44 | go s.KeepClusterClear() 45 | 46 | if err := s.init(); err != nil { 47 | panic(err) 48 | } 49 | return s 50 | } 51 | 52 | func (s *ClusterConnStore) init() error { 53 | res := s.provider.dbconn.HDel(context.TODO(), s.provider.keyPrefix+ClusterConnKey, s.provider.clientIP) 54 | return res.Err() 55 | } 56 | 57 | func (s *ClusterConnStore) ClusterMembers() ([]string, error) { 58 | res := s.provider.dbconn.HGetAll(context.TODO(), s.provider.keyPrefix+ClusterConnKey) 59 | if res.Err() != nil { 60 | return nil, res.Err() 61 | } 62 | var list []string 63 | for k := range res.Val() { 64 | list = append(list, k) 65 | } 66 | return list, nil 67 | } 68 | 69 | func packClientConnNumberNow(num uint64) string { 70 | b := make([]byte, 16) 71 | binary.LittleEndian.PutUint64(b, num) 72 | binary.LittleEndian.PutUint64(b[8:], uint64(time.Now().Unix())) 73 | return string(b) 74 | } 75 | 76 | func unpackClientConnNumberNow(b string) (uint64, error) { 77 | if len([]byte(b)) != 16 { 78 | return 0, fmt.Errorf("wrong pack data, got lenght %d, need 16", len([]byte(b))) 79 | } 80 | return binary.LittleEndian.Uint64([]byte(b)[:8]), nil 81 | } 82 | 83 | func (s *ClusterConnStore) OneClientAtomicAddBy(clientIP string, num int64) error { 84 | s.Lock() 85 | defer s.Unlock() 86 | s.storage[clientIP] += num 87 | 88 | res := s.provider.dbconn.HSet(context.TODO(), s.provider.keyPrefix+ClusterConnKey, clientIP, packClientConnNumberNow(uint64(s.storage[clientIP]))) 89 | if res.Err() != nil { 90 | s.storage[clientIP] -= num 91 | return res.Err() 92 | } 93 | return nil 94 | } 95 | 96 | func (s *ClusterConnStore) GetAllConnNum() (uint64, error) { 97 | res := s.provider.dbconn.HGetAll(context.TODO(), s.provider.keyPrefix+ClusterConnKey) 98 | if res.Err() != nil { 99 | return 0, res.Err() 100 | } 101 | var result uint64 102 | for _, v := range res.Val() { 103 | singleNumber, err := unpackClientConnNumberNow(v) 104 | if err != nil { 105 | return 0, err 106 | } 107 | result += singleNumber 108 | } 109 | return result, nil 110 | } 111 | 112 | func (s *ClusterConnStore) RemoveClient(clientIP string) error { 113 | res := s.provider.dbconn.HDel(context.TODO(), s.provider.keyPrefix+ClusterConnKey, clientIP) 114 | if res.Err() != nil { 115 | return res.Err() 116 | } 117 | s.Lock() 118 | defer s.Unlock() 119 | delete(s.storage, clientIP) 120 | 121 | return nil 122 | } 123 | 124 | const clusterShutdownCheckerScript = ` 125 | local key = KEYS[1] 126 | local currentTime = KEYS[2] 127 | local clientConns = redis.call("hgetall", key) 128 | local delTable = {} 129 | 130 | for i = 1, #clientConns, 2 do 131 | local number, timestamp = struct.unpack(" tostring(timestamp + 5)) 133 | then 134 | table.insert(delTable, clientConns[i]) 135 | end 136 | end 137 | 138 | return redis.call("hdel", key, unpack(delTable)) 139 | ` 140 | 141 | func (s *ClusterConnStore) KeepClusterClear() { 142 | go s.SelectMaster() 143 | for { 144 | time.Sleep(time.Second) 145 | 146 | if !s.isMaster { 147 | continue 148 | } 149 | 150 | result := s.provider.dbconn.EvalSha(context.TODO(), s.clusterShutdownCheckerScriptSHA, []string{s.provider.keyPrefix + ClusterConnKey, fmt.Sprintf("%d", time.Now().Unix())}, 2) 151 | if result.Err() != nil { 152 | // todo log 153 | continue 154 | } 155 | } 156 | } 157 | 158 | const keepMasterScript = ` 159 | local lockKey = KEYS[1] 160 | local currentID = KEYS[2] 161 | local currentMaster = redis.call('get', lockKey) 162 | local success = "fail" 163 | 164 | if (currentID == currentMaster) 165 | then 166 | redis.call('expire', lockKey, 3) 167 | success = "success" 168 | end 169 | return success 170 | ` 171 | 172 | func (s *ClusterConnStore) SelectMaster() { 173 | lockKey := s.provider.keyPrefix + "ft_cluster_master" 174 | for { 175 | res := s.provider.dbconn.SetNX(context.Background(), lockKey, s.provider.clientIP, time.Second*3) 176 | if res.Val() { 177 | s.isMaster = true 178 | ticker := time.NewTicker(time.Second * 1) 179 | for { 180 | select { 181 | case <-ticker.C: 182 | res := s.provider.dbconn.EvalSha(context.Background(), s.keepMasterScriptSHA, []string{lockKey, s.provider.clientIP}, 2) 183 | if res.Val() != "success" || res.Err() != nil { 184 | s.isMaster = false 185 | ticker.Stop() 186 | break 187 | } 188 | } 189 | } 190 | } 191 | time.Sleep(time.Second) 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /e2e/tower_test.go: -------------------------------------------------------------------------------- 1 | package e2e 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "log/slog" 7 | "net/http" 8 | "strconv" 9 | "testing" 10 | "time" 11 | 12 | "github.com/gorilla/websocket" 13 | "github.com/holdno/firetower/config" 14 | "github.com/holdno/firetower/protocol" 15 | towersvc "github.com/holdno/firetower/service/tower" 16 | "github.com/holdno/firetower/utils" 17 | "github.com/holdno/snowFlakeByGo" 18 | jsoniter "github.com/json-iterator/go" 19 | ) 20 | 21 | var upgrader = websocket.Upgrader{ 22 | CheckOrigin: func(r *http.Request) bool { 23 | return true 24 | }, 25 | } 26 | 27 | type messageInfo struct { 28 | From string `json:"from"` 29 | Data json.RawMessage `json:"data"` 30 | Type string `json:"type"` 31 | } 32 | 33 | // GlobalIdWorker 全局唯一id生成器 34 | var GlobalIdWorker *snowFlakeByGo.Worker 35 | 36 | var _ towersvc.PusherInfo = (*SystemPusher)(nil) 37 | 38 | type SystemPusher struct { 39 | clientID string 40 | } 41 | 42 | func (s *SystemPusher) UserID() string { 43 | return "system" 44 | } 45 | func (s *SystemPusher) ClientID() string { 46 | return s.clientID 47 | } 48 | 49 | var systemer *SystemPusher 50 | 51 | const ( 52 | listenAddress = "127.0.0.1:9999" 53 | websocketPath = "/ws" 54 | ) 55 | 56 | func startTower() { 57 | // 全局唯一id生成器 58 | tm, err := towersvc.Setup[jsoniter.RawMessage](config.FireTowerConfig{ 59 | WriteChanLens: 1000, 60 | ReadChanLens: 1000, 61 | Heartbeat: 30, 62 | ServiceMode: config.SingleMode, 63 | Bucket: config.BucketConfig{ 64 | Num: 4, 65 | CentralChanCount: 100000, 66 | BuffChanCount: 1000, 67 | ConsumerNum: 2, 68 | }, 69 | }) 70 | 71 | if err != nil { 72 | panic(err) 73 | } 74 | 75 | systemer = &SystemPusher{ 76 | clientID: "1", 77 | } 78 | 79 | tower := &Tower{ 80 | tm: tm, 81 | } 82 | http.HandleFunc(websocketPath, tower.Websocket) 83 | tm.Logger().Info("http server start", slog.String("address", listenAddress)) 84 | if err := http.ListenAndServe(listenAddress, nil); err != nil { 85 | panic(err) 86 | } 87 | } 88 | 89 | type Tower struct { 90 | tm towersvc.Manager[jsoniter.RawMessage] 91 | } 92 | 93 | const ( 94 | bindTopic = "bindtopic" 95 | ) 96 | 97 | // Websocket http转websocket连接 并实例化firetower 98 | func (t *Tower) Websocket(w http.ResponseWriter, r *http.Request) { 99 | // 做用户身份验证 100 | 101 | // 验证成功才升级连接 102 | ws, _ := upgrader.Upgrade(w, r, nil) 103 | 104 | id := utils.IDWorker().GetId() 105 | tower, err := t.tm.BuildTower(ws, strconv.FormatInt(id, 10)) 106 | if err != nil { 107 | w.WriteHeader(http.StatusInternalServerError) 108 | w.Write([]byte(err.Error())) 109 | return 110 | } 111 | 112 | tower.SetReadHandler(func(fire protocol.ReadOnlyFire[jsoniter.RawMessage]) bool { 113 | return true 114 | }) 115 | 116 | tower.SetReceivedHandler(func(fi protocol.ReadOnlyFire[jsoniter.RawMessage]) bool { 117 | return true 118 | }) 119 | 120 | tower.SetReadTimeoutHandler(func(fire protocol.ReadOnlyFire[jsoniter.RawMessage]) { 121 | messageInfo := new(messageInfo) 122 | err := json.Unmarshal(fire.GetMessage().Data, messageInfo) 123 | if err != nil { 124 | return 125 | } 126 | messageInfo.Type = "timeout" 127 | b, _ := json.Marshal(messageInfo) 128 | tower.SendToClient(b) 129 | }) 130 | 131 | tower.SetBeforeSubscribeHandler(func(context protocol.FireLife, topic []string) bool { 132 | // 这里用来判断当前用户是否允许订阅该topic 133 | for _, v := range topic { 134 | if v == bindTopic { 135 | messageInfo := new(messageInfo) 136 | messageInfo.From = "system" 137 | messageInfo.Type = "event" 138 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "bind", "topic": "%s"}`, v)) 139 | msg, _ := json.Marshal(messageInfo) 140 | tower.SendToClient(msg) 141 | } 142 | } 143 | return true 144 | }) 145 | 146 | tower.SetSubscribeHandler(func(context protocol.FireLife, topic []string) { 147 | for _, v := range topic { 148 | messageInfo := new(messageInfo) 149 | messageInfo.From = "system" 150 | messageInfo.Type = "event" 151 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "subscribe", "topic": "%s"}`, v)) 152 | msg, _ := json.Marshal(messageInfo) 153 | tower.SendToClient(msg) 154 | } 155 | }) 156 | 157 | go tower.Run() 158 | 159 | ws.Close() 160 | 161 | if err := tower.SendToClientBlock([]byte("test")); err != nil { 162 | fmt.Println("send error", err) 163 | } 164 | } 165 | 166 | func buildClient(t *testing.T) *websocket.Conn { 167 | url := fmt.Sprintf("ws://%s%s", listenAddress, websocketPath) 168 | client, _, err := websocket.DefaultDialer.Dial(url, nil) 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | 173 | go func() { 174 | for { 175 | _, data, err := client.ReadMessage() 176 | if err != nil { 177 | return 178 | } 179 | 180 | fmt.Println("--- client receive message ---") 181 | fmt.Println(string(data)) 182 | } 183 | }() 184 | return client 185 | } 186 | 187 | func TestBaseTower(t *testing.T) { 188 | go startTower() 189 | time.Sleep(time.Second) 190 | 191 | client1 := buildClient(t) 192 | subMsg := protocol.TopicMessage[jsoniter.RawMessage]{ 193 | Topic: bindTopic, 194 | Type: protocol.SubscribeOperation, 195 | } 196 | if err := client1.WriteMessage(websocket.TextMessage, subMsg.Json()); err != nil { 197 | t.Fatal(err) 198 | } 199 | 200 | subMsg.Topic = "testtopic" 201 | if err := client1.WriteMessage(websocket.TextMessage, subMsg.Json()); err != nil { 202 | t.Fatal(err) 203 | } 204 | 205 | client2 := buildClient(t) 206 | if err := client2.WriteMessage(websocket.BinaryMessage, subMsg.Json()); err != nil { 207 | t.Fatal(err) 208 | } 209 | 210 | testMessage := protocol.TopicMessage[jsoniter.RawMessage]{ 211 | Topic: subMsg.Topic, 212 | Type: protocol.PublishOperation, 213 | Data: jsoniter.RawMessage([]byte("\"hi\"")), 214 | } 215 | 216 | fire := new(protocol.FireInfo[jsoniter.RawMessage]) // 从对象池中获取消息对象 降低GC压力 217 | fire.MessageType = 1 218 | if err := jsoniter.Unmarshal(testMessage.Json(), &fire.Message); err != nil { 219 | t.Fatal(err) 220 | } 221 | client1.WriteMessage(websocket.BinaryMessage, testMessage.Json()) 222 | 223 | time.Sleep(time.Minute) 224 | } 225 | -------------------------------------------------------------------------------- /example/websocket_service/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | "net/http" 7 | _ "net/http/pprof" 8 | "strconv" 9 | "strings" 10 | "time" 11 | 12 | "github.com/gorilla/websocket" 13 | "github.com/holdno/firetower/config" 14 | "github.com/holdno/firetower/protocol" 15 | "github.com/holdno/firetower/service/tower" 16 | towersvc "github.com/holdno/firetower/service/tower" 17 | "github.com/holdno/firetower/utils" 18 | "github.com/holdno/snowFlakeByGo" 19 | json "github.com/json-iterator/go" 20 | ) 21 | 22 | var upgrader = websocket.Upgrader{ 23 | CheckOrigin: func(r *http.Request) bool { 24 | return true 25 | }, 26 | } 27 | 28 | type messageInfo struct { 29 | From string `json:"from"` 30 | Data json.RawMessage `json:"data"` 31 | Type string `json:"type"` 32 | } 33 | 34 | // GlobalIdWorker 全局唯一id生成器 35 | var GlobalIdWorker *snowFlakeByGo.Worker 36 | 37 | var _ tower.PusherInfo = (*SystemPusher)(nil) 38 | 39 | type SystemPusher struct { 40 | clientID string 41 | } 42 | 43 | func (s *SystemPusher) UserID() string { 44 | return "system" 45 | } 46 | func (s *SystemPusher) ClientID() string { 47 | return s.clientID 48 | } 49 | 50 | var systemer *SystemPusher 51 | 52 | func main() { 53 | // 全局唯一id生成器 54 | tm, err := towersvc.Setup[json.RawMessage](config.FireTowerConfig{ 55 | WriteChanLens: 1000, 56 | Heartbeat: 30, 57 | ServiceMode: config.SingleMode, 58 | Bucket: config.BucketConfig{ 59 | Num: 4, 60 | CentralChanCount: 100000, 61 | BuffChanCount: 1000, 62 | ConsumerNum: 1, 63 | }, 64 | // Cluster: config.Cluster{ 65 | // RedisOption: config.Redis{ 66 | // Addr: "localhost:6379", 67 | // Password: "", 68 | // }, 69 | // NatsOption: config.Nats{ 70 | // Addr: "nats://localhost:4222", 71 | // UserName: "root", 72 | // Password: "", 73 | // ServerName: "", 74 | // }, 75 | // }, 76 | }) 77 | 78 | if err != nil { 79 | panic(err) 80 | } 81 | 82 | systemer = &SystemPusher{ 83 | clientID: "1", 84 | } 85 | 86 | go func() { 87 | for { 88 | time.Sleep(time.Second * 60) 89 | f := tm.NewFire(protocol.SourceSystem, systemer) 90 | f.Message.Topic = "/chat/world" 91 | f.Message.Data = []byte(fmt.Sprintf("{\"type\":\"publish\",\"data\":\"请通过 room 命令切换聊天室\",\"from\":\"system\"}")) 92 | tm.Publish(f) 93 | } 94 | }() 95 | 96 | tower := &Tower{ 97 | tm: tm, 98 | } 99 | http.HandleFunc("/ws", tower.Websocket) 100 | tm.Logger().Info("http server start", slog.String("address", "0.0.0.0:9999")) 101 | if err := http.ListenAndServe("0.0.0.0:9999", nil); err != nil { 102 | panic(err) 103 | } 104 | } 105 | 106 | type Tower struct { 107 | tm tower.Manager[json.RawMessage] 108 | } 109 | 110 | // Websocket http转websocket连接 并实例化firetower 111 | func (t *Tower) Websocket(w http.ResponseWriter, r *http.Request) { 112 | // 做用户身份验证 113 | 114 | // 验证成功才升级连接 115 | ws, _ := upgrader.Upgrade(w, r, nil) 116 | 117 | id := utils.IDWorker().GetId() 118 | tower, err := t.tm.BuildTower(ws, strconv.FormatInt(id, 10)) 119 | if err != nil { 120 | w.WriteHeader(http.StatusInternalServerError) 121 | w.Write([]byte(err.Error())) 122 | return 123 | } 124 | 125 | tower.SetReadHandler(func(fire protocol.ReadOnlyFire[json.RawMessage]) bool { 126 | // fire将会在handler执行结束后被回收 127 | messageInfo := new(messageInfo) 128 | err := json.Unmarshal(fire.GetMessage().Data, messageInfo) 129 | if err != nil { 130 | return false 131 | } 132 | msg := strings.Trim(string(messageInfo.Data), "\"") 133 | switch true { 134 | case strings.HasPrefix(msg, "/name "): 135 | tower.SetUserID(strings.TrimLeft(msg, "/name ")) 136 | messageInfo.From = "system" 137 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "change_name", "name": "%s"}`, tower.UserID())) 138 | messageInfo.Type = "event" 139 | raw, _ := json.Marshal(messageInfo) 140 | tower.SendToClient(raw) 141 | return false 142 | case strings.HasPrefix(msg, "/room "): 143 | if err = tower.UnSubscribe(fire.GetContext(), tower.TopicList()); err != nil { 144 | messageInfo.From = "system" 145 | messageInfo.Type = "event" 146 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "error", "msg": "切换房间失败, %s"}`, err.Error())) 147 | raw, _ := json.Marshal(messageInfo) 148 | tower.SendToClient(raw) 149 | return false 150 | } 151 | roomCode := strings.TrimLeft(msg, "/room ") 152 | if err = tower.Subscribe(fire.GetContext(), []string{"/chat/" + roomCode}); err != nil { 153 | messageInfo.From = "system" 154 | messageInfo.Type = "event" 155 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "error", "msg": "切换房间失败, %s, 请重新尝试"}`, err.Error())) 156 | raw, _ := json.Marshal(messageInfo) 157 | tower.SendToClient(raw) 158 | return false 159 | } 160 | 161 | return false 162 | } 163 | 164 | if tower.UserID() == "" { 165 | tower.SetUserID(messageInfo.From) 166 | } 167 | messageInfo.From = tower.UserID() 168 | f := fire.Copy() 169 | f.Message.Data, _ = json.Marshal(messageInfo) 170 | 171 | // 做发送验证 172 | // 判断发送方是否有权限向到达方发送内容 173 | tower.SendToClient(f.Message.Json()) 174 | return false 175 | }) 176 | 177 | tower.SetReceivedHandler(func(fi protocol.ReadOnlyFire[json.RawMessage]) bool { 178 | return true 179 | }) 180 | 181 | tower.SetReadTimeoutHandler(func(fire protocol.ReadOnlyFire[json.RawMessage]) { 182 | messageInfo := new(messageInfo) 183 | err := json.Unmarshal(fire.GetMessage().Data, messageInfo) 184 | if err != nil { 185 | return 186 | } 187 | messageInfo.Type = "timeout" 188 | b, _ := json.Marshal(messageInfo) 189 | tower.SendToClient(b) 190 | }) 191 | 192 | tower.SetBeforeSubscribeHandler(func(context protocol.FireLife, topic []string) bool { 193 | // 这里用来判断当前用户是否允许订阅该topic 194 | return true 195 | }) 196 | 197 | tower.SetSubscribeHandler(func(context protocol.FireLife, topic []string) { 198 | for _, v := range topic { 199 | if strings.HasPrefix(v, "/chat/") { 200 | roomCode := strings.TrimPrefix(v, "/chat/") 201 | messageInfo := new(messageInfo) 202 | messageInfo.From = "system" 203 | messageInfo.Type = "event" 204 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "change_room", "room": "%s"}`, roomCode)) 205 | msg, _ := json.Marshal(messageInfo) 206 | tower.SendToClient(msg) 207 | } 208 | } 209 | }) 210 | 211 | ticker := time.NewTicker(time.Millisecond * 500) 212 | go func() { 213 | topicConnCache := make(map[string]uint64) 214 | for { 215 | select { 216 | case <-tower.OnClose(): 217 | return 218 | case <-ticker.C: 219 | for _, v := range tower.TopicList() { 220 | num, err := tower.GetConnectNum(v) 221 | if err != nil { 222 | tower.Logger().Error("failed to get connect number", slog.Any("error", err)) 223 | continue 224 | } 225 | if topicConnCache[v] == num { 226 | continue 227 | } 228 | 229 | tower.SendToClient([]byte(fmt.Sprintf("{\"type\":\"onSubscribe\",\"data\":%d}", num))) 230 | topicConnCache[v] = num 231 | } 232 | } 233 | } 234 | }() 235 | 236 | tower.Run() 237 | } 238 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA= 2 | github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE= 3 | github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= 4 | github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= 9 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= 10 | github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= 11 | github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= 12 | github.com/go-redis/redis/v9 v9.0.0-beta.2 h1:ZSr84TsnQyKMAg8gnV+oawuQezeJR11/09THcWCQzr4= 13 | github.com/go-redis/redis/v9 v9.0.0-beta.2/go.mod h1:Bldcd/M/bm9HbnNPi/LUtYBSD8ttcZYBMupwMXhdU0o= 14 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 15 | github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= 16 | github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 17 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 18 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= 19 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 20 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 21 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= 22 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 23 | github.com/holdno/snowFlakeByGo v1.0.0 h1:YbS4Jx78sF688XG8iwTRUDtIIH6+R+WWVWyS8uq8xOg= 24 | github.com/holdno/snowFlakeByGo v1.0.0/go.mod h1:aqAI0YiLKgShMi9R71i5S81IWfb0x2ghGF2w1RjyNbs= 25 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 26 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 27 | github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4= 28 | github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= 29 | github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= 30 | github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= 31 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 32 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 33 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 34 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 35 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 36 | github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I= 37 | github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= 38 | github.com/nats-io/nats-server/v2 v2.8.4 h1:0jQzze1T9mECg8YZEl8+WYUXb9JKluJfCBriPUtluB4= 39 | github.com/nats-io/nats-server/v2 v2.8.4/go.mod h1:8zZa+Al3WsESfmgSs98Fi06dRWLH5Bnq90m5bKD/eT4= 40 | github.com/nats-io/nats.go v1.16.0 h1:zvLE7fGBQYW6MWaFaRdsgm9qT39PJDQoju+DS8KsO1g= 41 | github.com/nats-io/nats.go v1.16.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= 42 | github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= 43 | github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= 44 | github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= 45 | github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= 46 | github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= 47 | github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= 48 | github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= 49 | github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= 50 | github.com/onsi/gomega v1.20.0 h1:8W0cWlwFkflGPLltQvLRB7ZVD5HuP6ng320w2IS245Q= 51 | github.com/onsi/gomega v1.20.0/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= 52 | github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c= 53 | github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM= 54 | github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM= 55 | github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= 56 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 57 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 58 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 59 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 60 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 61 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 62 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 63 | github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= 64 | github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= 65 | github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= 66 | github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= 67 | golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= 68 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= 69 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 70 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 71 | golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA= 72 | golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= 73 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 74 | golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16CMAGuqwO2lX1mTyyRRc= 75 | golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 76 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 77 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 78 | golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= 79 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 80 | golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= 81 | golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 82 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 83 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 84 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 85 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 86 | google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= 87 | google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 88 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 89 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= 90 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= 91 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 92 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 93 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 94 | -------------------------------------------------------------------------------- /example/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | firetower client 6 | 7 | 8 | 239 | 240 |
241 |
242 | firetower 243 |
244 |
245 |
246 | 您的昵称:,当前房间:world,当前在线:1 247 |
248 |
249 | 切换房间 250 |
251 |
252 |
253 | 254 |
255 |
256 | 257 | 258 |
259 |
260 | 261 |
262 |
263 |
264 |
265 |
266 | 267 |
268 | 269 | 276 |
277 | 280 |
281 |
282 |
283 |
284 | 285 | 286 | 287 | 437 | 438 | -------------------------------------------------------------------------------- /service/tower/bucket.go: -------------------------------------------------------------------------------- 1 | package tower 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "os" 7 | "sync" 8 | "sync/atomic" 9 | "time" 10 | 11 | "github.com/avast/retry-go/v4" 12 | "github.com/gorilla/websocket" 13 | "github.com/holdno/firetower/config" 14 | "github.com/holdno/firetower/pkg/nats" 15 | "github.com/holdno/firetower/protocol" 16 | "github.com/holdno/firetower/store" 17 | "github.com/holdno/firetower/store/redis" 18 | "github.com/holdno/firetower/store/single" 19 | "github.com/holdno/firetower/utils" 20 | 21 | cmap "github.com/orcaman/concurrent-map/v2" 22 | ) 23 | 24 | type Manager[T any] interface { 25 | protocol.Pusher[T] 26 | BuildTower(ws *websocket.Conn, clientId string) (tower *FireTower[T], err error) 27 | BuildServerSideTower(clientId string) ServerSideTower[T] 28 | NewFire(source protocol.FireSource, tower PusherInfo) *protocol.FireInfo[T] 29 | GetTopics() (map[string]uint64, error) 30 | ClusterID() int64 31 | Store() stores 32 | Logger() protocol.Logger 33 | } 34 | 35 | // TowerManager 包含中心处理队列和多个bucket 36 | // bucket的作用是将一个实例的连接均匀的分布在多个bucket中来达到并发推送的目的 37 | type TowerManager[T any] struct { 38 | cfg config.FireTowerConfig 39 | bucket []*Bucket[T] 40 | centralChan chan *protocol.FireInfo[T] // 中心处理队列 41 | ip string 42 | clusterID int64 43 | timeout time.Duration 44 | 45 | stores stores 46 | logger protocol.Logger 47 | topicCounter chan counterMsg 48 | connCounter chan counterMsg 49 | 50 | coder protocol.Coder[T] 51 | protocol.Pusher[T] 52 | 53 | isClose bool 54 | closeChan chan struct{} 55 | 56 | brazier protocol.Brazier[T] 57 | 58 | onTopicCountChangedHandler func(Topic string) 59 | onConnCountChangedHandler func() 60 | } 61 | 62 | func (t *TowerManager[T]) SetTopicCountChangedHandler(f func(string)) { 63 | t.onTopicCountChangedHandler = f 64 | } 65 | 66 | func (t *TowerManager[T]) SetConnCountChangedHandler(f func()) { 67 | t.onConnCountChangedHandler = f 68 | } 69 | 70 | type counterMsg struct { 71 | Key string 72 | Num int64 73 | } 74 | 75 | type stores interface { 76 | ClusterConnStore() store.ClusterConnStore 77 | ClusterTopicStore() store.ClusterTopicStore 78 | ClusterStore() store.ClusterStore 79 | } 80 | 81 | // Bucket 的作用是将一个实例的连接均匀的分布在多个bucket中来达到并发推送的目的 82 | type Bucket[T any] struct { 83 | tm *TowerManager[T] 84 | mu sync.RWMutex // 读写锁,可并发读不可并发读写 85 | id int64 86 | len int64 87 | // topicRelevance map[string]map[string]*FireTower // topic -> websocket clientid -> websocket conn 88 | topicRelevance cmap.ConcurrentMap[string, cmap.ConcurrentMap[string, *FireTower[T]]] 89 | BuffChan chan *protocol.FireInfo[T] // bucket的消息处理队列 90 | sendTimeout time.Duration 91 | } 92 | 93 | type TowerOption[T any] func(t *TowerManager[T]) 94 | 95 | func BuildWithPusher[T any](pusher protocol.Pusher[T]) TowerOption[T] { 96 | return func(t *TowerManager[T]) { 97 | t.Pusher = pusher 98 | } 99 | } 100 | 101 | func BuildWithMessageDisposeTimeout[T any](timeout time.Duration) TowerOption[T] { 102 | return func(t *TowerManager[T]) { 103 | if timeout > 0 { 104 | t.timeout = timeout 105 | } 106 | } 107 | } 108 | 109 | func BuildWithCoder[T any](coder protocol.Coder[T]) TowerOption[T] { 110 | return func(t *TowerManager[T]) { 111 | t.coder = coder 112 | } 113 | } 114 | 115 | func BuildWithStore[T any](store stores) TowerOption[T] { 116 | return func(t *TowerManager[T]) { 117 | t.stores = store 118 | } 119 | } 120 | 121 | func BuildWithClusterID[T any](id int64) TowerOption[T] { 122 | return func(t *TowerManager[T]) { 123 | t.clusterID = id 124 | } 125 | } 126 | 127 | func BuildWithLogger[T any](logger protocol.Logger) TowerOption[T] { 128 | return func(t *TowerManager[T]) { 129 | t.logger = logger 130 | } 131 | } 132 | 133 | func BuildFoundation[T any](cfg config.FireTowerConfig, opts ...TowerOption[T]) (Manager[T], error) { 134 | tm := &TowerManager[T]{ 135 | cfg: cfg, 136 | bucket: make([]*Bucket[T], cfg.Bucket.Num), 137 | centralChan: make(chan *protocol.FireInfo[T], cfg.Bucket.CentralChanCount), 138 | topicCounter: make(chan counterMsg, 20000), 139 | connCounter: make(chan counterMsg, 20000), 140 | closeChan: make(chan struct{}), 141 | brazier: newBrazier[T](), 142 | logger: slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ 143 | Level: slog.LevelDebug, 144 | })), 145 | timeout: time.Second, 146 | } 147 | 148 | for _, opt := range opts { 149 | opt(tm) 150 | } 151 | 152 | var err error 153 | if tm.ip, err = utils.GetIP(); err != nil { 154 | panic(err) 155 | } 156 | 157 | if tm.coder == nil { 158 | tm.coder = &protocol.DefaultCoder[T]{} 159 | } 160 | 161 | if tm.Pusher == nil { 162 | if cfg.ServiceMode == config.SingleMode { 163 | tm.Pusher = protocol.DefaultPusher(tm.brazier, tm.coder, tm.logger) 164 | } else { 165 | tm.Pusher = nats.MustSetupNatsPusher(cfg.Cluster.NatsOption, tm.coder, tm.logger, func() map[string]uint64 { 166 | m, err := tm.stores.ClusterTopicStore().Topics() 167 | if err != nil { 168 | tm.logger.Error("failed to get current node topics from nats", slog.Any("error", err)) 169 | return map[string]uint64{} 170 | } 171 | return m 172 | }) 173 | } 174 | } 175 | 176 | if tm.stores == nil { 177 | if cfg.ServiceMode == config.SingleMode { 178 | tm.stores, _ = single.Setup() 179 | } else { 180 | if tm.stores, err = redis.Setup(cfg.Cluster.RedisOption.Addr, 181 | cfg.Cluster.RedisOption.Password, 182 | cfg.Cluster.RedisOption.DB, tm.ip, 183 | cfg.Cluster.RedisOption.KeyPrefix); err != nil { 184 | panic(err) 185 | } 186 | } 187 | } 188 | 189 | if tm.clusterID == 0 { 190 | clusterID, err := tm.stores.ClusterStore().ClusterNumber() 191 | if err != nil { 192 | return nil, err 193 | } 194 | tm.clusterID = clusterID 195 | } 196 | utils.SetupIDWorker(tm.clusterID) 197 | 198 | for i := range tm.bucket { 199 | tm.bucket[i] = newBucket[T](tm, cfg.Bucket.BuffChanCount, cfg.Bucket.ConsumerNum) 200 | } 201 | 202 | go func() { 203 | var ( 204 | connCounter int64 205 | topicCounter = make(map[string]int64) 206 | ticker = time.NewTicker(time.Millisecond * 500) 207 | clusterHeartbeat = time.NewTicker(time.Second * 3) 208 | ) 209 | 210 | reportConn := func(counter int64) { 211 | err := retry.Do(func() error { 212 | return tm.stores.ClusterConnStore().OneClientAtomicAddBy(tm.ip, counter) 213 | }, retry.Attempts(3), retry.LastErrorOnly(true)) 214 | if err != nil { 215 | tm.logger.Error("failed to update the number of redis websocket connections", slog.Any("error", err)) 216 | tm.connCounter <- counterMsg{ 217 | Key: tm.ip, 218 | Num: counter, 219 | } 220 | return 221 | } 222 | if tm.onConnCountChangedHandler != nil { 223 | tm.onConnCountChangedHandler() 224 | } 225 | } 226 | 227 | reportTopicConn := func(topicCounter map[string]int64) { 228 | for t, n := range topicCounter { 229 | err := retry.Do(func() error { 230 | return tm.stores.ClusterTopicStore().TopicConnAtomicAddBy(t, n) 231 | }, retry.Attempts(3), retry.LastErrorOnly(true)) 232 | if err != nil { 233 | tm.logger.Error("failed to update the number of connections for the topic in redis", slog.Any("error", err)) 234 | tm.topicCounter <- counterMsg{ 235 | Key: t, 236 | Num: n, 237 | } 238 | return 239 | } 240 | if tm.onTopicCountChangedHandler != nil { 241 | tm.onTopicCountChangedHandler(t) 242 | } 243 | } 244 | } 245 | for { 246 | select { 247 | case msg := <-tm.connCounter: 248 | connCounter += msg.Num 249 | case msg := <-tm.topicCounter: 250 | topicCounter[msg.Key] += msg.Num 251 | case <-ticker.C: 252 | if connCounter > 0 { 253 | go reportConn(connCounter) 254 | connCounter = 0 255 | clusterHeartbeat.Reset(time.Second * 3) 256 | } 257 | if len(topicCounter) > 0 { 258 | go reportTopicConn(topicCounter) 259 | topicCounter = make(map[string]int64) 260 | } 261 | case <-clusterHeartbeat.C: 262 | reportConn(0) 263 | case <-tm.closeChan: 264 | return 265 | } 266 | } 267 | }() 268 | 269 | // 执行中心处理器 将所有推送消息推送到bucketNum个bucket中 270 | go func() { 271 | for { 272 | select { 273 | case fire := <-tm.Receive(): 274 | for _, b := range tm.bucket { 275 | b.BuffChan <- fire 276 | } 277 | case <-tm.closeChan: 278 | return 279 | } 280 | } 281 | }() 282 | 283 | return tm, nil 284 | } 285 | 286 | func (t *TowerManager[T]) Logger() protocol.Logger { 287 | return t.logger 288 | } 289 | 290 | func (t *TowerManager[T]) Store() stores { 291 | return t.stores 292 | } 293 | 294 | func (t *TowerManager[T]) GetTopics() (map[string]uint64, error) { 295 | return t.stores.ClusterTopicStore().ClusterTopics() 296 | } 297 | 298 | func (t *TowerManager[T]) ClusterID() int64 { 299 | if t == nil { 300 | panic("firetower cluster not setup") 301 | } 302 | return t.clusterID 303 | } 304 | 305 | func newBucket[T any](tm *TowerManager[T], buff int64, consumerNum int) *Bucket[T] { 306 | b := &Bucket[T]{ 307 | tm: tm, 308 | id: getNewBucketId(), 309 | len: 0, 310 | topicRelevance: cmap.New[cmap.ConcurrentMap[string, *FireTower[T]]](), 311 | BuffChan: make(chan *protocol.FireInfo[T], buff), 312 | sendTimeout: tm.timeout, 313 | } 314 | 315 | if consumerNum == 0 { 316 | consumerNum = 1 317 | } 318 | // 每个bucket启动ConsumerNum个消费者(并发处理) 319 | for i := 0; i < consumerNum; i++ { 320 | go b.consumer() 321 | } 322 | return b 323 | } 324 | 325 | var ( 326 | bucketId int64 327 | connId uint64 328 | ) 329 | 330 | func getNewBucketId() int64 { 331 | atomic.AddInt64(&bucketId, 1) 332 | return bucketId 333 | } 334 | 335 | func getConnId() uint64 { 336 | atomic.AddUint64(&connId, 1) 337 | return connId 338 | } 339 | 340 | // GetBucket 获取一个可以分配当前连接的bucket 341 | func (t *TowerManager[T]) GetBucket(bt *FireTower[T]) (bucket *Bucket[T]) { 342 | bucket = t.bucket[bt.connID%uint64(len(t.bucket))] 343 | return 344 | } 345 | 346 | // 来自publish的消息 347 | func (b *Bucket[T]) consumer() { 348 | for { 349 | select { 350 | case fire := <-b.BuffChan: 351 | switch fire.Message.Type { 352 | case protocol.OfflineTopicByUserIdOperation: 353 | // 需要退订的topic和user_id 354 | // todo use api 355 | b.unSubscribeByUserId(fire) 356 | case protocol.OfflineTopicOperation: 357 | // todo use api 358 | b.unSubscribeAll(fire) 359 | case protocol.OfflineUserOperation: 360 | // todo use api 361 | b.offlineUsers(fire) 362 | default: 363 | b.push(fire) 364 | } 365 | } 366 | } 367 | } 368 | 369 | // AddSubscribe 添加当前实例中的topic->conn的订阅关系 370 | func (b *Bucket[T]) AddSubscribe(topic string, bt *FireTower[T]) { 371 | if m, ok := b.topicRelevance.Get(topic); ok { 372 | m.Set(bt.ClientID(), bt) 373 | } else { 374 | inner := cmap.New[*FireTower[T]]() 375 | inner.Set(bt.ClientID(), bt) 376 | b.topicRelevance.Set(topic, inner) 377 | } 378 | b.tm.topicCounter <- counterMsg{ 379 | Key: topic, 380 | Num: 1, 381 | } 382 | } 383 | 384 | // DelSubscribe 删除当前实例中的topic->conn的订阅关系 385 | func (b *Bucket[T]) DelSubscribe(topic string, bt *FireTower[T]) { 386 | if inner, ok := b.topicRelevance.Get(topic); ok { 387 | inner.Remove(bt.clientID) 388 | if inner.IsEmpty() { 389 | b.topicRelevance.Remove(topic) 390 | } 391 | } 392 | 393 | b.tm.topicCounter <- counterMsg{ 394 | Key: topic, 395 | Num: -1, 396 | } 397 | } 398 | 399 | // Push 桶内进行遍历push 400 | // 每个bucket有一个Push方法 401 | // 在推送时每个bucket同时调用Push方法 来达到并发推送 402 | // 该方法主要通过遍历桶中的topic->conn订阅关系来进行websocket写入 403 | func (b *Bucket[T]) push(message *protocol.FireInfo[T]) error { 404 | if m, ok := b.topicRelevance.Get(message.Message.Topic); ok { 405 | for _, v := range m.Items() { 406 | if v.isClose { 407 | continue 408 | } 409 | 410 | if v.receivedHandler != nil && !v.receivedHandler(message) { 411 | continue 412 | } 413 | 414 | if v.ws != nil { 415 | func() { 416 | ctx, cancel := context.WithTimeout(context.Background(), b.sendTimeout) 417 | defer cancel() 418 | select { 419 | case v.sendOut <- message.Message.Json(): 420 | case <-ctx.Done(): 421 | if v.sendTimeoutHandler != nil { 422 | v.sendTimeoutHandler(message) 423 | } 424 | } 425 | }() 426 | } 427 | } 428 | } 429 | return nil 430 | } 431 | 432 | // UnSubscribeByUserId 服务端指定某个用户退订某个topic 433 | func (b *Bucket[T]) unSubscribeByUserId(fire *protocol.FireInfo[T]) error { 434 | if m, ok := b.topicRelevance.Get(fire.Message.Topic); ok { 435 | userId, ok := fire.Context.ExtMeta[protocol.SYSTEM_CMD_REMOVE_USER] 436 | if ok { 437 | for _, v := range m.Items() { 438 | if v.UserID() == userId { 439 | v.unbindTopic([]string{fire.Message.Topic}) 440 | if v.unSubscribeHandler != nil { 441 | v.unSubscribeHandler(fire.Context, []string{fire.Message.Topic}) 442 | } 443 | return nil 444 | } 445 | } 446 | return nil 447 | } 448 | } 449 | 450 | return ErrorTopicEmpty 451 | } 452 | 453 | // UnSubscribeAll 移除所有该topic的订阅关系 454 | func (b *Bucket[T]) unSubscribeAll(fire *protocol.FireInfo[T]) error { 455 | if m, ok := b.topicRelevance.Get(fire.Message.Topic); ok { 456 | for _, v := range m.Items() { 457 | v.unbindTopic([]string{fire.Message.Topic}) 458 | // 移除所有人的应该不需要执行回调方法 459 | if v.onSystemRemove != nil { 460 | v.onSystemRemove(fire.Message.Topic) 461 | } 462 | } 463 | return nil 464 | } 465 | return ErrorTopicEmpty 466 | } 467 | 468 | func (b *Bucket[T]) offlineUsers(fire *protocol.FireInfo[T]) error { 469 | if m, ok := b.topicRelevance.Get(fire.Message.Topic); ok { 470 | userId, ok := fire.Context.ExtMeta[protocol.SYSTEM_CMD_REMOVE_USER] 471 | if ok { 472 | for _, v := range m.Items() { 473 | if v.UserID() == userId { 474 | v.Close() 475 | return nil 476 | } 477 | } 478 | } 479 | } 480 | return ErrorTopicEmpty 481 | } 482 | -------------------------------------------------------------------------------- /service/tower/tower.go: -------------------------------------------------------------------------------- 1 | package tower 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log/slog" 8 | "strings" 9 | "sync" 10 | "time" 11 | 12 | "github.com/gorilla/websocket" 13 | 14 | "github.com/holdno/firetower/config" 15 | "github.com/holdno/firetower/protocol" 16 | json "github.com/json-iterator/go" 17 | ) 18 | 19 | var ( 20 | towerPool sync.Pool 21 | ) 22 | 23 | type BlockMessage struct { 24 | sendOutMessage []byte 25 | response chan error 26 | } 27 | 28 | // FireTower 客户端连接结构体 29 | // 包含了客户端一个连接的所有信息 30 | type FireTower[T any] struct { 31 | tm *TowerManager[T] 32 | connID uint64 // 连接id 每台服务器上该id从1开始自增 33 | clientID string // 客户端id 用来做业务逻辑 34 | userID string // 一般业务中每个连接都是一个用户 用来给业务提供用户识别 35 | ext *sync.Map 36 | Cookie []byte // 这里提供给业务放一个存放跟当前连接相关的数据信息 37 | startTime time.Time 38 | 39 | logger protocol.Logger 40 | timeout time.Duration 41 | readIn chan *protocol.FireInfo[T] // 读取队列 42 | sendOut chan []byte // 发送队列 43 | sendOutBlock chan BlockMessage 44 | ws *websocket.Conn // 保存底层websocket连接 45 | topic sync.Map // 订阅topic列表 46 | isClose bool // 判断当前websocket是否被关闭 47 | closeChan chan struct{} // 用来作为关闭websocket的触发点 48 | mutex sync.Mutex // 避免并发close chan 49 | 50 | onConnectHandler func() bool 51 | onOfflineHandler func() 52 | receivedHandler func(protocol.ReadOnlyFire[T]) bool 53 | readHandler func(protocol.ReadOnlyFire[T]) bool 54 | readTimeoutHandler func(protocol.ReadOnlyFire[T]) 55 | sendTimeoutHandler func(protocol.ReadOnlyFire[T]) 56 | subscribeHandler func(context protocol.FireLife, topic []string) 57 | unSubscribeHandler func(context protocol.FireLife, topic []string) 58 | beforeSubscribeHandler func(context protocol.FireLife, topic []string) bool 59 | onSystemRemove func(topic string) 60 | } 61 | 62 | func (t *FireTower[T]) ClientID() string { 63 | return t.clientID 64 | } 65 | 66 | func (t *FireTower[T]) UserID() string { 67 | return t.userID 68 | } 69 | 70 | func (t *FireTower[T]) SetUserID(id string) { 71 | t.userID = id 72 | } 73 | 74 | func (t *FireTower[T]) Ext() *sync.Map { 75 | return t.ext 76 | } 77 | 78 | // Init 初始化firetower 79 | // 在调用firetower前请一定要先调用Init方法 80 | func Setup[T any](cfg config.FireTowerConfig, opts ...TowerOption[T]) (Manager[T], error) { 81 | towerPool.New = func() interface{} { 82 | return &FireTower[T]{} 83 | } 84 | // 构建服务架构 85 | return BuildFoundation(cfg, opts...) 86 | } 87 | 88 | // BuildTower 实例化一个websocket客户端 89 | func (t *TowerManager[T]) BuildTower(ws *websocket.Conn, clientId string) (tower *FireTower[T], err error) { 90 | if ws == nil { 91 | t.logger.Error("empty websocket connect") 92 | return nil, errors.New("empty websocket connect") 93 | } 94 | tower = buildNewTower(t, ws, clientId) 95 | return 96 | } 97 | 98 | func buildNewTower[T any](tm *TowerManager[T], ws *websocket.Conn, clientID string) *FireTower[T] { 99 | t := towerPool.Get().(*FireTower[T]) 100 | t.tm = tm 101 | t.connID = getConnId() 102 | t.clientID = clientID 103 | t.startTime = time.Now() 104 | t.ext = &sync.Map{} 105 | t.readIn = make(chan *protocol.FireInfo[T], tm.cfg.ReadChanLens) 106 | t.sendOut = make(chan []byte, tm.cfg.WriteChanLens) 107 | t.sendOutBlock = make(chan BlockMessage) 108 | t.topic = sync.Map{} 109 | t.ws = ws 110 | t.isClose = false 111 | t.closeChan = make(chan struct{}) 112 | t.timeout = time.Second * 3 113 | 114 | t.readHandler = nil 115 | t.readTimeoutHandler = nil 116 | t.sendTimeoutHandler = nil 117 | t.subscribeHandler = nil 118 | t.unSubscribeHandler = nil 119 | t.beforeSubscribeHandler = nil 120 | 121 | t.logger = tm.logger 122 | 123 | return t 124 | } 125 | 126 | func (t *FireTower[T]) OnClose() chan struct{} { 127 | return t.closeChan 128 | } 129 | 130 | // Run 启动websocket客户端 131 | func (t *FireTower[T]) Run() { 132 | t.logger.Debug("new tower builded") 133 | t.tm.connCounter <- counterMsg{ 134 | Key: t.tm.ip, 135 | Num: 1, 136 | } 137 | 138 | if t.ws == nil { 139 | // server side 140 | return 141 | } 142 | 143 | // 读取websocket信息 144 | go t.readLoop() 145 | // 处理读取事件 146 | go t.readDispose() 147 | 148 | if t.onConnectHandler != nil { 149 | ok := t.onConnectHandler() 150 | if !ok { 151 | t.Close() 152 | } 153 | } 154 | // 向websocket发送信息 155 | t.sendLoop() 156 | } 157 | 158 | func (t *FireTower[T]) TopicList() []string { 159 | var topics []string 160 | t.topic.Range(func(key, value any) bool { 161 | topics = append(topics, key.(string)) 162 | return true 163 | }) 164 | return topics 165 | } 166 | 167 | // 订阅topic的绑定过程 168 | func (t *FireTower[T]) bindTopic(topic []string) ([]string, error) { 169 | var ( 170 | addTopic []string 171 | ) 172 | bucket := t.tm.GetBucket(t) 173 | for _, v := range topic { 174 | if _, loaded := t.topic.LoadOrStore(v, struct{}{}); !loaded { 175 | addTopic = append(addTopic, v) // 待订阅的topic 176 | bucket.AddSubscribe(v, t) 177 | } 178 | // if _, ok := t.topic[v]; !ok { 179 | // addTopic = append(addTopic, v) // 待订阅的topic 180 | // t.topic[v] = true 181 | // bucket.AddSubscribe(v, t) 182 | // } 183 | } 184 | return addTopic, nil 185 | } 186 | 187 | func (t *FireTower[T]) unbindTopic(topic []string) []string { 188 | var delTopic []string // 待取消订阅的topic列表 189 | bucket := t.tm.GetBucket(t) 190 | for _, v := range topic { 191 | if _, loaded := t.topic.LoadAndDelete(v); loaded { 192 | // 如果客户端已经订阅过该topic才执行退订 193 | delTopic = append(delTopic, v) 194 | bucket.DelSubscribe(v, t) 195 | } 196 | } 197 | return delTopic 198 | } 199 | 200 | func (t *FireTower[T]) read() (*protocol.FireInfo[T], error) { 201 | if t.isClose { 202 | return nil, ErrorClosed 203 | } 204 | fire, ok := <-t.readIn 205 | if !ok { 206 | return nil, ErrorClosed 207 | } 208 | return fire, nil 209 | } 210 | 211 | // Close 关闭客户端连接并注销 212 | // 调用该方法会完全注销掉由BuildTower生成的一切内容 213 | func (t *FireTower[T]) Close() { 214 | t.logger.Debug("close connect") 215 | t.mutex.Lock() 216 | defer t.mutex.Unlock() 217 | if !t.isClose { 218 | t.isClose = true 219 | var topicSlice []string 220 | t.topic.Range(func(key, value any) bool { 221 | topicSlice = append(topicSlice, key.(string)) 222 | return true 223 | }) 224 | if len(topicSlice) > 0 { 225 | delTopic := t.unbindTopic(topicSlice) 226 | 227 | fire := t.tm.NewFire(protocol.SourceSystem, t) 228 | defer t.tm.brazier.Extinguished(fire) 229 | 230 | if t.unSubscribeHandler != nil { 231 | t.unSubscribeHandler(fire.Context, delTopic) 232 | } 233 | } 234 | 235 | if t.ws != nil { 236 | t.ws.Close() 237 | } 238 | 239 | t.tm.connCounter <- counterMsg{ 240 | Key: t.tm.ip, 241 | Num: -1, 242 | } 243 | close(t.closeChan) 244 | if t.onOfflineHandler != nil { 245 | t.onOfflineHandler() 246 | } 247 | towerPool.Put(t) 248 | t.logger.Debug("tower closed") 249 | } 250 | } 251 | 252 | var heartbeat = []byte{104, 101, 97, 114, 116, 98, 101, 97, 116} 253 | 254 | func (t *FireTower[T]) sendLoop() { 255 | heartTicker := time.NewTicker(time.Duration(t.tm.cfg.Heartbeat) * time.Second) 256 | defer func() { 257 | heartTicker.Stop() 258 | t.Close() 259 | close(t.sendOut) 260 | }() 261 | for { 262 | select { 263 | case wsMsg := <-t.sendOut: 264 | if t.ws != nil { 265 | if err := t.sendToClient(wsMsg); err != nil { 266 | return 267 | } 268 | } 269 | case wsMsg := <-t.sendOutBlock: 270 | if t.ws == nil { 271 | wsMsg.response <- fmt.Errorf("send to none websocket client") 272 | return 273 | } 274 | if err := t.sendToClient(wsMsg.sendOutMessage); err != nil { 275 | wsMsg.response <- err 276 | return 277 | } 278 | wsMsg.response <- nil 279 | case <-heartTicker.C: 280 | // sendMessage.Data = []byte{104, 101, 97, 114, 116, 98, 101, 97, 116} // []byte("heartbeat") 281 | if err := t.sendToClient(heartbeat); err != nil { 282 | return 283 | } 284 | case <-t.closeChan: 285 | return 286 | } 287 | } 288 | } 289 | 290 | func (t *FireTower[T]) readLoop() { 291 | if t.ws == nil { 292 | return 293 | } 294 | defer func() { 295 | if err := recover(); err != nil { 296 | t.tm.logger.Error("readloop panic", slog.Any("error", err)) 297 | } 298 | t.Close() 299 | close(t.readIn) 300 | }() 301 | for { 302 | messageType, data, err := t.ws.ReadMessage() 303 | if err != nil { // 断开连接 304 | return 305 | } 306 | fire := t.tm.NewFire(protocol.SourceClient, t) // 从对象池中获取消息对象 降低GC压力 307 | fire.MessageType = messageType 308 | if err := json.Unmarshal(data, &fire.Message); err != nil { 309 | t.logger.Error("failed to unmarshal client data, filtered", slog.Any("error", err)) 310 | continue 311 | } 312 | 313 | func() { 314 | ctx, cancel := context.WithTimeout(context.Background(), t.timeout) 315 | defer cancel() 316 | select { 317 | case t.readIn <- fire: 318 | return 319 | case <-ctx.Done(): 320 | if t.readTimeoutHandler != nil { 321 | t.readTimeoutHandler(fire) 322 | } 323 | t.logger.Error("readloop timeout", slog.Any("data", data)) 324 | case <-t.closeChan: 325 | } 326 | t.tm.brazier.Extinguished(fire) 327 | }() 328 | 329 | } 330 | } 331 | 332 | // 处理前端发来的数据 333 | // 这里要做逻辑拆分,判断用户是要进行通信还是topic订阅 334 | func (t *FireTower[T]) readDispose() { 335 | for { 336 | fire, err := t.read() 337 | if err != nil { 338 | t.logger.Error("failed to read message from websocket", slog.Any("error", err)) 339 | return 340 | } 341 | if fire != nil { 342 | go t.readLogic(fire) 343 | } 344 | } 345 | } 346 | 347 | func (t *FireTower[T]) readLogic(fire *protocol.FireInfo[T]) error { 348 | defer t.tm.brazier.Extinguished(fire) 349 | if t.isClose { 350 | return nil 351 | } else { 352 | if fire.Message.Topic == "" { 353 | t.logger.Error("the obtained topic is empty. this message will be filtered") 354 | return fmt.Errorf("%s:topic is empty, ClintId:%s, UserId:%s", fire.Message.Type, t.ClientID(), t.UserID()) 355 | } 356 | switch fire.Message.Type { 357 | case protocol.SubscribeOperation: // 客户端订阅topic 358 | addTopics := strings.Split(fire.Message.Topic, ",") 359 | // 增加messageId 方便追踪 360 | err := t.Subscribe(fire.Context, addTopics) 361 | if err != nil { 362 | t.logger.Error("failed to subscribe topics", slog.Any("topics", addTopics), slog.Any("error", err)) 363 | // TODO metrics 364 | return err 365 | } 366 | case protocol.UnSubscribeOperation: // 客户端取消订阅topic 367 | delTopic := strings.Split(fire.Message.Topic, ",") 368 | err := t.UnSubscribe(fire.Context, delTopic) 369 | if err != nil { 370 | t.logger.Error("failed to unsubscribe topics", slog.Any("topics", delTopic), slog.Any("error", err)) 371 | // TODO metrics 372 | return err 373 | } 374 | default: 375 | if t.readHandler != nil && !t.readHandler(fire) { 376 | return nil 377 | } 378 | t.Publish(fire) 379 | } 380 | } 381 | return nil 382 | } 383 | 384 | func (t *FireTower[T]) Subscribe(context protocol.FireLife, topics []string) error { 385 | // 如果设置了订阅前触发事件则调用 386 | if t.beforeSubscribeHandler != nil { 387 | ok := t.beforeSubscribeHandler(context, topics) 388 | if !ok { 389 | return nil 390 | } 391 | } 392 | addTopics, err := t.bindTopic(topics) 393 | if err != nil { 394 | return err 395 | } 396 | if t.subscribeHandler != nil { 397 | t.subscribeHandler(context, addTopics) 398 | } 399 | return nil 400 | } 401 | 402 | func (t *FireTower[T]) UnSubscribe(context protocol.FireLife, topics []string) error { 403 | delTopics := t.unbindTopic(topics) 404 | 405 | if t.unSubscribeHandler != nil { 406 | t.unSubscribeHandler(context, delTopics) 407 | } 408 | return nil 409 | } 410 | 411 | // Publish 推送接口 412 | // 通过BuildTower生成的实例都可以调用该方法来达到推送的目的 413 | func (t *FireTower[T]) Publish(fire *protocol.FireInfo[T]) error { 414 | err := t.tm.Publish(fire) 415 | if err != nil { 416 | t.logger.Error("failed to publish message", slog.Any("error", err)) 417 | return err 418 | } 419 | return nil 420 | } 421 | 422 | // SendToClient 向自己推送消息 423 | // 这里描述一下使用场景 424 | // 只针对当前客户端进行的推送请调用该方法 425 | func (t *FireTower[T]) SendToClient(b []byte) { 426 | t.mutex.Lock() 427 | defer t.mutex.Unlock() 428 | if !t.isClose { 429 | t.sendOut <- b 430 | } 431 | } 432 | 433 | func (t *FireTower[T]) SendToClientBlock(b []byte) error { 434 | t.mutex.Lock() 435 | defer t.mutex.Unlock() 436 | if t.isClose { 437 | return fmt.Errorf("send to closed firetower") 438 | } 439 | 440 | msg := BlockMessage{ 441 | sendOutMessage: b, 442 | response: make(chan error, 1), 443 | } 444 | t.sendOutBlock <- msg 445 | 446 | defer close(msg.response) 447 | return <-msg.response 448 | } 449 | 450 | func (t *FireTower[T]) sendToClient(b []byte) error { 451 | if t.ws == nil { 452 | return ErrorServerSideMode 453 | } 454 | 455 | if t.isClose != true { 456 | return t.ws.WriteMessage(websocket.TextMessage, b) 457 | } 458 | return ErrorClosed 459 | } 460 | 461 | // SetOnConnectHandler 建立连接事件 462 | func (t *FireTower[T]) SetOnConnectHandler(fn func() bool) { 463 | t.onConnectHandler = fn 464 | } 465 | 466 | // SetOnOfflineHandler 用户连接关闭时触发 467 | func (t *FireTower[T]) SetOnOfflineHandler(fn func()) { 468 | t.onOfflineHandler = fn 469 | } 470 | 471 | func (t *FireTower[T]) SetReceivedHandler(fn func(protocol.ReadOnlyFire[T]) bool) { 472 | t.receivedHandler = fn 473 | } 474 | 475 | // SetReadHandler 客户端推送事件 476 | // 接收到用户publish的消息时触发 477 | func (t *FireTower[T]) SetReadHandler(fn func(protocol.ReadOnlyFire[T]) bool) { 478 | t.readHandler = fn 479 | } 480 | 481 | // SetSubscribeHandler 订阅事件 482 | // 用户订阅topic后触发 483 | func (t *FireTower[T]) SetSubscribeHandler(fn func(context protocol.FireLife, topic []string)) { 484 | t.subscribeHandler = fn 485 | } 486 | 487 | // SetUnSubscribeHandler 取消订阅事件 488 | // 用户取消订阅topic后触发 489 | func (t *FireTower[T]) SetUnSubscribeHandler(fn func(context protocol.FireLife, topic []string)) { 490 | t.unSubscribeHandler = fn 491 | } 492 | 493 | // SetBeforeSubscribeHandler 订阅前回调事件 494 | // 用户订阅topic前触发 495 | func (t *FireTower[T]) SetBeforeSubscribeHandler(fn func(context protocol.FireLife, topic []string) bool) { 496 | t.beforeSubscribeHandler = fn 497 | } 498 | 499 | // SetReadTimeoutHandler 超时回调 500 | // readIn channal写满了 生产 > 消费的情况下触发超时机制 501 | func (t *FireTower[T]) SetReadTimeoutHandler(fn func(protocol.ReadOnlyFire[T])) { 502 | t.readTimeoutHandler = fn 503 | } 504 | 505 | func (t *FireTower[T]) SetSendTimeoutHandler(fn func(protocol.ReadOnlyFire[T])) { 506 | t.sendTimeoutHandler = fn 507 | } 508 | 509 | // SetOnSystemRemove 系统移除某个用户的topic订阅 510 | func (t *FireTower[T]) SetOnSystemRemove(fn func(topic string)) { 511 | t.onSystemRemove = fn 512 | } 513 | 514 | // GetConnectNum 获取话题订阅数的grpc方法封装 515 | func (t *FireTower[T]) GetConnectNum(topic string) (uint64, error) { 516 | number, err := t.tm.stores.ClusterTopicStore().GetTopicConnNum(topic) 517 | if err != nil { 518 | return 0, err 519 | } 520 | return number, nil 521 | } 522 | 523 | func (t *FireTower[T]) Logger() protocol.Logger { 524 | return t.logger 525 | } 526 | --------------------------------------------------------------------------------