├── .travis.yml ├── .gitignore ├── auth.go ├── LICENSE ├── README.md ├── client.go ├── auth_cachedmutexedwarmedup.go ├── auth_static.go ├── sign_test.go ├── auth_assumerole.go ├── firehose.go ├── auth_metadata.go ├── auth_test.go ├── kinesis-cli ├── README.md └── kinesis-cli.go ├── examples └── example.go ├── sign.go ├── kinesis_test.go ├── kinesis.go └── batchproducer ├── batchproducer.go └── batchproducer_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | before_install: 4 | - npm install -g kinesalite 5 | 6 | before_script: 7 | - kinesalite --createStreamMs 5 --deleteStreamMs 5 & 8 | - sleep 1 9 | 10 | script: go test ./... -parallel 2 11 | 12 | sudo: false 13 | 14 | notifications: 15 | email: false 16 | 17 | branches: 18 | only: 19 | - master 20 | - development 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | 24 | kinesis-cli/kinesis-cli 25 | -------------------------------------------------------------------------------- /auth.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | const ( 8 | AWSSecurityTokenHeader = "X-Amz-Security-Token" 9 | ) 10 | 11 | // Auth interface for authentication credentials and information 12 | type Auth interface { 13 | // KeyForSigning return an access key / secret / token appropriate for signing at time now, 14 | // which as the name suggests, is usually now. 15 | KeyForSigning(now time.Time) (*SigningKey, error) 16 | } 17 | 18 | // SigningKey returns a set of data needed for signing 19 | type SigningKey struct { 20 | AccessKeyId string 21 | SecretAccessKey string 22 | SessionToken string 23 | } 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013 SendGrid, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-kinesis 2 | 3 | [![Build Status](https://travis-ci.org/sendgridlabs/go-kinesis.png?branch=master)](https://travis-ci.org/sendgridlabs/go-kinesis) 4 | 5 | GO-lang library for AWS Kinesis API. 6 | 7 | ## Documentation 8 | 9 | * [Core API](http://godoc.org/github.com/sendgridlabs/go-kinesis) 10 | * [Batch Producer API](http://godoc.org/github.com/sendgridlabs/go-kinesis/batchproducer) 11 | 12 | ## Example 13 | 14 | Example you can find in folder `examples`. 15 | 16 | ## Command line interface 17 | 18 | You can find a tool for interacting with kinesis from the command line in folder `kinesis-cli`. 19 | 20 | ## Testing 21 | 22 | ### Local Kinesis Server 23 | 24 | The tests require a local Kinesis server such as [Kinesalite](https://github.com/mhart/kinesalite) 25 | to be running and reachable at `http://127.0.0.1:4567`. 26 | 27 | To make the tests complete faster, you might want to have Kinesalite perform stream creation and 28 | deletion faster than the default of 500ms, like so: 29 | 30 | kinesalite --createStreamMs 5 --deleteStreamMs 5 & 31 | 32 | The `&` runs Kinesalite in the background, which is probably what you want. 33 | 34 | ### go test 35 | 36 | Some of the tests are marked as safe to be run in parallel, so to speed up test execution you might 37 | want to run `go test` with [the `-parallel n` flag](https://golang.org/cmd/go/#hdr-Description_of_testing_flags). 38 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // Client is like http.Client, but signs all requests using Auth. 8 | type Client struct { 9 | // Auth holds the credentials for this client instance 10 | auth Auth 11 | // The http client to make requests with. If nil, http.DefaultClient is used. 12 | client *http.Client 13 | } 14 | 15 | // NewClient creates a new Client that uses the credentials in the specified 16 | // Auth object. 17 | // 18 | // This function assumes the Auth object has been sanely initialized. If you 19 | // wish to infer auth credentials from the environment, refer to NewAuth 20 | func NewClient(auth Auth) *Client { 21 | return &Client{auth: auth, client: http.DefaultClient} 22 | } 23 | 24 | // NewClientWithHTTPClient creates a client with a non-default http client 25 | // ie. a timeout could be set on the HTTP client to timeout if Kinesis doesn't 26 | // response in a timely manner like after the 5 minute mark where the current 27 | // shard iterator expires 28 | func NewClientWithHTTPClient(auth Auth, httpClient *http.Client) *Client { 29 | return &Client{auth: auth, client: httpClient} 30 | } 31 | 32 | // Do some request, but sign it before sending 33 | func (c *Client) Do(req *http.Request) (*http.Response, error) { 34 | err := Sign(c.auth, req) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | return c.client.Do(req) 40 | } 41 | -------------------------------------------------------------------------------- /auth_cachedmutexedwarmedup.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // newCachedMutexedWarmedUpAuth wraps another auth object 9 | // with a cache that is thread-safe, and will always attempt 10 | // to fetch credentials when initialised. 11 | // The underlying Auth object will only be called if the time is 12 | // past the last returned expiration time. 13 | func newCachedMutexedWarmedUpAuth(underlying temporaryCredentialGenerator) (Auth, error) { 14 | rv := &cachedMutexedAuth{ 15 | underlying: underlying, 16 | } 17 | _, err := rv.KeyForSigning(time.Now()) 18 | if err != nil { 19 | return nil, err 20 | } 21 | return rv, nil 22 | } 23 | 24 | // Auth interface for authentication credentials and information 25 | type temporaryCredentialGenerator interface { 26 | // KeyForSigning return an access key / secret / token appropriate for signing at time now, 27 | // which as the name suggests, is usually now. 28 | // Additionally returns the expriration time of these credentials. 29 | ExpiringKeyForSigning(now time.Time) (*SigningKey, time.Time, error) 30 | } 31 | 32 | type cachedMutexedAuth struct { 33 | mu sync.Mutex 34 | current *SigningKey 35 | expiration time.Time 36 | underlying temporaryCredentialGenerator 37 | } 38 | 39 | func (cmuxa *cachedMutexedAuth) KeyForSigning(now time.Time) (*SigningKey, error) { 40 | cmuxa.mu.Lock() 41 | defer cmuxa.mu.Unlock() 42 | 43 | if cmuxa.current == nil || !cmuxa.expiration.After(now) { 44 | newCurrent, newExpiration, err := cmuxa.underlying.ExpiringKeyForSigning(now) 45 | if err != nil { 46 | return nil, err 47 | } 48 | cmuxa.current = newCurrent 49 | cmuxa.expiration = newExpiration 50 | } 51 | 52 | return cmuxa.current, nil 53 | } 54 | -------------------------------------------------------------------------------- /auth_static.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "time" 7 | ) 8 | 9 | const ( 10 | AccessEnvKey = "AWS_ACCESS_KEY" 11 | AccessEnvKeyId = "AWS_ACCESS_KEY_ID" 12 | SecretEnvKey = "AWS_SECRET_KEY" 13 | SecretEnvAccessKey = "AWS_SECRET_ACCESS_KEY" 14 | SecurityTokenEnvKey = "AWS_SECURITY_TOKEN" 15 | ) 16 | 17 | // NewAuth creates return an auth object that uses static 18 | // credentials which do not automatically renew. 19 | func NewAuth(accessKey, secretKey, token string) Auth { 20 | return &staticAuth{ 21 | staticCreds: &SigningKey{ 22 | AccessKeyId: accessKey, 23 | SecretAccessKey: secretKey, 24 | SessionToken: token, 25 | }, 26 | } 27 | } 28 | 29 | // NewAuthFromEnv retrieves auth credentials from environment vars 30 | func NewAuthFromEnv() (Auth, error) { 31 | accessKey := os.Getenv(AccessEnvKey) 32 | if accessKey == "" { 33 | accessKey = os.Getenv(AccessEnvKeyId) 34 | } 35 | 36 | secretKey := os.Getenv(SecretEnvKey) 37 | if secretKey == "" { 38 | secretKey = os.Getenv(SecretEnvAccessKey) 39 | } 40 | 41 | token := os.Getenv(SecurityTokenEnvKey) 42 | 43 | if accessKey == "" && secretKey == "" && token == "" { 44 | return nil, fmt.Errorf("No access key (%s or %s), secret key (%s or %s), or security token (%s) env variables were set", AccessEnvKey, AccessEnvKeyId, SecretEnvKey, SecretEnvAccessKey, SecurityTokenEnvKey) 45 | } 46 | if accessKey == "" { 47 | return nil, fmt.Errorf("Unable to retrieve access key from %s or %s env variables", AccessEnvKey, AccessEnvKeyId) 48 | } 49 | if secretKey == "" { 50 | return nil, fmt.Errorf("Unable to retrieve secret key from %s or %s env variables", SecretEnvKey, SecretEnvAccessKey) 51 | } 52 | 53 | return NewAuth(accessKey, secretKey, token), nil 54 | } 55 | 56 | type staticAuth struct { 57 | staticCreds *SigningKey 58 | } 59 | 60 | func (sc *staticAuth) KeyForSigning(now time.Time) (*SigningKey, error) { 61 | return sc.staticCreds, nil 62 | } 63 | -------------------------------------------------------------------------------- /sign_test.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | var testSignFactoryData = []struct { 10 | AWS_KEY string 11 | AWS_SECRET string 12 | TOKEN string 13 | DateHeader string 14 | AuthHeader string 15 | }{ 16 | {"ASWKEY", "AWSSECRET", "TOKEN1", "Thu, 28 Nov 2013 15:04:05 GMT", "AWS4-HMAC-SHA256 Credential=ASWKEY/20131128/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;date;host;user-agent;x-amz-target, Signature=6c21aca39f1d4afd383fbc45dd3a580192036162f74bf9fda6cad6c6fb7cde2f"}, 17 | {"ASWKEY2", "AWSSECRET2", "TOKEN2", "Thu, 28 Nov 2013 15:04:05 GMT", "AWS4-HMAC-SHA256 Credential=ASWKEY2/20131128/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;date;host;user-agent;x-amz-target, Signature=488ee09d2d56e747beb5653064d7976cb67136a2afa6013d82ff36d6ae95d263"}, 18 | {"ASWNEWKEY", "AWSSECRET", "TOKEN3", "Thu, 28 Nov 2013 15:04:05 GMT", "AWS4-HMAC-SHA256 Credential=ASWNEWKEY/20131128/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;date;host;user-agent;x-amz-target, Signature=6c21aca39f1d4afd383fbc45dd3a580192036162f74bf9fda6cad6c6fb7cde2f"}, 19 | {"ASWKEY", "AWSSECRET", "TOKEN4", "Mon, 25 Nov 2013 15:04:05 GMT", "AWS4-HMAC-SHA256 Credential=ASWKEY/20131125/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;date;host;user-agent;x-amz-target, Signature=cec25de1e72db69dd48ff4895dc4022e31dc5933209d5bce61286779d49a95e5"}, 20 | } 21 | 22 | func TestSign(t *testing.T) { 23 | for _, data := range testSignFactoryData { 24 | request, err := http.NewRequest("POST", "https://kinesis.us-east-1.amazonaws.com", strings.NewReader("{}")) 25 | if err != nil { 26 | t.Errorf("NewRequest Error %v", err) 27 | } 28 | 29 | request.Header.Set("Content-Type", "application/x-amz-json-1.1") 30 | request.Header.Set("X-Amz-Target", "") 31 | request.Header.Set("User-Agent", "Golang Kinesis") 32 | 33 | request.Header.Set("Date", data.DateHeader) 34 | err = Sign(NewAuth(data.AWS_KEY, data.AWS_SECRET, data.TOKEN), request) 35 | if err != nil { 36 | t.Errorf("Error on sign (%v)", err) 37 | continue 38 | } 39 | if request.Header.Get("Authorization") != data.AuthHeader { 40 | t.Errorf("Get this header (%v), but expect this (%v)", request.Header.Get("Authorization"), data.AuthHeader) 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /auth_assumerole.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "bytes" 5 | "encoding/xml" 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "net/url" 10 | "time" 11 | ) 12 | 13 | // NewAuthWithAssumedRole will call STS in a given region to assume a role 14 | // stsAuth object is used to authenticate to STS to fetch temporary credentials 15 | // for the desired role. 16 | func NewAuthWithAssumedRole(roleArn, sessionName, region string, stsAuth Auth) (Auth, error) { 17 | return newCachedMutexedWarmedUpAuth(&stsCreds{ 18 | RoleARN: roleArn, 19 | SessionName: sessionName, 20 | Region: region, 21 | STSAuth: stsAuth, 22 | }) 23 | } 24 | 25 | type stsCreds struct { 26 | RoleARN string 27 | SessionName string 28 | Region string 29 | STSAuth Auth 30 | } 31 | 32 | func (sts *stsCreds) ExpiringKeyForSigning(now time.Time) (*SigningKey, time.Time, error) { 33 | r, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://sts.%s.amazonaws.com/?%s", sts.Region, (url.Values{ 34 | "Version": []string{"2011-06-15"}, 35 | "Action": []string{"AssumeRole"}, 36 | "RoleSessionName": []string{sts.SessionName}, 37 | "RoleArn": []string{sts.RoleARN}, 38 | }).Encode()), bytes.NewReader([]byte{})) 39 | if err != nil { 40 | return nil, time.Time{}, err 41 | } 42 | 43 | err = (&Service{ 44 | Name: "sts", 45 | Region: sts.Region, 46 | }).Sign(sts.STSAuth, r) 47 | if err != nil { 48 | return nil, time.Time{}, err 49 | } 50 | 51 | resp, err := http.DefaultClient.Do(r) 52 | if err != nil { 53 | return nil, time.Time{}, err 54 | } 55 | defer resp.Body.Close() 56 | 57 | if resp.StatusCode != http.StatusOK { 58 | return nil, time.Time{}, errors.New("bad status code") 59 | } 60 | 61 | var wrapper struct { 62 | AssumeRoleResult struct { 63 | Credentials struct { 64 | AccessKeyId string 65 | SecretAccessKey string 66 | SessionToken string 67 | Expiration time.Time 68 | } 69 | } 70 | } 71 | err = xml.NewDecoder(resp.Body).Decode(&wrapper) 72 | if err != nil { 73 | return nil, time.Time{}, err 74 | } 75 | 76 | // sanity check at least 1 field 77 | if wrapper.AssumeRoleResult.Credentials.SecretAccessKey == "" { 78 | return nil, time.Time{}, errors.New("bad data back") 79 | } 80 | 81 | return &SigningKey{ 82 | AccessKeyId: wrapper.AssumeRoleResult.Credentials.AccessKeyId, 83 | SecretAccessKey: wrapper.AssumeRoleResult.Credentials.SecretAccessKey, 84 | SessionToken: wrapper.AssumeRoleResult.Credentials.SessionToken, 85 | }, wrapper.AssumeRoleResult.Credentials.Expiration, nil 86 | } 87 | -------------------------------------------------------------------------------- /firehose.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | // PutRecordBatchResp stores the information that provides by PutRecordBatch API call 4 | type PutRecordBatchResp struct { 5 | FailedPutCount int 6 | RequestResponses []PutRecordBatchResponses 7 | } 8 | 9 | // RecordBatchResponses stores individual Record information provided by PutRecordBatch API call 10 | type PutRecordBatchResponses struct { 11 | ErrorCode string 12 | ErrorMessage string 13 | RecordId string 14 | } 15 | 16 | type S3DestinationDescriptionResp struct { 17 | BucketARN string 18 | BufferingHints struct { 19 | IntervalInSeconds int 20 | SizeInMBs int 21 | } 22 | CompressionFormat string 23 | EncryptionConfiguration struct { 24 | KMSEncryptionConfig struct { 25 | AWSKMSKeyARN string 26 | } 27 | NoEncryptionConfig string 28 | } 29 | Prefix string 30 | RoleARN string 31 | } 32 | 33 | type RedshiftDestinationDescriptionResp struct { 34 | ClusterJDBCURL string 35 | CopyCommand struct { 36 | CopyOptions string 37 | DataTableColumns string 38 | DataTableName string 39 | } 40 | RoleARN string 41 | S3DestinationDescription S3DestinationDescriptionResp 42 | Username string 43 | } 44 | 45 | type DestinationsResp struct { 46 | DestinationId string 47 | RedshiftDestinationDescription RedshiftDestinationDescriptionResp 48 | S3DestinationDescription S3DestinationDescriptionResp 49 | } 50 | 51 | // DescribeDeliveryStreamResp stores the information that provides by the Firehose DescribeDeliveryStream API call 52 | type DescribeDeliveryStreamResp struct { 53 | DeliveryStreamDescription struct { 54 | CreateTimestamp float32 55 | DeliveryStreamARN string 56 | DeliveryStreamName string 57 | DeliveryStreamStatus string 58 | Destinations []DestinationsResp 59 | HasMoreDestinations bool 60 | LastUpdatedTimestamp int 61 | VersionId string 62 | } 63 | } 64 | 65 | // http://docs.aws.amazon.com/firehose/latest/APIReference/API_DescribeDeliveryStream.html 66 | func (kinesis *Kinesis) DescribeDeliveryStream(args *RequestArgs) (resp *DescribeDeliveryStreamResp, err error) { 67 | kinesis.Firehose() 68 | params := makeParams("DescribeDeliveryStream") 69 | resp = &DescribeDeliveryStreamResp{} 70 | err = kinesis.query(params, args.params, resp) 71 | if err != nil { 72 | return nil, err 73 | } 74 | return 75 | } 76 | 77 | // http://docs.aws.amazon.com/firehose/latest/APIReference/API_PutRecordBatch.html 78 | func (kinesis *Kinesis) PutRecordBatch(args *RequestArgs) (resp *PutRecordBatchResp, err error) { 79 | kinesis.Firehose() 80 | 81 | params := makeParams("PutRecordBatch") 82 | resp = &PutRecordBatchResp{} 83 | args.Add("Records", args.Records) 84 | err = kinesis.query(params, args.params, resp) 85 | 86 | if err != nil { 87 | return nil, err 88 | } 89 | return 90 | } 91 | -------------------------------------------------------------------------------- /auth_metadata.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "io/ioutil" 8 | "net/http" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | const ( 14 | AWSMetadataServer = "169.254.169.254" 15 | AWSIAMCredsPath = "/latest/meta-data/iam/security-credentials" 16 | AWSIAMCredsURL = "http://" + AWSMetadataServer + "/" + AWSIAMCredsPath 17 | ) 18 | 19 | // NewAuthFromMetadata retrieves auth credentials from the metadata 20 | // server. If an IAM role is associated with the instance we are running on, the 21 | // metadata server will expose credentials for that role under a known endpoint. 22 | // 23 | // TODO: specify custom network (connect, read) timeouts, else this will block 24 | // for the default timeout durations. 25 | func NewAuthFromMetadata() (Auth, error) { 26 | return newCachedMutexedWarmedUpAuth(&metadataCreds{}) 27 | } 28 | 29 | type metadataCreds struct{} 30 | 31 | func (mc *metadataCreds) ExpiringKeyForSigning(now time.Time) (*SigningKey, time.Time, error) { 32 | role, err := retrieveIAMRole() 33 | if err != nil { 34 | return nil, time.Time{}, err 35 | } 36 | 37 | data, err := retrieveAWSCredentials(role) 38 | if err != nil { 39 | return nil, time.Time{}, err 40 | } 41 | 42 | expiry, err := time.Parse(time.RFC3339, data["Expiration"]) 43 | if err != nil { 44 | return nil, time.Time{}, err 45 | } 46 | 47 | return &SigningKey{ 48 | AccessKeyId: data["AccessKeyId"], 49 | SecretAccessKey: data["SecretAccessKey"], 50 | SessionToken: data["Token"], 51 | }, expiry, nil 52 | } 53 | 54 | func retrieveAWSCredentials(role string) (map[string]string, error) { 55 | var bodybytes []byte 56 | 57 | client := http.Client{ 58 | Timeout: time.Duration(10 * time.Second), 59 | } 60 | 61 | // Retrieve the json for this role 62 | resp, err := client.Get(fmt.Sprintf("%s/%s", AWSIAMCredsURL, role)) 63 | if err != nil || resp.StatusCode != http.StatusOK { 64 | return nil, err 65 | } 66 | defer resp.Body.Close() 67 | 68 | bodybytes, err = ioutil.ReadAll(resp.Body) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | jsondata := make(map[string]string) 74 | err = json.Unmarshal(bodybytes, &jsondata) 75 | if err != nil { 76 | return nil, err 77 | } 78 | 79 | return jsondata, nil 80 | } 81 | 82 | func retrieveIAMRole() (string, error) { 83 | var bodybytes []byte 84 | 85 | client := http.Client{ 86 | Timeout: time.Duration(10 * time.Second), 87 | } 88 | 89 | resp, err := client.Get(AWSIAMCredsURL) 90 | if err != nil || resp.StatusCode != http.StatusOK { 91 | return "", err 92 | } 93 | defer resp.Body.Close() 94 | 95 | bodybytes, err = ioutil.ReadAll(resp.Body) 96 | if err != nil { 97 | return "", err 98 | } 99 | 100 | // pick the first IAM role 101 | role := strings.Split(string(bodybytes), "\n")[0] 102 | if len(role) == 0 { 103 | return "", errors.New("Unable to retrieve IAM role") 104 | } 105 | 106 | return role, nil 107 | } 108 | -------------------------------------------------------------------------------- /auth_test.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestGetSecretKey(t *testing.T) { 10 | auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY", "BAD_SECURITY_TOKEN") 11 | sk, _ := auth.KeyForSigning(time.Now()) 12 | if sk.AccessKeyId != "BAD_ACCESS_KEY" { 13 | t.Error("incorrect value for auth#accessKey") 14 | } 15 | } 16 | 17 | func TestGetAccessKey(t *testing.T) { 18 | auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY", "BAD_SECURITY_TOKEN") 19 | sk, _ := auth.KeyForSigning(time.Now()) 20 | if sk.SecretAccessKey != "BAD_SECRET_KEY" { 21 | t.Error("incorrect value for auth#secretKey") 22 | } 23 | } 24 | 25 | func TestGetToken(t *testing.T) { 26 | auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY", "BAD_SECURITY_TOKEN") 27 | sk, _ := auth.KeyForSigning(time.Now()) 28 | if sk.SessionToken != "BAD_SECURITY_TOKEN" { 29 | t.Error("incorrect value for auth#token") 30 | } 31 | } 32 | 33 | func TestNewAuthFromEnv(t *testing.T) { 34 | os.Setenv(AccessEnvKey, "asdf") 35 | os.Setenv(SecretEnvKey, "asdf2") 36 | os.Setenv(SecurityTokenEnvKey, "dummy_token") 37 | // Validate that the fallback environment variables will also work 38 | defer os.Unsetenv(AccessEnvKey) 39 | defer os.Unsetenv(SecretEnvKey) 40 | defer os.Unsetenv(SecurityTokenEnvKey) 41 | 42 | auth, _ := NewAuthFromEnv() 43 | sk, _ := auth.KeyForSigning(time.Now()) 44 | 45 | if sk.AccessKeyId != "asdf" { 46 | t.Error("Expected AccessKey to be inferred as \"asdf\"") 47 | } 48 | 49 | if sk.SecretAccessKey != "asdf2" { 50 | t.Error("Expected SecretKey to be inferred as \"asdf2\"") 51 | } 52 | 53 | if sk.SessionToken != "dummy_token" { 54 | t.Error("Expected SecurityToken to be inferred as \"dummy_token\"") 55 | } 56 | } 57 | 58 | func TestNewAuthFromEnvWithoutSecurityToken(t *testing.T) { 59 | os.Setenv(AccessEnvKey, "asdf") 60 | os.Setenv(SecretEnvKey, "asdf2") 61 | os.Unsetenv(SecurityTokenEnvKey) 62 | // Validate that the fallback environment variables will also work 63 | defer os.Unsetenv(AccessEnvKey) 64 | defer os.Unsetenv(SecretEnvKey) 65 | 66 | auth, _ := NewAuthFromEnv() 67 | sk, _ := auth.KeyForSigning(time.Now()) 68 | 69 | if sk.AccessKeyId != "asdf" { 70 | t.Error("Expected AccessKey to be inferred as \"asdf\"") 71 | } 72 | 73 | if sk.SecretAccessKey != "asdf2" { 74 | t.Error("Expected SecretKey to be inferred as \"asdf2\"") 75 | } 76 | 77 | if sk.SessionToken != "" { 78 | t.Error("Expected SecurityToken to be an empty string") 79 | } 80 | } 81 | 82 | func TestNewAuthFromEnvWithoutVars(t *testing.T) { 83 | os.Unsetenv(AccessEnvKey) 84 | os.Unsetenv(SecretEnvKey) 85 | os.Unsetenv(SecurityTokenEnvKey) 86 | 87 | auth, err := NewAuthFromEnv() 88 | 89 | if auth != nil { 90 | t.Error("Expected auth instance to be nil but was non-nil") 91 | } 92 | 93 | if err == nil { 94 | t.Error("Expected error to be non-nil but was nil") 95 | } 96 | } 97 | func TestNewAuthFromEnvWithFallbackVars(t *testing.T) { 98 | os.Setenv(AccessEnvKeyId, "asdf") 99 | os.Setenv(SecretEnvAccessKey, "asdf2") 100 | os.Setenv(SecurityTokenEnvKey, "dummy_token") 101 | defer os.Unsetenv(AccessEnvKey) 102 | defer os.Unsetenv(SecretEnvKey) 103 | defer os.Unsetenv(SecurityTokenEnvKey) 104 | 105 | auth, _ := NewAuthFromEnv() 106 | sk, _ := auth.KeyForSigning(time.Now()) 107 | 108 | if sk.AccessKeyId != "asdf" { 109 | t.Error("Expected AccessKey to be inferred as \"asdf\"") 110 | } 111 | 112 | if sk.SecretAccessKey != "asdf2" { 113 | t.Error("Expected SecretKey to be inferred as \"asdf2\"") 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /kinesis-cli/README.md: -------------------------------------------------------------------------------- 1 | # kinesis-cli 2 | 3 | kinesis-cli is a tool for interacting with kinesis from the command line. 4 | 5 | ## Setup 6 | 7 | You can either install the kinesis-cli using `go get` or `go install`: 8 | (TODO: verify this works) 9 | 10 | $ go install github.com/sendgridlabs/go-kinesis/kinesis-cli 11 | 12 | or build it and run it from the kinesis-cli folder: 13 | 14 | ``` 15 | $ go get github.com/sendgridlabs/go-kinesis/kinesis-cli 16 | $ cd $GOPATH/src/github.com/sendgridlabs/go-kinesis/kinesis-cli 17 | $ go build 18 | $ ./kinesis-cli 19 | Usage: ./kinesis-cli [, ...] 20 | (Note: expects $AWS_ACCESS_KEY and $AWS_SECRET_KEY to be set) 21 | Commands: 22 | create [] 23 | delete 24 | describe [ ] 25 | split [] 26 | merge 27 | ``` 28 | 29 | Note that you'll need to store your access/secret key in the proper env vars: 30 | 31 | $ export AWS_ACCESS_KEY=123myaccesskey456; export AWS_SECRET_KEY=789myVerySecretKey432 32 | 33 | ## Usage 34 | 35 | For all commands except `describe`, you will be prompted for confirmation before the aws request is sent. 36 | 37 | ##### Create a new stream: (only a single shard is created if num shards is not specified) 38 | 39 | $ ./kinesis-cli create somestream 2 40 | 41 | ##### Delete an existing stream: 42 | 43 | $ ./kinesis-cli delete somestream 44 | 45 | ##### Describe a stream: 46 | 47 | ``` 48 | $ ./kinesis-cli describe somestream 49 | { 50 | "StreamDescription": { 51 | "HasMoreShards": false, 52 | "Shards": [ 53 | { 54 | "AdjacentParentShardId": "", 55 | "HashKeyRange": { 56 | "EndingHashKey": "170141183460469231731687303715884105727", 57 | "StartingHashKey": "0" 58 | }, 59 | "ParentShardId": "", 60 | "SequenceNumberRange": { 61 | "EndingSequenceNumber": "", 62 | "StartingSequenceNumber": "49540491727041816751370913972624375777284624614827229185" 63 | }, 64 | "ShardId": "shardId-000000000000" 65 | }, 66 | { 67 | "AdjacentParentShardId": "", 68 | "HashKeyRange": { 69 | "EndingHashKey": "340282366920938463463374607431768211455", 70 | "StartingHashKey": "170141183460469231731687303715884105728" 71 | }, 72 | "ParentShardId": "", 73 | "SequenceNumberRange": { 74 | "EndingSequenceNumber": "", 75 | "StartingSequenceNumber": "49540491727064117496569444595765911495557272976333209617" 76 | }, 77 | "ShardId": "shardId-000000000001" 78 | } 79 | ], 80 | "StreamARN": "arn:aws:kinesis:us-east-1:123456789:stream/somestream", 81 | "StreamName": "somestream", 82 | "StreamStatus": "ACTIVE" 83 | } 84 | } 85 | 86 | ``` 87 | 88 | ##### Split a shard: (it will suggest a new hash key that evenly splits the shard) 89 | 90 | ``` 91 | $ ./kinesis-cli split somestream shardId-000000000000 92 | Shard's current hash key range (0 - 170141183460469231731687303715884105727) 93 | Default (even split) key: 85070591730234615865843651857942052863 94 | Type new key or press [enter] to choose default: 95 | Are you sure you want to split shard shardId-000000000000 at hash key 85070591730234615865843651857942052863? 96 | (y/N): y 97 | ``` 98 | 99 | ##### Merge two adjacent shards: (must be specified in low->high order) 100 | 101 | $ go build && ./kinesis-cli merge somestream shardId-000000000003 shardId-000000000001 102 | -------------------------------------------------------------------------------- /examples/example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "time" 7 | 8 | // kinesis "github.com/sendgridlabs/go-kinesis" 9 | kinesis "github.com/sendgridlabs/go-kinesis" 10 | ) 11 | 12 | func getRecords(ksis kinesis.KinesisClient, streamName, ShardId string) { 13 | args := kinesis.NewArgs() 14 | args.Add("StreamName", streamName) 15 | args.Add("ShardId", ShardId) 16 | args.Add("ShardIteratorType", "TRIM_HORIZON") 17 | resp10, _ := ksis.GetShardIterator(args) 18 | 19 | shardIterator := resp10.ShardIterator 20 | 21 | for { 22 | args = kinesis.NewArgs() 23 | args.Add("ShardIterator", shardIterator) 24 | resp11, err := ksis.GetRecords(args) 25 | if err != nil { 26 | time.Sleep(1000 * time.Millisecond) 27 | continue 28 | } 29 | 30 | if len(resp11.Records) > 0 { 31 | fmt.Printf("GetRecords Data BEGIN\n") 32 | for _, d := range resp11.Records { 33 | fmt.Printf("GetRecords Data: %v\n", string(d.GetData())) 34 | } 35 | fmt.Printf("GetRecords Data END\n") 36 | } else if resp11.NextShardIterator == "" || shardIterator == resp11.NextShardIterator || err != nil { 37 | fmt.Printf("GetRecords ERROR: %v\n", err) 38 | break 39 | } 40 | 41 | shardIterator = resp11.NextShardIterator 42 | time.Sleep(1000 * time.Millisecond) 43 | } 44 | } 45 | 46 | func main() { 47 | fmt.Println("Begin") 48 | var ( 49 | err error 50 | auth kinesis.Auth 51 | ) 52 | 53 | streamName := "test" 54 | // set env variables AWS_ACCESS_KEY and AWS_SECRET_KEY AWS_REGION_NAME 55 | auth, err = kinesis.NewAuthFromEnv() 56 | if err != nil { 57 | fmt.Printf("Unable to retrieve authentication credentials from the environment: %v", err) 58 | os.Exit(1) 59 | } 60 | region := os.Getenv("AWS_REGION_NAME") 61 | ksis := kinesis.New(auth, region) 62 | 63 | err = ksis.CreateStream(streamName, 2) 64 | if err != nil { 65 | fmt.Printf("CreateStream ERROR: %v\n", err) 66 | } 67 | 68 | args := kinesis.NewArgs() 69 | resp2, _ := ksis.ListStreams(args) 70 | fmt.Printf("ListStreams: %v\n", resp2) 71 | 72 | resp3 := &kinesis.DescribeStreamResp{} 73 | 74 | timeout := make(chan bool, 30) 75 | for { 76 | 77 | args = kinesis.NewArgs() 78 | args.Add("StreamName", streamName) 79 | resp3, _ = ksis.DescribeStream(args) 80 | fmt.Printf("DescribeStream: %v\n", resp3) 81 | 82 | if resp3.StreamDescription.StreamStatus != "ACTIVE" { 83 | time.Sleep(4 * time.Second) 84 | timeout <- true 85 | } else { 86 | break 87 | } 88 | 89 | } 90 | 91 | // Put records individually 92 | for i := 0; i < 10; i++ { 93 | args = kinesis.NewArgs() 94 | args.Add("StreamName", streamName) 95 | data := []byte(fmt.Sprintf("Hello AWS Kinesis %d", i)) 96 | partitionKey := fmt.Sprintf("partitionKey-%d", i) 97 | args.AddRecord(data, partitionKey) 98 | resp4, err := ksis.PutRecord(args) 99 | if err != nil { 100 | fmt.Printf("PutRecord err: %v\n", err) 101 | } else { 102 | fmt.Printf("PutRecord: %v\n", resp4) 103 | } 104 | } 105 | 106 | for _, shard := range resp3.StreamDescription.Shards { 107 | go getRecords(ksis, streamName, shard.ShardId) 108 | } 109 | 110 | // Put records in batch 111 | args = kinesis.NewArgs() 112 | args.Add("StreamName", streamName) 113 | 114 | for i := 0; i < 10; i++ { 115 | args.AddRecord( 116 | []byte(fmt.Sprintf("Hello AWS Kinesis %d", i)), 117 | fmt.Sprintf("partitionKey-%d", i), 118 | ) 119 | } 120 | 121 | resp4, err := ksis.PutRecords(args) 122 | if err != nil { 123 | fmt.Printf("PutRecords err: %v\n", err) 124 | } else { 125 | fmt.Printf("PutRecords: %v\n", resp4) 126 | } 127 | 128 | // Wait for user input 129 | var inputGuess string 130 | fmt.Scanf("%s\n", &inputGuess) 131 | 132 | // Delete the stream 133 | err1 := ksis.DeleteStream("test") 134 | if err1 != nil { 135 | fmt.Printf("DeleteStream ERROR: %v\n", err1) 136 | } 137 | 138 | fmt.Println("End") 139 | } 140 | -------------------------------------------------------------------------------- /sign.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "bytes" 5 | "crypto/hmac" 6 | "crypto/sha256" 7 | "fmt" 8 | "io" 9 | "io/ioutil" 10 | "net/http" 11 | "net/url" 12 | "path" 13 | "regexp" 14 | "sort" 15 | "strings" 16 | "time" 17 | ) 18 | 19 | const ( 20 | iSO8601BasicFormat = "20060102T150405Z" 21 | iSO8601BasicFormatShort = "20060102" 22 | AWS4_URL = "aws4_request" 23 | ) 24 | 25 | var lf = []byte{'\n'} 26 | var awsKinesisRegexp = regexp.MustCompile(`kinesis.(.*).amazonaws.com`) 27 | 28 | // Service represents an AWS-compatible service. 29 | type Service struct { 30 | // Name is the name of the service being used (i.e. iam, etc) 31 | Name string 32 | 33 | // Region is the region you want to communicate with the service through. (i.e. us-east-1) 34 | Region string 35 | } 36 | 37 | // Sign signs a request with a Service derived from r.Host 38 | func Sign(authKeys Auth, r *http.Request) error { 39 | sv := new(Service) 40 | if awsKinesisRegexp.MatchString(r.Host) { 41 | parts := strings.Split(r.Host, ".") 42 | sv.Name = parts[0] 43 | sv.Region = parts[1] 44 | } 45 | return sv.Sign(authKeys, r) 46 | } 47 | 48 | // Sign signs an HTTP request with the given AWS keys for use on service s. 49 | func (s *Service) Sign(authKeys Auth, r *http.Request) error { 50 | date := r.Header.Get("Date") 51 | t := time.Now().UTC() 52 | if date != "" { 53 | var err error 54 | t, err = time.Parse(http.TimeFormat, date) 55 | if err != nil { 56 | return err 57 | } 58 | } 59 | r.Header.Set("Date", t.Format(iSO8601BasicFormat)) 60 | 61 | sk, err := authKeys.KeyForSigning(t) 62 | if err != nil { 63 | return err 64 | } 65 | 66 | k := ghmac([]byte("AWS4"+sk.SecretAccessKey), []byte(t.Format(iSO8601BasicFormatShort))) 67 | k = ghmac(k, []byte(s.Region)) 68 | k = ghmac(k, []byte(s.Name)) 69 | k = ghmac(k, []byte(AWS4_URL)) 70 | 71 | h := hmac.New(sha256.New, k) 72 | s.writeStringToSign(h, t, r) 73 | 74 | auth := bytes.NewBufferString("AWS4-HMAC-SHA256 ") 75 | auth.Write([]byte("Credential=" + sk.AccessKeyId + "/" + s.creds(t))) 76 | auth.Write([]byte{',', ' '}) 77 | auth.Write([]byte("SignedHeaders=")) 78 | s.writeHeaderList(auth, r) 79 | auth.Write([]byte{',', ' '}) 80 | auth.Write([]byte("Signature=" + fmt.Sprintf("%x", h.Sum(nil)))) 81 | 82 | r.Header.Set("Authorization", auth.String()) 83 | 84 | if sk.SessionToken != "" { 85 | r.Header.Add(AWSSecurityTokenHeader, sk.SessionToken) 86 | } 87 | 88 | return nil 89 | } 90 | 91 | func (s *Service) writeQuery(w io.Writer, r *http.Request) { 92 | var a []string 93 | for k, vs := range r.URL.Query() { 94 | k = url.QueryEscape(k) 95 | for _, v := range vs { 96 | if v == "" { 97 | a = append(a, k) 98 | } else { 99 | v = url.QueryEscape(v) 100 | a = append(a, k+"="+v) 101 | } 102 | } 103 | } 104 | sort.Strings(a) 105 | for i, s := range a { 106 | if i > 0 { 107 | w.Write([]byte{'&'}) 108 | } 109 | w.Write([]byte(s)) 110 | } 111 | } 112 | 113 | func (s *Service) writeHeader(w io.Writer, r *http.Request) { 114 | i, a := 0, make([]string, len(r.Header)) 115 | for k, v := range r.Header { 116 | sort.Strings(v) 117 | a[i] = strings.ToLower(k) + ":" + strings.Join(v, ",") 118 | i++ 119 | } 120 | sort.Strings(a) 121 | for i, s := range a { 122 | if i > 0 { 123 | w.Write(lf) 124 | } 125 | io.WriteString(w, s) 126 | } 127 | } 128 | 129 | func (s *Service) writeHeaderList(w io.Writer, r *http.Request) { 130 | i, a := 0, make([]string, len(r.Header)) 131 | for k, _ := range r.Header { 132 | a[i] = strings.ToLower(k) 133 | i++ 134 | } 135 | sort.Strings(a) 136 | for i, s := range a { 137 | if i > 0 { 138 | w.Write([]byte{';'}) 139 | } 140 | w.Write([]byte(s)) 141 | } 142 | } 143 | 144 | func (s *Service) writeBody(w io.Writer, r *http.Request) { 145 | b, err := ioutil.ReadAll(r.Body) 146 | if err != nil { 147 | panic(err) 148 | } 149 | r.Body = ioutil.NopCloser(bytes.NewBuffer(b)) 150 | 151 | h := sha256.New() 152 | h.Write(b) 153 | fmt.Fprintf(w, "%x", h.Sum(nil)) 154 | } 155 | 156 | func (s *Service) writeURI(w io.Writer, r *http.Request) { 157 | ruri := r.URL.RequestURI() 158 | if r.URL.RawQuery != "" { 159 | ruri = ruri[:len(ruri)-len(r.URL.RawQuery)-1] 160 | } 161 | slash := strings.HasSuffix(ruri, "/") 162 | ruri = path.Clean(ruri) 163 | if ruri != "/" && slash { 164 | ruri += "/" 165 | } 166 | w.Write([]byte(ruri)) 167 | } 168 | 169 | func (s *Service) writeRequest(w io.Writer, r *http.Request) { 170 | r.Header.Set("host", r.Host) 171 | 172 | w.Write([]byte(r.Method)) 173 | w.Write(lf) 174 | s.writeURI(w, r) 175 | w.Write(lf) 176 | s.writeQuery(w, r) 177 | w.Write(lf) 178 | s.writeHeader(w, r) 179 | w.Write(lf) 180 | w.Write(lf) 181 | s.writeHeaderList(w, r) 182 | w.Write(lf) 183 | s.writeBody(w, r) 184 | } 185 | 186 | func (s *Service) writeStringToSign(w io.Writer, t time.Time, r *http.Request) { 187 | w.Write([]byte("AWS4-HMAC-SHA256")) 188 | w.Write(lf) 189 | w.Write([]byte(t.Format(iSO8601BasicFormat))) 190 | w.Write(lf) 191 | 192 | w.Write([]byte(s.creds(t))) 193 | w.Write(lf) 194 | 195 | h := sha256.New() 196 | s.writeRequest(h, r) 197 | fmt.Fprintf(w, "%x", h.Sum(nil)) 198 | } 199 | 200 | func (s *Service) creds(t time.Time) string { 201 | return fmt.Sprintf("%s/%s/%s/%s", t.Format(iSO8601BasicFormatShort), s.Region, s.Name, AWS4_URL) 202 | } 203 | 204 | func ghmac(key, data []byte) []byte { 205 | h := hmac.New(sha256.New, key) 206 | h.Write(data) 207 | return h.Sum(nil) 208 | } 209 | -------------------------------------------------------------------------------- /kinesis_test.go: -------------------------------------------------------------------------------- 1 | package kinesis 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | "strings" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | const localEndpoint = "http://127.0.0.1:4567" 13 | 14 | func TestKinesisClientInterfaceIsImplemented(t *testing.T) { 15 | var client KinesisClient = &Kinesis{} 16 | if client == nil { 17 | t.Error("Invalid nil kinesis client") 18 | } 19 | } 20 | 21 | func TestRegions(t *testing.T) { 22 | os.Setenv(RegionEnvName, "REGION_TEST") 23 | 24 | if NewRegionFromEnv() != "REGION_TEST" { 25 | t.Errorf("Invalid value read from the %s environment variable", RegionEnvName) 26 | } 27 | os.Setenv(RegionEnvName, "") 28 | } 29 | 30 | func TestAddRecord(t *testing.T) { 31 | args := NewArgs() 32 | 33 | args.AddRecord( 34 | []byte("data"), 35 | "partition_key", 36 | ) 37 | 38 | if len(args.Records) != 1 { 39 | t.Errorf("%q != %q", len(args.Records), 1) 40 | } 41 | } 42 | 43 | func TestListStreams(t *testing.T) { 44 | auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY", "BAD_SECURITY_TOKEN") 45 | client := NewWithEndpoint(auth, USEast1, localEndpoint) 46 | resp, err := client.ListStreams(NewArgs()) 47 | if resp == nil { 48 | t.Error("resp == nil") 49 | } 50 | if err != nil { 51 | t.Errorf("%q != nil", err) 52 | } 53 | } 54 | 55 | func TestCreateStream(t *testing.T) { 56 | auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY", "BAD_SECURITY_TOKEN") 57 | client := NewWithEndpoint(auth, USEast1, localEndpoint) 58 | 59 | streamName := "test2" 60 | 61 | err := client.CreateStream(streamName, 1) 62 | if err != nil { 63 | t.Errorf("%q != nil", err) 64 | } 65 | 66 | err = waitForStreamStatus(client, streamName, "ACTIVE") 67 | if err != nil { 68 | t.Errorf("%q != nil", err) 69 | } 70 | 71 | client.DeleteStream(streamName) 72 | err = waitForStreamDeletion(client, streamName) 73 | if err != nil { 74 | t.Errorf("%q != nil", err) 75 | } 76 | } 77 | 78 | // Older, lower-level way to use PutRecord 79 | func TestPutRecordWithAddData(t *testing.T) { 80 | auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY", "BAD_SECURITY_TOKEN") 81 | client := NewWithEndpoint(auth, USEast1, localEndpoint) 82 | 83 | streamName := "pizza" 84 | err := createStream(client, streamName, 1) 85 | 86 | if err != nil { 87 | t.Errorf("%q != nil", err) 88 | } 89 | 90 | args := NewArgs() 91 | args.Add("StreamName", streamName) 92 | args.AddData([]byte("The cheese is old and moldy, where is the bathroom?")) 93 | args.Add("PartitionKey", "key-1") 94 | 95 | resp, err := client.PutRecord(args) 96 | if resp == nil { 97 | t.Error("resp == nil") 98 | } 99 | if err != nil { 100 | t.Errorf("%q != nil", err) 101 | } 102 | 103 | client.DeleteStream(streamName) 104 | err = waitForStreamDeletion(client, streamName) 105 | if err != nil { 106 | t.Errorf("%q != nil", err) 107 | } 108 | } 109 | 110 | // Newer, higher-level way to use PutRecord 111 | func TestPutRecordWithAddRecord(t *testing.T) { 112 | auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY", "BAD_SECURITY_TOKEN") 113 | client := NewWithEndpoint(auth, USEast1, localEndpoint) 114 | 115 | streamName := "pizza" 116 | 117 | err := createStream(client, streamName, 1) 118 | if err != nil { 119 | t.Errorf("%q != nil", err) 120 | } 121 | 122 | args := NewArgs() 123 | args.Add("StreamName", streamName) 124 | args.AddRecord([]byte("The cheese is old and moldy, where is the bathroom?"), "key-1") 125 | resp, err := client.PutRecord(args) 126 | 127 | if resp == nil { 128 | t.Error("resp == nil") 129 | } 130 | if err != nil { 131 | t.Errorf("%q != nil", err) 132 | } 133 | 134 | client.DeleteStream(streamName) 135 | err = waitForStreamDeletion(client, streamName) 136 | if err != nil { 137 | t.Errorf("%q != nil", err) 138 | } 139 | } 140 | 141 | // waitForStreamStatus will poll for a stream status repeatedly, once every MS, for up to 1000 MS, 142 | // blocking until the stream has the desired status. It will return an error if the stream never 143 | // achieves the desired status. If a stream doesn’t exist then an error will be returned. 144 | func waitForStreamStatus(client KinesisClient, streamName string, statusToAwait string) error { 145 | args := NewArgs() 146 | args.Add("StreamName", streamName) 147 | var resp3 *DescribeStreamResp 148 | var err error 149 | 150 | for i := 1; i < 1000; i++ { 151 | resp3, err = client.DescribeStream(args) 152 | if err != nil { 153 | return err 154 | } 155 | 156 | if resp3.StreamDescription.StreamStatus == statusToAwait { 157 | break 158 | } else { 159 | time.Sleep(1 * time.Millisecond) 160 | } 161 | } 162 | 163 | if resp3 == nil { 164 | return errors.New("Could not get Stream Description") 165 | } 166 | 167 | if resp3.StreamDescription.StreamStatus != statusToAwait { 168 | return errors.New(fmt.Sprintf("Timed out waiting for stream to enter status %v; last status was %v.", statusToAwait, resp3.StreamDescription.StreamStatus)) 169 | } 170 | 171 | return nil 172 | } 173 | 174 | // waitForStreamDeletion will poll for a stream status repeatedly, once every MS, for up to 1000 MS, 175 | // blocking until the stream has been deleted. It will return an error if the stream is never deleted 176 | // or some other error occurs. If it succeeds then the return value will be nil. 177 | func waitForStreamDeletion(client KinesisClient, streamName string) error { 178 | err := waitForStreamStatus(client, streamName, "FOO") 179 | if !strings.Contains(err.Error(), "not found") { 180 | return err 181 | } 182 | return nil 183 | } 184 | 185 | // helper 186 | func createStream(client KinesisClient, streamName string, partitions int) error { 187 | err := client.CreateStream(streamName, partitions) 188 | if err != nil { 189 | return err 190 | } 191 | 192 | err = waitForStreamStatus(client, streamName, "ACTIVE") 193 | if err != nil { 194 | return err 195 | } 196 | 197 | return nil 198 | } 199 | -------------------------------------------------------------------------------- /kinesis-cli/kinesis-cli.go: -------------------------------------------------------------------------------- 1 | /* 2 | kinesis-cli is a command line interface tool for interacting with AWS kinesis. 3 | 4 | To install: 5 | go get github.com/sendgridlabs/go-kinesis/kinesis-cli 6 | 7 | To build: 8 | cd $GOPATH/src/github.com/sendgridlabs/go-kinesis/kinesis-cli; go build 9 | 10 | To use: 11 | run ./kinesis-cli to see the usage. 12 | */ 13 | 14 | package main 15 | 16 | import ( 17 | "bufio" 18 | "encoding/json" 19 | "fmt" 20 | "math/big" 21 | "os" 22 | "strconv" 23 | "strings" 24 | 25 | // "github.com/sendgridlabs/go-kinesis" 26 | "github.com/sendgridlabs/go-kinesis" 27 | ) 28 | 29 | const HELP = `Usage: ./kinesis-cli [, ...] 30 | (Note: expects $AWS_ACCESS_KEY, $AWS_SECRET_KEY and $AWS_REGION_NAME to be set) 31 | Commands: 32 | create [<# shards>] 33 | delete 34 | describe [ ] 35 | split [] 36 | merge 37 | 38 | ` 39 | 40 | var EMPTY_STRING = "" 41 | var EMPTY_INT = -1 42 | var DEFAULT_NUM_SHARDS = 1 43 | 44 | func create(args []string) { 45 | streamName := getArg(args, 0, "stream name", nil) 46 | numShards := getIntArg(args, 1, "stream name", &DEFAULT_NUM_SHARDS) 47 | if !confirm(fmt.Sprintf("create stream '%s' with %d shard(s)", streamName, numShards)) { 48 | fmt.Println("Create canceled.") 49 | return 50 | } 51 | if err := newClient().CreateStream(streamName, numShards); err != nil { 52 | die(false, "Error creating shard: %s", err) 53 | } 54 | } 55 | 56 | func delete(args []string) { 57 | streamName := getArg(args, 0, "stream name", nil) 58 | if !confirm("delete stream '" + streamName + "'") { 59 | fmt.Println("Delete canceled.") 60 | return 61 | } 62 | if err := newClient().DeleteStream(streamName); err != nil { 63 | die(false, "Error deleting shard: %s", err) 64 | } 65 | } 66 | 67 | func describe(args []string) { 68 | streamName := getArg(args, 0, "stream name", nil) 69 | exclusiveStartShardId := getArg(args, 1, "exclusive start shard id", &EMPTY_STRING) 70 | limit := getIntArg(args, 2, "limit", &EMPTY_INT) 71 | streamDesc := describeStream(streamName, exclusiveStartShardId, limit) 72 | 73 | prettyBytes, err := json.MarshalIndent(streamDesc, "", " ") 74 | if err != nil { 75 | die(false, "Error marshaling response: %s", err) 76 | } 77 | fmt.Println(string(prettyBytes)) 78 | } 79 | 80 | func split(args []string) { 81 | streamName := getArg(args, 0, "stream name", nil) 82 | shardId := getArg(args, 1, "shard id", nil) 83 | newStartHash := getArg(args, 2, "starting hash", &EMPTY_STRING) 84 | if newStartHash == "" { 85 | newStartHash = askForShardStartHash(streamName, shardId) 86 | } 87 | if !confirm(fmt.Sprintf("split shard %s at hash key %s", shardId, newStartHash)) { 88 | fmt.Println("Split canceled.") 89 | return 90 | } 91 | requestArgs := kinesis.NewArgs() 92 | requestArgs.Add("StreamName", streamName) 93 | requestArgs.Add("ShardToSplit", shardId) 94 | requestArgs.Add("NewStartingHashKey", newStartHash) 95 | if err := newClient().SplitShard(requestArgs); err != nil { 96 | die(false, "Error splitting shard: %s", err) 97 | } 98 | } 99 | 100 | func merge(args []string) { 101 | streamName := getArg(args, 0, "stream name", nil) 102 | shardId := getArg(args, 1, "shard id", nil) 103 | adjacentShardId := getArg(args, 2, "adjacent shard id", nil) 104 | requestArgs := kinesis.NewArgs() 105 | requestArgs.Add("StreamName", streamName) 106 | requestArgs.Add("ShardToMerge", shardId) 107 | requestArgs.Add("AdjacentShardToMerge", adjacentShardId) 108 | if !confirm(fmt.Sprintf("merge shards %s and %s", shardId, adjacentShardId)) { 109 | fmt.Println("Merge canceled.") 110 | return 111 | } 112 | if err := newClient().MergeShards(requestArgs); err != nil { 113 | die(false, "Error merging shards: %s", err) 114 | } 115 | } 116 | 117 | func main() { 118 | if len(os.Args) < 2 { 119 | die(true, "Error: no command specified.") 120 | } 121 | if os.Getenv(kinesis.AccessEnvKey) == "" || 122 | os.Getenv(kinesis.SecretEnvKey) == "" { 123 | fmt.Printf("WARNING: %s and/or %s environment variables not set. Will "+ 124 | "attempt to fetch credentials from metadata server.\n", 125 | kinesis.AccessEnvKey, kinesis.SecretEnvKey) 126 | } 127 | if os.Getenv(kinesis.RegionEnvName) == "" { 128 | fmt.Printf("WARNING: %s not set.\n", kinesis.RegionEnvName) 129 | } 130 | switch os.Args[1] { 131 | case "create": 132 | create(os.Args[2:]) 133 | case "delete": 134 | delete(os.Args[2:]) 135 | case "describe": 136 | describe(os.Args[2:]) 137 | case "split": 138 | split(os.Args[2:]) 139 | case "merge": 140 | merge(os.Args[2:]) 141 | default: 142 | die(true, "Error: unknown command '%s'", os.Args[1]) 143 | } 144 | } 145 | 146 | // 147 | // Command line helper functions 148 | // 149 | 150 | func die(printHelp bool, format string, args ...interface{}) { 151 | if printHelp { 152 | fmt.Print(HELP) 153 | } 154 | fmt.Printf(format, args...) 155 | fmt.Println("") 156 | os.Exit(1) 157 | } 158 | 159 | func confirm(action string) bool { 160 | prompt := fmt.Sprintf("Are you sure you want to %s?\n[y/N]: ", action) 161 | s := readString(prompt, "") 162 | return strings.ToLower(s) == "y" 163 | } 164 | 165 | func readString(prompt string, defaultStr string) string { 166 | fmt.Print(prompt) 167 | reader := bufio.NewReader(os.Stdin) 168 | result, err := reader.ReadString('\n') 169 | if err != nil { 170 | die(false, "Error reading input: %s", err) 171 | } 172 | if result = strings.TrimSpace(result); result == "" { 173 | return defaultStr 174 | } 175 | return result 176 | } 177 | 178 | func getArg(specified []string, index int, name string, def *string) string { 179 | if index < len(specified) { 180 | return specified[index] 181 | } 182 | if def == nil { 183 | die(true, "Error: %s is required.", name) 184 | } 185 | return *def 186 | } 187 | 188 | func getIntArg(specified []string, index int, name string, def *int) int { 189 | var argStr string 190 | if def != nil { 191 | defStr := strconv.Itoa(*def) 192 | argStr = getArg(specified, index, name, &defStr) 193 | } else { 194 | argStr = getArg(specified, index, name, nil) 195 | } 196 | intArg, err := strconv.Atoi(argStr) 197 | if err != nil { 198 | die(false, "Error parsing %s as integer: %s\n%s", name, argStr, err) 199 | } 200 | return intArg 201 | } 202 | 203 | // 204 | // Big int (for 128-bit start/end hash keys) helper functions 205 | // 206 | 207 | // Takes two large base 10 numeric strings a and b, and returns (a + b)/2 208 | func getMiddle(lowStr, highStr string) *big.Int { 209 | low := bigIntFromStr(lowStr, 10) 210 | high := bigIntFromStr(highStr, 10) 211 | if low.Cmp(high) != -1 { 212 | die(false, "Error: %s is not smaller than %s", lowStr, highStr) 213 | } 214 | middle := new(big.Int) 215 | middle = middle.Div(middle.Add(low, high), big.NewInt(2)) 216 | return middle 217 | } 218 | 219 | // Takes two large base 10 numeric strings low and high and returns low < x < high. 220 | func isBetween(lowStr, highStr, xStr string) bool { 221 | low := bigIntFromStr(lowStr, 10) 222 | high := bigIntFromStr(highStr, 10) 223 | x := bigIntFromStr(xStr, 10) 224 | return x.Cmp(low) == 1 && x.Cmp(high) == -1 225 | } 226 | 227 | func bigIntFromStr(s string, base int) *big.Int { 228 | result := new(big.Int) 229 | result, success := result.SetString(s, 10) 230 | if !success { 231 | die(false, "Error: cannot create big int from string '%s'", s) 232 | } 233 | return result 234 | } 235 | 236 | // 237 | // Kinesis helper functions 238 | // 239 | 240 | func newClient() kinesis.KinesisClient { 241 | auth, _ := kinesis.NewAuthFromEnv() 242 | return kinesis.New(auth, kinesis.NewRegionFromEnv()) 243 | } 244 | 245 | func askForShardStartHash(streamName, shardId string) string { 246 | // Figure out a sensible default value for a split hash key. 247 | shardDesc := describeShard(streamName, shardId) 248 | if shardDesc == nil { 249 | die(false, "Error: No shard found with id %s", shardId) 250 | } 251 | existingStart, existingEnd := shardDesc.HashKeyRange.StartingHashKey, shardDesc.HashKeyRange.EndingHashKey 252 | newStartHash := getMiddle(existingStart, existingEnd).String() 253 | 254 | prompt := fmt.Sprintf("Shard's current hash key range (%s - %s)\nDefault (even split) key: %s\nType new key or press [enter] to choose default: ", 255 | existingStart, existingEnd, newStartHash) 256 | newStartHash = readString(prompt, newStartHash) 257 | if !isBetween(existingStart, existingEnd, newStartHash) { 258 | die(false, "New starting hash '%s' is not within shard's current range.", newStartHash) 259 | } 260 | return newStartHash 261 | } 262 | 263 | func describeShard(streamName, shardId string) *kinesis.DescribeStreamShards { 264 | describeResponse := describeStream(streamName, "", -1) 265 | for _, shard := range describeResponse.StreamDescription.Shards { 266 | if shard.ShardId == shardId { 267 | return &shard 268 | } 269 | } 270 | return nil 271 | } 272 | 273 | func describeStream(streamName, exclusiveStartShardId string, limit int) *kinesis.DescribeStreamResp { 274 | var response *kinesis.DescribeStreamResp 275 | done := false 276 | for !done { 277 | requestArgs := kinesis.NewArgs() 278 | requestArgs.Add("StreamName", streamName) 279 | if exclusiveStartShardId != "" { 280 | requestArgs.Add("ExclusiveStartShardId", exclusiveStartShardId) 281 | } 282 | if limit > 0 { 283 | requestArgs.Add("Limit", limit) 284 | } 285 | curResponse, err := newClient().DescribeStream(requestArgs) 286 | if err != nil { 287 | die(false, "Error describing stream: %s", err) 288 | } 289 | if response == nil { 290 | response = curResponse 291 | } else { 292 | shards := response.StreamDescription.Shards 293 | for _, shard := range curResponse.StreamDescription.Shards { 294 | shards = append(shards, shard) 295 | if len(shards) >= limit { 296 | done = true 297 | break 298 | } 299 | exclusiveStartShardId = shard.ShardId 300 | limit-- 301 | } 302 | } 303 | done = done || !response.StreamDescription.HasMoreShards 304 | } 305 | return response 306 | } 307 | -------------------------------------------------------------------------------- /kinesis.go: -------------------------------------------------------------------------------- 1 | // Package kinesis provide GOlang API for http://aws.amazon.com/kinesis/ 2 | package kinesis 3 | 4 | import ( 5 | "bytes" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io/ioutil" 10 | "net/http" 11 | "os" 12 | "sync" 13 | ) 14 | 15 | const ( 16 | ActionKey = "Action" 17 | RegionEnvName = "AWS_REGION_NAME" 18 | 19 | // Regions 20 | USEast1 = "us-east-1" 21 | USWest2 = "us-west-2" 22 | EUWest1 = "eu-west-1" 23 | EUCentral1 = "eu-central-1" 24 | APSouthEast1 = "ap-southeast-1" 25 | APSouthEast2 = "ap-southeast-2" 26 | APNortheast1 = "ap-northeast-1" 27 | 28 | KinesisVersion = "20131202" 29 | FirehoseVersion = "20150804" 30 | 31 | kinesisURL = "https://kinesis.%s.amazonaws.com" 32 | firehoseURL = "https://firehose.%s.amazonaws.com" 33 | ) 34 | 35 | // NewRegionFromEnv creates a region from the an expected environment variable 36 | func NewRegionFromEnv() string { 37 | return os.Getenv(RegionEnvName) 38 | } 39 | 40 | // Structure for kinesis client 41 | type Kinesis struct { 42 | client *Client 43 | endpoint string 44 | region string 45 | version string 46 | streamType string 47 | 48 | typeMu sync.Mutex 49 | versionMu sync.Mutex 50 | endpointMu sync.Mutex 51 | } 52 | 53 | // KinesisClient interface implemented by Kinesis 54 | type KinesisClient interface { 55 | CreateStream(StreamName string, ShardCount int) error 56 | DeleteStream(StreamName string) error 57 | DescribeStream(args *RequestArgs) (resp *DescribeStreamResp, err error) 58 | DescribeDeliveryStream(args *RequestArgs) (resp *DescribeDeliveryStreamResp, err error) 59 | GetRecords(args *RequestArgs) (resp *GetRecordsResp, err error) 60 | GetShardIterator(args *RequestArgs) (resp *GetShardIteratorResp, err error) 61 | ListStreams(args *RequestArgs) (resp *ListStreamsResp, err error) 62 | MergeShards(args *RequestArgs) error 63 | PutRecord(args *RequestArgs) (resp *PutRecordResp, err error) 64 | PutRecords(args *RequestArgs) (resp *PutRecordsResp, err error) 65 | PutRecordBatch(args *RequestArgs) (resp *PutRecordBatchResp, err error) 66 | SplitShard(args *RequestArgs) error 67 | } 68 | 69 | // New returns an initialized AWS Kinesis client using the canonical live “production” endpoint 70 | // for AWS Kinesis, i.e. https://kinesis.{region}.amazonaws.com 71 | func New(auth Auth, region string) *Kinesis { 72 | endpoint := fmt.Sprintf(kinesisURL, region) 73 | return NewWithEndpoint(auth, region, endpoint) 74 | } 75 | 76 | // NewWithClient returns an initialized AWS Kinesis client using the canonical live “production” endpoint 77 | // for AWS Kinesis, i.e. https://kinesis.{region}.amazonaws.com but with the ability to create a custom client 78 | // with specific configurations like a timeout 79 | func NewWithClient(region string, client *Client) *Kinesis { 80 | endpoint := fmt.Sprintf(kinesisURL, region) 81 | return &Kinesis{client: client, version: KinesisVersion, region: region, endpoint: endpoint, streamType: "Kinesis"} 82 | } 83 | 84 | // NewWithEndpoint returns an initialized AWS Kinesis client using the specified endpoint. 85 | // This is generally useful for testing, so a local Kinesis server can be used. 86 | func NewWithEndpoint(auth Auth, region, endpoint string) *Kinesis { 87 | // TODO: remove trailing slash on endpoint if there is one? does it matter? 88 | // TODO: validate endpoint somehow? 89 | return &Kinesis{client: NewClient(auth), version: KinesisVersion, region: region, endpoint: endpoint, streamType: "Kinesis"} 90 | } 91 | 92 | // Create params object for request 93 | func makeParams(action string) map[string]string { 94 | params := make(map[string]string) 95 | params[ActionKey] = action 96 | return params 97 | } 98 | 99 | // RequestArgs store params for request 100 | type RequestArgs struct { 101 | params map[string]interface{} 102 | Records []Record 103 | } 104 | 105 | // NewArgs creates a new Filter. 106 | func NewArgs() *RequestArgs { 107 | return &RequestArgs{ 108 | params: make(map[string]interface{}), 109 | } 110 | } 111 | 112 | // Add appends a filtering parameter with the given name and value(s). 113 | func (f *RequestArgs) Add(name string, value interface{}) { 114 | f.params[name] = value 115 | } 116 | 117 | func (f *RequestArgs) AddData(value []byte) { 118 | f.params["Data"] = value 119 | } 120 | 121 | // Error represent error from Kinesis API 122 | type Error struct { 123 | // HTTP status code (200, 403, ...) 124 | StatusCode int 125 | // error code ("UnsupportedOperation", ...) 126 | Code string 127 | // The human-oriented error message 128 | Message string 129 | RequestId string 130 | } 131 | 132 | // Error returns error message from error object 133 | func (err *Error) Error() string { 134 | if err.Code == "" { 135 | return err.Message 136 | } 137 | return fmt.Sprintf("%s (%s)", err.Message, err.Code) 138 | } 139 | 140 | type jsonErrors struct { 141 | Code string `json:"__type"` 142 | Message string 143 | } 144 | 145 | func buildError(r *http.Response) error { 146 | // Reading the body into a []byte because we might need to put it into an error 147 | // message after having the JSON decoding fail to produce a message. 148 | body, ioerr := ioutil.ReadAll(r.Body) 149 | if ioerr != nil { 150 | return fmt.Errorf("Could not read response body: %s", ioerr) 151 | } 152 | 153 | errors := jsonErrors{} 154 | json.NewDecoder(bytes.NewReader(body)).Decode(&errors) 155 | 156 | var err Error 157 | err.Message = errors.Message 158 | err.Code = errors.Code 159 | err.StatusCode = r.StatusCode 160 | if err.Message == "" { 161 | err.Message = fmt.Sprintf("%s: %s", r.Status, body) 162 | } 163 | return &err 164 | } 165 | 166 | func (k *Kinesis) getStreamType() string { 167 | k.typeMu.Lock() 168 | defer k.typeMu.Unlock() 169 | return k.streamType 170 | } 171 | 172 | func (k *Kinesis) setStreamType(streamType string) { 173 | k.typeMu.Lock() 174 | k.streamType = streamType 175 | k.typeMu.Unlock() 176 | } 177 | 178 | func (k *Kinesis) getVersion() string { 179 | k.versionMu.Lock() 180 | defer k.versionMu.Unlock() 181 | return k.version 182 | } 183 | 184 | func (k *Kinesis) setVersion(version string) { 185 | k.versionMu.Lock() 186 | k.version = version 187 | k.versionMu.Unlock() 188 | } 189 | 190 | func (k *Kinesis) getEndpoint() string { 191 | k.endpointMu.Lock() 192 | defer k.endpointMu.Unlock() 193 | return k.endpoint 194 | } 195 | 196 | func (k *Kinesis) setEndpoint(endpoint string) { 197 | k.endpointMu.Lock() 198 | k.endpoint = endpoint 199 | k.endpointMu.Unlock() 200 | } 201 | 202 | func (k *Kinesis) Firehose() { 203 | k.setStreamType("Firehose") 204 | k.setVersion(FirehoseVersion) 205 | k.setEndpoint(fmt.Sprintf(firehoseURL, k.region)) 206 | } 207 | 208 | // Query by AWS API 209 | func (kinesis *Kinesis) query(params map[string]string, data interface{}, resp interface{}) error { 210 | jsonData, err := json.Marshal(data) 211 | if err != nil { 212 | return err 213 | } 214 | 215 | // request 216 | request, err := http.NewRequest( 217 | "POST", 218 | kinesis.getEndpoint(), 219 | bytes.NewReader(jsonData), 220 | ) 221 | 222 | if err != nil { 223 | return err 224 | } 225 | 226 | // headers 227 | request.Header.Set("Content-Type", "application/x-amz-json-1.1") 228 | request.Header.Set("X-Amz-Target", fmt.Sprintf("%s_%s.%s", kinesis.getStreamType(), kinesis.getVersion(), params[ActionKey])) 229 | request.Header.Set("User-Agent", "Golang Kinesis") 230 | 231 | // response 232 | response, err := kinesis.client.Do(request) 233 | if err != nil { 234 | return err 235 | } 236 | defer response.Body.Close() 237 | 238 | if response.StatusCode != 200 { 239 | return buildError(response) 240 | } 241 | 242 | if resp == nil { 243 | return nil 244 | } 245 | 246 | return json.NewDecoder(response.Body).Decode(resp) 247 | } 248 | 249 | // CreateStream adds a new Amazon Kinesis stream to your AWS account 250 | // StreamName is a name of stream, ShardCount is number of shards 251 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_CreateStream.html 252 | func (kinesis *Kinesis) CreateStream(StreamName string, ShardCount int) error { 253 | params := makeParams("CreateStream") 254 | requestParams := struct { 255 | StreamName string 256 | ShardCount int 257 | }{ 258 | StreamName, 259 | ShardCount, 260 | } 261 | err := kinesis.query(params, requestParams, nil) 262 | if err != nil { 263 | return err 264 | } 265 | return nil 266 | } 267 | 268 | // DeleteStream deletes a stream and all of its shards and data from your AWS account 269 | // StreamName is a name of stream 270 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_DeleteStream.html 271 | func (kinesis *Kinesis) DeleteStream(StreamName string) error { 272 | params := makeParams("DeleteStream") 273 | requestParams := struct { 274 | StreamName string 275 | }{ 276 | StreamName, 277 | } 278 | err := kinesis.query(params, requestParams, nil) 279 | if err != nil { 280 | return err 281 | } 282 | return nil 283 | } 284 | 285 | // MergeShards merges two adjacent shards in a stream and combines them into a single shard to reduce the stream's capacity to ingest and transport data 286 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_MergeShards.html 287 | func (kinesis *Kinesis) MergeShards(args *RequestArgs) error { 288 | params := makeParams("MergeShards") 289 | err := kinesis.query(params, args.params, nil) 290 | if err != nil { 291 | return err 292 | } 293 | return nil 294 | } 295 | 296 | // SplitShard splits a shard into two new shards in the stream, to increase the stream's capacity to ingest and transport data 297 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_SplitShard.html 298 | func (kinesis *Kinesis) SplitShard(args *RequestArgs) error { 299 | params := makeParams("SplitShard") 300 | err := kinesis.query(params, args.params, nil) 301 | if err != nil { 302 | return err 303 | } 304 | return nil 305 | } 306 | 307 | // ListStreamsResp stores the information that provides by ListStreams API call 308 | type ListStreamsResp struct { 309 | HasMoreStreams bool 310 | StreamNames []string 311 | } 312 | 313 | // ListStreams returns an array of the names of all the streams that are associated with the AWS account making the ListStreams request 314 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_ListStreams.html 315 | func (kinesis *Kinesis) ListStreams(args *RequestArgs) (resp *ListStreamsResp, err error) { 316 | params := makeParams("ListStreams") 317 | resp = &ListStreamsResp{} 318 | err = kinesis.query(params, args.params, resp) 319 | if err != nil { 320 | return nil, err 321 | } 322 | return 323 | } 324 | 325 | // DescribeStreamShards stores the information about list of shards inside DescribeStreamResp 326 | type DescribeStreamShards struct { 327 | AdjacentParentShardId string 328 | HashKeyRange struct { 329 | EndingHashKey string 330 | StartingHashKey string 331 | } 332 | ParentShardId string 333 | SequenceNumberRange struct { 334 | EndingSequenceNumber string 335 | StartingSequenceNumber string 336 | } 337 | ShardId string 338 | } 339 | 340 | // DescribeStreamResp stores the information that provides by DescribeStream API call 341 | type DescribeStreamResp struct { 342 | StreamDescription struct { 343 | HasMoreShards bool 344 | Shards []DescribeStreamShards 345 | StreamARN string 346 | StreamName string 347 | StreamStatus string 348 | } 349 | } 350 | 351 | // DescribeStream returns the following information about the stream: the current status of the stream, 352 | // the stream Amazon Resource Name (ARN), and an array of shard objects that comprise the stream. 353 | // For each shard object there is information about the hash key and sequence number ranges that 354 | // the shard spans, and the IDs of any earlier shards that played in a role in a MergeShards or 355 | // SplitShard operation that created the shard 356 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_DescribeStream.html 357 | func (kinesis *Kinesis) DescribeStream(args *RequestArgs) (resp *DescribeStreamResp, err error) { 358 | params := makeParams("DescribeStream") 359 | resp = &DescribeStreamResp{} 360 | err = kinesis.query(params, args.params, resp) 361 | if err != nil { 362 | return nil, err 363 | } 364 | return 365 | } 366 | 367 | // GetShardIteratorResp stores the information that provides by GetShardIterator API call 368 | type GetShardIteratorResp struct { 369 | ShardIterator string 370 | } 371 | 372 | // GetShardIterator returns a shard iterator 373 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_GetShardIterator.html 374 | func (kinesis *Kinesis) GetShardIterator(args *RequestArgs) (resp *GetShardIteratorResp, err error) { 375 | params := makeParams("GetShardIterator") 376 | resp = &GetShardIteratorResp{} 377 | err = kinesis.query(params, args.params, resp) 378 | if err != nil { 379 | return nil, err 380 | } 381 | return 382 | } 383 | 384 | // GetNextRecordsRecords stores the information that provides by GetNextRecordsResp 385 | type GetRecordsRecords struct { 386 | ApproximateArrivalTimestamp float64 387 | Data []byte 388 | PartitionKey string 389 | SequenceNumber string 390 | } 391 | 392 | func (r GetRecordsRecords) GetData() []byte { 393 | return r.Data 394 | } 395 | 396 | // GetNextRecordsResp stores the information that provides by GetNextRecords API call 397 | type GetRecordsResp struct { 398 | MillisBehindLatest int64 399 | NextShardIterator string 400 | Records []GetRecordsRecords 401 | } 402 | 403 | // GetRecords returns one or more data records from a shard 404 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_GetRecords.html 405 | func (kinesis *Kinesis) GetRecords(args *RequestArgs) (resp *GetRecordsResp, err error) { 406 | params := makeParams("GetRecords") 407 | resp = &GetRecordsResp{} 408 | err = kinesis.query(params, args.params, resp) 409 | if err != nil { 410 | return nil, err 411 | } 412 | return 413 | } 414 | 415 | // PutRecordResp stores the information that provides by PutRecord API call 416 | type PutRecordResp struct { 417 | SequenceNumber string 418 | ShardId string 419 | } 420 | 421 | // PutRecord puts a data record into an Amazon Kinesis stream from a producer. 422 | // args must contain a single record added with AddRecord. 423 | // More info: http://docs.aws.amazon.com/kinesis/latest/APIReference/API_PutRecord.html 424 | func (kinesis *Kinesis) PutRecord(args *RequestArgs) (resp *PutRecordResp, err error) { 425 | params := makeParams("PutRecord") 426 | 427 | if _, ok := args.params["Data"]; !ok && len(args.Records) == 0 { 428 | return nil, errors.New("PutRecord requires its args param to contain a record added with either AddRecord or AddData.") 429 | } else if ok && len(args.Records) > 0 { 430 | return nil, errors.New("PutRecord requires its args param to contain a record added with either AddRecord or AddData but not both.") 431 | } else if len(args.Records) > 1 { 432 | return nil, errors.New("PutRecord does not support more than one record.") 433 | } 434 | 435 | if len(args.Records) > 0 { 436 | args.AddData(args.Records[0].Data) 437 | args.Add("PartitionKey", args.Records[0].PartitionKey) 438 | } 439 | 440 | resp = &PutRecordResp{} 441 | err = kinesis.query(params, args.params, resp) 442 | if err != nil { 443 | return nil, err 444 | } 445 | return 446 | } 447 | 448 | // PutRecords puts multiple data records into an Amazon Kinesis stream from a producer 449 | // more info http://docs.aws.amazon.com/kinesis/latest/APIReference/API_PutRecords.html 450 | func (kinesis *Kinesis) PutRecords(args *RequestArgs) (resp *PutRecordsResp, err error) { 451 | params := makeParams("PutRecords") 452 | resp = &PutRecordsResp{} 453 | args.Add("Records", args.Records) 454 | err = kinesis.query(params, args.params, resp) 455 | 456 | if err != nil { 457 | return nil, err 458 | } 459 | return 460 | } 461 | 462 | // PutRecordsResp stores the information that provides by PutRecord API call 463 | type PutRecordsResp struct { 464 | FailedRecordCount int 465 | Records []PutRecordsRespRecord 466 | } 467 | 468 | // RecordResp stores individual Record information provided by PutRecords API call 469 | type PutRecordsRespRecord struct { 470 | ErrorCode string 471 | ErrorMessage string 472 | SequenceNumber string 473 | ShardId string 474 | } 475 | 476 | // AddRecord adds data and partition for sending multiple Records to Kinesis in one API call 477 | func (f *RequestArgs) AddRecord(value []byte, partitionKey string) { 478 | r := Record{ 479 | Data: value, 480 | PartitionKey: partitionKey, 481 | } 482 | f.Records = append(f.Records, r) 483 | } 484 | 485 | // Record stores the Data and PartitionKey for PutRecord or PutRecords calls to Kinesis API 486 | type Record struct { 487 | Data []byte 488 | PartitionKey string 489 | } 490 | -------------------------------------------------------------------------------- /batchproducer/batchproducer.go: -------------------------------------------------------------------------------- 1 | package batchproducer 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "os" 7 | "sync" 8 | "time" 9 | 10 | "github.com/sendgridlabs/go-kinesis" 11 | ) 12 | 13 | // MaxKinesisBatchSize is the maximum number of records that Kinesis accepts in a request 14 | const MaxKinesisBatchSize = 500 15 | 16 | // Producer collects records individually and then sends them to Kinesis in 17 | // batches in the background using PutRecords, with retries. 18 | // A Producer will do nothing until Start is called. 19 | type Producer interface { 20 | // Start starts the main goroutine. No need to call it using `go`. 21 | Start() error 22 | 23 | // Stop signals the main goroutine to finish. Once this is called, Add will immediately start 24 | // returning errors (unless and until Start is called again). 25 | Stop() error 26 | 27 | // Add might block if the BatchProducer has a buffer and the buffer is full. 28 | // In order to prevent filling the buffer and eventually blocking indefinitely, 29 | // Add will fail and return an error if the BatchProducer is stopped or stopping. Note 30 | // that it’s critical to check the return value because the BatchProducer could have 31 | // died in the background due to a panic (or something). 32 | Add(data []byte, partitionKey string) error 33 | 34 | // Flush stops the Producer using Stop and attempts to send all buffered records to Kinesis as 35 | // fast as possible with batches of size 500 (the maximum). It blocks until either all records 36 | // are sent or the timeout expires. It returns the number of records still remaining in the 37 | // buffer or (possibly) an error. (It doesn’t currently return errors but that is in the 38 | // signature for future-proofing.) A timeout value of 0 means no timeout. 39 | // If Flush finishes sending all records without timing out, and sendStats is true, it will 40 | // cause a single final StatsBatch to be sent to the StatsReceiver in Config, if set. 41 | Flush(timeout time.Duration, sendStats bool) (sent int, remaining int, err error) 42 | } 43 | 44 | // StatReceiver defines an object that can accept stats. 45 | type StatReceiver interface { 46 | // Receive will be called by the main Producer goroutine so it will block all batches from being 47 | // sent, so make sure it is either very fast or never blocks at all! 48 | Receive(StatsBatch) 49 | } 50 | 51 | // StatsBatch is a kind of a snapshot of activity and happenings. Some of its fields represent 52 | // "moment-in-time" values e.g. BufferSize is the size of the buffer at the moment the StatsBatch 53 | // is sent. Other fields are cumulative since the last StatsBatch, i.e. ErrorsSinceLastStat. 54 | type StatsBatch struct { 55 | // Moment-in-time stats 56 | BufferSize int 57 | 58 | // Cumulative stats 59 | KinesisErrorsSinceLastStat int 60 | RecordsSentSuccessfullySinceLastStat int 61 | RecordsDroppedSinceLastStat int 62 | } 63 | 64 | // BatchingKinesisClient is a subset of KinesisClient to ease mocking. 65 | type BatchingKinesisClient interface { 66 | PutRecords(args *kinesis.RequestArgs) (resp *kinesis.PutRecordsResp, err error) 67 | } 68 | 69 | type BatchProducerLogger interface { 70 | Printf(format string, args ...interface{}) 71 | } 72 | 73 | // Config is a collection of config values for a Producer 74 | type Config struct { 75 | // AddBlocksWhenBufferFull controls the behavior of Add when the buffer is full. If true, Add 76 | // will block. If false, Add will return an error. This enables integrating applications to 77 | // decide how they want to handle a full buffer e.g. so they can discard records if there’s 78 | // a problem. 79 | AddBlocksWhenBufferFull bool 80 | 81 | // BatchSize controls the maximum size of the batches sent to Kinesis. If the number of records 82 | // in the buffer hits this size, a batch of this size will be sent at that time, regardless of 83 | // whether FlushInterval has a value or not. 84 | BatchSize int 85 | 86 | // BufferSize is the size of the buffer that stores records before they are sent to the Kinesis 87 | // stream. If when Add is called the number of records in the buffer is >= bufferSize then 88 | // Add will either block or return an error, depending on the value of AddBlocksWhenBufferFull. 89 | BufferSize int 90 | 91 | // FlushInterval controls how often the buffer is flushed to Kinesis. If nonzero, then every 92 | // time this interval occurs, if there are any records in the buffer, they will be flushed, 93 | // no matter how few there are. The size of the batch that’s flushed may be as small as 1 but 94 | // will be no larger than BatchSize. 95 | FlushInterval time.Duration 96 | 97 | // The logger used by the Producer. 98 | Logger BatchProducerLogger 99 | 100 | // MaxAttemptsPerRecord defines how many attempts should be made for each record before it is 101 | // dropped. You probably want this higher than the init default of 0. 102 | MaxAttemptsPerRecord int 103 | 104 | // StatInterval will be used to make a *best effort* attempt to send stats *approximately* 105 | // when this interval elapses. There’s no guarantee, however, since the main goroutine is 106 | // used to send the stats and therefore there may be some skew. 107 | StatInterval time.Duration 108 | 109 | // StatReceiver will have its Receive method called approximately every StatInterval. 110 | StatReceiver StatReceiver 111 | } 112 | 113 | // DefaultConfig is provided for convenience; if you have no specific preferences on how you’d 114 | // like to configure your Producer you can pass this into New. The default value of Logger is 115 | // the same as the standard logger in "log" : `log.New(os.Stderr, "", log.LstdFlags)`. 116 | var DefaultConfig = Config{ 117 | AddBlocksWhenBufferFull: false, 118 | BufferSize: 10000, 119 | FlushInterval: 1 * time.Second, 120 | BatchSize: 10, 121 | MaxAttemptsPerRecord: 10, 122 | StatInterval: 1 * time.Second, 123 | Logger: log.New(os.Stderr, "", log.LstdFlags), 124 | } 125 | 126 | var ( 127 | // ErrAlreadyStarted is returned by Start if the Producer is already started. 128 | ErrAlreadyStarted = errors.New("already started") 129 | 130 | // ErrAlreadyStopped is returned by Stop if the Producer is already stopped. 131 | ErrAlreadyStopped = errors.New("already stopped") 132 | ) 133 | 134 | // New creates and returns a BatchProducer that will do nothing until its Start method is called. 135 | // Once it is started, it will flush a batch to Kinesis whenever either 136 | // the flushInterval occurs (if flushInterval > 0) or the batchSize is reached, 137 | // whichever happens first. 138 | func New( 139 | client BatchingKinesisClient, 140 | streamName string, 141 | config Config, 142 | ) (Producer, error) { 143 | if config.BatchSize < 1 || config.BatchSize > MaxKinesisBatchSize { 144 | return nil, errors.New("BatchSize must be between 1 and 500 inclusive") 145 | } 146 | 147 | if config.BufferSize < config.BatchSize && config.FlushInterval <= 0 { 148 | return nil, errors.New("if BufferSize < BatchSize && FlushInterval <= 0 then the buffer will eventually fill up and Add will block forever") 149 | } 150 | 151 | if config.FlushInterval > 0 && config.FlushInterval < 50*time.Millisecond { 152 | return nil, errors.New("are you crazy") 153 | } 154 | 155 | batchProducer := batchProducer{ 156 | client: client, 157 | streamName: streamName, 158 | config: config, 159 | logger: config.Logger, 160 | currentStat: new(StatsBatch), 161 | records: make(chan batchRecord, config.BufferSize), 162 | start: make(chan interface{}), 163 | stop: make(chan interface{}), 164 | } 165 | 166 | return &batchProducer, nil 167 | } 168 | 169 | type batchProducer struct { 170 | client BatchingKinesisClient 171 | streamName string 172 | config Config 173 | logger BatchProducerLogger 174 | running bool 175 | runningMu sync.RWMutex 176 | consecutiveErrors int 177 | currentDelay time.Duration 178 | currentStat *StatsBatch 179 | records chan batchRecord 180 | 181 | // start and stop will be unbuffered and will be used to send signals to start/stop and 182 | // response signals that indicate that the respective operations have completed. 183 | start chan interface{} 184 | stop chan interface{} 185 | } 186 | 187 | type batchRecord struct { 188 | data []byte 189 | partitionKey string 190 | sendAttempts int 191 | } 192 | 193 | // from/for interface Producer 194 | func (b *batchProducer) Add(data []byte, partitionKey string) error { 195 | if !b.isRunning() { 196 | return errors.New("Cannot call Add when BatchProducer is not running (to prevent the buffer filling up and Add blocking indefinitely).") 197 | } 198 | if b.isBufferFull() && !b.config.AddBlocksWhenBufferFull { 199 | return errors.New("Buffer is full") 200 | } 201 | b.records <- batchRecord{data: data, partitionKey: partitionKey} 202 | return nil 203 | } 204 | 205 | // from/for interface Producer 206 | func (b *batchProducer) Start() error { 207 | b.runningMu.Lock() 208 | defer b.runningMu.Unlock() 209 | 210 | if b.running { 211 | return ErrAlreadyStarted 212 | } 213 | 214 | go b.run() 215 | 216 | // We want run to run in the background (in a goroutine) but we don’t want to return until that 217 | // goroutine has actually entered its main loop. So we read from this non-buffered channel, which 218 | // will block until run writes a value to it. 219 | <-b.start 220 | 221 | b.running = true 222 | 223 | return nil 224 | } 225 | 226 | func (b *batchProducer) run() { 227 | flushTicker := &time.Ticker{} 228 | if b.config.FlushInterval > 0 { 229 | flushTicker = time.NewTicker(b.config.FlushInterval) 230 | defer flushTicker.Stop() 231 | } 232 | 233 | statTicker := &time.Ticker{} 234 | if b.config.StatReceiver != nil && b.config.StatInterval > 0 { 235 | statTicker = time.NewTicker(b.config.StatInterval) 236 | defer statTicker.Stop() 237 | } 238 | 239 | // used to signal Start that we are now running (entering the main loop) 240 | b.start <- true 241 | 242 | for { 243 | select { 244 | case <-flushTicker.C: 245 | b.sendBatch(b.config.BatchSize) 246 | case <-statTicker.C: 247 | b.sendStats() 248 | case <-b.stop: 249 | b.sendStats() 250 | b.stop <- true 251 | return 252 | default: 253 | if len(b.records) >= b.config.BatchSize { 254 | b.sendBatch(b.config.BatchSize) 255 | } else { 256 | time.Sleep(1 * time.Millisecond) 257 | } 258 | } 259 | } 260 | } 261 | 262 | // from/for interface Producer 263 | func (b *batchProducer) Stop() error { 264 | b.runningMu.Lock() 265 | defer b.runningMu.Unlock() 266 | 267 | if !b.running { 268 | return ErrAlreadyStopped 269 | } 270 | 271 | // request the main goroutine to stop 272 | b.stop <- true 273 | 274 | // block until the main goroutine returns a value indicating that it has stopped 275 | <-b.stop 276 | 277 | b.running = false 278 | 279 | return nil 280 | } 281 | 282 | // from/for interface Producer 283 | // TODO: send all batches in parallel, will require broader refactoring 284 | func (b *batchProducer) Flush(timeout time.Duration, sendStats bool) (int, int, error) { 285 | b.Stop() 286 | 287 | timer := time.NewTimer(timeout) 288 | if timeout == 0 { 289 | timer.Stop() 290 | } 291 | 292 | timedOut := false 293 | sent := 0 294 | 295 | loop: 296 | for len(b.records) > 0 { 297 | select { 298 | case <-timer.C: 299 | timedOut = true 300 | break loop 301 | default: 302 | sent += b.sendBatch(MaxKinesisBatchSize) 303 | } 304 | } 305 | 306 | if !timedOut && sendStats { 307 | b.sendStats() 308 | } 309 | 310 | return sent, len(b.records), nil 311 | } 312 | 313 | func (b *batchProducer) isRunning() bool { 314 | b.runningMu.RLock() 315 | defer b.runningMu.RUnlock() 316 | return b.running 317 | } 318 | 319 | // Sends batches of records to Kinesis, possibly re-enqueing them if there are any errors or failed 320 | // records. Returns the number of records successfully sent, if any. 321 | func (b *batchProducer) sendBatch(batchSize int) int { 322 | if len(b.records) == 0 { 323 | return 0 324 | } 325 | 326 | // In the future, maybe this could be a RetryPolicy or something 327 | if b.consecutiveErrors == 1 { 328 | b.currentDelay = 50 * time.Millisecond 329 | } else if b.consecutiveErrors > 1 { 330 | b.currentDelay *= 2 331 | } 332 | 333 | if b.currentDelay > 0 { 334 | b.logger.Printf("Delaying the batch by %v because of %v consecutive errors", b.currentDelay, b.consecutiveErrors) 335 | time.Sleep(b.currentDelay) 336 | } 337 | 338 | records := b.takeRecordsFromBuffer(batchSize) 339 | res, err := b.client.PutRecords(b.recordsToArgs(records)) 340 | 341 | if err != nil { 342 | b.consecutiveErrors++ 343 | b.currentStat.KinesisErrorsSinceLastStat++ 344 | b.logger.Printf("Error occurred when sending PutRecords request to Kinesis stream %v: %v", b.streamName, err) 345 | 346 | if b.consecutiveErrors >= 5 && b.isBufferFullOrNearlyFull() { 347 | // In order to prevent Add from hanging indefinitely, we start dropping records 348 | b.logger.Printf("DROPPING %v records because buffer is full or nearly full and there have been %v consecutive errors from Kinesis", len(records), b.consecutiveErrors) 349 | } else { 350 | b.logger.Printf("Returning %v records to buffer (%v consecutive errors)", len(records), b.consecutiveErrors) 351 | // returnRecordsToBuffer can block if the buffer (channel) if full so we’ll 352 | // call it in a goroutine. This might be problematic WRT ordering. TODO: revisit this. 353 | go b.returnRecordsToBuffer(records) 354 | } 355 | 356 | return 0 357 | } 358 | 359 | b.consecutiveErrors = 0 360 | b.currentDelay = 0 361 | succeeded := len(records) - res.FailedRecordCount 362 | 363 | b.currentStat.RecordsSentSuccessfullySinceLastStat += succeeded 364 | 365 | if res.FailedRecordCount == 0 { 366 | b.logger.Printf("PutRecords request succeeded: sent %v records to Kinesis stream %v", succeeded, b.streamName) 367 | } else { 368 | b.logger.Printf("Partial success when sending a PutRecords request to Kinesis stream %v: %v succeeded, %v failed. Re-enqueueing failed records.", b.streamName, succeeded, res.FailedRecordCount) 369 | // returnSomeFailedRecordsToBuffer can block if the buffer (channel) if full so we’ll 370 | // call it in a goroutine. This might be problematic WRT ordering. TODO: revisit this. 371 | go b.returnSomeFailedRecordsToBuffer(res, records) 372 | } 373 | 374 | return succeeded 375 | } 376 | 377 | func (b *batchProducer) isBufferFullOrNearlyFull() bool { 378 | return float32(len(b.records))/float32(cap(b.records)) >= 0.95 379 | } 380 | 381 | func (b *batchProducer) isBufferFull() bool { 382 | // Treating 99% as full because IIRC, len(chan) has a margin of error 383 | return float32(len(b.records))/float32(cap(b.records)) >= 0.99 384 | } 385 | 386 | func (b *batchProducer) takeRecordsFromBuffer(batchSize int) []batchRecord { 387 | var size int 388 | bufferLen := len(b.records) 389 | if bufferLen >= batchSize { 390 | size = batchSize 391 | } else { 392 | size = bufferLen 393 | } 394 | 395 | result := make([]batchRecord, size) 396 | for i := 0; i < size; i++ { 397 | result[i] = <-b.records 398 | } 399 | return result 400 | } 401 | 402 | func (b *batchProducer) recordsToArgs(records []batchRecord) *kinesis.RequestArgs { 403 | args := kinesis.NewArgs() 404 | args.Add("StreamName", b.streamName) 405 | for _, record := range records { 406 | args.AddRecord(record.data, record.partitionKey) 407 | } 408 | return args 409 | } 410 | 411 | // returnRecordsToBuffer can block if the buffer (channel) is full, so you might want to 412 | // call it in a goroutine. 413 | // TODO: we should probably use a deque internally as the buffer so we can return records to 414 | // the front of the queue, so as to preserve order, which is important. 415 | func (b *batchProducer) returnRecordsToBuffer(records []batchRecord) { 416 | for _, record := range records { 417 | // Not using b.Add because we want to preserve the value of record.sendAttempts. 418 | b.records <- record 419 | } 420 | } 421 | 422 | // returnSomeFailedRecordsToBuffer can block if the buffer (channel) is full, so you might want to 423 | // call it in a goroutine. 424 | // TODO: we should probably use a deque internally as the buffer so we can return records to 425 | // the front of the queue, so as to preserve order, which is important. 426 | func (b *batchProducer) returnSomeFailedRecordsToBuffer(res *kinesis.PutRecordsResp, records []batchRecord) { 427 | for i, result := range res.Records { 428 | record := records[i] 429 | if result.ErrorCode != "" { 430 | record.sendAttempts++ 431 | 432 | if record.sendAttempts < b.config.MaxAttemptsPerRecord { 433 | b.logger.Printf("Re-enqueueing failed record to buffer for retry. Error code was: '%v' and message was '%v'", result.ErrorCode, result.ErrorMessage) 434 | // Not using b.Add because we want to preserve the value of record.sendAttempts. 435 | b.records <- record 436 | } else { 437 | b.currentStat.RecordsDroppedSinceLastStat++ 438 | msg := "Dropping failed record; it has hit %v attempts " + 439 | "which is the maximum. Error code was: '%v' and message was '%v'." 440 | b.logger.Printf(msg, record.sendAttempts, result.ErrorCode, result.ErrorMessage) 441 | } 442 | } 443 | } 444 | } 445 | 446 | func (b *batchProducer) sendStats() { 447 | if b.config.StatReceiver == nil { 448 | return 449 | } 450 | 451 | b.currentStat.BufferSize = len(b.records) 452 | 453 | // I considered running this as a goroutine, but I’m concerned about leaks. So instead, for now, 454 | // the provider of the BatchStatReceiver must ensure that it is either very fast or non-blocking. 455 | b.config.StatReceiver.Receive(*b.currentStat) 456 | 457 | b.currentStat = new(StatsBatch) 458 | } 459 | -------------------------------------------------------------------------------- /batchproducer/batchproducer_test.go: -------------------------------------------------------------------------------- 1 | package batchproducer 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io/ioutil" 7 | "log" 8 | "os" 9 | "strings" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "github.com/sendgridlabs/go-kinesis" 15 | ) 16 | 17 | var ( 18 | discardLogger = log.New(ioutil.Discard, "", 0) 19 | stdoutLogger = log.New(os.Stdout, "", 0) 20 | ) 21 | 22 | func TestNewBatchProducerWithGoodValues(t *testing.T) { 23 | t.Parallel() 24 | config := Config{ 25 | BufferSize: 10, 26 | FlushInterval: 0, 27 | BatchSize: 10, 28 | } 29 | b, err := New(&mockBatchingClient{}, "foo", config) 30 | if b == nil { 31 | t.Error("b == nil") 32 | } 33 | if err != nil { 34 | t.Errorf("%q != nil", err) 35 | } 36 | } 37 | 38 | func TestNewBatchProducerWithBadBatchSize(t *testing.T) { 39 | t.Parallel() 40 | config := Config{ 41 | BufferSize: 10000, 42 | FlushInterval: 0, 43 | BatchSize: 1000, 44 | } 45 | b, err := New(&mockBatchingClient{}, "foo", config) 46 | if b != nil { 47 | t.Errorf("%q != nil", b) 48 | } 49 | if err == nil { 50 | t.Error("err == nil") 51 | } 52 | if !strings.Contains(err.Error(), "between 1 and 500") { 53 | t.Errorf("%q does not contain 'between 1 and 500'", err) 54 | } 55 | } 56 | 57 | func TestNewBatchProducerWithBadValues(t *testing.T) { 58 | t.Parallel() 59 | config := Config{ 60 | BufferSize: 10, 61 | FlushInterval: 0, 62 | BatchSize: 500, 63 | } 64 | b, err := New(&mockBatchingClient{}, "foo", config) 65 | if b != nil { 66 | t.Errorf("%q != nil", b) 67 | } 68 | if err == nil { 69 | t.Fatalf("err == nil") 70 | } 71 | if !strings.Contains(err.Error(), "Add will block forever") { 72 | t.Errorf("%q does not contain 'Add will block forever'", err) 73 | } 74 | } 75 | 76 | func TestStart(t *testing.T) { 77 | t.Parallel() 78 | 79 | b := newProducer(&mockBatchingClient{}, 10, 0, 10) 80 | 81 | if b.isRunning() { 82 | t.Error("b should not be running") 83 | } 84 | 85 | err := b.Start() 86 | defer b.Stop() 87 | 88 | if err != nil { 89 | t.Errorf("%v != nil", err) 90 | } 91 | 92 | if !b.isRunning() { 93 | t.Error("b should be running") 94 | } 95 | } 96 | 97 | func TestStop(t *testing.T) { 98 | t.Parallel() 99 | 100 | b := newProducer(&mockBatchingClient{}, 10, 0, 10) 101 | 102 | if b.isRunning() { 103 | t.Error("b should not be running") 104 | } 105 | 106 | b.Start() 107 | err := b.Stop() 108 | 109 | if err != nil { 110 | t.Errorf("%v != nil", err) 111 | } 112 | 113 | if b.isRunning() { 114 | t.Error("b should NOT be running") 115 | } 116 | } 117 | 118 | func TestStartWhenStarted(t *testing.T) { 119 | t.Parallel() 120 | config := Config{ 121 | BufferSize: 100, 122 | FlushInterval: 0, 123 | BatchSize: 10, 124 | } 125 | b, err := New(&mockBatchingClient{}, "foo", config) 126 | if err != nil { 127 | t.Fatalf("%v != nil", err) 128 | } 129 | 130 | b.Start() 131 | defer b.Stop() 132 | 133 | err = b.Start() 134 | if err == nil { 135 | t.Errorf("%v == nil", err) 136 | } 137 | } 138 | 139 | func TestStopWhenStopped(t *testing.T) { 140 | t.Parallel() 141 | config := Config{ 142 | BufferSize: 100, 143 | FlushInterval: 0, 144 | BatchSize: 10, 145 | } 146 | b, err := New(&mockBatchingClient{}, "foo", config) 147 | if err != nil { 148 | t.Fatalf("%v != nil", err) 149 | } 150 | 151 | err = b.Stop() 152 | if err == nil { 153 | t.Errorf("%v == nil", err) 154 | } 155 | } 156 | 157 | func TestSuccessiveStartsAndStops(t *testing.T) { 158 | t.Parallel() 159 | config := Config{ 160 | BufferSize: 100, 161 | FlushInterval: 0, 162 | BatchSize: 10, 163 | } 164 | b, err := New(&mockBatchingClient{}, "foo", config) 165 | if err != nil { 166 | t.Fatalf("%v != nil", err) 167 | } 168 | 169 | for i := 0; i < 10; i++ { 170 | err = b.Start() 171 | if err != nil { 172 | t.Errorf("%v != nil", err) 173 | } 174 | 175 | err = b.Stop() 176 | if err != nil { 177 | t.Errorf("%v != nil", err) 178 | } 179 | } 180 | } 181 | 182 | func TestAddRecordWhenStarted(t *testing.T) { 183 | t.Parallel() 184 | config := Config{ 185 | BufferSize: 100, 186 | FlushInterval: 0, 187 | BatchSize: 10, 188 | } 189 | b, err := New(&mockBatchingClient{}, "foo", config) 190 | if err != nil { 191 | t.Fatalf("%v != nil", err) 192 | } 193 | 194 | b.Start() 195 | defer b.Stop() 196 | 197 | err = b.Add([]byte("foo"), "bar") 198 | if err != nil { 199 | t.Errorf("%v != nil", err) 200 | } 201 | } 202 | 203 | func TestAddRecordWhenStopped(t *testing.T) { 204 | t.Parallel() 205 | config := Config{ 206 | BufferSize: 100, 207 | FlushInterval: 0, 208 | BatchSize: 10, 209 | } 210 | b, err := New(&mockBatchingClient{}, "foo", config) 211 | if err != nil { 212 | t.Fatalf("%v != nil", err) 213 | } 214 | 215 | err = b.Add([]byte("foo"), "bar") 216 | if err == nil { 217 | t.Errorf("%v == nil", err) 218 | } 219 | } 220 | 221 | func TestFlushInterval(t *testing.T) { 222 | t.Parallel() 223 | c := &mockBatchingClient{} 224 | b := newProducer(c, 100, 2*time.Millisecond, 10) 225 | b.Start() 226 | defer b.Stop() 227 | 228 | b.addRecordsAndWait(10, 0) 229 | if len(b.records) != 10 { 230 | t.Errorf("%v != 10", len(b.records)) 231 | } 232 | if c.calls != 0 { 233 | t.Errorf("%v != 0", c.calls) 234 | } 235 | 236 | time.Sleep(3 * time.Millisecond) 237 | if len(b.records) != 0 { 238 | t.Errorf("%v != 0", len(b.records)) 239 | } 240 | if c.calls != 1 { 241 | t.Errorf("%v != 1", c.calls) 242 | } 243 | 244 | // 20 more records should result in two more batches being sent 245 | b.addRecordsAndWait(20, 8) 246 | if len(b.records) != 0 { 247 | t.Errorf("%v != 0", len(b.records)) 248 | } 249 | if c.calls != 3 { 250 | t.Errorf("%v != 3", c.calls) 251 | } 252 | } 253 | 254 | func TestBatchSize(t *testing.T) { 255 | t.Parallel() 256 | c := &mockBatchingClient{} 257 | b := newProducer(c, 100, 0, 5) 258 | b.Start() 259 | defer b.Stop() 260 | 261 | b.addRecordsAndWait(4, 2) 262 | if len(b.records) != 4 { 263 | t.Errorf("%v != 4", len(b.records)) 264 | } 265 | if c.calls != 0 { 266 | t.Errorf("%v != 0", c.calls) 267 | } 268 | 269 | b.addRecordsAndWait(1, 2) 270 | if len(b.records) != 0 { 271 | t.Errorf("%v != 0", len(b.records)) 272 | } 273 | if c.calls != 1 { 274 | t.Errorf("%v != 1", c.calls) 275 | } 276 | 277 | b.addRecordsAndWait(6, 2) 278 | if len(b.records) != 1 { 279 | t.Errorf("%v != 1", len(b.records)) 280 | } 281 | if c.calls != 2 { 282 | t.Errorf("%v != 2", c.calls) 283 | } 284 | 285 | b.addRecordsAndWait(19, 2) 286 | if len(b.records) != 0 { 287 | t.Errorf("%v != 0", len(b.records)) 288 | } 289 | if c.calls != 6 { 290 | t.Errorf("%v != 6", c.calls) 291 | } 292 | } 293 | 294 | func TestBatchError(t *testing.T) { 295 | t.Parallel() 296 | c := &mockBatchingClient{shouldErr: true} 297 | b := newProducer(c, 100, 0, 5) 298 | b.Start() 299 | defer b.Stop() 300 | 301 | b.addRecordsAndWait(5, 5) 302 | if b.consecutiveErrors != 1 { 303 | t.Errorf("%v != 1", b.consecutiveErrors) 304 | } 305 | if len(b.records) != 5 { 306 | t.Errorf("%v != 5", len(b.records)) 307 | } 308 | 309 | // Wait another 55 ms and another error should have occurred 310 | time.Sleep(55 * time.Millisecond) 311 | if b.consecutiveErrors != 2 { 312 | t.Errorf("%v != 2", b.consecutiveErrors) 313 | } 314 | if len(b.records) != 5 { 315 | t.Errorf("%v != 5", len(b.records)) 316 | } 317 | 318 | b.Stop() 319 | b.client = &mockBatchingClient{shouldErr: false} 320 | b.Start() 321 | 322 | time.Sleep(205 * time.Millisecond) 323 | if b.consecutiveErrors != 0 { 324 | t.Errorf("%v != 0", b.consecutiveErrors) 325 | } 326 | if len(b.records) != 0 { 327 | t.Errorf("%v != 0", len(b.records)) 328 | } 329 | 330 | // This next batch should succeed immediately 331 | b.addRecordsAndWait(5, 1) 332 | if b.consecutiveErrors != 0 { 333 | t.Errorf("%v != 0", b.consecutiveErrors) 334 | } 335 | if len(b.records) != 0 { 336 | t.Errorf("%v != 0", len(b.records)) 337 | } 338 | } 339 | 340 | func TestBatchPartialFailure(t *testing.T) { 341 | t.Parallel() 342 | b := newProducer(&mockBatchingClient{}, 100, 0, 20) 343 | b.config.MaxAttemptsPerRecord = 2 344 | b.Start() 345 | defer b.Stop() 346 | 347 | b.addRecordsAndWait(19, 0) 348 | 349 | // Add a single record that will fail. partitionKey is (mis)used to specify that the record 350 | // should fail. 351 | b.Add([]byte("foo"), "fail") 352 | 353 | // First attempt 354 | time.Sleep(5 * time.Millisecond) 355 | if len(b.records) != 1 { 356 | t.Errorf("%v != 1", len(b.records)) 357 | } 358 | 359 | // Second attempt 360 | b.addRecordsAndWait(19, 1) 361 | // The failing record should be thrown away at this point 362 | if len(b.records) != 0 { 363 | t.Errorf("%v != 0", len(b.records)) 364 | } 365 | } 366 | 367 | func TestBufferSizeStat(t *testing.T) { 368 | t.Parallel() 369 | 370 | sr := &statReceiver{} 371 | 372 | b := newProducer(&mockBatchingClient{}, 100, 0, 20) 373 | b.config.StatReceiver = sr 374 | b.config.StatInterval = 1 * time.Millisecond 375 | b.Start() 376 | defer b.Stop() 377 | 378 | // Adding 10 will not trigger a batch 379 | b.addRecordsAndWait(10, 2) 380 | 381 | if len(sr.stats) == 0 { 382 | // More than one might have been sent, which is fine. We just need at least one. 383 | t.Fatalf("%v == 0", len(sr.stats)) 384 | } 385 | 386 | lastStat := sr.stats[len(sr.stats)-1] 387 | if lastStat.BufferSize != 10 { 388 | t.Errorf("%v != 10", lastStat.BufferSize) 389 | } 390 | 391 | // Adding another 10 **will** trigger a batch 392 | b.addRecordsAndWait(10, 2) 393 | 394 | if len(sr.stats) < 2 { 395 | t.Fatalf("%v < 2", len(sr.stats)) 396 | } 397 | 398 | lastStat = sr.stats[len(sr.stats)-1] 399 | if lastStat.BufferSize != 0 { 400 | t.Errorf("%v != 0", lastStat.BufferSize) 401 | } 402 | } 403 | 404 | func TestSuccessfulRecordsStat(t *testing.T) { 405 | t.Parallel() 406 | 407 | sr := &statReceiver{} 408 | b := newProducer(&mockBatchingClient{}, 100, 0, 20) 409 | b.config.StatReceiver = sr 410 | b.config.StatInterval = 1 * time.Millisecond 411 | // b.logger = stdoutLogger // TEMP TEMP TEMP 412 | b.Start() 413 | defer b.Stop() 414 | 415 | // Adding 10 will not trigger a batch 416 | b.addRecordsAndWait(10, 2) 417 | 418 | if len(sr.stats) == 0 { 419 | // More than one might have been sent, which is fine. We just need at least one. 420 | t.Fatalf("%v == 0", len(sr.stats)) 421 | } 422 | 423 | lastStat := sr.stats[len(sr.stats)-1] 424 | if lastStat.RecordsSentSuccessfullySinceLastStat != 0 { 425 | t.Errorf("%v != 0", lastStat.RecordsSentSuccessfullySinceLastStat) 426 | } 427 | 428 | // Adding another 10 **will** trigger a batch 429 | b.addRecordsAndWait(10, 2) 430 | 431 | if len(sr.stats) < 2 { 432 | t.Fatalf("%v < 2", len(sr.stats)) 433 | } 434 | 435 | if sr.totalRecordsSentSuccessfully != 20 { 436 | t.Errorf("%v != 20", sr.totalRecordsSentSuccessfully) 437 | } 438 | } 439 | 440 | func TestSuccessfulRecordsStatWhenSomeRecordsFail(t *testing.T) { 441 | t.Parallel() 442 | 443 | sr := &statReceiver{} 444 | b := newProducer(&mockBatchingClient{}, 100, 0, 20) 445 | b.config.StatReceiver = sr 446 | b.config.StatInterval = 1 * time.Millisecond 447 | b.config.MaxAttemptsPerRecord = 2 448 | b.Start() 449 | defer b.Stop() 450 | 451 | b.addRecordsAndWait(19, 0) 452 | 453 | // Add a single record that will fail. partitionKey is (mis)used to specify that the record 454 | // should fail. 455 | b.Add([]byte("foo"), "fail") 456 | 457 | // Sleep long enough for multiple attempts to be tried 458 | time.Sleep(3 * time.Millisecond) 459 | 460 | // Should be 10 because one record failed 461 | if sr.totalRecordsSentSuccessfully != 19 { 462 | t.Errorf("%v != 19", sr.totalRecordsSentSuccessfully) 463 | } 464 | } 465 | 466 | func TestRecordsDroppedStatWhenSomeRecordsFail(t *testing.T) { 467 | t.Parallel() 468 | 469 | sr := &statReceiver{} 470 | b := newProducer(&mockBatchingClient{}, 100, 0, 20) 471 | b.config.StatReceiver = sr 472 | b.config.StatInterval = 1 * time.Millisecond 473 | b.config.MaxAttemptsPerRecord = 1 474 | b.Start() 475 | defer b.Stop() 476 | 477 | b.addRecordsAndWait(18, 0) 478 | 479 | // Add two records that will fail. partitionKey is (mis)used to specify that the record 480 | // should fail. 481 | b.Add([]byte("foo"), "fail") 482 | b.Add([]byte("foo"), "fail") 483 | 484 | // Sleep long enough for an attempt to be tried and the stat to be recieved 485 | time.Sleep(5 * time.Millisecond) 486 | 487 | if sr.totalRecordsDroppedSinceLastStat != 2 { 488 | t.Errorf("%v != 2", sr.totalRecordsDroppedSinceLastStat) 489 | } 490 | } 491 | 492 | func TestSuccessfulRecordsStatWhenKinesisReturnsError(t *testing.T) { 493 | t.Parallel() 494 | 495 | sr := &statReceiver{} 496 | b := newProducer(&mockBatchingClient{shouldErr: true}, 100, 0, 20) 497 | b.config.StatReceiver = sr 498 | b.config.StatInterval = 1 * time.Millisecond 499 | b.Start() 500 | defer b.Stop() 501 | 502 | // Adding 20 **will** trigger a batch 503 | b.addRecordsAndWait(20, 50) 504 | 505 | if len(sr.stats) < 1 { 506 | t.Fatalf("%v < 1", len(sr.stats)) 507 | } 508 | 509 | // Should be 0 because Kinesis is just returning errors 510 | if sr.totalRecordsSentSuccessfully != 0 { 511 | t.Errorf("%v != 0", sr.totalRecordsSentSuccessfully) 512 | } 513 | } 514 | 515 | func TestKinesisErrorsStatWhenKinesisSucceeds(t *testing.T) { 516 | t.Parallel() 517 | 518 | sr := &statReceiver{} 519 | b := newProducer(&mockBatchingClient{shouldErr: false}, 100, 0, 20) 520 | b.config.StatReceiver = sr 521 | b.config.StatInterval = 1 * time.Millisecond 522 | b.Start() 523 | defer b.Stop() 524 | 525 | // Adding 20 **will** trigger a batch 526 | b.addRecordsAndWait(20, 2) 527 | 528 | if len(sr.stats) < 1 { 529 | t.Fatalf("%v < 1", len(sr.stats)) 530 | } 531 | 532 | // Should be 0 because Kinesis is succeeding 533 | if sr.totalKinesisErrorsSinceLastStat != 0 { 534 | t.Errorf("%v != 0", sr.totalKinesisErrorsSinceLastStat) 535 | } 536 | } 537 | 538 | func TestKinesisErrorsStatWhenKinesisReturnsError(t *testing.T) { 539 | t.Parallel() 540 | 541 | sr := &statReceiver{} 542 | b := newProducer(&mockBatchingClient{shouldErr: true}, 100, 0, 20) 543 | b.config.StatReceiver = sr 544 | b.config.StatInterval = 1 * time.Millisecond 545 | b.Start() 546 | defer b.Stop() 547 | 548 | b.addRecordsAndWait(20, 5) 549 | b.Stop() 550 | 551 | if sr.totalKinesisErrorsSinceLastStat != 2 { 552 | t.Errorf("%v != 2", sr.totalKinesisErrorsSinceLastStat) 553 | } 554 | } 555 | 556 | func TestLogMessageWhenKinesisSucceeds(t *testing.T) { 557 | t.Parallel() 558 | 559 | b := newProducer(&mockBatchingClient{shouldErr: false}, 100, 0, 20) 560 | loggerBuffer, logger := newBufferedLogger() 561 | b.logger = logger 562 | b.Start() 563 | defer b.Stop() 564 | 565 | // Adding 20 **will** trigger a batch 566 | b.addRecordsAndWait(20, 2) 567 | 568 | loggerString := loggerBuffer.String() 569 | requiredString := "PutRecords request succeeded: sent 20 records to Kinesis stream" 570 | if !strings.Contains(loggerString, requiredString) { 571 | t.Errorf("%s does not contain %s", loggerString, requiredString) 572 | } 573 | } 574 | 575 | func TestLogMessageWhenKinesisReturnsError(t *testing.T) { 576 | t.Parallel() 577 | 578 | b := newProducer(&mockBatchingClient{shouldErr: true}, 100, 0, 20) 579 | loggerBuffer, logger := newBufferedLogger() 580 | b.logger = logger 581 | b.Start() 582 | defer b.Stop() 583 | 584 | // Adding 20 **will** trigger a batch 585 | b.addRecordsAndWait(20, 2) 586 | 587 | loggerString := loggerBuffer.String() 588 | requiredString := "Error occurred when sending PutRecords request" 589 | if !strings.Contains(loggerString, requiredString) { 590 | t.Errorf("%s does not contain %s", loggerString, requiredString) 591 | } 592 | } 593 | 594 | func TestLogMessageWhenSomeRecordsFail(t *testing.T) { 595 | t.Parallel() 596 | 597 | sr := &statReceiver{} 598 | b := newProducer(&mockBatchingClient{}, 100, 2*time.Millisecond, 20) 599 | b.config.StatReceiver = sr 600 | b.config.StatInterval = 1 * time.Millisecond 601 | b.config.MaxAttemptsPerRecord = 2 602 | loggerBuffer, logger := newBufferedLogger() 603 | b.logger = logger 604 | b.Start() 605 | defer b.Stop() 606 | 607 | b.addRecordsAndWait(18, 0) 608 | 609 | // Add two records that will fail. partitionKey is (mis)used to specify that the record 610 | // should fail. 611 | b.Add([]byte("foo"), "fail") 612 | b.Add([]byte("foo"), "fail") 613 | 614 | // Sleep long enough for a few attempts to be tried and the failing records to be re-enqueued 615 | // and then dropped 616 | time.Sleep(5 * time.Millisecond) 617 | 618 | loggerString := loggerBuffer.String() 619 | 620 | requiredString := "Partial success when sending a PutRecords request" 621 | if !strings.Contains(loggerString, requiredString) { 622 | t.Errorf("%s does not contain %s", loggerString, requiredString) 623 | } 624 | 625 | requiredString = "Re-enqueueing failed record to buffer for retry. Error code was: 'foo'" 626 | if !strings.Contains(loggerString, requiredString) { 627 | t.Errorf("%s does not contain %s", loggerString, requiredString) 628 | } 629 | 630 | requiredString = "Dropping failed record; it has hit 2 attempts which is the maximum" 631 | if !strings.Contains(loggerString, requiredString) { 632 | t.Errorf("%s does not contain %s", loggerString, requiredString) 633 | } 634 | } 635 | 636 | func TestAddBlocksFalse(t *testing.T) { 637 | t.Parallel() 638 | 639 | b := newProducer(&mockBatchingClient{}, 10, 0, 20) 640 | b.Start() 641 | defer b.Stop() 642 | 643 | // Adding 10 will fill up the buffer and not trigger a batch 644 | b.addRecordsAndWait(10, 2) 645 | 646 | data := []byte("The cheese is old and moldy, where is the bathroom?") 647 | partitionKey := "foo" 648 | err := b.Add(data, partitionKey) 649 | 650 | if err == nil { 651 | t.Errorf("%s == nil", err) 652 | } 653 | } 654 | 655 | func TestAddBlocksTrue(t *testing.T) { 656 | t.Parallel() 657 | 658 | b := newProducer(&mockBatchingClient{}, 10, 0, 20) 659 | b.config.AddBlocksWhenBufferFull = true 660 | b.Start() 661 | defer b.Stop() 662 | 663 | // Adding 10 will fill up the buffer and not trigger a batch 664 | b.addRecordsAndWait(10, 2) 665 | 666 | // This should block so we need to run this in a goroutine 667 | go func() { 668 | data := []byte("The cheese is old and moldy, where is the bathroom?") 669 | partitionKey := "foo" 670 | b.Add(data, partitionKey) 671 | t.Fatal("We should never have gotten here.") 672 | }() 673 | 674 | time.Sleep(1 * time.Millisecond) 675 | 676 | if len(b.records) != 10 { 677 | t.Errorf("%v != 10", len(b.records)) 678 | } 679 | } 680 | 681 | func TestFlush(t *testing.T) { 682 | t.Parallel() 683 | 684 | b := newProducer(&mockBatchingClient{}, 20, 0, 20) 685 | b.Start() 686 | defer b.Stop() 687 | 688 | // Adding 10 will not trigger a batch 689 | b.addRecordsAndWait(10, 2) 690 | 691 | timeout := 20 * time.Second 692 | sent, remaining, err := b.Flush(timeout, false) 693 | if err != nil { 694 | t.Errorf("%s != nil", err) 695 | } 696 | 697 | if sent != 10 { 698 | t.Errorf("%v != 10", sent) 699 | } 700 | if remaining > 0 { 701 | t.Errorf("%v > 0", remaining) 702 | } 703 | if len(b.records) > 0 { 704 | t.Errorf("%v > 0", len(b.records)) 705 | } 706 | if b.isRunning() { 707 | t.Errorf("b.running != false") 708 | } 709 | } 710 | 711 | func TestFlushWithTimeout(t *testing.T) { 712 | t.Parallel() 713 | 714 | c := &mockBatchingClient{ 715 | sleepFor: 6 * time.Millisecond, 716 | } 717 | b := newProducer(c, 1000, 0, 10) 718 | 719 | // set running to true so Add will succeed 720 | b.running = true 721 | 722 | // Adding 600 will enqueue 2 batches 723 | b.addRecordsAndWait(600, 0) 724 | 725 | // back to normal 726 | b.running = false 727 | 728 | // This should lead to only 1 batch of 500 being sent by Flush 729 | timeout := 5 * time.Millisecond 730 | 731 | start := time.Now() 732 | sent, remaining, err := b.Flush(timeout, false) 733 | duration := time.Since(start) 734 | if err != nil { 735 | t.Errorf("%s != nil", err) 736 | } 737 | 738 | if sent != 500 { 739 | t.Errorf("%v != 500", sent) 740 | } 741 | if remaining != 100 { 742 | t.Errorf("%v != 100", remaining) 743 | } 744 | if len(b.records) != 100 { 745 | t.Errorf("%v != 100", len(b.records)) 746 | } 747 | if duration < 6*time.Millisecond || duration > 8*time.Millisecond { 748 | t.Errorf("%v seems off", duration) 749 | } 750 | } 751 | 752 | func TestFlushWithoutTimeout(t *testing.T) { 753 | t.Parallel() 754 | 755 | c := &mockBatchingClient{ 756 | sleepFor: 6 * time.Millisecond, 757 | } 758 | b := newProducer(c, 1000, 0, 10) 759 | 760 | // set running to true so Add will succeed 761 | b.running = true 762 | 763 | // Adding 600 will enqueue 2 batches 764 | b.addRecordsAndWait(600, 0) 765 | 766 | // back to normal 767 | b.running = false 768 | 769 | // This should lead to batches of 500 and 100 being sent by Flush 770 | timeout := 0 * time.Millisecond 771 | 772 | start := time.Now() 773 | sent, remaining, err := b.Flush(timeout, false) 774 | duration := time.Since(start) 775 | if err != nil { 776 | t.Errorf("%s != nil", err) 777 | } 778 | 779 | if sent != 600 { 780 | t.Errorf("%v != 600", sent) 781 | } 782 | if remaining != 0 { 783 | t.Errorf("%v != 0", remaining) 784 | } 785 | if len(b.records) != 0 { 786 | t.Errorf("%v != 0", len(b.records)) 787 | } 788 | if duration < 12*time.Millisecond || duration > 16*time.Millisecond { 789 | t.Errorf("%v seems off", duration) 790 | } 791 | } 792 | 793 | type mockBatchingClient struct { 794 | calls int 795 | callsMu sync.Mutex 796 | shouldErr bool 797 | numToFail int 798 | sleepFor time.Duration 799 | } 800 | 801 | func (s *mockBatchingClient) PutRecords(args *kinesis.RequestArgs) (resp *kinesis.PutRecordsResp, err error) { 802 | s.callsMu.Lock() 803 | defer s.callsMu.Unlock() 804 | s.calls++ 805 | 806 | if s.shouldErr { 807 | return nil, errors.New("Oh Noes!") 808 | } 809 | 810 | time.Sleep(s.sleepFor) 811 | 812 | res := kinesis.PutRecordsResp{Records: make([]kinesis.PutRecordsRespRecord, len(args.Records))} 813 | 814 | for i, record := range args.Records { 815 | if record.PartitionKey == "fail" { 816 | res.FailedRecordCount++ 817 | res.Records[i] = kinesis.PutRecordsRespRecord{ErrorCode: "foo", ErrorMessage: "bar"} 818 | } else { 819 | res.Records[i] = kinesis.PutRecordsRespRecord{SequenceNumber: "001", ShardId: "001"} 820 | } 821 | } 822 | 823 | return &res, nil 824 | } 825 | 826 | func newProducer(client *mockBatchingClient, bufferSize int, flushInterval time.Duration, batchSize int) *batchProducer { 827 | config := Config{ 828 | BufferSize: bufferSize, 829 | // Set FlushInterval to an interval that will be acceptable to New; we’ll override it below 830 | // after calling New. 831 | FlushInterval: 50 * time.Millisecond, 832 | BatchSize: batchSize, 833 | Logger: discardLogger, 834 | MaxAttemptsPerRecord: 2, 835 | } 836 | 837 | producer, err := New(client, "foo", config) 838 | if err != nil { 839 | panic(err) 840 | } 841 | 842 | bp, ok := producer.(*batchProducer) 843 | if !ok { 844 | panic("producer is not a *batchProducer!") 845 | } 846 | 847 | bp.config.FlushInterval = flushInterval 848 | 849 | return bp 850 | } 851 | 852 | // There are some cases wherein immediately after adding the records we want to sleep for some 853 | // amount of time in order to allow for the batchProducer’s goroutine to do stuff. 854 | // A possible alternative approach might be to run with multiple CPUs... but that would probably 855 | // still require waiting for at least some small amount of time. And in fact it would be way 856 | // less deterministic and less predictable. 857 | func (b *batchProducer) addRecordsAndWait(numRecords int, millisToWait int) { 858 | data := []byte("The cheese is old and moldy, where is the bathroom?") 859 | partitionKey := "foo" 860 | for i := 0; i < numRecords; i++ { 861 | err := b.Add(data, partitionKey) 862 | if err != nil { 863 | panic(err) 864 | } 865 | } 866 | 867 | if millisToWait > 0 { 868 | time.Sleep(time.Duration(millisToWait) * time.Millisecond) 869 | } 870 | } 871 | 872 | type statReceiver struct { 873 | stats []StatsBatch 874 | totalKinesisErrorsSinceLastStat int 875 | totalRecordsSentSuccessfully int 876 | totalRecordsDroppedSinceLastStat int 877 | } 878 | 879 | func (s *statReceiver) Receive(sf StatsBatch) { 880 | s.stats = append(s.stats, sf) 881 | s.totalKinesisErrorsSinceLastStat += sf.KinesisErrorsSinceLastStat 882 | s.totalRecordsSentSuccessfully += sf.RecordsSentSuccessfullySinceLastStat 883 | s.totalRecordsDroppedSinceLastStat += sf.RecordsDroppedSinceLastStat 884 | } 885 | 886 | func newBufferedLogger() (*bytes.Buffer, *log.Logger) { 887 | buf := new(bytes.Buffer) 888 | logger := log.New(buf, "", 0) 889 | return buf, logger 890 | } 891 | --------------------------------------------------------------------------------