├── slot.go ├── list_test.go ├── task.go ├── list.go ├── README.md ├── tw_test.go └── tw.go /slot.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | // 时间槽 4 | type twSlot struct { 5 | id int 6 | tasks *twList 7 | } 8 | 9 | func newSlot(id int) *twSlot { 10 | return &twSlot{id: id, tasks: newList()} 11 | } 12 | -------------------------------------------------------------------------------- /list_test.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestList(t *testing.T) { 9 | l := newList() 10 | l.push(1) 11 | n2 := l.push("2") 12 | if l.size != 2 { 13 | t.Fatalf("invalid list size, want %d, got %d", 2, l.size) 14 | } 15 | vs := make([]interface{}, 0, l.size) 16 | for cur := l.head; cur != nil; cur = cur.next { 17 | vs = append(vs, cur) 18 | } 19 | if len(vs) != 2 { 20 | t.Fatalf("invalid values size, want %d, got %d", 2, len(vs)) 21 | } 22 | 23 | if v, ok := vs[0].(int); !ok || v != 1 { 24 | t.Fatalf("invalid value: %v", vs[0]) 25 | } 26 | 27 | if v, ok := vs[1].(string); !ok || v != "2" { 28 | t.Fatalf("invalid value: %v", vs[1]) 29 | } 30 | 31 | l.remove(n2) 32 | if l.size != 1 { 33 | t.Fatalf("invalid size after remove") 34 | } 35 | 36 | fmt.Println(l) 37 | } 38 | -------------------------------------------------------------------------------- /task.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | type doTask func() interface{} 9 | 10 | // 每个 slot 链表中的 task 11 | type twTask struct { 12 | id int64 // 在 slot 中的索引位置 13 | slotIdx int // 所属 slot 14 | interval time.Duration // 任务执行间隔 15 | cycles int64 // 延迟指定圈后执行 16 | do doTask // 执行任务 17 | resCh chan interface{} // 传递任务执行结果 18 | repeat int64 // 任务重复执行次数 19 | } 20 | 21 | func newTask(interval time.Duration, repeat int64, do func() interface{}) *twTask { 22 | return &twTask{ 23 | interval: interval, 24 | cycles: cycle(interval), 25 | repeat: repeat, 26 | do: do, 27 | resCh: make(chan interface{}, 1), 28 | } 29 | } 30 | 31 | // 计算 timeout 应在第几圈被执行 32 | func cycle(interval time.Duration) (n int64) { 33 | n = 1 + int64(interval)/cycleCost 34 | return 35 | } 36 | 37 | func (t *twTask) String() string { 38 | return fmt.Sprintf("[slot]:%d [interval]:%.fs [repeat]:%d [cycle]:%dth [idx]:%d ", 39 | t.slotIdx, t.interval.Seconds(), t.repeat, t.cycles, t.id) 40 | } 41 | -------------------------------------------------------------------------------- /list.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | import "fmt" 4 | 5 | type twNode struct { 6 | value interface{} 7 | prev, next *twNode 8 | } 9 | 10 | func newNode(v interface{}) *twNode { 11 | return &twNode{value: v} 12 | } 13 | 14 | type twList struct { 15 | head, tail *twNode 16 | size int 17 | } 18 | 19 | func newList() *twList { 20 | return new(twList) 21 | } 22 | 23 | func (l *twList) push(v interface{}) *twNode { 24 | n := newNode(v) 25 | if l.head == nil { 26 | l.head, l.tail = n, n 27 | l.size++ 28 | return n 29 | } 30 | 31 | n.prev = l.tail 32 | n.next = nil 33 | 34 | l.tail.next = n 35 | l.tail = n 36 | l.size++ 37 | return n 38 | } 39 | 40 | func (l *twList) remove(n *twNode) { 41 | if n == nil { 42 | return 43 | } 44 | 45 | prev, next := n.prev, n.next 46 | if prev == nil { 47 | l.head = next 48 | } else { 49 | prev.next = next 50 | } 51 | 52 | if next == nil { 53 | l.tail = prev 54 | } else { 55 | next.prev = prev 56 | } 57 | n = nil // 主动释放内存 58 | l.size-- 59 | } 60 | 61 | func (l *twList) String() (s string) { 62 | s = fmt.Sprintf("[%d]: ", l.size) 63 | for cur := l.head; cur != nil; cur = cur.next { 64 | s += fmt.Sprintf("%v <-> ", cur.value) 65 | } 66 | s += "" 67 | 68 | return s 69 | } 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # timewheel 2 | 3 | Golang 实现的时间轮算法 4 | 5 | ## 功能 6 | 7 | - 执行定时任务 8 | - 执行指定次数的重复任务 9 | - 任务中断和更新 10 | 11 | ## 使用 12 | 13 | ```go 14 | package main 15 | 16 | import ( 17 | "fmt" 18 | "github.com/wuYin/timewheel" 19 | "time" 20 | ) 21 | 22 | func main() { 23 | tw := timewheel.NewTimeWheel(100*time.Millisecond, 600) // 周期为一分钟 24 | 25 | // 执行定时任务 26 | tid, _ := tw.After(5*time.Second, func() { 27 | fmt.Println("after 5 seconds, task1 executed") 28 | }) 29 | 30 | // 执行指定次数的重复任务 31 | _, allDone := tw.Repeat(1*time.Second, 3, func() { 32 | fmt.Println("per 1 second, task2 executed") 33 | }) 34 | <-allDone 35 | 36 | // 中途取消任务 37 | tw.Cancel(tid) 38 | } 39 | 40 | ``` 41 | 42 | 43 | 44 | ## 原理 45 | 46 | 使用双向链表存储提交的 **Task**,当 **Ticker** 扫到当前 **Slot** 后,将符合条件的 **Task** 放到新 goroutine 执行即可。 47 | 48 | 49 | 50 | ## 场景:定时保活 51 | 52 | 在 [wuYin/tron](https://github.com/wuYin/tron) 网络框架中,一个 Server 端需对已连接的多个 Client 定时发送 Ping 心跳包,若在超时时间内收到 Pong 包则认为连接有效,若未收到则二次规避重试一定次数后主动断开连接。实现方案: 53 | 54 | - 简单实现:为每个连接会话都分配一个 `Ticker` 定时保活,但连接过多后会占用 Server 过多内存资源 55 | - 时间轮实现:为每个 Server 配置一个时间轮,将保活任务作为指定次数的重复任务统一管理,安全高效 56 | 57 | ## 误差 58 | 59 | `time.Ticker ` 的粒度为 1ns,时间轮的粒度由用户指定。当新增任务时可能还未开始下次 tick,当 tick 粒度较大如 1s 时,任务执行时间将出现 `[0,1)s` 的明显误差。 60 | 61 | 为减少误差,tick 粒度向下取重复间隔 2 个量级较好。但粒度越细,时间轮占用 CPU 的频率越高,需做好权衡。 -------------------------------------------------------------------------------- /tw_test.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestAfter(t *testing.T) { 10 | tw := NewTimeWheel(10*time.Millisecond, 6000) 11 | start := time.Now() 12 | _, resCh := tw.After(2*time.Second, func() interface{} { 13 | fmt.Println(fmt.Sprintf("spent: %.2fs", time.Now().Sub(start).Seconds())) 14 | return true 15 | }) 16 | for res := range resCh { 17 | succ, ok := res.(bool) 18 | if !ok || succ { 19 | t.Fail() 20 | } 21 | } 22 | } 23 | 24 | func TestAfterPoints(t *testing.T) { 25 | tw := NewTimeWheel(100*time.Millisecond, 600) 26 | points := []int64{1, 2, 4, 8, 16} 27 | start := time.Now() 28 | _, resChs := tw.AfterPoints(1*time.Second, points, func() interface{} { 29 | fmt.Println(fmt.Sprintf("spent: %.2fs", time.Now().Sub(start).Seconds())) 30 | return true 31 | }) 32 | for _, resCh := range resChs { 33 | for res := range resCh { 34 | succ, ok := res.(bool) 35 | if !ok || succ { 36 | t.Fail() 37 | } 38 | } 39 | } 40 | } 41 | 42 | func TestRepeat(t *testing.T) { 43 | tw := NewTimeWheel(10*time.Millisecond, 6000) 44 | start := time.Now() 45 | _, allDoneCh := tw.Repeat(1*time.Second, 5, func() interface{} { 46 | fmt.Println(fmt.Sprintf("spent: %.2fs", time.Now().Sub(start).Seconds())) 47 | return true 48 | }) 49 | <-allDoneCh 50 | } 51 | 52 | func TestCancel(t *testing.T) { 53 | tw := NewTimeWheel(1*time.Second, 3) 54 | tid, _ := tw.After(4*time.Second, func() interface{} { 55 | fmt.Println("after 4s, task executed") 56 | return true 57 | }) 58 | time.Sleep(3 * time.Second) 59 | if !tw.Cancel(tid) { 60 | t.Fail() 61 | } 62 | if len(tw.taskMap) != 0 { 63 | t.Fail() 64 | } 65 | } 66 | 67 | func TestUpdate(t *testing.T) { 68 | tw := NewTimeWheel(10*time.Millisecond, 6000) 69 | start := time.Now() 70 | tids, _ := tw.Repeat(1*time.Second, 2, func() interface{} { 71 | fmt.Println(fmt.Sprintf("[origin] spent: %.2fs", time.Now().Sub(start).Seconds())) 72 | return true 73 | }) 74 | time.Sleep(2500 * time.Millisecond) 75 | _, allDoneCh := tw.Update(tids, 1*time.Second, 4, func() interface{} { 76 | fmt.Println(fmt.Sprintf("[updated] spent: %.2fs", time.Now().Sub(start).Seconds())) 77 | return true 78 | }) 79 | <-allDoneCh 80 | } 81 | -------------------------------------------------------------------------------- /tw.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | ) 10 | 11 | type TimeWheel struct { 12 | ticker *time.Ticker 13 | tickGap time.Duration // 每次 tick 时长 14 | slotNum int // slot 数量 15 | curSlot int // 当前 slot 序号 16 | slots []*twSlot // 槽数组 17 | taskMap map[int64]*twNode // taskId -> taskPtr 18 | incrId int64 // 自增 id 19 | taskCh chan *twTask // task 缓冲 channel 20 | lock sync.RWMutex // 数据读写锁 21 | } 22 | 23 | var cycleCost int64 // 周期耗时 24 | 25 | // 生成 slotNum 个以 tickGap 为时间间隔的时间轮 26 | func NewTimeWheel(tickGap time.Duration, slotNum int) *TimeWheel { 27 | tw := &TimeWheel{ 28 | ticker: time.NewTicker(tickGap), 29 | tickGap: tickGap, 30 | slotNum: slotNum, 31 | slots: make([]*twSlot, 0, slotNum), 32 | taskMap: make(map[int64]*twNode), 33 | taskCh: make(chan *twTask, 100), 34 | lock: sync.RWMutex{}, 35 | } 36 | cycleCost = int64(tw.tickGap * time.Duration(tw.slotNum)) 37 | for i := 0; i < slotNum; i++ { 38 | tw.slots = append(tw.slots, newSlot(i)) 39 | } 40 | 41 | go tw.turn() 42 | 43 | return tw 44 | } 45 | 46 | // 执行延时任务 47 | func (tw *TimeWheel) After(timeout time.Duration, do doTask) (int64, chan interface{}) { 48 | if timeout < 0 { 49 | return -1, nil 50 | } 51 | 52 | t := newTask(timeout, 1, do) 53 | tw.locate(t, t.interval, false) 54 | tw.taskCh <- t 55 | return t.id, t.resCh 56 | } 57 | 58 | // 执行指定重试逻辑的重复任务 59 | func (tw *TimeWheel) AfterPoints(timeoutUnit time.Duration, points []int64, do doTask) ([]int64, []chan interface{}) { 60 | if timeoutUnit < 0 || len(points) == 0 { 61 | return nil, nil 62 | } 63 | 64 | var tids []int64 65 | var resChs []chan interface{} 66 | for _, point := range points { 67 | timeout := timeoutUnit * time.Duration(point) 68 | tid, resCh := tw.After(timeout, do) 69 | tids = append(tids, tid) 70 | resChs = append(resChs, resCh) 71 | } 72 | 73 | return tids, resChs 74 | } 75 | 76 | // 执行重复任务 77 | func (tw *TimeWheel) Repeat(interval time.Duration, repeatN int64, do doTask) ([]int64, chan interface{}) { 78 | if interval <= 0 || repeatN < 1 { 79 | return nil, nil 80 | } 81 | 82 | costSum := repeatN * int64(interval) // 全部任务耗时 83 | cycleSum := costSum / cycleCost // 全部任务执行总圈数 84 | trip := cycleSum / cycle(interval) // 每个任务多少圈才执行一次 85 | 86 | var tids []int64 87 | var resChs []chan interface{} 88 | if trip > 0 { 89 | gap := interval 90 | for step := int64(0); step < cycleCost; step += int64(interval) { // 每隔 interval 放置执行 trip 次的 task 91 | t := newTask(interval, trip, do) 92 | tw.locate(t, gap, false) 93 | tw.taskCh <- t 94 | gap += interval 95 | tids = append(tids, t.id) 96 | resChs = append(resChs, t.resCh) 97 | } 98 | } 99 | 100 | // 计算余下几个任务时需重头开始计算 101 | gap := time.Duration(0) 102 | remain := (costSum % cycleCost) / int64(interval) 103 | for i := 0; i < int(remain); i++ { 104 | t := newTask(interval, 1, do) 105 | t.cycles = trip + 1 106 | tw.locate(t, gap, true) 107 | tw.taskCh <- t 108 | gap += interval 109 | tids = append(tids, t.id) 110 | resChs = append(resChs, t.resCh) 111 | } 112 | 113 | allDone := make(chan interface{}, 1) 114 | go func(doneChs []chan interface{}) { 115 | for _, ch := range doneChs { 116 | for range ch { 117 | } 118 | } 119 | allDone <- nil // 等待全部子任务完成 120 | }(resChs) 121 | return tids, allDone 122 | } 123 | 124 | // 更新任务 125 | func (tw *TimeWheel) Update(tids []int64, interval time.Duration, repeatN int64, do doTask) ([]int64, chan interface{}) { 126 | if len(tids) == 0 || interval <= 0 || repeatN < 1 { 127 | return nil, nil 128 | } 129 | 130 | if repeatN == 1 { 131 | if !tw.Cancel(tids[0]) { 132 | // return nil, nil // 按需处理 133 | } 134 | newTid, resCh := tw.After(interval, do) 135 | return []int64{newTid}, resCh 136 | } 137 | 138 | // 重复任务需全部取消 139 | for _, tid := range tids { 140 | if !tw.Cancel(tid) { 141 | // return nil, nil // 按需处理 142 | } 143 | } 144 | return tw.Repeat(interval, repeatN, do) 145 | } 146 | 147 | // 取消任务 148 | func (tw *TimeWheel) Cancel(tid int64) bool { 149 | tw.lock.Lock() 150 | defer tw.lock.Unlock() 151 | 152 | node, ok := tw.taskMap[tid] 153 | if !ok { 154 | return false // 任务已执行完毕或不存在 155 | } 156 | 157 | t := node.value.(*twTask) 158 | t.resCh <- nil 159 | close(t.resCh) // 避免资源泄漏 160 | 161 | slot := tw.slots[t.slotIdx] 162 | slot.tasks.remove(node) 163 | delete(tw.taskMap, tid) 164 | return true 165 | } 166 | 167 | // 接收 task 并定时运行 slot 中的任务 168 | func (tw *TimeWheel) turn() { 169 | idx := 0 170 | for { 171 | select { 172 | case <-tw.ticker.C: 173 | idx %= tw.slotNum 174 | tw.lock.Lock() 175 | tw.curSlot = idx // 锁粒度要细,不要重叠 176 | tw.lock.Unlock() 177 | tw.handleSlotTasks(idx) 178 | idx++ 179 | case t := <-tw.taskCh: 180 | tw.lock.Lock() 181 | // fmt.Println(t) 182 | slot := tw.slots[t.slotIdx] 183 | tw.taskMap[t.id] = slot.tasks.push(t) 184 | tw.lock.Unlock() 185 | } 186 | } 187 | } 188 | 189 | // 计算 task 所在 slot 的编号 190 | func (tw *TimeWheel) locate(t *twTask, gap time.Duration, restart bool) { 191 | tw.lock.Lock() 192 | defer tw.lock.Unlock() 193 | if restart { 194 | t.slotIdx = tw.convSlotIdx(gap) 195 | } else { 196 | t.slotIdx = tw.curSlot + tw.convSlotIdx(gap) 197 | } 198 | t.id = tw.slot2Task(t.slotIdx) 199 | } 200 | 201 | // 执行指定 slot 中的所有任务 202 | func (tw *TimeWheel) handleSlotTasks(idx int) { 203 | var expNodes []*twNode 204 | 205 | tw.lock.RLock() 206 | slot := tw.slots[idx] 207 | for node := slot.tasks.head; node != nil; node = node.next { 208 | task := node.value.(*twTask) 209 | task.cycles-- 210 | if task.cycles > 0 { 211 | continue 212 | } 213 | // 重复任务恢复 cycle 214 | if task.repeat > 0 { 215 | task.cycles = cycle(task.interval) 216 | task.repeat-- 217 | } 218 | 219 | // 不重复任务或重复任务最后一次执行都将移除 220 | if task.repeat == 0 { 221 | expNodes = append(expNodes, node) 222 | } 223 | go func() { 224 | defer func() { 225 | if err := recover(); err != nil { 226 | log.Printf("task exec paic: %v", err) // 出错暂只记录 227 | } 228 | }() 229 | 230 | var res interface{} 231 | if task.do != nil { 232 | res = task.do() 233 | } 234 | task.resCh <- res 235 | if task.repeat == 0 { 236 | close(task.resCh) 237 | } 238 | }() 239 | } 240 | tw.lock.RUnlock() 241 | 242 | tw.lock.Lock() 243 | for _, n := range expNodes { 244 | slot.tasks.remove(n) // 剔除过期任务 245 | delete(tw.taskMap, n.value.(*twTask).id) // 246 | } 247 | tw.lock.Unlock() 248 | } 249 | 250 | // 在指定 slot 中无重复生成新 task id 251 | func (tw *TimeWheel) slot2Task(slotIdx int) int64 { 252 | return int64(slotIdx)<<32 + atomic.AddInt64(&tw.incrId, 1) // 保证去重优先 253 | } 254 | 255 | // 反向获取 task 所在的 slot 256 | func (tw *TimeWheel) task2Slot(taskIdx int64) int { 257 | return int(taskIdx >> 32) 258 | } 259 | 260 | // 将指定间隔计算到指定的 slot 中 261 | func (tw *TimeWheel) convSlotIdx(gap time.Duration) int { 262 | timeGap := gap % time.Duration(cycleCost) 263 | slotGap := int(timeGap / tw.tickGap) 264 | return int(slotGap % tw.slotNum) 265 | } 266 | 267 | func (tw *TimeWheel) String() (s string) { 268 | for _, slot := range tw.slots { 269 | if slot.tasks.size > 0 { 270 | s += fmt.Sprintf("[%v]\t", slot.tasks) 271 | } 272 | } 273 | return 274 | } 275 | --------------------------------------------------------------------------------