├── README.md ├── discordhandlers.go ├── example └── main.go └── dshardmanager.go /README.md: -------------------------------------------------------------------------------- 1 | # dshardmanager 2 | 3 | Simple shard manager for discord bots 4 | 5 | Status: 6 | 7 | - [x] Core funcitonality, add handlers, log shard events to discord and to custom outputs 8 | - [x] Fancy status message in discord that gets updated live 9 | - [x] Use the recommded shard count by discord 10 | - [ ] Warn when getting close to the cap 11 | - [ ] Automatically re-scale the sharding when needed 12 | Needed being when a shard with +2500 guilds disconnects and fails to resume, this shard will no longer be able to identify afaik 13 | - [ ] Simple api? maybe in an extras package. -------------------------------------------------------------------------------- /discordhandlers.go: -------------------------------------------------------------------------------- 1 | package dshardmanager 2 | 3 | import ( 4 | "github.com/bwmarrin/discordgo" 5 | ) 6 | 7 | func (m *Manager) OnDiscordConnected(s *discordgo.Session, evt *discordgo.Connect) { 8 | m.handleEvent(EventConnected, s.ShardID, "") 9 | } 10 | 11 | func (m *Manager) OnDiscordDisconnected(s *discordgo.Session, evt *discordgo.Disconnect) { 12 | m.handleEvent(EventDisconnected, s.ShardID, "") 13 | } 14 | 15 | func (m *Manager) OnDiscordReady(s *discordgo.Session, evt *discordgo.Ready) { 16 | m.handleEvent(EventReady, s.ShardID, "") 17 | } 18 | 19 | func (m *Manager) OnDiscordResumed(s *discordgo.Session, evt *discordgo.Resumed) { 20 | m.handleEvent(EventResumed, s.ShardID, "") 21 | } 22 | -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | // "github.com/bwmarrin/discordgo" 6 | "github.com/jonas747/dshardmanager" 7 | "log" 8 | "net/http" 9 | _ "net/http/pprof" 10 | "os" 11 | "strings" 12 | ) 13 | 14 | var ( 15 | FlagToken string 16 | FlagLogChannel string 17 | ) 18 | 19 | func main() { 20 | 21 | flag.StringVar(&FlagToken, "t", "", "Discord token") 22 | flag.StringVar(&FlagLogChannel, "c", "", "Log channel, optional") 23 | flag.Parse() 24 | 25 | log.Println("Starting v" + dshardmanager.VersionString) 26 | if FlagToken == "" { 27 | FlagToken = os.Getenv("DG_TOKEN") 28 | if FlagToken == "" { 29 | log.Fatal("No token specified") 30 | } 31 | } 32 | 33 | if !strings.HasPrefix(FlagToken, "Bot ") { 34 | log.Fatal("dshardmanager only works on bot accounts, did you maybe forgot to add `Bot ` before the token?") 35 | } 36 | 37 | manager := dshardmanager.New(FlagToken) 38 | manager.Name = "ExampleBot" 39 | manager.LogChannel = FlagLogChannel 40 | manager.StatusMessageChannel = FlagLogChannel 41 | 42 | recommended, err := manager.GetRecommendedCount() 43 | if err != nil { 44 | log.Fatal("Failed getting recommended shard count") 45 | } 46 | if recommended < 2 { 47 | manager.SetNumShards(5) 48 | } 49 | 50 | log.Println("Starting the shard manager") 51 | err = manager.Start() 52 | if err != nil { 53 | log.Fatal("Faled to start: ", err) 54 | } 55 | 56 | log.Println("Started!") 57 | 58 | log.Fatal(http.ListenAndServe(":7441", nil)) 59 | select {} 60 | } 61 | -------------------------------------------------------------------------------- /dshardmanager.go: -------------------------------------------------------------------------------- 1 | package dshardmanager 2 | 3 | import ( 4 | "fmt" 5 | "github.com/bwmarrin/discordgo" 6 | "github.com/pkg/errors" 7 | "log" 8 | "strconv" 9 | "strings" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | const ( 15 | VersionMajor = 0 16 | VersionMinor = 2 17 | VersionPath = 0 18 | ) 19 | 20 | var ( 21 | VersionString = strconv.Itoa(VersionMajor) + "." + strconv.Itoa(VersionMinor) + "." + strconv.Itoa(VersionPath) 22 | ) 23 | 24 | type SessionFunc func(token string) (*discordgo.Session, error) 25 | 26 | type Manager struct { 27 | sync.RWMutex 28 | 29 | // Name of the bot, to appear before log messages as a prefix 30 | // and in the title of the updated status message 31 | Name string 32 | 33 | // All the shard sessions 34 | Sessions []*discordgo.Session 35 | eventHandlers []interface{} 36 | 37 | // If set logs connection status events to this channel 38 | LogChannel string 39 | 40 | // If set keeps an updated satus message in this channel 41 | StatusMessageChannel string 42 | 43 | // The function that provides the guild counts per shard, used fro the updated status message 44 | // Should return a slice of guild counts, with the index being the shard number 45 | GuildCountsFunc func() []int 46 | 47 | // Called on events, by default this is set to a function that logs it to log.Printf 48 | // You can override this if you want another behaviour, or just set it to nil for nothing. 49 | OnEvent func(e *Event) 50 | 51 | // SessionFunc creates a new session and returns it, override the default one if you have your own 52 | // session settings to apply 53 | SessionFunc SessionFunc 54 | 55 | nextStatusUpdate time.Time 56 | statusUpdaterStarted bool 57 | 58 | numShards int 59 | token string 60 | 61 | bareSession *discordgo.Session 62 | started bool 63 | } 64 | 65 | // New creates a new shard manager with the defaults set, after you have created this you call Manager.Start 66 | // To start connecting 67 | // dshardmanager.New("Bot asd", OptLogChannel(someChannel), OptLogEventsToDiscord(true, true)) 68 | func New(token string) *Manager { 69 | // Setup defaults 70 | manager := &Manager{ 71 | token: token, 72 | numShards: -1, 73 | } 74 | 75 | manager.OnEvent = manager.LogConnectionEventStd 76 | manager.SessionFunc = manager.StdSessionFunc 77 | 78 | manager.bareSession, _ = discordgo.New(token) 79 | 80 | return manager 81 | } 82 | 83 | // GetRecommendedCount gets the recommended sharding count from discord, this will also 84 | // set the shard count internally if called 85 | // Should not be called after calling Start(), will have undefined behaviour 86 | func (m *Manager) GetRecommendedCount() (int, error) { 87 | resp, err := m.bareSession.GatewayBot() 88 | if err != nil { 89 | return 0, errors.WithMessage(err, "GetRecommendedCount()") 90 | } 91 | 92 | m.numShards = resp.Shards 93 | if m.numShards < 1 { 94 | m.numShards = 1 95 | } 96 | 97 | return m.numShards, nil 98 | } 99 | 100 | // GetNumShards returns the current set number of shards 101 | func (m *Manager) GetNumShards() int { 102 | return m.numShards 103 | } 104 | 105 | // SetNumShards sets the number of shards to use, if you want to override the recommended count 106 | // Should not be called after calling Start(), will panic 107 | func (m *Manager) SetNumShards(n int) { 108 | m.Lock() 109 | defer m.Unlock() 110 | if m.started { 111 | panic("Can't set num shard after started") 112 | } 113 | 114 | m.numShards = n 115 | } 116 | 117 | // Adds an event handler to all shards 118 | // All event handlers will be added to new sessions automatically. 119 | func (m *Manager) AddHandler(handler interface{}) { 120 | m.Lock() 121 | defer m.Unlock() 122 | m.eventHandlers = append(m.eventHandlers, handler) 123 | 124 | if len(m.Sessions) > 0 { 125 | for _, v := range m.Sessions { 126 | v.AddHandler(handler) 127 | } 128 | } 129 | } 130 | 131 | // Init initializesthe manager, retreiving the recommended shard count if needed 132 | // and initalizes all the shards 133 | func (m *Manager) Init() error { 134 | m.Lock() 135 | if m.numShards < 1 { 136 | _, err := m.GetRecommendedCount() 137 | if err != nil { 138 | return errors.WithMessage(err, "Start") 139 | } 140 | } 141 | 142 | m.Sessions = make([]*discordgo.Session, m.numShards) 143 | for i := 0; i < m.numShards; i++ { 144 | err := m.initSession(i) 145 | if err != nil { 146 | m.Unlock() 147 | return errors.WithMessage(err, "initSession") 148 | } 149 | } 150 | 151 | if !m.statusUpdaterStarted { 152 | m.statusUpdaterStarted = true 153 | go m.statusRoutine() 154 | } 155 | 156 | m.nextStatusUpdate = time.Now() 157 | 158 | m.Unlock() 159 | 160 | return nil 161 | } 162 | 163 | // Start starts the shard manager, opening all gateway connections 164 | func (m *Manager) Start() error { 165 | 166 | m.Lock() 167 | if m.Sessions == nil { 168 | m.Unlock() 169 | err := m.Init() 170 | if err != nil { 171 | return err 172 | } 173 | m.Lock() 174 | } 175 | 176 | m.Unlock() 177 | 178 | for i := 0; i < m.numShards; i++ { 179 | if i != 0 { 180 | // One indentify every 5 seconds 181 | time.Sleep(time.Second * 5) 182 | } 183 | 184 | m.Lock() 185 | err := m.startSession(i) 186 | m.Unlock() 187 | if err != nil { 188 | return errors.WithMessage(err, fmt.Sprintf("Failed starting shard %d", i)) 189 | } 190 | } 191 | 192 | return nil 193 | } 194 | 195 | // StopAll stops all the shard sessions and returns the last error that occured 196 | func (m *Manager) StopAll() (err error) { 197 | m.Lock() 198 | for _, v := range m.Sessions { 199 | if e := v.Close(); e != nil { 200 | err = e 201 | } 202 | } 203 | m.Unlock() 204 | 205 | return 206 | } 207 | 208 | func (m *Manager) initSession(shard int) error { 209 | session, err := m.SessionFunc(m.token) 210 | if err != nil { 211 | return errors.WithMessage(err, "startSession.SessionFunc") 212 | } 213 | 214 | session.ShardCount = m.numShards 215 | session.ShardID = shard 216 | 217 | session.AddHandler(m.OnDiscordConnected) 218 | session.AddHandler(m.OnDiscordDisconnected) 219 | session.AddHandler(m.OnDiscordReady) 220 | session.AddHandler(m.OnDiscordResumed) 221 | 222 | // Add the user event handlers retroactively 223 | for _, v := range m.eventHandlers { 224 | session.AddHandler(v) 225 | } 226 | 227 | m.Sessions[shard] = session 228 | return nil 229 | } 230 | 231 | func (m *Manager) startSession(shard int) error { 232 | 233 | err := m.Sessions[shard].Open() 234 | if err != nil { 235 | return errors.Wrap(err, "startSession.Open") 236 | } 237 | m.handleEvent(EventOpen, shard, "") 238 | 239 | return nil 240 | } 241 | 242 | // SessionForGuildS is the same as SessionForGuild but accepts the guildID as a string for convenience 243 | func (m *Manager) SessionForGuildS(guildID string) *discordgo.Session { 244 | // Question is, should we really ignore this error? 245 | // In reality, the guildID should never be invalid but... 246 | parsed, _ := strconv.ParseInt(guildID, 10, 64) 247 | return m.SessionForGuild(parsed) 248 | } 249 | 250 | // SessionForGuild returns the session for the specified guild 251 | func (m *Manager) SessionForGuild(guildID int64) *discordgo.Session { 252 | // (guild_id >> 22) % num_shards == shard_id 253 | // That formula is taken from the sharding issue on the api docs repository on github 254 | m.RLock() 255 | defer m.RUnlock() 256 | shardID := (guildID >> 22) % int64(m.numShards) 257 | return m.Sessions[shardID] 258 | } 259 | 260 | // Session retrieves a session from the sessions map, rlocking it in the process 261 | func (m *Manager) Session(shardID int) *discordgo.Session { 262 | m.RLock() 263 | defer m.RUnlock() 264 | return m.Sessions[shardID] 265 | } 266 | 267 | // LogConnectionEventStd is the standard connection event logger, it logs it to whatever log.output is set to. 268 | func (m *Manager) LogConnectionEventStd(e *Event) { 269 | log.Printf("[Shard Manager] %s", e.String()) 270 | } 271 | 272 | func (m *Manager) handleError(err error, shard int, msg string) bool { 273 | if err == nil { 274 | return false 275 | } 276 | 277 | m.handleEvent(EventError, shard, msg+": "+err.Error()) 278 | return true 279 | } 280 | 281 | func (m *Manager) handleEvent(typ EventType, shard int, msg string) { 282 | if m.OnEvent == nil { 283 | return 284 | } 285 | 286 | evt := &Event{ 287 | Type: typ, 288 | Shard: shard, 289 | NumShards: m.numShards, 290 | Msg: msg, 291 | Time: time.Now(), 292 | } 293 | 294 | go m.OnEvent(evt) 295 | 296 | if m.LogChannel != "" { 297 | go m.logEventToDiscord(evt) 298 | } 299 | 300 | go func() { 301 | m.Lock() 302 | m.nextStatusUpdate = time.Now().Add(time.Second * 2) 303 | m.Unlock() 304 | }() 305 | } 306 | 307 | // StdSessionFunc is the standard session provider, it does nothing to the actual session 308 | func (m *Manager) StdSessionFunc(token string) (*discordgo.Session, error) { 309 | s, err := discordgo.New(token) 310 | if err != nil { 311 | return nil, errors.WithMessage(err, "StdSessionFunc") 312 | } 313 | return s, nil 314 | } 315 | 316 | func (m *Manager) logEventToDiscord(evt *Event) { 317 | if evt.Type == EventError { 318 | return 319 | } 320 | 321 | prefix := "" 322 | if m.Name != "" { 323 | prefix = m.Name + ": " 324 | } 325 | 326 | str := evt.String() 327 | embed := &discordgo.MessageEmbed{ 328 | Description: prefix + str, 329 | Timestamp: evt.Time.Format(time.RFC3339), 330 | Color: eventColors[evt.Type], 331 | } 332 | 333 | _, err := m.bareSession.ChannelMessageSendEmbed(m.LogChannel, embed) 334 | m.handleError(err, evt.Shard, "Failed sending event to discord") 335 | } 336 | 337 | func (m *Manager) statusRoutine() { 338 | if m.StatusMessageChannel == "" { 339 | return 340 | } 341 | 342 | mID := "" 343 | 344 | // Find the initial message id and reuse that message if found 345 | msgs, err := m.bareSession.ChannelMessages(m.StatusMessageChannel, 50, "", "", "") 346 | if err != nil { 347 | m.handleError(err, -1, "Failed requesting message history in channel") 348 | } else { 349 | for _, msg := range msgs { 350 | // Dunno our own bot id so best we can do is bot 351 | if !msg.Author.Bot || len(msg.Embeds) < 1 { 352 | continue 353 | } 354 | 355 | nameStr := "" 356 | if m.Name != "" { 357 | nameStr = " for " + m.Name 358 | } 359 | 360 | embed := msg.Embeds[0] 361 | if embed.Title == "Sharding status"+nameStr { 362 | // Found it sucessfully 363 | mID = msg.ID 364 | break 365 | } 366 | } 367 | } 368 | 369 | ticker := time.NewTicker(time.Second) 370 | for { 371 | select { 372 | case <-ticker.C: 373 | m.RLock() 374 | after := time.Now().After(m.nextStatusUpdate) 375 | m.RUnlock() 376 | if after { 377 | m.Lock() 378 | m.nextStatusUpdate = time.Now().Add(time.Minute) 379 | m.Unlock() 380 | 381 | nID, err := m.updateStatusMessage(mID) 382 | if !m.handleError(err, -1, "Failed updating status message") { 383 | mID = nID 384 | } 385 | } 386 | } 387 | } 388 | } 389 | 390 | func (m *Manager) updateStatusMessage(mID string) (string, error) { 391 | content := "" 392 | 393 | status := m.GetFullStatus() 394 | for _, shard := range status.Shards { 395 | emoji := "" 396 | if !shard.Started { 397 | emoji = "🕒" 398 | } else if shard.OK { 399 | emoji = "👌" 400 | } else { 401 | emoji = "🔥" 402 | } 403 | content += fmt.Sprintf("[%d/%d]: %s (%d,%d)\n", shard.Shard, m.numShards, emoji, shard.NumGuilds, status.NumGuilds) 404 | } 405 | 406 | nameStr := "" 407 | if m.Name != "" { 408 | nameStr = " for " + m.Name 409 | } 410 | 411 | embed := &discordgo.MessageEmbed{ 412 | Title: "Sharding status" + nameStr, 413 | Description: content, 414 | Color: 0x4286f4, 415 | Timestamp: time.Now().Format(time.RFC3339), 416 | } 417 | 418 | if mID == "" { 419 | msg, err := m.bareSession.ChannelMessageSendEmbed(m.StatusMessageChannel, embed) 420 | if err != nil { 421 | return "", err 422 | } 423 | 424 | return msg.ID, err 425 | } 426 | 427 | _, err := m.bareSession.ChannelMessageEditEmbed(m.StatusMessageChannel, mID, embed) 428 | return mID, err 429 | } 430 | 431 | // GetFullStatus retrieves the full status at this instant 432 | func (m *Manager) GetFullStatus() *Status { 433 | var shardGuilds []int 434 | if m.GuildCountsFunc != nil { 435 | shardGuilds = m.GuildCountsFunc() 436 | } else { 437 | shardGuilds = m.StdGuildCountsFunc() 438 | } 439 | 440 | m.RLock() 441 | 442 | result := make([]*ShardStatus, len(m.Sessions)) 443 | for i, shard := range m.Sessions { 444 | result[i] = &ShardStatus{ 445 | Shard: i, 446 | } 447 | 448 | if shard != nil { 449 | result[i].Started = true 450 | 451 | shard.RLock() 452 | result[i].OK = shard.DataReady 453 | shard.RUnlock() 454 | } 455 | } 456 | m.RUnlock() 457 | 458 | totalGuilds := 0 459 | for shard, guilds := range shardGuilds { 460 | totalGuilds += guilds 461 | result[shard].NumGuilds = guilds 462 | } 463 | 464 | return &Status{ 465 | Shards: result, 466 | NumGuilds: totalGuilds, 467 | } 468 | } 469 | 470 | // StdGuildsFunc uses the standard states to return the guilds 471 | func (m *Manager) StdGuildCountsFunc() []int { 472 | 473 | m.RLock() 474 | nShards := m.numShards 475 | result := make([]int, nShards) 476 | 477 | for i, session := range m.Sessions { 478 | if session == nil { 479 | continue 480 | } 481 | session.State.RLock() 482 | result[i] = len(session.State.Guilds) 483 | session.State.RUnlock() 484 | } 485 | 486 | m.RUnlock() 487 | return result 488 | } 489 | 490 | type Status struct { 491 | Shards []*ShardStatus `json:"shards"` 492 | NumGuilds int `json:"num_guilds"` 493 | } 494 | 495 | type ShardStatus struct { 496 | Shard int `json:"shard"` 497 | OK bool `json:"ok"` 498 | Started bool `json:"started"` 499 | NumGuilds int `json:"num_guilds"` 500 | } 501 | 502 | // Event holds data for an event 503 | type Event struct { 504 | Type EventType 505 | 506 | Shard int 507 | NumShards int 508 | 509 | Msg string 510 | 511 | // When this event occured 512 | Time time.Time 513 | } 514 | 515 | func (c *Event) String() string { 516 | prefix := "" 517 | if c.Shard > -1 { 518 | prefix = fmt.Sprintf("[%d/%d] ", c.Shard, c.NumShards) 519 | } 520 | 521 | s := fmt.Sprintf("%s%s", prefix, strings.Title(c.Type.String())) 522 | if c.Msg != "" { 523 | s += ": " + c.Msg 524 | } 525 | 526 | return s 527 | } 528 | 529 | type EventType int 530 | 531 | const ( 532 | // Sent when the connection to the gateway was established 533 | EventConnected EventType = iota 534 | 535 | // Sent when the connection is lose 536 | EventDisconnected 537 | 538 | // Sent when the connection was sucessfully resumed 539 | EventResumed 540 | 541 | // Sent on ready 542 | EventReady 543 | 544 | // Sent when Open() is called 545 | EventOpen 546 | 547 | // Sent when Close() is called 548 | EventClose 549 | 550 | // Sent when an error occurs 551 | EventError 552 | ) 553 | 554 | var ( 555 | eventStrings = map[EventType]string{ 556 | EventOpen: "opened", 557 | EventClose: "closed", 558 | EventConnected: "connected", 559 | EventDisconnected: "disconnected", 560 | EventResumed: "resumed", 561 | EventReady: "ready", 562 | EventError: "error", 563 | } 564 | 565 | eventColors = map[EventType]int{ 566 | EventOpen: 0xec58fc, 567 | EventClose: 0xff7621, 568 | EventConnected: 0x54d646, 569 | EventDisconnected: 0xcc2424, 570 | EventResumed: 0x5985ff, 571 | EventReady: 0x00ffbf, 572 | EventError: 0x7a1bad, 573 | } 574 | ) 575 | 576 | func (c EventType) String() string { 577 | return eventStrings[c] 578 | } 579 | --------------------------------------------------------------------------------