├── .gitignore ├── README.md ├── cmd └── taskino │ └── main.go ├── cron.go ├── executor.go ├── go.mod ├── go.sum ├── scheduler.go ├── store.go ├── store_test.go ├── task.go ├── trigger.go └── trigger_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | taskino* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## taskino 2 | Micro Distributed Task Scheduler using Redis 3 | 4 | ### Install 5 | 6 | ``` 7 | go get -u github.com/pyloque/taskino 8 | ``` 9 | 10 | ### Example 11 | 12 | ```go 13 | // 构造 Redis 连接 14 | c:= redis.NewClient(&redis.Options{ 15 | Addr: "localhost:6379", 16 | DB: 0, 17 | }) 18 | store := taskino.NewRedisTaskStore(taskino.NewRedisStore(c), "sample", 5) 19 | // 日志 20 | logger := log.New(os.Stdout, "taskino", log.LstdFlags) 21 | // 创建调度器 22 | scheduler := taskino.NewDistributedScheduler(store, logger) 23 | 24 | // hello 循环任务 25 | hello := taskino.NewTask("hello", false, func() { 26 | fmt.Println("hello world") 27 | }) 28 | scheduler.Register(taskino.PeriodOfDelay(1, 5), hello) 29 | 30 | // stopper 停止任务(30s后停止调度器) 31 | stopper := taskino.NewTask("stopper", true, func() { 32 | scheduler.Stop() 33 | }) 34 | scheduler.Register(taskino.OnceOfDelay(30), stopper) 35 | 36 | // 设置任务全局版本号 37 | scheduler.SetVersion(1) 38 | // 开启调度 39 | err := scheduler.Start() 40 | if err != nil { 41 | panic(err) 42 | } 43 | // 等待退出 44 | scheduler.WaitForever() 45 | ``` 46 | 47 | ### 解决单点故障 48 | 多进程调度,挂掉一个其它进程可以继续调度 49 | 50 | ### 分布式任务锁 51 | 多进程同时调度,只有一个进程可以夺取任务执行权,这里使用 Redis 分布式锁来控制并发冲突 52 | 如果 `task.Concurrent=true` 那么多进程可以并行运行 53 | 54 | ### 任务重加载 55 | 使用全局版本号来监听任务变更,用来刷新任务调度时间(代码升级) 56 | 当任务有变更时,版本号发生变动,老代码进程会自动从 Redis 中同步新的任务调度时间 57 | 对有变动的任务进行重新调度 58 | 59 | ### 事件回调 60 | 监听任务运行时间,观察任务运行状态 61 | 62 | ```go 63 | type SampleListener struct { 64 | scheduler *taskino.DistributedScheduler 65 | } 66 | 67 | func NewSampleListener(scheduler *taskino.DistributedScheduler) *SampleListener { 68 | return &SampleListener{scheduler} 69 | } 70 | 71 | func (l *SampleListener) OnComplete(ctx *taskino.TaskContext) { 72 | fmt.Printf("task %s cost %d millis\n", ctx.Task.Name, ctx.CostInMillis) 73 | } 74 | 75 | func (l *SampleListener) OnStartup() { 76 | fmt.Println("scheduler started") 77 | } 78 | 79 | func (l *SampleListener) OnStop() { 80 | fmt.Println("scheduler stopped") 81 | } 82 | 83 | func (l *SampleListener) OnReload() { 84 | fmt.Println("scheduler reloaded") 85 | } 86 | 87 | ... 88 | scheduler.SetVersion(2) 89 | scheduler.AddListener(NewSampleListener(scheduler)) 90 | scheduler.Start() 91 | scheduler.WaitForever() 92 | ``` 93 | 94 | ### 支持三种任务类型 95 | 96 | 1. 单次任务(OnceTrigger):固定时间运行一次即结束 97 | 2. 循环任务(PeriodTrigger):从起始时间开始间隔循环到结束时间 98 | 3. CRON任务(CronTrigger):CRON表达式控制任务运行时间(最低精度 1 分钟) 99 | 100 | ```go 101 | taskino.OnceOf(startTime time.Time) *OnceTrigger 102 | taskino.PeriodOf(startTime time.Time, endTime time.Time, period int) *PeriodTrigger 103 | taskino.CronOf(expr string) *CronTrigger 104 | ``` 105 | 106 | ### 任务手动运行 107 | 108 | ```go 109 | scheduler.TriggerTask(name string) 110 | ``` 111 | 112 | ### 获取任务上次运行时间 113 | 114 | ```go 115 | scheduler.GetLastRunTime(name string) (*time.Time, error) 116 | scheduler.GetAllLastRunTimes() map[string] (*time.Time, error) 117 | ``` 118 | 119 | ### 注意点 120 | 121 | ``` 122 | 1. 如果在任务调度点发生网络抖动,Redis 读写出错,可能会引发任务的miss,需要监控 123 | 2. 多机器部署时务必保持时间同步,如果时间差异过大(5s),会导致任务重复执行 124 | ``` 125 | 126 | ### Example 127 | 128 | [入门实例](https://github.com/pyloque/taskino/blob/master/cmd/taskino/main.go) 129 | 130 | 131 | ### Java 版 132 | [jtaskino](https://github.com/pyloque/jtaskino) 133 | -------------------------------------------------------------------------------- /cmd/taskino/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | 8 | "github.com/go-redis/redis" 9 | "github.com/pyloque/taskino" 10 | ) 11 | 12 | type SampleListener struct { 13 | scheduler *taskino.DistributedScheduler 14 | } 15 | 16 | func NewSampleListener(scheduler *taskino.DistributedScheduler) *SampleListener { 17 | return &SampleListener{scheduler} 18 | } 19 | 20 | func (l *SampleListener) OnComplete(ctx *taskino.TaskContext) { 21 | fmt.Printf("task %s cost %d millis\n", ctx.Task.Name, ctx.CostInMillis) 22 | } 23 | 24 | func (l *SampleListener) OnStartup() { 25 | fmt.Println("scheduler started") 26 | } 27 | 28 | func (l *SampleListener) OnStop() { 29 | fmt.Println("scheduler stopped") 30 | } 31 | 32 | func (l *SampleListener) OnReload() { 33 | fmt.Println("scheduler reloaded") 34 | } 35 | 36 | func main() { 37 | c := redis.NewClient(&redis.Options{ 38 | Addr: "localhost:6379", 39 | DB: 0, 40 | }) 41 | store := taskino.NewRedisTaskStore(taskino.NewRedisStore(c), "sample", 5) 42 | logger := log.New(os.Stdout, "", log.LstdFlags) 43 | scheduler := taskino.NewDistributedScheduler(store, logger) 44 | once1 := taskino.TaskOf("once1", func() { 45 | fmt.Println("once1") 46 | }) 47 | scheduler.Register(taskino.OnceOfDelay(5), once1) 48 | period2 := taskino.TaskOf("period2", func() { 49 | fmt.Println("period2") 50 | }) 51 | scheduler.Register(taskino.PeriodOfDelay(5, 5), period2) 52 | cron3 := taskino.TaskOf("cron3", func() { 53 | fmt.Println("cron3") 54 | }) 55 | scheduler.Register(taskino.CronOfMinutes(1), cron3) 56 | period4 := taskino.TaskOf("period4", func() { 57 | fmt.Println("period4") 58 | }) 59 | scheduler.Register(taskino.PeriodOfDelay(5, 5), period4) 60 | stopper := taskino.ConcurrentTask("stopper", func() { 61 | scheduler.Stop() 62 | }) 63 | scheduler.Register(taskino.OnceOfDelay(70), stopper) 64 | scheduler.SetVersion(3) 65 | scheduler.AddListener(NewSampleListener(scheduler)) 66 | err := scheduler.Start() 67 | if err != nil { 68 | panic(err) 69 | } 70 | scheduler.WaitForever() 71 | } 72 | -------------------------------------------------------------------------------- /cron.go: -------------------------------------------------------------------------------- 1 | package taskino 2 | 3 | import ( 4 | "github.com/robfig/cron" 5 | "time" 6 | ) 7 | 8 | type CronPattern struct { 9 | s cron.Schedule 10 | } 11 | 12 | func NewCronPattern(expr string) (*CronPattern, bool) { 13 | s, err := cron.ParseStandard(expr) 14 | if err != nil { 15 | return nil, false 16 | } 17 | return &CronPattern{s}, true 18 | } 19 | 20 | func (p *CronPattern) Matches(millis int64) bool { 21 | var minuteAgoMillis = millis - 60*1000 22 | var minuteAgoTime = time.Unix(minuteAgoMillis/1000, (minuteAgoMillis%1000)*1000000) 23 | var minAgoNextTime = p.s.Next(minuteAgoTime) 24 | var gap = minuteAgoTime.Sub(minAgoNextTime) 25 | return gap < 60*1000000000 26 | } 27 | -------------------------------------------------------------------------------- /executor.go: -------------------------------------------------------------------------------- 1 | package taskino 2 | 3 | import ( 4 | "log" 5 | "sync" 6 | ) 7 | 8 | type Executor struct { 9 | wg *sync.WaitGroup 10 | logger *log.Logger 11 | stop bool 12 | } 13 | 14 | func NewExecutor(logger *log.Logger) *Executor { 15 | return &Executor{ 16 | wg: &sync.WaitGroup{}, 17 | logger: logger, 18 | } 19 | } 20 | 21 | func (e *Executor) submit(runner func()) { 22 | if e.stop { 23 | return 24 | } 25 | e.wg.Add(1) 26 | go func() { 27 | defer e.wg.Done() 28 | defer func() { 29 | if r := recover(); r != nil { 30 | e.logger.Printf("executor run error %s\n", r) 31 | } 32 | }() 33 | runner() 34 | }() 35 | } 36 | 37 | func (e *Executor) shutdown() { 38 | e.stop = true 39 | e.wg.Wait() 40 | } 41 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pyloque/taskino 2 | 3 | require ( 4 | github.com/go-redis/redis v6.15.2+incompatible 5 | github.com/onsi/ginkgo v1.8.0 // indirect 6 | github.com/onsi/gomega v1.5.0 // indirect 7 | github.com/robfig/cron v1.1.0 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= 2 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 3 | github.com/go-redis/redis v6.15.2+incompatible h1:9SpNVG76gr6InJGxoZ6IuuxaCOQwDAhzyXg+Bs+0Sb4= 4 | github.com/go-redis/redis v6.15.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= 5 | github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= 6 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 7 | github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= 8 | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= 9 | github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 10 | github.com/onsi/ginkgo v1.8.0 h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w= 11 | github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 12 | github.com/onsi/gomega v1.5.0 h1:izbySO9zDPmjJ8rDjLvkA2zJHIo+HkYXHnf7eN7SSyo= 13 | github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= 14 | github.com/robfig/cron v1.1.0 h1:jk4/Hud3TTdcrJgUOBgsqrZBarcxl6ADIjSC2iniwLY= 15 | github.com/robfig/cron v1.1.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k= 16 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 17 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 18 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 19 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 20 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 21 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 22 | gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= 23 | gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= 24 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= 25 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= 26 | gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= 27 | gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 28 | -------------------------------------------------------------------------------- /scheduler.go: -------------------------------------------------------------------------------- 1 | package taskino 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "time" 7 | ) 8 | 9 | type taskGrabber func(task *Task) bool 10 | 11 | type SchedulerListener interface { 12 | OnComplete(*TaskContext) 13 | OnStartup() 14 | OnStop() 15 | OnReload() 16 | } 17 | 18 | type DistributedScheduler struct { 19 | Store TaskStore 20 | Version int64 21 | AllTasks map[string]*Task 22 | Triggers map[string]Trigger 23 | executor *Executor 24 | reloadingTriggers map[string]Trigger 25 | listeners []SchedulerListener 26 | stop chan bool 27 | logger *log.Logger 28 | } 29 | 30 | func NewDistributedScheduler(store TaskStore, logger *log.Logger) *DistributedScheduler { 31 | var scheduler = &DistributedScheduler{ 32 | Store: store, 33 | AllTasks: make(map[string]*Task), 34 | Triggers: make(map[string]Trigger), 35 | reloadingTriggers: make(map[string]Trigger), 36 | listeners: make([]SchedulerListener, 0), 37 | stop: make(chan bool, 1), 38 | } 39 | if logger != nil { 40 | scheduler.logger = logger 41 | } else { 42 | scheduler.logger = log.New(ioutil.Discard, "scheduler", 0) 43 | } 44 | scheduler.executor = NewExecutor(scheduler.logger) 45 | return scheduler 46 | } 47 | 48 | func (s *DistributedScheduler) AddListener(listener SchedulerListener) *DistributedScheduler { 49 | s.listeners = append(s.listeners, listener) 50 | return s 51 | } 52 | 53 | func (s *DistributedScheduler) callListener(method string, consumer func(listener SchedulerListener)) { 54 | for _, listener := range s.listeners { 55 | func() { 56 | defer func() { 57 | if e := recover(); e != nil { 58 | s.logger.Printf("invoke task listener %s error %s\n", method, e) 59 | } 60 | }() 61 | consumer(listener) 62 | }() 63 | } 64 | } 65 | 66 | func (s *DistributedScheduler) Register(trigger Trigger, task *Task) *DistributedScheduler { 67 | if s.Triggers[task.Name] != nil { 68 | panic("task name duplicated!") 69 | } 70 | s.Triggers[task.Name] = trigger 71 | s.AllTasks[task.Name] = task 72 | task.callback = func(ctx *TaskContext) { 73 | func() { 74 | defer func() { 75 | if e := recover(); e != nil { 76 | s.logger.Printf("save task %s lastrun time error %s\n", task.Name, e) 77 | } 78 | }() 79 | var now = time.Now() 80 | if err := s.Store.SaveLastRunTime(task.Name, &now); err != nil { 81 | s.logger.Printf("save task %s last run time error %s\n", task.Name, err) 82 | } 83 | }() 84 | s.callListener("OnCompelete", func(listener SchedulerListener) { 85 | listener.OnComplete(ctx) 86 | }) 87 | } 88 | return s 89 | } 90 | 91 | func (s *DistributedScheduler) TriggerTask(name string) { 92 | var task = s.AllTasks[name] 93 | if task != nil { 94 | task.run() 95 | } 96 | } 97 | 98 | func (s *DistributedScheduler) GetLastRunTime(name string) (*time.Time, error) { 99 | return s.Store.GetLastRunTime(name) 100 | } 101 | 102 | func (s *DistributedScheduler) GetAllLastRunTimes() (map[string]*time.Time, error) { 103 | return s.Store.GetAllLastRunTimes() 104 | } 105 | 106 | func (s *DistributedScheduler) SetVersion(version int64) *DistributedScheduler { 107 | if version < 0 { 108 | panic("illegal version!") 109 | } 110 | s.Version = version 111 | return s 112 | } 113 | 114 | func (s *DistributedScheduler) Start() error { 115 | if err := s.saveTriggers(); err != nil { 116 | return err 117 | } 118 | s.scheduleTasks() 119 | go s.scheduleReload() 120 | s.callListener("OnStartup", func(listener SchedulerListener) { 121 | listener.OnStartup() 122 | }) 123 | return nil 124 | } 125 | 126 | func (s *DistributedScheduler) WaitForever() { 127 | <-s.stop 128 | } 129 | 130 | func (s *DistributedScheduler) saveTriggers() error { 131 | var triggersRaw = map[string]string{} 132 | for name, trigger := range s.Triggers { 133 | triggersRaw[name] = SerializeTrigger(trigger) 134 | } 135 | err := s.Store.SaveAllTriggers(s.Version, triggersRaw) 136 | if err != nil { 137 | s.logger.Printf("save task triggers error %s", err) 138 | } 139 | return err 140 | } 141 | 142 | func (s *DistributedScheduler) scheduleTasks() { 143 | for name, trigger := range s.Triggers { 144 | var task = s.AllTasks[name] 145 | if task == nil { 146 | continue 147 | } 148 | s.logger.Printf("scheduling task %s\n", name) 149 | trigger.schedule(s.executor, s.grabTask, task) 150 | } 151 | } 152 | 153 | func (s *DistributedScheduler) grabTask(task *Task) bool { 154 | if task.Concurrent { 155 | return true 156 | } 157 | r, e := s.Store.GrabTask(task.Name) 158 | if e != nil { 159 | s.logger.Printf("grab task %s error %s\n", task.Name, e) 160 | } 161 | return r 162 | } 163 | 164 | func (s *DistributedScheduler) scheduleReload() { 165 | var ticker = time.NewTicker(time.Second) 166 | for { 167 | select { 168 | case <-ticker.C: 169 | if s.reloadIfChanged() { 170 | s.rescheduleTasks() 171 | } 172 | break 173 | case <-s.stop: 174 | ticker.Stop() 175 | return 176 | } 177 | } 178 | } 179 | 180 | func (s *DistributedScheduler) reloadIfChanged() bool { 181 | defer func() { 182 | if e := recover(); e != nil { 183 | s.logger.Printf("reloading task error %s\n", e) 184 | } 185 | }() 186 | remoteVersion, err := s.Store.GetRemoteVersion() 187 | if err != nil { 188 | s.logger.Printf("get remote version error %s\n", err) 189 | return false 190 | } 191 | if remoteVersion != s.Version { 192 | s.Version = remoteVersion 193 | s.reload() 194 | return true 195 | } 196 | return false 197 | } 198 | 199 | func (s *DistributedScheduler) reload() { 200 | raws, err := s.Store.GetAllTriggers() 201 | if err != nil { 202 | log.Printf("load triggers error %s\n", err) 203 | return 204 | } 205 | var reloadings = map[string]Trigger{} 206 | for name, raw := range raws { 207 | if s.AllTasks[name] != nil { 208 | var trigger = ParseTrigger(raw) 209 | var oldTrigger = s.Triggers[name] 210 | if oldTrigger == nil || !oldTrigger.equals(trigger) { 211 | reloadings[name] = trigger 212 | } 213 | } 214 | } 215 | for name := range s.Triggers { 216 | if raws[name] == "" { 217 | reloadings[name] = nil 218 | } 219 | } 220 | s.reloadingTriggers = reloadings 221 | } 222 | 223 | func (s *DistributedScheduler) rescheduleTasks() { 224 | for name, trigger := range s.reloadingTriggers { 225 | var task = s.AllTasks[name] 226 | if trigger == nil { 227 | s.logger.Printf("cancelling task %s\n", name) 228 | s.Triggers[name].cancel() 229 | delete(s.Triggers, name) 230 | } else { 231 | var oldTrigger = s.Triggers[name] 232 | if oldTrigger != nil { 233 | s.logger.Printf("cancelling task %s\n", name) 234 | oldTrigger.cancel() 235 | } 236 | s.Triggers[name] = trigger 237 | s.logger.Printf("scheduling task %s\n", name) 238 | trigger.schedule(s.executor, s.grabTask, task) 239 | } 240 | } 241 | s.reloadingTriggers = map[string]Trigger{} 242 | s.callListener("OnReload", func(listener SchedulerListener) { 243 | listener.OnReload() 244 | }) 245 | } 246 | 247 | func (s *DistributedScheduler) cancelAllTasks() { 248 | for name, trigger := range s.Triggers { 249 | s.logger.Printf("cancelling task %s\n", name) 250 | trigger.cancel() 251 | } 252 | s.Triggers = map[string]Trigger{} 253 | } 254 | 255 | func (s *DistributedScheduler) Stop() { 256 | close(s.stop) 257 | s.cancelAllTasks() 258 | s.executor.shutdown() 259 | s.callListener("OnStop", func(listener SchedulerListener) { 260 | listener.OnStop() 261 | }) 262 | } 263 | -------------------------------------------------------------------------------- /store.go: -------------------------------------------------------------------------------- 1 | package taskino 2 | 3 | import ( 4 | "fmt" 5 | "github.com/go-redis/redis" 6 | "math/rand" 7 | "strings" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | type TaskStore interface { 13 | GetRemoteVersion() (int64, error) 14 | GetAllTriggers() (map[string]string, error) 15 | SaveAllTriggers(version int64, triggers map[string]string) error 16 | GrabTask(name string) (bool, error) 17 | SaveLastRunTime(name string, lastRun *time.Time) error 18 | GetLastRunTime(name string) (*time.Time, error) 19 | GetAllLastRunTimes() (map[string]*time.Time, error) 20 | } 21 | 22 | type MemoryTaskStore struct { 23 | triggers map[string]string 24 | lastRuns map[string]*time.Time 25 | version int64 26 | *sync.Mutex // protect lastRuns 27 | } 28 | 29 | func NewMemoryTaskStore() *MemoryTaskStore { 30 | return &MemoryTaskStore{ 31 | Mutex: &sync.Mutex{}, 32 | triggers: make(map[string]string), 33 | lastRuns: make(map[string]*time.Time), 34 | version: 0, 35 | } 36 | } 37 | 38 | func (s *MemoryTaskStore) GetRemoteVersion() (int64, error) { 39 | return s.version, nil 40 | } 41 | 42 | func (s *MemoryTaskStore) GetAllTriggers() (map[string]string, error) { 43 | return s.triggers, nil 44 | } 45 | 46 | func (s *MemoryTaskStore) SaveAllTriggers(version int64, triggers map[string]string) error { 47 | s.triggers = triggers 48 | s.version = version 49 | return nil 50 | } 51 | 52 | func (s *MemoryTaskStore) GrabTask(name string) (bool, error) { 53 | return true, nil 54 | } 55 | 56 | func (s *MemoryTaskStore) SaveLastRunTime(name string, lastRun *time.Time) error { 57 | s.Lock() 58 | defer s.Unlock() 59 | s.lastRuns[name] = lastRun 60 | return nil 61 | } 62 | 63 | func (s *MemoryTaskStore) GetLastRunTime(name string) (*time.Time, error) { 64 | s.Lock() 65 | defer s.Unlock() 66 | return s.lastRuns[name], nil 67 | } 68 | 69 | func (s *MemoryTaskStore) GetAllLastRunTimes() (map[string]*time.Time, error) { 70 | s.Lock() 71 | defer s.Unlock() 72 | r := map[string]*time.Time{} 73 | for k, v := range s.lastRuns { 74 | r[k] = v 75 | } 76 | return r, nil 77 | } 78 | 79 | type RedisStore struct { 80 | clients []*redis.Client 81 | } 82 | 83 | func NewRedisStore(clients ...*redis.Client) *RedisStore { 84 | return &RedisStore{clients: clients} 85 | } 86 | 87 | func (s *RedisStore) execute(consumer func(*redis.Client)) { 88 | var i = rand.Int31n(int32(len(s.clients))) 89 | var client = s.clients[i] 90 | consumer(client) 91 | } 92 | 93 | type RedisTaskStore struct { 94 | redis *RedisStore 95 | group string 96 | lockAge int 97 | } 98 | 99 | func NewRedisTaskStore(redis *RedisStore, group string, lockAge int) *RedisTaskStore { 100 | return &RedisTaskStore{redis, group, lockAge} 101 | } 102 | 103 | func (s *RedisTaskStore) GrabTask(name string) (r bool, e error) { 104 | s.redis.execute(func(redis *redis.Client) { 105 | var key = s.keyFor("task_lock", name) 106 | var cmd = redis.SetNX(key, "true", time.Second*time.Duration(s.lockAge)) 107 | if cmd.Err() != nil { 108 | e = cmd.Err() 109 | return 110 | } 111 | r = cmd.Val() 112 | }) 113 | return 114 | } 115 | 116 | func (s *RedisTaskStore) keyFor(args ...interface{}) string { 117 | var params = make([]string, len(args)+1) 118 | params[0] = s.group 119 | for i := 0; i < len(args); i++ { 120 | params[i+1] = fmt.Sprintf("%s", args[i]) 121 | } 122 | return strings.Join(params, "_") 123 | } 124 | 125 | func (s *RedisTaskStore) GetRemoteVersion() (r int64, e error) { 126 | r = 0 127 | s.redis.execute(func(redis *redis.Client) { 128 | var key = s.keyFor("version") 129 | var cmd = redis.IncrBy(key, 0) 130 | if cmd.Err() != nil { 131 | e = cmd.Err() 132 | return 133 | } 134 | r = cmd.Val() 135 | }) 136 | return 137 | } 138 | 139 | func (s *RedisTaskStore) GetAllTriggers() (r map[string]string, e error) { 140 | s.redis.execute(func(redis *redis.Client) { 141 | var key = s.keyFor("triggers") 142 | var cmd = redis.HGetAll(key) 143 | if cmd.Err() != nil { 144 | e = cmd.Err() 145 | return 146 | } 147 | r = cmd.Val() 148 | }) 149 | return 150 | } 151 | 152 | func (s *RedisTaskStore) SaveAllTriggers(version int64, triggers map[string]string) (r error) { 153 | var triggersGeneric = make(map[string]interface{}) 154 | for key, value := range triggers { 155 | triggersGeneric[key] = value 156 | } 157 | s.redis.execute(func(redis *redis.Client) { 158 | var triggersKey = s.keyFor("triggers") 159 | var lastRunKey = s.keyFor("lastruns") 160 | var versionKey = s.keyFor("version") 161 | cmd := redis.HMSet(triggersKey, triggersGeneric) 162 | if cmd.Err() != nil { 163 | r = cmd.Err() 164 | return 165 | } 166 | for _, name := range redis.HKeys(triggersKey).Val() { 167 | if triggersGeneric[name] == nil { 168 | cmd := redis.HDel(triggersKey, name) 169 | if cmd.Err() != nil { 170 | r = cmd.Err() 171 | return 172 | } 173 | cmd = redis.HDel(lastRunKey, name) 174 | if cmd.Err() != nil { 175 | r = cmd.Err() 176 | return 177 | } 178 | } 179 | } 180 | cmd = redis.Set(versionKey, version, 0) 181 | r = cmd.Err() 182 | }) 183 | return 184 | } 185 | 186 | func (s *RedisTaskStore) SaveLastRunTime(name string, lastRun *time.Time) (r error) { 187 | s.redis.execute(func(redis *redis.Client) { 188 | var key = s.keyFor("lastruns") 189 | var raw = lastRun.Format(LayoutISO) 190 | cmd := redis.HSet(key, name, raw) 191 | r = cmd.Err() 192 | }) 193 | return 194 | } 195 | 196 | func (s *RedisTaskStore) GetLastRunTime(name string) (r *time.Time, e error) { 197 | s.redis.execute(func(redis *redis.Client) { 198 | var key = s.keyFor("lastruns") 199 | var cmd = redis.HGet(key, name) 200 | if cmd.Err() != nil { 201 | e = cmd.Err() 202 | return 203 | } 204 | var raw = cmd.Val() 205 | if raw != "" { 206 | t, err := time.Parse(LayoutISO, raw) 207 | if err != nil { 208 | e = err 209 | return 210 | } 211 | r = &t 212 | } 213 | }) 214 | return 215 | } 216 | 217 | func (s *RedisTaskStore) GetAllLastRunTimes() (r map[string]*time.Time, e error) { 218 | s.redis.execute(func(redis *redis.Client) { 219 | var key = s.keyFor("lastruns") 220 | var cmd = redis.HGetAll(key) 221 | if cmd.Err() != nil { 222 | e = cmd.Err() 223 | return 224 | } 225 | r = map[string]*time.Time{} 226 | for name, raw := range cmd.Val() { 227 | t, err := time.Parse(LayoutISO, raw) 228 | if err != nil { 229 | e = err 230 | r = nil 231 | return 232 | } 233 | r[name] = &t 234 | } 235 | }) 236 | return 237 | } 238 | -------------------------------------------------------------------------------- /store_test.go: -------------------------------------------------------------------------------- 1 | package taskino 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/go-redis/redis" 7 | ) 8 | 9 | func TestMemoryTaskStore(t *testing.T) { 10 | var store = NewMemoryTaskStore() 11 | var triggers = map[string]Trigger{} 12 | triggers["once1"] = OnceOfDelay(5) 13 | triggers["period2"] = PeriodOfDelay(5, 10) 14 | triggers["cron3"] = CronOfDays(2, 12, 30) 15 | var triggersRaw = map[string]string{} 16 | for name, trigger := range triggers { 17 | triggersRaw[name] = SerializeTrigger(trigger) 18 | } 19 | err := store.SaveAllTriggers(1024, triggersRaw) 20 | if err != nil { 21 | t.Errorf("memory save all trigger error %s", err) 22 | } 23 | version, err := store.GetRemoteVersion() 24 | if err != nil { 25 | t.Errorf("memory get remote version error %s", err) 26 | } 27 | if version != 1024 { 28 | t.Errorf("memory get remote version mismatch") 29 | } 30 | for _, name := range []string{"once1", "once1", "once1"} { 31 | r, err := store.GrabTask(name) 32 | if err != nil { 33 | t.Errorf("memory grab task error %s", err) 34 | } 35 | if !r { 36 | t.Errorf("memory grab task failed") 37 | } 38 | } 39 | } 40 | 41 | func TestRedisTaskStore(t *testing.T) { 42 | var client = redis.NewClient(&redis.Options{ 43 | Addr: "localhost:6379", 44 | DB: 0, 45 | }) 46 | var store = NewRedisTaskStore(NewRedisStore(client), "test", 5) 47 | var triggers = map[string]Trigger{} 48 | triggers["once1"] = OnceOfDelay(5) 49 | triggers["period2"] = PeriodOfDelay(5, 10) 50 | triggers["cron3"] = CronOfDays(2, 12, 30) 51 | var triggersRaw = map[string]string{} 52 | for name, trigger := range triggers { 53 | triggersRaw[name] = SerializeTrigger(trigger) 54 | } 55 | err := store.SaveAllTriggers(1024, triggersRaw) 56 | if err != nil { 57 | t.Errorf("redis save all trigger error %s", err) 58 | } 59 | version, err := store.GetRemoteVersion() 60 | if err != nil { 61 | t.Errorf("redis get remote version error %s", err) 62 | } 63 | if version != 1024 { 64 | t.Errorf("redis get remote version mismatch") 65 | } 66 | r, err := store.GrabTask("once1") 67 | if err != nil { 68 | t.Errorf("redis grab task error %s", err) 69 | } 70 | if !r { 71 | t.Errorf("redis grab task failed") 72 | } 73 | r, err = store.GrabTask("once1") 74 | if err != nil { 75 | t.Errorf("redis grab task error %s", err) 76 | } 77 | if r { 78 | t.Errorf("redis grab task should not be ok here") 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /task.go: -------------------------------------------------------------------------------- 1 | package taskino 2 | 3 | import "time" 4 | 5 | type Runner func() 6 | 7 | type TaskCallback func(*TaskContext) 8 | 9 | type Task struct { 10 | Name string 11 | Concurrent bool 12 | runner Runner 13 | callback TaskCallback 14 | } 15 | 16 | func NewTask(name string, concurrent bool, runner Runner) *Task { 17 | return &Task{ 18 | Name: name, 19 | Concurrent: concurrent, 20 | runner: runner, 21 | } 22 | } 23 | 24 | func TaskOf(name string, runner Runner) *Task { 25 | return NewTask(name, false, runner) 26 | } 27 | 28 | func ConcurrentTask(name string, runner Runner) *Task { 29 | return NewTask(name, true, runner) 30 | } 31 | 32 | func (t *Task) run() { 33 | startTime := time.Now() 34 | err := func() (e error) { 35 | defer func() { 36 | if r := recover(); r != nil { 37 | e = r.(error) 38 | } 39 | }() 40 | t.runner() 41 | return e 42 | }() 43 | endTime := time.Now() 44 | cost := endTime.Sub(startTime).Nanoseconds() / 1000000 45 | ctx := NewTaskContext(t, cost, err == nil, err) 46 | if t.callback != nil { 47 | t.callback(ctx) 48 | } 49 | } 50 | 51 | type TaskContext struct { 52 | Task *Task 53 | CostInMillis int64 54 | Ok bool 55 | Err error 56 | } 57 | 58 | func NewTaskContext(task *Task, costInMillis int64, ok bool, e error) *TaskContext { 59 | return &TaskContext{ 60 | Task: task, 61 | CostInMillis: costInMillis, 62 | Ok: ok, 63 | Err: e, 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /trigger.go: -------------------------------------------------------------------------------- 1 | package taskino 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | type TriggerType string 11 | 12 | const ( 13 | ONCE TriggerType = "ONCE" 14 | PERIOD TriggerType = "PERIOD" 15 | CRON TriggerType = "CRON" 16 | ) 17 | 18 | const LayoutISO = "2006-01-02T15:04:05-0700" 19 | 20 | type Trigger interface { 21 | kind() TriggerType 22 | parse(s string) 23 | serialize() string 24 | equals(other Trigger) bool 25 | cancel() 26 | schedule(executor *Executor, grabber taskGrabber, task *Task) bool 27 | } 28 | 29 | type OnceTrigger struct { 30 | startTime time.Time 31 | timer *time.Timer 32 | stop chan bool 33 | } 34 | 35 | func NewOnceTrigger() *OnceTrigger { 36 | return &OnceTrigger{} 37 | } 38 | 39 | func (t *OnceTrigger) kind() TriggerType { 40 | return ONCE 41 | } 42 | 43 | func (t *OnceTrigger) serialize() string { 44 | return t.startTime.Format(LayoutISO) 45 | } 46 | 47 | func (t *OnceTrigger) parse(s string) { 48 | startTime, err := time.Parse(LayoutISO, s) 49 | // should not happen 50 | if err != nil { 51 | panic(err) 52 | } 53 | t.startTime = startTime 54 | } 55 | 56 | func (t *OnceTrigger) equals(other Trigger) bool { 57 | oo, ok := other.(*OnceTrigger) 58 | if !ok { 59 | return false 60 | } 61 | return t.startTime.Unix() == oo.startTime.Unix() 62 | } 63 | 64 | func (t *OnceTrigger) cancel() { 65 | if t.stop != nil { 66 | close(t.stop) 67 | } 68 | } 69 | 70 | func (t *OnceTrigger) schedule(executor *Executor, grabber taskGrabber, task *Task) bool { 71 | var gap = t.startTime.Sub(time.Now()) 72 | if gap < 0 { 73 | return false 74 | } 75 | t.stop = make(chan bool) 76 | t.timer = time.NewTimer(gap) 77 | go func() { 78 | select { 79 | case <-t.timer.C: 80 | t.timer.Stop() 81 | executor.submit(func() { 82 | if grabber(task) { 83 | task.run() 84 | } 85 | }) 86 | break 87 | case <-t.stop: 88 | t.timer.Stop() 89 | break 90 | } 91 | }() 92 | return true 93 | } 94 | 95 | func OnceOf(startTime time.Time) *OnceTrigger { 96 | var trigger = NewOnceTrigger() 97 | trigger.startTime = startTime 98 | return trigger 99 | } 100 | 101 | func OnceOfDelay(seconds int) *OnceTrigger { 102 | var startTime = time.Now().Add(time.Duration(seconds) * time.Second) 103 | return OnceOf(startTime) 104 | } 105 | 106 | type PeriodTrigger struct { 107 | startTime time.Time 108 | endTime time.Time 109 | period int 110 | delayTimer *time.Timer 111 | periodTicker *time.Ticker 112 | stop chan bool 113 | } 114 | 115 | func NewPeriodTrigger() *PeriodTrigger { 116 | return &PeriodTrigger{} 117 | } 118 | 119 | func (t *PeriodTrigger) kind() TriggerType { 120 | return PERIOD 121 | } 122 | 123 | func (t *PeriodTrigger) serialize() string { 124 | var starts = t.startTime.Format(LayoutISO) 125 | var ends = t.endTime.Format(LayoutISO) 126 | return fmt.Sprintf("%s|%s|%d", starts, ends, t.period) 127 | } 128 | 129 | func (t *PeriodTrigger) parse(s string) { 130 | var parts = strings.Split(s, "|") 131 | startTime, err := time.Parse(LayoutISO, parts[0]) 132 | // should not happen 133 | if err != nil { 134 | panic(err) 135 | } 136 | endTime, err := time.Parse(LayoutISO, parts[1]) 137 | // should not happen 138 | if err != nil { 139 | panic(err) 140 | } 141 | t.startTime = startTime 142 | t.endTime = endTime 143 | t.period, err = strconv.Atoi(parts[2]) 144 | // should not happend 145 | if err != nil { 146 | panic(err) 147 | } 148 | } 149 | 150 | func (t *PeriodTrigger) equals(other Trigger) bool { 151 | oo, ok := other.(*PeriodTrigger) 152 | if !ok { 153 | return false 154 | } 155 | if t.startTime.Unix() != oo.startTime.Unix() { 156 | return false 157 | } 158 | if t.endTime.Unix() != oo.endTime.Unix() { 159 | return false 160 | } 161 | return t.period == oo.period 162 | } 163 | 164 | func (t *PeriodTrigger) cancel() { 165 | if t.stop != nil { 166 | close(t.stop) 167 | } 168 | } 169 | 170 | func (t *PeriodTrigger) schedule(executor *Executor, grabber taskGrabber, task *Task) bool { 171 | var now = time.Now() 172 | if t.endTime.Before(now) { 173 | return false 174 | } 175 | var delay time.Duration 176 | if t.startTime.After(now) { 177 | delay = t.startTime.Sub(now) 178 | } else { 179 | elapsed := now.Sub(t.startTime).Nanoseconds() % int64(t.period*1000000000) 180 | if elapsed > 0 { 181 | delay = time.Duration(t.period)*time.Second - time.Duration(elapsed)*time.Nanosecond 182 | } 183 | } 184 | t.stop = make(chan bool) 185 | if delay > 0 { 186 | t.delayTimer = time.NewTimer(delay) 187 | go t.delayTask(executor, grabber, task) 188 | } else { 189 | go t.tickTask(executor, grabber, task) 190 | } 191 | return true 192 | } 193 | 194 | func (t *PeriodTrigger) delayTask(executor *Executor, grabber taskGrabber, task *Task) { 195 | stop := false 196 | select { 197 | case <-t.delayTimer.C: 198 | break 199 | case <-t.stop: 200 | stop = true 201 | break 202 | } 203 | t.delayTimer.Stop() 204 | if stop { 205 | return 206 | } 207 | t.tickTask(executor, grabber, task) 208 | } 209 | 210 | func (t *PeriodTrigger) tickTask(executor *Executor, grabber taskGrabber, task *Task) { 211 | t.periodTicker = time.NewTicker(time.Duration(t.period) * time.Second) 212 | for { 213 | if time.Now().Sub(t.endTime) >= 0 { 214 | return 215 | } 216 | executor.submit(func() { 217 | if grabber(task) { 218 | task.run() 219 | } 220 | }) 221 | select { 222 | case <-t.periodTicker.C: 223 | break 224 | case <-t.stop: 225 | t.periodTicker.Stop() 226 | return 227 | } 228 | } 229 | } 230 | 231 | func PeriodOf(startTime time.Time, endTime time.Time, period int) *PeriodTrigger { 232 | var trigger = NewPeriodTrigger() 233 | trigger.startTime = startTime 234 | trigger.endTime = endTime 235 | trigger.period = period 236 | return trigger 237 | } 238 | 239 | func PeriodOfStart(startTime time.Time, period int) *PeriodTrigger { 240 | maxTime := time.Date(2048, time.April, 0, 0, 0, 0, 0, time.UTC) 241 | return PeriodOf(startTime, maxTime, period) 242 | } 243 | 244 | func PeriodOfDelay(delay int, period int) *PeriodTrigger { 245 | var startTime = time.Now().Add(time.Duration(delay) * time.Second) 246 | return PeriodOfStart(startTime, period) 247 | } 248 | 249 | type CronTrigger struct { 250 | expr string 251 | delayTimer *time.Timer 252 | periodTicker *time.Ticker 253 | stop chan bool 254 | } 255 | 256 | func NewCronTrigger() *CronTrigger { 257 | return &CronTrigger{} 258 | } 259 | 260 | func (t *CronTrigger) kind() TriggerType { 261 | return CRON 262 | } 263 | 264 | func (t *CronTrigger) serialize() string { 265 | return t.expr 266 | } 267 | 268 | func (t *CronTrigger) parse(s string) { 269 | t.expr = s 270 | } 271 | 272 | func (t *CronTrigger) equals(other Trigger) bool { 273 | oo, ok := other.(*CronTrigger) 274 | if !ok { 275 | return false 276 | } 277 | return t.expr == oo.expr 278 | } 279 | 280 | func (t *CronTrigger) cancel() { 281 | if t.stop != nil { 282 | close(t.stop) 283 | } 284 | } 285 | 286 | func (t *CronTrigger) schedule(executor *Executor, grabber taskGrabber, task *Task) bool { 287 | now := time.Now() 288 | snow := now.Add(-time.Duration(now.Nanosecond())) 289 | if snow.Second() != 0 { 290 | snow.Add(-time.Duration(snow.Second()) * time.Second) 291 | snow.Add(time.Minute) 292 | } 293 | delay := snow.Sub(now).Nanoseconds() 294 | if delay < 0 { 295 | delay = 0 296 | } 297 | if delay > 0 { 298 | t.delayTimer = time.NewTimer(time.Duration(delay) * time.Nanosecond) 299 | go t.delayTask(executor, grabber, task) 300 | } else { 301 | go t.tickTask(executor, grabber, task) 302 | } 303 | return true 304 | } 305 | 306 | func (t *CronTrigger) delayTask(executor *Executor, grabber taskGrabber, task *Task) { 307 | select { 308 | case <-t.delayTimer.C: 309 | t.delayTimer.Stop() 310 | t.tickTask(executor, grabber, task) 311 | break 312 | case <-t.stop: 313 | t.delayTimer.Stop() 314 | break 315 | } 316 | } 317 | 318 | func (t *CronTrigger) tickTask(executor *Executor, grabber taskGrabber, task *Task) { 319 | t.periodTicker = time.NewTicker(time.Minute) 320 | p, _ := NewCronPattern(t.expr) 321 | for { 322 | select { 323 | case <-t.periodTicker.C: 324 | if p.Matches(time.Now().UnixNano() / 1000000) { 325 | executor.submit(func() { 326 | if grabber(task) { 327 | task.run() 328 | } 329 | }) 330 | } 331 | break 332 | case <-t.stop: 333 | t.periodTicker.Stop() 334 | return 335 | } 336 | } 337 | } 338 | 339 | func CronOf(expr string) *CronTrigger { 340 | var trigger = NewCronTrigger() 341 | trigger.expr = expr 342 | return trigger 343 | } 344 | 345 | func CronOfMinutes(minutes int) *CronTrigger { 346 | return CronOf(fmt.Sprintf("*/%d * * * *", minutes)) 347 | } 348 | 349 | func CronOfHours(hours int, minuteOffset int) *CronTrigger { 350 | return CronOf(fmt.Sprintf("%d */%d * * *", minuteOffset, hours)) 351 | } 352 | 353 | func CronOfDays(days int, hourOffset int, minuteOffset int) *CronTrigger { 354 | return CronOf(fmt.Sprintf("%d %d */%d * *", minuteOffset, hourOffset, days)) 355 | } 356 | 357 | func ParseTrigger(s string) Trigger { 358 | var parts = strings.SplitN(s, "@", 2) 359 | var typ = TriggerType(parts[0]) 360 | var trigger Trigger = nil 361 | switch typ { 362 | case ONCE: 363 | trigger = NewOnceTrigger() 364 | break 365 | case PERIOD: 366 | trigger = NewPeriodTrigger() 367 | break 368 | case CRON: 369 | trigger = NewCronTrigger() 370 | break 371 | default: 372 | panic("illegal trigger string") 373 | } 374 | trigger.parse(parts[1]) 375 | return trigger 376 | } 377 | 378 | func SerializeTrigger(trigger Trigger) string { 379 | return fmt.Sprintf("%s@%s", trigger.kind(), trigger.serialize()) 380 | } 381 | -------------------------------------------------------------------------------- /trigger_test.go: -------------------------------------------------------------------------------- 1 | package taskino 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestTriggerSerialize(t *testing.T) { 8 | var s1 = OnceOfDelay(5) 9 | var t1 = ParseTrigger(SerializeTrigger(s1)) 10 | if !s1.equals(t1) { 11 | t.Error("once trigger serialize error") 12 | return 13 | } 14 | var s2 = PeriodOfDelay(5, 10) 15 | var t2 = ParseTrigger(SerializeTrigger(s2)) 16 | if !s2.equals(t2) { 17 | t.Error("period trigger serialize error") 18 | } 19 | var s3 = CronOfDays(2, 12, 30) 20 | var t3 = ParseTrigger(SerializeTrigger(s3)) 21 | if !s3.equals(t3) { 22 | t.Error("cron trigger serialize error") 23 | } 24 | } 25 | --------------------------------------------------------------------------------