├── doc ├── scale.odc ├── template.odg ├── flow-phase1.odg └── scale.svg ├── http2x ├── doc.go ├── reflect_test.go └── reflect.go ├── cryptox ├── doc.go ├── test_data │ └── pk_valid.p8 ├── crypto_test.go └── crypto.go ├── syncx ├── doc.go ├── counter_test.go ├── ticktock_test.go ├── counter.go └── ticktock.go ├── funit ├── funit.go ├── time.go ├── time_test.go ├── prefix.go └── size.go ├── .gitignore ├── .travis.yml ├── apns2 ├── comm_test.go ├── http_test.go ├── dispatch_test.go ├── result.go ├── backoff.go ├── request.go ├── client_test.go ├── log.go ├── test_harness.go ├── payload.go ├── comm.go ├── auth_test.go ├── auth.go ├── backoff_test.go ├── response.go ├── notification.go ├── http.go ├── streamer.go ├── client.go └── dispatch.go ├── LICENSE ├── example └── fire-and-forget │ └── main.go ├── scale ├── scale.go └── scale_test.go └── README.md /doc/scale.odc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baobabus/go-apns/HEAD/doc/scale.odc -------------------------------------------------------------------------------- /doc/template.odg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baobabus/go-apns/HEAD/doc/template.odg -------------------------------------------------------------------------------- /doc/flow-phase1.odg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baobabus/go-apns/HEAD/doc/flow-phase1.odg -------------------------------------------------------------------------------- /http2x/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | // Package http2x provides extensions to standard http2 functionality. 4 | package http2x 5 | -------------------------------------------------------------------------------- /cryptox/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | // Package cryptox provides extensions to standard crypto functionality. 4 | package cryptox 5 | -------------------------------------------------------------------------------- /syncx/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | // Package syncx provides extensions to standard synchronization primitives. 4 | package syncx 5 | -------------------------------------------------------------------------------- /funit/funit.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | // Package funit provides idiomatic way of explicitly expressing 4 | // various measures in terms of their units. 5 | package funit 6 | 7 | type Measure float64 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | *.cgo1.go 11 | *.cgo2.c 12 | _cgo_defun.c 13 | _cgo_gotypes.go 14 | _cgo_export.* 15 | 16 | *.exe 17 | *.test 18 | *.prof 19 | -------------------------------------------------------------------------------- /cryptox/test_data/pk_valid.p8: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgEbVzfPnZPxfAyxqE 3 | ZV05laAoJAl+/6Xt2O4mOB611sOhRANCAASgFTKjwJAAU95g++/vzKWHkzAVmNMI 4 | tB5vTjZOOIwnEb70MsWZFIyUFD1P9Gwstz4+akHX7vI8BH6hHmBmfeQl 5 | -----END PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.7.x 5 | - 1.8.x 6 | - 1.9.x 7 | 8 | before_install: 9 | - go get golang.org/x/crypto/pkcs12 10 | - go get golang.org/x/net/http2 11 | - go get golang.org/x/net/idna 12 | - go get github.com/dgrijalva/jwt-go 13 | - go get github.com/stretchr/testify/assert 14 | - go get github.com/baobabus/go-apnsmock/apns2mock 15 | 16 | os: 17 | - linux 18 | -------------------------------------------------------------------------------- /cryptox/crypto_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package cryptox 4 | 5 | import ( 6 | "crypto/ecdsa" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestPKCS8PrivateKeyFromFile(t *testing.T) { 13 | s, err := PKCS8PrivateKeyFromFile("test_data/pk_valid.p8") 14 | assert.NoError(t, err) 15 | assert.IsType(t, &ecdsa.PrivateKey{}, s) 16 | } 17 | -------------------------------------------------------------------------------- /syncx/counter_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package syncx 4 | 5 | import ( 6 | "testing" 7 | ) 8 | 9 | // Basic non-contention tests for Counter primitive. 10 | 11 | func TestCounter(t *testing.T) { 12 | var subj Counter 13 | if subj != 0 { 14 | t.Fatalf("Bad zero value %v", subj) 15 | } 16 | subj.Add(1) 17 | if subj != 1 { 18 | t.Fatalf("Bad tick %v", subj) 19 | } 20 | subj.Add(9) 21 | if subj != 10 { 22 | t.Fatalf("Bad tock %v", subj) 23 | } 24 | subj.Add(1) 25 | i := subj.Draw() 26 | if subj != 0 || i != 11 { 27 | t.Fatalf("Bad draw %v %v %v", subj, i) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /funit/time.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package funit 4 | 5 | import "time" 6 | 7 | const ( 8 | Second Measure = 1.0 9 | Minute = 60.0 * Second 10 | Hour = 60.0 * Minute 11 | Sec = Second 12 | Min = Minute 13 | Hr = Hour 14 | Millisecond = Milli * Second 15 | Microsecond = Micro * Second 16 | Nanosecond = Nano * Second 17 | Picosecond = Pico * Second 18 | Femtosecond = Femto * Second 19 | ) 20 | 21 | func (m Measure) AsDuration() time.Duration { 22 | return time.Duration(1000000000.0 * m) 23 | } 24 | -------------------------------------------------------------------------------- /http2x/reflect_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package http2x 4 | 5 | import ( 6 | "net/http" 7 | "testing" 8 | 9 | "golang.org/x/net/http2" 10 | ) 11 | 12 | func TestGetClientConnPool(t *testing.T) { 13 | res, err := GetClientConnPool(nil) 14 | if res != nil || err == nil { 15 | t.Fatal("Should have failed to get connection") 16 | } 17 | if err != ErrUnsupportedTransport { 18 | t.Fatal("Wrong error: ", err) 19 | } 20 | res, err = GetClientConnPool(http.DefaultTransport) 21 | if res != nil || err == nil { 22 | t.Fatal("Should have failed to get connection") 23 | } 24 | if err != ErrUnsupportedTransport { 25 | t.Fatal("Wrong error: ", err) 26 | } 27 | tr := &http2.Transport{} 28 | res, err = GetClientConnPool(tr) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | if res == nil { 33 | t.Fatal("Should have gotten connection pool") 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /funit/time_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package funit 4 | 5 | import ( 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestTimeAsDuration(t *testing.T) { 13 | assert.Exactly(t, time.Second, Second.AsDuration()) 14 | assert.Exactly(t, time.Second, (1000 * Millisecond).AsDuration()) 15 | assert.Exactly(t, time.Minute, Minute.AsDuration()) 16 | assert.Exactly(t, time.Minute, (60 * Second).AsDuration()) 17 | assert.Exactly(t, time.Hour, Hour.AsDuration()) 18 | assert.Exactly(t, time.Hour, (60 * Minute).AsDuration()) 19 | assert.Exactly(t, time.Hour, (3600 * Second).AsDuration()) 20 | assert.Exactly(t, time.Hour+30*time.Minute, (90 * Minute).AsDuration()) 21 | assert.Exactly(t, time.Millisecond, Millisecond.AsDuration()) 22 | assert.Exactly(t, time.Microsecond, Microsecond.AsDuration()) 23 | assert.Exactly(t, time.Nanosecond, Nanosecond.AsDuration()) 24 | } 25 | -------------------------------------------------------------------------------- /apns2/comm_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "testing" 7 | 8 | "golang.org/x/net/http2" 9 | ) 10 | 11 | func TestDialOk(t *testing.T) { 12 | s := mustNewMockServer(t) 13 | defer s.Close() 14 | d := makeDialer(commsTest_Fast) 15 | tc := s.Client().Transport.(*http2.Transport).TLSClientConfig 16 | c, err := d("tcp", s.URL[8:], tc) 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | if c == nil { 21 | t.Fatal("Should have connected") 22 | } 23 | } 24 | 25 | func TestDialTimeout(t *testing.T) { 26 | s := mustNewMockServerWithCfg(t, apnsMockComms_30ms) 27 | defer s.Close() 28 | d := makeDialer(commsTest_Fast) 29 | tc := s.Client().Transport.(*http2.Transport).TLSClientConfig 30 | c, err := d("tcp", s.URL[8:], tc) 31 | if err == nil || err.Error() != "tls: DialWithDialer timed out" { 32 | t.Fatal("Should have gotten error tls: DialWithDialer timed out") 33 | } 34 | if c == nil { 35 | t.Fatal("Should not have connected") 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017 Aleksey Blinov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /apns2/http_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestGetClientConnNoHTTP2Incursion(t *testing.T) { 12 | s := mustNewMockServer(t) 13 | defer s.Close() 14 | c := mustNewHTTPClient(t, s) 15 | cc, err := c.getClientConn() 16 | if err != nil { 17 | t.Fatal(err) 18 | } 19 | if cc != nil { 20 | t.Fatal("Should not have gotten a connection") 21 | } 22 | } 23 | 24 | func TestGetClientConn(t *testing.T) { 25 | s := mustNewMockServer(t) 26 | defer s.Close() 27 | c := mustNewHTTPClient(t, s) 28 | c.precise = true 29 | cc, err := c.getClientConn() 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | if cc == nil { 34 | t.Fatal("Should have gotten a connection") 35 | } 36 | } 37 | 38 | func TestReservedStreamNoContention(t *testing.T) { 39 | s := mustNewMockServer(t) 40 | defer s.Close() 41 | c := mustNewHTTPClient(t, s) 42 | st, err := c.ReservedStream(nil) 43 | if err != nil { 44 | t.Fatal(err) 45 | } 46 | assert.Equal(t, uint32(1), c.cnt) 47 | st.Close() 48 | assert.Equal(t, uint32(0), c.cnt) 49 | } 50 | -------------------------------------------------------------------------------- /syncx/ticktock_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package syncx 4 | 5 | import ( 6 | "testing" 7 | ) 8 | 9 | // Basic non-contention tests for TickTock primitives. 10 | 11 | func TestTickTockCounter(t *testing.T) { 12 | var subj TickTockCounter 13 | if subj != 0 { 14 | t.Fatalf("Bad zero value %v", subj) 15 | } 16 | subj.Tick() 17 | if subj != 0x0100000000 { 18 | t.Fatalf("Bad tick %v", subj) 19 | } 20 | subj.Tock() 21 | if subj != 0x0100000001 { 22 | t.Fatalf("Bad tock %v", subj) 23 | } 24 | subj.Tick() 25 | i, o := subj.Fold() 26 | if subj != 0x0100000000 || i != 2 || o != 1 { 27 | t.Fatalf("Bad fold %v %v %v", subj, i, o) 28 | } 29 | } 30 | 31 | func TestTickTockFolder(t *testing.T) { 32 | var subj TickTockFolder 33 | if subj != 0 { 34 | t.Fatalf("Bad zero value %v", subj) 35 | } 36 | subj.Tick() 37 | if subj != 1 { 38 | t.Fatalf("Bad tick %v", subj) 39 | } 40 | subj.Tock() 41 | if subj != 0x0100000000 { 42 | t.Fatalf("Bad tock %v", subj) 43 | } 44 | subj.Tick() 45 | if subj != 0x0100000001 { 46 | t.Fatalf("Bad second tick %v", subj) 47 | } 48 | subj.Tick() 49 | c, p := subj.Draw() 50 | if subj != 2 || c != 1 || p != 2 { 51 | t.Fatalf("Bad draw %v %v %v", subj, c, p) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /funit/prefix.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package funit 4 | 5 | // Base-10 suffixes. 6 | const ( 7 | Kilo Measure = 1000.0 8 | Mega = 1000.0 * Kilo 9 | Giga = 1000.0 * Mega 10 | Tera = 1000.0 * Giga 11 | Peta = 1000.0 * Tera 12 | Exa = 1000.0 * Peta 13 | Zetta = 1000.0 * Exa 14 | Yotta = 1000.0 * Zetta 15 | Milli = 1.0 / 1000.0 16 | Micro = Milli / 1000.0 17 | Nano = Micro / 1000.0 18 | Pico = Nano / 1000.0 19 | Femto = Pico / 1000.0 20 | ) 21 | 22 | // Base-2 suffixes - long and short. 23 | const ( 24 | Kibi Measure = 1024.0 25 | Mebi = 1024.0 * Kibi 26 | Gibi = 1024.0 * Mebi 27 | Tebi = 1024.0 * Gibi 28 | Pebi = 1024.0 * Tebi 29 | Exbi = 1024.0 * Pebi 30 | Zebi = 1024.0 * Exbi 31 | Yobi = 1024.0 * Zebi 32 | Ki = Kibi 33 | Mi = Mebi 34 | Gi = Gibi 35 | Ti = Tebi 36 | Pi = Pebi 37 | Ei = Exbi 38 | Zi = Zebi 39 | Yi = Yobi 40 | ) 41 | 42 | // Ratios 43 | const ( 44 | Percent Measure = 0.01 45 | BasisPoint = 0.0001 46 | Pct = Percent 47 | BP = BasisPoint 48 | ) 49 | -------------------------------------------------------------------------------- /apns2/dispatch_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestMovingAcc(t *testing.T) { 12 | var v uint64 13 | var s *movingAcc 14 | // -1 samples 15 | s = newMovingAcc(-1) 16 | assert.Nil(t, s) 17 | // 0 samples 18 | s = newMovingAcc(0) 19 | assert.Nil(t, s) 20 | // 1 sample 21 | s = newMovingAcc(1) 22 | assert.Equal(t, 1, len(s.samples)) 23 | assert.Equal(t, uint64(0), s.sum) 24 | assert.Equal(t, 0, s.pos) 25 | v = s.accumulate(2) 26 | assert.Equal(t, uint64(2), s.sum) 27 | assert.Equal(t, 0, s.pos) 28 | assert.Equal(t, uint64(2), v) 29 | v = s.accumulate(4) 30 | assert.Equal(t, uint64(4), s.sum) 31 | assert.Equal(t, 0, s.pos) 32 | assert.Equal(t, uint64(4), v) 33 | // 2 samples 34 | s = newMovingAcc(2) 35 | assert.Equal(t, 2, len(s.samples)) 36 | assert.Equal(t, uint64(0), s.sum) 37 | assert.Equal(t, 0, s.pos) 38 | v = s.accumulate(2) 39 | assert.Equal(t, uint64(2), s.sum) 40 | assert.Equal(t, 1, s.pos) 41 | assert.Equal(t, uint64(2), v) 42 | v = s.accumulate(4) 43 | assert.Equal(t, uint64(6), s.sum) 44 | assert.Equal(t, 0, s.pos) 45 | assert.Equal(t, uint64(6), v) 46 | v = s.accumulate(6) 47 | assert.Equal(t, uint64(10), s.sum) 48 | assert.Equal(t, 1, s.pos) 49 | assert.Equal(t, uint64(10), v) 50 | } 51 | -------------------------------------------------------------------------------- /apns2/result.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "context" 7 | ) 8 | 9 | // Result represents the outcome of an asynchronous push operation. 10 | // The original notification is included along with any optional 11 | // agruments supplied to the push request. 12 | type Result struct { 13 | 14 | // Notification is the original notification. 15 | Notification *Notification 16 | 17 | // Signer is the one-off signer that was supplied in the push request. 18 | Signer RequestSigner 19 | 20 | // Context is the cancellation context instance passed to the original 21 | // push request. 22 | Context context.Context 23 | 24 | // Response represents a result from the APN service. If a push operation 25 | // fails prior to communicating with APN servers, Response will be nil and 26 | // Err field will have a non-nil value. 27 | Response *Response 28 | 29 | // Err, if not nil, is an error encontered while attempting a push. 30 | // Note that nil Err does not necessarily indicate a successful attempt. 31 | // You must also examine Response for additional status details. 32 | Err error 33 | } 34 | 35 | // IsAccepted returns whether or not the notification was accepted by APN service. 36 | func (r *Result) IsAccepted() bool { 37 | return r.Err == nil && r.Response != nil && r.Response.StatusCode == StatusAcccepted 38 | } 39 | -------------------------------------------------------------------------------- /syncx/counter.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package syncx 4 | 5 | import ( 6 | "sync/atomic" 7 | ) 8 | 9 | // Counter is a convenience unit64 value counter with atomic operations. 10 | // It is optimized for concurrent use by multiple incrementers, 11 | // but is restricted to a single concurrent consumer. Protect access 12 | // to the counter with a mutex if concurrent Draw attempts are anticipated. 13 | type Counter uint64 14 | 15 | // Add atomically adds the supplied value to the counter. 16 | // 17 | // This method is safe for use in concurrent gorotines. 18 | func (f *Counter) Add(v uint64) { 19 | atomic.AddUint64((*uint64)(f), v) 20 | } 21 | 22 | // Draw atomically draws the counter counter. The counter's value it set to 0 23 | // and its previous value is returned. 24 | // 25 | // This method is not safe for use in concurrent gorotines. It is however safe 26 | // for use concurrently with Add method. 27 | // If concurrent calls to Draw are anticipated they must be protected 28 | // by a mutex. 29 | func (f *Counter) Draw() uint64 { 30 | res := atomic.LoadUint64((*uint64)(f)) 31 | // It's possible for the count to have increased by this point, 32 | // but we are only subtracting the value previously read. 33 | // This is safe as long as we are not calling Draw concurrently from more 34 | // than one goroutine. 35 | atomic.AddUint64((*uint64)(f), ^(res - 1)) 36 | return res 37 | } 38 | -------------------------------------------------------------------------------- /apns2/backoff.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "math/rand" 7 | "time" 8 | 9 | "github.com/baobabus/go-apns/funit" 10 | ) 11 | 12 | type backOffTracker struct { 13 | initial time.Duration 14 | max time.Duration 15 | jitter funit.Measure 16 | current time.Duration 17 | end time.Time 18 | } 19 | 20 | func (t *backOffTracker) update(status error) { 21 | if status != nil { 22 | if now := time.Now(); now.After(t.end) { 23 | // Ignore any failures before end time as they may be coming 24 | // from a concurrent attempt. 25 | if t.current == 0 { 26 | t.current = t.initial 27 | } 28 | d := t.current 29 | if t.jitter > 0 { 30 | jtr := rand.Int63n(int64(funit.Measure(d) * t.jitter)) 31 | d += time.Duration(jtr) 32 | } 33 | if t.max > 0 && d > t.max { 34 | d = t.max 35 | } 36 | t.end = now.Add(d) 37 | t.current = t.current << 1 38 | if t.max > 0 && t.current > t.max { 39 | t.current = t.max 40 | } 41 | logTrace(1, "backoff", "backing off for %v until %v", d, t.end) 42 | } 43 | } else { 44 | if now := time.Now(); now.After(t.end) { 45 | // Ignore any success before end time as it may be coming 46 | // from a concurrent attempt. 47 | t.current = t.initial 48 | logTrace(1, "backoff", "resetting to &v", t.current) 49 | } 50 | } 51 | } 52 | 53 | func (t *backOffTracker) blackoutEnd() time.Time { 54 | return t.end 55 | } 56 | -------------------------------------------------------------------------------- /apns2/request.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "context" 7 | ) 8 | 9 | // Request holds all necessary information needed to submit a notification 10 | // to APN service. Requests can be directly submitted to Client's Queue. 11 | type Request struct { 12 | 13 | // Notification is the notification to push to APN service 14 | Notification *Notification 15 | 16 | // Signer, if not nil, is used to sign the request before submitting it 17 | // to APN service. If Signer is nil, but client's signer was configured 18 | // at the initialization time, the client's signer will sign the request. 19 | Signer RequestSigner 20 | 21 | // Context carries a deadline and a cancellation signal and allows you 22 | // to close long running requests when the context timeout is exceeded. 23 | // Context can be nil, for backwards compatibility. 24 | Context context.Context 25 | 26 | // Callback, if not nil, specifies the channel to which the outcome of 27 | // the push execution should be delivered. If Callback is nil and client's 28 | // Callback was configured at the initialization time, the result 29 | // will be delivered to client's Callback. 30 | Callback chan<- *Result 31 | 32 | attemptCnt int 33 | } 34 | 35 | // HasSigner returns true if the request has a custom signer supplied or if 36 | // no signing should be performed for this request. 37 | func (r *Request) HasSigner() bool { 38 | return r.Signer != DefaultSigner 39 | } 40 | 41 | // RequestError indicates a request-level error. This helps distinguishing 42 | // errors that are only scoped to a single request from those related to wider 43 | // scope, such as transport layer errors. 44 | type RequestError struct { 45 | error 46 | } 47 | -------------------------------------------------------------------------------- /example/fire-and-forget/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package main 4 | 5 | import ( 6 | "log" 7 | 8 | "github.com/baobabus/go-apns/apns2" 9 | "github.com/baobabus/go-apns/cryptox" 10 | ) 11 | 12 | func main() { 13 | 14 | // Load and parse our token signing key 15 | signingKey, err := cryptox.PKCS8PrivateKeyFromFile("token_signing_pk.p8") 16 | if err != nil { 17 | log.Fatal("Token signing key error: ", err) 18 | } 19 | 20 | // Set up our client 21 | client := &apns2.Client{ 22 | Gateway: apns2.Gateway.Production, 23 | Signer: &apns2.JWTSigner{ 24 | KeyID: "ABC123DEFG", // Your key ID 25 | TeamID: "DEF123GHIJ", // Your team ID 26 | SigningKey: signingKey, 27 | }, 28 | CommsCfg: apns2.CommsFast, 29 | ProcCfg: apns2.UnlimitedProcConfig, 30 | } 31 | 32 | // Start processing 33 | err = client.Start(nil) 34 | if err != nil { 35 | log.Fatal("Client start error: ", err) 36 | } 37 | 38 | // Mock motification and recipients 39 | header := &apns2.Header{Topic: "com.example.Alert"} 40 | payload := &apns2.Payload{APS: &apns2.APS{Alert: "Ping!"}} 41 | recipients := []string{ 42 | "00fc13adff785122b4ad28809a3420982341241421348097878e577c991de8f0", 43 | "10fc13adff785122b4ad28809a3420982341241421348097878e577c991de8f0", 44 | "20fc13adff785122b4ad28809a3420982341241421348097878e577c991de8f0", 45 | } 46 | 47 | // Push to all recipients 48 | for _, rcpt := range recipients { 49 | notif := &apns2.Notification{ 50 | Recipient: rcpt, 51 | Header: header, 52 | Payload: payload, 53 | } 54 | err := client.Push(notif, apns2.DefaultSigner, apns2.NoContext, apns2.DefaultCallback) 55 | if err != nil { 56 | log.Fatal("Push error: ", err) 57 | } 58 | } 59 | 60 | // Perform soft shutdown allowing the processing to complete. 61 | client.Stop() 62 | } 63 | -------------------------------------------------------------------------------- /apns2/client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/baobabus/go-apns/cryptox" 9 | "github.com/baobabus/go-apnsmock/apns2mock" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func mustNewClient_Signer_Good(t tester, s *apns2mock.Server) *Client { 14 | //t.Helper() 15 | tsk, err := cryptox.PKCS8PrivateKeyFromBytes([]byte(testTokenKey_Good)) 16 | if err != nil { 17 | t.Fatal(err) 18 | } 19 | res := &Client{ 20 | Gateway: s.URL, 21 | RootCA: s.RootCertificate, 22 | Signer: &JWTSigner{ 23 | KeyID: "ABC123DEFG", 24 | TeamID: "DEF123GHIJ", 25 | SigningKey: tsk, 26 | }, 27 | CommsCfg: commsTest_Fast, 28 | ProcCfg: MinBlockingProcConfig, 29 | Callback: NoCallback, 30 | } 31 | return res 32 | } 33 | 34 | func TestClient_Signer_Good_1(t *testing.T) { 35 | s := mustNewMockServer(t) 36 | defer s.Close() 37 | c := mustNewClient_Signer_Good(t, s) 38 | err := c.Start(nil) 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | defer c.Stop() 43 | tcs := []struct { 44 | ntf *Notification 45 | exp *Result 46 | cb chan *Result 47 | }{ 48 | { 49 | testNotif_Good, 50 | &Result{ 51 | Response: &Response{ 52 | StatusCode: 200, 53 | RejectionReason: "", 54 | }, 55 | Err: nil, 56 | }, 57 | make(chan *Result, 1), 58 | }, 59 | { 60 | testNotif_BadDevice, 61 | &Result{ 62 | Response: &Response{ 63 | StatusCode: 400, 64 | RejectionReason: ReasonBadDeviceToken, 65 | }, 66 | Err: nil, 67 | }, 68 | make(chan *Result, 1), 69 | }, 70 | } 71 | for _, tc := range tcs { 72 | err = c.Push(tc.ntf, DefaultSigner, NoContext, tc.cb) 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | r := <-tc.cb 77 | if r == nil && tc.exp != nil { 78 | t.Fatal("Should have gotten a result") 79 | } 80 | if r.Response == nil && tc.exp.Response != nil { 81 | t.Fatal("Should have gotten a response") 82 | } 83 | assert.Equal(t, tc.exp.Response.StatusCode, r.Response.StatusCode) 84 | assert.Equal(t, tc.exp.Response.RejectionReason, r.Response.RejectionReason) 85 | if r.Err != nil && tc.exp.Err == nil { 86 | t.Fatal("Error in result:", r.Err) 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /funit/size.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package funit 4 | 5 | const ( 6 | Bit Measure = 1.0 7 | Byte = 8.0 8 | Word = 16.0 9 | DWord = 32.0 10 | QWord = 64.0 11 | ) 12 | 13 | const ( 14 | Kilobit Measure = Kilo * Bit 15 | Megabit = Mega * Bit 16 | Gigabit = Giga * Bit 17 | Terabit = Tera * Bit 18 | Petabit = Peta * Bit 19 | Exabit = Exa * Bit 20 | Zettabit = Zetta * Bit 21 | Yottabit = Yotta * Bit 22 | Kb = Kilobit 23 | Mb = Megabit 24 | Gb = Gigabit 25 | Tb = Terabit 26 | Pb = Petabit 27 | Eb = Exabit 28 | Zb = Zettabit 29 | Yb = Yottabit 30 | ) 31 | 32 | const ( 33 | Kibibit Measure = Kibi * Bit 34 | Mebibit = Mebi * Bit 35 | Gibibit = Gibi * Bit 36 | Tebibit = Tebi * Bit 37 | Pebibit = Pebi * Bit 38 | Exbibit = Exbi * Bit 39 | Zebibit = Zebi * Bit 40 | Yobibit = Yobi * Bit 41 | Kib = Kibibit 42 | Mib = Mebibit 43 | Gib = Gibibit 44 | Tib = Tebibit 45 | Pib = Pebibit 46 | Eib = Exbibit 47 | Zib = Zebibit 48 | Yib = Yobibit 49 | ) 50 | 51 | const ( 52 | Kilobyte Measure = Kilo * Byte 53 | Megabyte = Mega * Byte 54 | Gigabyte = Giga * Byte 55 | Terabyte = Tera * Byte 56 | Petabyte = Peta * Byte 57 | Exabyte = Exa * Byte 58 | Zettabyte = Zetta * Byte 59 | Yottabyte = Yotta * Byte 60 | KB = Kilobyte 61 | MB = Megabyte 62 | GB = Gigabyte 63 | TB = Terabyte 64 | PB = Petabyte 65 | EB = Exabyte 66 | ZB = Zettabyte 67 | YB = Yottabyte 68 | ) 69 | 70 | const ( 71 | Kibibyte Measure = Kibi * Byte 72 | Mebibyte = Mebi * Byte 73 | Gibibyte = Gibi * Byte 74 | Tebibyte = Tebi * Byte 75 | Pebibyte = Pebi * Byte 76 | Exbibyte = Exbi * Byte 77 | Zebibyte = Zebi * Byte 78 | Yobibyte = Yobi * Byte 79 | KiB = Kibibyte 80 | MiB = Mebibyte 81 | GiB = Gibibyte 82 | TiB = Tebibyte 83 | PiB = Pebibyte 84 | EiB = Exbibyte 85 | ZiB = Zebibyte 86 | YiB = Yobibyte 87 | ) 88 | -------------------------------------------------------------------------------- /apns2/log.go: -------------------------------------------------------------------------------- 1 | package apns2 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "os" 7 | ) 8 | 9 | // Logger interface is extracted from log.Logger to aid in configuring 10 | // custom loggers for use in this package. 11 | type Logger interface { 12 | Fatal(v ...interface{}) 13 | Fatalf(format string, v ...interface{}) 14 | Fatalln(v ...interface{}) 15 | Flags() int 16 | Output(calldepth int, s string) error 17 | Panic(v ...interface{}) 18 | Panicf(format string, v ...interface{}) 19 | Panicln(v ...interface{}) 20 | Prefix() string 21 | Print(v ...interface{}) 22 | Printf(format string, v ...interface{}) 23 | Println(v ...interface{}) 24 | SetFlags(flag int) 25 | SetOutput(w io.Writer) 26 | SetPrefix(prefix string) 27 | } 28 | 29 | // Log is a runtime-wide logger used by this package. You are allowed to set it 30 | // to a different facility as needed. 31 | var Log Logger = log.New(os.Stderr, "apns2: ", log.LstdFlags) 32 | 33 | // Severity represents a log entry severity. 34 | type Severity int 35 | 36 | const ( 37 | LogError Severity = iota 38 | LogWarn 39 | LogNotice 40 | LogInfo 41 | ) 42 | 43 | // LogLevel is a runtime-wide setting that indicates which severity levels 44 | // should be logged. 45 | var LogLevel = LogNotice 46 | 47 | var severityStrs = map[Severity]string{ 48 | LogError: "ERROR ", 49 | LogWarn: "WARNING ", 50 | LogNotice: "NOTICE ", 51 | LogInfo: "INFO ", 52 | LogInfo + 1: "TRACE ", 53 | } 54 | 55 | // Bounds returns a severity value that is clamped between LogError and 56 | // LogInfo + 1, latter being indicative of trace level logging. 57 | func (t Severity) Bound() Severity { 58 | switch { 59 | case t < LogError: 60 | return LogError 61 | case t > LogInfo: 62 | return LogInfo + 1 63 | } 64 | return t 65 | } 66 | 67 | // String returns name associated with given Severity value. 68 | func (t Severity) String() string { 69 | return severityStrs[t.Bound()] 70 | } 71 | 72 | // LogTrace returns a Severity value corresponding to the spcified trace level. 73 | func LogTrace(traceLevel uint) Severity { 74 | return LogInfo + Severity(traceLevel+1) 75 | } 76 | 77 | func logWarn(id string, format string, v ...interface{}) { 78 | logTag(id, LogWarn, format, v...) 79 | } 80 | 81 | func logInfo(id string, format string, v ...interface{}) { 82 | logTag(id, LogInfo, format, v...) 83 | } 84 | 85 | func logTrace(level uint, id string, format string, v ...interface{}) { 86 | logTag(id, LogInfo+Severity(level+1), format, v...) 87 | } 88 | 89 | func logTag(id string, tag Severity, format string, v ...interface{}) { 90 | if tag > LogLevel { 91 | return 92 | } 93 | format = tag.String() + format 94 | if len(id) > 0 { 95 | format = id + ": " + format 96 | } 97 | Log.Printf(format, v...) 98 | } 99 | -------------------------------------------------------------------------------- /scale/scale.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package scale 4 | 5 | // Scale must be implemented by scale-up and wind-down calculators. 6 | // Three scale calculators come predefined: Incremental, Exponential 7 | // and Constant. 8 | type Scale interface { 9 | IsValid() bool 10 | Apply(n uint32) uint32 11 | ApplyInverse(n uint32) uint32 12 | } 13 | 14 | // Constant scaling mode does not allow scaling. 15 | type constant struct{} 16 | 17 | // IsValid always return true. 18 | func (s constant) IsValid() bool { 19 | return true 20 | } 21 | 22 | // Apply returns supplied value unmodified. 23 | func (s constant) Apply(n uint32) uint32 { 24 | return n 25 | } 26 | 27 | // ApplyInverse returns supplied value unmodified. 28 | func (s constant) ApplyInverse(n uint32) uint32 { 29 | return n 30 | } 31 | 32 | // Constant scaler that does not allow scaling. 33 | var Constant constant 34 | 35 | // Incremental scaling mode specifies the number of new instances to be added 36 | // during each scaling attempt. Must be 1 or greater. 37 | type Incremental uint32 38 | 39 | // IsValid checks that its value is greater than 1. 40 | func (s Incremental) IsValid() bool { 41 | return s >= 1 42 | } 43 | 44 | // Apply adds itself to the supplied value and returns the sum. 45 | func (s Incremental) Apply(n uint32) uint32 { 46 | return n + uint32(s) 47 | } 48 | 49 | // If Incremental is greater or equal to the supplied value, ApplyInverse 50 | // subtracts itself to the argument and returns the difference. Otherwise 51 | // 0 is returned. 52 | func (s Incremental) ApplyInverse(n uint32) uint32 { 53 | if uint32(s) > n { 54 | return 0 55 | } 56 | return n - uint32(s) 57 | } 58 | 59 | // Exponential scaling mode specifies the factor by which the number of 60 | // instances should be increased during each scaling attempt. Must be greater 61 | // than 1.0. 62 | type Exponential float32 63 | 64 | // IsValid checks that its value is greater than 1. 65 | func (s Exponential) IsValid() bool { 66 | return s > 1.0 67 | } 68 | 69 | // Apply scales the supplied value by its factor and returns the result. 70 | // The result is guaranteed to be greater than the input by at least 1. 71 | func (s Exponential) Apply(n uint32) uint32 { 72 | res := uint32(float32(s) * float32(n)) 73 | // We must increase by at least 1. 74 | if res <= n { 75 | res = n + 1 76 | } 77 | return res 78 | } 79 | 80 | // Apply scales the supplied value by its inverse factor and returns the result. 81 | // The result is guaranteed to be 0 or to be less than the nonzero input 82 | // by at least 1. 83 | func (s Exponential) ApplyInverse(n uint32) uint32 { 84 | res := uint32(float32(n) / float32(s)) 85 | // We must decrease by at least 1, but not go below 0. 86 | if res >= n && n > 0 { 87 | res = n - 1 88 | } 89 | return res 90 | } 91 | -------------------------------------------------------------------------------- /scale/scale_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package scale 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestConstant(t *testing.T) { 12 | assert.True(t, Constant.IsValid()) 13 | assert.Exactly(t, uint32(0), Constant.Apply(0)) 14 | assert.Exactly(t, uint32(1), Constant.Apply(1)) 15 | assert.Exactly(t, uint32(2), Constant.Apply(2)) 16 | assert.Exactly(t, uint32(0), Constant.ApplyInverse(0)) 17 | assert.Exactly(t, uint32(1), Constant.ApplyInverse(1)) 18 | assert.Exactly(t, uint32(2), Constant.ApplyInverse(2)) 19 | } 20 | 21 | func TestIncremental(t *testing.T) { 22 | var s Incremental 23 | s = Incremental(0) 24 | assert.False(t, s.IsValid()) 25 | s = Incremental(1) 26 | assert.True(t, s.IsValid()) 27 | assert.Exactly(t, uint32(1), s.Apply(0)) 28 | assert.Exactly(t, uint32(2), s.Apply(1)) 29 | assert.Exactly(t, uint32(3), s.Apply(2)) 30 | assert.Exactly(t, uint32(0), s.ApplyInverse(0)) 31 | assert.Exactly(t, uint32(0), s.ApplyInverse(1)) 32 | assert.Exactly(t, uint32(1), s.ApplyInverse(2)) 33 | s = Incremental(10) 34 | assert.True(t, s.IsValid()) 35 | assert.Exactly(t, uint32(10), s.Apply(0)) 36 | assert.Exactly(t, uint32(11), s.Apply(1)) 37 | assert.Exactly(t, uint32(12), s.Apply(2)) 38 | assert.Exactly(t, uint32(0), s.ApplyInverse(0)) 39 | assert.Exactly(t, uint32(0), s.ApplyInverse(9)) 40 | assert.Exactly(t, uint32(0), s.ApplyInverse(10)) 41 | assert.Exactly(t, uint32(1), s.ApplyInverse(11)) 42 | } 43 | 44 | func TestExponential(t *testing.T) { 45 | var s Exponential 46 | s = Exponential(0) 47 | assert.False(t, s.IsValid()) 48 | s = Exponential(1) 49 | assert.False(t, s.IsValid()) 50 | s = Exponential(2) 51 | assert.True(t, s.IsValid()) 52 | assert.Exactly(t, uint32(1), s.Apply(0)) 53 | assert.Exactly(t, uint32(2), s.Apply(1)) 54 | assert.Exactly(t, uint32(4), s.Apply(2)) 55 | assert.Exactly(t, uint32(0), s.ApplyInverse(0)) 56 | assert.Exactly(t, uint32(0), s.ApplyInverse(1)) 57 | assert.Exactly(t, uint32(1), s.ApplyInverse(2)) 58 | assert.Exactly(t, uint32(1), s.ApplyInverse(3)) 59 | assert.Exactly(t, uint32(2), s.ApplyInverse(4)) 60 | s = Exponential(1.25) 61 | assert.True(t, s.IsValid()) 62 | assert.Exactly(t, uint32(1), s.Apply(0)) 63 | assert.Exactly(t, uint32(2), s.Apply(1)) 64 | assert.Exactly(t, uint32(3), s.Apply(2)) 65 | assert.Exactly(t, uint32(12), s.Apply(10)) 66 | assert.Exactly(t, uint32(0), s.ApplyInverse(0)) 67 | assert.Exactly(t, uint32(0), s.ApplyInverse(1)) 68 | assert.Exactly(t, uint32(1), s.ApplyInverse(2)) 69 | assert.Exactly(t, uint32(2), s.ApplyInverse(3)) 70 | assert.Exactly(t, uint32(3), s.ApplyInverse(4)) 71 | assert.Exactly(t, uint32(4), s.ApplyInverse(5)) 72 | assert.Exactly(t, uint32(4), s.ApplyInverse(6)) 73 | assert.Exactly(t, uint32(9), s.ApplyInverse(12)) 74 | assert.Exactly(t, uint32(10), s.ApplyInverse(13)) 75 | } 76 | -------------------------------------------------------------------------------- /apns2/test_harness.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "time" 7 | 8 | "github.com/baobabus/go-apns/funit" 9 | "github.com/baobabus/go-apnsmock/apns2mock" 10 | ) 11 | 12 | var ( 13 | apnsMockComms_Typical = apns2mock.CommsCfg{ 14 | MaxConcurrentStreams: 500, 15 | MaxConns: 1000, 16 | ConnectionDelay: 1 * time.Second, 17 | ResponseTime: 20 * time.Millisecond, 18 | } 19 | apnsMockComms_30ms = apns2mock.CommsCfg{ 20 | MaxConcurrentStreams: 500, 21 | MaxConns: 1000, 22 | ConnectionDelay: 30 * time.Millisecond, 23 | ResponseTime: 30 * time.Millisecond, 24 | } 25 | apnsMockComms_NoDelay = apns2mock.CommsCfg{ 26 | MaxConcurrentStreams: 500, 27 | MaxConns: 1000, 28 | ConnectionDelay: 0, 29 | ResponseTime: 0, 30 | } 31 | commsTest_Fast = CommsCfg{ 32 | DialTimeout: 20 * time.Millisecond, 33 | MinDialBackOff: 100 * time.Millisecond, 34 | MaxDialBackOff: 500 * time.Millisecond, 35 | DialBackOffJitter: 10 * funit.Percent, 36 | RequestTimeout: 30 * time.Millisecond, 37 | KeepAlive: 100 * time.Millisecond, 38 | MaxConcurrentStreams: 500, 39 | } 40 | ) 41 | 42 | const testTokenKey_Good = ` 43 | -----BEGIN PRIVATE KEY----- 44 | MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgEbVzfPnZPxfAyxqE 45 | ZV05laAoJAl+/6Xt2O4mOB611sOhRANCAASgFTKjwJAAU95g++/vzKWHkzAVmNMI 46 | tB5vTjZOOIwnEb70MsWZFIyUFD1P9Gwstz4+akHX7vI8BH6hHmBmfeQl 47 | -----END PRIVATE KEY----- 48 | ` 49 | 50 | var ( 51 | testNotif_Good = &Notification{ 52 | Recipient: "00fc13adff785122b4ad28809a3420982341241421348097878e577c991de8f0", 53 | Header: &Header{Topic: "com.example.Alert"}, 54 | Payload: &Payload{APS: &APS{Alert: "Ping!"}}, 55 | } 56 | testNotif_BadDevice = &Notification{ 57 | Recipient: "10fc13adff785122b4ad28809a3420982341241421348097878e577c991de8f0", 58 | Header: &Header{Topic: "com.example.Alert"}, 59 | Payload: &Payload{APS: &APS{Alert: "Ping!"}}, 60 | } 61 | ) 62 | 63 | type tester interface { 64 | //Helper() 65 | Fatal(args ...interface{}) 66 | Fatalf(format string, args ...interface{}) 67 | } 68 | 69 | func mustNewMockServer(t tester) *apns2mock.Server { 70 | //t.Helper() 71 | res, err := apns2mock.NewServer( 72 | apnsMockComms_NoDelay, 73 | apns2mock.DefaultHandler, 74 | apns2mock.AutoCert, 75 | apns2mock.AutoKey, 76 | ) 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | return res 81 | } 82 | 83 | func mustNewMockServerWithCfg(t tester, cfg apns2mock.CommsCfg) *apns2mock.Server { 84 | //t.Helper() 85 | res, err := apns2mock.NewServer( 86 | cfg, 87 | apns2mock.AllOkayHandler, 88 | apns2mock.AutoCert, 89 | apns2mock.AutoKey, 90 | ) 91 | if err != nil { 92 | t.Fatal(err) 93 | } 94 | return res 95 | } 96 | 97 | func mustNewHTTPClient(t tester, s *apns2mock.Server) *HTTPClient { 98 | //t.Helper() 99 | res, err := NewHTTPClient(s.URL, CommsFast, nil, s.RootCertificate) 100 | if err != nil { 101 | t.Fatal(err) 102 | } 103 | return res 104 | } 105 | -------------------------------------------------------------------------------- /apns2/payload.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "encoding/json" 7 | "sync/atomic" 8 | ) 9 | 10 | // Payload is the container for the actual data to be delivered 11 | // to the notification recipient. 12 | // How a payload is utilized is not constrained, but the intent is 13 | // to never modify it once created and assigned. 14 | // The same payload can then be sent to any number of recipients. 15 | type Payload struct { 16 | APS *APS 17 | Raw map[string]interface{} 18 | json atomic.Value 19 | } 20 | 21 | type APS struct { 22 | Alert interface{} 23 | Badge interface{} 24 | Category string 25 | ContentAvailable bool 26 | MutableContent bool 27 | Sound string 28 | ThreadID string 29 | URLArgs []string 30 | } 31 | 32 | type Alert struct { 33 | Action string `json:"action,omitempty"` 34 | ActionLocKey string `json:"action-loc-key,omitempty"` 35 | Body string `json:"body,omitempty"` 36 | LaunchImage string `json:"launch-image,omitempty"` 37 | LocArgs []string `json:"loc-args,omitempty"` 38 | LocKey string `json:"loc-key,omitempty"` 39 | Title string `json:"title,omitempty"` 40 | Subtitle string `json:"subtitle,omitempty"` 41 | TitleLocArgs []string `json:"title-loc-args,omitempty"` 42 | TitleLocKey string `json:"title-loc-key,omitempty"` 43 | } 44 | 45 | func (p *Payload) MarshalJSON() ([]byte, error) { 46 | res := p.json.Load() 47 | if res != nil { 48 | return res.([]byte), nil 49 | } 50 | // We could protect this with a Mutex, but for improved throughput 51 | // it is probably better to avoid resource contention here and just 52 | // duplicate the work in case we have concurrent calls. 53 | m := p.mergedMap() 54 | j, err := json.Marshal(m) 55 | if err != nil { 56 | return nil, err 57 | } 58 | p.json.Store(j) 59 | return j, nil 60 | } 61 | 62 | func (p *Payload) mergedMap() map[string]interface{} { 63 | if p.APS == nil { 64 | return p.Raw 65 | } 66 | res := make(map[string]interface{}) 67 | // 1. Shallow copy the original raw map 68 | for k, v := range p.Raw { 69 | res[k] = v 70 | } 71 | // 2. Overwrite APS fields 72 | if aps, ok := res["aps"]; !ok { 73 | res["aps"] = make(map[string]interface{}) 74 | } else if _, ok := aps.(map[string]interface{}); !ok { 75 | res["aps"] = make(map[string]interface{}) 76 | } 77 | p.APS.addToMap(res["aps"].(map[string]interface{})) 78 | return res 79 | } 80 | 81 | func (a APS) addToMap(m map[string]interface{}) { 82 | if a.Alert != nil { 83 | m["alert"] = a.Alert 84 | } 85 | if a.Badge != nil { 86 | m["badge"] = a.Badge 87 | } 88 | if a.Category != "" { 89 | m["category"] = a.Category 90 | } 91 | if a.ContentAvailable { 92 | m["content-available"] = 1 93 | } 94 | if a.MutableContent { 95 | m["mutable-content"] = 1 96 | } 97 | if a.Sound != "" { 98 | m["sound"] = a.Sound 99 | } 100 | if a.ThreadID != "" { 101 | m["thread-id"] = a.ThreadID 102 | } 103 | if len(a.URLArgs) > 0 { 104 | m["url-args"] = a.URLArgs 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /apns2/comm.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "crypto/tls" 7 | "net" 8 | "time" 9 | 10 | "github.com/baobabus/go-apns/funit" 11 | ) 12 | 13 | // CommsCfg is a set of parameters that govern communications with APN servers. 14 | // Two baseline configuration sets are predefined by CommsFast and CommsSlow 15 | // global variables. You may define your own sets as needed to address 16 | // any specific requirements of your particular setup. 17 | type CommsCfg struct { 18 | 19 | // DialTimeout is the maximum amount of time a dial will wait for a connect 20 | // to complete. 21 | DialTimeout time.Duration 22 | 23 | // MinDialBackOff is the minimum amount of time by which dial attempts 24 | // should be delayed after encountering a refused connection. 25 | // Actual back-off time will grow exponentially until a connection attempt 26 | // is successful. 27 | MinDialBackOff time.Duration 28 | 29 | // MaxDialBackOff is the maximum amount of time by which dial attempts 30 | // should be delayed after encountering a refused connection. 31 | MaxDialBackOff time.Duration 32 | 33 | // DialBackOffJitter is used to calculate the ramdom amount to appy to each 34 | // back-off time calculation. 35 | DialBackOffJitter funit.Measure 36 | 37 | // RequestTimeout specifies a time limit for requests made by the 38 | // HTTPClient. The timeout includes connection time, any redirects, 39 | // and reading the response body. 40 | RequestTimeout time.Duration 41 | 42 | // KeepAlive specifies the keep-alive period for an active network 43 | // connection. If zero, keep-alives are not enabled. 44 | // Apple recommends not closing connections to APN service at all, 45 | // but a sinsibly long duration is acceptable. 46 | KeepAlive time.Duration 47 | 48 | // MaxConcurrentStreams is the maximum allowed number of concurrent streams 49 | // per HTTP/2 connection. If connection's MAX_CONCURRENT_STREAMS option 50 | // is invoked by the remote side with a lower value, the remote request 51 | // will be honored if possible. 52 | MaxConcurrentStreams uint32 53 | } 54 | 55 | // CommsFast is a baseline set of communication settings for situations where 56 | // long delays cannot be tolerated. 57 | var CommsFast = CommsCfg{ 58 | DialTimeout: 20 * time.Second, 59 | MinDialBackOff: 4 * time.Second, 60 | MaxDialBackOff: 10 * time.Minute, 61 | DialBackOffJitter: 10 * funit.Percent, 62 | RequestTimeout: 30 * time.Second, 63 | KeepAlive: 10 * time.Hour, 64 | MaxConcurrentStreams: 500, 65 | } 66 | 67 | // CommsSlow is a baseline set of communication settings accommodating 68 | // wider range of network performance and APN service responsiveness scenarios. 69 | var CommsSlow = CommsCfg{ 70 | DialTimeout: 40 * time.Second, 71 | MinDialBackOff: 10 * time.Second, 72 | MaxDialBackOff: 10 * time.Minute, 73 | DialBackOffJitter: 10 * funit.Percent, 74 | RequestTimeout: 60 * time.Second, 75 | KeepAlive: 10 * time.Hour, 76 | MaxConcurrentStreams: 500, 77 | } 78 | 79 | // CommsDefault is the set of communication settings that is used when 80 | // you do not supply an explicit comms configuration where one is needed. 81 | var CommsDefault = CommsSlow 82 | 83 | func makeDialer(commsCfg CommsCfg) func(network, addr string, cfg *tls.Config) (net.Conn, error) { 84 | return func(network, addr string, cfg *tls.Config) (net.Conn, error) { 85 | dialer := &net.Dialer{ 86 | Timeout: commsCfg.DialTimeout, 87 | KeepAlive: commsCfg.KeepAlive, 88 | } 89 | return tls.DialWithDialer(dialer, network, addr, cfg) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /apns2/auth_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "net/http" 7 | "regexp" 8 | "testing" 9 | "time" 10 | 11 | "github.com/baobabus/go-apns/cryptox" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | var ( 16 | auth_test_jwtAsHeader = regexp.MustCompile("bearer [a-zA-Z0-9\\-_]+\\.[a-zA-Z0-9\\-_]+\\.[a-zA-Z0-9\\-_]+") 17 | ) 18 | 19 | func TestJWTSignerDefaults(t *testing.T) { 20 | signingKey, err := cryptox.PKCS8PrivateKeyFromFile("../cryptox/test_data/pk_valid.p8") 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | s := &JWTSigner{ 25 | KeyID: "ABC123DEFG", 26 | TeamID: "DEF123GHIJ", 27 | SigningKey: signingKey, 28 | } 29 | now := time.Now() 30 | tk, err := s.GetToken() 31 | if err != nil { 32 | t.Fatal(err) 33 | } 34 | assert.Exactly(t, DefaultJWTSigningMethod, s.signingMethod) 35 | assert.Exactly(t, DefaultTokenLifeSpan, s.tokenLifeSpan) 36 | assert.True(t, tk.IssuedAt.Unix()-now.Unix() < 1) 37 | assert.Exactly(t, tk.IssuedAt.Add(DefaultTokenLifeSpan).Unix(), tk.ExpiresAt.Unix()) 38 | assert.True(t, auth_test_jwtAsHeader.MatchString(tk.AsHeader)) 39 | } 40 | 41 | func TestJWTSignerCustom(t *testing.T) { 42 | signingKey, err := cryptox.PKCS8PrivateKeyFromFile("../cryptox/test_data/pk_valid.p8") 43 | if err != nil { 44 | t.Fatal(err) 45 | } 46 | lifespan := time.Minute 47 | s := &JWTSigner{ 48 | KeyID: "ABC123DEFG", 49 | TeamID: "DEF123GHIJ", 50 | SigningKey: signingKey, 51 | TokenLifeSpan: lifespan, 52 | } 53 | now := time.Now() 54 | tk, err := s.GetToken() 55 | if err != nil { 56 | t.Fatal(err) 57 | } 58 | assert.Exactly(t, DefaultJWTSigningMethod, s.signingMethod) 59 | assert.Exactly(t, lifespan, s.tokenLifeSpan) 60 | assert.True(t, tk.IssuedAt.Unix()-now.Unix() < 1) 61 | assert.Exactly(t, tk.IssuedAt.Add(lifespan).Unix(), tk.ExpiresAt.Unix()) 62 | assert.True(t, auth_test_jwtAsHeader.MatchString(tk.AsHeader)) 63 | } 64 | 65 | func TestJWTSignerRefresh(t *testing.T) { 66 | signingKey, err := cryptox.PKCS8PrivateKeyFromFile("../cryptox/test_data/pk_valid.p8") 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | lifespan := 750 * time.Microsecond 71 | s := &JWTSigner{ 72 | KeyID: "ABC123DEFG", 73 | TeamID: "DEF123GHIJ", 74 | SigningKey: signingKey, 75 | TokenLifeSpan: lifespan, 76 | } 77 | tk1, err := s.GetToken() 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | tk2, err := s.GetToken() 82 | if err != nil { 83 | t.Fatal(err) 84 | } 85 | assert.Equal(t, tk1, tk2) 86 | time.Sleep(lifespan) 87 | tk3, err := s.GetToken() 88 | if err != nil { 89 | t.Fatal(err) 90 | } 91 | assert.NotEqual(t, tk1, tk3) 92 | assert.True(t, tk1.IssuedAt.Before(tk3.IssuedAt)) 93 | assert.True(t, tk1.ExpiresAt.Before(tk3.IssuedAt)) 94 | } 95 | 96 | func TestJWTSignerSignRequest(t *testing.T) { 97 | signingKey, err := cryptox.PKCS8PrivateKeyFromFile("../cryptox/test_data/pk_valid.p8") 98 | if err != nil { 99 | t.Fatal(err) 100 | } 101 | s := &JWTSigner{ 102 | KeyID: "ABC123DEFG", 103 | TeamID: "DEF123GHIJ", 104 | SigningKey: signingKey, 105 | } 106 | req, err := http.NewRequest("POST", "", nil) 107 | if err != nil { 108 | t.Fatal(err) 109 | } 110 | assert.Equal(t, 0, len(req.Header)) 111 | err = s.SignRequest(req) 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | assert.Equal(t, 1, len(req.Header)) 116 | assert.Equal(t, 1, len(req.Header["Authorization"])) 117 | h := req.Header.Get("Authorization") 118 | assert.True(t, auth_test_jwtAsHeader.MatchString(h)) 119 | } 120 | 121 | func TestNoSignerSignRequest(t *testing.T) { 122 | s := NoSigner 123 | req, err := http.NewRequest("POST", "", nil) 124 | if err != nil { 125 | t.Fatal(err) 126 | } 127 | assert.Equal(t, 0, len(req.Header)) 128 | err = s.SignRequest(req) 129 | if err != nil { 130 | t.Fatal(err) 131 | } 132 | assert.Equal(t, 0, len(req.Header)) 133 | } 134 | -------------------------------------------------------------------------------- /syncx/ticktock.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package syncx 4 | 5 | import ( 6 | "sync/atomic" 7 | ) 8 | 9 | // TickTockCounter counts the number of Tick and Tock calls that it receives. 10 | // Balancing ticks and tocks can then be "folded", such that both counts are 11 | // reduced by the number of "tocks". I.e. 4 calls of Tick and 3 calls of Tock 12 | // would result in the counts to be reduced to 1 and 0 respectively. 13 | // 14 | // TickTockCounter places following constraints on its use: 15 | // 16 | // 1. A call to Tick must be guaranteed to have completed before calling its 17 | // balancing Tock. Typically such balancing calls would be made serially 18 | // from a single goroutine. 19 | // 20 | // 2. Concurrent calls to Fold are not allowed. Folding would typically be done 21 | // in a single goroutine in a serial manner. 22 | // 23 | // TickTockCounter is optimized for use with multiple concurrent "tikers" and 24 | // a single folder. If concurrent calls to Fold are anticipated, they must be 25 | // guarded by a mutex. 26 | type TickTockCounter uint64 27 | 28 | // Tick atomically increments "tick" counter. It is safe for use in concurrent 29 | // gorotines as long as the caller ensures that corresponding Tock call is only 30 | // made after Tick call. 31 | func (c *TickTockCounter) Tick() { 32 | atomic.AddUint64((*uint64)(c), 1<<32) 33 | } 34 | 35 | // Tick atomically increments "tock" counter. It is safe for use in concurrent 36 | // gorotines as long as the caller ensures that corresponding Tick call has 37 | // already been made. 38 | func (c *TickTockCounter) Tock() { 39 | atomic.AddUint64((*uint64)(c), 1) 40 | } 41 | 42 | // Fold collapses balancing "ticks" and "tocks" by reducing the counts by the 43 | // number of "tocks" and returns pre-folded counts. Folding is done atomically, 44 | // such that concurrent calls to Tick and Tock do not result in an imbalance 45 | // as well as ensuring that no "ticks" or "tocks" are dropped or double-counted. 46 | // 47 | // For performance reasons this method is not safe for use in concurrent 48 | // gorotines. It is however safe for use concurrently with Tick and Tock calls. 49 | // If concurrent calls to Draw are anticipated they must be protected 50 | // by a mutex. 51 | func (c *TickTockCounter) Fold() (ticks uint32, tocks uint32) { 52 | cntr := atomic.LoadUint64((*uint64)(c)) 53 | tocks = uint32(cntr) 54 | ticks = uint32(cntr >> 32) 55 | // It is possible for the counts to have increased since the load call 56 | // as Tick and Tock are called from concurrent goroutines. 57 | // Atomically subtracting previously read tock count from both counters 58 | // is still safe as the counts would have never decreased (as long as Fold 59 | // is not called concurrently for another goroutine). 60 | // We may end up with a non-zero tock count at the end of the subtraction, 61 | // but it is not wrong. These "extra" counts will be picked up by the 62 | // subsequent call to Fold. No counts are dropped or double-counted. 63 | atomic.AddUint64((*uint64)(c), ^((uint64(tocks) << 32) + uint64(tocks) - 1)) 64 | return 65 | } 66 | 67 | // TickTockFolder counts the number of balanced Tick/Tock calls. 68 | // Conceptually it comprises two counters, one for the number of complete 69 | // pairs and one the number of pending ones, with each being a uint32. 70 | // 71 | // TickTockFolder is optimized for concurrent use by multiple "tickers", 72 | // but is restricted to a single concurrent consumer. Protect access 73 | // to the counter with a mutex if concurrent Draw attempts are anticipated. 74 | type TickTockFolder uint64 75 | 76 | // Tick atomically increments pending cycle counter. 77 | // 78 | // This method is safe for use in concurrent gorotines. 79 | func (f *TickTockFolder) Tick() { 80 | atomic.AddUint64((*uint64)(f), 1) 81 | } 82 | 83 | // Tock atomically decrements pending cycle counter and increments 84 | // complete cycle counter. 85 | // 86 | // This method is safe for use in concurrent gorotines. 87 | func (f *TickTockFolder) Tock() { 88 | atomic.AddUint64((*uint64)(f), uint64(^uint32(0))) 89 | } 90 | 91 | // Draw atomically draws complete cycle counter. The counter's complete cycles 92 | // count is set to 0 and the previous value is returned. Pending cycle count 93 | // is also returned. 94 | // 95 | // This method is not safe for use in concurrent gorotines. It is however safe 96 | // for use concurrently with Tick and Tock methods. 97 | // If concurrent calls to Draw are anticipated they must be protected 98 | // by a mutex. 99 | func (f *TickTockFolder) Draw() (complete uint32, pending uint32) { 100 | cntr := atomic.LoadUint64((*uint64)(f)) 101 | pending = uint32(cntr) 102 | cntr = cntr >> 32 103 | complete = uint32(cntr) 104 | // It's possible for the complete count to have increased by this point, 105 | // but we are only subtracting the value previously read. 106 | // This is safe as long as we are not calling Draw concurrently from more 107 | // than one goroutine. 108 | atomic.AddUint64((*uint64)(f), ^((cntr << 32) - 1)) 109 | return 110 | } 111 | -------------------------------------------------------------------------------- /apns2/auth.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "crypto/ecdsa" 7 | "fmt" 8 | "net/http" 9 | "sync" 10 | "sync/atomic" 11 | "time" 12 | 13 | jwt "github.com/dgrijalva/jwt-go" 14 | ) 15 | 16 | // RequestSigner must be implemented by all APN service request signers. 17 | // Provider token signing allows authenticating with APN service on per request 18 | // basis, if needed. 19 | type RequestSigner interface { 20 | 21 | // SignRequest gives the signer a chance to sign the request. 22 | // Any headers and the request body is guaranteed to have been 23 | // set up at this point. 24 | SignRequest(r *http.Request) error 25 | } 26 | 27 | // DefaultTokenLifeSpan specifies the time duration for which 28 | // provier tokens are considered to be valid. At present APN service 29 | // stops honoring authentication tokens that are older than 1 hour. 30 | // Initial global default value allows 10 minutes of safety margin. 31 | // If changed, any provider token authenticators created thereafter 32 | // will use the new value. 33 | var DefaultTokenLifeSpan = 50 * time.Minute 34 | 35 | // DefaultJWTSigningMethod method for APN requests is ES256. 36 | var DefaultJWTSigningMethod = jwt.SigningMethodES256 37 | 38 | // Provider token-based signer that uses JSON Web Tokens to sign individual 39 | // requests to APN service. It is safe to use in concurrent goroutines. 40 | type JWTSigner struct { 41 | // A 10-character key identifier, obtained from Apple developer account. 42 | KeyID string 43 | 44 | // A 10-character Team ID, obtained from Apple developer account. 45 | TeamID string 46 | 47 | // Private key for signing generated tokens. 48 | SigningKey *ecdsa.PrivateKey 49 | 50 | // Method to use for signing generated tokens. 51 | SigningMethod *jwt.SigningMethodECDSA 52 | 53 | // The duration for which generated tokens are considered valid by apns2. 54 | // This is currently required to not exceed one hour. 55 | TokenLifeSpan time.Duration 56 | 57 | mu sync.Mutex 58 | // Last generated token. This should not be accessed directly. 59 | // Use GetToken() method, which may generated a new token 60 | // before returing it if needed. 61 | currentToken atomic.Value 62 | // SigningMethod or, if nil, DefaultSigningMethod 63 | signingMethod *jwt.SigningMethodECDSA 64 | tokenLifeSpan time.Duration 65 | } 66 | 67 | // JWT is an implementation of provider token in the form of 68 | // Javascript Web Token, that can be written to HTTP authorization header. 69 | // It is intended to remain immutable once created, and is safe to use 70 | // in concurrent goroutines. 71 | type JWT struct { 72 | IssuedAt time.Time 73 | ExpiresAt time.Time 74 | JwtToken *jwt.Token 75 | AsHeader string 76 | } 77 | 78 | // SignRequest adds Authorization header to the supplied request. 79 | // The header is an encrypted JSON Web Token containing signer's credentials. 80 | // The token is guaranteed to be valid at the time of the call. 81 | func (s *JWTSigner) SignRequest(r *http.Request) error { 82 | t, err := s.GetToken() 83 | if err != nil { 84 | return err 85 | } 86 | r.Header.Set("Authorization", t.AsHeader) 87 | return nil 88 | } 89 | 90 | // GetToken returns provider authentication token that is guaranteed 91 | // to be valid at the time of the call. 92 | func (s *JWTSigner) GetToken() (*JWT, error) { 93 | now := time.Now() 94 | // This is very heavy on read and atomics are said to be much faster 95 | // than RWMutex. Not that it is important in this case, though. 96 | res := s.currentToken.Load() 97 | if res != nil && res.(*JWT).ExpiresAt.After(now) { 98 | return res.(*JWT), nil 99 | } 100 | // We could safely forgo a mutex here and generate more than one 101 | // new token concurrently, let them all get used and then overwritten, 102 | // but lets do it cleanly and not annoy APN servers. 103 | s.mu.Lock() 104 | defer s.mu.Unlock() 105 | // Check again in case someone else got here first. 106 | res = s.currentToken.Load() 107 | if res != nil && res.(*JWT).ExpiresAt.After(now) { 108 | return res.(*JWT), nil 109 | } 110 | if s.signingMethod == nil { 111 | if s.SigningMethod == nil { 112 | s.signingMethod = DefaultJWTSigningMethod 113 | } else { 114 | s.signingMethod = s.SigningMethod 115 | } 116 | } 117 | if s.tokenLifeSpan == 0 { 118 | if s.TokenLifeSpan > 0 { 119 | s.tokenLifeSpan = s.TokenLifeSpan 120 | } else { 121 | s.tokenLifeSpan = DefaultTokenLifeSpan 122 | } 123 | } 124 | t := &jwt.Token{ 125 | Header: map[string]interface{}{ 126 | "alg": s.signingMethod.Name, 127 | "kid": s.KeyID, 128 | }, 129 | Claims: jwt.MapClaims{ 130 | "iss": s.TeamID, 131 | "iat": now.Unix(), 132 | }, 133 | Method: s.signingMethod, 134 | } 135 | ss, err := t.SignedString(s.SigningKey) 136 | if err != nil { 137 | return nil, err 138 | } 139 | tkn := &JWT{ 140 | IssuedAt: now, 141 | ExpiresAt: now.Add(s.tokenLifeSpan), 142 | JwtToken: t, 143 | AsHeader: fmt.Sprintf("bearer %v", ss), 144 | } 145 | s.currentToken.Store(tkn) 146 | return tkn, nil 147 | } 148 | 149 | type noSigner struct{} 150 | 151 | func (s noSigner) SignRequest(r *http.Request) error { 152 | return nil 153 | } 154 | -------------------------------------------------------------------------------- /http2x/reflect.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package http2x 4 | 5 | import ( 6 | "errors" 7 | "net/http" 8 | "reflect" 9 | "sync" 10 | "unsafe" 11 | 12 | "golang.org/x/net/http2" 13 | ) 14 | 15 | var ( 16 | ErrIncompatibleHTTP2Layer = errors.New("http2x: incompatible http2 client library") 17 | ErrUnsupportedTransport = errors.New("http2x: unsupported transport layer") 18 | ) 19 | 20 | // GetMaxConcurrentStreams returns the value of maxConcurrentStreams 21 | // private field of c using reflection. It properly guards its read 22 | // with c's mutex. 23 | // 24 | // If c is nil, 1 is returned. If c is closed, 0 is returned. 25 | // Otherwise, if maxConcurrentStreams cannot be determined 26 | // due to http2.ClientConn incompatibility, maximum uint32 value is returned. 27 | func GetMaxConcurrentStreams(c *http2.ClientConn) uint32 { 28 | if c == nil { 29 | return 1 30 | } 31 | rc := reflect.Indirect(reflect.ValueOf(c)) 32 | if !c.CanTakeNewRequest() { 33 | return 0 34 | } 35 | if !http2Compat { 36 | return ^uint32(0) 37 | } 38 | // This is all of the data we are currently interested in. 39 | // If the need for other fields arises, we should ensure 40 | // to retrive them all together while holding the c's Mutex. 41 | mu := (*sync.Mutex)(ptrToFieldValue(rc, clientConn.mu)) 42 | mu.Lock() 43 | defer mu.Unlock() 44 | // This may be better than c.CanTakeNewRequest() as it is being guarded. 45 | // closed := (*bool)(ptrToFieldValue(rc, clientConn.closed)) 46 | // goAway := (**http2.GoAwayFrame)(ptrToFieldValue(rc, clientConn.goAway)) 47 | // if *closed || *goAway != nil { 48 | // return 0 49 | // } 50 | res := (*uint32)(ptrToFieldValue(rc, clientConn.maxConcurrentStreams)) 51 | return *res 52 | } 53 | 54 | var dummyReq http.Request 55 | 56 | // GetClientConnPool returns http2.Transport t's ClientConnPool. If t is not a 57 | // *http2.Transport, ErrUnsupportedTransport error is returned. 58 | // 59 | // GetClientConnPool must be used with extreme caution. It relies on the side 60 | // effect of http2.Transport.CloseIdleConnections to ensure t's connection pool 61 | // is initialized before trying to access it. The only safe time for this call 62 | // is before t has had a chance to open its first connection. 63 | func GetClientConnPool(t http.RoundTripper) (http2.ClientConnPool, error) { 64 | if !http2Compat { 65 | return nil, ErrIncompatibleHTTP2Layer 66 | } 67 | t2, ok := t.(*http2.Transport) 68 | if !ok || t2 == nil { 69 | return nil, ErrUnsupportedTransport 70 | } 71 | // Hack 72 | // This has a side effect of the transport initializing its connection pool 73 | t2.CloseIdleConnections() 74 | rt := reflect.Indirect(reflect.ValueOf(t2)) 75 | res := (*(*http2.ClientConnPool)(ptrToFieldValue(rt, transport.connPoolOrDef))) 76 | return res, nil 77 | } 78 | 79 | // GetClientConn returns http2.ClientConn from the pool that can be used to 80 | // communicate with the endpoint specified by the addr. Note that the pool 81 | // may initiate a dial operation at this time if no active connection 82 | // to addr exists. 83 | func GetClientConn(pool http2.ClientConnPool, addr string) (*http2.ClientConn, error) { 84 | return pool.GetClientConn(&dummyReq, addr) 85 | } 86 | 87 | func ptrToFieldValue(v reflect.Value, fieldIndex []int) unsafe.Pointer { 88 | return unsafe.Pointer(v.FieldByIndex(fieldIndex).UnsafeAddr()) 89 | } 90 | 91 | // True if it is confirmed that http2.ClientConn structure is 92 | // as expected and can be used with our reflection code. 93 | var http2Compat = true 94 | 95 | var clientConn struct { 96 | mu []int 97 | maxConcurrentStreams []int 98 | closed []int 99 | goAway []int 100 | } 101 | 102 | var transport struct { 103 | connPoolOrDef []int 104 | } 105 | 106 | func init() { 107 | // Validate http2.ClientConn structure 108 | c := reflect.TypeOf(&http2.ClientConn{}).Elem() 109 | if f, ok := c.FieldByName("mu"); ok { 110 | if f.Type.AssignableTo(reflect.TypeOf(sync.Mutex{})) { 111 | clientConn.mu = f.Index 112 | } else { 113 | http2Compat = false 114 | } 115 | } else { 116 | http2Compat = false 117 | } 118 | if f, ok := c.FieldByName("maxConcurrentStreams"); ok { 119 | if f.Type.Kind() == reflect.Uint32 { 120 | clientConn.maxConcurrentStreams = f.Index 121 | } else { 122 | http2Compat = false 123 | } 124 | } else { 125 | http2Compat = false 126 | } 127 | if f, ok := c.FieldByName("closed"); ok { 128 | if f.Type.Kind() == reflect.Bool { 129 | clientConn.closed = f.Index 130 | } else { 131 | http2Compat = false 132 | } 133 | } else { 134 | http2Compat = false 135 | } 136 | if f, ok := c.FieldByName("goAway"); ok { 137 | if f.Type.AssignableTo(reflect.TypeOf(&http2.GoAwayFrame{})) { 138 | clientConn.goAway = f.Index 139 | } else { 140 | http2Compat = false 141 | } 142 | } else { 143 | http2Compat = false 144 | } 145 | // Validate http2.Transport structure 146 | t := reflect.TypeOf(&http2.Transport{}).Elem() 147 | if f, ok := t.FieldByName("connPoolOrDef"); ok { 148 | if f.Type.AssignableTo(reflect.TypeOf((*http2.ClientConnPool)(nil)).Elem()) { 149 | transport.connPoolOrDef = f.Index 150 | } else { 151 | http2Compat = false 152 | } 153 | } else { 154 | http2Compat = false 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /apns2/backoff_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "errors" 7 | "testing" 8 | "time" 9 | 10 | "github.com/baobabus/go-apns/funit" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var backOffTesterErr = errors.New("") 15 | 16 | const backOffTesterTimeDelta float64 = 10000000 // 10 millisecond 17 | 18 | func TestZeroBackOffTracker(t *testing.T) { 19 | // Failure first 20 | s := backOffTracker{} 21 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 22 | s.update(backOffTesterErr) 23 | assert.InDelta(t, time.Now().UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 24 | s.update(backOffTesterErr) 25 | assert.InDelta(t, time.Now().UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 26 | s.update(nil) 27 | assert.InDelta(t, time.Now().UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 28 | s.update(nil) 29 | assert.InDelta(t, time.Now().UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 30 | // Success first 31 | s = backOffTracker{} 32 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 33 | s.update(nil) 34 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 35 | s.update(nil) 36 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 37 | s.update(backOffTesterErr) 38 | assert.InDelta(t, time.Now().UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 39 | s.update(backOffTesterErr) 40 | assert.InDelta(t, time.Now().UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 41 | } 42 | 43 | func TestNoJitterBackOffTracker(t *testing.T) { 44 | // Failure first 45 | d := time.Millisecond 46 | s := backOffTracker{initial: d, jitter: 0 * funit.Percent} 47 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 48 | s.update(backOffTesterErr) 49 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 50 | s.update(backOffTesterErr) 51 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 52 | time.Sleep(d) 53 | s.update(backOffTesterErr) 54 | d = d << 1 55 | last := time.Now().Add(d).UnixNano() 56 | assert.InDelta(t, last, s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 57 | time.Sleep(d) 58 | s.update(nil) 59 | assert.InDelta(t, last, s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 60 | s.update(backOffTesterErr) 61 | d = d >> 1 62 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 63 | // Success first 64 | d = time.Millisecond 65 | s = backOffTracker{initial: d, jitter: 0 * funit.Percent} 66 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 67 | s.update(nil) 68 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 69 | s.update(backOffTesterErr) 70 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 71 | s.update(backOffTesterErr) 72 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 73 | time.Sleep(d) 74 | s.update(backOffTesterErr) 75 | d = d << 1 76 | last = time.Now().Add(d).UnixNano() 77 | assert.InDelta(t, last, s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 78 | time.Sleep(d) 79 | s.update(nil) 80 | assert.InDelta(t, last, s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 81 | s.update(backOffTesterErr) 82 | d = d >> 1 83 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 84 | } 85 | 86 | func TestNoJitterCappedBackOffTracker(t *testing.T) { 87 | // Failure first 88 | d := time.Millisecond 89 | max := 3 * time.Millisecond 90 | s := backOffTracker{initial: d, max: max, jitter: 0 * funit.Percent} 91 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 92 | s.update(backOffTesterErr) 93 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 94 | s.update(backOffTesterErr) 95 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 96 | time.Sleep(d) 97 | s.update(backOffTesterErr) 98 | d = max 99 | last := time.Now().Add(d).UnixNano() 100 | assert.InDelta(t, last, s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 101 | time.Sleep(d) 102 | s.update(nil) 103 | assert.InDelta(t, last, s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 104 | s.update(backOffTesterErr) 105 | d = time.Millisecond 106 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 107 | // Success first 108 | d = time.Millisecond 109 | s = backOffTracker{initial: d, jitter: 0 * funit.Percent} 110 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 111 | s.update(nil) 112 | assert.Exactly(t, time.Time{}, s.blackoutEnd()) 113 | s.update(backOffTesterErr) 114 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 115 | s.update(backOffTesterErr) 116 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 117 | time.Sleep(d) 118 | s.update(backOffTesterErr) 119 | d = max 120 | last = time.Now().Add(d).UnixNano() 121 | assert.InDelta(t, last, s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 122 | time.Sleep(d) 123 | s.update(nil) 124 | assert.InDelta(t, last, s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 125 | s.update(backOffTesterErr) 126 | d = time.Millisecond 127 | assert.InDelta(t, time.Now().Add(d).UnixNano(), s.blackoutEnd().UnixNano(), backOffTesterTimeDelta) 128 | } 129 | -------------------------------------------------------------------------------- /apns2/response.go: -------------------------------------------------------------------------------- 1 | // The MIT License (MIT) 2 | // 3 | // Copyright (c) 2016 Adam Jones 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | // 23 | // Modifications copyright 2017 Aleksey Blinov. All rights reserved. 24 | 25 | package apns2 26 | 27 | import ( 28 | "net/http" 29 | "strconv" 30 | "time" 31 | ) 32 | 33 | // StatusAcccepted is a 200 response. 34 | const StatusAcccepted = http.StatusOK 35 | 36 | // The possible Reason error codes returned from apns2. 37 | // From table 8-6 in the Apple Local and Remote Notification Programming Guide. 38 | const ( 39 | // 400 The collapse identifier exceeds the maximum allowed size 40 | ReasonBadCollapseID = "BadCollapseId" 41 | 42 | // 400 The specified device token was bad. Verify that the request contains a 43 | // valid token and that the token matches the environment. 44 | ReasonBadDeviceToken = "BadDeviceToken" 45 | 46 | // 400 The apns-expiration value is bad. 47 | ReasonBadExpirationDate = "BadExpirationDate" 48 | 49 | // 400 The apns-id value is bad. 50 | ReasonBadMessageID = "BadMessageId" 51 | 52 | // 400 The apns-priority value is bad. 53 | ReasonBadPriority = "BadPriority" 54 | 55 | // 400 The apns-topic was invalid. 56 | ReasonBadTopic = "BadTopic" 57 | 58 | // 400 The device token does not match the specified topic. 59 | ReasonDeviceTokenNotForTopic = "DeviceTokenNotForTopic" 60 | 61 | // 400 One or more headers were repeated. 62 | ReasonDuplicateHeaders = "DuplicateHeaders" 63 | 64 | // 400 Idle time out. 65 | ReasonIdleTimeout = "IdleTimeout" 66 | 67 | // 400 The device token is not specified in the request :path. Verify that the 68 | // :path header contains the device token. 69 | ReasonMissingDeviceToken = "MissingDeviceToken" 70 | 71 | // 400 The apns-topic header of the request was not specified and was 72 | // required. The apns-topic header is mandatory when the client is connected 73 | // using a certificate that supports multiple topics. 74 | ReasonMissingTopic = "MissingTopic" 75 | 76 | // 400 The message payload was empty. 77 | ReasonPayloadEmpty = "PayloadEmpty" 78 | 79 | // 400 Pushing to this topic is not allowed. 80 | ReasonTopicDisallowed = "TopicDisallowed" 81 | 82 | // 403 The certificate was bad. 83 | ReasonBadCertificate = "BadCertificate" 84 | 85 | // 403 The client certificate was for the wrong environment. 86 | ReasonBadCertificateEnvironment = "BadCertificateEnvironment" 87 | 88 | // 403 The provider token is stale and a new token should be generated. 89 | ReasonExpiredProviderToken = "ExpiredProviderToken" 90 | 91 | // 403 The specified action is not allowed. 92 | ReasonForbidden = "Forbidden" 93 | 94 | // 403 The provider token is not valid or the token signature could not be 95 | // verified. 96 | ReasonInvalidProviderToken = "InvalidProviderToken" 97 | 98 | // 403 No provider certificate was used to connect to APNs and Authorization 99 | // header was missing or no provider token was specified. 100 | ReasonMissingProviderToken = "MissingProviderToken" 101 | 102 | // 404 The request contained a bad :path value. 103 | ReasonBadPath = "BadPath" 104 | 105 | // 405 The specified :method was not POST. 106 | ReasonMethodNotAllowed = "MethodNotAllowed" 107 | 108 | // 410 The device token is inactive for the specified topic. 109 | ReasonUnregistered = "Unregistered" 110 | 111 | // 413 The message payload was too large. See Creating the Remote Notification 112 | // Payload in the Apple Local and Remote Notification Programming Guide for 113 | // details on maximum payload size. 114 | ReasonPayloadTooLarge = "PayloadTooLarge" 115 | 116 | // 429 The provider token is being updated too often. 117 | ReasonTooManyProviderTokenUpdates = "TooManyProviderTokenUpdates" 118 | 119 | // 429 Too many requests were made consecutively to the same device token. 120 | ReasonTooManyRequests = "TooManyRequests" 121 | 122 | // 500 An internal server error occurred. 123 | ReasonInternalServerError = "InternalServerError" 124 | 125 | // 503 The service is unavailable. 126 | ReasonServiceUnavailable = "ServiceUnavailable" 127 | 128 | // 503 The server is shutting down. 129 | ReasonShutdown = "Shutdown" 130 | ) 131 | 132 | // Response represents a result from the APN service indicating whether a 133 | // notification was accepted or rejected and (if applicable) any accompanying 134 | // data. 135 | type Response struct { 136 | 137 | // The ApnsID value from the Notification. If you didn't set an ApnsID in the 138 | // Notification, this will be a new unique UUID which has been created by apns2. 139 | ApnsID string 140 | 141 | // StatusCode is the HTTP status code returned by apns2. 142 | // A 200 value indicates that the notification was successfully sent. 143 | // For a list of other possible status codes, see table 6-4 in the Apple Local 144 | // and Remote Notification Programming Guide. 145 | StatusCode int 146 | 147 | // RejectionReason is the APNs error string indicating the reason 148 | // for the push failure (if any). The error code is specified as a string. 149 | // For a list of possible values, see the Reason constants above. 150 | // If the notification was accepted, this value will be "". 151 | RejectionReason string `json:"reason"` 152 | 153 | // If the value of StatusCode is 410, this is the last time at which APNs 154 | // confirmed that the device token was no longer valid for the topic. 155 | // TODO Make Response.UnsubscribedAt a time.Time and handle unmarshalling better 156 | UnsubscribedAt Time `json:"timestamp"` 157 | } 158 | 159 | // IsAccepted returns whether or not the notification was accepted by APN service. 160 | // This is the same as checking if the StatusCode == 200. 161 | func (c *Response) IsAccepted() bool { 162 | return c.StatusCode == StatusAcccepted 163 | } 164 | 165 | // Time represents a device uninstall time 166 | type Time struct { 167 | time.Time 168 | } 169 | 170 | // UnmarshalJSON converts an epoch date in milliseconds into a Time struct. 171 | func (t *Time) UnmarshalJSON(b []byte) error { 172 | ts, err := strconv.ParseInt(string(b), 10, 64) 173 | if err != nil { 174 | return err 175 | } 176 | t.Time = time.Unix(ts/1000, 1000000*(ts%1000)) 177 | return nil 178 | } 179 | -------------------------------------------------------------------------------- /apns2/notification.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "sync/atomic" 11 | "time" 12 | ) 13 | 14 | // Priority is the priority of the notification. 15 | // Allowable values are defined by APNs and are listed below. 16 | type Priority int 17 | 18 | const ( 19 | // PriorityLow instructs APNs to send the push message at a time 20 | // that takes into account power considerations for the device. 21 | // Notifications with this priority might be grouped and delivered 22 | //in bursts. They are throttled, and in some cases are not delivered. 23 | PriorityLow Priority = 5 24 | 25 | // PriorityHigh instructs APNs to send the push message immediately. 26 | // Notifications with this priority must trigger an alert, sound, 27 | // or badge on the target device. 28 | // It is an error to use this priority for a push notification 29 | // that contains only the content-available key. 30 | PriorityHigh = 10 31 | ) 32 | 33 | // Notification holds the data that is to be pushed to the recipient 34 | // as well as any routing information required to deliver it. 35 | // Routing headers and the notification payload are meant to remain immutable 36 | // and are intended to be shared accross multiple notifications if needed. 37 | // This is usefull when the same message needs to be deliverd to 38 | // many recipients. 39 | type Notification struct { 40 | // ApnsID is a canonical UUID that identifies the notification. 41 | // If there is an error sending the notification, APNs uses this value 42 | // to identify the notification in its response. 43 | // The canonical form is 32 lowercase hexadecimal digits, 44 | // displayed in five groups separated by hyphens in the form 8-4-4-4-12. 45 | // An example ApnsID is as follows: 123e4567-e89b-12d3-a456-42665544000 46 | // If omitted, a new ApnsID is created by APNs and returned in the response. 47 | ApnsID string 48 | 49 | // Recipient is the device token of the notification target. 50 | Recipient string 51 | 52 | // Header is a reference to a structure containing routing information. 53 | Header *Header 54 | 55 | // Payload is the notification data that is passed to the recipient. 56 | // Payload can be of any type that can be marshalled into a valid 57 | // JSON dictionary, a string representation of such dictionaty or 58 | // a slice of bytes of JSON encoding of such dictionary. 59 | Payload interface{} 60 | } 61 | 62 | // Header is a container for the routing information. 63 | // How a header is constructed and utilized is not constrained, 64 | // but the intent is to never modify it once created and assigned. 65 | // The same header can then be used for routing of any number of notifications. 66 | type Header struct { 67 | // The topic of the remote notification, which is typically the bundle ID 68 | // for your app. The certificate you create in your developer account 69 | // must include the capability for this topic. 70 | // If your certificate includes multiple topics, you must specify a value for this header. 71 | // If you omit this request header and your APNs certificate does not specify 72 | // multiple topics, the APNs server uses the certificate’s Subject as the default topic. 73 | // If you are using a provider token instead of a certificate, you must specify a value 74 | // for this request header. The topic you provide should be provisioned for the your team 75 | // named in your developer account. 76 | Topic string 77 | 78 | // CollapseID, if set, allows grouping of multiple notifications by apns2. 79 | // Multiple notifications with the same collapse identifier are displayed 80 | // to the user as a single notification. 81 | // The value of this field must not exceed 64 bytes. 82 | CollapseID string 83 | 84 | // Priority is the priority of the notification. 85 | // Specify ether apns2.PriorityHigh (10) or apns2.PriorityLow (5) 86 | // If you don't set this, the APNs server will set the priority to 10. 87 | Priority Priority 88 | 89 | // Expiration identifies the date when the notification is no longer valid 90 | // and can be discarded. 91 | // If this value is nonzero, APNs stores the notification 92 | // and tries to deliver it at least once, repeating the attempt as needed 93 | // if it is unable to deliver the notification the first time. 94 | // If the value is 0, APNs treats the notification as if it expires immediately 95 | // and does not store the notification or attempt to redeliver it. 96 | Expiration time.Time 97 | 98 | httpHeaders atomic.Value 99 | } 100 | 101 | func (n *Notification) write(r *http.Request) error { 102 | r.Header.Set("Content-Type", "application/json; charset=utf-8") 103 | if n.ApnsID != "" { 104 | r.Header.Set("apns-id", n.ApnsID) 105 | } 106 | n.Header.write(r) 107 | body, err := n.newPayloadReader() 108 | if err != nil { 109 | return err 110 | } 111 | r.Body = body 112 | r.ContentLength = body.Len() 113 | // TODO Move to a separate func congitional on go1.8 114 | // r.GetBody = func() (io.ReadCloser, error) { 115 | // return body.ResetClone(), nil 116 | // } 117 | return nil 118 | } 119 | 120 | func (n *Notification) newPayloadReader() (*sliceReader, error) { 121 | var buf []byte 122 | switch n.Payload.(type) { 123 | case []byte: 124 | buf = n.Payload.([]byte) 125 | case string: 126 | buf = []byte(n.Payload.(string)) 127 | default: 128 | var err error 129 | buf, err = json.Marshal(n.Payload) 130 | if err != nil { 131 | return nil, err 132 | } 133 | } 134 | return newSliceReader(buf), nil 135 | } 136 | 137 | func (h *Header) getHTTPHeaders() [][2]string { 138 | res := h.httpHeaders.Load() 139 | if res != nil { 140 | return res.([][2]string) 141 | } 142 | // We could protect this with a Mutex, but for improved throughput 143 | // it is probably better to avoid resource contention here and just 144 | // duplicate the work in case we have concurrent calls. 145 | hdrs := make([][2]string, 0, 4) 146 | if h.Topic != "" { 147 | hdrs = append(hdrs, [...]string{"apns-topic", h.Topic}) 148 | } 149 | if h.CollapseID != "" { 150 | hdrs = append(hdrs, [...]string{"apns-collapse-id", h.CollapseID}) 151 | } 152 | if h.Priority > 0 { 153 | hdrs = append(hdrs, [...]string{"apns-priority", fmt.Sprintf("%v", h.Priority)}) 154 | } 155 | if !h.Expiration.IsZero() { 156 | hdrs = append(hdrs, [...]string{"apns-expiration", fmt.Sprintf("%v", h.Expiration.Unix())}) 157 | } 158 | h.httpHeaders.Store(hdrs) 159 | return hdrs 160 | } 161 | 162 | func (h *Header) write(r *http.Request) error { 163 | for _, h := range h.getHTTPHeaders() { 164 | r.Header.Set(h[0], h[1]) 165 | } 166 | return nil 167 | } 168 | 169 | // sliceReader is ReaderCloser that doesn't take ownership of the slice. 170 | type sliceReader struct { 171 | buf []byte 172 | off int 173 | } 174 | 175 | func newSliceReader(buf []byte) *sliceReader { 176 | return &sliceReader{buf: buf} 177 | } 178 | 179 | // Read reads the next len(p) bytes from the buffer or until the buffer 180 | // is drained. The return value n is the number of bytes read. If the 181 | // buffer has no data to return, err is io.EOF (unless len(p) is zero); 182 | // otherwise it is nil. 183 | func (r *sliceReader) Read(p []byte) (n int, err error) { 184 | if r.off >= len(r.buf) { 185 | if len(p) == 0 { 186 | return 187 | } 188 | return 0, io.EOF 189 | } 190 | n = copy(p, r.buf[r.off:]) 191 | r.off += n 192 | return 193 | } 194 | 195 | func (r *sliceReader) Close() error { 196 | return nil 197 | } 198 | 199 | func (r *sliceReader) Len() int64 { 200 | return int64(len(r.buf)) 201 | } 202 | 203 | func (r *sliceReader) ResetClone() *sliceReader { 204 | return &sliceReader{buf: r.buf} 205 | } 206 | -------------------------------------------------------------------------------- /apns2/http.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "crypto/tls" 7 | "crypto/x509" 8 | "errors" 9 | "net" 10 | "net/http" 11 | "net/url" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | "github.com/baobabus/go-apns/http2x" 17 | "golang.org/x/net/http2" 18 | "golang.org/x/net/idna" 19 | ) 20 | 21 | // HTTP layer-related errors. 22 | var ( 23 | ErrHTTPClientClosed = errors.New("HTTPClient: attempt to close already closed client") 24 | ErrNoConnectionPool = errors.New("HTTPClient: no connection pool") 25 | ) 26 | 27 | // HTTPClient wraps http.Client and augments it with HTTP/2 stream 28 | // reservation facility. 29 | // 30 | // Due to current limitations of Go http2 client implementation, only a single 31 | // underlying http2.ClientConn is intended to be supported by the client. This 32 | // means that correct communication behavior is limited to a single HTTP/2 33 | // server endpoint. Note, however, that no attempt is made to restrict the way 34 | // in which the client is used, including handling of any encountered redirect 35 | // responses. 36 | type HTTPClient struct { 37 | http.Client 38 | 39 | addr string 40 | precise bool 41 | pollInt time.Duration 42 | cfgCap uint32 43 | 44 | mu sync.Mutex 45 | cond *sync.Cond 46 | connPool http2.ClientConnPool 47 | actCap uint32 48 | effCap uint32 49 | cnt uint32 50 | closed bool 51 | 52 | tkr *time.Ticker 53 | ctl chan struct{} 54 | 55 | initOnce sync.Once 56 | } 57 | 58 | // NewHTTPClient creates a new HTTPClient for handling HTTP requests 59 | // to a single specified gateway. 60 | // TLS client certificate cCert and custom root certificate authority rootCA 61 | // certificate are optional and can be nil. 62 | func NewHTTPClient(gateway string, commsCfg CommsCfg, cCert *tls.Certificate, rootCA *tls.Certificate) (*HTTPClient, error) { 63 | t := &http2.Transport{ 64 | DialTLS: makeDialer(commsCfg), 65 | DisableCompression: true, // As per Apple spec 66 | } 67 | tlsConfig := t.TLSClientConfig 68 | if cCert != nil { 69 | if tlsConfig == nil { 70 | tlsConfig = &tls.Config{} 71 | } 72 | tlsConfig.Certificates = []tls.Certificate{*cCert} 73 | if len(cCert.Certificate) > 0 { 74 | tlsConfig.BuildNameToCertificate() 75 | } 76 | } 77 | if rootCA != nil && len(rootCA.Certificate[0]) > 0 { 78 | if tlsConfig == nil { 79 | tlsConfig = &tls.Config{} 80 | } 81 | rCert, err := x509.ParseCertificate(rootCA.Certificate[0]) 82 | if err != nil { 83 | return nil, err 84 | } 85 | certpool := x509.NewCertPool() 86 | certpool.AddCert(rCert) 87 | tlsConfig.RootCAs = certpool 88 | } 89 | t.TLSClientConfig = tlsConfig 90 | url, _ := url.ParseRequestURI(gateway) 91 | res := &HTTPClient{ 92 | Client: http.Client{ 93 | Transport: t, 94 | Timeout: commsCfg.RequestTimeout, 95 | }, 96 | addr: authorityAddr(url.Scheme, url.Host), 97 | precise: false, 98 | pollInt: 0, 99 | cfgCap: 1, 100 | } 101 | return res, nil 102 | } 103 | 104 | func (c *HTTPClient) init() { 105 | c.cond = sync.NewCond(&c.mu) 106 | c.effCap = 1 // assume just 1 until connection is open 107 | if c.precise || c.pollInt > 0 { 108 | c.connPool, _ = http2x.GetClientConnPool(c.Client.Transport) 109 | c.refreshCap() 110 | } 111 | if c.connPool != nil && c.pollInt > 0 { 112 | c.tkr = time.NewTicker(c.pollInt) 113 | c.ctl = make(chan struct{}) 114 | go func() { 115 | select { 116 | case <-c.tkr.C: 117 | c.refreshCap() 118 | case <-c.ctl: 119 | return 120 | } 121 | }() 122 | } 123 | } 124 | 125 | // getClientConn returns http2.ClientConn from HTTPClient's connection pool. 126 | func (c *HTTPClient) getClientConn() (*http2.ClientConn, error) { 127 | c.initOnce.Do(c.init) 128 | if c.connPool == nil { 129 | // http2 incursion is disabled, so this it not an error 130 | return nil, nil 131 | } 132 | return http2x.GetClientConn(c.connPool, c.addr) 133 | } 134 | 135 | // ReservedStream returns a reserved HTTP2Stream in the client's 136 | // HTTP/2 connection, or a non-nil error. 137 | func (c *HTTPClient) ReservedStream(cancel func(<-chan struct{}) error) (*HTTP2Stream, error) { 138 | c.initOnce.Do(c.init) 139 | c.mu.Lock() 140 | defer c.mu.Unlock() 141 | if c.precise { 142 | c.refreshCapLocked() 143 | } 144 | var cerr error 145 | for cnlLaunched := false; c.effCap > 0 && c.cnt >= c.effCap && cerr == nil; { 146 | if !cnlLaunched && cancel != nil { 147 | done := make(chan struct{}) 148 | defer close(done) 149 | go func() { 150 | if err := cancel(done); err != nil { 151 | // Must guard access to cerr. 152 | // Atomic store and load could be more efficient. 153 | c.mu.Lock() 154 | cerr = err 155 | c.mu.Unlock() 156 | c.cond.Broadcast() 157 | } 158 | }() 159 | cnlLaunched = true 160 | } 161 | c.cond.Wait() 162 | } 163 | if cerr != nil { 164 | return nil, cerr 165 | } 166 | // This may need to be its own error 167 | // if c.effCap == 0 { 168 | // return nil, ErrZeroCapacity 169 | // } 170 | c.cnt++ 171 | // TODO Consider using sync.Pool for HTTP2Stream instances. 172 | return &HTTP2Stream{client: c}, nil 173 | } 174 | 175 | func (c *HTTPClient) Close() error { 176 | c.initOnce.Do(c.init) 177 | if c.closed { 178 | return ErrHTTPClientClosed 179 | } 180 | c.mu.Lock() 181 | defer c.mu.Unlock() 182 | if c.tkr != nil { 183 | c.tkr.Stop() 184 | close(c.ctl) 185 | } 186 | // Client and everything underneath should be GC'd soon 187 | // and that should take care of closing any open connections. 188 | // Not sure if present http2 state of client can be trusted, though. 189 | if c.Client.Transport != nil { 190 | if t2, ok := c.Client.Transport.(*http2.Transport); ok { 191 | // All streams must be closed for the below to have any effect. 192 | t2.CloseIdleConnections() 193 | } 194 | } 195 | c.closed = true 196 | return nil 197 | } 198 | 199 | func (c *HTTPClient) release() { 200 | c.mu.Lock() 201 | defer c.mu.Unlock() 202 | if c.cnt > 0 { 203 | c.cnt-- 204 | if c.cnt < c.effCap { 205 | c.cond.Broadcast() 206 | } 207 | } 208 | } 209 | 210 | func (c *HTTPClient) refreshCap() { 211 | c.mu.Lock() 212 | defer c.mu.Unlock() 213 | c.refreshCapLocked() 214 | } 215 | 216 | func (c *HTTPClient) refreshCapLocked() { 217 | if c.connPool == nil { 218 | return 219 | } 220 | conn, err := http2x.GetClientConn(c.connPool, c.addr) 221 | if err != nil { 222 | return 223 | } 224 | c.actCap = http2x.GetMaxConcurrentStreams(conn) 225 | logTrace(0, "HTTClient", "Max streams = %d\n", c.actCap) 226 | v := c.actCap 227 | if v > c.cfgCap { 228 | v = c.cfgCap 229 | } 230 | notif := c.effCap < v || (c.effCap > 0 && v == 0) 231 | c.effCap = v 232 | if notif { 233 | c.cond.Broadcast() 234 | } 235 | } 236 | 237 | // HTTP2Stream is a token indicating a stream reservation in one 238 | // of the HTTPClient's HTTP/2 connections. 239 | type HTTP2Stream struct { 240 | client *HTTPClient 241 | } 242 | 243 | // Close releases an HTTP/2 stream reservation. 244 | func (s *HTTP2Stream) Close() { 245 | s.client.release() 246 | } 247 | 248 | // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) 249 | // and returns a host:port. The port 443 is added if needed. 250 | func authorityAddr(scheme string, authority string) string { 251 | host, port, err := net.SplitHostPort(authority) 252 | if err != nil { // authority didn't have a port 253 | port = "443" 254 | if scheme == "http" { 255 | port = "80" 256 | } 257 | host = authority 258 | } 259 | if a, err := idna.ToASCII(host); err == nil { 260 | host = a 261 | } 262 | // IPv6 address literal, without a port: 263 | if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { 264 | return host + ":" + port 265 | } 266 | return net.JoinHostPort(host, port) 267 | } 268 | -------------------------------------------------------------------------------- /cryptox/crypto.go: -------------------------------------------------------------------------------- 1 | // The MIT License (MIT) 2 | // 3 | // Copyright (c) 2016 Adam Jones 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | // 23 | // Modifications copyright 2017 Aleksey Blinov. All rights reserved. 24 | 25 | package cryptox 26 | 27 | import ( 28 | "crypto" 29 | "crypto/ecdsa" 30 | "crypto/tls" 31 | "crypto/x509" 32 | "encoding/pem" 33 | "errors" 34 | "io/ioutil" 35 | 36 | "golang.org/x/crypto/pkcs12" 37 | ) 38 | 39 | const ( 40 | PEM_X509 string = "CERTIFICATE" 41 | PEM_RSA = "RSA PRIVATE KEY" 42 | PEM_PKCS8 = "ENCRYPTED PRIVATE KEY" 43 | PEM_PKCS8INF = "PRIVATE KEY" 44 | ) 45 | 46 | var ( 47 | ErrPKCS8NotPem = errors.New("PKCS8PrivateKey: invalid .p8 PEM file") 48 | ErrPKCS8NotECDSA = errors.New("PKCS8PrivateKey: key must be of type ecdsa.PrivateKey") 49 | ErrPEMMissingPrivateKey = errors.New("PEM: private key not found") 50 | ErrPEMMissingCertificate = errors.New("PEM: certificate not found") 51 | ) 52 | 53 | // PKCS8PrivateKeyFromFile loads a .p8 certificate from a local file and returns a 54 | // *ecdsa.PrivateKey. 55 | func PKCS8PrivateKeyFromFile(filename string) (*ecdsa.PrivateKey, error) { 56 | bytes, err := ioutil.ReadFile(filename) 57 | if err != nil { 58 | return nil, err 59 | } 60 | return PKCS8PrivateKeyFromBytes(bytes) 61 | } 62 | 63 | // PKCS8PrivateKeyFromBytes decodes a .p8 certificate from an in memory byte slice and 64 | // returns an *ecdsa.PrivateKey. 65 | func PKCS8PrivateKeyFromBytes(bytes []byte) (*ecdsa.PrivateKey, error) { 66 | block, _ := pem.Decode(bytes) 67 | if block == nil { 68 | return nil, ErrPKCS8NotPem 69 | } 70 | key, err := x509.ParsePKCS8PrivateKey(block.Bytes) 71 | if err != nil { 72 | return nil, err 73 | } 74 | switch pk := key.(type) { 75 | case *ecdsa.PrivateKey: 76 | return pk, nil 77 | default: 78 | return nil, ErrPKCS8NotECDSA 79 | } 80 | } 81 | 82 | // ClientCertFromP12File loads a PKCS#12 certificate from a local file and returns a 83 | // tls.Certificate. 84 | // 85 | // Use "" as the password argument if the PKCS#12 certificate is not password 86 | // protected. 87 | func ClientCertFromP12File(filename string, password string) (tls.Certificate, error) { 88 | p12bytes, err := ioutil.ReadFile(filename) 89 | if err != nil { 90 | return tls.Certificate{}, err 91 | } 92 | return ClientCertFromP12Bytes(p12bytes, password) 93 | } 94 | 95 | // ClientCertFromP12Bytes loads a PKCS#12 certificate from an in memory byte array and 96 | // returns a tls.Certificate. 97 | // 98 | // Use "" as the password argument if the PKCS#12 certificate is not password 99 | // protected. 100 | func ClientCertFromP12Bytes(bytes []byte, password string) (tls.Certificate, error) { 101 | key, cert, err := pkcs12.Decode(bytes, password) 102 | if err != nil { 103 | return tls.Certificate{}, err 104 | } 105 | return tls.Certificate{ 106 | Certificate: [][]byte{cert.Raw}, 107 | PrivateKey: key, 108 | Leaf: cert, 109 | }, nil 110 | } 111 | 112 | // ClientCertFromPemFile loads a PEM certificate from a local file and returns a 113 | // tls.Certificate. This function is similar to the crypto/tls LoadX509KeyPair 114 | // function, however it supports PEM files with the cert and key combined 115 | // in the same file, as well as password protected key files which are both 116 | // common with APNs certificates. 117 | // 118 | // Use "" as the password argument if the PEM certificate is not password 119 | // protected. 120 | func ClientCertFromPemFile(filename string, password string) (tls.Certificate, error) { 121 | bytes, err := ioutil.ReadFile(filename) 122 | if err != nil { 123 | return tls.Certificate{}, err 124 | } 125 | return ClientCertFromPemBytes(bytes, password) 126 | } 127 | 128 | // ClientCertFromPemBytes loads a PEM certificate from an in memory byte array and 129 | // returns a tls.Certificate. This function is similar to the crypto/tls 130 | // X509KeyPair function, however it supports PEM files with the cert and 131 | // key combined, as well as password protected keys which are both common with 132 | // APNs certificates. 133 | // 134 | // Use "" as the password argument if the PEM certificate is not password 135 | // protected. 136 | func ClientCertFromPemBytes(bytes []byte, password string) (tls.Certificate, error) { 137 | var cert tls.Certificate 138 | var block *pem.Block 139 | for { 140 | block, bytes = pem.Decode(bytes) 141 | if block == nil { 142 | break 143 | } 144 | var isRSA bool 145 | switch block.Type { 146 | case PEM_X509: 147 | cert.Certificate = append(cert.Certificate, block.Bytes) 148 | case PEM_RSA: 149 | isRSA = true 150 | fallthrough 151 | case PEM_PKCS8, PEM_PKCS8INF: 152 | key, err := decryptPrivateKey(block, password, isRSA) 153 | if err != nil { 154 | return tls.Certificate{}, err 155 | } 156 | cert.PrivateKey = key 157 | } 158 | } 159 | if len(cert.Certificate) == 0 { 160 | return tls.Certificate{}, ErrPEMMissingCertificate 161 | } 162 | if cert.PrivateKey == nil { 163 | return tls.Certificate{}, ErrPEMMissingPrivateKey 164 | } 165 | if c, err := x509.ParseCertificate(cert.Certificate[0]); err == nil { 166 | cert.Leaf = c 167 | } 168 | return cert, nil 169 | } 170 | 171 | // RootCAFromPemFile loads a PEM certificate from a local file and returns a 172 | // tls.Certificate. 173 | func RootCAFromPemFile(filename string) (tls.Certificate, error) { 174 | bytes, err := ioutil.ReadFile(filename) 175 | if err != nil { 176 | return tls.Certificate{}, err 177 | } 178 | return RootCAFromPemBytes(bytes) 179 | } 180 | 181 | // RootCAFromPemBytes loads a PEM certificate from an in memory byte array and 182 | // returns a tls.Certificate. 183 | func RootCAFromPemBytes(bytes []byte) (tls.Certificate, error) { 184 | var cert tls.Certificate 185 | var block *pem.Block 186 | for { 187 | block, bytes = pem.Decode(bytes) 188 | if block == nil { 189 | break 190 | } 191 | if block.Type == PEM_X509 { 192 | cert.Certificate = append(cert.Certificate, block.Bytes) 193 | } 194 | } 195 | if len(cert.Certificate) == 0 { 196 | return tls.Certificate{}, ErrPEMMissingCertificate 197 | } 198 | // This should not be needed: 199 | // if c, err := x509.ParseCertificate(cert.Certificate[0]); err == nil { 200 | // cert.Leaf = c 201 | // } 202 | return cert, nil 203 | } 204 | 205 | func decryptPrivateKey(block *pem.Block, password string, isRSA bool) (crypto.PrivateKey, error) { 206 | bytes := block.Bytes 207 | if x509.IsEncryptedPEMBlock(block) { 208 | var err error 209 | bytes, err = x509.DecryptPEMBlock(block, []byte(password)) 210 | if err != nil { 211 | return nil, err 212 | } 213 | } 214 | return parsePrivateKey(bytes, isRSA) 215 | } 216 | 217 | func parsePrivateKey(bytes []byte, isRSA bool) (res crypto.PrivateKey, err error) { 218 | // Thanks to @jameshfisher for brining up PKCS8 case 219 | if isRSA { 220 | res, err = x509.ParsePKCS1PrivateKey(bytes) 221 | } else { 222 | res, err = x509.ParsePKCS8PrivateKey(bytes) 223 | } 224 | return res, err 225 | } 226 | -------------------------------------------------------------------------------- /apns2/streamer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | "strings" 10 | "sync" 11 | "time" 12 | 13 | "github.com/baobabus/go-apns/syncx" 14 | ) 15 | 16 | // Each streamer "owns" a single HTTPClient on top of an HTTP/2 transport. 17 | // It streams requests to and handles responses from APN servers while 18 | // coordinating HTTP/2 stream utilization. 19 | type streamer struct { 20 | id string 21 | c *Client 22 | gov *governor 23 | in <-chan *Request 24 | out chan<- *Result 25 | ctl chan struct{} 26 | done chan<- *streamer 27 | 28 | warmStart bool 29 | 30 | startOnce sync.Once 31 | startErr error 32 | 33 | httpClient *HTTPClient 34 | 35 | // counter for waits on outbound channel 36 | waitCtr syncx.TickTockCounter 37 | // cumulative request sizes in bytes 38 | sizeCtr syncx.Counter 39 | 40 | // wait group for spawned HTTP/2 roundrips 41 | wg sync.WaitGroup 42 | 43 | didQuit bool 44 | inClosed bool 45 | } 46 | 47 | func (s *streamer) start(wg *sync.WaitGroup) error { 48 | s.startOnce.Do(func() { 49 | logInfo(s.id, "Starting.") 50 | s.httpClient, s.startErr = NewHTTPClient(s.c.Gateway, s.c.CommsCfg, s.c.Certificate, s.c.RootCA) 51 | if s.startErr != nil { 52 | return 53 | } 54 | var pollInt time.Duration 55 | if s.gov.cfg.AllowHTTP2Incursion && !s.gov.cfg.UsePreciseHTTP2Metrics { 56 | pollInt = s.gov.cfg.HTTP2MetricsRefreshPeriod 57 | } 58 | s.httpClient.precise = s.gov.cfg.AllowHTTP2Incursion && s.gov.cfg.UsePreciseHTTP2Metrics 59 | s.httpClient.pollInt = pollInt 60 | s.httpClient.cfgCap = s.c.CommsCfg.MaxConcurrentStreams 61 | if s.warmStart { 62 | // This can also be accomplished by sending a malformed http.Request. 63 | // No reflection is required, but it's still a kludge and results 64 | // in error being logged. 65 | _, s.startErr = s.httpClient.getClientConn() 66 | } 67 | if s.startErr != nil { 68 | return 69 | } 70 | if wg != nil { 71 | wg.Add(1) 72 | } 73 | go s.run(wg) 74 | }) 75 | return s.startErr 76 | } 77 | 78 | func (s *streamer) run(wg *sync.WaitGroup) { 79 | logInfo(s.id, "Running.") 80 | for done := false; !done; { 81 | select { 82 | case req, ok := <-s.in: 83 | if !ok { 84 | // soft shutdown - wait for pending roundtrips to complete 85 | logInfo(s.id, "Stopping.") 86 | // TODO Switch from WaitGroup to channel signal 87 | s.wg.Wait() 88 | done = true 89 | s.inClosed = true 90 | break 91 | } 92 | s.exec(req) 93 | case _, ok := <-s.ctl: 94 | if ok { 95 | // unusable connection 96 | s.didQuit = true 97 | logInfo(s.id, "Quitting.") 98 | } else { 99 | // hard shutdown - do not wait for pending roundtrips to complete 100 | logInfo(s.id, "Terminating.") 101 | } 102 | // TODO Cancel pending roundtrips' contexts. 103 | done = true 104 | } 105 | } 106 | // This will only have effect if all roundtrips are finished. 107 | s.httpClient.Close() 108 | // read from ctl prevents blocking on done if the governor 109 | // was commanded to terminate in the meantime 110 | select { 111 | case s.done <- s: 112 | case <-s.ctl: 113 | } 114 | if wg != nil { 115 | wg.Done() 116 | } 117 | logInfo(s.id, "Stopped.") 118 | } 119 | 120 | func (s *streamer) exec(req *Request) { 121 | logTrace(0, s.id, "Serving %v.", req) 122 | if s.c.Certificate == nil && (req.Signer == NoSigner || !s.c.HasSigner() && !req.HasSigner()) { 123 | s.callBack(req, nil, ErrMissingAuth) 124 | return 125 | } 126 | hasCtx := req.Context != NoContext 127 | canceled := false 128 | // TODO Move the below to HTTP/2 stream wait code 129 | if hasCtx { 130 | select { 131 | case <-req.Context.Done(): 132 | canceled = true 133 | default: 134 | } 135 | } 136 | if canceled { 137 | s.callBack(req, nil, ErrCanceled) 138 | return 139 | } 140 | var cancel func(done <-chan struct{}) error 141 | if hasCtx { 142 | // Waits for the user to cancel a request's context. 143 | cancel = func(done <-chan struct{}) error { 144 | ctx := req.Context 145 | if ctx.Done() == nil { 146 | return nil 147 | } 148 | select { 149 | case <-ctx.Done(): 150 | return ctx.Err() 151 | case <-done: 152 | return nil 153 | } 154 | } 155 | } 156 | // 1. Acquire HTTP/2 stream 157 | // This can block and is the primary source of back pressure. 158 | st, err := s.httpClient.ReservedStream(cancel) 159 | if err != nil { 160 | s.callBack(req, nil, err) 161 | return 162 | } 163 | // 2. go submit() 164 | s.wg.Add(1) 165 | go func() { 166 | defer st.Close() 167 | defer s.wg.Done() 168 | resp, err := s.submit(req) 169 | if err != nil && uint32(req.attemptCnt) < s.gov.cfg.MaxRetries && s.isRetriable(resp, err) { 170 | req.attemptCnt++ 171 | // Retry is serviced in a timely manner, so no need to worry about blocking. 172 | // There's just a potential issue with retry forwarder stopping reads 173 | // due to a signal on its ctl channel with streamers still running. 174 | // Forwarder's ctl channel shoulnd't be shared with governor. 175 | s.gov.retry <- req 176 | return 177 | } 178 | s.callBack(req, resp, err) 179 | if !s.isConnUsable(resp, err) { 180 | // Each worker is given its own ctl channel, but we cannot close it here. 181 | // Writing to it accomplishes the same thing. Just do not block. 182 | var v struct{} 183 | select { 184 | case s.ctl <- v: 185 | default: 186 | } 187 | } 188 | }() 189 | } 190 | 191 | // Submits request to APN service and returns APN response or an error. 192 | func (s *streamer) submit(req *Request) (*Response, error) { 193 | url := s.c.Gateway + RequestRoot + req.Notification.Recipient 194 | httpReq, err := http.NewRequest("POST", url, nil) 195 | if err != nil { 196 | return nil, &RequestError{err} 197 | } 198 | if err := req.Notification.write(httpReq); err != nil { 199 | return nil, &RequestError{err} 200 | } 201 | signer := req.Signer 202 | if signer == nil { 203 | signer = s.c.Signer 204 | } 205 | if signer != nil { 206 | if err := signer.SignRequest(httpReq); err != nil { 207 | return nil, &RequestError{err} 208 | } 209 | } 210 | if req.Context != NoContext { 211 | httpReq = httpReq.WithContext(req.Context) 212 | } 213 | logTrace(2, s.id, "http.Request: %v\n", httpReq) 214 | httpResp, err := s.httpClient.Do(httpReq) 215 | if err != nil { 216 | return nil, err 217 | } 218 | s.sizeCtr.Add(uint64(estimatedRequestWireSize(httpReq))) 219 | logTrace(2, s.id, "http.Response: %v\n", httpResp) 220 | defer httpResp.Body.Close() 221 | res := &Response{ 222 | StatusCode: httpResp.StatusCode, 223 | ApnsID: httpResp.Header.Get("apns-id"), 224 | } 225 | decoder := json.NewDecoder(httpResp.Body) 226 | if err := decoder.Decode(&res); err != nil && err != io.EOF { 227 | return &Response{}, &RequestError{err} 228 | } 229 | return res, nil 230 | } 231 | 232 | func (s *streamer) callBack(req *Request, resp *Response, err error) { 233 | res := &Result{ 234 | Notification: req.Notification, 235 | Signer: req.Signer, 236 | Context: req.Context, 237 | Response: resp, 238 | Err: err, 239 | } 240 | if req.Callback == NoCallback { 241 | return 242 | } 243 | tgt := s.out 244 | if req.Callback != nil { 245 | tgt = req.Callback 246 | } 247 | if tgt != nil && tgt != NoCallback { 248 | isBlocked := false 249 | select { 250 | case tgt <- res: 251 | default: 252 | isBlocked = true 253 | } 254 | if !isBlocked { 255 | return 256 | } 257 | s.waitCtr.Tick() 258 | select { 259 | case tgt <- res: 260 | case <-s.ctl: 261 | } 262 | s.waitCtr.Tock() 263 | } 264 | } 265 | 266 | func (s *streamer) isRetriable(resp *Response, err error) bool { 267 | if resp == nil && err != nil { 268 | return false 269 | } 270 | if s.gov.cfg.RetryEval != nil { 271 | return s.gov.cfg.RetryEval(resp, err) 272 | } 273 | return false 274 | } 275 | 276 | func (s *streamer) isConnUsable(resp *Response, err error) bool { 277 | if resp == nil && err != nil { 278 | switch err.(type) { 279 | case *RequestError: 280 | // Request-level error 281 | return true 282 | default: 283 | // Error from http.Client.Do() 284 | // "Invalid method" is our fault and not recoverable. 285 | if strings.HasPrefix(err.Error(), "net/http: ") { 286 | return false 287 | } 288 | // TODO Consider other possibilities 289 | return false 290 | } 291 | } 292 | if resp != nil { 293 | switch resp.StatusCode { 294 | case http.StatusServiceUnavailable, 295 | http.StatusMethodNotAllowed: 296 | return true 297 | case http.StatusBadRequest: 298 | return resp.RejectionReason != ReasonIdleTimeout 299 | case http.StatusForbidden: 300 | return resp.RejectionReason != ReasonBadCertificate && resp.RejectionReason != ReasonBadCertificateEnvironment 301 | } 302 | } 303 | return true 304 | } 305 | 306 | var baseReqWireSizeSize = uint64(5 + len(RequestRoot)) 307 | 308 | // Only an estimate and only based on the fields we use. I.e. cookie sizes 309 | // are not included. 310 | func estimatedRequestWireSize(req *http.Request) (res int) { 311 | res = len(req.Host) + // this needs to be counted in addition 312 | len(req.URL.RawPath) + // not .EscapedPath() as no escaping is needed in our case 313 | int(req.ContentLength) + // We know we set it 314 | 14 + // for "POST " and " HTTP/2.0" 315 | estimatedHeaderWireSize(req.Header) 316 | return res 317 | } 318 | 319 | // Only an estimate and under the assupmtion that no duplicates are present 320 | func estimatedHeaderWireSize(hs http.Header) (res int) { 321 | for h, vs := range hs { 322 | res += len(h) + 4 // account for ": " and "\r\n" 323 | for _, v := range vs { 324 | res += len(v) 325 | break // no duplicates allowed 326 | } 327 | } 328 | return res 329 | } 330 | -------------------------------------------------------------------------------- /apns2/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "context" 7 | "crypto/tls" 8 | "errors" 9 | "sync" 10 | 11 | "github.com/baobabus/go-apns/syncx" 12 | ) 13 | 14 | // Gateway holds APN service's Development & Production urls. 15 | // These use default HTTPS port 443. According to Apple you can 16 | // alternatively use port 2197 if needed. 17 | var Gateway = struct { 18 | Development string 19 | Production string 20 | }{ 21 | Development: "https://api.development.push.apple.com", 22 | Production: "https://api.push.apple.com", 23 | } 24 | 25 | // APNS default root URL path. 26 | const RequestRoot = "/3/device/" 27 | 28 | var ( 29 | ErrMissingAuth = errors.New("apns2: authentication is not possible with no client certificate and no signer") 30 | ErrClientNotRunning = errors.New("apns2: client processing pipeline not running") 31 | ErrClientAlreadyStarted = errors.New("apns2: client processing pipeline already started") 32 | ErrClientAlreadyClosed = errors.New("apns2: client processing pipeline already closed") 33 | ErrPushInterrupted = errors.New("apns2: push request interrupted") 34 | ErrCanceled = errors.New("apns2: push request canceled") 35 | ) 36 | 37 | // NoSigner can be used where a RequestSigner is required when a push request 38 | // need not be signed. 39 | var NoSigner RequestSigner 40 | 41 | // DefaultSigner can be used instead of nil value where a RequestSigner 42 | // is required to indicate that a push request should be signed with client's 43 | // default signer. 44 | var DefaultSigner RequestSigner 45 | 46 | // NoContext can be used instead of nil value to indicate no cancellation 47 | // context. 48 | var NoContext context.Context 49 | 50 | // NoCallback is used to indicate that results of push notification requests 51 | // should be silently discarded. 52 | var NoCallback chan *Result 53 | 54 | // DefaultCallback can be used instead of nil value to idicate that client's 55 | // default callback channel should be used to communicate back the result 56 | // of a push request. 57 | var DefaultCallback chan<- *Result 58 | 59 | // Client provides the means for asynchronous communication with APN service. 60 | // It is safe to use one client in concurrent goroutines and issue concurrent 61 | // push requests. 62 | // 63 | // As per APN service guidelines, you should keep a handle on this client 64 | // so that you can keep your connections with APN servers open. 65 | // Repeatedly opening and closing connections in rapid succession is 66 | // treated by Apple as a denial-of-service attack. 67 | type Client struct { 68 | 69 | // Id identifies client in log entries. 70 | Id string 71 | 72 | // Gateway is the APN service connection endpoint. 73 | // Apple publishes two public endpoints: production and development. 74 | // They are preconfigured in Gateway.Production and Gateway.Development. 75 | Gateway string 76 | 77 | // CommsCfg contains communication settings to be used by the client. 78 | // See CommsCfg type declaration for additional details. 79 | CommsCfg CommsCfg 80 | 81 | // ProcCfg contains autoscaling settings. 82 | // See ProcCfg type declaration for additional details. 83 | ProcCfg ProcCfg 84 | 85 | // Certificate, if not nil, is used in the client side configuration 86 | // of the TLS connections to APN servers. 87 | // This is one of the authentication methods supported by APN service. 88 | Certificate *tls.Certificate 89 | 90 | // RootCA, if not nil, can be used to specify an alternative root 91 | // certificate authority. This should only be needed in testing, or 92 | // if you system's root certificate authorities are not set up. 93 | RootCA *tls.Certificate 94 | 95 | // Signer, if not nil, is used to sign individual requests to APN service. 96 | Signer RequestSigner 97 | 98 | // Queue for submitting push requests. 99 | // 100 | // You can use it directly in your code, especially in select statements 101 | // when coordination with other channels is desired. 102 | // Alternatively client's Push method can be used. 103 | Queue <-chan *Request 104 | 105 | // Callback, if not nil, specifies the channel to which the outcome of 106 | // the push request executions should be delivered. 107 | // If Callback is nil and a request doesn't specify an alternative callback, 108 | // requests execution result is silently dropped. 109 | Callback chan<- *Result 110 | 111 | retry chan *Request 112 | 113 | out chan *Request 114 | gov *governor 115 | 116 | mu sync.RWMutex 117 | state uint 118 | wg sync.WaitGroup 119 | ctl chan struct{} // our control channel 120 | cctl chan struct{} // submitter control channel 121 | gctl chan struct{} // governor control channel 122 | cdone chan struct{} // pipeline done processing signal 123 | 124 | // counter for waits on outbound channel 125 | waitCtr syncx.TickTockCounter 126 | // counter of processed requests 127 | rateCtr syncx.Counter 128 | } 129 | 130 | const ( 131 | stateInitial uint = iota 132 | stateStarting 133 | stateRunning 134 | stateStopping 135 | stateTerminating 136 | stateClosed 137 | ) 138 | 139 | // Start starts Client processing pipeline. If the client has already 140 | // been started, ErrClientAlreadyStarted error is returned. 141 | func (c *Client) Start(wg *sync.WaitGroup) error { 142 | c.mu.Lock() 143 | defer c.mu.Unlock() 144 | if len(c.Id) == 0 { 145 | c.Id = "Client" 146 | } 147 | if c.state >= stateStarting { 148 | return ErrClientAlreadyStarted 149 | } 150 | c.state = stateStarting 151 | logInfo(c.Id, "Starting.") 152 | if wg != nil { 153 | wg.Add(1) 154 | } 155 | c.wg.Add(1) 156 | c.ctl = make(chan struct{}) 157 | c.cctl = make(chan struct{}) 158 | c.gctl = make(chan struct{}) 159 | c.cdone = make(chan struct{}) 160 | c.out = make(chan *Request) 161 | c.retry = make(chan *Request) 162 | c.gov = &governor{ 163 | id: c.Id + "-Governor", 164 | c: c, 165 | ctl: c.gctl, 166 | done: c.cdone, 167 | cfg: c.ProcCfg, 168 | minSust: c.ProcCfg.minSustainPollPeriods(), 169 | } 170 | // TODO Figure out coordination of governor and retrier shutdowns. 171 | go c.gov.run() 172 | go c.runSubmitter(wg) 173 | return nil 174 | } 175 | 176 | // Stop performs soft shutdown of the Client. All inflight requests are 177 | // given the chance to be executed. 178 | func (c *Client) Stop() error { 179 | c.mu.Lock() 180 | if c.state >= stateStopping { 181 | c.mu.Unlock() 182 | return ErrClientAlreadyClosed 183 | } 184 | c.state = stateStopping 185 | logInfo(c.Id, "Stopping.") 186 | close(c.cctl) // stop submitter 187 | c.mu.Unlock() 188 | c.wg.Wait() 189 | close(c.out) 190 | // Block until all processing is complete 191 | // or we are signaled to terminate. 192 | select { 193 | case <-c.cdone: 194 | case <-c.ctl: 195 | } 196 | if c.Callback != nil && c.Callback != NoCallback { 197 | close(c.Callback) 198 | } 199 | logInfo(c.Id, "Stopped.") 200 | return nil 201 | } 202 | 203 | // Kill performs hard shutdown of the Client without waiting for the processing 204 | // pipeline to unwind. Inflight requests are discarded. 205 | func (c *Client) Kill() error { 206 | c.mu.Lock() 207 | if c.state >= stateTerminating { 208 | c.mu.Unlock() 209 | return ErrClientAlreadyClosed 210 | } 211 | wasStopping := c.state == stateStopping 212 | c.state = stateTerminating 213 | logInfo(c.Id, "Terminating.") 214 | if !wasStopping { 215 | close(c.cctl) 216 | } 217 | close(c.gctl) 218 | close(c.ctl) // unblock pending Stop() if there's one 219 | c.mu.Unlock() 220 | logInfo(c.Id, "Terminated.") 221 | return nil 222 | } 223 | 224 | // Push asynchronously sends a Notification to the APN service. 225 | // Context carries a deadline and a cancellation signal and allows you to close 226 | // long running requests when the context timeout is exceeded. 227 | // Context can be nil or NoContext if no cancellation functionality 228 | // is desired. 229 | // 230 | // If not nil, the supplied signer is asked to sign the request before 231 | // submitting it to APN service. If the supplied signer is nil, but client's 232 | // signer was configured at the initialization time, the client's signer will 233 | // sign the request. NoSigner can be specified if the request must not be signed. 234 | // 235 | // This method will block if downstream capacity is exceeded. For non-blocking 236 | // behavior or to allow coordination with activity on other channels consider 237 | // creating a Request instance and writing it to client's Queue directly. 238 | func (c *Client) Push(n *Notification, signer RequestSigner, ctx context.Context, callback chan<- *Result) error { 239 | c.mu.RLock() 240 | state := c.state 241 | c.mu.RUnlock() 242 | if state < stateStarting || state > stateRunning { 243 | return ErrClientNotRunning 244 | } 245 | // Ensure that authentication is possible 246 | if c.Certificate == nil && (signer == NoSigner || !c.HasSigner() && signer == DefaultSigner) { 247 | return ErrMissingAuth 248 | } 249 | // Everything else is done asynchronously 250 | req := &Request{ 251 | Notification: n, 252 | Signer: signer, 253 | Context: ctx, 254 | Callback: callback, 255 | } 256 | err := c.submit(req) 257 | return err 258 | } 259 | 260 | // HasSigner returns `true` if there is a non-default signer configured 261 | // for signing push requests. 262 | func (c *Client) HasSigner() bool { 263 | return c.Signer != DefaultSigner 264 | } 265 | 266 | // TODO Separate submitter out 267 | func (c *Client) runSubmitter(wg *sync.WaitGroup) { 268 | done := false 269 | c.mu.Lock() 270 | if c.state != stateStarting { 271 | done = true 272 | } else { 273 | c.state = stateRunning 274 | } 275 | c.mu.Unlock() 276 | if !done { 277 | logInfo(c.Id+"-Submitter", "Running.") 278 | } 279 | for !done { 280 | select { 281 | case req, _ := <-c.retry: 282 | c.submit(req) 283 | case req, ok := <-c.Queue: 284 | if !ok { 285 | // Queue is closed and we must do s soft shutdown. 286 | // TODO Rework soft shutdown to account for retries. 287 | done = true 288 | break 289 | } 290 | c.submit(req) 291 | case <-c.cctl: 292 | done = true 293 | } 294 | } 295 | c.mu.Lock() 296 | c.state = stateClosed 297 | c.mu.Unlock() 298 | logInfo(c.Id+"-Submitter", "Stopped.") 299 | c.wg.Done() 300 | if wg != nil { 301 | wg.Done() 302 | } 303 | } 304 | 305 | func (c *Client) submit(req *Request) (rerr error) { 306 | c.rateCtr.Add(1) 307 | // TODO implement ctx timing out and cancellation checks 308 | isBlocked := false 309 | select { 310 | case c.out <- req: 311 | default: 312 | isBlocked = true 313 | } 314 | if !isBlocked { 315 | return 316 | } 317 | c.waitCtr.Tick() 318 | select { 319 | case c.out <- req: 320 | case <-c.cctl: 321 | rerr = ErrPushInterrupted 322 | } 323 | c.waitCtr.Tock() 324 | return 325 | } 326 | 327 | func init() { 328 | NoSigner = noSigner{} 329 | NoCallback = make(chan *Result) 330 | } 331 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go APNS 2 | 3 | :warning: **WORK IN PROGRESS** :warning: 4 | 5 | Go APNS is a client for Apple Push Notification service over HTTP/2 protocol done "the Go way". 6 | 7 | [](https://travis-ci.org/baobabus/go-apns) 8 | [](https://godoc.org/github.com/baobabus/go-apns/apns2) 9 | 10 | ## Features 11 | 12 | - Designed to communicate with Apple Push Notification service over HTTP/2 protocol 13 | - Processes push requests asynchronously: just write your requests to a channel 14 | - Notifies about push results asynchronously: just receive on the callback channel 15 | - Automatically scales up processing pipeline as the load increases 16 | - Allows full control of the scaling process and connection handling 17 | - Effects back pressure as needed to ensure full awareness by the up-stream 18 | - Supports Go 1.7 and later 19 | 20 | ## Processing Flow 21 | 22 | Each APNS Client handles all aspects of communication with APN service, including 23 | management of HTTP/2 connections and controlling concurrent HTTP/2 streams. 24 | 25 |  26 | 27 |
(Control channels and service goroutines are not shown)
28 | 29 | 1. `Submitter` picks up push requests from the processing queue 30 | 2. `Submitter` forwards requests to internal dispatch channel 31 | 3. Each `streamer` maintains a single HTTP/2 connection to APN service 32 | 4. One of the `streamers` picks up a push request from the dispatch queue 33 | 5. The `streamer` allocates a stream in its HTTP/2 connection 34 | 6. The `streamer` spins up a single-flight `round-tripper` goroutine 35 | 7. The `round-tripper` synchronously POSTs a request to APN service over 36 | its streamer's HTTP/2 connection 37 | 8. APN server's response is written to the callback channel 38 | 9. `Governor` collects metrics for dispatch and callback channel blockages, 39 | evaluates processing throughput and spins up new streamers as needed 40 | 41 | ## Scaling up 42 | 43 | During each poll interval governor collects stats on inbound and outbound channel blockages, 44 | on the number of notifications pushed and the size of data sent out. It then evalutes the stats 45 | against its scaling configuration and spins up new streamers if and when is appropriate. 46 | 47 | In the following illustrative scenario: 48 | 49 | - PollInterval = 0.2sec 50 | - MinSustain = 1sec 51 | - SettlePeriod = 2sec 52 | 53 |  54 |(Processing rate and bandwidth stats are not shown)
55 | 56 | 1. Blockages on outbound channel prevent blockages on inbound channel from being counted 57 | 2. Minimum time of sustained blockage on inbound channel is encountered 58 | - new streamers are spun up asynchronously 59 | - sustained blockages on inboud channel have no effect while new streamers are starting 60 | 3. All new streamers have completed their initialization 61 | - settle period begins 62 | - sustained blockages on inboud channel have no effect during settle period 63 | 4. Settle period ends 64 | - since there's been minimum time of sustained blockage on inbound channel, more streamers are spun up 65 | - sustained blockages on inboud channel have no effect while new streamers are starting 66 | 5. All new streamers have completed their initialization 67 | - settle period begins 68 | - sustained blockages on inboud channel have no effect during settle period 69 | 6. Blockages on inbound channel end - no more scaling up is needed. 70 | 71 | ## Example 72 | 73 | Fire-and-forget example sends a notification to three recipients. It uses 74 | provider token authentication method and does not check for outcome, just 75 | making sure the push is complete. 76 | 77 | ```go 78 | package main 79 | 80 | import ( 81 | "log" 82 | 83 | "github.com/baobabus/go-apns/apns2" 84 | "github.com/baobabus/go-apns/cryptox" 85 | ) 86 | 87 | func main() { 88 | 89 | // Load and parse out token signing key 90 | signingKey, err := cryptox.PKCS8PrivateKeyFromFile("token_signing_pk.p8") 91 | if err != nil { 92 | log.Fatal("Token signing key error: ", err) 93 | } 94 | 95 | // Set up our client 96 | client := &apns2.Client{ 97 | Gateway: apns2.Gateway.Production, 98 | Signer: &apns2.JWTSigner{ 99 | KeyID: "ABC123DEFG", // Your key ID 100 | TeamID: "DEF123GHIJ", // Your team ID 101 | SigningKey: signingKey, 102 | }, 103 | CommsCfg: apns2.CommsFast, 104 | ProcCfg: apns2.UnlimitedProcConfig, 105 | } 106 | 107 | // Start processing 108 | err = client.Start(nil) 109 | if err != nil { 110 | log.Fatal("Client start error: ", err) 111 | } 112 | 113 | // Mock motification and recipients 114 | header := &apns2.Header{ Topic: "com.example.Alert" } 115 | payload := &apns2.Payload{ APS: &apns2.APS{Alert: "Ping!"} } 116 | recipients := []string{ 117 | "00fc13adff785122b4ad28809a3420982341241421348097878e577c991de8f0", 118 | "10fc13adff785122b4ad28809a3420982341241421348097878e577c991de8f0", 119 | "20fc13adff785122b4ad28809a3420982341241421348097878e577c991de8f0", 120 | } 121 | 122 | // Push to all recipients 123 | for _, rcpt := range recipients { 124 | notif := &apns2.Notification{ 125 | Recipient: rcpt, 126 | Header: header, 127 | Payload: payload, 128 | } 129 | err := client.Push(notif, apns2.DefaultSigner, apns2.NoContext, apns2.DefaultCallback) 130 | if err != nil { 131 | log.Fatal("Push error: ", err) 132 | } 133 | } 134 | 135 | // Perform soft shutdown allowing the processing to complete. 136 | client.Stop() 137 | } 138 | ``` 139 | 140 | ## Configuration Settings and Customization 141 | 142 | ### Communication Settings 143 | 144 | Following communication settings are supported: 145 | 146 | ##### DialTimeout 147 | 148 | DialTimeout is the maximum amount of time a dial will wait for a connect 149 | to complete. 150 | 151 | ##### MinDialBackOff 152 | 153 | MinDialBackOff is the minimum amount of time by which dial attempts 154 | should be delayed after encountering a refused connection. 155 | Actual back-off time will grow exponentially until a connection attempt 156 | is successful. 157 | 158 | ##### MaxDialBackOff 159 | 160 | MaxDialBackOff is the maximum amount of time by which dial attempts 161 | should be delayed after encountering a refused connection. 162 | 163 | ##### DialBackOffJitter 164 | 165 | DialBackOffJitter is used to calculate the ramdom amount to appy to each 166 | back-off time calculation. 167 | 168 | ##### RequestTimeout 169 | 170 | RequestTimeout specifies a time limit for requests made by the 171 | HTTPClient. The timeout includes connection time, any redirects, 172 | and reading the response body. 173 | 174 | ##### KeepAlive 175 | 176 | KeepAlive specifies the keep-alive period for an active network 177 | connection. If zero, keep-alives are not enabled. 178 | Apple recommends not closing connections to APN service at all, 179 | but a sinsibly long duration is acceptable. 180 | 181 | ##### MaxConcurrentStreams 182 | 183 | MaxConcurrentStreams is the maximum allowed number of concurrent streams 184 | per HTTP/2 connection. If connection's MAX_CONCURRENT_STREAMS option 185 | is invoked by the remote side with a lower value, the remote request 186 | will be honored if possible. (See AllowHTTP2Incursion processing option.) 187 | 188 | 189 | CommsCfg example: 190 | 191 | ```go 192 | CommsCfg{ 193 | DialTimeout: 1 * time.Second, 194 | MinDialBackOff: 4 * time.Second, 195 | MaxDialBackOff: 10 * time.Minute, 196 | DialBackOffJitter: 10 * funit.Percent, 197 | RequestTimeout: 2 * time.Second, 198 | KeepAlive: 10 * time.Hour, 199 | MaxConcurrentStreams: 500, 200 | } 201 | ``` 202 | 203 | ### Processing Settings 204 | 205 | Following processing settings are supported: 206 | 207 | ##### MaxRetries 208 | MaxRetries is the maximum number of times a failed notification push 209 | should be reattempted. This only applies to "retriable" failures. 210 | 211 | ##### RetryEval 212 | RetryEval is the function that is called when a push attempt fails 213 | and retry eligibility needs to be determined. 214 | 215 | ##### MinConns 216 | MinConns is minimum number of concurrent connections to APN servers 217 | that should be kept open. When a client is started it immeditely attempts 218 | to open the specified number of connections. 219 | 220 | ##### MaxConns 221 | MaxConns is maximum allowed number of concurrent connections 222 | to APN service. 223 | 224 | ##### MaxRate 225 | MaxRate is the throughput cap specified in notifications per second. 226 | It is not strictly enforced as would be the case with a true rate 227 | limiter. Instead it only prevents additional scaling from taking place 228 | once the specified rate is reached. 229 | 230 | For clarity it is best expressed in idiomatic way: 231 | 232 | ```go 233 | MaxRate = 10000 / funit.Second 234 | ``` 235 | 236 | ##### MaxBandwidth 237 | MaxBandwidth is the throughput cap specified in bits per second. 238 | It is not strictly enforced as would be the case with a true rate 239 | limiter. Instead it only prevents additional scaling from taking place 240 | once the specified rate is reached. 241 | 242 | For clarity it is best expressed in idiomatic way: 243 | 244 | ```go 245 | MaxBandwidth = 10 * funit.Kilobit / funit.Second 246 | ``` 247 | 248 | ##### Scale 249 | Scale specifies the manner of scaling up and winding down. 250 | Three scaling modes come prefefined: Incremental, Exponential and Constant. 251 | 252 | ```go 253 | Scale = scale.Incremental(2) // Add two new connections each time 254 | ``` 255 | 256 | ##### MinSustain 257 | MinSustain is the minimum duration of time over which the processing 258 | has to experience blocking before a scale-up attemp is made. It is also 259 | the minimum amount of time over which non-blocking processing has to 260 | take place before a wind-down attemp is made. 261 | 262 | ##### PollInterval 263 | PollInterval is the time between performance metrics sampling attempts. 264 | 265 | ##### SettlePeriod 266 | SettlePeriod is the amount of time given to the processing for it to 267 | settle down at the new rate after successful scaling up or 268 | winding down attempt. Sustained performance analysis is ignored during 269 | this time and no new scaling attempt is made. 270 | 271 | ##### AllowHTTP2Incursion 272 | AllowHTTP2Incursion controls whether it is OK to perform reflection-based 273 | probing of HTTP/2 layer. When enabled, scaler may access certain private 274 | properties in x/net/http2 package if needed for more precise performance 275 | analysis. 276 | 277 | ##### UsePreciseHTTP2Metrics 278 | UsePreciseHTTP2Metrics, if set to true, instructs the scaler to query 279 | HTTP/2 layer parameters on every call that requires the data. 280 | Set this to false if you wish to eliminate any additional overhead that 281 | this may introduce. 282 | 283 | ##### HTTP2MetricsRefreshPeriod 284 | HTTP2MetricsRefreshPeriod, if set to a positive value, controls 285 | the frequency of "imprecise" metrics updates. Under this approach any 286 | relevant fields that are private to x/net/http2 packaged are only 287 | queried periodically. 288 | This reduces the overhead of any required reflection calls, but it also 289 | introduces the risk of potentially relying on some stale metrics. 290 | In most realistic situations, however, this can be easily tolerated 291 | given frequent enough refresh period. 292 | 293 | HTTP2MetricsRefreshPeriod value is ignored and periodic updates 294 | are turned off if UsePreciseHTTP2Metrics is set to true. 295 | Setting HTTP2MetricsRefreshPeriod to 0 or negative value disables 296 | metrics refresh even if UsePreciseMetrics is false. 297 | 298 | ProcCfg example: 299 | 300 | ```go 301 | ProcCfg{ 302 | MaxRetries: 0, 303 | RetryEval: nil, 304 | MinConns: 1, 305 | MaxConns: 100, 306 | MaxRate: 100000 / funit.Second, 307 | MaxBandwidth: 10 * funit.Kilobit / funit.Second, 308 | Scale: scale.Incremental(2), 309 | MinSustain: 2 * time.Second, 310 | PollInterval: 200 * time.Millisecond, 311 | SettlePeriod: 5 * time.Second, 312 | AllowHTTP2Incursion: true, 313 | UsePreciseHTTP2Metrics: false, 314 | HTTP2MetricsRefreshPeriod: 200 * time.Millisecond, 315 | } 316 | ``` 317 | 318 | ## License 319 | 320 | The MIT License (MIT) 321 | 322 | Copyright (c) 2017 Aleksey Blinov 323 | 324 | Permission is hereby granted, free of charge, to any person obtaining a copy 325 | of this software and associated documentation files (the "Software"), to deal 326 | in the Software without restriction, including without limitation the rights 327 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 328 | copies of the Software, and to permit persons to whom the Software is 329 | furnished to do so, subject to the following conditions: 330 | 331 | The above copyright notice and this permission notice shall be included in all 332 | copies or substantial portions of the Software. 333 | 334 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 335 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 336 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 337 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 338 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 339 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 340 | SOFTWARE. 341 | -------------------------------------------------------------------------------- /apns2/dispatch.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Aleksey Blinov. All rights reserved. 2 | 3 | package apns2 4 | 5 | import ( 6 | "fmt" 7 | "time" 8 | 9 | "github.com/baobabus/go-apns/funit" 10 | "github.com/baobabus/go-apns/scale" 11 | ) 12 | 13 | // ProcCfg is a set of parameters that govern request processing flow 14 | // including automatic scaling of the processing pipeline. 15 | type ProcCfg struct { 16 | 17 | // MaxRetries is the maximum number of times a failed notification push 18 | // should be reattempted. This only applies to "retriable" failures. 19 | MaxRetries uint32 20 | 21 | // RetryEval is the function that is called when a push attempt fails 22 | // and retry eligibility needs to be determined. 23 | RetryEval func(*Response, error) bool 24 | 25 | // MinConns is minimum number of concurrent connections to APN servers 26 | // that should be kept open. 27 | MinConns uint32 28 | 29 | // MaxConns is maximum allowed number of concurrent connections 30 | // to APN service. 31 | MaxConns uint32 32 | 33 | // MaxRate is the throughput cap specified in notifications per second. 34 | // It is not strictly enforced as would be the case with a true rate 35 | // limiter. Instead it only prevents additional scaling from taking place 36 | // once the specified rate is reached. 37 | MaxRate funit.Measure 38 | 39 | // MaxBandwidth is the throughput cap specified in bits per second. 40 | // It is not strictly enforced as would be the case with a true rate 41 | // limiter. Instead it only prevents additional scaling from taking place 42 | // once the specified rate is reached. 43 | MaxBandwidth funit.Measure 44 | 45 | // Scale specifies the manner of scaling up and winding down. 46 | // Three scaling modes come prefefined: Incremental, Exponential and Constant. 47 | // See below for more detail. 48 | Scale scale.Scale 49 | 50 | // MinSustain is the minimum duration of time over which the processing 51 | // has to experience blocking before a scale-up attemp is made. It is also 52 | // the minimum amount of time over which non-blocking processing has to 53 | // take place before a wind-down attemp is made. 54 | MinSustain time.Duration 55 | 56 | // PollInterval is the time between performance metrics sampling attempts. 57 | PollInterval time.Duration 58 | 59 | // SettlePeriod is the amount of time given to the processing for it to 60 | // settle down at the new rate after successful scaling up or 61 | // winding down attempt. Sustained performance analysis is ignored during 62 | // this time and no new scaling attempt is made. 63 | SettlePeriod time.Duration 64 | 65 | // AllowHTTP2Incursion controls whether it is OK to perform reflection-based 66 | // probing of HTTP/2 layer. When enabled, scaler may access certain private 67 | // properties in x/net/http2 package if needed for more precise performance 68 | // analysis. 69 | AllowHTTP2Incursion bool 70 | 71 | // UsePreciseHTTP2Metrics, if set to true, instructs the scaler to query 72 | // HTTP/2 layer parameters on every call that requires the data. 73 | // Set this to false if you wish to eliminate any additional overhead that 74 | // this may introduce. 75 | UsePreciseHTTP2Metrics bool 76 | 77 | // HTTP2MetricsRefreshPeriod, if set to a positive value, controls 78 | // the frequency of "imprecise" metrics updates. Under this approach any 79 | // relevant fields that are private to x/net/http2 packaged are only 80 | // queried periodically. 81 | // This reduces the overhead of any required reflection calls, but it also 82 | // introduces the risk of potentially relying on some stale metrics. 83 | // In most realistic situations, however, this can be easily tolerated 84 | // given frequent enough refresh period. 85 | // 86 | // HTTP2MetricsRefreshPeriod value is ignored and periodic updates 87 | // are turned off if UsePreciseHTTP2Metrics is set to true. 88 | // Setting HTTP2MetricsRefreshPeriod to 0 or negative value disables 89 | // metrics refresh even if UsePreciseMetrics is false. 90 | HTTP2MetricsRefreshPeriod time.Duration 91 | } 92 | 93 | // MinBlockingProcConfig is a configuration with absolute mimimal processing 94 | // settings. It only allows a single connection to APN service with no scaling. 95 | // HTTP/2 layer metrics refresh is set to 500ms to allow proper handling 96 | // of HTTP/2 streams concurrency without introducing any noticeable overhead. 97 | var MinBlockingProcConfig = ProcCfg{ 98 | MinConns: 1, 99 | MaxConns: 1, 100 | MaxRate: 1000 / funit.Second, 101 | MaxBandwidth: 10 * funit.Gigabit / funit.Second, 102 | Scale: scale.Constant, 103 | AllowHTTP2Incursion: true, 104 | HTTP2MetricsRefreshPeriod: 500 * time.Millisecond, 105 | } 106 | 107 | // UnlimitedProcConfig is a configuration with virtually no limit on processing 108 | // speed and unlimited base 2 exponential scaling. 109 | var UnlimitedProcConfig = ProcCfg{ 110 | MinConns: 1, 111 | MaxConns: ^uint32(0), 112 | MaxRate: 10000000 / funit.Second, 113 | MaxBandwidth: 1 * funit.Terabit / funit.Second, 114 | Scale: scale.Exponential(2), 115 | AllowHTTP2Incursion: true, 116 | HTTP2MetricsRefreshPeriod: 500 * time.Millisecond, 117 | } 118 | 119 | // minSustainPollPeriods returns the number of PollInterval periods per 120 | // MinSustain time interval. If PollInterval is not a whole divisor of 121 | // MinSustain, the result is rounded up. 122 | // If either PollInterval or MinSustain is not a valid time interval, 123 | // max uint32 is returned. 124 | func (c *ProcCfg) minSustainPollPeriods() uint32 { 125 | if c.MinSustain == 0 { 126 | return 0 127 | } 128 | if c.MinSustain < 0 || c.PollInterval <= 0 { 129 | return ^uint32(0) 130 | } 131 | res := c.MinSustain / c.PollInterval 132 | if c.MinSustain%c.PollInterval > 0 { 133 | res++ 134 | } 135 | return uint32(res) 136 | } 137 | 138 | // rateAsCount returns MaxRate expressed as number of counts per adjusted 139 | // MinSustain period. A rate of 1000/sec with MinSustain interval of 11 seconds 140 | // and PollInterval of 2 seconds is 12000 counts (6 poll intervals are needed 141 | // to make up at least 11 seconds, resulting in 12 seconds in adjusted 142 | // sustain period). 143 | func (c *ProcCfg) rateAsCount() uint64 { 144 | if c.MinSustain <= 0 || c.PollInterval <= 0 || c.MaxRate <= 0 { 145 | return 0 146 | } 147 | n := float64(c.minSustainPollPeriods()) 148 | return uint64(float64(c.MaxRate)*n*float64(c.PollInterval)) / uint64(funit.Second.AsDuration()) 149 | } 150 | 151 | // bandwidthAsSize returns MaxBandwidth expressed in bytes per adjusted 152 | // MinSustain period. A bandwidth of 1000/sec with MinSustain interval of 11 seconds 153 | // and PollInterval of 2 seconds is 12000 counts (6 poll intervals are needed 154 | // to make up at least 11 seconds, resulting in 12 seconds in adjusted 155 | // sustain period). 156 | func (c *ProcCfg) bandwidthAsSize() uint64 { 157 | if c.MinSustain <= 0 || c.PollInterval <= 0 || c.MaxBandwidth <= 0 { 158 | return 0 159 | } 160 | n := float64(c.minSustainPollPeriods()) 161 | return uint64(float64(c.MaxBandwidth/funit.Byte)*n*float64(c.PollInterval)) / uint64(funit.Second.AsDuration()) 162 | } 163 | 164 | type governor struct { 165 | id string 166 | c *Client 167 | ctl <-chan struct{} 168 | done chan<- struct{} 169 | 170 | cfg ProcCfg 171 | 172 | // minimun number of continuous sampling periods of performance 173 | // evaluation need to have an effect on scaling decision 174 | minSust uint32 175 | 176 | // counters of continuous periods with waits and no waits 177 | // on inbound and oubound channels 178 | inCtr waitCounter 179 | outCtr waitCounter 180 | 181 | // processing rate and bandwidth accumulators 182 | countAcc *movingAcc 183 | sizeAcc *movingAcc 184 | maxCount uint64 // derived from cfg.MaxRate and minSust 185 | maxSize uint64 // derived from cfg.MaxBandwidth and minSust 186 | 187 | retry chan *Request 188 | 189 | // active streamers and pending launchers 190 | streamers map[*streamer]chan struct{} 191 | launchers map[*launcher]chan struct{} 192 | nextWId uint 193 | 194 | // "callback" channels streamers and launchers 195 | // to annouce their completion 196 | wExits chan *streamer 197 | lExits chan *launcher 198 | 199 | // time of last up- or down-scaling completion 200 | lastScale time.Time 201 | 202 | // tracker of blackout time due to back-off after failed connects 203 | backOffTracker backOffTracker 204 | 205 | isClosing bool 206 | } 207 | 208 | type waitCounter struct { 209 | waits uint32 210 | noWaits uint32 211 | } 212 | 213 | func (c *waitCounter) acc(val uint32) { 214 | if val > 0 { 215 | c.waits++ 216 | c.noWaits = 0 217 | } else { 218 | c.waits = 0 219 | c.noWaits++ 220 | } 221 | } 222 | 223 | // Must be called exactly once 224 | func (g *governor) run() { 225 | logInfo(g.id, "Starting.") 226 | if g.cfg.MaxRate > 0 && g.minSust > 0 { 227 | g.countAcc = newMovingAcc(int(g.minSust)) 228 | g.maxCount = g.cfg.rateAsCount() 229 | } 230 | if g.cfg.MaxBandwidth > 0 && g.minSust > 0 { 231 | g.sizeAcc = newMovingAcc(int(g.minSust)) 232 | g.maxSize = g.cfg.bandwidthAsSize() 233 | } 234 | g.wExits = make(chan *streamer) 235 | g.lExits = make(chan *launcher) 236 | g.streamers = make(map[*streamer]chan struct{}) 237 | g.launchers = make(map[*launcher]chan struct{}) 238 | g.backOffTracker.initial = 4 * time.Second 239 | if g.c.CommsCfg.MinDialBackOff > 0 { 240 | g.backOffTracker.initial = g.c.CommsCfg.MinDialBackOff 241 | } 242 | g.backOffTracker.max = g.c.CommsCfg.MaxDialBackOff 243 | g.backOffTracker.jitter = g.c.CommsCfg.DialBackOffJitter 244 | go g.runRetryForwarder() 245 | // Launch first MinConns streamers 246 | g.tryScaleUp() 247 | var tkrChan <-chan time.Time 248 | if g.cfg.PollInterval > 0 { 249 | tkr := time.NewTicker(g.cfg.PollInterval) 250 | defer tkr.Stop() 251 | tkrChan = tkr.C 252 | } 253 | logInfo(g.id, "Running.") 254 | for done := false; !done; { 255 | select { 256 | case l := <-g.lExits: 257 | // launcher finished 258 | delete(g.launchers, l) 259 | g.backOffTracker.update(l.err) 260 | if w := l.worker; w != nil { 261 | g.streamers[w] = w.ctl 262 | } else if l.err != nil { 263 | logWarn(g.id, "Error starting streamer: %v", l.err) 264 | } 265 | if len(g.launchers) == 0 { 266 | g.lastScale = time.Now() 267 | } 268 | // TODO Handle failed launches 269 | case w := <-g.wExits: 270 | // worker finished 271 | if w.inClosed && !g.isClosing { 272 | // Soft stop: Client closed main channel. We are closing, too. 273 | logInfo(g.id, "Stopping.") 274 | g.isClosing = true 275 | } 276 | delete(g.streamers, w) 277 | if w.didQuit { 278 | // This needs to be on exponential back-off 279 | g.launchStreamer() 280 | } 281 | case <-tkrChan: 282 | if g.isClosing { 283 | break 284 | } 285 | s := g.updateCountersAndEvalScaling() 286 | if s > 0 { 287 | g.tryScaleUp() 288 | } else if s < 0 { 289 | g.tryWindDown() 290 | } 291 | case <-g.ctl: 292 | // Hard stop command 293 | logInfo(g.id, "Terminating.") 294 | done = true 295 | } 296 | if !done && g.isClosing { 297 | done = len(g.streamers) == 0 && len(g.launchers) == 0 298 | } 299 | } 300 | // signal launchers and streamers 301 | logInfo(g.id, "Terminating launchers and streamers.") 302 | for i, _ := range g.launchers { 303 | close(i.ctl) 304 | } 305 | for i, _ := range g.streamers { 306 | close(i.ctl) 307 | } 308 | // TODO Signal forwarder to stop 309 | logInfo(g.id, "Stopped.") 310 | // Signal parent 311 | close(g.done) 312 | } 313 | 314 | func (g *governor) updateCountersAndEvalScaling() int { 315 | shouldCount := g.cfg.MaxRate > 0 && g.minSust > 0 316 | shouldSize := g.cfg.MaxBandwidth > 0 && g.minSust > 0 317 | ics, _ := g.c.waitCtr.Fold() 318 | cnt := g.c.rateCtr.Draw() 319 | var ocs uint32 320 | var osz uint64 321 | // It is ok for the calls to Fold and Draw to not be fully synchronized. 322 | // We are only roughly estimating the disparity. 323 | for s, _ := range g.streamers { 324 | oc, _ := s.waitCtr.Fold() 325 | ocs += oc 326 | if shouldSize { 327 | osz += s.sizeCtr.Draw() 328 | } 329 | } 330 | g.inCtr.acc(ics) 331 | g.outCtr.acc(ocs) 332 | if shouldCount { 333 | cnt = g.countAcc.accumulate(cnt) 334 | } 335 | if shouldSize { 336 | osz = g.sizeAcc.accumulate(osz) 337 | } 338 | if g.inCtr.waits >= g.minSust && g.outCtr.noWaits >= g.minSust { 339 | // We've been experiencing blocking long enough, 340 | // but we must also not exceed allowed performance limits. 341 | if shouldCount && cnt > g.maxCount { 342 | return 0 343 | } 344 | if shouldSize && osz > g.maxSize { 345 | return 0 346 | } 347 | return 1 348 | } else if g.inCtr.noWaits >= g.minSust { 349 | return -1 350 | } 351 | return 0 352 | } 353 | 354 | const ( 355 | forScaleUp = true 356 | forWindDown = false 357 | ) 358 | 359 | func (g *governor) tryScaleUp() { 360 | delta := g.allowedScaleDelta(forScaleUp) 361 | logTrace(2, g.id, "tryScaleUp delta = %d", delta) 362 | if delta <= 0 { 363 | return 364 | } 365 | for i := 0; i < delta; i++ { 366 | g.launchStreamer() 367 | } 368 | } 369 | 370 | func (g *governor) tryWindDown() { 371 | // TODO Implement winding down 372 | } 373 | 374 | func (g *governor) launchStreamer() { 375 | wid := fmt.Sprintf(g.id+"-Streamer-%d", g.nextWId) 376 | l := &launcher{gov: g, id: wid, done: g.lExits, ctl: make(chan struct{})} 377 | g.nextWId++ 378 | g.launchers[l] = l.ctl 379 | go l.launch() 380 | } 381 | 382 | func (g *governor) allowedScaleDelta(forScaleUp bool) int { 383 | if g.isClosing || len(g.launchers) > 0 { 384 | return 0 385 | } 386 | now := time.Now() 387 | switch { 388 | case g.lastScale.Add(g.cfg.SettlePeriod).After(now): 389 | return 0 390 | case g.backOffTracker.blackoutEnd().After(now): 391 | return 0 392 | } 393 | prov := uint32(len(g.streamers) + len(g.launchers)) 394 | req := uint32(0) 395 | if forScaleUp { 396 | if prov >= g.cfg.MaxConns { 397 | return 0 398 | } 399 | req = g.cfg.Scale.Apply(prov) 400 | } else { 401 | if prov <= g.cfg.MinConns { 402 | return 0 403 | } 404 | req = g.cfg.Scale.ApplyInverse(prov) 405 | } 406 | if req < g.cfg.MinConns { 407 | req = g.cfg.MinConns 408 | } 409 | if req > g.cfg.MaxConns { 410 | req = g.cfg.MaxConns 411 | } 412 | return int(req) - int(prov) 413 | } 414 | 415 | type launcher struct { 416 | gov *governor 417 | id string 418 | done chan<- *launcher 419 | ctl chan struct{} 420 | err error 421 | worker *streamer 422 | } 423 | 424 | func (l *launcher) launch() { 425 | w := &streamer{ 426 | id: l.id, 427 | c: l.gov.c, 428 | gov: l.gov, 429 | in: l.gov.c.out, 430 | out: l.gov.c.Callback, 431 | warmStart: true, 432 | ctl: make(chan struct{}), 433 | done: l.gov.wExits, 434 | } 435 | if l.err = w.start(nil); l.err == nil { 436 | l.worker = w 437 | } 438 | // read from ctl prevents blocking on done if the governor 439 | // was commanded to terminate in the meantime 440 | select { 441 | case l.done <- l: 442 | case <-l.ctl: 443 | } 444 | } 445 | 446 | // TODO Rework forwarder and streamers so that inbound channel can be closed 447 | // by the client to indicate end of input, while allowing any retry requests 448 | // to finish. 449 | func (g *governor) runRetryForwarder() { 450 | if g.cfg.MaxRetries == 0 { 451 | return 452 | } 453 | // Retry requests will be re-queued with the Client. We need to ensure 454 | // that any blocking on the Client inbound channel is dealt with in a way 455 | // that doesn't block our streamers. 456 | // Rather than spinning goroutines for every retry send, we buffer 457 | // the sends. 100 buffered forwarders with buffers of 500 requests each 458 | // is more efficient than 50000 individual sender goroutines. 459 | var buf chan *Request 460 | bufSize := 500 461 | cnt := 0 462 | // slight buffering on the inbound channel to improve performance 463 | g.retry = make(chan *Request, 100) 464 | logInfo(g.id+"-RetryForwarder", "Running.") 465 | for done := false; !done; { 466 | select { 467 | case req := <-g.retry: 468 | if buf == nil || cnt >= bufSize { 469 | if buf != nil { 470 | // signal bufferedForwarder to return 471 | close(buf) 472 | } 473 | buf = make(chan *Request, bufSize) 474 | go bufferedForwarder(buf, g.c, g.ctl) 475 | cnt = 0 476 | } 477 | buf <- req 478 | case <-g.ctl: 479 | done = true 480 | } 481 | } 482 | logInfo(g.id+"-RetryForwarder", "Stopped.") 483 | } 484 | 485 | func bufferedForwarder(in <-chan *Request, client *Client, ctl <-chan struct{}) { 486 | for done := false; !done; { 487 | select { 488 | case req, ok := <-in: 489 | if !ok { 490 | done = true 491 | break 492 | } 493 | select { 494 | case client.retry <- req: 495 | case <-ctl: 496 | done = true 497 | } 498 | case <-ctl: 499 | done = true 500 | } 501 | } 502 | } 503 | 504 | type movingAcc struct { 505 | samples []uint64 506 | sum uint64 507 | pos int 508 | } 509 | 510 | func newMovingAcc(windowSize int) *movingAcc { 511 | if windowSize <= 0 { 512 | return nil 513 | } 514 | return &movingAcc{samples: make([]uint64, windowSize)} 515 | } 516 | 517 | func (a *movingAcc) accumulate(v uint64) uint64 { 518 | a.sum += v - a.samples[a.pos] 519 | a.samples[a.pos] = v 520 | a.pos = (a.pos + 1) % len(a.samples) 521 | return a.sum 522 | } 523 | -------------------------------------------------------------------------------- /doc/scale.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | --------------------------------------------------------------------------------