├── .gitignore
├── example
├── websocket_service
│ ├── macBuildLiunx.sh
│ ├── fireTower.toml
│ └── main.go
└── web
│ ├── firetower_v0.1.js
│ └── index.html
├── config
├── topicmanage.toml
├── fireTower.toml
└── config.go
├── store
├── single
│ ├── cluster.go
│ ├── cluster.conn.go
│ ├── cluster.topic.go
│ └── provider.go
├── redis
│ ├── cluster.go
│ ├── provider.go
│ ├── redis_test.go
│ ├── cluster.topic.go
│ └── cluster.conn.go
└── store.go
├── protocol
├── errorinfo.go
├── coder.go
├── const.go
├── pusher.go
└── protocol.go
├── service
└── tower
│ ├── error_info.go
│ ├── brazier_test.go
│ ├── serverside.go
│ ├── brazier.go
│ ├── bucket.go
│ └── tower.go
├── pkg
└── nats
│ ├── nats_test.go
│ └── nats.go
├── utils
└── utils.go
├── LICENSE
├── go.mod
├── README.md
├── e2e
└── tower_test.go
└── go.sum
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | experiment
--------------------------------------------------------------------------------
/example/websocket_service/macBuildLiunx.sh:
--------------------------------------------------------------------------------
1 | CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build
2 |
--------------------------------------------------------------------------------
/config/topicmanage.toml:
--------------------------------------------------------------------------------
1 | # 内部socket通信配置
2 |
3 | heartbeat = 5 # 秒(s) 发送心跳包的间隔时间
4 | servicetimeout = 8 # 服务端超时时间 超过该时间没有收到client的心跳则断开连接
5 |
6 | [socket]
7 | port = 6666
--------------------------------------------------------------------------------
/store/single/cluster.go:
--------------------------------------------------------------------------------
1 | package single
2 |
3 | type ClusterStore struct {
4 | }
5 |
6 | func (s *ClusterStore) ClusterNumber() (int64, error) {
7 | return 1, nil
8 | }
9 |
--------------------------------------------------------------------------------
/protocol/errorinfo.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import "errors"
4 |
5 | var (
6 | // ErrorClose 连接关闭的错误信息
7 | ErrorClose = errors.New("firetower is collapsed")
8 | // ErrorBlock block错误信息
9 | ErrorBlock = errors.New("network congestion")
10 | )
11 |
--------------------------------------------------------------------------------
/example/websocket_service/fireTower.toml:
--------------------------------------------------------------------------------
1 | # 连接配置
2 |
3 | chanLens = 1000 # channal 缓冲区大小
4 | heartbeat = 30 # 心跳间隔 单位秒(s)
5 |
6 | [grpc]
7 | address = "localhost:6667"
8 |
9 | [bucket]
10 | Num = 4 # 启动多少个Bucket
11 | CentralChanCount = 100000 # 整台服务器的消息中心处理通道容量
12 | BuffChanCount = 10000 # 每个bucket的消息通道容量
13 | ConsumerNum = 2 # 每个bucket有多少个消费者同时向socket中推送消息
--------------------------------------------------------------------------------
/config/fireTower.toml:
--------------------------------------------------------------------------------
1 | # 连接配置
2 |
3 | chanLens = 1000 # channal 缓冲区大小
4 | heartbeat = 600 # 心跳间隔 单位秒(s)
5 |
6 | topicServiceAddr = "0.0.0.0:6666"
7 |
8 | clusterMode = "single"
9 |
10 | [bucket]
11 | num = 4 # 启动多少个Bucket
12 | centralChanCount = 100000 # 整台服务器的消息中心处理通道容量
13 | buffChanCount = 1000 # 每个bucket的消息通道容量
14 | consumerNum = 32 # 每个bucket有多少个消费者同时向socket中推送消息
--------------------------------------------------------------------------------
/protocol/coder.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import "github.com/vmihailenco/msgpack/v5"
4 |
5 | type DefaultCoder[T any] struct{}
6 |
7 | func (c *DefaultCoder[T]) Decode(data []byte, fire *FireInfo[T]) error {
8 | return msgpack.Unmarshal(data, fire)
9 | }
10 |
11 | func (c *DefaultCoder[T]) Encode(msg *FireInfo[T]) []byte {
12 | b, _ := msgpack.Marshal(msg)
13 | return b
14 | }
15 |
--------------------------------------------------------------------------------
/service/tower/error_info.go:
--------------------------------------------------------------------------------
1 | package tower
2 |
3 | import "errors"
4 |
5 | var (
6 | // ErrorClosed gateway连接已经关闭的错误信息
7 | ErrorClosed = errors.New("firetower is collapsed")
8 | // ErrorTopicEmpty topic不存在的错误信息
9 | ErrorTopicEmpty = errors.New("topic is empty")
10 | // Server Side Mode Can not send to self
11 | ErrorServerSideMode = errors.New("server side tower can not send to self")
12 | )
13 |
--------------------------------------------------------------------------------
/store/redis/cluster.go:
--------------------------------------------------------------------------------
1 | package redis
2 |
3 | import "context"
4 |
5 | type ClusterStore struct {
6 | provider *RedisProvider
7 | }
8 |
9 | func newClusterStore(provider *RedisProvider) *ClusterStore {
10 | return &ClusterStore{
11 | provider: provider,
12 | }
13 | }
14 |
15 | const (
16 | ClusterKey = "firetower_cluster_number"
17 | )
18 |
19 | func (s *ClusterStore) ClusterNumber() (int64, error) {
20 | res := s.provider.dbconn.Incr(context.TODO(), s.provider.keyPrefix+ClusterKey)
21 | return res.Val(), res.Err()
22 | }
23 |
--------------------------------------------------------------------------------
/store/store.go:
--------------------------------------------------------------------------------
1 | package store
2 |
3 | type ClusterConnStore interface {
4 | OneClientAtomicAddBy(clientIP string, num int64) error
5 | GetAllConnNum() (uint64, error)
6 | RemoveClient(clientIP string) error
7 | ClusterMembers() ([]string, error)
8 | }
9 |
10 | type ClusterTopicStore interface {
11 | TopicConnAtomicAddBy(topic string, num int64) error
12 | GetTopicConnNum(topic string) (uint64, error)
13 | RemoveTopic(topic string) error
14 | Topics() (map[string]uint64, error)
15 | ClusterTopics() (map[string]uint64, error)
16 | }
17 |
18 | type ClusterStore interface {
19 | ClusterNumber() (int64, error)
20 | }
21 |
--------------------------------------------------------------------------------
/pkg/nats/nats_test.go:
--------------------------------------------------------------------------------
1 | package nats
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "testing"
7 | "time"
8 |
9 | "github.com/nats-io/nats.go"
10 | )
11 |
12 | func TestNats(t *testing.T) {
13 | nc, err := nats.Connect("nats://localhost:4222", nats.Name("FireTower"), nats.UserInfo("firetower", "firetower"))
14 | if err != nil {
15 | log.Fatal(err)
16 | }
17 |
18 | topic := "chat.world."
19 |
20 | go func() {
21 | nc.Subscribe(topic+">", func(msg *nats.Msg) {
22 | fmt.Println("received", string(msg.Data))
23 | })
24 | }()
25 |
26 | time.Sleep(time.Second)
27 | if err := nc.Publish(topic+"1", []byte(`{"message":"hello"}`)); err != nil {
28 | t.Fatal(err)
29 | }
30 |
31 | time.Sleep(time.Second * 10)
32 | }
33 |
--------------------------------------------------------------------------------
/utils/utils.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "net"
5 |
6 | "github.com/holdno/snowFlakeByGo"
7 | )
8 |
9 | var (
10 | // IdWorker 全局唯一id生成器实例
11 | idWorker *snowFlakeByGo.Worker
12 | )
13 |
14 | func SetupIDWorker(clusterID int64) {
15 | idWorker, _ = snowFlakeByGo.NewWorker(clusterID)
16 | }
17 |
18 | func IDWorker() *snowFlakeByGo.Worker {
19 | return idWorker
20 | }
21 |
22 | // GetIP 获取当前服务器ip
23 | func GetIP() (string, error) {
24 | addrs, err := net.InterfaceAddrs()
25 | if err != nil {
26 | return "", nil
27 | }
28 | for _, a := range addrs {
29 | if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
30 | if ipnet.IP.To4() != nil {
31 | return ipnet.IP.String(), nil
32 | }
33 | }
34 | }
35 | return "127.0.0.1", nil
36 | }
37 |
--------------------------------------------------------------------------------
/service/tower/brazier_test.go:
--------------------------------------------------------------------------------
1 | package tower
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "sync"
7 | "sync/atomic"
8 | "testing"
9 | )
10 |
11 | func TestConnID(t *testing.T) {
12 | connId = math.MaxUint64
13 |
14 | atomic.AddUint64(&connId, 1)
15 | if connId != 0 {
16 | t.Fatal("wrong connid")
17 | }
18 | atomic.AddUint64(&connId, 1)
19 | if connId != 1 {
20 | t.Fatal("wrong connid")
21 | }
22 | }
23 |
24 | func TestSyncPool(t *testing.T) {
25 | var pool = sync.Pool{
26 | New: func() interface{} {
27 | fmt.Println("new")
28 | return &struct{}{}
29 | },
30 | }
31 | n := &struct{ Name string }{
32 | Name: "hhhh",
33 | }
34 |
35 | pool.Put(n)
36 | pool.Put(n)
37 | pool.Put(n)
38 | pool.Put(n)
39 |
40 | a := pool.Get()
41 | b := pool.Get()
42 |
43 | a.(*struct{ Name string }).Name = "123"
44 | fmt.Println(b.(*struct{ Name string }).Name)
45 |
46 | pool.Get()
47 | pool.Get()
48 | pool.Get()
49 | }
50 |
--------------------------------------------------------------------------------
/store/single/cluster.conn.go:
--------------------------------------------------------------------------------
1 | package single
2 |
3 | import (
4 | "sync"
5 | )
6 |
7 | type ClusterConnStore struct {
8 | storage map[string]int64
9 |
10 | sync.RWMutex
11 | }
12 |
13 | func (s *ClusterConnStore) ClusterMembers() ([]string, error) {
14 | var list []string
15 | for k := range s.storage {
16 | list = append(list, k)
17 | }
18 | return list, nil
19 | }
20 |
21 | func (s *ClusterConnStore) OneClientAtomicAddBy(clientIP string, num int64) error {
22 | s.Lock()
23 | defer s.Unlock()
24 | s.storage[clientIP] += num
25 | return nil
26 | }
27 |
28 | func (s *ClusterConnStore) GetAllConnNum() (uint64, error) {
29 | s.RLock()
30 | defer s.RUnlock()
31 |
32 | var result int64
33 | for _, v := range s.storage {
34 | result += v
35 | }
36 | return uint64(result), nil
37 | }
38 |
39 | func (s *ClusterConnStore) RemoveClient(clientIP string) error {
40 | s.Lock()
41 | defer s.Unlock()
42 |
43 | delete(s.storage, clientIP)
44 | return nil
45 | }
46 |
--------------------------------------------------------------------------------
/store/single/cluster.topic.go:
--------------------------------------------------------------------------------
1 | package single
2 |
3 | import (
4 | "sync"
5 | )
6 |
7 | type ClusterTopicStore struct {
8 | storage map[string]int64
9 |
10 | sync.RWMutex
11 | }
12 |
13 | func (s *ClusterTopicStore) TopicConnAtomicAddBy(topic string, num int64) error {
14 | s.Lock()
15 | defer s.Unlock()
16 |
17 | s.storage[topic] += num
18 | return nil
19 | }
20 |
21 | func (s *ClusterTopicStore) RemoveTopic(topic string) error {
22 | s.Lock()
23 | defer s.Unlock()
24 | delete(s.storage, topic)
25 | return nil
26 | }
27 |
28 | func (s *ClusterTopicStore) GetTopicConnNum(topic string) (uint64, error) {
29 | s.RLock()
30 | defer s.RUnlock()
31 | return uint64(s.storage[topic]), nil
32 | }
33 |
34 | func (s *ClusterTopicStore) Topics() (map[string]uint64, error) {
35 | result := make(map[string]uint64)
36 | for k, v := range result {
37 | result[k] = uint64(v)
38 | }
39 | return result, nil
40 | }
41 |
42 | func (s *ClusterTopicStore) ClusterTopics() (map[string]uint64, error) {
43 | return s.Topics()
44 | }
45 |
--------------------------------------------------------------------------------
/store/single/provider.go:
--------------------------------------------------------------------------------
1 | package single
2 |
3 | import "github.com/holdno/firetower/store"
4 |
5 | var provider *SingleProvider
6 |
7 | type SingleProvider struct {
8 | clusterConnStore store.ClusterConnStore
9 | clusterTopicStore store.ClusterTopicStore
10 | clusterStore store.ClusterStore
11 | }
12 |
13 | func Setup() (*SingleProvider, error) {
14 | provider = &SingleProvider{
15 | clusterConnStore: &ClusterConnStore{
16 | storage: make(map[string]int64),
17 | },
18 | clusterTopicStore: &ClusterTopicStore{
19 | storage: make(map[string]int64),
20 | },
21 | clusterStore: &ClusterStore{},
22 | }
23 | return provider, nil
24 | }
25 |
26 | func Provider() *SingleProvider {
27 | return provider
28 | }
29 |
30 | func (s *SingleProvider) ClusterConnStore() store.ClusterConnStore {
31 | return s.clusterConnStore
32 | }
33 |
34 | func (s *SingleProvider) ClusterTopicStore() store.ClusterTopicStore {
35 | return s.clusterTopicStore
36 | }
37 |
38 | func (s *SingleProvider) ClusterStore() store.ClusterStore {
39 | return s.clusterStore
40 | }
41 |
--------------------------------------------------------------------------------
/service/tower/serverside.go:
--------------------------------------------------------------------------------
1 | package tower
2 |
3 | import (
4 | "github.com/holdno/firetower/protocol"
5 | )
6 |
7 | type ServerSideTower[T any] interface {
8 | SetOnConnectHandler(fn func() bool)
9 | SetOnOfflineHandler(fn func())
10 | SetReceivedHandler(fn func(protocol.ReadOnlyFire[T]) bool)
11 | SetSubscribeHandler(fn func(context protocol.FireLife, topic []string))
12 | SetUnSubscribeHandler(fn func(context protocol.FireLife, topic []string))
13 | SetBeforeSubscribeHandler(fn func(context protocol.FireLife, topic []string) bool)
14 | SetOnSystemRemove(fn func(topic string))
15 | GetConnectNum(topic string) (uint64, error)
16 | Publish(fire *protocol.FireInfo[T]) error
17 | Subscribe(context protocol.FireLife, topics []string) error
18 | UnSubscribe(context protocol.FireLife, topics []string) error
19 | Logger() protocol.Logger
20 | TopicList() []string
21 | Run()
22 | Close()
23 | OnClose() chan struct{}
24 | }
25 |
26 | // BuildTower 实例化一个websocket客户端
27 | func (t *TowerManager[T]) BuildServerSideTower(clientId string) ServerSideTower[T] {
28 | return buildNewTower(t, nil, clientId)
29 | }
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 wby
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/holdno/firetower
2 |
3 | go 1.22
4 |
5 | require (
6 | github.com/gorilla/websocket v1.5.0
7 | github.com/holdno/snowFlakeByGo v1.0.0
8 | github.com/json-iterator/go v1.1.12
9 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
10 | github.com/modern-go/reflect2 v1.0.2 // indirect
11 | github.com/pelletier/go-toml v1.9.4
12 | )
13 |
14 | require (
15 | github.com/avast/retry-go/v4 v4.6.0
16 | github.com/go-redis/redis/v9 v9.0.0-beta.2
17 | github.com/nats-io/nats.go v1.16.0
18 | github.com/orcaman/concurrent-map/v2 v2.0.1
19 | github.com/vmihailenco/msgpack/v5 v5.3.5
20 | )
21 |
22 | require (
23 | github.com/cespare/xxhash/v2 v2.1.2 // indirect
24 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
25 | github.com/golang/protobuf v1.5.2 // indirect
26 | github.com/nats-io/nats-server/v2 v2.8.4 // indirect
27 | github.com/nats-io/nkeys v0.3.0 // indirect
28 | github.com/nats-io/nuid v1.0.1 // indirect
29 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
30 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect
31 | google.golang.org/protobuf v1.27.1 // indirect
32 | )
33 |
--------------------------------------------------------------------------------
/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/pelletier/go-toml"
7 | )
8 |
9 | // FireTowerConfig 每个连接的配置信息
10 | type FireTowerConfig struct {
11 | ReadChanLens int
12 | WriteChanLens int
13 | Heartbeat int
14 | ServiceMode string
15 | Bucket BucketConfig
16 | Cluster Cluster
17 | }
18 |
19 | type Cluster struct {
20 | RedisOption Redis
21 | NatsOption Nats
22 | }
23 |
24 | type Redis struct {
25 | KeyPrefix string
26 | Addr string
27 | Password string
28 | DB int
29 | }
30 |
31 | type Nats struct {
32 | Addr string
33 | ServerName string
34 | UserName string
35 | Password string
36 | SubjectPrefix string
37 | }
38 |
39 | type BucketConfig struct {
40 | Num int
41 | CentralChanCount int64
42 | BuffChanCount int64
43 | ConsumerNum int
44 | }
45 |
46 | const (
47 | SingleMode = "single"
48 | ClusterMode = "cluster"
49 | DefaultServiceMode = SingleMode
50 | )
51 |
52 | var (
53 | // ConfigTree 保存配置
54 | ConfigTree *toml.Tree
55 | )
56 |
57 | func loadConfig(path string) {
58 | var (
59 | err error
60 | )
61 | if ConfigTree, err = toml.LoadFile(path); err != nil {
62 | fmt.Println("config load failed:", err)
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/service/tower/brazier.go:
--------------------------------------------------------------------------------
1 | package tower
2 |
3 | import (
4 | "sync"
5 |
6 | "github.com/holdno/firetower/protocol"
7 | )
8 |
9 | func newBrazier[T any]() *brazier[T] {
10 | return &brazier[T]{
11 | pool: &sync.Pool{
12 | New: func() interface{} {
13 | return &protocol.FireInfo[T]{
14 | Context: protocol.FireLife{},
15 | Message: protocol.TopicMessage[T]{},
16 | }
17 | },
18 | },
19 | }
20 | }
21 |
22 | type brazier[T any] struct {
23 | len int
24 | pool *sync.Pool
25 | }
26 |
27 | func (b *brazier[T]) Extinguished(fire *protocol.FireInfo[T]) {
28 | if fire == nil || b.len > 100000 {
29 | return
30 | }
31 | var empty T
32 | b.len++
33 | fire.MessageType = 0
34 | fire.Message.Data = empty
35 | fire.Message.Topic = ""
36 | fire.Message.Type = 0
37 | b.pool.Put(fire)
38 | }
39 |
40 | func (b *brazier[T]) LightAFire() *protocol.FireInfo[T] {
41 | b.len--
42 | return b.pool.Get().(*protocol.FireInfo[T])
43 | }
44 |
45 | type PusherInfo interface {
46 | ClientID() string
47 | UserID() string
48 | }
49 |
50 | func (t *TowerManager[T]) NewFire(source protocol.FireSource, tower PusherInfo) *protocol.FireInfo[T] {
51 | f := t.brazier.LightAFire()
52 | f.Message.Type = protocol.PublishOperation
53 | f.Context.Reset(source, tower.ClientID(), tower.UserID())
54 | return f
55 | }
56 |
--------------------------------------------------------------------------------
/protocol/const.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | type FireSource uint8
4 |
5 | const (
6 | SourceClient FireSource = 1
7 | SourceLogic FireSource = 2
8 | SourceSystem FireSource = 3
9 | )
10 |
11 | func (s FireSource) String() string {
12 | switch s {
13 | case SourceClient:
14 | return "client"
15 | case SourceLogic:
16 | return "logic"
17 | case SourceSystem:
18 | return "system"
19 | default:
20 | return "unknown"
21 | }
22 | }
23 |
24 | type FireOperation uint8
25 |
26 | const (
27 | // PublishKey 与前端(客户端约定的推送关键字)
28 | PublishOperation FireOperation = 1
29 | SubscribeOperation FireOperation = 2
30 | UnSubscribeOperation FireOperation = 3
31 | // OfflineTopicByUserIdKey 踢除,将用户某个topic踢下线
32 | OfflineTopicByUserIdOperation FireOperation = 4
33 | // OfflineTopicKey 针对某个topic进行踢除
34 | OfflineTopicOperation FireOperation = 5
35 | // OfflineUserKey 将某个用户踢下线
36 | OfflineUserOperation FireOperation = 6
37 | )
38 |
39 | func (o FireOperation) String() string {
40 | switch o {
41 | case PublishOperation:
42 | return "publish"
43 | case SubscribeOperation:
44 | return "subscribe"
45 | case UnSubscribeOperation:
46 | return "unSubscribe"
47 | case OfflineTopicByUserIdOperation:
48 | return "offlineTopicByUserid"
49 | case OfflineTopicOperation:
50 | return "offlineTopic"
51 | case OfflineUserOperation:
52 | return "offlineUser"
53 | default:
54 | return "unknown"
55 | }
56 | }
57 |
58 | const (
59 | SYSTEM_CMD_REMOVE_USER = "remove_user"
60 | )
61 |
--------------------------------------------------------------------------------
/protocol/pusher.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "log/slog"
5 | "sync"
6 | )
7 |
8 | type Pusher[T any] interface {
9 | Publish(fire *FireInfo[T]) error
10 | Receive() chan *FireInfo[T]
11 | }
12 |
13 | type Logger interface {
14 | Info(string, ...any)
15 | Debug(string, ...any)
16 | Error(string, ...any)
17 | }
18 |
19 | type SinglePusher[T any] struct {
20 | msg chan []byte
21 | b Brazier[T]
22 | once sync.Once
23 | coder Coder[T]
24 | fireChan chan *FireInfo[T]
25 | logger Logger
26 | }
27 |
28 | func (s *SinglePusher[T]) Publish(fire *FireInfo[T]) error {
29 | s.msg <- s.coder.Encode(fire)
30 | return nil
31 | }
32 |
33 | func (s *SinglePusher[T]) Receive() chan *FireInfo[T] {
34 | s.once.Do(func() {
35 | go func() {
36 | for {
37 | select {
38 | case m := <-s.msg:
39 | fire := new(FireInfo[T])
40 | err := s.coder.Decode(m, fire)
41 | if err != nil {
42 | s.logger.Error("failed to decode message", slog.String("data", string(m)), slog.String("error", err.Error()))
43 | continue
44 | }
45 | s.fireChan <- fire
46 | }
47 | }
48 | }()
49 | })
50 | return s.fireChan
51 | }
52 |
53 | type Brazier[T any] interface {
54 | Extinguished(fire *FireInfo[T])
55 | LightAFire() *FireInfo[T]
56 | }
57 |
58 | func DefaultPusher[T any](b Brazier[T], coder Coder[T], logger Logger) *SinglePusher[T] {
59 | return &SinglePusher[T]{
60 | msg: make(chan []byte, 100),
61 | b: b,
62 | coder: coder,
63 | fireChan: make(chan *FireInfo[T], 10000),
64 | logger: logger,
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/store/redis/provider.go:
--------------------------------------------------------------------------------
1 | package redis
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/go-redis/redis/v9"
7 | "github.com/holdno/firetower/store"
8 | )
9 |
10 | var provider *RedisProvider
11 |
12 | type RedisProvider struct {
13 | dbconn *redis.Client
14 | clientIP string
15 | keyPrefix string
16 |
17 | clusterConnStore store.ClusterConnStore
18 | clusterTopicStore store.ClusterTopicStore
19 | clusterStore store.ClusterStore
20 | }
21 |
22 | func Setup(addr, password string, db int, clientIP, keyPrefix string) (*RedisProvider, error) {
23 | if clientIP == "" {
24 | panic("setup redis store: required client IP")
25 | }
26 | provider = &RedisProvider{
27 | keyPrefix: keyPrefix,
28 | dbconn: redis.NewClient(&redis.Options{
29 | Addr: addr,
30 | Password: password, // no password set
31 | DB: db, // use default DB
32 | }),
33 | clientIP: clientIP,
34 | }
35 |
36 | provider.clusterConnStore = newClusterConnStore(provider)
37 | provider.clusterTopicStore = newClusterTopicStore(provider)
38 | provider.clusterStore = newClusterStore(provider)
39 |
40 | res := provider.dbconn.Ping(context.TODO())
41 | if res.Err() != nil {
42 | return nil, res.Err()
43 | }
44 |
45 | return provider, nil
46 | }
47 |
48 | func (s *RedisProvider) ClusterConnStore() store.ClusterConnStore {
49 | return s.clusterConnStore
50 | }
51 |
52 | func (s *RedisProvider) ClusterTopicStore() store.ClusterTopicStore {
53 | return s.clusterTopicStore
54 | }
55 |
56 | func (s *RedisProvider) ClusterStore() store.ClusterStore {
57 | return s.clusterStore
58 | }
59 |
--------------------------------------------------------------------------------
/pkg/nats/nats.go:
--------------------------------------------------------------------------------
1 | package nats
2 |
3 | import (
4 | "log/slog"
5 | "sync"
6 |
7 | "github.com/holdno/firetower/config"
8 | "github.com/holdno/firetower/protocol"
9 |
10 | "github.com/nats-io/nats.go"
11 | )
12 |
13 | var _ protocol.Pusher[any] = (*pusher[any])(nil)
14 |
15 | func MustSetupNatsPusher[T any](cfg config.Nats, coder protocol.Coder[T], logger protocol.Logger, topicFunc func() map[string]uint64) protocol.Pusher[T] {
16 | if cfg.SubjectPrefix == "" {
17 | cfg.SubjectPrefix = "firetower.topic."
18 | }
19 | p := &pusher[T]{
20 | subjectPrefix: cfg.SubjectPrefix,
21 | coder: coder,
22 | currentTopic: topicFunc,
23 | msg: make(chan *protocol.FireInfo[T], 10000),
24 | logger: logger,
25 | }
26 | var err error
27 | p.nats, err = nats.Connect(cfg.Addr, nats.Name(cfg.ServerName), nats.UserInfo(cfg.UserName, cfg.Password))
28 | if err != nil {
29 | panic(err)
30 | }
31 | return p
32 | }
33 |
34 | type pusher[T any] struct {
35 | subjectPrefix string
36 | msg chan *protocol.FireInfo[T]
37 | nats *nats.Conn
38 | once sync.Once
39 | b protocol.Brazier[T]
40 | coder protocol.Coder[T]
41 | currentTopic func() map[string]uint64
42 | logger protocol.Logger
43 | }
44 |
45 | func (p *pusher[T]) Publish(fire *protocol.FireInfo[T]) error {
46 | msg := nats.NewMsg(p.subjectPrefix + fire.Message.Topic)
47 | msg.Header.Set("topic", fire.Message.Topic)
48 | msg.Data = p.coder.Encode(fire)
49 | return p.nats.PublishMsg(msg)
50 | }
51 |
52 | func (p *pusher[T]) Receive() chan *protocol.FireInfo[T] {
53 | p.once.Do(func() {
54 | p.nats.Subscribe(p.subjectPrefix+">", func(msg *nats.Msg) {
55 | topic := msg.Header.Get("topic")
56 | if _, exist := p.currentTopic()[topic]; exist {
57 | fire := new(protocol.FireInfo[T])
58 | if err := p.coder.Decode(msg.Data, fire); err != nil {
59 | p.logger.Error("failed to decode message", slog.String("data", string(msg.Data)), slog.Any("error", err), slog.String("topic", topic))
60 | return
61 | }
62 | p.msg <- fire
63 | }
64 | msg.Ack()
65 | })
66 | })
67 |
68 | return p.msg
69 | }
70 |
--------------------------------------------------------------------------------
/protocol/protocol.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "strconv"
5 | "time"
6 |
7 | json "github.com/json-iterator/go"
8 |
9 | "github.com/holdno/firetower/utils"
10 | )
11 |
12 | // type Message struct {
13 | // Ctx MessageContext `json:"c"`
14 | // Data []byte `json:"d"`
15 | // Topic string `json:"t"`
16 | // }
17 |
18 | // type MessageContext struct {
19 | // ID string `json:"i"`
20 | // MsgTime int64 `json:"m"`
21 | // Source string `json:"s"`
22 | // PushType int `json:"p"`
23 | // Type string `json:"t"`
24 | // }
25 |
26 | type Coder[T any] interface {
27 | Decode([]byte, *FireInfo[T]) error
28 | Encode(msg *FireInfo[T]) []byte
29 | }
30 |
31 | type WebSocketMessage struct {
32 | MessageType int
33 | Data []byte
34 | }
35 |
36 | // FireInfo 接收的消息结构体
37 | type FireInfo[T any] struct {
38 | Context FireLife `json:"c"`
39 | MessageType int `json:"t"`
40 | Message TopicMessage[T] `json:"m"`
41 | }
42 |
43 | func (f *FireInfo[T]) Copy() FireInfo[T] {
44 | return *f
45 | }
46 |
47 | func (f *FireInfo[T]) GetContext() FireLife {
48 | return f.Context
49 | }
50 |
51 | func (f *FireInfo[T]) GetMessage() TopicMessage[T] {
52 | return f.Message
53 | }
54 |
55 | type ReadOnlyFire[T any] interface {
56 | GetContext() FireLife
57 | GetMessage() TopicMessage[T]
58 | Copy() FireInfo[T]
59 | }
60 |
61 | // TopicMessage 话题信息结构体
62 | type TopicMessage[T any] struct {
63 | Topic string `json:"topic"`
64 | Data T `json:"data,omitempty"`
65 | Type FireOperation `json:"type"`
66 | }
67 |
68 | func (s *TopicMessage[T]) Json() []byte {
69 | raw, _ := json.Marshal(s)
70 | return raw
71 | }
72 |
73 | type TowerInfo interface {
74 | UserID() string
75 | ClientID() string
76 | }
77 |
78 | // FireLife 客户端推送消息的结构体
79 | type FireLife struct {
80 | ID string `json:"i"`
81 | StartTime time.Time `json:"t"`
82 | ClientID string `json:"c"`
83 | UserID string `json:"u"`
84 | Source FireSource `json:"s"`
85 | ExtMeta map[string]string `json:"e,omitempty"`
86 | }
87 |
88 | func (f *FireLife) Reset(source FireSource, clientID, userID string) {
89 | f.StartTime = time.Now()
90 | f.ID = strconv.FormatInt(utils.IDWorker().GetId(), 10)
91 | f.ClientID = clientID
92 | f.UserID = userID
93 | f.Source = source
94 | }
95 |
--------------------------------------------------------------------------------
/example/web/firetower_v0.1.js:
--------------------------------------------------------------------------------
1 | function firetower(addr, onopen) {
2 | var ws = new WebSocket(addr);
3 | ws.onopen = onopen
4 |
5 | var _this = this
6 |
7 | this.publishKey = 1
8 | this.subscribeKey = 2
9 | this.unSubscribeKey = 3
10 |
11 | this.logopen = true // 开启log
12 |
13 | var logInfo = function(data){
14 | console.log('[firetower] INFO', data)
15 | }
16 |
17 | this.publish = function(topic, data){
18 | if (topic == '' || data == '') {
19 | return errorMessage('topic或data参数不能为空')
20 | }
21 |
22 | if (_this.logopen) {
23 | logInfo('publish topic:"' + topic + '", data:' + JSON.stringify(data))
24 | }
25 |
26 | ws.send(JSON.stringify({
27 | type: _this.publishKey,
28 | topic: topic,
29 | data: data
30 | }))
31 | return successMessage('发送成功')
32 | }
33 |
34 | this.onmessage = false
35 | ws.onmessage = function(event){
36 | if (_this.logopen) {
37 | logInfo('new message:' + JSON.stringify(event.data))
38 | }
39 |
40 | if (event.data == 'heartbeat') {
41 | return
42 | }
43 |
44 | if (_this.onmessage) {
45 | _this.onmessage(event)
46 | }
47 | }
48 |
49 | this.onclose = false
50 | ws.onclose = function(){
51 | if (_this.onclose) {
52 | _this.onclose()
53 | }
54 | }
55 |
56 | this.subscribe = function(topicArr){
57 | if (!Array.isArray(topicArr)) {
58 | topicArr = [topicArr]
59 | }
60 |
61 | if (_this.logopen) {
62 | logInfo('subscribe:"' + topicArr.join(',') + '"')
63 | }
64 |
65 | ws.send(JSON.stringify({
66 | type: _this.subscribeKey,
67 | topic: topicArr.join(','),
68 | data: ''
69 | }))
70 |
71 |
72 | }
73 |
74 | this.unsubscribe = function(topicArr){
75 | if (!Array.isArray(topicArr)) {
76 | topicArr = [topicArr]
77 | }
78 |
79 | if (_this.logopen) {
80 | logInfo('unSubscribe:"' + topicArr.join(',') + '"')
81 | }
82 |
83 | ws.send(JSON.stringify({
84 | type: _this.unSubscribeKey,
85 | topic: topicArr.join(','),
86 | data: ''
87 | }))
88 | }
89 |
90 | function errorMessage(info){
91 | return {
92 | type: 'error',
93 | info: info
94 | }
95 | }
96 |
97 | function successMessage(info){
98 | return {
99 | type: 'success',
100 | info: info
101 | }
102 | }
103 | }
104 |
105 |
--------------------------------------------------------------------------------
/store/redis/redis_test.go:
--------------------------------------------------------------------------------
1 | package redis
2 |
3 | import (
4 | "context"
5 | "encoding/binary"
6 | "fmt"
7 | "testing"
8 | "time"
9 |
10 | "github.com/go-redis/redis/v9"
11 | "github.com/holdno/firetower/utils"
12 | )
13 |
14 | func TestRedis_HSet(t *testing.T) {
15 | rdb := redis.NewClient(&redis.Options{
16 | Addr: "localhost:6379",
17 | Password: "", // no password set
18 | DB: 0, // use default DB
19 | })
20 |
21 | ip, err := utils.GetIP()
22 | if err != nil {
23 | t.Fatal(err)
24 | }
25 | res := rdb.HSet(context.Background(), ClusterConnKey, ip, fmt.Sprintf("%d|%d", 2315, time.Now().Unix()))
26 | if res.Err() != nil {
27 | t.Fatal(res.Err())
28 | }
29 |
30 | gRes := rdb.HGet(context.Background(), ClusterConnKey, ip)
31 | if gRes.Err() != nil {
32 | t.Fatal(gRes.Err())
33 | }
34 |
35 | t.Log(gRes.Val())
36 |
37 | if gRes.Val() != fmt.Sprintf("%d|%d", 2315, time.Now().Unix()) {
38 | t.Fatal(gRes.Val())
39 | }
40 | t.Log("successful")
41 | }
42 |
43 | func TestLua(t *testing.T) {
44 | rdb := redis.NewClient(&redis.Options{
45 | Addr: "localhost:6379",
46 | Password: "", // no password set
47 | DB: 0, // use default DB
48 | })
49 | res1 := rdb.Del(context.TODO(), ClusterConnKey)
50 | if res1.Err() != nil {
51 | t.Fatal(res1.Err())
52 | }
53 |
54 | b := make([]byte, 16)
55 | binary.LittleEndian.PutUint64(b, 2315)
56 | binary.LittleEndian.PutUint64(b[8:], uint64(time.Now().Unix()-10))
57 | res := rdb.HSet(context.TODO(), ClusterConnKey, "localhost", string(b))
58 | if res.Err() != nil {
59 | t.Fatal(res.Err())
60 | }
61 |
62 | result := rdb.Eval(context.TODO(), clusterShutdownCheckerScript, []string{ClusterConnKey, fmt.Sprintf("%d", time.Now().Unix())}, 2)
63 | if result.Err() != nil {
64 | t.Fatal(result.Err())
65 | }
66 | t.Log(result.Val())
67 | }
68 |
69 | func TestPack(t *testing.T) {
70 | var number uint64 = 3512
71 | r := packClientConnNumberNow(number)
72 | n, err := unpackClientConnNumberNow(r + "1")
73 | if err != nil {
74 | t.Fatal(err)
75 | }
76 |
77 | if n != number {
78 | t.Fatal(n, number)
79 | }
80 | t.Log("successful")
81 | }
82 |
83 | func TestLock(t *testing.T) {
84 | lockKey := "ft_cluster_master"
85 | rdb := redis.NewClient(&redis.Options{
86 | Addr: "localhost:6379",
87 | Password: "", // no password set
88 | DB: 0, // use default DB
89 | })
90 | locker := func(clusterID string) {
91 | for {
92 | res := rdb.SetNX(context.Background(), lockKey, clusterID, time.Second*3)
93 | if res.Val() {
94 | for {
95 | ticker := time.NewTicker(time.Second * 1)
96 | select {
97 | case <-ticker.C:
98 | res := rdb.Eval(context.Background(), keepMasterScript, []string{lockKey, clusterID}, 2)
99 | if res.Val() != "success" || res.Err() != nil {
100 | ticker.Stop()
101 | break
102 | }
103 | fmt.Println(clusterID, "expire")
104 | }
105 | }
106 | }
107 | time.Sleep(time.Second)
108 | }
109 | }
110 |
111 | go locker("1")
112 | go locker("2")
113 | go locker("3")
114 | go locker("4")
115 |
116 | time.Sleep(time.Second * 20)
117 | }
118 |
--------------------------------------------------------------------------------
/store/redis/cluster.topic.go:
--------------------------------------------------------------------------------
1 | package redis
2 |
3 | import (
4 | "context"
5 | "strconv"
6 | "sync"
7 |
8 | "github.com/holdno/firetower/store"
9 | )
10 |
11 | type ClusterTopicStore struct {
12 | storage map[string]int64
13 | provider *RedisProvider
14 |
15 | sync.RWMutex
16 | }
17 |
18 | func newClusterTopicStore(provider *RedisProvider) store.ClusterTopicStore {
19 | s := &ClusterTopicStore{
20 | storage: make(map[string]int64),
21 | provider: provider,
22 | }
23 | if err := s.init(); err != nil {
24 | panic(err)
25 | }
26 | return s
27 | }
28 |
29 | const (
30 | ClusterTopicKeyPrefix = "ft_topics_conn_"
31 | )
32 |
33 | func (s *ClusterTopicStore) getTopicKey(clientIP string) string {
34 | return s.provider.keyPrefix + ClusterTopicKeyPrefix + clientIP
35 | }
36 |
37 | func (s *ClusterTopicStore) init() error {
38 | res := s.provider.dbconn.Del(context.TODO(), s.getTopicKey(s.provider.clientIP))
39 | return res.Err()
40 | }
41 |
42 | func (s *ClusterTopicStore) TopicConnAtomicAddBy(topic string, num int64) error {
43 | s.Lock()
44 | defer s.Unlock()
45 | s.storage[topic] += num
46 | res := s.provider.dbconn.HSet(context.TODO(), s.getTopicKey(s.provider.clientIP), topic, s.storage[topic])
47 | if res.Err() != nil {
48 | s.storage[topic] -= num
49 | return res.Err()
50 | }
51 |
52 | if s.storage[topic] == 0 {
53 | res := s.provider.dbconn.HDel(context.TODO(), s.getTopicKey(s.provider.clientIP), topic)
54 | if res.Err() != nil {
55 | return res.Err()
56 | }
57 | }
58 |
59 | return nil
60 | }
61 |
62 | func (s *ClusterTopicStore) RemoveTopic(topic string) error {
63 | res := s.provider.dbconn.HDel(context.TODO(), s.getTopicKey(s.provider.clientIP), topic)
64 | if res.Err() != nil {
65 | return res.Err()
66 | }
67 | s.Lock()
68 | defer s.Unlock()
69 | delete(s.storage, topic)
70 | return nil
71 | }
72 |
73 | func (s *ClusterTopicStore) GetTopicConnNum(topic string) (uint64, error) {
74 | members, err := s.provider.ClusterConnStore().ClusterMembers()
75 | if err != nil {
76 | return 0, err
77 | }
78 |
79 | var total uint64
80 | for _, v := range members {
81 | res := s.provider.dbconn.HGet(context.TODO(), s.getTopicKey(v), topic)
82 | if res.Err() != nil {
83 | return 0, res.Err()
84 | }
85 | num, _ := strconv.ParseUint(res.Val(), 10, 64)
86 | total += num
87 | }
88 |
89 | return total, nil
90 | }
91 |
92 | func (s *ClusterTopicStore) Topics() (map[string]uint64, error) {
93 | res := s.provider.dbconn.HGetAll(context.TODO(), s.getTopicKey(s.provider.clientIP))
94 | if res.Err() != nil {
95 | return nil, res.Err()
96 | }
97 |
98 | result := make(map[string]uint64)
99 | for k, v := range res.Val() {
100 | vInt, err := strconv.ParseUint(v, 10, 64)
101 | if err != nil {
102 | return nil, err
103 | }
104 | result[k] = vInt
105 | }
106 |
107 | return result, nil
108 | }
109 |
110 | func (s *ClusterTopicStore) ClusterTopics() (map[string]uint64, error) {
111 | members, err := s.provider.ClusterConnStore().ClusterMembers()
112 | if err != nil {
113 | return nil, err
114 | }
115 |
116 | result := make(map[string]uint64)
117 | for _, v := range members {
118 | res := s.provider.dbconn.HGetAll(context.TODO(), s.getTopicKey(v))
119 | if res.Err() != nil {
120 | return nil, res.Err()
121 | }
122 | for k, v := range res.Val() {
123 | vInt, err := strconv.ParseUint(v, 10, 64)
124 | if err != nil {
125 | return nil, err
126 | }
127 | result[k] += vInt
128 | }
129 | }
130 |
131 | return result, nil
132 | }
133 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Firetower
12 | firetower是一个用golang开发的分布式推送(IM)服务
13 |
14 | 完全基于websocket封装,围绕topic进行sub/pub
15 | 自身实现订阅管理服务,无需依赖redis
16 | 聊天室demo体验地址: http://chat.ojbk.io
17 | ### 可用版本
18 | go get github.com/holdno/firetower@v0.4.0
19 | ### 构成
20 |
21 | 基本服务由两点构成
22 | - topic管理服务
23 | > 详见示例 example/topicService
24 |
25 | 该服务主要作为集群环境下唯一的topic管理节点
26 | firetower一定要依赖这个管理节点才能正常工作
27 | 大型项目可以将该服务单独部署在一台独立的服务器上,小项目可以同连接层服务一起部署在一台机器上
28 | - 连接层服务(websocket服务)
29 | > 详见示例 example/websocketService
30 |
31 | websocket服务是用户基于firetower自定义开发的业务逻辑
32 | 可以通过firetower提供的回调方法来实现自己的业务逻辑
33 | (web client 在 example/web 下)
34 | ### 架构图
35 | 
36 | ### 接入姿势
37 | ``` golang
38 | package main
39 |
40 | import (
41 | "fmt"
42 | "github.com/gorilla/websocket"
43 | "github.com/holdno/firetower/gateway"
44 | "github.com/holdno/snowFlakeByGo" // 这是一个分布式全局唯一id生成器
45 | "net/http"
46 | "strconv"
47 | )
48 |
49 | var upgrader = websocket.Upgrader{
50 | CheckOrigin: func(r *http.Request) bool {
51 | return true
52 | },
53 | }
54 |
55 | var GlobalIdWorker *snowFlakeByGo.Worker
56 |
57 | func main() {
58 | GlobalIdWorker, _ = snowFlakeByGo.NewWorker(1)
59 | // 如果是集群环境 一定一定要给每个服务设置唯一的id
60 | // 取值范围 1-1024
61 | gateway.ClusterId = 1
62 | http.HandleFunc("/ws", Websocket)
63 | fmt.Println("websocket service start: 0.0.0.0:9999")
64 | http.ListenAndServe("0.0.0.0:9999", nil)
65 | }
66 |
67 | func Websocket(w http.ResponseWriter, r *http.Request) {
68 | // 做用户身份验证
69 | ...
70 | // 验证成功才升级连接
71 | ws, _ := upgrader.Upgrade(w, r, nil)
72 | // 生成一个全局唯一的clientid 正常业务下这个clientid应该由前端传入
73 | id := GlobalIdWorker.GetId()
74 | tower := gateway.BuildTower(ws, strconv.FormatInt(id, 10)) // 生成一个烽火台
75 | tower.Run()
76 | }
77 | ```
78 | ### 目前支持的回调方法
79 | - ReadHandler 收到客户端发送的消息时触发
80 | ``` golang
81 | tower := gateway.BuildTower(ws, strconv.FormatInt(id, 10)) // 创建beacontower实例
82 | tower.SetReadHandler(func(fire *gateway.FireInfo) bool { // 绑定ReadHandler回调方法
83 | // message.Data 为客户端传来的信息
84 | // message.Topic 为消息传递的topic
85 | // 用户可在此做发送验证
86 | // 判断发送方是否有权限向到达方发送内容
87 | // 通过 Publish 方法将内容推送到所有订阅 message.Topic 的连接
88 | tower.Publish(message)
89 | return true
90 | })
91 | ```
92 |
93 | - ReadTimeoutHandler 客户端websocket请求超时处理(生产速度高于消费速度)
94 | ``` golang
95 | tower.SetReadTimeoutHandler(func(fire *gateway.FireInfo) {
96 | fmt.Println("read timeout:", fire.Message.Type, fire.Message.Topic, fire.Message.Data)
97 | })
98 | ```
99 |
100 | - BeforeSubscribeHandler 客户端订阅某些topic时触发(这个时候topic还没有订阅,是before subscribe)
101 | ``` golang
102 | tower.SetBeforeSubscribeHandler(func(context *gateway.FireLife, topic []string) bool {
103 | // 这里用来判断当前用户是否允许订阅该topic
104 | return true
105 | })
106 | ```
107 |
108 | - SubscribeHandler 客户端完成某些topic的订阅时触发(topic已经被topicService收录并管理)
109 | ``` golang
110 | tower.SetSubscribeHandler(func(context *gateway.FireLife, topic []string) bool {
111 | // 我们给出的聊天室示例是需要用到这个回调方法
112 | // 当某个聊天室(topic)有新的订阅者,则需要通知其他已经在聊天室内的成员当前在线人数+1
113 | for _, v := range topic {
114 | num := tower.GetConnectNum(v)
115 | // 继承订阅消息的context
116 | var pushmsg = gateway.NewFireInfo(tower, context)
117 | pushmsg.Message.Topic = v
118 | pushmsg.Message.Data = []byte(fmt.Sprintf("{\"type\":\"onSubscribe\",\"data\":%d}", num))
119 | tower.Publish(pushmsg)
120 | }
121 | return true
122 | })
123 | ```
124 |
125 | - UnSubscribeHandler 客户端取消订阅某些topic完成时触发 (这个回调方法没有设置before方法,目前没有想到什么场景会使用到before unsubscribe,如果有请issue联系)
126 | ``` golang
127 | tower.SetUnSubscribeHandler(func(context *gateway.FireLife, topic []string) bool {
128 | for _, v := range topic {
129 | num := tower.GetConnectNum(v)
130 | var pushmsg = gateway.NewFireInfo(tower, context)
131 | pushmsg.Message.Topic = v
132 | pushmsg.Message.Data = []byte(fmt.Sprintf("{\"type\":\"onUnsubscribe\",\"data\":%d}", num))
133 | tower.Publish(pushmsg)
134 | }
135 | return true
136 | })
137 | ```
138 | 注意:当客户端断开websocket连接时firetower会将其在线时订阅的所有topic进行退订 会触发UnSubscirbeHandler
139 |
140 | ## TODO
141 | - 运行时web看板
142 | - 提供推送相关http及grpc接口
143 |
144 | ## License
145 | [MIT](https://opensource.org/licenses/MIT)
146 |
--------------------------------------------------------------------------------
/store/redis/cluster.conn.go:
--------------------------------------------------------------------------------
1 | package redis
2 |
3 | import (
4 | "context"
5 | "encoding/binary"
6 | "fmt"
7 | "sync"
8 | "time"
9 | )
10 |
11 | type ClusterConnStore struct {
12 | storage map[string]int64
13 | provider *RedisProvider
14 | isMaster bool
15 |
16 | keepMasterScriptSHA string
17 | clusterShutdownCheckerScriptSHA string
18 |
19 | sync.RWMutex
20 | }
21 |
22 | const (
23 | ClusterConnKey = "ft_cluster_conn"
24 | )
25 |
26 | func newClusterConnStore(provider *RedisProvider) *ClusterConnStore {
27 | s := &ClusterConnStore{
28 | storage: make(map[string]int64),
29 | provider: provider,
30 | }
31 |
32 | res := s.provider.dbconn.ScriptLoad(context.TODO(), clusterShutdownCheckerScript)
33 | if res.Err() != nil {
34 | panic(res.Err())
35 | }
36 | s.clusterShutdownCheckerScriptSHA = res.Val()
37 |
38 | res = s.provider.dbconn.ScriptLoad(context.TODO(), keepMasterScript)
39 | if res.Err() != nil {
40 | panic(res.Err())
41 | }
42 | s.keepMasterScriptSHA = res.Val()
43 |
44 | go s.KeepClusterClear()
45 |
46 | if err := s.init(); err != nil {
47 | panic(err)
48 | }
49 | return s
50 | }
51 |
52 | func (s *ClusterConnStore) init() error {
53 | res := s.provider.dbconn.HDel(context.TODO(), s.provider.keyPrefix+ClusterConnKey, s.provider.clientIP)
54 | return res.Err()
55 | }
56 |
57 | func (s *ClusterConnStore) ClusterMembers() ([]string, error) {
58 | res := s.provider.dbconn.HGetAll(context.TODO(), s.provider.keyPrefix+ClusterConnKey)
59 | if res.Err() != nil {
60 | return nil, res.Err()
61 | }
62 | var list []string
63 | for k := range res.Val() {
64 | list = append(list, k)
65 | }
66 | return list, nil
67 | }
68 |
69 | func packClientConnNumberNow(num uint64) string {
70 | b := make([]byte, 16)
71 | binary.LittleEndian.PutUint64(b, num)
72 | binary.LittleEndian.PutUint64(b[8:], uint64(time.Now().Unix()))
73 | return string(b)
74 | }
75 |
76 | func unpackClientConnNumberNow(b string) (uint64, error) {
77 | if len([]byte(b)) != 16 {
78 | return 0, fmt.Errorf("wrong pack data, got lenght %d, need 16", len([]byte(b)))
79 | }
80 | return binary.LittleEndian.Uint64([]byte(b)[:8]), nil
81 | }
82 |
83 | func (s *ClusterConnStore) OneClientAtomicAddBy(clientIP string, num int64) error {
84 | s.Lock()
85 | defer s.Unlock()
86 | s.storage[clientIP] += num
87 |
88 | res := s.provider.dbconn.HSet(context.TODO(), s.provider.keyPrefix+ClusterConnKey, clientIP, packClientConnNumberNow(uint64(s.storage[clientIP])))
89 | if res.Err() != nil {
90 | s.storage[clientIP] -= num
91 | return res.Err()
92 | }
93 | return nil
94 | }
95 |
96 | func (s *ClusterConnStore) GetAllConnNum() (uint64, error) {
97 | res := s.provider.dbconn.HGetAll(context.TODO(), s.provider.keyPrefix+ClusterConnKey)
98 | if res.Err() != nil {
99 | return 0, res.Err()
100 | }
101 | var result uint64
102 | for _, v := range res.Val() {
103 | singleNumber, err := unpackClientConnNumberNow(v)
104 | if err != nil {
105 | return 0, err
106 | }
107 | result += singleNumber
108 | }
109 | return result, nil
110 | }
111 |
112 | func (s *ClusterConnStore) RemoveClient(clientIP string) error {
113 | res := s.provider.dbconn.HDel(context.TODO(), s.provider.keyPrefix+ClusterConnKey, clientIP)
114 | if res.Err() != nil {
115 | return res.Err()
116 | }
117 | s.Lock()
118 | defer s.Unlock()
119 | delete(s.storage, clientIP)
120 |
121 | return nil
122 | }
123 |
124 | const clusterShutdownCheckerScript = `
125 | local key = KEYS[1]
126 | local currentTime = KEYS[2]
127 | local clientConns = redis.call("hgetall", key)
128 | local delTable = {}
129 |
130 | for i = 1, #clientConns, 2 do
131 | local number, timestamp = struct.unpack(" tostring(timestamp + 5))
133 | then
134 | table.insert(delTable, clientConns[i])
135 | end
136 | end
137 |
138 | return redis.call("hdel", key, unpack(delTable))
139 | `
140 |
141 | func (s *ClusterConnStore) KeepClusterClear() {
142 | go s.SelectMaster()
143 | for {
144 | time.Sleep(time.Second)
145 |
146 | if !s.isMaster {
147 | continue
148 | }
149 |
150 | result := s.provider.dbconn.EvalSha(context.TODO(), s.clusterShutdownCheckerScriptSHA, []string{s.provider.keyPrefix + ClusterConnKey, fmt.Sprintf("%d", time.Now().Unix())}, 2)
151 | if result.Err() != nil {
152 | // todo log
153 | continue
154 | }
155 | }
156 | }
157 |
158 | const keepMasterScript = `
159 | local lockKey = KEYS[1]
160 | local currentID = KEYS[2]
161 | local currentMaster = redis.call('get', lockKey)
162 | local success = "fail"
163 |
164 | if (currentID == currentMaster)
165 | then
166 | redis.call('expire', lockKey, 3)
167 | success = "success"
168 | end
169 | return success
170 | `
171 |
172 | func (s *ClusterConnStore) SelectMaster() {
173 | lockKey := s.provider.keyPrefix + "ft_cluster_master"
174 | for {
175 | res := s.provider.dbconn.SetNX(context.Background(), lockKey, s.provider.clientIP, time.Second*3)
176 | if res.Val() {
177 | s.isMaster = true
178 | ticker := time.NewTicker(time.Second * 1)
179 | for {
180 | select {
181 | case <-ticker.C:
182 | res := s.provider.dbconn.EvalSha(context.Background(), s.keepMasterScriptSHA, []string{lockKey, s.provider.clientIP}, 2)
183 | if res.Val() != "success" || res.Err() != nil {
184 | s.isMaster = false
185 | ticker.Stop()
186 | break
187 | }
188 | }
189 | }
190 | }
191 | time.Sleep(time.Second)
192 | }
193 | }
194 |
--------------------------------------------------------------------------------
/e2e/tower_test.go:
--------------------------------------------------------------------------------
1 | package e2e
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "log/slog"
7 | "net/http"
8 | "strconv"
9 | "testing"
10 | "time"
11 |
12 | "github.com/gorilla/websocket"
13 | "github.com/holdno/firetower/config"
14 | "github.com/holdno/firetower/protocol"
15 | towersvc "github.com/holdno/firetower/service/tower"
16 | "github.com/holdno/firetower/utils"
17 | "github.com/holdno/snowFlakeByGo"
18 | jsoniter "github.com/json-iterator/go"
19 | )
20 |
21 | var upgrader = websocket.Upgrader{
22 | CheckOrigin: func(r *http.Request) bool {
23 | return true
24 | },
25 | }
26 |
27 | type messageInfo struct {
28 | From string `json:"from"`
29 | Data json.RawMessage `json:"data"`
30 | Type string `json:"type"`
31 | }
32 |
33 | // GlobalIdWorker 全局唯一id生成器
34 | var GlobalIdWorker *snowFlakeByGo.Worker
35 |
36 | var _ towersvc.PusherInfo = (*SystemPusher)(nil)
37 |
38 | type SystemPusher struct {
39 | clientID string
40 | }
41 |
42 | func (s *SystemPusher) UserID() string {
43 | return "system"
44 | }
45 | func (s *SystemPusher) ClientID() string {
46 | return s.clientID
47 | }
48 |
49 | var systemer *SystemPusher
50 |
51 | const (
52 | listenAddress = "127.0.0.1:9999"
53 | websocketPath = "/ws"
54 | )
55 |
56 | func startTower() {
57 | // 全局唯一id生成器
58 | tm, err := towersvc.Setup[jsoniter.RawMessage](config.FireTowerConfig{
59 | WriteChanLens: 1000,
60 | ReadChanLens: 1000,
61 | Heartbeat: 30,
62 | ServiceMode: config.SingleMode,
63 | Bucket: config.BucketConfig{
64 | Num: 4,
65 | CentralChanCount: 100000,
66 | BuffChanCount: 1000,
67 | ConsumerNum: 2,
68 | },
69 | })
70 |
71 | if err != nil {
72 | panic(err)
73 | }
74 |
75 | systemer = &SystemPusher{
76 | clientID: "1",
77 | }
78 |
79 | tower := &Tower{
80 | tm: tm,
81 | }
82 | http.HandleFunc(websocketPath, tower.Websocket)
83 | tm.Logger().Info("http server start", slog.String("address", listenAddress))
84 | if err := http.ListenAndServe(listenAddress, nil); err != nil {
85 | panic(err)
86 | }
87 | }
88 |
89 | type Tower struct {
90 | tm towersvc.Manager[jsoniter.RawMessage]
91 | }
92 |
93 | const (
94 | bindTopic = "bindtopic"
95 | )
96 |
97 | // Websocket http转websocket连接 并实例化firetower
98 | func (t *Tower) Websocket(w http.ResponseWriter, r *http.Request) {
99 | // 做用户身份验证
100 |
101 | // 验证成功才升级连接
102 | ws, _ := upgrader.Upgrade(w, r, nil)
103 |
104 | id := utils.IDWorker().GetId()
105 | tower, err := t.tm.BuildTower(ws, strconv.FormatInt(id, 10))
106 | if err != nil {
107 | w.WriteHeader(http.StatusInternalServerError)
108 | w.Write([]byte(err.Error()))
109 | return
110 | }
111 |
112 | tower.SetReadHandler(func(fire protocol.ReadOnlyFire[jsoniter.RawMessage]) bool {
113 | return true
114 | })
115 |
116 | tower.SetReceivedHandler(func(fi protocol.ReadOnlyFire[jsoniter.RawMessage]) bool {
117 | return true
118 | })
119 |
120 | tower.SetReadTimeoutHandler(func(fire protocol.ReadOnlyFire[jsoniter.RawMessage]) {
121 | messageInfo := new(messageInfo)
122 | err := json.Unmarshal(fire.GetMessage().Data, messageInfo)
123 | if err != nil {
124 | return
125 | }
126 | messageInfo.Type = "timeout"
127 | b, _ := json.Marshal(messageInfo)
128 | tower.SendToClient(b)
129 | })
130 |
131 | tower.SetBeforeSubscribeHandler(func(context protocol.FireLife, topic []string) bool {
132 | // 这里用来判断当前用户是否允许订阅该topic
133 | for _, v := range topic {
134 | if v == bindTopic {
135 | messageInfo := new(messageInfo)
136 | messageInfo.From = "system"
137 | messageInfo.Type = "event"
138 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "bind", "topic": "%s"}`, v))
139 | msg, _ := json.Marshal(messageInfo)
140 | tower.SendToClient(msg)
141 | }
142 | }
143 | return true
144 | })
145 |
146 | tower.SetSubscribeHandler(func(context protocol.FireLife, topic []string) {
147 | for _, v := range topic {
148 | messageInfo := new(messageInfo)
149 | messageInfo.From = "system"
150 | messageInfo.Type = "event"
151 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "subscribe", "topic": "%s"}`, v))
152 | msg, _ := json.Marshal(messageInfo)
153 | tower.SendToClient(msg)
154 | }
155 | })
156 |
157 | go tower.Run()
158 |
159 | ws.Close()
160 |
161 | if err := tower.SendToClientBlock([]byte("test")); err != nil {
162 | fmt.Println("send error", err)
163 | }
164 | }
165 |
166 | func buildClient(t *testing.T) *websocket.Conn {
167 | url := fmt.Sprintf("ws://%s%s", listenAddress, websocketPath)
168 | client, _, err := websocket.DefaultDialer.Dial(url, nil)
169 | if err != nil {
170 | t.Fatal(err)
171 | }
172 |
173 | go func() {
174 | for {
175 | _, data, err := client.ReadMessage()
176 | if err != nil {
177 | return
178 | }
179 |
180 | fmt.Println("--- client receive message ---")
181 | fmt.Println(string(data))
182 | }
183 | }()
184 | return client
185 | }
186 |
187 | func TestBaseTower(t *testing.T) {
188 | go startTower()
189 | time.Sleep(time.Second)
190 |
191 | client1 := buildClient(t)
192 | subMsg := protocol.TopicMessage[jsoniter.RawMessage]{
193 | Topic: bindTopic,
194 | Type: protocol.SubscribeOperation,
195 | }
196 | if err := client1.WriteMessage(websocket.TextMessage, subMsg.Json()); err != nil {
197 | t.Fatal(err)
198 | }
199 |
200 | subMsg.Topic = "testtopic"
201 | if err := client1.WriteMessage(websocket.TextMessage, subMsg.Json()); err != nil {
202 | t.Fatal(err)
203 | }
204 |
205 | client2 := buildClient(t)
206 | if err := client2.WriteMessage(websocket.BinaryMessage, subMsg.Json()); err != nil {
207 | t.Fatal(err)
208 | }
209 |
210 | testMessage := protocol.TopicMessage[jsoniter.RawMessage]{
211 | Topic: subMsg.Topic,
212 | Type: protocol.PublishOperation,
213 | Data: jsoniter.RawMessage([]byte("\"hi\"")),
214 | }
215 |
216 | fire := new(protocol.FireInfo[jsoniter.RawMessage]) // 从对象池中获取消息对象 降低GC压力
217 | fire.MessageType = 1
218 | if err := jsoniter.Unmarshal(testMessage.Json(), &fire.Message); err != nil {
219 | t.Fatal(err)
220 | }
221 | client1.WriteMessage(websocket.BinaryMessage, testMessage.Json())
222 |
223 | time.Sleep(time.Minute)
224 | }
225 |
--------------------------------------------------------------------------------
/example/websocket_service/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "log/slog"
6 | "net/http"
7 | _ "net/http/pprof"
8 | "strconv"
9 | "strings"
10 | "time"
11 |
12 | "github.com/gorilla/websocket"
13 | "github.com/holdno/firetower/config"
14 | "github.com/holdno/firetower/protocol"
15 | "github.com/holdno/firetower/service/tower"
16 | towersvc "github.com/holdno/firetower/service/tower"
17 | "github.com/holdno/firetower/utils"
18 | "github.com/holdno/snowFlakeByGo"
19 | json "github.com/json-iterator/go"
20 | )
21 |
22 | var upgrader = websocket.Upgrader{
23 | CheckOrigin: func(r *http.Request) bool {
24 | return true
25 | },
26 | }
27 |
28 | type messageInfo struct {
29 | From string `json:"from"`
30 | Data json.RawMessage `json:"data"`
31 | Type string `json:"type"`
32 | }
33 |
34 | // GlobalIdWorker 全局唯一id生成器
35 | var GlobalIdWorker *snowFlakeByGo.Worker
36 |
37 | var _ tower.PusherInfo = (*SystemPusher)(nil)
38 |
39 | type SystemPusher struct {
40 | clientID string
41 | }
42 |
43 | func (s *SystemPusher) UserID() string {
44 | return "system"
45 | }
46 | func (s *SystemPusher) ClientID() string {
47 | return s.clientID
48 | }
49 |
50 | var systemer *SystemPusher
51 |
52 | func main() {
53 | // 全局唯一id生成器
54 | tm, err := towersvc.Setup[json.RawMessage](config.FireTowerConfig{
55 | WriteChanLens: 1000,
56 | Heartbeat: 30,
57 | ServiceMode: config.SingleMode,
58 | Bucket: config.BucketConfig{
59 | Num: 4,
60 | CentralChanCount: 100000,
61 | BuffChanCount: 1000,
62 | ConsumerNum: 1,
63 | },
64 | // Cluster: config.Cluster{
65 | // RedisOption: config.Redis{
66 | // Addr: "localhost:6379",
67 | // Password: "",
68 | // },
69 | // NatsOption: config.Nats{
70 | // Addr: "nats://localhost:4222",
71 | // UserName: "root",
72 | // Password: "",
73 | // ServerName: "",
74 | // },
75 | // },
76 | })
77 |
78 | if err != nil {
79 | panic(err)
80 | }
81 |
82 | systemer = &SystemPusher{
83 | clientID: "1",
84 | }
85 |
86 | go func() {
87 | for {
88 | time.Sleep(time.Second * 60)
89 | f := tm.NewFire(protocol.SourceSystem, systemer)
90 | f.Message.Topic = "/chat/world"
91 | f.Message.Data = []byte(fmt.Sprintf("{\"type\":\"publish\",\"data\":\"请通过 room 命令切换聊天室\",\"from\":\"system\"}"))
92 | tm.Publish(f)
93 | }
94 | }()
95 |
96 | tower := &Tower{
97 | tm: tm,
98 | }
99 | http.HandleFunc("/ws", tower.Websocket)
100 | tm.Logger().Info("http server start", slog.String("address", "0.0.0.0:9999"))
101 | if err := http.ListenAndServe("0.0.0.0:9999", nil); err != nil {
102 | panic(err)
103 | }
104 | }
105 |
106 | type Tower struct {
107 | tm tower.Manager[json.RawMessage]
108 | }
109 |
110 | // Websocket http转websocket连接 并实例化firetower
111 | func (t *Tower) Websocket(w http.ResponseWriter, r *http.Request) {
112 | // 做用户身份验证
113 |
114 | // 验证成功才升级连接
115 | ws, _ := upgrader.Upgrade(w, r, nil)
116 |
117 | id := utils.IDWorker().GetId()
118 | tower, err := t.tm.BuildTower(ws, strconv.FormatInt(id, 10))
119 | if err != nil {
120 | w.WriteHeader(http.StatusInternalServerError)
121 | w.Write([]byte(err.Error()))
122 | return
123 | }
124 |
125 | tower.SetReadHandler(func(fire protocol.ReadOnlyFire[json.RawMessage]) bool {
126 | // fire将会在handler执行结束后被回收
127 | messageInfo := new(messageInfo)
128 | err := json.Unmarshal(fire.GetMessage().Data, messageInfo)
129 | if err != nil {
130 | return false
131 | }
132 | msg := strings.Trim(string(messageInfo.Data), "\"")
133 | switch true {
134 | case strings.HasPrefix(msg, "/name "):
135 | tower.SetUserID(strings.TrimLeft(msg, "/name "))
136 | messageInfo.From = "system"
137 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "change_name", "name": "%s"}`, tower.UserID()))
138 | messageInfo.Type = "event"
139 | raw, _ := json.Marshal(messageInfo)
140 | tower.SendToClient(raw)
141 | return false
142 | case strings.HasPrefix(msg, "/room "):
143 | if err = tower.UnSubscribe(fire.GetContext(), tower.TopicList()); err != nil {
144 | messageInfo.From = "system"
145 | messageInfo.Type = "event"
146 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "error", "msg": "切换房间失败, %s"}`, err.Error()))
147 | raw, _ := json.Marshal(messageInfo)
148 | tower.SendToClient(raw)
149 | return false
150 | }
151 | roomCode := strings.TrimLeft(msg, "/room ")
152 | if err = tower.Subscribe(fire.GetContext(), []string{"/chat/" + roomCode}); err != nil {
153 | messageInfo.From = "system"
154 | messageInfo.Type = "event"
155 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "error", "msg": "切换房间失败, %s, 请重新尝试"}`, err.Error()))
156 | raw, _ := json.Marshal(messageInfo)
157 | tower.SendToClient(raw)
158 | return false
159 | }
160 |
161 | return false
162 | }
163 |
164 | if tower.UserID() == "" {
165 | tower.SetUserID(messageInfo.From)
166 | }
167 | messageInfo.From = tower.UserID()
168 | f := fire.Copy()
169 | f.Message.Data, _ = json.Marshal(messageInfo)
170 |
171 | // 做发送验证
172 | // 判断发送方是否有权限向到达方发送内容
173 | tower.SendToClient(f.Message.Json())
174 | return false
175 | })
176 |
177 | tower.SetReceivedHandler(func(fi protocol.ReadOnlyFire[json.RawMessage]) bool {
178 | return true
179 | })
180 |
181 | tower.SetReadTimeoutHandler(func(fire protocol.ReadOnlyFire[json.RawMessage]) {
182 | messageInfo := new(messageInfo)
183 | err := json.Unmarshal(fire.GetMessage().Data, messageInfo)
184 | if err != nil {
185 | return
186 | }
187 | messageInfo.Type = "timeout"
188 | b, _ := json.Marshal(messageInfo)
189 | tower.SendToClient(b)
190 | })
191 |
192 | tower.SetBeforeSubscribeHandler(func(context protocol.FireLife, topic []string) bool {
193 | // 这里用来判断当前用户是否允许订阅该topic
194 | return true
195 | })
196 |
197 | tower.SetSubscribeHandler(func(context protocol.FireLife, topic []string) {
198 | for _, v := range topic {
199 | if strings.HasPrefix(v, "/chat/") {
200 | roomCode := strings.TrimPrefix(v, "/chat/")
201 | messageInfo := new(messageInfo)
202 | messageInfo.From = "system"
203 | messageInfo.Type = "event"
204 | messageInfo.Data = []byte(fmt.Sprintf(`{"type": "change_room", "room": "%s"}`, roomCode))
205 | msg, _ := json.Marshal(messageInfo)
206 | tower.SendToClient(msg)
207 | }
208 | }
209 | })
210 |
211 | ticker := time.NewTicker(time.Millisecond * 500)
212 | go func() {
213 | topicConnCache := make(map[string]uint64)
214 | for {
215 | select {
216 | case <-tower.OnClose():
217 | return
218 | case <-ticker.C:
219 | for _, v := range tower.TopicList() {
220 | num, err := tower.GetConnectNum(v)
221 | if err != nil {
222 | tower.Logger().Error("failed to get connect number", slog.Any("error", err))
223 | continue
224 | }
225 | if topicConnCache[v] == num {
226 | continue
227 | }
228 |
229 | tower.SendToClient([]byte(fmt.Sprintf("{\"type\":\"onSubscribe\",\"data\":%d}", num)))
230 | topicConnCache[v] = num
231 | }
232 | }
233 | }
234 | }()
235 |
236 | tower.Run()
237 | }
238 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA=
2 | github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE=
3 | github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
4 | github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
8 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
9 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
10 | github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
11 | github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
12 | github.com/go-redis/redis/v9 v9.0.0-beta.2 h1:ZSr84TsnQyKMAg8gnV+oawuQezeJR11/09THcWCQzr4=
13 | github.com/go-redis/redis/v9 v9.0.0-beta.2/go.mod h1:Bldcd/M/bm9HbnNPi/LUtYBSD8ttcZYBMupwMXhdU0o=
14 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
15 | github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
16 | github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
17 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
18 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
19 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
20 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
21 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
22 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
23 | github.com/holdno/snowFlakeByGo v1.0.0 h1:YbS4Jx78sF688XG8iwTRUDtIIH6+R+WWVWyS8uq8xOg=
24 | github.com/holdno/snowFlakeByGo v1.0.0/go.mod h1:aqAI0YiLKgShMi9R71i5S81IWfb0x2ghGF2w1RjyNbs=
25 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
26 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
27 | github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4=
28 | github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
29 | github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g=
30 | github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY=
31 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
32 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
33 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
34 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
35 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
36 | github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I=
37 | github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
38 | github.com/nats-io/nats-server/v2 v2.8.4 h1:0jQzze1T9mECg8YZEl8+WYUXb9JKluJfCBriPUtluB4=
39 | github.com/nats-io/nats-server/v2 v2.8.4/go.mod h1:8zZa+Al3WsESfmgSs98Fi06dRWLH5Bnq90m5bKD/eT4=
40 | github.com/nats-io/nats.go v1.16.0 h1:zvLE7fGBQYW6MWaFaRdsgm9qT39PJDQoju+DS8KsO1g=
41 | github.com/nats-io/nats.go v1.16.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
42 | github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
43 | github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
44 | github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
45 | github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
46 | github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
47 | github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
48 | github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
49 | github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
50 | github.com/onsi/gomega v1.20.0 h1:8W0cWlwFkflGPLltQvLRB7ZVD5HuP6ng320w2IS245Q=
51 | github.com/onsi/gomega v1.20.0/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo=
52 | github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
53 | github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
54 | github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
55 | github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
56 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
57 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
58 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
59 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
60 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
61 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
62 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
63 | github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
64 | github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
65 | github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
66 | github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
67 | golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
68 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38=
69 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
70 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
71 | golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA=
72 | golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
73 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
74 | golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16CMAGuqwO2lX1mTyyRRc=
75 | golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
76 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
77 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
78 | golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
79 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
80 | golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
81 | golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
82 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
83 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
84 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
85 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
86 | google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ=
87 | google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
88 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
89 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
90 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
91 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
92 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
93 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
94 |
--------------------------------------------------------------------------------
/example/web/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | firetower client
6 |
7 |
8 |
239 |
240 |
241 |
242 | firetower
243 |
244 |
245 |
246 | 您的昵称:,当前房间:world,当前在线:1
247 |
248 |
249 | 切换房间
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
276 |
277 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
437 |
438 |
--------------------------------------------------------------------------------
/service/tower/bucket.go:
--------------------------------------------------------------------------------
1 | package tower
2 |
3 | import (
4 | "context"
5 | "log/slog"
6 | "os"
7 | "sync"
8 | "sync/atomic"
9 | "time"
10 |
11 | "github.com/avast/retry-go/v4"
12 | "github.com/gorilla/websocket"
13 | "github.com/holdno/firetower/config"
14 | "github.com/holdno/firetower/pkg/nats"
15 | "github.com/holdno/firetower/protocol"
16 | "github.com/holdno/firetower/store"
17 | "github.com/holdno/firetower/store/redis"
18 | "github.com/holdno/firetower/store/single"
19 | "github.com/holdno/firetower/utils"
20 |
21 | cmap "github.com/orcaman/concurrent-map/v2"
22 | )
23 |
24 | type Manager[T any] interface {
25 | protocol.Pusher[T]
26 | BuildTower(ws *websocket.Conn, clientId string) (tower *FireTower[T], err error)
27 | BuildServerSideTower(clientId string) ServerSideTower[T]
28 | NewFire(source protocol.FireSource, tower PusherInfo) *protocol.FireInfo[T]
29 | GetTopics() (map[string]uint64, error)
30 | ClusterID() int64
31 | Store() stores
32 | Logger() protocol.Logger
33 | }
34 |
35 | // TowerManager 包含中心处理队列和多个bucket
36 | // bucket的作用是将一个实例的连接均匀的分布在多个bucket中来达到并发推送的目的
37 | type TowerManager[T any] struct {
38 | cfg config.FireTowerConfig
39 | bucket []*Bucket[T]
40 | centralChan chan *protocol.FireInfo[T] // 中心处理队列
41 | ip string
42 | clusterID int64
43 | timeout time.Duration
44 |
45 | stores stores
46 | logger protocol.Logger
47 | topicCounter chan counterMsg
48 | connCounter chan counterMsg
49 |
50 | coder protocol.Coder[T]
51 | protocol.Pusher[T]
52 |
53 | isClose bool
54 | closeChan chan struct{}
55 |
56 | brazier protocol.Brazier[T]
57 |
58 | onTopicCountChangedHandler func(Topic string)
59 | onConnCountChangedHandler func()
60 | }
61 |
62 | func (t *TowerManager[T]) SetTopicCountChangedHandler(f func(string)) {
63 | t.onTopicCountChangedHandler = f
64 | }
65 |
66 | func (t *TowerManager[T]) SetConnCountChangedHandler(f func()) {
67 | t.onConnCountChangedHandler = f
68 | }
69 |
70 | type counterMsg struct {
71 | Key string
72 | Num int64
73 | }
74 |
75 | type stores interface {
76 | ClusterConnStore() store.ClusterConnStore
77 | ClusterTopicStore() store.ClusterTopicStore
78 | ClusterStore() store.ClusterStore
79 | }
80 |
81 | // Bucket 的作用是将一个实例的连接均匀的分布在多个bucket中来达到并发推送的目的
82 | type Bucket[T any] struct {
83 | tm *TowerManager[T]
84 | mu sync.RWMutex // 读写锁,可并发读不可并发读写
85 | id int64
86 | len int64
87 | // topicRelevance map[string]map[string]*FireTower // topic -> websocket clientid -> websocket conn
88 | topicRelevance cmap.ConcurrentMap[string, cmap.ConcurrentMap[string, *FireTower[T]]]
89 | BuffChan chan *protocol.FireInfo[T] // bucket的消息处理队列
90 | sendTimeout time.Duration
91 | }
92 |
93 | type TowerOption[T any] func(t *TowerManager[T])
94 |
95 | func BuildWithPusher[T any](pusher protocol.Pusher[T]) TowerOption[T] {
96 | return func(t *TowerManager[T]) {
97 | t.Pusher = pusher
98 | }
99 | }
100 |
101 | func BuildWithMessageDisposeTimeout[T any](timeout time.Duration) TowerOption[T] {
102 | return func(t *TowerManager[T]) {
103 | if timeout > 0 {
104 | t.timeout = timeout
105 | }
106 | }
107 | }
108 |
109 | func BuildWithCoder[T any](coder protocol.Coder[T]) TowerOption[T] {
110 | return func(t *TowerManager[T]) {
111 | t.coder = coder
112 | }
113 | }
114 |
115 | func BuildWithStore[T any](store stores) TowerOption[T] {
116 | return func(t *TowerManager[T]) {
117 | t.stores = store
118 | }
119 | }
120 |
121 | func BuildWithClusterID[T any](id int64) TowerOption[T] {
122 | return func(t *TowerManager[T]) {
123 | t.clusterID = id
124 | }
125 | }
126 |
127 | func BuildWithLogger[T any](logger protocol.Logger) TowerOption[T] {
128 | return func(t *TowerManager[T]) {
129 | t.logger = logger
130 | }
131 | }
132 |
133 | func BuildFoundation[T any](cfg config.FireTowerConfig, opts ...TowerOption[T]) (Manager[T], error) {
134 | tm := &TowerManager[T]{
135 | cfg: cfg,
136 | bucket: make([]*Bucket[T], cfg.Bucket.Num),
137 | centralChan: make(chan *protocol.FireInfo[T], cfg.Bucket.CentralChanCount),
138 | topicCounter: make(chan counterMsg, 20000),
139 | connCounter: make(chan counterMsg, 20000),
140 | closeChan: make(chan struct{}),
141 | brazier: newBrazier[T](),
142 | logger: slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
143 | Level: slog.LevelDebug,
144 | })),
145 | timeout: time.Second,
146 | }
147 |
148 | for _, opt := range opts {
149 | opt(tm)
150 | }
151 |
152 | var err error
153 | if tm.ip, err = utils.GetIP(); err != nil {
154 | panic(err)
155 | }
156 |
157 | if tm.coder == nil {
158 | tm.coder = &protocol.DefaultCoder[T]{}
159 | }
160 |
161 | if tm.Pusher == nil {
162 | if cfg.ServiceMode == config.SingleMode {
163 | tm.Pusher = protocol.DefaultPusher(tm.brazier, tm.coder, tm.logger)
164 | } else {
165 | tm.Pusher = nats.MustSetupNatsPusher(cfg.Cluster.NatsOption, tm.coder, tm.logger, func() map[string]uint64 {
166 | m, err := tm.stores.ClusterTopicStore().Topics()
167 | if err != nil {
168 | tm.logger.Error("failed to get current node topics from nats", slog.Any("error", err))
169 | return map[string]uint64{}
170 | }
171 | return m
172 | })
173 | }
174 | }
175 |
176 | if tm.stores == nil {
177 | if cfg.ServiceMode == config.SingleMode {
178 | tm.stores, _ = single.Setup()
179 | } else {
180 | if tm.stores, err = redis.Setup(cfg.Cluster.RedisOption.Addr,
181 | cfg.Cluster.RedisOption.Password,
182 | cfg.Cluster.RedisOption.DB, tm.ip,
183 | cfg.Cluster.RedisOption.KeyPrefix); err != nil {
184 | panic(err)
185 | }
186 | }
187 | }
188 |
189 | if tm.clusterID == 0 {
190 | clusterID, err := tm.stores.ClusterStore().ClusterNumber()
191 | if err != nil {
192 | return nil, err
193 | }
194 | tm.clusterID = clusterID
195 | }
196 | utils.SetupIDWorker(tm.clusterID)
197 |
198 | for i := range tm.bucket {
199 | tm.bucket[i] = newBucket[T](tm, cfg.Bucket.BuffChanCount, cfg.Bucket.ConsumerNum)
200 | }
201 |
202 | go func() {
203 | var (
204 | connCounter int64
205 | topicCounter = make(map[string]int64)
206 | ticker = time.NewTicker(time.Millisecond * 500)
207 | clusterHeartbeat = time.NewTicker(time.Second * 3)
208 | )
209 |
210 | reportConn := func(counter int64) {
211 | err := retry.Do(func() error {
212 | return tm.stores.ClusterConnStore().OneClientAtomicAddBy(tm.ip, counter)
213 | }, retry.Attempts(3), retry.LastErrorOnly(true))
214 | if err != nil {
215 | tm.logger.Error("failed to update the number of redis websocket connections", slog.Any("error", err))
216 | tm.connCounter <- counterMsg{
217 | Key: tm.ip,
218 | Num: counter,
219 | }
220 | return
221 | }
222 | if tm.onConnCountChangedHandler != nil {
223 | tm.onConnCountChangedHandler()
224 | }
225 | }
226 |
227 | reportTopicConn := func(topicCounter map[string]int64) {
228 | for t, n := range topicCounter {
229 | err := retry.Do(func() error {
230 | return tm.stores.ClusterTopicStore().TopicConnAtomicAddBy(t, n)
231 | }, retry.Attempts(3), retry.LastErrorOnly(true))
232 | if err != nil {
233 | tm.logger.Error("failed to update the number of connections for the topic in redis", slog.Any("error", err))
234 | tm.topicCounter <- counterMsg{
235 | Key: t,
236 | Num: n,
237 | }
238 | return
239 | }
240 | if tm.onTopicCountChangedHandler != nil {
241 | tm.onTopicCountChangedHandler(t)
242 | }
243 | }
244 | }
245 | for {
246 | select {
247 | case msg := <-tm.connCounter:
248 | connCounter += msg.Num
249 | case msg := <-tm.topicCounter:
250 | topicCounter[msg.Key] += msg.Num
251 | case <-ticker.C:
252 | if connCounter > 0 {
253 | go reportConn(connCounter)
254 | connCounter = 0
255 | clusterHeartbeat.Reset(time.Second * 3)
256 | }
257 | if len(topicCounter) > 0 {
258 | go reportTopicConn(topicCounter)
259 | topicCounter = make(map[string]int64)
260 | }
261 | case <-clusterHeartbeat.C:
262 | reportConn(0)
263 | case <-tm.closeChan:
264 | return
265 | }
266 | }
267 | }()
268 |
269 | // 执行中心处理器 将所有推送消息推送到bucketNum个bucket中
270 | go func() {
271 | for {
272 | select {
273 | case fire := <-tm.Receive():
274 | for _, b := range tm.bucket {
275 | b.BuffChan <- fire
276 | }
277 | case <-tm.closeChan:
278 | return
279 | }
280 | }
281 | }()
282 |
283 | return tm, nil
284 | }
285 |
286 | func (t *TowerManager[T]) Logger() protocol.Logger {
287 | return t.logger
288 | }
289 |
290 | func (t *TowerManager[T]) Store() stores {
291 | return t.stores
292 | }
293 |
294 | func (t *TowerManager[T]) GetTopics() (map[string]uint64, error) {
295 | return t.stores.ClusterTopicStore().ClusterTopics()
296 | }
297 |
298 | func (t *TowerManager[T]) ClusterID() int64 {
299 | if t == nil {
300 | panic("firetower cluster not setup")
301 | }
302 | return t.clusterID
303 | }
304 |
305 | func newBucket[T any](tm *TowerManager[T], buff int64, consumerNum int) *Bucket[T] {
306 | b := &Bucket[T]{
307 | tm: tm,
308 | id: getNewBucketId(),
309 | len: 0,
310 | topicRelevance: cmap.New[cmap.ConcurrentMap[string, *FireTower[T]]](),
311 | BuffChan: make(chan *protocol.FireInfo[T], buff),
312 | sendTimeout: tm.timeout,
313 | }
314 |
315 | if consumerNum == 0 {
316 | consumerNum = 1
317 | }
318 | // 每个bucket启动ConsumerNum个消费者(并发处理)
319 | for i := 0; i < consumerNum; i++ {
320 | go b.consumer()
321 | }
322 | return b
323 | }
324 |
325 | var (
326 | bucketId int64
327 | connId uint64
328 | )
329 |
330 | func getNewBucketId() int64 {
331 | atomic.AddInt64(&bucketId, 1)
332 | return bucketId
333 | }
334 |
335 | func getConnId() uint64 {
336 | atomic.AddUint64(&connId, 1)
337 | return connId
338 | }
339 |
340 | // GetBucket 获取一个可以分配当前连接的bucket
341 | func (t *TowerManager[T]) GetBucket(bt *FireTower[T]) (bucket *Bucket[T]) {
342 | bucket = t.bucket[bt.connID%uint64(len(t.bucket))]
343 | return
344 | }
345 |
346 | // 来自publish的消息
347 | func (b *Bucket[T]) consumer() {
348 | for {
349 | select {
350 | case fire := <-b.BuffChan:
351 | switch fire.Message.Type {
352 | case protocol.OfflineTopicByUserIdOperation:
353 | // 需要退订的topic和user_id
354 | // todo use api
355 | b.unSubscribeByUserId(fire)
356 | case protocol.OfflineTopicOperation:
357 | // todo use api
358 | b.unSubscribeAll(fire)
359 | case protocol.OfflineUserOperation:
360 | // todo use api
361 | b.offlineUsers(fire)
362 | default:
363 | b.push(fire)
364 | }
365 | }
366 | }
367 | }
368 |
369 | // AddSubscribe 添加当前实例中的topic->conn的订阅关系
370 | func (b *Bucket[T]) AddSubscribe(topic string, bt *FireTower[T]) {
371 | if m, ok := b.topicRelevance.Get(topic); ok {
372 | m.Set(bt.ClientID(), bt)
373 | } else {
374 | inner := cmap.New[*FireTower[T]]()
375 | inner.Set(bt.ClientID(), bt)
376 | b.topicRelevance.Set(topic, inner)
377 | }
378 | b.tm.topicCounter <- counterMsg{
379 | Key: topic,
380 | Num: 1,
381 | }
382 | }
383 |
384 | // DelSubscribe 删除当前实例中的topic->conn的订阅关系
385 | func (b *Bucket[T]) DelSubscribe(topic string, bt *FireTower[T]) {
386 | if inner, ok := b.topicRelevance.Get(topic); ok {
387 | inner.Remove(bt.clientID)
388 | if inner.IsEmpty() {
389 | b.topicRelevance.Remove(topic)
390 | }
391 | }
392 |
393 | b.tm.topicCounter <- counterMsg{
394 | Key: topic,
395 | Num: -1,
396 | }
397 | }
398 |
399 | // Push 桶内进行遍历push
400 | // 每个bucket有一个Push方法
401 | // 在推送时每个bucket同时调用Push方法 来达到并发推送
402 | // 该方法主要通过遍历桶中的topic->conn订阅关系来进行websocket写入
403 | func (b *Bucket[T]) push(message *protocol.FireInfo[T]) error {
404 | if m, ok := b.topicRelevance.Get(message.Message.Topic); ok {
405 | for _, v := range m.Items() {
406 | if v.isClose {
407 | continue
408 | }
409 |
410 | if v.receivedHandler != nil && !v.receivedHandler(message) {
411 | continue
412 | }
413 |
414 | if v.ws != nil {
415 | func() {
416 | ctx, cancel := context.WithTimeout(context.Background(), b.sendTimeout)
417 | defer cancel()
418 | select {
419 | case v.sendOut <- message.Message.Json():
420 | case <-ctx.Done():
421 | if v.sendTimeoutHandler != nil {
422 | v.sendTimeoutHandler(message)
423 | }
424 | }
425 | }()
426 | }
427 | }
428 | }
429 | return nil
430 | }
431 |
432 | // UnSubscribeByUserId 服务端指定某个用户退订某个topic
433 | func (b *Bucket[T]) unSubscribeByUserId(fire *protocol.FireInfo[T]) error {
434 | if m, ok := b.topicRelevance.Get(fire.Message.Topic); ok {
435 | userId, ok := fire.Context.ExtMeta[protocol.SYSTEM_CMD_REMOVE_USER]
436 | if ok {
437 | for _, v := range m.Items() {
438 | if v.UserID() == userId {
439 | v.unbindTopic([]string{fire.Message.Topic})
440 | if v.unSubscribeHandler != nil {
441 | v.unSubscribeHandler(fire.Context, []string{fire.Message.Topic})
442 | }
443 | return nil
444 | }
445 | }
446 | return nil
447 | }
448 | }
449 |
450 | return ErrorTopicEmpty
451 | }
452 |
453 | // UnSubscribeAll 移除所有该topic的订阅关系
454 | func (b *Bucket[T]) unSubscribeAll(fire *protocol.FireInfo[T]) error {
455 | if m, ok := b.topicRelevance.Get(fire.Message.Topic); ok {
456 | for _, v := range m.Items() {
457 | v.unbindTopic([]string{fire.Message.Topic})
458 | // 移除所有人的应该不需要执行回调方法
459 | if v.onSystemRemove != nil {
460 | v.onSystemRemove(fire.Message.Topic)
461 | }
462 | }
463 | return nil
464 | }
465 | return ErrorTopicEmpty
466 | }
467 |
468 | func (b *Bucket[T]) offlineUsers(fire *protocol.FireInfo[T]) error {
469 | if m, ok := b.topicRelevance.Get(fire.Message.Topic); ok {
470 | userId, ok := fire.Context.ExtMeta[protocol.SYSTEM_CMD_REMOVE_USER]
471 | if ok {
472 | for _, v := range m.Items() {
473 | if v.UserID() == userId {
474 | v.Close()
475 | return nil
476 | }
477 | }
478 | }
479 | }
480 | return ErrorTopicEmpty
481 | }
482 |
--------------------------------------------------------------------------------
/service/tower/tower.go:
--------------------------------------------------------------------------------
1 | package tower
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "log/slog"
8 | "strings"
9 | "sync"
10 | "time"
11 |
12 | "github.com/gorilla/websocket"
13 |
14 | "github.com/holdno/firetower/config"
15 | "github.com/holdno/firetower/protocol"
16 | json "github.com/json-iterator/go"
17 | )
18 |
19 | var (
20 | towerPool sync.Pool
21 | )
22 |
23 | type BlockMessage struct {
24 | sendOutMessage []byte
25 | response chan error
26 | }
27 |
28 | // FireTower 客户端连接结构体
29 | // 包含了客户端一个连接的所有信息
30 | type FireTower[T any] struct {
31 | tm *TowerManager[T]
32 | connID uint64 // 连接id 每台服务器上该id从1开始自增
33 | clientID string // 客户端id 用来做业务逻辑
34 | userID string // 一般业务中每个连接都是一个用户 用来给业务提供用户识别
35 | ext *sync.Map
36 | Cookie []byte // 这里提供给业务放一个存放跟当前连接相关的数据信息
37 | startTime time.Time
38 |
39 | logger protocol.Logger
40 | timeout time.Duration
41 | readIn chan *protocol.FireInfo[T] // 读取队列
42 | sendOut chan []byte // 发送队列
43 | sendOutBlock chan BlockMessage
44 | ws *websocket.Conn // 保存底层websocket连接
45 | topic sync.Map // 订阅topic列表
46 | isClose bool // 判断当前websocket是否被关闭
47 | closeChan chan struct{} // 用来作为关闭websocket的触发点
48 | mutex sync.Mutex // 避免并发close chan
49 |
50 | onConnectHandler func() bool
51 | onOfflineHandler func()
52 | receivedHandler func(protocol.ReadOnlyFire[T]) bool
53 | readHandler func(protocol.ReadOnlyFire[T]) bool
54 | readTimeoutHandler func(protocol.ReadOnlyFire[T])
55 | sendTimeoutHandler func(protocol.ReadOnlyFire[T])
56 | subscribeHandler func(context protocol.FireLife, topic []string)
57 | unSubscribeHandler func(context protocol.FireLife, topic []string)
58 | beforeSubscribeHandler func(context protocol.FireLife, topic []string) bool
59 | onSystemRemove func(topic string)
60 | }
61 |
62 | func (t *FireTower[T]) ClientID() string {
63 | return t.clientID
64 | }
65 |
66 | func (t *FireTower[T]) UserID() string {
67 | return t.userID
68 | }
69 |
70 | func (t *FireTower[T]) SetUserID(id string) {
71 | t.userID = id
72 | }
73 |
74 | func (t *FireTower[T]) Ext() *sync.Map {
75 | return t.ext
76 | }
77 |
78 | // Init 初始化firetower
79 | // 在调用firetower前请一定要先调用Init方法
80 | func Setup[T any](cfg config.FireTowerConfig, opts ...TowerOption[T]) (Manager[T], error) {
81 | towerPool.New = func() interface{} {
82 | return &FireTower[T]{}
83 | }
84 | // 构建服务架构
85 | return BuildFoundation(cfg, opts...)
86 | }
87 |
88 | // BuildTower 实例化一个websocket客户端
89 | func (t *TowerManager[T]) BuildTower(ws *websocket.Conn, clientId string) (tower *FireTower[T], err error) {
90 | if ws == nil {
91 | t.logger.Error("empty websocket connect")
92 | return nil, errors.New("empty websocket connect")
93 | }
94 | tower = buildNewTower(t, ws, clientId)
95 | return
96 | }
97 |
98 | func buildNewTower[T any](tm *TowerManager[T], ws *websocket.Conn, clientID string) *FireTower[T] {
99 | t := towerPool.Get().(*FireTower[T])
100 | t.tm = tm
101 | t.connID = getConnId()
102 | t.clientID = clientID
103 | t.startTime = time.Now()
104 | t.ext = &sync.Map{}
105 | t.readIn = make(chan *protocol.FireInfo[T], tm.cfg.ReadChanLens)
106 | t.sendOut = make(chan []byte, tm.cfg.WriteChanLens)
107 | t.sendOutBlock = make(chan BlockMessage)
108 | t.topic = sync.Map{}
109 | t.ws = ws
110 | t.isClose = false
111 | t.closeChan = make(chan struct{})
112 | t.timeout = time.Second * 3
113 |
114 | t.readHandler = nil
115 | t.readTimeoutHandler = nil
116 | t.sendTimeoutHandler = nil
117 | t.subscribeHandler = nil
118 | t.unSubscribeHandler = nil
119 | t.beforeSubscribeHandler = nil
120 |
121 | t.logger = tm.logger
122 |
123 | return t
124 | }
125 |
126 | func (t *FireTower[T]) OnClose() chan struct{} {
127 | return t.closeChan
128 | }
129 |
130 | // Run 启动websocket客户端
131 | func (t *FireTower[T]) Run() {
132 | t.logger.Debug("new tower builded")
133 | t.tm.connCounter <- counterMsg{
134 | Key: t.tm.ip,
135 | Num: 1,
136 | }
137 |
138 | if t.ws == nil {
139 | // server side
140 | return
141 | }
142 |
143 | // 读取websocket信息
144 | go t.readLoop()
145 | // 处理读取事件
146 | go t.readDispose()
147 |
148 | if t.onConnectHandler != nil {
149 | ok := t.onConnectHandler()
150 | if !ok {
151 | t.Close()
152 | }
153 | }
154 | // 向websocket发送信息
155 | t.sendLoop()
156 | }
157 |
158 | func (t *FireTower[T]) TopicList() []string {
159 | var topics []string
160 | t.topic.Range(func(key, value any) bool {
161 | topics = append(topics, key.(string))
162 | return true
163 | })
164 | return topics
165 | }
166 |
167 | // 订阅topic的绑定过程
168 | func (t *FireTower[T]) bindTopic(topic []string) ([]string, error) {
169 | var (
170 | addTopic []string
171 | )
172 | bucket := t.tm.GetBucket(t)
173 | for _, v := range topic {
174 | if _, loaded := t.topic.LoadOrStore(v, struct{}{}); !loaded {
175 | addTopic = append(addTopic, v) // 待订阅的topic
176 | bucket.AddSubscribe(v, t)
177 | }
178 | // if _, ok := t.topic[v]; !ok {
179 | // addTopic = append(addTopic, v) // 待订阅的topic
180 | // t.topic[v] = true
181 | // bucket.AddSubscribe(v, t)
182 | // }
183 | }
184 | return addTopic, nil
185 | }
186 |
187 | func (t *FireTower[T]) unbindTopic(topic []string) []string {
188 | var delTopic []string // 待取消订阅的topic列表
189 | bucket := t.tm.GetBucket(t)
190 | for _, v := range topic {
191 | if _, loaded := t.topic.LoadAndDelete(v); loaded {
192 | // 如果客户端已经订阅过该topic才执行退订
193 | delTopic = append(delTopic, v)
194 | bucket.DelSubscribe(v, t)
195 | }
196 | }
197 | return delTopic
198 | }
199 |
200 | func (t *FireTower[T]) read() (*protocol.FireInfo[T], error) {
201 | if t.isClose {
202 | return nil, ErrorClosed
203 | }
204 | fire, ok := <-t.readIn
205 | if !ok {
206 | return nil, ErrorClosed
207 | }
208 | return fire, nil
209 | }
210 |
211 | // Close 关闭客户端连接并注销
212 | // 调用该方法会完全注销掉由BuildTower生成的一切内容
213 | func (t *FireTower[T]) Close() {
214 | t.logger.Debug("close connect")
215 | t.mutex.Lock()
216 | defer t.mutex.Unlock()
217 | if !t.isClose {
218 | t.isClose = true
219 | var topicSlice []string
220 | t.topic.Range(func(key, value any) bool {
221 | topicSlice = append(topicSlice, key.(string))
222 | return true
223 | })
224 | if len(topicSlice) > 0 {
225 | delTopic := t.unbindTopic(topicSlice)
226 |
227 | fire := t.tm.NewFire(protocol.SourceSystem, t)
228 | defer t.tm.brazier.Extinguished(fire)
229 |
230 | if t.unSubscribeHandler != nil {
231 | t.unSubscribeHandler(fire.Context, delTopic)
232 | }
233 | }
234 |
235 | if t.ws != nil {
236 | t.ws.Close()
237 | }
238 |
239 | t.tm.connCounter <- counterMsg{
240 | Key: t.tm.ip,
241 | Num: -1,
242 | }
243 | close(t.closeChan)
244 | if t.onOfflineHandler != nil {
245 | t.onOfflineHandler()
246 | }
247 | towerPool.Put(t)
248 | t.logger.Debug("tower closed")
249 | }
250 | }
251 |
252 | var heartbeat = []byte{104, 101, 97, 114, 116, 98, 101, 97, 116}
253 |
254 | func (t *FireTower[T]) sendLoop() {
255 | heartTicker := time.NewTicker(time.Duration(t.tm.cfg.Heartbeat) * time.Second)
256 | defer func() {
257 | heartTicker.Stop()
258 | t.Close()
259 | close(t.sendOut)
260 | }()
261 | for {
262 | select {
263 | case wsMsg := <-t.sendOut:
264 | if t.ws != nil {
265 | if err := t.sendToClient(wsMsg); err != nil {
266 | return
267 | }
268 | }
269 | case wsMsg := <-t.sendOutBlock:
270 | if t.ws == nil {
271 | wsMsg.response <- fmt.Errorf("send to none websocket client")
272 | return
273 | }
274 | if err := t.sendToClient(wsMsg.sendOutMessage); err != nil {
275 | wsMsg.response <- err
276 | return
277 | }
278 | wsMsg.response <- nil
279 | case <-heartTicker.C:
280 | // sendMessage.Data = []byte{104, 101, 97, 114, 116, 98, 101, 97, 116} // []byte("heartbeat")
281 | if err := t.sendToClient(heartbeat); err != nil {
282 | return
283 | }
284 | case <-t.closeChan:
285 | return
286 | }
287 | }
288 | }
289 |
290 | func (t *FireTower[T]) readLoop() {
291 | if t.ws == nil {
292 | return
293 | }
294 | defer func() {
295 | if err := recover(); err != nil {
296 | t.tm.logger.Error("readloop panic", slog.Any("error", err))
297 | }
298 | t.Close()
299 | close(t.readIn)
300 | }()
301 | for {
302 | messageType, data, err := t.ws.ReadMessage()
303 | if err != nil { // 断开连接
304 | return
305 | }
306 | fire := t.tm.NewFire(protocol.SourceClient, t) // 从对象池中获取消息对象 降低GC压力
307 | fire.MessageType = messageType
308 | if err := json.Unmarshal(data, &fire.Message); err != nil {
309 | t.logger.Error("failed to unmarshal client data, filtered", slog.Any("error", err))
310 | continue
311 | }
312 |
313 | func() {
314 | ctx, cancel := context.WithTimeout(context.Background(), t.timeout)
315 | defer cancel()
316 | select {
317 | case t.readIn <- fire:
318 | return
319 | case <-ctx.Done():
320 | if t.readTimeoutHandler != nil {
321 | t.readTimeoutHandler(fire)
322 | }
323 | t.logger.Error("readloop timeout", slog.Any("data", data))
324 | case <-t.closeChan:
325 | }
326 | t.tm.brazier.Extinguished(fire)
327 | }()
328 |
329 | }
330 | }
331 |
332 | // 处理前端发来的数据
333 | // 这里要做逻辑拆分,判断用户是要进行通信还是topic订阅
334 | func (t *FireTower[T]) readDispose() {
335 | for {
336 | fire, err := t.read()
337 | if err != nil {
338 | t.logger.Error("failed to read message from websocket", slog.Any("error", err))
339 | return
340 | }
341 | if fire != nil {
342 | go t.readLogic(fire)
343 | }
344 | }
345 | }
346 |
347 | func (t *FireTower[T]) readLogic(fire *protocol.FireInfo[T]) error {
348 | defer t.tm.brazier.Extinguished(fire)
349 | if t.isClose {
350 | return nil
351 | } else {
352 | if fire.Message.Topic == "" {
353 | t.logger.Error("the obtained topic is empty. this message will be filtered")
354 | return fmt.Errorf("%s:topic is empty, ClintId:%s, UserId:%s", fire.Message.Type, t.ClientID(), t.UserID())
355 | }
356 | switch fire.Message.Type {
357 | case protocol.SubscribeOperation: // 客户端订阅topic
358 | addTopics := strings.Split(fire.Message.Topic, ",")
359 | // 增加messageId 方便追踪
360 | err := t.Subscribe(fire.Context, addTopics)
361 | if err != nil {
362 | t.logger.Error("failed to subscribe topics", slog.Any("topics", addTopics), slog.Any("error", err))
363 | // TODO metrics
364 | return err
365 | }
366 | case protocol.UnSubscribeOperation: // 客户端取消订阅topic
367 | delTopic := strings.Split(fire.Message.Topic, ",")
368 | err := t.UnSubscribe(fire.Context, delTopic)
369 | if err != nil {
370 | t.logger.Error("failed to unsubscribe topics", slog.Any("topics", delTopic), slog.Any("error", err))
371 | // TODO metrics
372 | return err
373 | }
374 | default:
375 | if t.readHandler != nil && !t.readHandler(fire) {
376 | return nil
377 | }
378 | t.Publish(fire)
379 | }
380 | }
381 | return nil
382 | }
383 |
384 | func (t *FireTower[T]) Subscribe(context protocol.FireLife, topics []string) error {
385 | // 如果设置了订阅前触发事件则调用
386 | if t.beforeSubscribeHandler != nil {
387 | ok := t.beforeSubscribeHandler(context, topics)
388 | if !ok {
389 | return nil
390 | }
391 | }
392 | addTopics, err := t.bindTopic(topics)
393 | if err != nil {
394 | return err
395 | }
396 | if t.subscribeHandler != nil {
397 | t.subscribeHandler(context, addTopics)
398 | }
399 | return nil
400 | }
401 |
402 | func (t *FireTower[T]) UnSubscribe(context protocol.FireLife, topics []string) error {
403 | delTopics := t.unbindTopic(topics)
404 |
405 | if t.unSubscribeHandler != nil {
406 | t.unSubscribeHandler(context, delTopics)
407 | }
408 | return nil
409 | }
410 |
411 | // Publish 推送接口
412 | // 通过BuildTower生成的实例都可以调用该方法来达到推送的目的
413 | func (t *FireTower[T]) Publish(fire *protocol.FireInfo[T]) error {
414 | err := t.tm.Publish(fire)
415 | if err != nil {
416 | t.logger.Error("failed to publish message", slog.Any("error", err))
417 | return err
418 | }
419 | return nil
420 | }
421 |
422 | // SendToClient 向自己推送消息
423 | // 这里描述一下使用场景
424 | // 只针对当前客户端进行的推送请调用该方法
425 | func (t *FireTower[T]) SendToClient(b []byte) {
426 | t.mutex.Lock()
427 | defer t.mutex.Unlock()
428 | if !t.isClose {
429 | t.sendOut <- b
430 | }
431 | }
432 |
433 | func (t *FireTower[T]) SendToClientBlock(b []byte) error {
434 | t.mutex.Lock()
435 | defer t.mutex.Unlock()
436 | if t.isClose {
437 | return fmt.Errorf("send to closed firetower")
438 | }
439 |
440 | msg := BlockMessage{
441 | sendOutMessage: b,
442 | response: make(chan error, 1),
443 | }
444 | t.sendOutBlock <- msg
445 |
446 | defer close(msg.response)
447 | return <-msg.response
448 | }
449 |
450 | func (t *FireTower[T]) sendToClient(b []byte) error {
451 | if t.ws == nil {
452 | return ErrorServerSideMode
453 | }
454 |
455 | if t.isClose != true {
456 | return t.ws.WriteMessage(websocket.TextMessage, b)
457 | }
458 | return ErrorClosed
459 | }
460 |
461 | // SetOnConnectHandler 建立连接事件
462 | func (t *FireTower[T]) SetOnConnectHandler(fn func() bool) {
463 | t.onConnectHandler = fn
464 | }
465 |
466 | // SetOnOfflineHandler 用户连接关闭时触发
467 | func (t *FireTower[T]) SetOnOfflineHandler(fn func()) {
468 | t.onOfflineHandler = fn
469 | }
470 |
471 | func (t *FireTower[T]) SetReceivedHandler(fn func(protocol.ReadOnlyFire[T]) bool) {
472 | t.receivedHandler = fn
473 | }
474 |
475 | // SetReadHandler 客户端推送事件
476 | // 接收到用户publish的消息时触发
477 | func (t *FireTower[T]) SetReadHandler(fn func(protocol.ReadOnlyFire[T]) bool) {
478 | t.readHandler = fn
479 | }
480 |
481 | // SetSubscribeHandler 订阅事件
482 | // 用户订阅topic后触发
483 | func (t *FireTower[T]) SetSubscribeHandler(fn func(context protocol.FireLife, topic []string)) {
484 | t.subscribeHandler = fn
485 | }
486 |
487 | // SetUnSubscribeHandler 取消订阅事件
488 | // 用户取消订阅topic后触发
489 | func (t *FireTower[T]) SetUnSubscribeHandler(fn func(context protocol.FireLife, topic []string)) {
490 | t.unSubscribeHandler = fn
491 | }
492 |
493 | // SetBeforeSubscribeHandler 订阅前回调事件
494 | // 用户订阅topic前触发
495 | func (t *FireTower[T]) SetBeforeSubscribeHandler(fn func(context protocol.FireLife, topic []string) bool) {
496 | t.beforeSubscribeHandler = fn
497 | }
498 |
499 | // SetReadTimeoutHandler 超时回调
500 | // readIn channal写满了 生产 > 消费的情况下触发超时机制
501 | func (t *FireTower[T]) SetReadTimeoutHandler(fn func(protocol.ReadOnlyFire[T])) {
502 | t.readTimeoutHandler = fn
503 | }
504 |
505 | func (t *FireTower[T]) SetSendTimeoutHandler(fn func(protocol.ReadOnlyFire[T])) {
506 | t.sendTimeoutHandler = fn
507 | }
508 |
509 | // SetOnSystemRemove 系统移除某个用户的topic订阅
510 | func (t *FireTower[T]) SetOnSystemRemove(fn func(topic string)) {
511 | t.onSystemRemove = fn
512 | }
513 |
514 | // GetConnectNum 获取话题订阅数的grpc方法封装
515 | func (t *FireTower[T]) GetConnectNum(topic string) (uint64, error) {
516 | number, err := t.tm.stores.ClusterTopicStore().GetTopicConnNum(topic)
517 | if err != nil {
518 | return 0, err
519 | }
520 | return number, nil
521 | }
522 |
523 | func (t *FireTower[T]) Logger() protocol.Logger {
524 | return t.logger
525 | }
526 |
--------------------------------------------------------------------------------