├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── checkpointer.go ├── checkpointer_integration_test.go ├── checkpointer_test.go ├── consumer.go ├── consumer_test.go ├── docker-compose.yml ├── example_checkpoint_test.go ├── example_test.go ├── go.mod ├── go.sum ├── helper_test.go ├── integration_test.go ├── monitoring.go └── monitoring_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - "1.11" 4 | - "1.12" 5 | - "1.13" 6 | services: 7 | - docker 8 | env: 9 | - KINESIS_ENDPOINT="http://localhost:4567" DYNAMODB_ENDPOINT="http://localhost:8000" AWS_DEFAULT_REGION="ap-nil-1" AWS_ACCESS_KEY="AKAILKAJDFLKADJFL" AWS_SECRET_KEY="90uda098fjdsoifjsdaoifjpisjf" GO111MODULE=on 10 | before_install: 11 | - docker pull deangiberson/aws-dynamodb-local 12 | - docker pull dlsniper/kinesalite 13 | install: make get 14 | script: make docker-integration 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 Patrick robinson 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TIMEOUT = 30 2 | 3 | get: 4 | @go get ./... 5 | 6 | check test tests: 7 | @go test -short -timeout $(TIMEOUT)s ./... 8 | 9 | integration: get 10 | @go test -timeout 30s -tags=integration 11 | @go test -run TestRebalance -count 25 -tags=integration 12 | 13 | docker-integration: 14 | @docker-compose run gokini make integration 15 | @docker-compose down 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gokini 2 | 3 | [![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://godoc.org/github.com/patrobinson/gokini) 4 | [![Build 5 | Status](https://travis-ci.org/golang/gddo.svg?branch=master)](https://travis-ci.org/patrobinson/gokini) 6 | 7 | A Golang Kinesis Consumer Library with minimal dependencies. This library does not depend on the Java MultiLangDaemon but does use the AWS SDK. 8 | 9 | ## Project Goals 10 | 11 | This project aims to provide feature parity with the [Kinesis Client Library](https://github.com/awslabs/amazon-kinesis-client) including: 12 | 13 | - [x] Enumerates shards 14 | 15 | - [x] Coordinates shard associations with other workers 16 | 17 | - [x] Instantiates a record processor for every shard it manages 18 | 19 | - [x] Checkpoints processed records 20 | 21 | - [x] Balances shard-worker associations when the worker instance count changes 22 | 23 | - [x] Balances shard-worker associations when shards are split or merged 24 | 25 | - [x] Instrumentation that supports CloudWatch (partial support) 26 | 27 | - [ ] Support enhanced fan-out consumers 28 | 29 | - [ ] Support aggregated records from Kinesis Producer library 30 | 31 | ## Development Status 32 | 33 | Beta - Ready to be used in non-critical Production environments. 34 | 35 | Actively used (via a fork) by [VMWare](https://github.com/vmware/vmware-go-kcl) 36 | 37 | ## Testing 38 | 39 | Unit tests can be run with: 40 | ``` 41 | go test consumer_test.go consumer.go checkpointer_test.go checkpointer.go monitoring.go monitoring_test.go 42 | ``` 43 | 44 | Integration tests can be run in docker with: 45 | ``` 46 | make docker-integration 47 | ``` 48 | -------------------------------------------------------------------------------- /checkpointer.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "time" 7 | 8 | "github.com/aws/aws-sdk-go/aws" 9 | "github.com/aws/aws-sdk-go/aws/awserr" 10 | "github.com/aws/aws-sdk-go/aws/session" 11 | "github.com/aws/aws-sdk-go/service/dynamodb" 12 | "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" 13 | "github.com/aws/aws-sdk-go/service/dynamodb/expression" 14 | log "github.com/sirupsen/logrus" 15 | ) 16 | 17 | // TODO: We have to deal with possibly empty attributes in a lot of places in here 18 | // This is error prone and likely to create race conditions 19 | // Create a method that will update all attributes, in a conditional update, based on all attributes in shard 20 | 21 | const ( 22 | defaultLeaseDuration = 30000 23 | // ErrLeaseNotAquired is returned when we failed to get a lock on the shard 24 | ErrLeaseNotAquired = "Lease is already held by another node" 25 | // ErrInvalidDynamoDBSchema is returned when there are one or more fields missing from the table 26 | ErrInvalidDynamoDBSchema = "The DynamoDB schema is invalid and may need to be re-created" 27 | ) 28 | 29 | // Checkpointer handles checkpointing when a record has been processed 30 | type Checkpointer interface { 31 | Init() error 32 | GetLease(*shardStatus, string) error 33 | CheckpointSequence(*shardStatus) error 34 | FetchCheckpoint(*shardStatus) error 35 | ListActiveWorkers() (map[string][]string, error) 36 | ClaimShard(*shardStatus, string) error 37 | } 38 | 39 | // ErrSequenceIDNotFound is returned by FetchCheckpoint when no SequenceID is found 40 | var ErrSequenceIDNotFound = errors.New("SequenceIDNotFoundForShard") 41 | 42 | // DynamoCheckpoint implements the Checkpoint interface using DynamoDB as a backend 43 | type DynamoCheckpoint struct { 44 | TableName string 45 | LeaseDuration int 46 | Retries int 47 | ReadCapacityUnits *int64 48 | WriteCapacityUnits *int64 49 | BillingMode *string 50 | Session *session.Session 51 | svc dynamodbiface.DynamoDBAPI 52 | skipTableCheck bool 53 | } 54 | 55 | // Init initialises the DynamoDB Checkpoint 56 | func (c *DynamoCheckpoint) Init() error { 57 | if endpoint := os.Getenv("DYNAMODB_ENDPOINT"); endpoint != "" { 58 | log.Infof("Using dynamodb endpoint from environment %s", endpoint) 59 | c.Session.Config.Endpoint = &endpoint 60 | } 61 | 62 | if c.svc == nil { 63 | c.svc = dynamodb.New(c.Session) 64 | } 65 | 66 | if c.LeaseDuration == 0 { 67 | c.LeaseDuration = defaultLeaseDuration 68 | } 69 | 70 | if c.BillingMode == nil { 71 | c.BillingMode = aws.String("PAY_PER_REQUEST") 72 | } 73 | 74 | if !c.skipTableCheck && !c.doesTableExist() { 75 | return c.createTable() 76 | } 77 | return nil 78 | } 79 | 80 | // GetLease attempts to gain a lock on the given shard 81 | func (c *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo string) error { 82 | newLeaseTimeout := time.Now().Add(time.Duration(c.LeaseDuration) * time.Millisecond).UTC() 83 | newLeaseTimeoutString := newLeaseTimeout.Format(time.RFC3339Nano) 84 | currentCheckpoint, err := c.getItem(shard.ID) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | assignedVar, assignedToOk := currentCheckpoint["AssignedTo"] 90 | leaseVar, leaseTimeoutOk := currentCheckpoint["LeaseTimeout"] 91 | 92 | var cond expression.ConditionBuilder 93 | 94 | if !leaseTimeoutOk || !assignedToOk { 95 | cond = expression.Name("AssignedTo").AttributeNotExists() 96 | } else { 97 | assignedTo := *assignedVar.S 98 | leaseTimeout := *leaseVar.S 99 | 100 | currentLeaseTimeout, err := time.Parse(time.RFC3339Nano, leaseTimeout) 101 | if err != nil { 102 | return err 103 | } 104 | if !time.Now().UTC().After(currentLeaseTimeout) && assignedTo != newAssignTo { 105 | return errors.New(ErrLeaseNotAquired) 106 | } 107 | cond = expression.Name("AssignedTo").Equal(expression.Value(assignedTo)) 108 | } 109 | if shard.Checkpoint != "" { 110 | cond = cond.And( 111 | expression.Name("SequenceID").Equal(expression.Value(shard.Checkpoint)), 112 | ) 113 | } 114 | 115 | update := expression.Set( 116 | expression.Name("AssignedTo"), 117 | expression.Value(newAssignTo), 118 | ).Set( 119 | expression.Name("LeaseTimeout"), 120 | expression.Value(newLeaseTimeoutString), 121 | ) 122 | if shard.ParentShardID != nil { 123 | update.Set( 124 | expression.Name("ParentShardID"), 125 | expression.Value(shard.ParentShardID), 126 | ) 127 | } 128 | 129 | expr, err := expression.NewBuilder(). 130 | WithUpdate(update). 131 | WithCondition(cond). 132 | Build() 133 | if err != nil { 134 | return err 135 | } 136 | 137 | i := &dynamodb.UpdateItemInput{ 138 | ExpressionAttributeNames: expr.Names(), 139 | ExpressionAttributeValues: expr.Values(), 140 | ConditionExpression: expr.Condition(), 141 | TableName: &c.TableName, 142 | Key: map[string]*dynamodb.AttributeValue{ 143 | "ShardID": { 144 | S: &shard.ID, 145 | }, 146 | }, 147 | UpdateExpression: expr.Update(), 148 | } 149 | log.Traceln(i) 150 | _, err = c.svc.UpdateItem(i) 151 | if err != nil { 152 | if awsErr, ok := err.(awserr.Error); ok { 153 | if awsErr.Code() == dynamodb.ErrCodeConditionalCheckFailedException { 154 | log.Traceln("Condition failed", err) 155 | return errors.New(ErrLeaseNotAquired) 156 | } 157 | } 158 | return err 159 | } 160 | 161 | shard.Lock() 162 | shard.AssignedTo = newAssignTo 163 | shard.LeaseTimeout = newLeaseTimeout 164 | shard.ClaimRequest = nil 165 | shard.Unlock() 166 | 167 | return nil 168 | } 169 | 170 | // CheckpointSequence writes a checkpoint at the designated sequence ID 171 | func (c *DynamoCheckpoint) CheckpointSequence(shard *shardStatus) error { 172 | update := expression.Set( 173 | expression.Name("SequenceID"), 174 | expression.Value(shard.Checkpoint), 175 | ).Set( 176 | expression.Name("Closed"), 177 | expression.Value(shard.Closed), 178 | ) 179 | cond := expression.Name("ClaimRequest").AttributeNotExists().And( 180 | expression.Name("AssignedTo").Equal(expression.Value(shard.AssignedTo)), 181 | ) 182 | expr, err := expression.NewBuilder(). 183 | WithUpdate(update). 184 | WithCondition(cond). 185 | Build() 186 | if err != nil { 187 | return err 188 | } 189 | _, err = c.svc.UpdateItem(&dynamodb.UpdateItemInput{ 190 | ExpressionAttributeNames: expr.Names(), 191 | ExpressionAttributeValues: expr.Values(), 192 | ConditionExpression: expr.Condition(), 193 | TableName: aws.String(c.TableName), 194 | Key: map[string]*dynamodb.AttributeValue{ 195 | "ShardID": { 196 | S: &shard.ID, 197 | }, 198 | }, 199 | UpdateExpression: expr.Update(), 200 | }) 201 | return err 202 | } 203 | 204 | // FetchCheckpoint retrieves the checkpoint for the given shard 205 | func (c *DynamoCheckpoint) FetchCheckpoint(shard *shardStatus) error { 206 | checkpoint, err := c.getItem(shard.ID) 207 | if err != nil { 208 | return err 209 | } 210 | 211 | var sequenceID string 212 | if s, ok := checkpoint["SequenceID"]; ok { 213 | // Why do we thrown an error here??? 214 | //return ErrSequenceIDNotFound 215 | sequenceID = *s.S 216 | } 217 | log.Debugf("Retrieved Shard Iterator %s", sequenceID) 218 | shard.Lock() 219 | defer shard.Unlock() 220 | shard.Checkpoint = sequenceID 221 | 222 | if assignedTo, ok := checkpoint["AssignedTo"]; ok { 223 | shard.AssignedTo = *assignedTo.S 224 | } 225 | 226 | if parent, ok := checkpoint["ParentShardID"]; ok && parent.S != nil { 227 | shard.ParentShardID = aws.String(*parent.S) 228 | } 229 | 230 | if claim, ok := checkpoint["ClaimRequest"]; ok && claim.S != nil { 231 | shard.ClaimRequest = aws.String(*claim.S) 232 | } 233 | 234 | if lease, ok := checkpoint["LeaseTimeout"]; ok && lease.S != nil { 235 | currentLeaseTimeout, err := time.Parse(time.RFC3339Nano, *lease.S) 236 | if err != nil { 237 | return err 238 | } 239 | shard.LeaseTimeout = currentLeaseTimeout 240 | } 241 | log.Debugln("Shard updated", *shard) 242 | return nil 243 | } 244 | 245 | type Worker struct { 246 | UUID string 247 | Shards []string 248 | } 249 | 250 | func (c *DynamoCheckpoint) ListActiveWorkers() (map[string][]string, error) { 251 | items, err := c.svc.Scan(&dynamodb.ScanInput{ 252 | TableName: aws.String(c.TableName), 253 | }) 254 | if err != nil { 255 | return nil, err 256 | } 257 | workers := make(map[string][]string) 258 | for _, i := range items.Items { 259 | // Ignore closed shards, only return active shards 260 | if closed, ok := i["Closed"]; ok && closed.S != nil { 261 | if *closed.BOOL { 262 | continue 263 | } 264 | } 265 | var workUUID, shardID string 266 | if u, ok := i["AssignedTo"]; !ok || u.S == nil { 267 | return nil, errors.New("invalid value found in DynamoDB table") 268 | } 269 | workUUID = *i["AssignedTo"].S 270 | if s, ok := i["ShardID"]; !ok || s.S == nil { 271 | return nil, errors.New("invalid value found in DynamoDB table") 272 | } 273 | shardID = *i["ShardID"].S 274 | if w, ok := workers[workUUID]; ok { 275 | workers[workUUID] = append(w, shardID) 276 | } else { 277 | workers[workUUID] = []string{shardID} 278 | } 279 | } 280 | return workers, nil 281 | } 282 | 283 | func (c *DynamoCheckpoint) ClaimShard(shard *shardStatus, claimID string) error { 284 | err := c.FetchCheckpoint(shard) 285 | if err != nil { 286 | return err 287 | } 288 | update := expression.Set( 289 | expression.Name("ClaimRequest"), 290 | expression.Value(claimID), 291 | ) 292 | cond := expression.Name("ClaimRequest").AttributeNotExists().And( 293 | expression.Name("AssignedTo").Equal(expression.Value(shard.AssignedTo)), 294 | ) 295 | expr, err := expression.NewBuilder(). 296 | WithUpdate(update). 297 | WithCondition(cond). 298 | Build() 299 | if err != nil { 300 | return err 301 | } 302 | _, err = c.svc.UpdateItem(&dynamodb.UpdateItemInput{ 303 | ExpressionAttributeNames: expr.Names(), 304 | ExpressionAttributeValues: expr.Values(), 305 | ConditionExpression: expr.Condition(), 306 | TableName: aws.String(c.TableName), 307 | Key: map[string]*dynamodb.AttributeValue{ 308 | "ShardID": { 309 | S: &shard.ID, 310 | }, 311 | }, 312 | UpdateExpression: expr.Update(), 313 | }) 314 | return err 315 | } 316 | 317 | func (c *DynamoCheckpoint) createTable() error { 318 | input := &dynamodb.CreateTableInput{ 319 | AttributeDefinitions: []*dynamodb.AttributeDefinition{ 320 | { 321 | AttributeName: aws.String("ShardID"), 322 | AttributeType: aws.String("S"), 323 | }, 324 | }, 325 | BillingMode: c.BillingMode, 326 | KeySchema: []*dynamodb.KeySchemaElement{ 327 | { 328 | AttributeName: aws.String("ShardID"), 329 | KeyType: aws.String("HASH"), 330 | }, 331 | }, 332 | TableName: aws.String(c.TableName), 333 | } 334 | if *c.BillingMode == "PROVISIONED" { 335 | input.ProvisionedThroughput = &dynamodb.ProvisionedThroughput{ 336 | ReadCapacityUnits: c.ReadCapacityUnits, 337 | WriteCapacityUnits: c.WriteCapacityUnits, 338 | } 339 | } 340 | _, err := c.svc.CreateTable(input) 341 | return err 342 | } 343 | 344 | func (c *DynamoCheckpoint) doesTableExist() bool { 345 | input := &dynamodb.DescribeTableInput{ 346 | TableName: aws.String(c.TableName), 347 | } 348 | _, err := c.svc.DescribeTable(input) 349 | return (err == nil) 350 | } 351 | 352 | func (c *DynamoCheckpoint) getItem(shardID string) (map[string]*dynamodb.AttributeValue, error) { 353 | item, err := c.svc.GetItem(&dynamodb.GetItemInput{ 354 | TableName: aws.String(c.TableName), 355 | Key: map[string]*dynamodb.AttributeValue{ 356 | "ShardID": { 357 | S: aws.String(shardID), 358 | }, 359 | }, 360 | }) 361 | return item.Item, err 362 | } 363 | -------------------------------------------------------------------------------- /checkpointer_integration_test.go: -------------------------------------------------------------------------------- 1 | //+build integration 2 | 3 | package gokini 4 | 5 | import ( 6 | "testing" 7 | "time" 8 | 9 | "github.com/aws/aws-sdk-go/aws" 10 | "github.com/aws/aws-sdk-go/aws/session" 11 | "github.com/aws/aws-sdk-go/service/dynamodb" 12 | ) 13 | 14 | func TestGetLeaseNotAquired(t *testing.T) { 15 | checkpoint := &DynamoCheckpoint{ 16 | TableName: "TableName", 17 | Session: session.New(), 18 | } 19 | checkpoint.Init() 20 | defer checkpoint.svc.DeleteTable(&dynamodb.DeleteTableInput{TableName: aws.String("TableName")}) 21 | err := checkpoint.GetLease(&shardStatus{ 22 | ID: "0001", 23 | Checkpoint: "", 24 | }, "abcd-efgh") 25 | if err != nil { 26 | t.Errorf("Error getting lease %s", err) 27 | } 28 | 29 | err = checkpoint.GetLease(&shardStatus{ 30 | ID: "0001", 31 | Checkpoint: "", 32 | }, "ijkl-mnop") 33 | if err == nil || err.Error() != ErrLeaseNotAquired { 34 | t.Errorf("Got a lease when it was already held by abcd-efgh: %s", err) 35 | } 36 | } 37 | 38 | func TestGetLeaseAquired(t *testing.T) { 39 | checkpoint := &DynamoCheckpoint{ 40 | TableName: "TableName", 41 | Session: session.New(), 42 | } 43 | checkpoint.Init() 44 | defer checkpoint.svc.DeleteTable(&dynamodb.DeleteTableInput{TableName: aws.String("TableName")}) 45 | marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ 46 | "ShardID": { 47 | S: aws.String("0001"), 48 | }, 49 | "AssignedTo": { 50 | S: aws.String("abcd-efgh"), 51 | }, 52 | "LeaseTimeout": { 53 | S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)), 54 | }, 55 | "SequenceID": { 56 | S: aws.String("deadbeef"), 57 | }, 58 | } 59 | input := &dynamodb.PutItemInput{ 60 | TableName: aws.String("TableName"), 61 | Item: marshalledCheckpoint, 62 | } 63 | checkpoint.svc.PutItem(input) 64 | shard := &shardStatus{ 65 | ID: "0001", 66 | Checkpoint: "deadbeef", 67 | } 68 | err := checkpoint.GetLease(shard, "ijkl-mnop") 69 | 70 | if err != nil { 71 | t.Errorf("Lease not aquired after timeout %s", err) 72 | t.Log(checkpoint.svc.GetItem(&dynamodb.GetItemInput{TableName: aws.String("TableName"), 73 | Key: map[string]*dynamodb.AttributeValue{ 74 | "ShardID": { 75 | S: aws.String("0001"), 76 | }, 77 | }})) 78 | } 79 | 80 | checkpoint.FetchCheckpoint(shard) 81 | 82 | if shard.Checkpoint != "deadbeef" { 83 | t.Errorf("Expected SequenceID to be deadbeef. Got '%s'", shard.Checkpoint) 84 | } 85 | } 86 | 87 | func TestRaceCondGetLeaseTimeout(t *testing.T) { 88 | checkpoint := &DynamoCheckpoint{ 89 | TableName: "TableName", 90 | Session: session.New(), 91 | } 92 | checkpoint.Init() 93 | defer checkpoint.svc.DeleteTable(&dynamodb.DeleteTableInput{TableName: aws.String("TableName")}) 94 | marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ 95 | "ShardID": { 96 | S: aws.String("0001"), 97 | }, 98 | "AssignedTo": { 99 | S: aws.String("abcd-efgh"), 100 | }, 101 | "LeaseTimeout": { 102 | S: aws.String(time.Now().AddDate(0, 0, 1).UTC().Format(time.RFC3339)), 103 | }, 104 | "SequenceID": { 105 | S: aws.String("deadbeef"), 106 | }, 107 | } 108 | input := &dynamodb.PutItemInput{ 109 | TableName: aws.String("TableName"), 110 | Item: marshalledCheckpoint, 111 | } 112 | _, err := checkpoint.svc.PutItem(input) 113 | if err != nil { 114 | t.Fatalf("Error writing to dynamo %s", err) 115 | } 116 | shard := &shardStatus{ 117 | ID: "0001", 118 | Checkpoint: "TestRaceCondGetLeaseTimeout", 119 | } 120 | err = checkpoint.GetLease(shard, "ijkl-mnop") 121 | 122 | if err == nil || err.Error() != ErrLeaseNotAquired { 123 | t.Error("Got a lease when shard was assigned.") 124 | } 125 | } 126 | func TestRaceCondGetLeaseNoAssignee(t *testing.T) { 127 | checkpoint := &DynamoCheckpoint{ 128 | TableName: "TableName", 129 | Session: session.New(), 130 | } 131 | checkpoint.Init() 132 | defer checkpoint.svc.DeleteTable(&dynamodb.DeleteTableInput{TableName: aws.String("TableName")}) 133 | marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ 134 | "ShardID": { 135 | S: aws.String("0001"), 136 | }, 137 | "SequenceID": { 138 | S: aws.String("deadbeef"), 139 | }, 140 | } 141 | input := &dynamodb.PutItemInput{ 142 | TableName: aws.String("TableName"), 143 | Item: marshalledCheckpoint, 144 | } 145 | _, err := checkpoint.svc.PutItem(input) 146 | if err != nil { 147 | t.Fatalf("Error writing to dynamo %s", err) 148 | } 149 | shard := &shardStatus{ 150 | ID: "0001", 151 | Checkpoint: "TestRaceCondGetLeaseNoAssignee", 152 | } 153 | err = checkpoint.GetLease(shard, "ijkl-mnop") 154 | 155 | if err == nil || err.Error() != ErrLeaseNotAquired { 156 | t.Errorf("Got a lease when checkpoints didn't match. Potentially we stomped on the checkpoint %s", err) 157 | } 158 | } 159 | 160 | func TestGetLeaseRenewed(t *testing.T) { 161 | checkpoint := &DynamoCheckpoint{ 162 | TableName: "TableName", 163 | Session: session.New(), 164 | LeaseDuration: 0, 165 | } 166 | checkpoint.Init() 167 | defer checkpoint.svc.DeleteTable(&dynamodb.DeleteTableInput{TableName: aws.String("TableName")}) 168 | err := checkpoint.GetLease(&shardStatus{ 169 | ID: "0001", 170 | Checkpoint: "", 171 | }, "abcd-efgh") 172 | if err != nil { 173 | t.Errorf("Error getting lease %s", err) 174 | } 175 | 176 | err = checkpoint.GetLease(&shardStatus{ 177 | ID: "0001", 178 | Checkpoint: "", 179 | }, "abcd-efgh") 180 | if err != nil { 181 | t.Errorf("Error renewing lease %s", err) 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /checkpointer_test.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/aws/aws-sdk-go/aws/awserr" 8 | "github.com/aws/aws-sdk-go/aws/session" 9 | "github.com/aws/aws-sdk-go/service/dynamodb" 10 | "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" 11 | ) 12 | 13 | type mockDynamoDB struct { 14 | dynamodbiface.DynamoDBAPI 15 | tableExist bool 16 | item map[string]*dynamodb.AttributeValue 17 | } 18 | 19 | func (m *mockDynamoDB) DescribeTable(*dynamodb.DescribeTableInput) (*dynamodb.DescribeTableOutput, error) { 20 | if !m.tableExist { 21 | return &dynamodb.DescribeTableOutput{}, awserr.New(dynamodb.ErrCodeResourceNotFoundException, "doesNotExist", errors.New("")) 22 | } 23 | return &dynamodb.DescribeTableOutput{}, nil 24 | } 25 | 26 | func (m *mockDynamoDB) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { 27 | m.item = input.Item 28 | return nil, nil 29 | } 30 | 31 | // To hard to implement this 32 | func (m *mockDynamoDB) UpdateItem(input *dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) { 33 | return nil, nil 34 | } 35 | 36 | func (m *mockDynamoDB) GetItem(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) { 37 | return &dynamodb.GetItemOutput{ 38 | Item: m.item, 39 | }, nil 40 | } 41 | 42 | func (m *mockDynamoDB) CreateTable(input *dynamodb.CreateTableInput) (*dynamodb.CreateTableOutput, error) { 43 | return &dynamodb.CreateTableOutput{}, nil 44 | } 45 | func TestDoesTableExist(t *testing.T) { 46 | svc := &mockDynamoDB{tableExist: true} 47 | checkpoint := &DynamoCheckpoint{ 48 | TableName: "TableName", 49 | Session: session.New(), 50 | svc: svc, 51 | } 52 | if !checkpoint.doesTableExist() { 53 | t.Error("Table exists but returned false") 54 | } 55 | 56 | svc = &mockDynamoDB{tableExist: false} 57 | checkpoint.svc = svc 58 | if checkpoint.doesTableExist() { 59 | t.Error("Table does not exist but returned true") 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /consumer.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "os" 8 | "sync" 9 | "time" 10 | 11 | "github.com/pkg/errors" 12 | 13 | "github.com/aws/aws-sdk-go/aws" 14 | "github.com/aws/aws-sdk-go/aws/awserr" 15 | "github.com/aws/aws-sdk-go/aws/client" 16 | "github.com/aws/aws-sdk-go/aws/session" 17 | "github.com/aws/aws-sdk-go/service/kinesis" 18 | "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" 19 | "github.com/google/uuid" 20 | log "github.com/sirupsen/logrus" 21 | ) 22 | 23 | const ( 24 | defaultEmptyRecordBackoffMs = 500 25 | defaultMillisecondsBackoffClaim = 30000 26 | defaultEventLoopSleepMs = 1000 27 | // ErrCodeKMSThrottlingException is defined in the API Reference https://docs.aws.amazon.com/sdk-for-go/api/service/kinesis/#Kinesis.GetRecords 28 | // But it's not a constant? 29 | ErrCodeKMSThrottlingException = "KMSThrottlingException" 30 | ) 31 | 32 | // RecordConsumer is the interface consumers will implement 33 | type RecordConsumer interface { 34 | Init(string) error 35 | ProcessRecords([]*Records, *KinesisConsumer) 36 | Shutdown() 37 | } 38 | 39 | // Records is structure for Kinesis Records 40 | type Records struct { 41 | Data []byte `json:"data"` 42 | PartitionKey string `json:"partitionKey"` 43 | SequenceNumber string `json:"sequenceNumber"` 44 | ShardID string `json:"shardID"` 45 | } 46 | 47 | type shardStatus struct { 48 | ID string 49 | Checkpoint string 50 | AssignedTo string 51 | LeaseTimeout time.Time 52 | ParentShardID *string 53 | Closed bool 54 | ClaimRequest *string 55 | readyToBeClosed bool 56 | sync.Mutex 57 | } 58 | 59 | // KinesisConsumer contains all the configuration and functions necessary to start the Kinesis Consumer 60 | type KinesisConsumer struct { 61 | StreamName string 62 | ShardIteratorType string 63 | RecordConsumer RecordConsumer 64 | EmptyRecordBackoffMs int 65 | LeaseDuration int 66 | Monitoring MonitoringConfiguration 67 | DisableAutomaticCheckpoints bool 68 | Retries *int 69 | IgnoreShardOrdering bool 70 | TableName string 71 | DynamoReadCapacityUnits *int64 72 | DynamoWriteCapacityUnits *int64 73 | DynamoBillingMode *string 74 | Session *session.Session // Setting session means Retries is ignored 75 | millisecondsBackoffClaim int 76 | eventLoopSleepMs int 77 | svc kinesisiface.KinesisAPI 78 | checkpointer Checkpointer 79 | stop *chan struct{} 80 | shardStatus map[string]*shardStatus 81 | consumerID string 82 | mService monitoringService 83 | shardStealInProgress bool 84 | sync.WaitGroup 85 | } 86 | 87 | var defaultRetries = 5 88 | 89 | // StartConsumer starts the RecordConsumer, calls Init and starts sending records to ProcessRecords 90 | func (kc *KinesisConsumer) StartConsumer() error { 91 | rand.Seed(time.Now().UnixNano()) 92 | 93 | // Set Defaults 94 | if kc.EmptyRecordBackoffMs == 0 { 95 | kc.EmptyRecordBackoffMs = defaultEmptyRecordBackoffMs 96 | } 97 | 98 | kc.consumerID = uuid.New().String() 99 | 100 | err := kc.Monitoring.init(kc.StreamName, kc.consumerID, kc.Session) 101 | if err != nil { 102 | log.Errorf("Failed to start monitoring service: %s", err) 103 | } 104 | kc.mService = kc.Monitoring.service 105 | 106 | if kc.millisecondsBackoffClaim == 0 { 107 | kc.millisecondsBackoffClaim = defaultMillisecondsBackoffClaim 108 | } 109 | 110 | if kc.eventLoopSleepMs == 0 { 111 | kc.eventLoopSleepMs = defaultEventLoopSleepMs 112 | } 113 | 114 | retries := defaultRetries 115 | if kc.Retries != nil { 116 | retries = *kc.Retries 117 | } 118 | 119 | if kc.Session == nil { 120 | log.Debugln("Creating AWS Session", kc.consumerID) 121 | kc.Session, err = session.NewSessionWithOptions( 122 | session.Options{ 123 | Config: aws.Config{Retryer: client.DefaultRetryer{NumMaxRetries: retries}}, 124 | SharedConfigState: session.SharedConfigEnable, 125 | }, 126 | ) 127 | if err != nil { 128 | return err 129 | } 130 | } 131 | 132 | if kc.svc == nil && kc.checkpointer == nil { 133 | if endpoint := os.Getenv("KINESIS_ENDPOINT"); endpoint != "" { 134 | kc.Session.Config.Endpoint = aws.String(endpoint) 135 | } 136 | kc.svc = kinesis.New(kc.Session) 137 | kc.checkpointer = &DynamoCheckpoint{ 138 | ReadCapacityUnits: kc.DynamoReadCapacityUnits, 139 | WriteCapacityUnits: kc.DynamoWriteCapacityUnits, 140 | BillingMode: kc.DynamoBillingMode, 141 | TableName: kc.TableName, 142 | Retries: retries, 143 | LeaseDuration: kc.LeaseDuration, 144 | Session: kc.Session, 145 | } 146 | } 147 | 148 | log.Debugf("Initializing Checkpointer") 149 | if err := kc.checkpointer.Init(); err != nil { 150 | return errors.Wrapf(err, "Failed to start Checkpointer") 151 | } 152 | 153 | kc.shardStatus = make(map[string]*shardStatus) 154 | 155 | stopChan := make(chan struct{}) 156 | kc.stop = &stopChan 157 | 158 | err = kc.getShardIDs("") 159 | if err != nil { 160 | log.Errorf("Error getting Kinesis shards: %s", err) 161 | return err 162 | } 163 | go kc.eventLoop() 164 | 165 | return nil 166 | } 167 | 168 | func (kc *KinesisConsumer) eventLoop() { 169 | for { 170 | log.Debug("Getting shards") 171 | err := kc.getShardIDs("") 172 | if err != nil { 173 | log.Errorf("Error getting Kinesis shards: %s", err) 174 | // Back-off? 175 | time.Sleep(500 * time.Millisecond) 176 | } 177 | log.Debugf("Found %d shards", len(kc.shardStatus)) 178 | 179 | for _, shard := range kc.shardStatus { 180 | // We already own this shard so carry on 181 | if shard.AssignedTo == kc.consumerID { 182 | continue 183 | } 184 | 185 | err := kc.checkpointer.FetchCheckpoint(shard) 186 | if err != nil { 187 | if err != ErrSequenceIDNotFound { 188 | log.Error(err) 189 | continue 190 | } 191 | } 192 | 193 | var stealShard bool 194 | if shard.ClaimRequest != nil { 195 | if shard.LeaseTimeout.Before(time.Now().Add(time.Millisecond * time.Duration(kc.millisecondsBackoffClaim))) { 196 | if *shard.ClaimRequest != kc.consumerID { 197 | log.Debugln("Shard being stolen", shard.ID) 198 | continue 199 | } else { 200 | stealShard = true 201 | log.Debugln("Stealing shard", shard.ID) 202 | } 203 | } 204 | } 205 | 206 | err = kc.checkpointer.GetLease(shard, kc.consumerID) 207 | if err != nil { 208 | if err.Error() != ErrLeaseNotAquired { 209 | log.Error(err) 210 | } 211 | continue 212 | } 213 | if stealShard { 214 | log.Debugln("Successfully stole shard", shard.ID) 215 | kc.shardStealInProgress = false 216 | } 217 | 218 | kc.mService.leaseGained(shard.ID) 219 | 220 | kc.RecordConsumer.Init(shard.ID) 221 | log.Debugf("Starting consumer for shard %s on %s", shard.ID, shard.AssignedTo) 222 | kc.Add(1) 223 | go kc.getRecords(shard.ID) 224 | } 225 | 226 | err = kc.rebalance() 227 | if err != nil { 228 | log.Warn(err) 229 | } 230 | 231 | select { 232 | case <-*kc.stop: 233 | log.Info("Shutting down") 234 | return 235 | case <-time.After(time.Duration(kc.eventLoopSleepMs) * time.Millisecond): 236 | } 237 | } 238 | } 239 | 240 | // Shutdown stops consuming records gracefully 241 | func (kc *KinesisConsumer) Shutdown() { 242 | close(*kc.stop) 243 | kc.Wait() 244 | kc.RecordConsumer.Shutdown() 245 | } 246 | 247 | func (kc *KinesisConsumer) getShardIDs(startShardID string) error { 248 | args := &kinesis.DescribeStreamInput{ 249 | StreamName: aws.String(kc.StreamName), 250 | } 251 | if startShardID != "" { 252 | args.ExclusiveStartShardId = aws.String(startShardID) 253 | } 254 | 255 | streamDesc, err := kc.svc.DescribeStream(args) 256 | if err != nil { 257 | return err 258 | } 259 | 260 | if *streamDesc.StreamDescription.StreamStatus != "ACTIVE" { 261 | return errors.New("Stream not active") 262 | } 263 | 264 | var lastShardID string 265 | for _, s := range streamDesc.StreamDescription.Shards { 266 | if _, ok := kc.shardStatus[*s.ShardId]; !ok { 267 | log.Debugf("Found shard with id %s", *s.ShardId) 268 | kc.shardStatus[*s.ShardId] = &shardStatus{ 269 | ID: *s.ShardId, 270 | ParentShardID: s.ParentShardId, 271 | } 272 | } 273 | lastShardID = *s.ShardId 274 | } 275 | 276 | if *streamDesc.StreamDescription.HasMoreShards { 277 | err := kc.getShardIDs(lastShardID) 278 | if err != nil { 279 | return err 280 | } 281 | } 282 | 283 | return nil 284 | } 285 | 286 | func (kc *KinesisConsumer) getShardIterator(shard *shardStatus) (*string, error) { 287 | err := kc.checkpointer.FetchCheckpoint(shard) 288 | if err != nil && err != ErrSequenceIDNotFound { 289 | return nil, err 290 | } 291 | 292 | if shard.Checkpoint == "" { 293 | shardIterArgs := &kinesis.GetShardIteratorInput{ 294 | ShardId: &shard.ID, 295 | ShardIteratorType: &kc.ShardIteratorType, 296 | StreamName: &kc.StreamName, 297 | } 298 | iterResp, err := kc.svc.GetShardIterator(shardIterArgs) 299 | if err != nil { 300 | return nil, err 301 | } 302 | return iterResp.ShardIterator, nil 303 | } 304 | 305 | shardIterArgs := &kinesis.GetShardIteratorInput{ 306 | ShardId: &shard.ID, 307 | ShardIteratorType: aws.String("AFTER_SEQUENCE_NUMBER"), 308 | StartingSequenceNumber: &shard.Checkpoint, 309 | StreamName: &kc.StreamName, 310 | } 311 | iterResp, err := kc.svc.GetShardIterator(shardIterArgs) 312 | if err != nil { 313 | return nil, err 314 | } 315 | return iterResp.ShardIterator, nil 316 | } 317 | 318 | func (kc *KinesisConsumer) getRecords(shardID string) { 319 | defer kc.Done() 320 | 321 | shard := kc.shardStatus[shardID] 322 | shardIterator, err := kc.getShardIterator(shard) 323 | if err != nil { 324 | kc.RecordConsumer.Shutdown() 325 | log.Errorf("Unable to get shard iterator for %s: %s", shardID, err) 326 | return 327 | } 328 | 329 | var retriedErrors int 330 | 331 | for { 332 | getRecordsStartTime := time.Now() 333 | err := kc.checkpointer.FetchCheckpoint(shard) 334 | if err != nil && err != ErrSequenceIDNotFound { 335 | log.Errorln("Error fetching checkpoint", err) 336 | time.Sleep(time.Duration(defaultEventLoopSleepMs) * time.Second) 337 | continue 338 | } 339 | if shard.ClaimRequest != nil { 340 | // Claim request means another worker wants to steal our shard. So let the lease lapse 341 | log.Infof("Shard %s has been stolen from us", shardID) 342 | return 343 | } 344 | if time.Now().UTC().After(shard.LeaseTimeout.Add(-5 * time.Second)) { 345 | err = kc.checkpointer.GetLease(shard, kc.consumerID) 346 | if err != nil { 347 | if err.Error() == ErrLeaseNotAquired { 348 | shard.Lock() 349 | defer shard.Unlock() 350 | shard.AssignedTo = "" 351 | kc.mService.leaseLost(shard.ID) 352 | log.Debugln("Lease lost for shard", shard.ID, kc.consumerID) 353 | return 354 | } 355 | log.Warnln("Error renewing lease", err) 356 | time.Sleep(time.Duration(1) * time.Second) 357 | continue 358 | } 359 | } 360 | 361 | getRecordsArgs := &kinesis.GetRecordsInput{ 362 | ShardIterator: shardIterator, 363 | } 364 | getResp, err := kc.svc.GetRecords(getRecordsArgs) 365 | if err != nil { 366 | if awsErr, ok := err.(awserr.Error); ok { 367 | if awsErr.Code() == kinesis.ErrCodeProvisionedThroughputExceededException || awsErr.Code() == ErrCodeKMSThrottlingException { 368 | log.Errorf("Error getting records from shard %v: %v", shardID, err) 369 | retriedErrors++ 370 | time.Sleep(time.Duration(2^retriedErrors*100) * time.Millisecond) 371 | continue 372 | } 373 | } 374 | // This is an exception we cannot handle and therefore we exit 375 | panic(fmt.Sprintf("Error getting records from Kinesis that cannot be retried: %s\nRequest: %s", err, getRecordsArgs)) 376 | } 377 | retriedErrors = 0 378 | 379 | var records []*Records 380 | var recordBytes int64 381 | for _, r := range getResp.Records { 382 | record := &Records{ 383 | Data: r.Data, 384 | PartitionKey: *r.PartitionKey, 385 | SequenceNumber: *r.SequenceNumber, 386 | ShardID: shardID, 387 | } 388 | records = append(records, record) 389 | recordBytes += int64(len(record.Data)) 390 | log.Tracef("Processing record %s", *r.SequenceNumber) 391 | } 392 | processRecordsStartTime := time.Now() 393 | kc.RecordConsumer.ProcessRecords(records, kc) 394 | 395 | // Convert from nanoseconds to milliseconds 396 | processedRecordsTiming := time.Since(processRecordsStartTime) / 1000000 397 | kc.mService.recordProcessRecordsTime(shard.ID, float64(processedRecordsTiming)) 398 | 399 | if len(records) == 0 { 400 | time.Sleep(time.Duration(kc.EmptyRecordBackoffMs) * time.Millisecond) 401 | } else if !kc.DisableAutomaticCheckpoints { 402 | kc.Checkpoint(shardID, *getResp.Records[len(getResp.Records)-1].SequenceNumber) 403 | } 404 | 405 | kc.mService.incrRecordsProcessed(shard.ID, len(records)) 406 | kc.mService.incrBytesProcessed(shard.ID, recordBytes) 407 | kc.mService.millisBehindLatest(shard.ID, float64(*getResp.MillisBehindLatest)) 408 | 409 | // Convert from nanoseconds to milliseconds 410 | getRecordsTime := time.Since(getRecordsStartTime) / 1000000 411 | kc.mService.recordGetRecordsTime(shard.ID, float64(getRecordsTime)) 412 | 413 | // The shard has been closed, so no new records can be read from it 414 | if getResp.NextShardIterator == nil { 415 | log.Debugf("Shard %s closed", shardID) 416 | shard := kc.shardStatus[shardID] 417 | shard.Lock() 418 | shard.readyToBeClosed = true 419 | shard.Unlock() 420 | if !kc.DisableAutomaticCheckpoints { 421 | kc.Checkpoint(shardID, *getResp.Records[len(getResp.Records)-1].SequenceNumber) 422 | } 423 | return 424 | } 425 | shardIterator = getResp.NextShardIterator 426 | 427 | select { 428 | case <-*kc.stop: 429 | log.Infoln("Received stop signal, stopping record consumer for", shardID) 430 | return 431 | case <-time.After(1 * time.Nanosecond): 432 | } 433 | } 434 | } 435 | 436 | // Checkpoint records the sequence number for the given shard ID as being processed 437 | func (kc *KinesisConsumer) Checkpoint(shardID string, sequenceNumber string) error { 438 | shard := kc.shardStatus[shardID] 439 | shard.Lock() 440 | shard.Checkpoint = sequenceNumber 441 | shard.Unlock() 442 | // If shard is closed and we've read all records from the shard, mark the shard as closed 443 | if shard.readyToBeClosed { 444 | var err error 445 | shard.Closed, err = kc.shardIsEmpty(shard) 446 | if err != nil { 447 | return err 448 | } 449 | } 450 | return kc.checkpointer.CheckpointSequence(shard) 451 | } 452 | 453 | func (kc *KinesisConsumer) shardIsEmpty(shard *shardStatus) (empty bool, err error) { 454 | iterResp, err := kc.svc.GetShardIterator(&kinesis.GetShardIteratorInput{ 455 | ShardId: &shard.ID, 456 | ShardIteratorType: aws.String("AFTER_SEQUENCE_NUMBER"), 457 | StartingSequenceNumber: &shard.Checkpoint, 458 | StreamName: &kc.StreamName, 459 | }) 460 | if err != nil { 461 | return 462 | } 463 | recordsResp, err := kc.svc.GetRecords(&kinesis.GetRecordsInput{ 464 | ShardIterator: iterResp.ShardIterator, 465 | }) 466 | if err != nil { 467 | return 468 | } 469 | if len(recordsResp.Records) == 0 { 470 | empty = true 471 | } 472 | return 473 | } 474 | 475 | func (kc *KinesisConsumer) rebalance() error { 476 | workers, err := kc.checkpointer.ListActiveWorkers() 477 | if err != nil { 478 | log.Debugln("Error listing workings", kc.consumerID, err) 479 | return err 480 | } 481 | 482 | // Only attempt to steal one shard at at time, to allow for linear convergence 483 | if kc.shardStealInProgress { 484 | err := kc.getShardIDs("") 485 | if err != nil { 486 | return err 487 | } 488 | for _, shard := range kc.shardStatus { 489 | if shard.ClaimRequest != nil && *shard.ClaimRequest == kc.consumerID { 490 | log.Debugln("Steal in progress", kc.consumerID) 491 | return nil 492 | } 493 | // Our shard steal was stomped on by a Checkpoint. 494 | // We could deal with that, but instead just try again 495 | kc.shardStealInProgress = false 496 | } 497 | } 498 | 499 | var numShards float64 500 | for _, shards := range workers { 501 | numShards += float64(len(shards)) 502 | } 503 | numWorkers := float64(len(workers)) 504 | 505 | // 1:1 shards to workers is optimal, so we cannot possibly rebalance 506 | if numWorkers >= numShards { 507 | log.Debugln("Optimal shard allocation, not stealing any shards", numWorkers, ">", numShards, kc.consumerID) 508 | return nil 509 | } 510 | 511 | currentShards, ok := workers[kc.consumerID] 512 | var numCurrentShards float64 513 | if !ok { 514 | numCurrentShards = 0 515 | numWorkers++ 516 | } else { 517 | numCurrentShards = float64(len(currentShards)) 518 | } 519 | 520 | optimalShards := math.Floor(numShards / numWorkers) 521 | log.Debugln("Number of shards", numShards) 522 | log.Debugln("Number of workers", numWorkers) 523 | log.Debugln("Optimal shards", optimalShards) 524 | log.Debugln("Current shards", numCurrentShards) 525 | // We have more than or equal optimal shards, so no rebalancing can take place 526 | if numCurrentShards >= optimalShards { 527 | log.Debugln("We have enough shards, not attempting to steal any", kc.consumerID) 528 | return nil 529 | } 530 | maxShards := int(optimalShards) 531 | var workerSteal *string 532 | for w, shards := range workers { 533 | if len(shards) > maxShards { 534 | workerSteal = &w 535 | maxShards = len(shards) 536 | } 537 | } 538 | // Not all shards are allocated so fallback to default shard allocation mechanisms 539 | if workerSteal == nil { 540 | log.Debugln("Not all shards are allocated, not stealing any", kc.consumerID) 541 | return nil 542 | } 543 | 544 | // Steal a random shard from the worker with the most shards 545 | kc.shardStealInProgress = true 546 | randIndex := rand.Perm(len(workers[*workerSteal]))[0] 547 | shardToSteal := workers[*workerSteal][randIndex] 548 | log.Debugln("Stealing shard", shardToSteal, "from", *workerSteal) 549 | err = kc.checkpointer.ClaimShard(&shardStatus{ 550 | ID: shardToSteal, 551 | }, kc.consumerID) 552 | if err != nil { 553 | kc.shardStealInProgress = false 554 | } 555 | return err 556 | } 557 | -------------------------------------------------------------------------------- /consumer_test.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/aws/aws-sdk-go/aws" 11 | "github.com/aws/aws-sdk-go/service/kinesis" 12 | "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" 13 | ) 14 | 15 | type testConsumer struct { 16 | ShardID string 17 | Records []*Records 18 | IsShutdown bool 19 | } 20 | 21 | func (tc *testConsumer) Init(shardID string) error { 22 | tc.ShardID = shardID 23 | return nil 24 | } 25 | 26 | func (tc *testConsumer) ProcessRecords(records []*Records, consumer *KinesisConsumer) { 27 | tc.Records = append(tc.Records, records...) 28 | } 29 | 30 | func (tc *testConsumer) Shutdown() { 31 | tc.IsShutdown = true 32 | return 33 | } 34 | 35 | type mockKinesisClient struct { 36 | kinesisiface.KinesisAPI 37 | NumberRecordsBeforeClosing int 38 | numberRecordsSent int 39 | getShardIteratorCalled bool 40 | RecordData []byte 41 | numShards int 42 | } 43 | 44 | func (k *mockKinesisClient) GetShardIterator(args *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { 45 | k.getShardIteratorCalled = true 46 | return &kinesis.GetShardIteratorOutput{ 47 | ShardIterator: aws.String("0123456789ABCDEF"), 48 | }, nil 49 | } 50 | 51 | func (k *mockKinesisClient) DescribeStream(args *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { 52 | shards := []*kinesis.Shard{} 53 | for i := 0; i < k.numShards; i++ { 54 | shards = append(shards, &kinesis.Shard{ 55 | ShardId: aws.String(fmt.Sprintf("0000000%d", i)), 56 | }) 57 | } 58 | return &kinesis.DescribeStreamOutput{ 59 | StreamDescription: &kinesis.StreamDescription{ 60 | StreamStatus: aws.String("ACTIVE"), 61 | Shards: shards, 62 | HasMoreShards: aws.Bool(false), 63 | }, 64 | }, nil 65 | } 66 | 67 | type mockCheckpointer struct { 68 | checkpointFound bool 69 | checkpoint map[string]*shardStatus 70 | checkpointerCalled bool 71 | sync.Mutex 72 | } 73 | 74 | func (c *mockCheckpointer) Init() error { 75 | c.checkpoint = make(map[string]*shardStatus) 76 | return nil 77 | } 78 | 79 | func (c *mockCheckpointer) GetLease(shard *shardStatus, assignTo string) error { 80 | shard.Lock() 81 | shard.AssignedTo = assignTo 82 | shard.LeaseTimeout = time.Now() 83 | shard.Unlock() 84 | return nil 85 | } 86 | 87 | func (c *mockCheckpointer) CheckpointSequence(shard *shardStatus) error { 88 | c.Lock() 89 | defer c.Unlock() 90 | c.checkpoint[shard.ID] = shard 91 | c.checkpointerCalled = true 92 | return nil 93 | } 94 | func (c *mockCheckpointer) FetchCheckpoint(shard *shardStatus) error { 95 | if c.checkpointFound { 96 | if checkpointShard, ok := c.checkpoint[shard.ID]; ok { 97 | shard.Checkpoint = checkpointShard.Checkpoint 98 | shard.AssignedTo = checkpointShard.AssignedTo 99 | } else { 100 | shard.Checkpoint = "ABCD124" 101 | shard.AssignedTo = "abcdef-1234567" 102 | } 103 | } 104 | return nil 105 | } 106 | 107 | func (c *mockCheckpointer) ListActiveWorkers() (map[string][]string, error) { 108 | return nil, nil 109 | } 110 | 111 | func (c *mockCheckpointer) ClaimShard(*shardStatus, string) error { 112 | return nil 113 | } 114 | 115 | func (k *mockKinesisClient) GetRecords(args *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { 116 | k.numberRecordsSent++ 117 | var nextShardIterator *string 118 | if k.NumberRecordsBeforeClosing == 0 || k.numberRecordsSent < k.NumberRecordsBeforeClosing { 119 | nextShardIterator = aws.String("ABCD1234") 120 | } 121 | return &kinesis.GetRecordsOutput{ 122 | MillisBehindLatest: aws.Int64(0), 123 | NextShardIterator: nextShardIterator, 124 | Records: []*kinesis.Record{ 125 | &kinesis.Record{ 126 | Data: k.RecordData, 127 | PartitionKey: aws.String("abcdefg"), 128 | SequenceNumber: aws.String(strings.Join([]string{"0000", string(k.numberRecordsSent)}, "")), 129 | }, 130 | }, 131 | }, nil 132 | } 133 | 134 | func createConsumer(t *testing.T, numRecords int, checkpointFound bool, shutdown bool) (consumer *testConsumer, kinesisSvc *mockKinesisClient, checkpointer *mockCheckpointer) { 135 | consumer = &testConsumer{} 136 | kinesisSvc = &mockKinesisClient{ 137 | NumberRecordsBeforeClosing: numRecords, 138 | RecordData: []byte("Hello World"), 139 | numShards: 1, 140 | } 141 | checkpointer = &mockCheckpointer{ 142 | checkpointFound: checkpointFound, 143 | } 144 | kc := &KinesisConsumer{ 145 | StreamName: "FOO", 146 | ShardIteratorType: "TRIM_HORIZON", 147 | RecordConsumer: consumer, 148 | checkpointer: checkpointer, 149 | svc: kinesisSvc, 150 | eventLoopSleepMs: 1, 151 | } 152 | 153 | err := kc.StartConsumer() 154 | if err != nil { 155 | t.Fatalf("Got unexpected error from StartConsumer: %s", err) 156 | } 157 | time.Sleep(200 * time.Millisecond) 158 | if shutdown { 159 | kc.Shutdown() 160 | } 161 | return 162 | } 163 | 164 | func TestStartConsumer(t *testing.T) { 165 | consumer, kinesisSvc, _ := createConsumer(t, 1, false, true) 166 | 167 | if consumer.ShardID != "00000000" { 168 | t.Errorf("Expected shardId to be set to 00000000, but got: %s", consumer.ShardID) 169 | } 170 | 171 | if len(consumer.Records) != 1 { 172 | t.Fatalf("Expected there to be one record from Kinesis, got %d", len(consumer.Records)) 173 | } 174 | 175 | if string(consumer.Records[0].Data) != "Hello World" { 176 | t.Errorf("Expected record to be \"Hello World\", got %s", consumer.Records[0].Data) 177 | } 178 | 179 | if string(consumer.Records[0].ShardID) != "00000000" { 180 | t.Errorf("Expected Shard ID to be \"00000000\", got %s", consumer.Records[0].ShardID) 181 | } 182 | 183 | if !consumer.IsShutdown { 184 | t.Errorf("Expected consumer to be shutdown but it was not") 185 | } 186 | 187 | consumer, kinesisSvc, _ = createConsumer(t, 2, true, true) 188 | if len(consumer.Records) != 2 { 189 | t.Errorf("Expected there to be two records from Kinesis, got %v", consumer.Records) 190 | } 191 | 192 | if !kinesisSvc.getShardIteratorCalled { 193 | t.Errorf("Expected shard iterator to be called, but it was not") 194 | } 195 | 196 | if !consumer.IsShutdown { 197 | t.Errorf("Expected consumer to be shutdown but it was not") 198 | } 199 | 200 | consumer, kinesisSvc, _ = createConsumer(t, 1, true, true) 201 | if !kinesisSvc.getShardIteratorCalled { 202 | t.Errorf("Expected shard iterator not to be called, but it was") 203 | } 204 | } 205 | 206 | func TestScaleDownShards(t *testing.T) { 207 | consumer, kinesisSvc, _ := createConsumer(t, 0, false, false) 208 | kinesisSvc.numShards = 2 209 | time.Sleep(10 * time.Millisecond) 210 | // Shards don't just "dissapear" they rather just return a nil nextIteratorShard 211 | kinesisSvc.NumberRecordsBeforeClosing = 1 212 | time.Sleep(10 * time.Millisecond) 213 | if consumer.IsShutdown { 214 | t.Errorf("Expected consumer to not be shutdown but it was") 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.4" 2 | services: 3 | kinesis: 4 | image: dlsniper/kinesalite 5 | ports: 6 | - 4567:4567 7 | expose: 8 | - 4567 9 | dynamodb: 10 | image: amazon/dynamodb-local 11 | ports: 12 | - 8000:8000 13 | expose: 14 | - 8000 15 | gokini: 16 | image: golang:1.10 17 | volumes: 18 | - .:/go/src/github.com/patrobinson/gokini 19 | working_dir: /go/src/github.com/patrobinson/gokini 20 | depends_on: 21 | - kinesis 22 | - dynamodb 23 | links: 24 | - kinesis 25 | - dynamodb 26 | environment: 27 | - KINESIS_ENDPOINT=http://kinesis:4567 28 | - DYNAMODB_ENDPOINT=http://dynamodb:8000 29 | - AWS_DEFAULT_REGION=ap-nil-1 30 | - AWS_REGION=ap-nil-1 31 | - AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE 32 | - AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY 33 | - GO111MODULE=on 34 | -------------------------------------------------------------------------------- /example_checkpoint_test.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | type CheckpointRecordConsumer struct { 9 | shardID string 10 | } 11 | 12 | func (p *CheckpointRecordConsumer) Init(shardID string) error { 13 | fmt.Printf("Checkpointer initializing\n") 14 | p.shardID = shardID 15 | return nil 16 | } 17 | 18 | func (p *CheckpointRecordConsumer) ProcessRecords(records []*Records, consumer *KinesisConsumer) { 19 | if len(records) > 0 { 20 | fmt.Printf("%s\n", records[0].Data) 21 | } 22 | consumer.Checkpoint(p.shardID, records[len(records)-1].SequenceNumber) 23 | } 24 | 25 | func (p *CheckpointRecordConsumer) Shutdown() { 26 | fmt.Print("PrintRecordConsumer Shutdown\n") 27 | } 28 | 29 | func ExampleCheckpointRecordConsumer() { 30 | // An implementation of the RecordConsumer interface that prints out records and checkpoints at the end 31 | rc := &PrintRecordConsumer{} 32 | kc := &KinesisConsumer{ 33 | StreamName: "KINESIS_STREAM_2", 34 | ShardIteratorType: "TRIM_HORIZON", 35 | RecordConsumer: rc, 36 | TableName: "gokini_2", 37 | EmptyRecordBackoffMs: 1000, 38 | } 39 | 40 | // Send records to our kinesis stream so we have something to process 41 | pushRecordToKinesis("KINESIS_STREAM_2", []byte("example_checkpoint_record_consumer"), true) 42 | defer deleteStream("KINESIS_STREAM_2") 43 | defer deleteTable("gokini_2") 44 | 45 | err := kc.StartConsumer() 46 | if err != nil { 47 | fmt.Printf("Failed to start consumer: %s", err) 48 | } 49 | 50 | // Wait for it to do it's thing 51 | time.Sleep(200 * time.Millisecond) 52 | kc.Shutdown() 53 | time.Sleep(200 * time.Millisecond) 54 | 55 | // Output: 56 | // Checkpointer initializing 57 | // example_checkpoint_record_consumer 58 | // PrintRecordConsumer Shutdown 59 | } 60 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | type PrintRecordConsumer struct { 9 | shardID string 10 | } 11 | 12 | func (p *PrintRecordConsumer) Init(shardID string) error { 13 | fmt.Printf("Checkpointer initializing\n") 14 | p.shardID = shardID 15 | return nil 16 | } 17 | 18 | func (p *PrintRecordConsumer) ProcessRecords(records []*Records, consumer *KinesisConsumer) { 19 | if len(records) > 0 { 20 | fmt.Printf("%s\n", records[0].Data) 21 | } 22 | } 23 | 24 | func (p *PrintRecordConsumer) Shutdown() { 25 | fmt.Print("PrintRecordConsumer Shutdown\n") 26 | } 27 | 28 | func ExampleRecordConsumer() { 29 | // An implementation of the RecordConsumer interface that prints out records 30 | rc := &PrintRecordConsumer{} 31 | kc := &KinesisConsumer{ 32 | StreamName: "KINESIS_STREAM", 33 | ShardIteratorType: "TRIM_HORIZON", 34 | RecordConsumer: rc, 35 | TableName: "gokini", 36 | EmptyRecordBackoffMs: 1000, 37 | } 38 | 39 | // Send records to our kinesis stream so we have something to process 40 | pushRecordToKinesis("KINESIS_STREAM", []byte("foo"), true) 41 | defer deleteStream("KINESIS_STREAM") 42 | defer deleteTable("gokini") 43 | 44 | err := kc.StartConsumer() 45 | if err != nil { 46 | fmt.Printf("Failed to start consumer: %s", err) 47 | } 48 | 49 | // Wait for it to do it's thing 50 | time.Sleep(200 * time.Millisecond) 51 | kc.Shutdown() 52 | 53 | // Output: 54 | // Checkpointer initializing 55 | // foo 56 | // PrintRecordConsumer Shutdown 57 | } 58 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/patrobinson/gokini 2 | 3 | require ( 4 | github.com/aws/aws-sdk-go v1.19.38 5 | github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 // indirect 6 | github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927 // indirect 7 | github.com/gogo/protobuf v1.1.1 // indirect 8 | github.com/golang/protobuf v1.2.0 // indirect 9 | github.com/google/uuid v1.0.0 10 | github.com/matryer/try v0.0.0-20161228173917-9ac251b645a2 11 | github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect 12 | github.com/pkg/errors v0.9.1 13 | github.com/prometheus/client_golang v0.9.1 14 | github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 // indirect 15 | github.com/prometheus/common v0.0.0-20181116084131-1f2c4f3cd6db 16 | github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d // indirect 17 | github.com/sirupsen/logrus v1.2.0 18 | golang.org/x/net v0.0.0-20181114220301-adae6a3d119a // indirect 19 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f // indirect 20 | golang.org/x/text v0.3.0 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/aws/aws-sdk-go v1.15.78 h1:LaXy6lWR0YK7LKyuU0QWy2ws/LWTPfYV/UgfiBu4tvY= 2 | github.com/aws/aws-sdk-go v1.15.78/go.mod h1:E3/ieXAlvM0XWO57iftYVDLLvQ824smPP3ATZkfNZeM= 3 | github.com/aws/aws-sdk-go v1.19.38 h1:WKjobgPO4Ua1ww2NJJl2/zQNreUZxvqmEzwMlRjjm9g= 4 | github.com/aws/aws-sdk-go v1.19.38/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= 5 | github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= 6 | github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= 7 | github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927 h1:SKI1/fuSdodxmNNyVBR8d7X/HuLnRpvvFO0AgyQk764= 8 | github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927/go.mod h1:h/aW8ynjgkuj+NQRlZcDbAbM1ORAbXjXX77sX7T289U= 9 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 10 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 11 | github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= 12 | github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= 13 | github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= 14 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 15 | github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= 16 | github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 17 | github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE= 18 | github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= 19 | github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= 20 | github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= 21 | github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= 22 | github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= 23 | github.com/matryer/try v0.0.0-20161228173917-9ac251b645a2 h1:JAEbJn3j/FrhdWA9jW8B5ajsLIjeuEHLi8xE4fk997o= 24 | github.com/matryer/try v0.0.0-20161228173917-9ac251b645a2/go.mod h1:0KeJpeMD6o+O4hW7qJOT7vyQPKrWmj26uf5wMc/IiIs= 25 | github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= 26 | github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= 27 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 28 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 29 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 30 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 31 | github.com/prometheus/client_golang v0.9.1 h1:K47Rk0v/fkEfwfQet2KWhscE0cJzjgCCDBG2KHZoVno= 32 | github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= 33 | github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8= 34 | github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= 35 | github.com/prometheus/common v0.0.0-20181116084131-1f2c4f3cd6db h1:ckMAAQJ96ZKwKyiGamJdsinLn3D9+daeRlvvmYo9tkI= 36 | github.com/prometheus/common v0.0.0-20181116084131-1f2c4f3cd6db/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= 37 | github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFdaDqxJVlbOQ1DtGmZWs/Qau0hIlk+WQ= 38 | github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= 39 | github.com/sirupsen/logrus v1.2.0 h1:juTguoYk5qI21pwyTXY3B3Y5cOTH3ZUyZCg1v/mihuo= 40 | github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= 41 | github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 42 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 43 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 44 | golang.org/x/crypto v0.0.0-20180904163835-0709b304e793 h1:u+LnwYTOOW7Ukr/fppxEb1Nwz0AtPflrblfvUudpo+I= 45 | golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= 46 | golang.org/x/net v0.0.0-20181114220301-adae6a3d119a h1:gOpx8G595UYyvj8UK4+OFyY4rx037g3fmfhe5SasG3U= 47 | golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 48 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= 49 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 50 | golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8= 51 | golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 52 | golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= 53 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 54 | -------------------------------------------------------------------------------- /helper_test.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "os" 5 | "time" 6 | "fmt" 7 | 8 | "github.com/aws/aws-sdk-go/aws" 9 | awsclient "github.com/aws/aws-sdk-go/aws/client" 10 | "github.com/aws/aws-sdk-go/aws/session" 11 | "github.com/aws/aws-sdk-go/service/dynamodb" 12 | "github.com/aws/aws-sdk-go/service/kinesis" 13 | 14 | log "github.com/sirupsen/logrus" 15 | ) 16 | 17 | func createStream(streamName string, shards int64) error { 18 | session, err := session.NewSessionWithOptions( 19 | session.Options{ 20 | SharedConfigState: session.SharedConfigEnable, 21 | Config: aws.Config{ 22 | CredentialsChainVerboseErrors: aws.Bool(true), 23 | Endpoint: aws.String(os.Getenv("KINESIS_ENDPOINT")), 24 | Retryer: awsclient.DefaultRetryer{NumMaxRetries: 1}, 25 | }, 26 | }, 27 | ) 28 | if err != nil { 29 | return fmt.Errorf("Error starting kinesis client %s", err) 30 | } 31 | svc := kinesis.New(session) 32 | _, err = svc.CreateStream(&kinesis.CreateStreamInput{ 33 | ShardCount: aws.Int64(shards), 34 | StreamName: aws.String(streamName), 35 | }) 36 | if err != nil { 37 | return err 38 | } 39 | time.Sleep(500 * time.Millisecond) 40 | return nil 41 | } 42 | 43 | func pushRecordToKinesis(streamName string, record []byte, create bool) error { 44 | session, err := session.NewSessionWithOptions( 45 | session.Options{ 46 | SharedConfigState: session.SharedConfigEnable, 47 | Config: aws.Config{ 48 | CredentialsChainVerboseErrors: aws.Bool(true), 49 | Endpoint: aws.String(os.Getenv("KINESIS_ENDPOINT")), 50 | Retryer: awsclient.DefaultRetryer{NumMaxRetries: 1}, 51 | }, 52 | }, 53 | ) 54 | if err != nil { 55 | log.Errorf("Error starting kinesis client %s", err) 56 | return err 57 | } 58 | svc := kinesis.New(session) 59 | if create { 60 | if err := createStream(streamName, 1); err != nil { 61 | return err 62 | } 63 | } 64 | _, err = svc.PutRecord(&kinesis.PutRecordInput{ 65 | Data: record, 66 | PartitionKey: aws.String("abc123"), 67 | StreamName: &streamName, 68 | }) 69 | if err != nil { 70 | log.Errorf("Error sending data to kinesis %s", err) 71 | } 72 | return err 73 | } 74 | 75 | func deleteStream(streamName string) { 76 | session, _ := session.NewSessionWithOptions( 77 | session.Options{ 78 | SharedConfigState: session.SharedConfigEnable, 79 | Config: aws.Config{ 80 | CredentialsChainVerboseErrors: aws.Bool(true), 81 | Endpoint: aws.String(os.Getenv("KINESIS_ENDPOINT")), 82 | }, 83 | }, 84 | ) 85 | svc := kinesis.New(session) 86 | _, err := svc.DeleteStream(&kinesis.DeleteStreamInput{ 87 | StreamName: &streamName, 88 | }) 89 | if err != nil { 90 | log.Errorln(err) 91 | } 92 | } 93 | 94 | func deleteTable(tableName string) { 95 | session, _ := session.NewSessionWithOptions( 96 | session.Options{ 97 | Config: aws.Config{ 98 | Endpoint: aws.String(os.Getenv("DYNAMODB_ENDPOINT")), 99 | }, 100 | SharedConfigState: session.SharedConfigEnable, 101 | }, 102 | ) 103 | svc := dynamodb.New(session) 104 | _, err := svc.DeleteTable(&dynamodb.DeleteTableInput{ 105 | TableName: &tableName, 106 | }) 107 | if err != nil { 108 | log.Errorln(err) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /integration_test.go: -------------------------------------------------------------------------------- 1 | //+build integration 2 | 3 | package gokini 4 | 5 | import ( 6 | "fmt" 7 | "net/http" 8 | "testing" 9 | "time" 10 | 11 | "github.com/google/uuid" 12 | 13 | dto "github.com/prometheus/client_model/go" 14 | "github.com/prometheus/common/expfmt" 15 | ) 16 | 17 | type IntegrationRecordConsumer struct { 18 | shardID string 19 | processedRecords map[string]int 20 | } 21 | 22 | func (p *IntegrationRecordConsumer) Init(shardID string) error { 23 | p.shardID = shardID 24 | return nil 25 | } 26 | 27 | func (p *IntegrationRecordConsumer) ProcessRecords(records []*Records, consumer *KinesisConsumer) { 28 | if len(records) > 0 { 29 | for _, record := range records { 30 | p.processedRecords[record.SequenceNumber] += 1 31 | } 32 | } 33 | } 34 | 35 | func (p *IntegrationRecordConsumer) Shutdown() {} 36 | 37 | func TestCheckpointRecovery(t *testing.T) { 38 | rc := &IntegrationRecordConsumer{ 39 | processedRecords: make(map[string]int), 40 | } 41 | kc := &KinesisConsumer{ 42 | StreamName: "checkpoint_recovery", 43 | ShardIteratorType: "TRIM_HORIZON", 44 | RecordConsumer: rc, 45 | TableName: "checkpoint_recovery", 46 | EmptyRecordBackoffMs: 2000, 47 | LeaseDuration: 1, 48 | eventLoopSleepMs: 1, 49 | } 50 | pushRecordToKinesis("checkpoint_recovery", []byte("abcd"), true) 51 | defer deleteStream("checkpoint_recovery") 52 | defer deleteTable("checkpoint_recovery") 53 | 54 | err := kc.StartConsumer() 55 | if err != nil { 56 | t.Errorf("Error starting consumer %s", err) 57 | } 58 | 59 | time.Sleep(200 * time.Millisecond) 60 | kc.Shutdown() 61 | 62 | kc = &KinesisConsumer{ 63 | StreamName: "checkpoint_recovery", 64 | ShardIteratorType: "TRIM_HORIZON", 65 | RecordConsumer: rc, 66 | TableName: "checkpoint_recovery", 67 | LeaseDuration: 1, 68 | } 69 | 70 | err = kc.StartConsumer() 71 | if err != nil { 72 | t.Errorf("Error starting consumer %s", err) 73 | } 74 | time.Sleep(200 * time.Millisecond) 75 | for sequenceID, timesSequenceProcessed := range rc.processedRecords { 76 | fmt.Printf("seqenceID: %s, processed %d time(s)\n", sequenceID, timesSequenceProcessed) 77 | if timesSequenceProcessed > 1 { 78 | t.Errorf("Sequence number %s was processed more than once", sequenceID) 79 | } 80 | } 81 | kc.Shutdown() 82 | } 83 | 84 | func TestCheckpointGainLock(t *testing.T) { 85 | rc := &IntegrationRecordConsumer{ 86 | processedRecords: make(map[string]int), 87 | } 88 | kc := &KinesisConsumer{ 89 | StreamName: "checkpoint_gain_lock", 90 | ShardIteratorType: "TRIM_HORIZON", 91 | RecordConsumer: rc, 92 | TableName: "checkpoint_gain_lock", 93 | EmptyRecordBackoffMs: 2000, 94 | LeaseDuration: 100, 95 | } 96 | pushRecordToKinesis("checkpoint_gain_lock", []byte("abcd"), true) 97 | defer deleteStream("checkpoint_gain_lock") 98 | defer deleteTable("checkpoint_gain_lock") 99 | 100 | err := kc.StartConsumer() 101 | if err != nil { 102 | t.Errorf("Error starting consumer %s", err) 103 | } 104 | 105 | time.Sleep(200 * time.Millisecond) 106 | kc.Shutdown() 107 | 108 | kc = &KinesisConsumer{ 109 | StreamName: "checkpoint_gain_lock", 110 | ShardIteratorType: "TRIM_HORIZON", 111 | RecordConsumer: rc, 112 | TableName: "checkpoint_gain_lock", 113 | LeaseDuration: 100, 114 | } 115 | 116 | err = kc.StartConsumer() 117 | if err != nil { 118 | t.Errorf("Error starting consumer %s", err) 119 | } 120 | pushRecordToKinesis("checkpoint_gain_lock", []byte("abcd"), false) 121 | time.Sleep(200 * time.Millisecond) 122 | if len(rc.processedRecords) != 2 { 123 | t.Errorf("Expected to have processed 2 records") 124 | for sequenceId, timesProcessed := range rc.processedRecords { 125 | fmt.Println("Processed", sequenceId, timesProcessed, "time(s)") 126 | } 127 | } 128 | kc.Shutdown() 129 | } 130 | 131 | func TestPrometheusMonitoring(t *testing.T) { 132 | rc := &IntegrationRecordConsumer{ 133 | processedRecords: make(map[string]int), 134 | } 135 | kc := &KinesisConsumer{ 136 | StreamName: "prometheus_monitoring", 137 | ShardIteratorType: "TRIM_HORIZON", 138 | RecordConsumer: rc, 139 | TableName: "prometheus_monitoring", 140 | EmptyRecordBackoffMs: 2000, 141 | LeaseDuration: 1, 142 | Monitoring: MonitoringConfiguration{ 143 | MonitoringService: "prometheus", 144 | Prometheus: prometheusMonitoringService{ 145 | ListenAddress: ":8080", 146 | }, 147 | }, 148 | } 149 | pushRecordToKinesis("prometheus_monitoring", []byte("abcd"), true) 150 | defer deleteStream("prometheus_monitoring") 151 | defer deleteTable("prometheus_monitoring") 152 | 153 | err := kc.StartConsumer() 154 | if err != nil { 155 | t.Errorf("Error starting consumer %s", err) 156 | } 157 | time.Sleep(1200 * time.Millisecond) 158 | 159 | res, err := http.Get("http://localhost:8080/metrics") 160 | if err != nil { 161 | t.Fatalf("Error scraping Prometheus endpoint %s", err) 162 | } 163 | kc.Shutdown() 164 | 165 | var parser expfmt.TextParser 166 | parsed, err := parser.TextToMetricFamilies(res.Body) 167 | res.Body.Close() 168 | if err != nil { 169 | t.Errorf("Error reading monitoring response %s", err) 170 | } 171 | 172 | if parsed["gokini_processed_bytes"] == nil || parsed["gokini_processed_records"] == nil || parsed["gokini_leases_held"] == nil { 173 | t.Fatalf("Missing metrics %s", keys(parsed)) 174 | } 175 | 176 | if *parsed["gokini_processed_bytes"].Metric[0].Counter.Value != float64(4) { 177 | t.Errorf("Expected to have read 4 bytes, got %d", int(*parsed["gokini_processed_bytes"].Metric[0].Counter.Value)) 178 | } 179 | 180 | if *parsed["gokini_processed_records"].Metric[0].Counter.Value != float64(1) { 181 | t.Errorf("Expected to have read 1 records, got %d", int(*parsed["gokini_processed_records"].Metric[0].Counter.Value)) 182 | } 183 | 184 | if *parsed["gokini_leases_held"].Metric[0].Gauge.Value != float64(1) { 185 | t.Errorf("Expected to have 1 lease held, got %d", int(*parsed["gokini_leases_held"].Metric[0].Counter.Value)) 186 | } 187 | } 188 | 189 | func setupConsumer(name string, t *testing.T) *KinesisConsumer { 190 | rc := &IntegrationRecordConsumer{ 191 | processedRecords: make(map[string]int), 192 | } 193 | kc := &KinesisConsumer{ 194 | StreamName: name, 195 | ShardIteratorType: "TRIM_HORIZON", 196 | RecordConsumer: rc, 197 | TableName: name, 198 | EmptyRecordBackoffMs: 50, 199 | LeaseDuration: 200, 200 | eventLoopSleepMs: 100, 201 | millisecondsBackoffClaim: 200, 202 | } 203 | err := kc.StartConsumer() 204 | if err != nil { 205 | t.Fatalf("Error starting consumer %s", err) 206 | } 207 | return kc 208 | } 209 | 210 | func TestRebalance(t *testing.T) { 211 | uuid, _ := uuid.NewUUID() 212 | name := uuid.String() 213 | err := createStream(name, 2) 214 | if err != nil { 215 | t.Fatalf("Error creating stream %s", err) 216 | } 217 | kc := setupConsumer(name, t) 218 | time.Sleep(200 * time.Millisecond) 219 | secondKc := setupConsumer(name, t) 220 | defer deleteStream(name) 221 | defer deleteTable(name) 222 | time.Sleep(1000 * time.Millisecond) 223 | workers, err := kc.checkpointer.ListActiveWorkers() 224 | if err != nil { 225 | t.Fatalf("Error getting workers %s", err) 226 | } 227 | if len(workers[kc.consumerID]) != 1 { 228 | t.Errorf("Expected consumer to have 1 shard, it has %d", len(workers[kc.consumerID])) 229 | } 230 | if len(workers[secondKc.consumerID]) != 1 { 231 | t.Errorf("Expected consumer to have 1 shard, it has %d", len(workers[secondKc.consumerID])) 232 | } 233 | kc.Shutdown() 234 | secondKc.Shutdown() 235 | } 236 | 237 | func keys(m map[string]*dto.MetricFamily) []string { 238 | var keys []string 239 | for k := range m { 240 | keys = append(keys, k) 241 | } 242 | return keys 243 | } 244 | -------------------------------------------------------------------------------- /monitoring.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "sync" 7 | "time" 8 | 9 | "github.com/aws/aws-sdk-go/aws" 10 | "github.com/aws/aws-sdk-go/aws/session" 11 | "github.com/aws/aws-sdk-go/service/cloudwatch" 12 | "github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface" 13 | "github.com/prometheus/client_golang/prometheus" 14 | "github.com/prometheus/client_golang/prometheus/promhttp" 15 | log "github.com/sirupsen/logrus" 16 | ) 17 | 18 | // MonitoringConfiguration allows you to configure how record processing metrics are exposed 19 | type MonitoringConfiguration struct { 20 | MonitoringService string // Type of monitoring to expose. Supported types are "prometheus" 21 | Prometheus prometheusMonitoringService 22 | CloudWatch cloudWatchMonitoringService 23 | service monitoringService 24 | } 25 | 26 | type monitoringService interface { 27 | init() error 28 | incrRecordsProcessed(string, int) 29 | incrBytesProcessed(string, int64) 30 | millisBehindLatest(string, float64) 31 | leaseGained(string) 32 | leaseLost(string) 33 | leaseRenewed(string) 34 | recordGetRecordsTime(string, float64) 35 | recordProcessRecordsTime(string, float64) 36 | } 37 | 38 | func (m *MonitoringConfiguration) init(streamName string, workerID string, sess *session.Session) error { 39 | if m.MonitoringService == "" { 40 | m.service = &noopMonitoringService{} 41 | return nil 42 | } 43 | 44 | switch m.MonitoringService { 45 | case "prometheus": 46 | m.Prometheus.KinesisStream = streamName 47 | m.Prometheus.WorkerID = workerID 48 | m.service = &m.Prometheus 49 | case "cloudwatch": 50 | m.CloudWatch.Session = sess 51 | m.CloudWatch.KinesisStream = streamName 52 | m.CloudWatch.WorkerID = workerID 53 | m.service = &m.CloudWatch 54 | default: 55 | return fmt.Errorf("Invalid monitoring service type %s", m.MonitoringService) 56 | } 57 | return m.service.init() 58 | } 59 | 60 | type prometheusMonitoringService struct { 61 | ListenAddress string 62 | Namespace string 63 | KinesisStream string 64 | WorkerID string 65 | processedRecords *prometheus.CounterVec 66 | processedBytes *prometheus.CounterVec 67 | behindLatestMillis *prometheus.GaugeVec 68 | leasesHeld *prometheus.GaugeVec 69 | leaseRenewals *prometheus.CounterVec 70 | getRecordsTime *prometheus.HistogramVec 71 | processRecordsTime *prometheus.HistogramVec 72 | } 73 | 74 | const defaultNamespace = "gokini" 75 | 76 | func (p *prometheusMonitoringService) init() error { 77 | if p.Namespace == "" { 78 | p.Namespace = defaultNamespace 79 | } 80 | p.processedBytes = prometheus.NewCounterVec(prometheus.CounterOpts{ 81 | Name: p.Namespace + `_processed_bytes`, 82 | Help: "Number of bytes processed", 83 | }, []string{"kinesisStream", "shard"}) 84 | p.processedRecords = prometheus.NewCounterVec(prometheus.CounterOpts{ 85 | Name: p.Namespace + `_processed_records`, 86 | Help: "Number of records processed", 87 | }, []string{"kinesisStream", "shard"}) 88 | p.behindLatestMillis = prometheus.NewGaugeVec(prometheus.GaugeOpts{ 89 | Name: p.Namespace + `_behind_latest_millis`, 90 | Help: "The amount of milliseconds processing is behind", 91 | }, []string{"kinesisStream", "shard"}) 92 | p.leasesHeld = prometheus.NewGaugeVec(prometheus.GaugeOpts{ 93 | Name: p.Namespace + `_leases_held`, 94 | Help: "The number of leases held by the worker", 95 | }, []string{"kinesisStream", "shard", "workerID"}) 96 | p.leaseRenewals = prometheus.NewCounterVec(prometheus.CounterOpts{ 97 | Name: p.Namespace + `_lease_renewals`, 98 | Help: "The number of successful lease renewals", 99 | }, []string{"kinesisStream", "shard", "workerID"}) 100 | p.getRecordsTime = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 101 | Name: p.Namespace + `_get_records_duration_milliseconds`, 102 | Help: "The time taken to fetch records and process them", 103 | }, []string{"kinesisStream", "shard"}) 104 | p.processRecordsTime = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 105 | Name: p.Namespace + `_process_records_duration_milliseconds`, 106 | Help: "The time taken to process records", 107 | }, []string{"kinesisStream", "shard"}) 108 | 109 | metrics := []prometheus.Collector{ 110 | p.processedBytes, 111 | p.processedRecords, 112 | p.behindLatestMillis, 113 | p.leasesHeld, 114 | p.leaseRenewals, 115 | p.getRecordsTime, 116 | p.processRecordsTime, 117 | } 118 | for _, metric := range metrics { 119 | err := prometheus.Register(metric) 120 | if err != nil { 121 | return err 122 | } 123 | } 124 | 125 | http.Handle("/metrics", promhttp.Handler()) 126 | go func() { 127 | log.Debugf("Starting Prometheus listener on %s", p.ListenAddress) 128 | err := http.ListenAndServe(p.ListenAddress, nil) 129 | if err != nil { 130 | log.Errorln("Error starting Prometheus metrics endpoint", err) 131 | } 132 | }() 133 | return nil 134 | } 135 | 136 | func (p *prometheusMonitoringService) incrRecordsProcessed(shard string, count int) { 137 | p.processedRecords.With(prometheus.Labels{"shard": shard, "kinesisStream": p.KinesisStream}).Add(float64(count)) 138 | } 139 | 140 | func (p *prometheusMonitoringService) incrBytesProcessed(shard string, count int64) { 141 | p.processedBytes.With(prometheus.Labels{"shard": shard, "kinesisStream": p.KinesisStream}).Add(float64(count)) 142 | } 143 | 144 | func (p *prometheusMonitoringService) millisBehindLatest(shard string, millSeconds float64) { 145 | p.behindLatestMillis.With(prometheus.Labels{"shard": shard, "kinesisStream": p.KinesisStream}).Set(millSeconds) 146 | } 147 | 148 | func (p *prometheusMonitoringService) leaseGained(shard string) { 149 | p.leasesHeld.With(prometheus.Labels{"shard": shard, "kinesisStream": p.KinesisStream, "workerID": p.WorkerID}).Inc() 150 | } 151 | 152 | func (p *prometheusMonitoringService) leaseLost(shard string) { 153 | p.leasesHeld.With(prometheus.Labels{"shard": shard, "kinesisStream": p.KinesisStream, "workerID": p.WorkerID}).Dec() 154 | } 155 | 156 | func (p *prometheusMonitoringService) leaseRenewed(shard string) { 157 | p.leaseRenewals.With(prometheus.Labels{"shard": shard, "kinesisStream": p.KinesisStream, "workerID": p.WorkerID}).Inc() 158 | } 159 | 160 | func (p *prometheusMonitoringService) recordGetRecordsTime(shard string, time float64) { 161 | p.getRecordsTime.With(prometheus.Labels{"shard": shard, "kinesisStream": p.KinesisStream}).Observe(time) 162 | } 163 | 164 | func (p *prometheusMonitoringService) recordProcessRecordsTime(shard string, time float64) { 165 | p.processRecordsTime.With(prometheus.Labels{"shard": shard, "kinesisStream": p.KinesisStream}).Observe(time) 166 | } 167 | 168 | type noopMonitoringService struct{} 169 | 170 | func (n *noopMonitoringService) init() error { 171 | return nil 172 | } 173 | 174 | func (n *noopMonitoringService) incrRecordsProcessed(shard string, count int) {} 175 | func (n *noopMonitoringService) incrBytesProcessed(shard string, count int64) {} 176 | func (n *noopMonitoringService) millisBehindLatest(shard string, millSeconds float64) {} 177 | func (n *noopMonitoringService) leaseGained(shard string) {} 178 | func (n *noopMonitoringService) leaseLost(shard string) {} 179 | func (n *noopMonitoringService) leaseRenewed(shard string) {} 180 | func (n *noopMonitoringService) recordGetRecordsTime(shard string, time float64) {} 181 | func (n *noopMonitoringService) recordProcessRecordsTime(shard string, time float64) {} 182 | 183 | type cloudWatchMonitoringService struct { 184 | Namespace string 185 | KinesisStream string 186 | WorkerID string 187 | // What granularity we should send metrics to CW at. Note setting this to 1 will cost quite a bit of money 188 | // At the time of writing (March 2018) about US$200 per month 189 | ResolutionSec int 190 | Session *session.Session 191 | svc cloudwatchiface.CloudWatchAPI 192 | shardMetrics map[string]*cloudWatchMetrics 193 | } 194 | 195 | type cloudWatchMetrics struct { 196 | processedRecords int64 197 | processedBytes int64 198 | behindLatestMillis []float64 199 | leasesHeld int64 200 | leaseRenewals int64 201 | getRecordsTime []float64 202 | processRecordsTime []float64 203 | sync.Mutex 204 | } 205 | 206 | func (cw *cloudWatchMonitoringService) init() error { 207 | if cw.ResolutionSec == 0 { 208 | cw.ResolutionSec = 60 209 | } 210 | 211 | cw.svc = cloudwatch.New(cw.Session) 212 | cw.shardMetrics = make(map[string]*cloudWatchMetrics) 213 | return nil 214 | } 215 | 216 | func (cw *cloudWatchMonitoringService) flushDaemon() { 217 | previousFlushTime := time.Now() 218 | resolutionDuration := time.Duration(cw.ResolutionSec) * time.Second 219 | for { 220 | time.Sleep(resolutionDuration - time.Now().Sub(previousFlushTime)) 221 | err := cw.flush() 222 | if err != nil { 223 | log.Errorln("Error sending metrics to CloudWatch", err) 224 | } 225 | previousFlushTime = time.Now() 226 | } 227 | } 228 | 229 | func (cw *cloudWatchMonitoringService) flush() error { 230 | for shard, metric := range cw.shardMetrics { 231 | metric.Lock() 232 | defaultDimensions := []*cloudwatch.Dimension{ 233 | &cloudwatch.Dimension{ 234 | Name: aws.String("shard"), 235 | Value: &shard, 236 | }, 237 | &cloudwatch.Dimension{ 238 | Name: aws.String("KinesisStreamName"), 239 | Value: &cw.KinesisStream, 240 | }, 241 | } 242 | leaseDimensions := make([]*cloudwatch.Dimension, len(defaultDimensions)) 243 | copy(defaultDimensions, leaseDimensions) 244 | leaseDimensions = append(leaseDimensions, &cloudwatch.Dimension{ 245 | Name: aws.String("WorkerID"), 246 | Value: &cw.WorkerID, 247 | }) 248 | metricTimestamp := time.Now() 249 | _, err := cw.svc.PutMetricData(&cloudwatch.PutMetricDataInput{ 250 | Namespace: aws.String(cw.Namespace), 251 | MetricData: []*cloudwatch.MetricDatum{ 252 | &cloudwatch.MetricDatum{ 253 | Dimensions: defaultDimensions, 254 | MetricName: aws.String("RecordsProcessed"), 255 | Unit: aws.String("Count"), 256 | Timestamp: &metricTimestamp, 257 | Value: aws.Float64(float64(metric.processedRecords)), 258 | }, 259 | &cloudwatch.MetricDatum{ 260 | Dimensions: defaultDimensions, 261 | MetricName: aws.String("DataBytesProcessed"), 262 | Unit: aws.String("Byte"), 263 | Timestamp: &metricTimestamp, 264 | Value: aws.Float64(float64(metric.processedBytes)), 265 | }, 266 | &cloudwatch.MetricDatum{ 267 | Dimensions: defaultDimensions, 268 | MetricName: aws.String("MillisBehindLatest"), 269 | Unit: aws.String("Milliseconds"), 270 | Timestamp: &metricTimestamp, 271 | StatisticValues: &cloudwatch.StatisticSet{ 272 | SampleCount: aws.Float64(float64(len(metric.behindLatestMillis))), 273 | Sum: sumFloat64(metric.behindLatestMillis), 274 | Maximum: maxFloat64(metric.behindLatestMillis), 275 | Minimum: minFloat64(metric.behindLatestMillis), 276 | }, 277 | }, 278 | &cloudwatch.MetricDatum{ 279 | Dimensions: defaultDimensions, 280 | MetricName: aws.String("KinesisDataFetcher.getRecords.Time"), 281 | Unit: aws.String("Milliseconds"), 282 | Timestamp: &metricTimestamp, 283 | StatisticValues: &cloudwatch.StatisticSet{ 284 | SampleCount: aws.Float64(float64(len(metric.getRecordsTime))), 285 | Sum: sumFloat64(metric.getRecordsTime), 286 | Maximum: maxFloat64(metric.getRecordsTime), 287 | Minimum: minFloat64(metric.getRecordsTime), 288 | }, 289 | }, 290 | &cloudwatch.MetricDatum{ 291 | Dimensions: defaultDimensions, 292 | MetricName: aws.String("RecordProcessor.processRecords.Time"), 293 | Unit: aws.String("Milliseconds"), 294 | Timestamp: &metricTimestamp, 295 | StatisticValues: &cloudwatch.StatisticSet{ 296 | SampleCount: aws.Float64(float64(len(metric.processRecordsTime))), 297 | Sum: sumFloat64(metric.processRecordsTime), 298 | Maximum: maxFloat64(metric.processRecordsTime), 299 | Minimum: minFloat64(metric.processRecordsTime), 300 | }, 301 | }, 302 | &cloudwatch.MetricDatum{ 303 | Dimensions: leaseDimensions, 304 | MetricName: aws.String("RenewLease.Success"), 305 | Unit: aws.String("Count"), 306 | Timestamp: &metricTimestamp, 307 | Value: aws.Float64(float64(metric.leaseRenewals)), 308 | }, 309 | &cloudwatch.MetricDatum{ 310 | Dimensions: leaseDimensions, 311 | MetricName: aws.String("CurrentLeases"), 312 | Unit: aws.String("Count"), 313 | Timestamp: &metricTimestamp, 314 | Value: aws.Float64(float64(metric.leasesHeld)), 315 | }, 316 | }, 317 | }) 318 | if err == nil { 319 | metric.processedRecords = 0 320 | metric.processedBytes = 0 321 | metric.behindLatestMillis = []float64{} 322 | metric.leaseRenewals = 0 323 | metric.getRecordsTime = []float64{} 324 | metric.processRecordsTime = []float64{} 325 | } 326 | metric.Unlock() 327 | return err 328 | } 329 | return nil 330 | } 331 | 332 | func (cw *cloudWatchMonitoringService) incrRecordsProcessed(shard string, count int) { 333 | if _, ok := cw.shardMetrics[shard]; !ok { 334 | cw.shardMetrics[shard] = &cloudWatchMetrics{} 335 | } 336 | cw.shardMetrics[shard].Lock() 337 | defer cw.shardMetrics[shard].Unlock() 338 | cw.shardMetrics[shard].processedRecords += int64(count) 339 | } 340 | 341 | func (cw *cloudWatchMonitoringService) incrBytesProcessed(shard string, count int64) { 342 | if _, ok := cw.shardMetrics[shard]; !ok { 343 | cw.shardMetrics[shard] = &cloudWatchMetrics{} 344 | } 345 | cw.shardMetrics[shard].Lock() 346 | defer cw.shardMetrics[shard].Unlock() 347 | cw.shardMetrics[shard].processedBytes += count 348 | } 349 | 350 | func (cw *cloudWatchMonitoringService) millisBehindLatest(shard string, millSeconds float64) { 351 | if _, ok := cw.shardMetrics[shard]; !ok { 352 | cw.shardMetrics[shard] = &cloudWatchMetrics{} 353 | } 354 | cw.shardMetrics[shard].Lock() 355 | defer cw.shardMetrics[shard].Unlock() 356 | cw.shardMetrics[shard].behindLatestMillis = append(cw.shardMetrics[shard].behindLatestMillis, millSeconds) 357 | } 358 | 359 | func (cw *cloudWatchMonitoringService) leaseGained(shard string) { 360 | if _, ok := cw.shardMetrics[shard]; !ok { 361 | cw.shardMetrics[shard] = &cloudWatchMetrics{} 362 | } 363 | cw.shardMetrics[shard].Lock() 364 | defer cw.shardMetrics[shard].Unlock() 365 | cw.shardMetrics[shard].leasesHeld++ 366 | } 367 | 368 | func (cw *cloudWatchMonitoringService) leaseLost(shard string) { 369 | if _, ok := cw.shardMetrics[shard]; !ok { 370 | cw.shardMetrics[shard] = &cloudWatchMetrics{} 371 | } 372 | cw.shardMetrics[shard].Lock() 373 | defer cw.shardMetrics[shard].Unlock() 374 | cw.shardMetrics[shard].leasesHeld-- 375 | } 376 | 377 | func (cw *cloudWatchMonitoringService) leaseRenewed(shard string) { 378 | if _, ok := cw.shardMetrics[shard]; !ok { 379 | cw.shardMetrics[shard] = &cloudWatchMetrics{} 380 | } 381 | cw.shardMetrics[shard].Lock() 382 | defer cw.shardMetrics[shard].Unlock() 383 | cw.shardMetrics[shard].leaseRenewals++ 384 | } 385 | 386 | func (cw *cloudWatchMonitoringService) recordGetRecordsTime(shard string, time float64) { 387 | if _, ok := cw.shardMetrics[shard]; !ok { 388 | cw.shardMetrics[shard] = &cloudWatchMetrics{} 389 | } 390 | cw.shardMetrics[shard].Lock() 391 | defer cw.shardMetrics[shard].Unlock() 392 | cw.shardMetrics[shard].getRecordsTime = append(cw.shardMetrics[shard].getRecordsTime, time) 393 | } 394 | func (cw *cloudWatchMonitoringService) recordProcessRecordsTime(shard string, time float64) { 395 | if _, ok := cw.shardMetrics[shard]; !ok { 396 | cw.shardMetrics[shard] = &cloudWatchMetrics{} 397 | } 398 | cw.shardMetrics[shard].Lock() 399 | defer cw.shardMetrics[shard].Unlock() 400 | cw.shardMetrics[shard].processRecordsTime = append(cw.shardMetrics[shard].processRecordsTime, time) 401 | } 402 | 403 | func sumFloat64(slice []float64) *float64 { 404 | sum := float64(0) 405 | for _, num := range slice { 406 | sum += num 407 | } 408 | return &sum 409 | } 410 | 411 | func maxFloat64(slice []float64) *float64 { 412 | if len(slice) < 1 { 413 | return aws.Float64(0) 414 | } 415 | max := slice[0] 416 | for _, num := range slice { 417 | if num > max { 418 | max = num 419 | } 420 | } 421 | return &max 422 | } 423 | 424 | func minFloat64(slice []float64) *float64 { 425 | if len(slice) < 1 { 426 | return aws.Float64(0) 427 | } 428 | min := slice[0] 429 | for _, num := range slice { 430 | if num < min { 431 | min = num 432 | } 433 | } 434 | return &min 435 | } 436 | -------------------------------------------------------------------------------- /monitoring_test.go: -------------------------------------------------------------------------------- 1 | package gokini 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/aws/aws-sdk-go/aws/session" 7 | "github.com/aws/aws-sdk-go/service/cloudwatch" 8 | "github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface" 9 | ) 10 | 11 | type mockCloudWatch struct { 12 | cloudwatchiface.CloudWatchAPI 13 | metricData []*cloudwatch.MetricDatum 14 | } 15 | 16 | func (m *mockCloudWatch) PutMetricData(input *cloudwatch.PutMetricDataInput) (*cloudwatch.PutMetricDataOutput, error) { 17 | m.metricData = append(m.metricData, input.MetricData...) 18 | return &cloudwatch.PutMetricDataOutput{}, nil 19 | } 20 | 21 | func TestCloudWatchMonitoring(t *testing.T) { 22 | mockCW := &mockCloudWatch{} 23 | cwService := &cloudWatchMonitoringService{ 24 | Namespace: "testCloudWatchMonitoring", 25 | KinesisStream: "cloudwatch_monitoring", 26 | WorkerID: "abc123", 27 | ResolutionSec: 1, 28 | Session: session.New(), 29 | svc: mockCW, 30 | shardMetrics: map[string]*cloudWatchMetrics{}, 31 | } 32 | cwService.incrRecordsProcessed("00001", 10) 33 | err := cwService.flush() 34 | if err != nil { 35 | t.Errorf("Received error sending data to cloudwatch %s", err) 36 | } 37 | if len(mockCW.metricData) < 1 { 38 | t.Fatal("Expected at least one metric to be sent to cloudwatch") 39 | } 40 | 41 | if *mockCW.metricData[0].Value != float64(10) { 42 | t.Errorf("Expected metric value to be 10.0, got %f", *mockCW.metricData[0].Value) 43 | } 44 | } 45 | --------------------------------------------------------------------------------