├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── badge_number.go ├── badge_number_test.go ├── connection.go ├── connection_test.go ├── feedback_service.go ├── feedback_service_test.go ├── payload.go └── payload_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | .godeps 20 | 21 | _testmain.go 22 | 23 | *.exe 24 | *.test 25 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Karl Kirch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | go-libapns 2 | ========== 3 | 4 | APNS library for go 5 | 6 | The idea here is to be a simple low level library that will handle establishing a connection and sending push notifications via Apple's apns service with thought towards throughput and performance. 7 | 8 | Handles the latest Apple push notification guidelines at https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/APNSOverview.html 9 | 10 | Specifically will implement the binary framed format by batching push notifications. Each batch will be flushed either every 10ms or when a frame is full. A frame is full when the framed format cannot fit anymore data into a tcp packet (65535 bytes). Due to this framing, when finished with the apns connection, one should call Disconnect() to flush any remaining messages out the door. 11 | 12 | ##Godoc 13 | Located here [![GoDoc](https://godoc.org/github.com/joekarl/go-libapns?status.svg)](https://godoc.org/github.com/joekarl/go-libapns) 14 | 15 | ##Installation 16 | 17 | ```bash 18 | > go get github.com/joekarl/go-libapns 19 | ``` 20 | 21 | ##Basic Usage 22 | ```go 23 | package main 24 | 25 | import ( 26 | apns "github.com/joekarl/go-libapns" 27 | "io/ioutil" 28 | ) 29 | 30 | func main() { 31 | certPem, err := ioutil.ReadFile("../certs/PushTestCert.pem") 32 | if err != nil { 33 | panic(err) 34 | } 35 | keyPem, err := ioutil.ReadFile("../certs/PushTestKey-noenc.pem") 36 | if err != nil { 37 | panic(err) 38 | } 39 | 40 | apnsConnection, _ := apns.NewAPNSConnection(&APNSConfig{ 41 | CertificateBytes: certPem, 42 | KeyBytes: keyPem, 43 | }) 44 | 45 | payload := &apns.Payload { 46 | Token: "2ed202ac08ea9...cf8d55910df290567037dcc4", 47 | AlertText: "This is a push notification!", 48 | } 49 | 50 | apnsConnection.SendChannel <- payload 51 | apnsConnection.Disconnect() 52 | } 53 | ``` 54 | **Note** This example doesn't take into account essential error handling. See below for error handling details 55 | 56 | **Payload.Badge Need to Know** Apple specifies that one should set the badge key to 0 to clear the badge number. This unfortunately has the side effect of causing the go JSON serializer to omit the badge field. Luckily Apple uses negative badge numbers to clear the badge as well. So for our purposes, a badge > 0 will set the badge number, a badge < 0 will clear the badge number, and a badge == 0 will leave the badge number as is. 57 | 58 | ##Creating an APNS connection 59 | Creating a connection consists of a couple of steps. They are: 60 | 61 | * Creating a tcp socket to Apple's servers 62 | * Initiating a TLS session using your Pem certs 63 | 64 | There are two ways to do this using go-libapns. They are: 65 | 66 | * NewConnection(*APNSConfig) 67 | * SocketAPNSConnection(net.Conn, *APNSConfig) 68 | 69 | The `NewConnection` method will validate your config, create the tcp connection, initiate a TLS session, and return a new APNSConnection. You probably will always use this way of creating a connection. 70 | 71 | If you are on a platform that needs to create a custom socket (like Google App Engine), you can use the `SocketAPNSConnection` method. This takes a `net.Conn` (should be a tcpSocket), validates your config, initializes a TLS session, and returns a new APNSConnection. 72 | 73 | ##Pem Certs 74 | You should provide your apns certificate as separated cert/key pem files. Currently go doesn't support password protected pem files (https://github.com/golang/go/issues/6722) so you'll need remove the password from your key pem. 75 | 76 | ####Separate pem files from p12 77 | ```sh 78 | openssl pkcs12 -clcerts -nokeys -out cert.pem -in cert.p12 79 | 80 | openssl pkcs12 -nocerts -out key.pem -in key.p12 81 | ``` 82 | 83 | ####Remove password from pem file 84 | ```sh 85 | openssl rsa -in key.pem -out key-noenc.pem 86 | ``` 87 | 88 | ##Error Handling 89 | As per Apple's guidelines, when a connection is closed due to error, the id of the message which caused the error will be transmitted back over the connection. In this case, multiple push notifications may have followed the bad message. These push notifications will be supplied on a channel **as well as any other unsent messages** and will be then available to re-process. Also when writing to the send channel, you should wrap the send with a select and case both the send and connection close channels. This will allow you to correctly handle the async nature of Apple's error handling scheme. See this gist (https://gist.github.com/joekarl/86d9bdb8f9af044710b7) for a full featured example of how to integrate go-libapns with proper shutdown handling and looped connection handling. 90 | 91 | ##Persistent Connection 92 | go-libapns will use a persistant tcp connection (supplied by the user) to connect to Apple's APNS gateway. This allows for the greatest throughput to Apple's servers. On close or error, this connection will be killed and all unsent push notifications will be supplied for re-process. **Note** Unlike most other APNS libraries, go-libapns will NOT attempt to re-transmit your unsent payloads. Because it is trivial to write this retry logic, go-libapns leaves that to the user to implement as not everyone needs or wants this behavior (i.e. you may want to put the messages that need resent into a queue or store them for later). 93 | 94 | ##Feedback Service 95 | Apple specifies that you should connect to the feedback service gateway regularly to keep track of devices that no longer have your application installed. go-libapns provides a simple interface to the feedback service. Simply create a `APNSFeedbackServiceConfig` object and then call `ConnectToFeedbackService`. This will return a list of device tokens that you should keep track of and not send push notifications to again (specifically this will return a List of `*FeedbackResponse`) 96 | 97 | ##Push Notification Length 98 | Apple places a strict limit on push notification length (currently at 2048 bytes). go-libapns will attempt to fit your push notification into that size limit by first applying all of your supplied custom fields and applying as much of your alert text as possible. This truncation is not without cost as it takes almost twice the time to fix a message that is too long. So if possible, try to find a sweet spot that won't cause truncation to occur. If unable to truncate the message, go-libapns will close it's connection to the APNS gateway (you've been warned). This limit is configurable in the APNSConfig object. 99 | 100 | _Note: Prior to iOS 8, the limit was 256 bytes. APNS will accept and deliver up to 2048 bytes to devices 101 | running iOS 8 as well as those running on older versions of iOS._ 102 | 103 | ##TCP Framing 104 | Most APNS libraries rely on the OS Nagling to buffer data into the socket. go-libapns does not rely on Nagling but does do what it can to optimize the number of bytes sent per TCP frame. The two relevant config options that control this behavior are: 105 | 106 | * MaxOutboundTCPFrameSize - (default TCP_FRAME_MAX) Max number of bytes to send per TCP frame 107 | * FramingTimeout - (default 10ms) Max time between TCP flushes 108 | 109 | TCP_NODELAY can be turned on with this setup by setting the FramingTimeout to anything less than 0 (like -1). In practice you want this buffering to occur, so best to leave defaults. If you're concerned about a (max) 10ms delay between your push notifications being sent onto the socket be aware that this is much much much shorter than the default linux Nagle timeout of 1 second. 110 | 111 | ##What's with using channels for writing to the connection? 112 | Basically, this makes it easier to synchronize error handling and socket errors. Not sure if this is the best idea, but definitely works. 113 | 114 | ##APNSConfig 115 | The only required fields are the CertificateBytes and KeyBytes. 116 | The other fields all have sane defaults 117 | 118 | ```go 119 | InFlightPayloadBufferSize int //number of payloads to keep for error purposes, defaults to 10000 120 | FramingTimeout int //number of milliseconds between frame flushes, defaults to 10ms 121 | MaxPayloadSize int //max number of bytes allowed in payload, defaults to 2048 122 | CertificateBytes []byte //bytes for cert.pem : required 123 | KeyBytes []byte //bytes for key.pem : required 124 | GatewayHost string //apple gateway, defaults to "gateway.push.apple.com" 125 | GatewayPort string //apple gateway port, defaults to "2195" 126 | MaxOutboundTCPFrameSize int //max number of bytes to frame data to, defaults to TCP_FRAME_MAX 127 | //generally best to NOT set this and use the default 128 | SocketTimeout int //number of seconds to wait before bailing on a socket connection, defaults to no timeout 129 | TlsTimeout int //number of seconds to wait before bailing on a tls handshake, defaults to 5 sec 130 | ``` 131 | 132 | #License 133 | The MIT License (MIT) 134 | 135 | Copyright (c) 2014 Karl Kirch 136 | 137 | Permission is hereby granted, free of charge, to any person obtaining a copy 138 | of this software and associated documentation files (the "Software"), to deal 139 | in the Software without restriction, including without limitation the rights 140 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 141 | copies of the Software, and to permit persons to whom the Software is 142 | furnished to do so, subject to the following conditions: 143 | 144 | The above copyright notice and this permission notice shall be included in all 145 | copies or substantial portions of the Software. 146 | 147 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 148 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 149 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 150 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 151 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 152 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 153 | SOFTWARE. 154 | -------------------------------------------------------------------------------- /badge_number.go: -------------------------------------------------------------------------------- 1 | package apns 2 | 3 | import ( 4 | "errors" 5 | "strconv" 6 | ) 7 | 8 | // Struct representing the badge number over 9 | // the app icon on iOS 10 | type BadgeNumber struct { 11 | number int 12 | set bool 13 | } 14 | 15 | // Returns the set badge number 16 | func (b *BadgeNumber) Number() int { 17 | return b.number 18 | } 19 | 20 | // Returns whether or not this BadgeNumber 21 | // is set and should be sent in the APNS payload 22 | func (b *BadgeNumber) IsSet() bool { 23 | return b.set 24 | } 25 | 26 | // Resets the BadgeNumber to 0 and 27 | // removes it from the APNS payload 28 | func (b *BadgeNumber) UnSet() { 29 | b.number = 0 30 | b.set = false 31 | } 32 | 33 | // Sets the badge number and includes it in the 34 | // payload to APNS. call .Set(0) to have the badge 35 | // number cleared from the app icon 36 | func (b *BadgeNumber) Set(number int) error { 37 | if number < 0 { 38 | return errors.New("Number must be >= 0") 39 | } 40 | 41 | b.number = number 42 | b.set = true 43 | return nil 44 | } 45 | 46 | func (b BadgeNumber) MarshalJSON() ([]byte, error) { 47 | return []byte(strconv.Itoa(b.number)), nil 48 | } 49 | 50 | func (b *BadgeNumber) UnmarshalJSON(data []byte) error { 51 | val, err := strconv.ParseInt(string(data), 10, 32) 52 | if err != nil { 53 | return errors.New("Error unmarshalling BadgeNumber, cannot convert []byte to int32") 54 | } 55 | 56 | // Since the point of this type is to 57 | // allow proper inclusion of 0 for int 58 | // types while respecting omitempty, 59 | // assume that set==true if there is 60 | // a value to unmarshal 61 | *b = BadgeNumber{ 62 | number: int(val), 63 | set: true, 64 | } 65 | return nil 66 | } 67 | 68 | // Get a new badge number, set to the initial 69 | // number, and included in the payload 70 | func NewBadgeNumber(number int) BadgeNumber { 71 | return BadgeNumber{ 72 | number: number, 73 | set: true, 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /badge_number_test.go: -------------------------------------------------------------------------------- 1 | package apns 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | ) 7 | 8 | func TestBadgeNumberDefaults(t *testing.T) { 9 | b := BadgeNumber{} 10 | 11 | if b.IsSet() { 12 | t.Error("BadgeNumber should not be set by default") 13 | } 14 | if b.Number() != 0 { 15 | t.Error("Badge number should be 0 by default") 16 | } 17 | } 18 | 19 | func TestBadgeNumberNew(t *testing.T) { 20 | b := NewBadgeNumber(5) 21 | 22 | if !b.IsSet() { 23 | t.Error("NewBadgeNumber should return set BadgeNumber") 24 | } 25 | if b.Number() != 5 { 26 | t.Error("Resulting badge number should be 5") 27 | } 28 | } 29 | 30 | func TestBadgeNumberUnset(t *testing.T) { 31 | b := NewBadgeNumber(5) 32 | 33 | if !b.IsSet() { 34 | t.Error("NewBadgeNumber should return set BadgeNumber") 35 | } 36 | 37 | b.UnSet() 38 | 39 | if b.IsSet() { 40 | t.Error("UnSet should unset BadgeNumber") 41 | } 42 | if b.Number() != 0 { 43 | t.Error("UnSet should set number to 0") 44 | } 45 | } 46 | 47 | func TestBadgeNumberMarshalJSON(t *testing.T) { 48 | b := NewBadgeNumber(11) 49 | m := map[string]BadgeNumber{ 50 | "number": b, 51 | } 52 | 53 | jsonData, err := json.Marshal(m) 54 | if err != nil { 55 | t.Errorf("Error marshalling BadgeNumber: %s", err.Error()) 56 | } 57 | 58 | expected := "{\"number\":11}" 59 | if string(jsonData) != expected { 60 | t.Errorf( 61 | "JSON output\n%s\ndoes not match\n%s", 62 | string(jsonData), 63 | expected, 64 | ) 65 | } 66 | } 67 | 68 | func TestBadgeNumberUnmarshalJSON(t *testing.T) { 69 | type TestStruct struct { 70 | Number BadgeNumber 71 | } 72 | 73 | var ts TestStruct 74 | jsonStr := "{\"number\":11}" 75 | err := json.Unmarshal([]byte(jsonStr), &ts) 76 | if err != nil { 77 | t.Errorf("Error unmarshalling to BadgeNumber: %s", err.Error()) 78 | } 79 | 80 | if !ts.Number.IsSet() { 81 | t.Error("Resulting BadgeNumber should be set") 82 | } 83 | if ts.Number.Number() != 11 { 84 | t.Error("Expected number to be 11, got %d", ts.Number.Number()) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /connection.go: -------------------------------------------------------------------------------- 1 | //Package for creating a connection to Apple's APNS gateway and facilitating 2 | //sending push notifications via that gateway 3 | package apns 4 | 5 | import ( 6 | "bytes" 7 | "container/list" 8 | "crypto/tls" 9 | "encoding/binary" 10 | "encoding/hex" 11 | "errors" 12 | "fmt" 13 | "net" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | //Config for creating an APNS Connection 19 | type APNSConfig struct { 20 | //number of payloads to keep for error purposes, defaults to 10000 21 | InFlightPayloadBufferSize int 22 | //number of milliseconds between frame flushes, defaults to 10 23 | FramingTimeout int 24 | //max number of bytes allowed in payload, defaults to 2048 25 | MaxPayloadSize int 26 | //bytes for cert.pem : required 27 | CertificateBytes []byte 28 | //bytes for key.pem : required 29 | KeyBytes []byte 30 | //apple gateway, defaults to "gateway.push.apple.com" 31 | GatewayHost string 32 | //apple gateway port, defaults to "2195" 33 | GatewayPort string 34 | //max number of bytes to frame data to, defaults to TCP_FRAME_MAX 35 | //generally best to NOT set this and use the default 36 | MaxOutboundTCPFrameSize int 37 | //number of seconds to wait for connection before bailing, defaults to no timeout 38 | SocketTimeout int 39 | //number of seconds to wait for Tls handshake to complete before bailing, defaults to no timeout 40 | TlsTimeout int 41 | } 42 | 43 | //Object returned on a connection close or connection error 44 | type ConnectionClose struct { 45 | //Any payload objects that weren't sent after a connection close 46 | UnsentPayloads *list.List 47 | //The error details returned from Apple 48 | Error *AppleError 49 | //The payload object that caused the error 50 | ErrorPayload *Payload 51 | //True if error payload wasn't found indicating some unsent payloads were lost 52 | UnsentPayloadBufferOverflow bool 53 | } 54 | 55 | //Details from Apple regarding a connection close 56 | type AppleError struct { 57 | //Internal ID of the message that caused the error 58 | MessageID uint32 59 | //Error code returned by Apple (see APPLE_PUSH_RESPONSES) 60 | ErrorCode uint8 61 | //String name of error code 62 | ErrorString string 63 | } 64 | 65 | //APNS Connection state 66 | type APNSConnection struct { 67 | //Channel to send payloads on 68 | SendChannel chan *Payload 69 | //Channel that connection close is received on 70 | CloseChannel chan *ConnectionClose 71 | //raw socket connection 72 | socket net.Conn 73 | //config 74 | config *APNSConfig 75 | //Buffer to hold payloads for replay 76 | inFlightPayloadBuffer *list.List 77 | //Stateful buffer to hold framed byte data 78 | inFlightFrameByteBuffer *bytes.Buffer 79 | //Stateful buffer to hold data while generating item bytes 80 | inFlightItemByteBuffer *bytes.Buffer 81 | //Mutex to sync access to Frame byte buffer 82 | inFlightBufferLock *sync.Mutex 83 | //Stateful counter to identify payloads for replay 84 | payloadIdCounter uint32 85 | // Mutex to sync during disconnect 86 | disconnectLock *sync.Mutex 87 | // Boolean saying we're disconnecting 88 | disconnecting bool 89 | } 90 | 91 | //Wrapper for associating an ID with a Payload object 92 | type idPayload struct { 93 | //The Payload object 94 | Payload *Payload 95 | //The numerical id (from payloadIdCounter) for replay identification 96 | ID uint32 97 | } 98 | 99 | const ( 100 | //Max number of bytes in a TCP frame 101 | TCP_FRAME_MAX = 65535 102 | //Number of bytes used in the Apple Notification Header 103 | //command is 1 byte, frame length is 4 bytes 104 | NOTIFICATION_HEADER_SIZE = 5 105 | //Size of token 106 | APNS_TOKEN_SIZE = 32 107 | // client shutdown via disconnect error code 108 | CONNECTION_CLOSED_DISCONNECT = 250 109 | // client shutdown via unknown error code 110 | CONNECTION_CLOSED_UNKNOWN = 251 111 | ) 112 | 113 | // This enumerates the response codes that Apple defines 114 | // for push notification attempts. 115 | var APPLE_PUSH_RESPONSES = map[uint8]string{ 116 | 0: "NO_ERRORS", 117 | 1: "PROCESSING_ERROR", 118 | 2: "MISSING_DEVICE_TOKEN", 119 | 3: "MISSING_TOPIC", 120 | 4: "MISSING_PAYLOAD", 121 | 5: "INVALID_TOKEN_SIZE", 122 | 6: "INVALID_TOPIC_SIZE", 123 | 7: "INVALID_PAYLOAD_SIZE", 124 | 8: "INVALID_TOKEN", 125 | 10: "SHUTDOWN", // apple shutdown connection 126 | 128: "INVALID_FRAME_ITEM_ID", //this is not documented, but ran across it in testing 127 | CONNECTION_CLOSED_DISCONNECT: "CONNECTION CLOSED DISCONNECT", // client disconnect (not apple, used internally) 128 | CONNECTION_CLOSED_UNKNOWN: "CONNECTION CLOSED UNKNOWN", // client unknown connection error (not apple, used internally) 129 | 255: "UNKNOWN", 130 | } 131 | 132 | func (e *AppleError) Error() string { 133 | return e.ErrorString 134 | } 135 | 136 | // Apply config defaults to given Config 137 | func applyConfigDefaults(config *APNSConfig) error { 138 | errorStrs := "" 139 | 140 | if config.CertificateBytes == nil || config.KeyBytes == nil { 141 | errorStrs += "Invalid Key/Certificate bytes\n" 142 | } 143 | if config.InFlightPayloadBufferSize < 0 { 144 | errorStrs += "Invalid InFlightPayloadBufferSize. Should be > 0 (and probably around 10000)\n" 145 | } 146 | if config.MaxOutboundTCPFrameSize < 0 || config.MaxOutboundTCPFrameSize > TCP_FRAME_MAX { 147 | errorStrs += "Invalid MaxOutboundTCPFrameSize. Should be between 0 and TCP_FRAME_MAX (and probably above 2048)\n" 148 | } 149 | if config.MaxPayloadSize < 0 { 150 | errorStrs += "Invalid MaxPayloadSize. Should be greater than 0.\n" 151 | } 152 | 153 | if errorStrs != "" { 154 | return errors.New(errorStrs) 155 | } 156 | 157 | if config.InFlightPayloadBufferSize == 0 { 158 | config.InFlightPayloadBufferSize = 10000 159 | } 160 | if config.MaxOutboundTCPFrameSize == 0 { 161 | config.MaxOutboundTCPFrameSize = TCP_FRAME_MAX 162 | } 163 | if config.FramingTimeout == 0 { 164 | config.FramingTimeout = 10 165 | } 166 | if config.GatewayPort == "" { 167 | config.GatewayPort = "2195" 168 | } 169 | if config.GatewayHost == "" { 170 | config.GatewayHost = "gateway.push.apple.com" 171 | } 172 | if config.MaxPayloadSize == 0 { 173 | config.MaxPayloadSize = 2048 174 | } 175 | if config.TlsTimeout == 0 { 176 | config.TlsTimeout = 5 177 | } 178 | return nil 179 | } 180 | 181 | //Create a new apns connection with supplied config 182 | //If invalid config an error will be returned 183 | //See APNSConfig object for defaults 184 | func NewAPNSConnection(config *APNSConfig) (*APNSConnection, error) { 185 | err := applyConfigDefaults(config) 186 | 187 | if err != nil { 188 | return nil, err 189 | } 190 | 191 | tcpSocket, err := net.DialTimeout("tcp", 192 | config.GatewayHost+":"+config.GatewayPort, 193 | time.Duration(config.SocketTimeout)*time.Second) 194 | if err != nil { 195 | //failed to connect to gateway 196 | return nil, err 197 | } 198 | 199 | tlsSocket, err := createTLSClient(tcpSocket, config) 200 | 201 | if err != nil { 202 | return nil, err 203 | } 204 | 205 | return socketAPNSConnection(tlsSocket, config), nil 206 | } 207 | 208 | //Create APNS connection from raw socket 209 | func SocketAPNSConnection(socket net.Conn, config *APNSConfig) (*APNSConnection, error) { 210 | err := applyConfigDefaults(config) 211 | 212 | if err != nil { 213 | return nil, err 214 | } 215 | 216 | tlsSocket, err := createTLSClient(socket, config) 217 | 218 | if err != nil { 219 | return nil, err 220 | } 221 | 222 | return socketAPNSConnection(tlsSocket, config), nil 223 | } 224 | 225 | func createTLSClient(socket net.Conn, config *APNSConfig) (net.Conn, error) { 226 | x509Cert, err := tls.X509KeyPair(config.CertificateBytes, config.KeyBytes) 227 | if err != nil { 228 | //failed to validate key pair 229 | return nil, err 230 | } 231 | 232 | tlsConf := &tls.Config{ 233 | Certificates: []tls.Certificate{x509Cert}, 234 | ServerName: config.GatewayHost, 235 | } 236 | 237 | tlsSocket := tls.Client(socket, tlsConf) 238 | tlsSocket.SetDeadline(time.Now().Add(time.Duration(config.TlsTimeout) * time.Second)) 239 | err = tlsSocket.Handshake() 240 | if err != nil { 241 | //failed to handshake with tls information 242 | return nil, err 243 | } 244 | 245 | //hooray! we're connected 246 | //reset the deadline so it doesn't fail subsequent writes 247 | tlsSocket.SetDeadline(time.Time{}) 248 | 249 | return tlsSocket, nil 250 | } 251 | 252 | //Starts connection close and send listeners 253 | func socketAPNSConnection(socket net.Conn, config *APNSConfig) *APNSConnection { 254 | 255 | c := new(APNSConnection) 256 | //TODO(karl): maybe should copy the config to prevent tampering? 257 | c.config = config 258 | c.inFlightPayloadBuffer = list.New() 259 | c.socket = socket 260 | c.SendChannel = make(chan *Payload) 261 | c.CloseChannel = make(chan *ConnectionClose) 262 | c.inFlightFrameByteBuffer = new(bytes.Buffer) 263 | c.inFlightItemByteBuffer = new(bytes.Buffer) 264 | c.inFlightBufferLock = new(sync.Mutex) 265 | c.disconnectLock = new(sync.Mutex) 266 | c.payloadIdCounter = 1 267 | errCloseChannel := make(chan *AppleError) 268 | 269 | go c.closeListener(errCloseChannel) 270 | go c.sendListener(errCloseChannel) 271 | 272 | return c 273 | } 274 | 275 | //Disconnect from the Apns Gateway 276 | //Flushes any currently unsent messages before disconnecting from the socket 277 | func (c *APNSConnection) Disconnect() { 278 | c.disconnectLock.Lock() 279 | c.disconnecting = true 280 | c.disconnectLock.Unlock() 281 | //flush on disconnect 282 | c.inFlightBufferLock.Lock() 283 | c.flushBufferToSocket() 284 | c.inFlightBufferLock.Unlock() 285 | c.noFlushDisconnect() 286 | } 287 | 288 | //internal close socket 289 | func (c *APNSConnection) noFlushDisconnect() { 290 | c.socket.Close() 291 | } 292 | 293 | //go-routine to listen for socket closes or apple response information 294 | func (c *APNSConnection) closeListener(errCloseChannel chan *AppleError) { 295 | buffer := make([]byte, 6, 6) 296 | _, err := c.socket.Read(buffer) 297 | if err != nil { 298 | c.disconnectLock.Lock() 299 | if c.disconnecting { 300 | errCloseChannel <- &AppleError{ 301 | ErrorCode: CONNECTION_CLOSED_DISCONNECT, // closed due to disconnect 302 | ErrorString: err.Error(), 303 | MessageID: 0, 304 | } 305 | } else { 306 | errCloseChannel <- &AppleError{ 307 | ErrorCode: CONNECTION_CLOSED_UNKNOWN, // don't know why we closed 308 | ErrorString: err.Error(), 309 | MessageID: 0, 310 | } 311 | } 312 | c.disconnectLock.Unlock() 313 | } else { 314 | messageId := binary.BigEndian.Uint32(buffer[2:]) 315 | errCloseChannel <- &AppleError{ 316 | ErrorString: APPLE_PUSH_RESPONSES[uint8(buffer[1])], 317 | ErrorCode: uint8(buffer[1]), 318 | MessageID: messageId, 319 | } 320 | } 321 | } 322 | 323 | //go-routine to listen for Payloads which should be sent 324 | func (c *APNSConnection) sendListener(errCloseChannel chan *AppleError) { 325 | var appleError *AppleError 326 | 327 | longTimeoutDuration := 5 * time.Minute 328 | shortTimeoutDuration := time.Duration(c.config.FramingTimeout) * time.Millisecond 329 | zeroTimeoutDuration := 0 * time.Millisecond 330 | timeoutTimer := time.NewTimer(longTimeoutDuration) 331 | 332 | for { 333 | if appleError != nil { 334 | break 335 | } 336 | select { 337 | case sendPayload := <-c.SendChannel: 338 | if sendPayload == nil { 339 | //channel was closed 340 | return 341 | } 342 | idPayloadObj := &idPayload{ 343 | Payload: sendPayload, 344 | ID: c.payloadIdCounter, 345 | } 346 | 347 | // increment payload id counter but don't allow 348 | // 0 as valid id as it is the null value 349 | // only a problem if we overflow a uint32 350 | c.payloadIdCounter++ 351 | 352 | if c.payloadIdCounter == 0 { 353 | c.payloadIdCounter = 1 354 | } 355 | 356 | err := c.bufferPayload(idPayloadObj) 357 | if err != nil { 358 | fmt.Print(err) 359 | break 360 | } 361 | 362 | if shortTimeoutDuration > zeroTimeoutDuration { 363 | //schedule short timeout 364 | timeoutTimer.Reset(shortTimeoutDuration) 365 | } else { 366 | //flush buffer to socket 367 | c.inFlightBufferLock.Lock() 368 | c.flushBufferToSocket() 369 | c.inFlightBufferLock.Unlock() 370 | timeoutTimer.Reset(longTimeoutDuration) 371 | } 372 | break 373 | case <-timeoutTimer.C: 374 | //flush buffer to socket 375 | c.inFlightBufferLock.Lock() 376 | c.flushBufferToSocket() 377 | c.inFlightBufferLock.Unlock() 378 | timeoutTimer.Reset(longTimeoutDuration) 379 | break 380 | case appleError = <-errCloseChannel: 381 | break 382 | } 383 | } 384 | 385 | // gather unsent payload objs 386 | unsentPayloads := list.New() 387 | var errorPayload *Payload 388 | // only calculate unsent payloads if messageId is not empty 389 | if appleError.ErrorCode != 0 && 390 | appleError.ErrorCode != CONNECTION_CLOSED_DISCONNECT && 391 | appleError.MessageID != 0 { 392 | for e := c.inFlightPayloadBuffer.Front(); e != nil; e = e.Next() { 393 | idPayloadObj := e.Value.(*idPayload) 394 | if idPayloadObj.ID == appleError.MessageID { 395 | //found error payload, keep track of it and remove from send buffer 396 | errorPayload = idPayloadObj.Payload 397 | break 398 | } 399 | unsentPayloads.PushFront(idPayloadObj.Payload) 400 | } 401 | } 402 | 403 | // clear error information if we closed the connection 404 | if appleError.ErrorCode == CONNECTION_CLOSED_DISCONNECT { 405 | appleError = nil 406 | errorPayload = nil 407 | } 408 | 409 | //connection close channel write and close 410 | go func() { 411 | c.CloseChannel <- &ConnectionClose{ 412 | Error: appleError, 413 | UnsentPayloads: unsentPayloads, 414 | ErrorPayload: errorPayload, 415 | UnsentPayloadBufferOverflow: (unsentPayloads.Len() > 0 && errorPayload == nil), 416 | } 417 | 418 | close(c.CloseChannel) 419 | }() 420 | } 421 | 422 | //Write buffer payload to tcp frame buffer and flush if tcp frame buffer full 423 | //THREADSAFE (with regard to interaction with the frameBuffer using frameBufferLock) 424 | func (c *APNSConnection) bufferPayload(idPayloadObj *idPayload) error { 425 | token, err := hex.DecodeString(idPayloadObj.Payload.Token) 426 | if err != nil { 427 | return fmt.Errorf("Error decoding token for payload %+v : %v\n", idPayloadObj.Payload, err) 428 | } 429 | 430 | if len(token) != APNS_TOKEN_SIZE { 431 | return fmt.Errorf("Invalid token length. Was %v bytes but should have been %v bytes\n", len(token), APNS_TOKEN_SIZE) 432 | } 433 | 434 | payloadBytes, err := idPayloadObj.Payload.Marshal(c.config.MaxPayloadSize) 435 | if err != nil { 436 | return fmt.Errorf("Error marshalling payload %+v : %v\n", idPayloadObj.Payload, err) 437 | } 438 | 439 | c.inFlightPayloadBuffer.PushFront(idPayloadObj) 440 | //check to see if we've overrun our buffer 441 | //if so, remove one from the buffer 442 | if c.inFlightPayloadBuffer.Len() > c.config.InFlightPayloadBufferSize { 443 | c.inFlightPayloadBuffer.Remove(c.inFlightPayloadBuffer.Back()) 444 | } 445 | 446 | //acquire lock to tcp buffer to do length checking, buffer writing, 447 | //and potentially flush buffer 448 | c.inFlightBufferLock.Lock() 449 | defer c.inFlightBufferLock.Unlock() 450 | 451 | //write token 452 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint8(1)) 453 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint16(APNS_TOKEN_SIZE)) 454 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, token) 455 | 456 | //write payload 457 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint8(2)) 458 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint16(len(payloadBytes))) 459 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, payloadBytes) 460 | 461 | //write id 462 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint8(3)) 463 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint16(4)) 464 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, idPayloadObj.ID) 465 | 466 | //write expire date if set 467 | if idPayloadObj.Payload.ExpirationTime != 0 { 468 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint8(4)) 469 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint16(4)) 470 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, idPayloadObj.Payload.ExpirationTime) 471 | } 472 | 473 | //write priority if set correctly 474 | if idPayloadObj.Payload.Priority == 10 || idPayloadObj.Payload.Priority == 5 { 475 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint8(5)) 476 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, uint16(4)) 477 | binary.Write(c.inFlightItemByteBuffer, binary.BigEndian, idPayloadObj.Payload.Priority) 478 | } 479 | 480 | //check to see if we should flush inFlightFrameByteBuffer 481 | if c.inFlightFrameByteBuffer.Len()+c.inFlightItemByteBuffer.Len()+NOTIFICATION_HEADER_SIZE > TCP_FRAME_MAX { 482 | c.flushBufferToSocket() 483 | } 484 | 485 | //write header info and item info 486 | binary.Write(c.inFlightFrameByteBuffer, binary.BigEndian, uint8(2)) 487 | binary.Write(c.inFlightFrameByteBuffer, binary.BigEndian, uint32(c.inFlightItemByteBuffer.Len())) 488 | c.inFlightItemByteBuffer.WriteTo(c.inFlightFrameByteBuffer) 489 | 490 | c.inFlightItemByteBuffer.Reset() 491 | 492 | return nil 493 | } 494 | 495 | //NOT THREADSAFE (need to acquire inFlightBufferLock before calling) 496 | //Write tcp frame buffer to socket and reset when done 497 | //Close on error 498 | func (c *APNSConnection) flushBufferToSocket() { 499 | //if buffer not created, or zero length, do nothing 500 | if c.inFlightFrameByteBuffer == nil || c.inFlightFrameByteBuffer.Len() == 0 { 501 | return 502 | } 503 | 504 | bufBytes := c.inFlightFrameByteBuffer.Bytes() 505 | 506 | //write to socket 507 | _, writeErr := c.socket.Write(bufBytes) 508 | if writeErr != nil { 509 | fmt.Printf("Error while writing to socket \n%v\n", writeErr) 510 | defer c.noFlushDisconnect() 511 | } 512 | c.inFlightFrameByteBuffer.Reset() 513 | } 514 | -------------------------------------------------------------------------------- /connection_test.go: -------------------------------------------------------------------------------- 1 | package apns 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "net" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | /** 14 | * Tests related to connection write errors 15 | */ 16 | type MockConnErrorOnWrite struct { 17 | WrittenBytes *bytes.Buffer 18 | CloseChannel chan bool 19 | } 20 | 21 | func (conn MockConnErrorOnWrite) Read(b []byte) (n int, err error) { 22 | <-conn.CloseChannel 23 | return -1, errors.New("Socket Closed") 24 | } 25 | func (conn MockConnErrorOnWrite) Write(b []byte) (n int, err error) { 26 | conn.WrittenBytes.Write(b) 27 | return len(b), errors.New("Socket Closed") 28 | } 29 | func (conn MockConnErrorOnWrite) Close() error { 30 | defer func() { conn.CloseChannel <- true }() 31 | return nil 32 | } 33 | func (conn MockConnErrorOnWrite) LocalAddr() net.Addr { 34 | return nil 35 | } 36 | func (conn MockConnErrorOnWrite) RemoteAddr() net.Addr { 37 | return nil 38 | } 39 | func (conn MockConnErrorOnWrite) SetDeadline(t time.Time) error { 40 | return nil 41 | } 42 | func (conn MockConnErrorOnWrite) SetReadDeadline(t time.Time) error { 43 | return nil 44 | } 45 | func (conn MockConnErrorOnWrite) SetWriteDeadline(t time.Time) error { 46 | return nil 47 | } 48 | 49 | func TestConnectionShouldCloseOnWriteError(t *testing.T) { 50 | socket := MockConnErrorOnWrite{ 51 | WrittenBytes: new(bytes.Buffer), 52 | CloseChannel: make(chan bool), 53 | } 54 | 55 | apn := socketAPNSConnection(socket, 56 | &APNSConfig{ 57 | InFlightPayloadBufferSize: 10000, 58 | FramingTimeout: 10, 59 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 60 | MaxPayloadSize: 2048, 61 | }) 62 | 63 | payload := &Payload{ 64 | AlertText: "Testing", 65 | Token: "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8f", 66 | } 67 | 68 | apn.SendChannel <- payload 69 | connectionClose := <-apn.CloseChannel 70 | 71 | if connectionClose.Error.ErrorCode != CONNECTION_CLOSED_UNKNOWN { 72 | fmt.Printf("Should have received error CONNECTION_CLOSED_UNKNOWN for closed socket but received %v\n", connectionClose.Error) 73 | t.FailNow() 74 | } 75 | } 76 | 77 | /** 78 | * Tests related to connection write errors 79 | */ 80 | type MockConnErrorOnWrite2 struct { 81 | WrittenBytes *bytes.Buffer 82 | CloseChannel chan bool 83 | } 84 | 85 | func (conn MockConnErrorOnWrite2) Read(b []byte) (n int, err error) { 86 | fmt.Printf("Read %v\n", conn.CloseChannel) 87 | return -1, errors.New("Socket Closed on Read") 88 | } 89 | func (conn MockConnErrorOnWrite2) Write(b []byte) (n int, err error) { 90 | <-conn.CloseChannel 91 | return len(b), errors.New("Socket Closed on Write") 92 | } 93 | func (conn MockConnErrorOnWrite2) Close() error { 94 | fmt.Printf("conn close called %v\n", conn.CloseChannel) 95 | defer func() { conn.CloseChannel <- true }() 96 | return nil 97 | } 98 | func (conn MockConnErrorOnWrite2) LocalAddr() net.Addr { 99 | return nil 100 | } 101 | func (conn MockConnErrorOnWrite2) RemoteAddr() net.Addr { 102 | return nil 103 | } 104 | func (conn MockConnErrorOnWrite2) SetDeadline(t time.Time) error { 105 | return nil 106 | } 107 | func (conn MockConnErrorOnWrite2) SetReadDeadline(t time.Time) error { 108 | return nil 109 | } 110 | func (conn MockConnErrorOnWrite2) SetWriteDeadline(t time.Time) error { 111 | return nil 112 | } 113 | 114 | func TestConnectionShouldCloseOnReadError(t *testing.T) { 115 | 116 | extSendChannel := make(chan *Payload) 117 | syncChan := make(chan bool) 118 | 119 | go func() { 120 | socket := MockConnErrorOnWrite2{ 121 | WrittenBytes: new(bytes.Buffer), 122 | CloseChannel: make(chan bool), 123 | } 124 | 125 | apn := socketAPNSConnection(socket, 126 | &APNSConfig{ 127 | InFlightPayloadBufferSize: 10000, 128 | FramingTimeout: 10, 129 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 130 | MaxPayloadSize: 2048, 131 | }) 132 | 133 | for { 134 | select { 135 | case p := <-extSendChannel: 136 | apn.SendChannel <- p 137 | break 138 | case connectionClose := <-apn.CloseChannel: 139 | if connectionClose.Error.ErrorCode != CONNECTION_CLOSED_UNKNOWN { 140 | fmt.Printf("Should have received error CONNECTION_CLOSED_UNKNOWN for closed socket but received %v\n", connectionClose.Error) 141 | syncChan <- true 142 | t.FailNow() 143 | } 144 | syncChan <- true 145 | return 146 | } 147 | } 148 | }() 149 | 150 | <-syncChan 151 | } 152 | 153 | /** 154 | * Tests related to apple returned errors 155 | */ 156 | type MockConnErrorOnToken struct { 157 | WrittenBytes *bytes.Buffer 158 | CloseChannel chan uint32 159 | } 160 | 161 | func (conn MockConnErrorOnToken) Read(b []byte) (n int, err error) { 162 | errorId := <-conn.CloseChannel 163 | b[0] = uint8(8) //command 164 | b[1] = uint8(8) //invalid token 165 | //write error id in big endian 166 | b[2] = byte(errorId >> 24) 167 | b[3] = byte(errorId >> 16) 168 | b[4] = byte(errorId >> 8) 169 | b[5] = byte(errorId) 170 | return 6, nil 171 | } 172 | func (conn MockConnErrorOnToken) Write(b []byte) (n int, err error) { 173 | conn.WrittenBytes.Write(b) 174 | defer func() { conn.CloseChannel <- 1 }() 175 | return len(b), nil 176 | } 177 | func (conn MockConnErrorOnToken) Close() error { 178 | return nil 179 | } 180 | func (conn MockConnErrorOnToken) LocalAddr() net.Addr { 181 | return nil 182 | } 183 | func (conn MockConnErrorOnToken) RemoteAddr() net.Addr { 184 | return nil 185 | } 186 | func (conn MockConnErrorOnToken) SetDeadline(t time.Time) error { 187 | return nil 188 | } 189 | func (conn MockConnErrorOnToken) SetReadDeadline(t time.Time) error { 190 | return nil 191 | } 192 | func (conn MockConnErrorOnToken) SetWriteDeadline(t time.Time) error { 193 | return nil 194 | } 195 | 196 | func TestConnectionShouldCloseOnAppleResponse(t *testing.T) { 197 | socket := MockConnErrorOnToken{ 198 | WrittenBytes: new(bytes.Buffer), 199 | CloseChannel: make(chan uint32), 200 | } 201 | 202 | apn := socketAPNSConnection(socket, 203 | &APNSConfig{ 204 | InFlightPayloadBufferSize: 10000, 205 | FramingTimeout: 10, 206 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 207 | MaxPayloadSize: 2048, 208 | }) 209 | 210 | token := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8f" 211 | 212 | payload := &Payload{ 213 | AlertText: "Testing", 214 | Token: token, 215 | } 216 | 217 | apn.SendChannel <- payload 218 | 219 | connectionClose := <-apn.CloseChannel 220 | 221 | if connectionClose.Error.ErrorCode != 8 { 222 | fmt.Printf("Should have received error 8 for closed socket but received %v\n", connectionClose.Error) 223 | t.FailNow() 224 | } 225 | 226 | if connectionClose.ErrorPayload == nil || 227 | connectionClose.ErrorPayload.Token != token { 228 | fmt.Printf("Should have returned payload object but received %v\n", connectionClose.ErrorPayload) 229 | t.FailNow() 230 | } 231 | } 232 | 233 | type MockConnErrorOnToken2 struct { 234 | WrittenBytes *bytes.Buffer 235 | CloseChannel chan uint32 236 | } 237 | 238 | func (conn MockConnErrorOnToken2) Read(b []byte) (n int, err error) { 239 | errorId := <-conn.CloseChannel 240 | b[0] = uint8(8) //command 241 | b[1] = uint8(8) //invalid token 242 | //write error id in big endian 243 | b[2] = byte(errorId >> 24) 244 | b[3] = byte(errorId >> 16) 245 | b[4] = byte(errorId >> 8) 246 | b[5] = byte(errorId) 247 | return 6, nil 248 | } 249 | func (conn MockConnErrorOnToken2) Write(b []byte) (n int, err error) { 250 | conn.WrittenBytes.Write(b) 251 | defer func() { conn.CloseChannel <- 2 }() 252 | return len(b), nil 253 | } 254 | func (conn MockConnErrorOnToken2) Close() error { 255 | return nil 256 | } 257 | func (conn MockConnErrorOnToken2) LocalAddr() net.Addr { 258 | return nil 259 | } 260 | func (conn MockConnErrorOnToken2) RemoteAddr() net.Addr { 261 | return nil 262 | } 263 | func (conn MockConnErrorOnToken2) SetDeadline(t time.Time) error { 264 | return nil 265 | } 266 | func (conn MockConnErrorOnToken2) SetReadDeadline(t time.Time) error { 267 | return nil 268 | } 269 | func (conn MockConnErrorOnToken2) SetWriteDeadline(t time.Time) error { 270 | return nil 271 | } 272 | 273 | func TestConnectionShouldCloseAndReturnUnsentOnAppleResponse(t *testing.T) { 274 | 275 | extSendChannel := make(chan *Payload) 276 | syncChan := make(chan bool) 277 | 278 | token := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8f" 279 | token2 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8e" 280 | token3 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8d" 281 | token4 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8c" 282 | 283 | go func() { 284 | 285 | payload := &Payload{ 286 | AlertText: "Testing", 287 | Token: token, 288 | } 289 | payload2 := &Payload{ 290 | AlertText: "Testing2", 291 | Token: token2, 292 | } 293 | payload3 := &Payload{ 294 | AlertText: "Testing3", 295 | Token: token3, 296 | } 297 | payload4 := &Payload{ 298 | AlertText: "Testing4", 299 | Token: token4, 300 | } 301 | 302 | extSendChannel <- payload 303 | extSendChannel <- payload2 304 | extSendChannel <- payload3 305 | extSendChannel <- payload4 306 | }() 307 | 308 | go func() { 309 | socket := MockConnErrorOnToken2{ 310 | WrittenBytes: new(bytes.Buffer), 311 | CloseChannel: make(chan uint32), 312 | } 313 | 314 | apn := socketAPNSConnection(socket, 315 | &APNSConfig{ 316 | InFlightPayloadBufferSize: 10000, 317 | FramingTimeout: 10, 318 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 319 | MaxPayloadSize: 2048, 320 | }) 321 | 322 | for { 323 | select { 324 | case p := <-extSendChannel: 325 | if p == nil { 326 | return 327 | } 328 | apn.SendChannel <- p 329 | break 330 | case connectionClose := <-apn.CloseChannel: 331 | if connectionClose.Error.ErrorCode != 8 { 332 | fmt.Printf("Should have received error 8 for closed socket but received %v\n", connectionClose.Error) 333 | syncChan <- true 334 | t.FailNow() 335 | } 336 | 337 | if connectionClose.ErrorPayload == nil || 338 | connectionClose.ErrorPayload.Token != token2 { 339 | fmt.Printf("Should have returned payload object but received %v\n", connectionClose.ErrorPayload) 340 | syncChan <- true 341 | t.FailNow() 342 | } 343 | 344 | for e := connectionClose.UnsentPayloads.Front(); e != nil; e = e.Next() { 345 | fmt.Printf("Unsent payload %v\n", e.Value.(*Payload)) 346 | } 347 | 348 | if connectionClose.UnsentPayloads == nil || 349 | connectionClose.UnsentPayloads.Len() != 2 { 350 | fmt.Printf("Should have returned 2 unsent payload objects but received %v len %v\n", connectionClose.UnsentPayloads, connectionClose.UnsentPayloads.Len()) 351 | syncChan <- true 352 | t.FailNow() 353 | } 354 | 355 | if connectionClose.UnsentPayloads.Front().Value.(*Payload).Token != token3 && 356 | connectionClose.UnsentPayloads.Back().Value.(*Payload).Token != token4 { 357 | fmt.Printf("Expected to receive specific unsent payloads but received %v len %v\n", connectionClose.UnsentPayloads, connectionClose.UnsentPayloads.Len()) 358 | syncChan <- true 359 | t.FailNow() 360 | } 361 | 362 | if connectionClose.UnsentPayloadBufferOverflow { 363 | fmt.Printf("Expected to NOT get buffer overflow indication but did\n") 364 | syncChan <- true 365 | t.FailNow() 366 | } 367 | 368 | syncChan <- true 369 | return 370 | } 371 | } 372 | }() 373 | 374 | <-syncChan 375 | } 376 | 377 | type MockConnErrorOnToken3 struct { 378 | WrittenBytes *bytes.Buffer 379 | CloseChannel chan uint32 380 | } 381 | 382 | func (conn MockConnErrorOnToken3) Read(b []byte) (n int, err error) { 383 | errorId := <-conn.CloseChannel 384 | b[0] = uint8(8) //command 385 | b[1] = uint8(8) //invalid token 386 | //write error id in big endian 387 | b[2] = byte(errorId >> 24) 388 | b[3] = byte(errorId >> 16) 389 | b[4] = byte(errorId >> 8) 390 | b[5] = byte(errorId) 391 | return 6, nil 392 | } 393 | func (conn MockConnErrorOnToken3) Write(b []byte) (n int, err error) { 394 | conn.WrittenBytes.Write(b) 395 | dataSize := binary.BigEndian.Uint16(b[1:3]) 396 | idStart := 3 + uint64(dataSize) - 4 - 4 - 1 397 | id := binary.BigEndian.Uint32(b[idStart : idStart+4]) 398 | //after #4 written, say an error happened on id 2 399 | if id == 4 { 400 | defer func() { conn.CloseChannel <- 2 }() 401 | } 402 | return len(b), nil 403 | } 404 | func (conn MockConnErrorOnToken3) Close() error { 405 | return nil 406 | } 407 | func (conn MockConnErrorOnToken3) LocalAddr() net.Addr { 408 | return nil 409 | } 410 | func (conn MockConnErrorOnToken3) RemoteAddr() net.Addr { 411 | return nil 412 | } 413 | func (conn MockConnErrorOnToken3) SetDeadline(t time.Time) error { 414 | return nil 415 | } 416 | func (conn MockConnErrorOnToken3) SetReadDeadline(t time.Time) error { 417 | return nil 418 | } 419 | func (conn MockConnErrorOnToken3) SetWriteDeadline(t time.Time) error { 420 | return nil 421 | } 422 | 423 | func TestConnectionShouldCloseAndReturnUnsentUpToBufferSizeOnAppleResponse(t *testing.T) { 424 | 425 | extSendChannel := make(chan *Payload) 426 | syncChan := make(chan bool) 427 | 428 | token := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8f" 429 | token2 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8e" 430 | token3 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8d" 431 | token4 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8c" 432 | 433 | go func() { 434 | 435 | payload := &Payload{ 436 | AlertText: "Testing", 437 | Token: token, 438 | } 439 | payload2 := &Payload{ 440 | AlertText: "Testing2", 441 | Token: token2, 442 | } 443 | payload3 := &Payload{ 444 | AlertText: "Testing3", 445 | Token: token3, 446 | } 447 | payload4 := &Payload{ 448 | AlertText: "Testing4", 449 | Token: token4, 450 | } 451 | 452 | extSendChannel <- payload 453 | extSendChannel <- payload2 454 | extSendChannel <- payload3 455 | extSendChannel <- payload4 456 | }() 457 | 458 | go func() { 459 | socket := MockConnErrorOnToken2{ 460 | WrittenBytes: new(bytes.Buffer), 461 | CloseChannel: make(chan uint32), 462 | } 463 | 464 | apn := socketAPNSConnection(socket, 465 | &APNSConfig{ 466 | InFlightPayloadBufferSize: 1, 467 | FramingTimeout: 10, 468 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 469 | MaxPayloadSize: 2048, 470 | }) 471 | 472 | for { 473 | select { 474 | case p := <-extSendChannel: 475 | if p == nil { 476 | return 477 | } 478 | apn.SendChannel <- p 479 | break 480 | case connectionClose := <-apn.CloseChannel: 481 | if connectionClose.Error.ErrorCode != 8 { 482 | fmt.Printf("Should have received error 8 for closed socket but received %v\n", connectionClose.Error) 483 | syncChan <- true 484 | t.FailNow() 485 | } 486 | 487 | if connectionClose.ErrorPayload != nil { 488 | fmt.Printf("Should have returned payload object but received %v\n", connectionClose.ErrorPayload) 489 | syncChan <- true 490 | t.FailNow() 491 | } 492 | 493 | for e := connectionClose.UnsentPayloads.Front(); e != nil; e = e.Next() { 494 | fmt.Printf("Unsent payload %v\n", e.Value.(*Payload)) 495 | } 496 | 497 | if connectionClose.UnsentPayloads == nil || 498 | connectionClose.UnsentPayloads.Len() != 1 { 499 | fmt.Printf("Should have returned 1 unsent payload objects but received %v len %v\n", connectionClose.UnsentPayloads, connectionClose.UnsentPayloads.Len()) 500 | syncChan <- true 501 | t.FailNow() 502 | } 503 | 504 | if connectionClose.UnsentPayloads.Front().Value.(*Payload).Token != token4 { 505 | fmt.Printf("Expected to receive specific unsent payloads but received %v len %v\n", connectionClose.UnsentPayloads, connectionClose.UnsentPayloads.Len()) 506 | syncChan <- true 507 | t.FailNow() 508 | } 509 | 510 | if !connectionClose.UnsentPayloadBufferOverflow { 511 | fmt.Printf("Expected to get buffer overflow indication but didn't\n") 512 | syncChan <- true 513 | t.FailNow() 514 | } 515 | 516 | syncChan <- true 517 | return 518 | } 519 | } 520 | }() 521 | 522 | <-syncChan 523 | } 524 | 525 | type MockConnErrorOnToken4 struct { 526 | WrittenBytes *bytes.Buffer 527 | DisconnectChannel chan bool 528 | } 529 | 530 | func (conn MockConnErrorOnToken4) Read(b []byte) (n int, err error) { 531 | <-conn.DisconnectChannel 532 | return 0, fmt.Errorf("Connection closed") 533 | } 534 | func (conn MockConnErrorOnToken4) Write(b []byte) (n int, err error) { 535 | conn.WrittenBytes.Write(b) 536 | return len(b), nil 537 | } 538 | func (conn MockConnErrorOnToken4) Close() error { 539 | return nil 540 | } 541 | func (conn MockConnErrorOnToken4) LocalAddr() net.Addr { 542 | return nil 543 | } 544 | func (conn MockConnErrorOnToken4) RemoteAddr() net.Addr { 545 | return nil 546 | } 547 | func (conn MockConnErrorOnToken4) SetDeadline(t time.Time) error { 548 | return nil 549 | } 550 | func (conn MockConnErrorOnToken4) SetReadDeadline(t time.Time) error { 551 | return nil 552 | } 553 | func (conn MockConnErrorOnToken4) SetWriteDeadline(t time.Time) error { 554 | return nil 555 | } 556 | 557 | func TestConnectionShouldNotReturnErrorOnDisconnect(t *testing.T) { 558 | socket := MockConnErrorOnToken4{ 559 | WrittenBytes: new(bytes.Buffer), 560 | DisconnectChannel: make(chan bool), 561 | } 562 | 563 | apn := socketAPNSConnection(socket, 564 | &APNSConfig{ 565 | InFlightPayloadBufferSize: 10000, 566 | FramingTimeout: 10, 567 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 568 | MaxPayloadSize: 2048, 569 | }) 570 | 571 | token := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8f" 572 | 573 | payload := &Payload{ 574 | AlertText: "Testing", 575 | Token: token, 576 | } 577 | 578 | apn.SendChannel <- payload 579 | 580 | apn.Disconnect() 581 | 582 | socket.DisconnectChannel <- true 583 | 584 | connectionClose := <-apn.CloseChannel 585 | 586 | if connectionClose.Error != nil { 587 | fmt.Printf("Should NOT have received error but received %v\n", connectionClose.Error) 588 | t.FailNow() 589 | } 590 | 591 | if connectionClose.ErrorPayload != nil { 592 | fmt.Printf("Should NOT have received error payload but received %v\n", connectionClose.ErrorPayload) 593 | t.FailNow() 594 | } 595 | } 596 | 597 | type MockConnErrorOnToken5 struct { 598 | WrittenBytes *bytes.Buffer 599 | DisconnectChannel chan bool 600 | } 601 | 602 | func (conn MockConnErrorOnToken5) Read(b []byte) (n int, err error) { 603 | <-conn.DisconnectChannel 604 | return 0, fmt.Errorf("Connection closed") 605 | } 606 | func (conn MockConnErrorOnToken5) Write(b []byte) (n int, err error) { 607 | conn.WrittenBytes.Write(b) 608 | return len(b), nil 609 | } 610 | func (conn MockConnErrorOnToken5) Close() error { 611 | return nil 612 | } 613 | func (conn MockConnErrorOnToken5) LocalAddr() net.Addr { 614 | return nil 615 | } 616 | func (conn MockConnErrorOnToken5) RemoteAddr() net.Addr { 617 | return nil 618 | } 619 | func (conn MockConnErrorOnToken5) SetDeadline(t time.Time) error { 620 | return nil 621 | } 622 | func (conn MockConnErrorOnToken5) SetReadDeadline(t time.Time) error { 623 | return nil 624 | } 625 | func (conn MockConnErrorOnToken5) SetWriteDeadline(t time.Time) error { 626 | return nil 627 | } 628 | 629 | func TestConnectionShouldReturnErrorOnRandomSocketClose(t *testing.T) { 630 | socket := MockConnErrorOnToken5{ 631 | WrittenBytes: new(bytes.Buffer), 632 | DisconnectChannel: make(chan bool), 633 | } 634 | 635 | apn := socketAPNSConnection(socket, 636 | &APNSConfig{ 637 | InFlightPayloadBufferSize: 10000, 638 | FramingTimeout: 10, 639 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 640 | MaxPayloadSize: 2048, 641 | }) 642 | 643 | token := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8f" 644 | 645 | payload := &Payload{ 646 | AlertText: "Testing", 647 | Token: token, 648 | } 649 | 650 | apn.SendChannel <- payload 651 | 652 | socket.DisconnectChannel <- true 653 | 654 | connectionClose := <-apn.CloseChannel 655 | 656 | if connectionClose.Error == nil { 657 | fmt.Printf("Should have received error but received %v\n", connectionClose.Error) 658 | t.FailNow() 659 | } 660 | 661 | if connectionClose.Error.ErrorCode != CONNECTION_CLOSED_UNKNOWN { 662 | fmt.Printf("Should have received error code CONNECTION_CLOSED_UNKNOWN but received %v\n", connectionClose.Error.ErrorCode) 663 | t.FailNow() 664 | } 665 | 666 | if connectionClose.ErrorPayload != nil { 667 | fmt.Printf("Should NOT have received error payload but received %v\n", connectionClose.ErrorPayload) 668 | t.FailNow() 669 | } 670 | } 671 | 672 | func TestShouldNotWritePayloadOnBadToken(t *testing.T) { 673 | socket := MockConnErrorOnWrite{ 674 | WrittenBytes: new(bytes.Buffer), 675 | CloseChannel: make(chan bool), 676 | } 677 | 678 | apn := socketAPNSConnection(socket, 679 | &APNSConfig{ 680 | InFlightPayloadBufferSize: 10000, 681 | FramingTimeout: 10, 682 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 683 | MaxPayloadSize: 2048, 684 | }) 685 | 686 | payload := &Payload{ 687 | AlertText: "Testing", 688 | Token: "4ec500", 689 | } 690 | 691 | apn.SendChannel <- payload 692 | 693 | apn.Disconnect() 694 | 695 | if socket.WrittenBytes.Len() != 0 { 696 | fmt.Printf("Expected no bytes to be written but bytes were written\n") 697 | t.FailNow() 698 | } 699 | } 700 | 701 | func TestShouldNotWritePayloadIfUnableToMarshall(t *testing.T) { 702 | socket := MockConnErrorOnWrite{ 703 | WrittenBytes: new(bytes.Buffer), 704 | CloseChannel: make(chan bool), 705 | } 706 | 707 | apn := socketAPNSConnection(socket, 708 | &APNSConfig{ 709 | InFlightPayloadBufferSize: 10000, 710 | FramingTimeout: 10, 711 | MaxOutboundTCPFrameSize: TCP_FRAME_MAX, 712 | MaxPayloadSize: 10, 713 | }) 714 | 715 | payload := &Payload{ 716 | AlertText: "Testing this payload", 717 | Token: "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8c", 718 | } 719 | 720 | apn.SendChannel <- payload 721 | 722 | apn.Disconnect() 723 | 724 | if socket.WrittenBytes.Len() != 0 { 725 | fmt.Printf("Expected no bytes to be written but bytes were written\n") 726 | t.FailNow() 727 | } 728 | } 729 | -------------------------------------------------------------------------------- /feedback_service.go: -------------------------------------------------------------------------------- 1 | //Package for creating a connection to Apple's APNS gateway and facilitating 2 | //sending push notifications via that gateway 3 | package apns 4 | 5 | import ( 6 | "container/list" 7 | "crypto/tls" 8 | "encoding/binary" 9 | "encoding/hex" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "net" 14 | "time" 15 | ) 16 | 17 | //Config for creating an APNS Feedback Service Connection 18 | type APNSFeedbackServiceConfig struct { 19 | //bytes for cert.pem : required 20 | CertificateBytes []byte 21 | //bytes for key.pem : required 22 | KeyBytes []byte 23 | //apple gateway, defaults to "feedback.push.apple.com" 24 | GatewayHost string 25 | //apple gateway port, defaults to "2196" 26 | GatewayPort string 27 | //number of seconds to wait for connection before bailing, defaults to 5 seconds 28 | SocketTimeout int 29 | //number of seconds to wait for Tls handshake to complete before bailing, defaults to 5 seconds 30 | TlsTimeout int 31 | } 32 | 33 | //Feedback Response 34 | type FeedbackResponse struct { 35 | //A timestamp indicating when APNs 36 | //determined that the app no longer exists on the device. 37 | //This value represents the seconds since 12:00 midnight on January 1, 1970 UTC. 38 | Timestamp uint32 39 | //Device push token 40 | Token string 41 | } 42 | 43 | const ( 44 | //Size of feedback header frame 45 | FEEDBACK_RESPONSE_HEADER_FRAME_SIZE = 6 46 | ) 47 | 48 | //Create a new apns feedback service connection with supplied config 49 | //If invalid config an error will be returned 50 | //Also if unable to create a connection an error will be returned 51 | //Will return a list of *FeedbackResponse or error 52 | func ConnectToFeedbackService(config *APNSFeedbackServiceConfig) (*list.List, error) { 53 | errorStrs := "" 54 | 55 | if config.CertificateBytes == nil || config.KeyBytes == nil { 56 | errorStrs += "Invalid Key/Certificate bytes\n" 57 | } 58 | 59 | if errorStrs != "" { 60 | return nil, errors.New(errorStrs) 61 | } 62 | 63 | if config.GatewayPort == "" { 64 | config.GatewayPort = "2196" 65 | } 66 | if config.GatewayHost == "" { 67 | config.GatewayHost = "feedback.push.apple.com" 68 | } 69 | if config.SocketTimeout == 0 { 70 | config.SocketTimeout = 5 71 | } 72 | if config.TlsTimeout == 0 { 73 | config.TlsTimeout = 5 74 | } 75 | 76 | x509Cert, err := tls.X509KeyPair(config.CertificateBytes, config.KeyBytes) 77 | if err != nil { 78 | //failed to validate key pair 79 | return nil, err 80 | } 81 | 82 | tlsConf := &tls.Config{ 83 | Certificates: []tls.Certificate{x509Cert}, 84 | ServerName: config.GatewayHost, 85 | } 86 | 87 | tcpSocket, err := net.DialTimeout("tcp", 88 | config.GatewayHost+":"+config.GatewayPort, 89 | time.Duration(config.SocketTimeout)*time.Second) 90 | if err != nil { 91 | //failed to connect to gateway 92 | return nil, err 93 | } 94 | 95 | tlsSocket := tls.Client(tcpSocket, tlsConf) 96 | tlsSocket.SetReadDeadline(time.Now().Add(time.Duration(config.TlsTimeout) * time.Second)) 97 | err = tlsSocket.Handshake() 98 | if err != nil { 99 | //failed to handshake with tls information 100 | return nil, err 101 | } 102 | 103 | //hooray! we're connected 104 | 105 | //let socket close itself when we're finished 106 | defer tlsSocket.Close() 107 | 108 | return readFromFeedbackService(tlsSocket) 109 | } 110 | 111 | //Read from the socket until there is no more to be read or an error occurs 112 | //Then close the socket 113 | //On error some responses may be returned so one should check that the list 114 | //returned doesn't have anything in it 115 | func readFromFeedbackService(socket net.Conn) (*list.List, error) { 116 | 117 | headerBuffer := make([]byte, FEEDBACK_RESPONSE_HEADER_FRAME_SIZE) 118 | responses := list.New() 119 | 120 | for { 121 | bytesRead, err := socket.Read(headerBuffer) 122 | if err != nil { 123 | if err == io.EOF { 124 | //we're good, just reached the end of the socket 125 | return responses, nil 126 | } else { 127 | //this is a legit error, return it 128 | return responses, err 129 | } 130 | } 131 | 132 | if bytesRead != FEEDBACK_RESPONSE_HEADER_FRAME_SIZE { 133 | //? should always be this size... 134 | return responses, 135 | errors.New(fmt.Sprintf("Should have read %v header bytes but read %v bytes", 136 | FEEDBACK_RESPONSE_HEADER_FRAME_SIZE, bytesRead)) 137 | } 138 | 139 | tokenSize := int(binary.BigEndian.Uint16(headerBuffer[4:6])) 140 | 141 | tokenBuffer := make([]byte, tokenSize) 142 | 143 | bytesRead, err = socket.Read(tokenBuffer) 144 | if err != nil { 145 | if err == io.EOF { 146 | //we're good, just reached the end of the socket 147 | return responses, nil 148 | } else { 149 | //this is a legit error, return it 150 | return responses, err 151 | } 152 | } 153 | 154 | if bytesRead != tokenSize { 155 | //? should always be this size... 156 | return responses, 157 | errors.New(fmt.Sprintf("Should have read %v token bytes but read %v bytes", 158 | tokenSize, bytesRead)) 159 | } 160 | 161 | response := new(FeedbackResponse) 162 | response.Timestamp = binary.BigEndian.Uint32(headerBuffer[0:4]) 163 | response.Token = hex.EncodeToString(tokenBuffer) 164 | responses.PushBack(response) 165 | } 166 | 167 | return responses, nil 168 | } 169 | -------------------------------------------------------------------------------- /feedback_service_test.go: -------------------------------------------------------------------------------- 1 | package apns 2 | 3 | import ( 4 | "encoding/hex" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func writeToken(b []byte, token string) { 14 | tokenBytes, _ := hex.DecodeString(token) 15 | for i := 0; i < 32; i++ { 16 | b[i] = tokenBytes[i] 17 | } 18 | } 19 | 20 | type MockConnTokens struct { 21 | WriteHeaderState *bool 22 | CurrentResponse **FeedbackResponse 23 | ResponseChannel chan *FeedbackResponse 24 | CloseChannel chan bool 25 | } 26 | 27 | func (conn MockConnTokens) Read(b []byte) (n int, err error) { 28 | if !(*conn.WriteHeaderState) { 29 | //write token 30 | writeToken(b, (*conn.CurrentResponse).Token) 31 | 32 | (*conn.WriteHeaderState) = true 33 | 34 | return 32, nil 35 | } else { 36 | select { 37 | case r := <-conn.ResponseChannel: 38 | (*conn.CurrentResponse) = r 39 | 40 | //write time in big endian 41 | b[0] = uint8(r.Timestamp >> 24) 42 | b[1] = uint8(r.Timestamp >> 16) 43 | b[2] = uint8(r.Timestamp >> 8) 44 | b[3] = uint8(r.Timestamp) 45 | 46 | //write token size 47 | b[4] = uint8(0) 48 | b[5] = uint8(32) 49 | 50 | (*conn.WriteHeaderState) = false 51 | 52 | return 6, nil 53 | case <-conn.CloseChannel: 54 | return 0, io.EOF 55 | } 56 | } 57 | } 58 | func (conn MockConnTokens) Write(b []byte) (n int, err error) { 59 | return 0, nil 60 | } 61 | func (conn MockConnTokens) Close() error { 62 | return nil 63 | } 64 | func (conn MockConnTokens) LocalAddr() net.Addr { 65 | return nil 66 | } 67 | func (conn MockConnTokens) RemoteAddr() net.Addr { 68 | return nil 69 | } 70 | func (conn MockConnTokens) SetDeadline(t time.Time) error { 71 | return nil 72 | } 73 | func (conn MockConnTokens) SetReadDeadline(t time.Time) error { 74 | return nil 75 | } 76 | func (conn MockConnTokens) SetWriteDeadline(t time.Time) error { 77 | return nil 78 | } 79 | 80 | func TestFeedbackServiceReadShouldReturnTokens(t *testing.T) { 81 | var writeHeaderState = true 82 | var feedbackResponse = &FeedbackResponse{} 83 | socket := MockConnTokens{ 84 | CurrentResponse: &feedbackResponse, 85 | WriteHeaderState: &writeHeaderState, 86 | ResponseChannel: make(chan *FeedbackResponse), 87 | CloseChannel: make(chan bool), 88 | } 89 | 90 | token := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8f" 91 | token2 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8e" 92 | token3 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8d" 93 | token4 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8c" 94 | 95 | timestamp := uint32(837431) 96 | 97 | go func() { 98 | r1 := &FeedbackResponse{ 99 | Timestamp: timestamp, 100 | Token: token, 101 | } 102 | r2 := &FeedbackResponse{ 103 | Timestamp: uint32(837432), 104 | Token: token2, 105 | } 106 | r3 := &FeedbackResponse{ 107 | Timestamp: uint32(837433), 108 | Token: token3, 109 | } 110 | r4 := &FeedbackResponse{ 111 | Timestamp: uint32(837434), 112 | Token: token4, 113 | } 114 | 115 | socket.ResponseChannel <- r1 116 | socket.ResponseChannel <- r2 117 | socket.ResponseChannel <- r3 118 | socket.ResponseChannel <- r4 119 | socket.CloseChannel <- true 120 | }() 121 | 122 | responses, err := readFromFeedbackService(socket) 123 | if err != nil { 124 | fmt.Printf("Shouldn't have received an error but got %v\n", err) 125 | t.FailNow() 126 | } 127 | 128 | if responses.Len() != 4 { 129 | fmt.Printf("Should've received 4 tokens\n") 130 | t.FailNow() 131 | } 132 | 133 | var res = responses.Front() 134 | recToken := res.Value.(*FeedbackResponse).Token 135 | recTime := res.Value.(*FeedbackResponse).Timestamp 136 | res = res.Next() 137 | recToken2 := res.Value.(*FeedbackResponse).Token 138 | res = res.Next() 139 | recToken3 := res.Value.(*FeedbackResponse).Token 140 | res = res.Next() 141 | recToken4 := res.Value.(*FeedbackResponse).Token 142 | 143 | if recToken != token { 144 | fmt.Printf("Should've received token %v but got %v\n", token, recToken) 145 | t.FailNow() 146 | } 147 | 148 | if recToken2 != token2 { 149 | fmt.Printf("Should've received token2 %v but got %v\n", token2, recToken2) 150 | t.FailNow() 151 | } 152 | 153 | if recToken3 != token3 { 154 | fmt.Printf("Should've received token3 %v but got %v\n", token3, recToken3) 155 | t.FailNow() 156 | } 157 | 158 | if recToken4 != token4 { 159 | fmt.Printf("Should've received token4 %v but got %v\n", token4, recToken4) 160 | t.FailNow() 161 | } 162 | 163 | if recTime != timestamp { 164 | fmt.Printf("Should've received timestamp %v but got %v\n", timestamp, recTime) 165 | t.FailNow() 166 | } 167 | } 168 | 169 | type MockConnTokensAndErr struct { 170 | WriteHeaderState *bool 171 | CurrentResponse **FeedbackResponse 172 | ResponseChannel chan *FeedbackResponse 173 | CloseChannel chan bool 174 | } 175 | 176 | func (conn MockConnTokensAndErr) Read(b []byte) (n int, err error) { 177 | if !(*conn.WriteHeaderState) { 178 | //write token 179 | writeToken(b, (*conn.CurrentResponse).Token) 180 | 181 | (*conn.WriteHeaderState) = true 182 | 183 | return 32, nil 184 | } else { 185 | select { 186 | case r := <-conn.ResponseChannel: 187 | (*conn.CurrentResponse) = r 188 | 189 | //write time in big endian 190 | b[0] = uint8(r.Timestamp >> 24) 191 | b[1] = uint8(r.Timestamp >> 16) 192 | b[2] = uint8(r.Timestamp >> 8) 193 | b[3] = uint8(r.Timestamp) 194 | 195 | //write token size 196 | b[4] = uint8(0) 197 | b[5] = uint8(32) 198 | 199 | (*conn.WriteHeaderState) = false 200 | 201 | return 6, nil 202 | case <-conn.CloseChannel: 203 | return 0, errors.New("Some random error") 204 | } 205 | } 206 | } 207 | func (conn MockConnTokensAndErr) Write(b []byte) (n int, err error) { 208 | return 0, nil 209 | } 210 | func (conn MockConnTokensAndErr) Close() error { 211 | return nil 212 | } 213 | func (conn MockConnTokensAndErr) LocalAddr() net.Addr { 214 | return nil 215 | } 216 | func (conn MockConnTokensAndErr) RemoteAddr() net.Addr { 217 | return nil 218 | } 219 | func (conn MockConnTokensAndErr) SetDeadline(t time.Time) error { 220 | return nil 221 | } 222 | func (conn MockConnTokensAndErr) SetReadDeadline(t time.Time) error { 223 | return nil 224 | } 225 | func (conn MockConnTokensAndErr) SetWriteDeadline(t time.Time) error { 226 | return nil 227 | } 228 | 229 | func TestFeedbackServiceReadShouldReturnTokensAndError(t *testing.T) { 230 | var writeHeaderState = true 231 | var feedbackResponse = &FeedbackResponse{} 232 | socket := MockConnTokensAndErr{ 233 | CurrentResponse: &feedbackResponse, 234 | WriteHeaderState: &writeHeaderState, 235 | ResponseChannel: make(chan *FeedbackResponse), 236 | CloseChannel: make(chan bool), 237 | } 238 | 239 | token := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8f" 240 | token2 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8e" 241 | token3 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8d" 242 | token4 := "4ec500020d8350072d2417ba566feda10b2b266558371a65ba67fede21393c8c" 243 | 244 | go func() { 245 | r1 := &FeedbackResponse{ 246 | Timestamp: uint32(837431), 247 | Token: token, 248 | } 249 | r2 := &FeedbackResponse{ 250 | Timestamp: uint32(837432), 251 | Token: token2, 252 | } 253 | r3 := &FeedbackResponse{ 254 | Timestamp: uint32(837433), 255 | Token: token3, 256 | } 257 | r4 := &FeedbackResponse{ 258 | Timestamp: uint32(837434), 259 | Token: token4, 260 | } 261 | 262 | socket.ResponseChannel <- r1 263 | socket.ResponseChannel <- r2 264 | socket.ResponseChannel <- r3 265 | socket.ResponseChannel <- r4 266 | socket.CloseChannel <- true 267 | }() 268 | 269 | responses, err := readFromFeedbackService(socket) 270 | if err == nil { 271 | fmt.Printf("Should have received an error\n") 272 | t.FailNow() 273 | } 274 | 275 | if responses.Len() != 4 { 276 | fmt.Printf("Should've received 4 tokens\n") 277 | t.FailNow() 278 | } 279 | 280 | var res = responses.Front() 281 | recToken := res.Value.(*FeedbackResponse).Token 282 | res = res.Next() 283 | recToken2 := res.Value.(*FeedbackResponse).Token 284 | res = res.Next() 285 | recToken3 := res.Value.(*FeedbackResponse).Token 286 | res = res.Next() 287 | recToken4 := res.Value.(*FeedbackResponse).Token 288 | 289 | if recToken != token { 290 | fmt.Printf("Should've received token %v but got %v\n", token, recToken) 291 | t.FailNow() 292 | } 293 | 294 | if recToken2 != token2 { 295 | fmt.Printf("Should've received token2 %v but got %v\n", token2, recToken2) 296 | t.FailNow() 297 | } 298 | 299 | if recToken3 != token3 { 300 | fmt.Printf("Should've received token3 %v but got %v\n", token3, recToken3) 301 | t.FailNow() 302 | } 303 | 304 | if recToken4 != token4 { 305 | fmt.Printf("Should've received token4 %v but got %v\n", token4, recToken4) 306 | t.FailNow() 307 | } 308 | } 309 | -------------------------------------------------------------------------------- /payload.go: -------------------------------------------------------------------------------- 1 | package apns 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | //Object describing a push notification payload 10 | type Payload struct { 11 | // Basic alert structure 12 | AlertText string 13 | Badge BadgeNumber 14 | Sound string 15 | ContentAvailable int 16 | Category string 17 | 18 | // If this is an enhanced message, use 19 | // an APSAlertBody instead of .Alert 20 | AlertBody APSAlertBody 21 | 22 | // Any custom fields to be added to the apns payload 23 | // These exist outside of the `aps` namespace 24 | CustomFields map[string]interface{} 25 | 26 | // Payload server fields 27 | // UNIX time in seconds when the payload is invalid 28 | ExpirationTime uint32 29 | // Must be either 5 or 10, if not one of these two values will default to 5 30 | Priority uint8 31 | 32 | // Device push token, should contain no spaces 33 | Token string 34 | 35 | // Any extra data to be associated with this payload, 36 | // Will not be sent to apple but will be held onto for error cases 37 | ExtraData interface{} 38 | } 39 | 40 | type APSAlertBody struct { 41 | // Text of the alert 42 | Body string `json:"body,omitempty"` 43 | 44 | // Other alert options 45 | ActionLocKey string `json:"action-loc-key,omitempty"` 46 | LocKey string `json:"loc-key,omitempty"` 47 | LocArgs []string `json:"loc-args,omitempty"` 48 | LaunchImage string `json:"launch-image,omitempty"` 49 | 50 | // New Title fields and localizations. >= iOS 8.2 51 | Title string `json:"title,omitempty"` 52 | TitleLocKey string `json:"title-loc-key,omitempty"` 53 | TitleLocArgs []string `json:"title-loc-args,omitempty"` 54 | } 55 | 56 | type alertBodyAps struct { 57 | Alert APSAlertBody 58 | Badge BadgeNumber 59 | Sound string 60 | Category string 61 | ContentAvailable int 62 | } 63 | 64 | type simpleAps struct { 65 | Alert string 66 | Badge BadgeNumber 67 | Sound string 68 | Category string 69 | ContentAvailable int 70 | } 71 | 72 | // Convert a Payload into a json object and then converted to a byte array 73 | // If the number of converted bytes is greater than the maxPayloadSize 74 | // an attempt will be made to truncate the AlertText 75 | // If this cannot be done, then an error will be returned 76 | func (p *Payload) Marshal(maxPayloadSize int) ([]byte, error) { 77 | if p.isSimple() { 78 | return p.marshalSimplePayload(maxPayloadSize) 79 | } else { 80 | return p.marshalAlertBodyPayload(maxPayloadSize) 81 | } 82 | } 83 | 84 | //Whether or not to use simple aps format or not 85 | func (p *Payload) isSimple() bool { 86 | return p.AlertBody.Body == "" 87 | } 88 | 89 | //Helper method to generate a json compatible map with aps key + custom fields 90 | //will return error if custom field named aps supplied 91 | func constructFullPayload(aps interface{}, customFields map[string]interface{}) (map[string]interface{}, error) { 92 | var fullPayload = make(map[string]interface{}) 93 | fullPayload["aps"] = aps 94 | for key, value := range customFields { 95 | if key == "aps" { 96 | return nil, errors.New("Cannot have a custom field named aps") 97 | } 98 | fullPayload[key] = value 99 | } 100 | return fullPayload, nil 101 | } 102 | 103 | //Handle simple payload case with just text alert 104 | //Handle truncating of alert text if too long for maxPayloadSize 105 | func (p *Payload) marshalSimplePayload(maxPayloadSize int) ([]byte, error) { 106 | var jsonStr []byte 107 | 108 | //use simple payload 109 | aps := simpleAps{ 110 | Alert: p.AlertText, 111 | Badge: p.Badge, 112 | Sound: p.Sound, 113 | Category: p.Category, 114 | ContentAvailable: p.ContentAvailable, 115 | } 116 | 117 | fullPayload, err := constructFullPayload(aps, p.CustomFields) 118 | if err != nil { 119 | return nil, err 120 | } 121 | 122 | jsonStr, err = json.Marshal(fullPayload) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | payloadLen := len(jsonStr) 128 | 129 | if payloadLen > maxPayloadSize { 130 | clipSize := payloadLen - (maxPayloadSize) + 3 //need extra characters for ellipse 131 | if clipSize > len(p.AlertText) { 132 | return nil, errors.New(fmt.Sprintf("Payload was too long to successfully marshall to less than %v", maxPayloadSize)) 133 | } 134 | aps.Alert = aps.Alert[:len(aps.Alert)-clipSize] + "..." 135 | fullPayload["aps"] = aps 136 | if err != nil { 137 | return nil, err 138 | } 139 | 140 | jsonStr, err = json.Marshal(fullPayload) 141 | if err != nil { 142 | return nil, err 143 | } 144 | } 145 | 146 | return jsonStr, nil 147 | } 148 | 149 | //Handle complet payload case with alert object 150 | //Handle truncating of alert text if too long for maxPayloadSize 151 | func (p *Payload) marshalAlertBodyPayload(maxPayloadSize int) ([]byte, error) { 152 | var jsonStr []byte 153 | 154 | // Use APSAlertBody payload 155 | aps := alertBodyAps{ 156 | Alert: p.AlertBody, 157 | Badge: p.Badge, 158 | Sound: p.Sound, 159 | Category: p.Category, 160 | ContentAvailable: p.ContentAvailable, 161 | } 162 | 163 | fullPayload, err := constructFullPayload(aps, p.CustomFields) 164 | if err != nil { 165 | return nil, err 166 | } 167 | 168 | jsonStr, err = json.Marshal(fullPayload) 169 | if err != nil { 170 | return nil, err 171 | } 172 | 173 | payloadLen := len(jsonStr) 174 | 175 | if payloadLen > maxPayloadSize { 176 | clipSize := payloadLen - (maxPayloadSize) + 3 //need extra characters for ellipse 177 | if clipSize > len(p.AlertBody.Body) { 178 | return nil, errors.New(fmt.Sprintf("Payload was too long to successfully marshall %v or less bytes", maxPayloadSize)) 179 | } 180 | aps.Alert.Body = aps.Alert.Body[:len(aps.Alert.Body)-clipSize] + "..." 181 | fullPayload["aps"] = aps 182 | if err != nil { 183 | return nil, err 184 | } 185 | 186 | jsonStr, err = json.Marshal(fullPayload) 187 | if err != nil { 188 | return nil, err 189 | } 190 | } 191 | 192 | return jsonStr, nil 193 | } 194 | 195 | func (s simpleAps) MarshalJSON() ([]byte, error) { 196 | toMarshal := make(map[string]interface{}) 197 | 198 | if s.Alert != "" { 199 | toMarshal["alert"] = s.Alert 200 | } 201 | if s.Badge.IsSet() { 202 | toMarshal["badge"] = s.Badge 203 | } 204 | if s.Sound != "" { 205 | toMarshal["sound"] = s.Sound 206 | } 207 | if s.Category != "" { 208 | toMarshal["category"] = s.Category 209 | } 210 | if s.ContentAvailable != 0 { 211 | toMarshal["content-available"] = s.ContentAvailable 212 | } 213 | 214 | return json.Marshal(toMarshal) 215 | } 216 | 217 | func (a alertBodyAps) MarshalJSON() ([]byte, error) { 218 | toMarshal := make(map[string]interface{}) 219 | toMarshal["alert"] = a.Alert 220 | 221 | if a.Badge.IsSet() { 222 | toMarshal["badge"] = a.Badge 223 | } 224 | if a.Sound != "" { 225 | toMarshal["sound"] = a.Sound 226 | } 227 | if a.Category != "" { 228 | toMarshal["category"] = a.Category 229 | } 230 | if a.ContentAvailable != 0 { 231 | toMarshal["content-available"] = a.ContentAvailable 232 | } 233 | 234 | return json.Marshal(toMarshal) 235 | } 236 | -------------------------------------------------------------------------------- /payload_test.go: -------------------------------------------------------------------------------- 1 | package apns 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestSimpleMarshal(t *testing.T) { 9 | p := Payload{ 10 | AlertText: "Testing this payload", 11 | Badge: NewBadgeNumber(2), 12 | ContentAvailable: 1, 13 | Sound: "test.aiff", 14 | Category: "TEST_CATEGORY", 15 | } 16 | 17 | payloadSize := 256 18 | 19 | json, err := p.Marshal(payloadSize) 20 | if err != nil { 21 | t.Error(err) 22 | } 23 | 24 | if len(json) > payloadSize { 25 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 26 | } 27 | 28 | expectedJson := "{\"aps\":{\"alert\":\"Testing this payload\",\"badge\":2,\"category\":\"TEST_CATEGORY\",\"content-available\":1,\"sound\":\"test.aiff\"}}" 29 | if string(json) != expectedJson { 30 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, string(json))) 31 | } 32 | } 33 | 34 | func TestBadge0ShouldOmitBadge(t *testing.T) { 35 | p := Payload{ 36 | AlertText: "Testing this payload", 37 | ContentAvailable: 1, 38 | Sound: "test.aiff", 39 | Category: "TEST_CATEGORY", 40 | } 41 | 42 | payloadSize := 256 43 | 44 | json, err := p.Marshal(payloadSize) 45 | if err != nil { 46 | t.Error(err) 47 | } 48 | 49 | if len(json) > payloadSize { 50 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 51 | } 52 | 53 | expectedJson := "{\"aps\":{\"alert\":\"Testing this payload\",\"category\":\"TEST_CATEGORY\",\"content-available\":1,\"sound\":\"test.aiff\"}}" 54 | if string(json) != expectedJson { 55 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, string(json))) 56 | } 57 | } 58 | 59 | func TestSimpleMarshalWithCustomFields(t *testing.T) { 60 | customFields := map[string]interface{}{ 61 | "num": 55, 62 | "str": "string", 63 | "arr": []interface{}{"a", 2}, 64 | "obj": map[string]string{ 65 | "obja": "a", 66 | "objb": "b", 67 | }, 68 | } 69 | 70 | p := Payload{ 71 | AlertText: "Testing this payload", 72 | Badge: NewBadgeNumber(2), 73 | ContentAvailable: 1, 74 | Sound: "test.aiff", 75 | CustomFields: customFields, 76 | } 77 | 78 | payloadSize := 256 79 | 80 | json, err := p.Marshal(payloadSize) 81 | if err != nil { 82 | t.Error(err) 83 | } 84 | 85 | if len(json) > payloadSize { 86 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 87 | } 88 | 89 | expectedJson := "{\"aps\":{\"alert\":\"Testing this payload\",\"badge\":2,\"content-available\":1,\"sound\":\"test.aiff\"},\"arr\":[\"a\",2],\"num\":55,\"obj\":{\"obja\":\"a\",\"objb\":\"b\"},\"str\":\"string\"}" 90 | if string(json) != expectedJson { 91 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, json)) 92 | } 93 | } 94 | 95 | func TestSimpleMarshalTruncate(t *testing.T) { 96 | p := Payload{ 97 | AlertText: "Testing this payload with a really long message that should " + 98 | "cause the payload to be truncated yay and stuff blah blah blah blah blah blah " + 99 | "and some more text to really make this much bigger and stuff", 100 | Badge: NewBadgeNumber(2), 101 | ContentAvailable: 1, 102 | Sound: "test.aiff", 103 | } 104 | 105 | payloadSize := 256 106 | 107 | json, err := p.Marshal(payloadSize) 108 | if err != nil { 109 | t.Error(err) 110 | } 111 | 112 | if len(json) > payloadSize { 113 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 114 | } 115 | 116 | expectedJson := "{\"aps\":{\"alert\":\"Testing this payload with a really long message that should cause the payload to be truncated yay and stuff blah blah blah blah blah blah and some more text to really make this much...\",\"badge\":2,\"content-available\":1,\"sound\":\"test.aiff\"}}" 117 | if string(json) != expectedJson { 118 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, json)) 119 | } 120 | } 121 | 122 | func TestSimpleMarshalTruncateWithCustomFields(t *testing.T) { 123 | customFields := map[string]interface{}{ 124 | "num": 55, 125 | "str": "string", 126 | "arr": []interface{}{"a", 2}, 127 | "obj": map[string]string{ 128 | "obja": "a", 129 | "objb": "b", 130 | }, 131 | } 132 | 133 | p := Payload{ 134 | AlertText: "Testing this payload with a bunch of text that should get truncated " + 135 | "so truncate this already please yes thank you blah blah blah blah blah blah " + 136 | "plus some more text", 137 | Badge: NewBadgeNumber(2), 138 | ContentAvailable: 1, 139 | Sound: "test.aiff", 140 | CustomFields: customFields, 141 | } 142 | 143 | payloadSize := 256 144 | 145 | json, err := p.Marshal(payloadSize) 146 | if err != nil { 147 | t.Error(err) 148 | } 149 | 150 | if len(json) > payloadSize { 151 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 152 | } 153 | 154 | expectedJson := "{\"aps\":{\"alert\":\"Testing this payload with a bunch of text that should get truncated " + 155 | "so truncate this already please yes thank you...\",\"badge\":2,\"content-available\":1,\"sound\":\"test.aiff\"}," + 156 | "\"arr\":[\"a\",2],\"num\":55,\"obj\":{\"obja\":\"a\",\"objb\":\"b\"},\"str\":\"string\"}" 157 | if string(json) != expectedJson { 158 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, json)) 159 | } 160 | } 161 | 162 | func TestSimpleMarshalThrowErrorIfPayloadTooBigWithCustomFields(t *testing.T) { 163 | //lots of custom fields to force failure 164 | customFields := map[string]interface{}{ 165 | "num": 55, 166 | "str": "string", 167 | "arr": []interface{}{"a", 2}, 168 | "obj": map[string]string{ 169 | "obja": "a", 170 | "objb": "b", 171 | }, 172 | "obj2": map[string]string{ 173 | "obja": "a", 174 | "objb": "b", 175 | }, 176 | "obj3": map[string]string{ 177 | "obja": "a", 178 | "objb": "b", 179 | }, 180 | "obj4": map[string]string{ 181 | "obja": "a", 182 | "objb": "b", 183 | }, 184 | "obj5": map[string]string{ 185 | "obja": "a", 186 | "objb": "b", 187 | }, 188 | } 189 | 190 | p := Payload{ 191 | AlertText: "Testing this payload", 192 | Badge: NewBadgeNumber(2), 193 | ContentAvailable: 1, 194 | Sound: "test.aiff", 195 | CustomFields: customFields, 196 | } 197 | 198 | payloadSize := 256 199 | 200 | _, err := p.Marshal(payloadSize) 201 | if err == nil { 202 | t.Error("Should have thrown marshaling error") 203 | } 204 | } 205 | 206 | func TestAlertBodyMarshal(t *testing.T) { 207 | p := Payload{ 208 | Badge: NewBadgeNumber(2), 209 | ContentAvailable: 1, 210 | Category: "TEST_CATEGORY", 211 | Sound: "test.aiff", 212 | AlertBody: APSAlertBody{ 213 | Body: "Testing this payload", 214 | ActionLocKey: "act-loc-key", 215 | LocKey: "loc-key", 216 | LocArgs: []string{"arg1", "arg2"}, 217 | LaunchImage: "launch.png", 218 | }, 219 | } 220 | 221 | payloadSize := 256 222 | 223 | json, err := p.Marshal(payloadSize) 224 | if err != nil { 225 | t.Error(err) 226 | } 227 | 228 | if len(json) > payloadSize { 229 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 230 | } 231 | 232 | expectedJson := "{\"aps\":{\"alert\":{\"body\":\"Testing this payload\",\"action-loc-key\":\"act-loc-key\",\"loc-key\":\"loc-key\",\"loc-args\":[\"arg1\",\"arg2\"],\"launch-image\":\"launch.png\"},\"badge\":2,\"category\":\"TEST_CATEGORY\",\"content-available\":1,\"sound\":\"test.aiff\"}}" 233 | if string(json) != expectedJson { 234 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, string(json))) 235 | } 236 | } 237 | 238 | func TestAlertBodyMarshalWithCustomFields(t *testing.T) { 239 | customFields := map[string]interface{}{ 240 | "num": 55, 241 | "str": "string", 242 | "arr": []interface{}{"a", 2}, 243 | "obj": map[string]string{ 244 | "obja": "a", 245 | "objb": "b", 246 | }, 247 | } 248 | 249 | p := Payload{ 250 | Badge: NewBadgeNumber(2), 251 | ContentAvailable: 1, 252 | Sound: "test.aiff", 253 | CustomFields: customFields, 254 | AlertBody: APSAlertBody{ 255 | Body: "Testing this payload", 256 | ActionLocKey: "act-loc-key", 257 | LocKey: "loc-key", 258 | LaunchImage: "launch.png", 259 | }, 260 | } 261 | 262 | payloadSize := 256 263 | 264 | json, err := p.Marshal(payloadSize) 265 | if err != nil { 266 | t.Error(err) 267 | } 268 | 269 | if len(json) > payloadSize { 270 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 271 | } 272 | 273 | expectedJson := "{\"aps\":{\"alert\":{\"body\":\"Testing this payload\",\"action-loc-key\":\"act-loc-key\",\"loc-key\":\"loc-key\"," + 274 | "\"launch-image\":\"launch.png\"}," + 275 | "\"badge\":2,\"content-available\":1,\"sound\":\"test.aiff\"},\"arr\":[\"a\",2]," + 276 | "\"num\":55,\"obj\":{\"obja\":\"a\",\"objb\":\"b\"},\"str\":\"string\"}" 277 | 278 | if string(json) != expectedJson { 279 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, string(json))) 280 | } 281 | } 282 | 283 | func TestAlertBodyMarshalTruncate(t *testing.T) { 284 | p := Payload{ 285 | Badge: NewBadgeNumber(2), 286 | ContentAvailable: 1, 287 | Sound: "test.aiff", 288 | AlertBody: APSAlertBody{ 289 | Body: "Testing this payload with a really long message that should " + 290 | "cause the payload to be truncated yay and stuff blah blah blah blah blah blah " + 291 | "and some more text to really make this much bigger and stuff", 292 | LaunchImage: "launch.png", 293 | }, 294 | } 295 | 296 | payloadSize := 256 297 | 298 | json, err := p.Marshal(payloadSize) 299 | if err != nil { 300 | t.Error(err) 301 | } 302 | 303 | if len(json) > payloadSize { 304 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 305 | } 306 | 307 | expectedJson := "{\"aps\":{\"alert\":{\"body\":\"Testing this payload with a really long message that should cause the payload to be truncated yay and stuff blah blah blah blah blah blah and so...\",\"launch-image\":\"launch.png\"},\"badge\":2,\"content-available\":1,\"sound\":\"test.aiff\"}}" 308 | if string(json) != expectedJson { 309 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, string(json))) 310 | } 311 | } 312 | 313 | func TestAlertBodyMarshalTruncateWithCustomFields(t *testing.T) { 314 | customFields := map[string]interface{}{ 315 | "num": 55, 316 | "str": "string", 317 | "arr": []interface{}{"a", 2}, 318 | "arr2": []interface{}{"a", 2}, 319 | } 320 | 321 | p := Payload{ 322 | Badge: NewBadgeNumber(2), 323 | ContentAvailable: 1, 324 | Sound: "test.aiff", 325 | CustomFields: customFields, 326 | AlertBody: APSAlertBody{ 327 | Body: "Testing this payload with a bunch of text that should get truncated " + 328 | "so truncate this already please yes thank you blah blah blah blah blah blah " + 329 | "plus some more text", 330 | ActionLocKey: "act-loc-key", 331 | LocKey: "loc-key", 332 | LocArgs: []string{"arg1", "arg2"}, 333 | LaunchImage: "launch.png", 334 | }, 335 | } 336 | 337 | payloadSize := 256 338 | 339 | json, err := p.Marshal(payloadSize) 340 | if err != nil { 341 | t.Error(err) 342 | } 343 | 344 | if len(json) > payloadSize { 345 | t.Error(fmt.Sprintf("Expected payload to be less than %v but was %v", payloadSize, len(json))) 346 | } 347 | 348 | expectedJson := "{\"aps\":{\"alert\":{\"body\":\"Testing this ...\",\"action-loc-key\":\"act-loc-key\",\"loc-key\":\"loc-key\"," + 349 | "\"loc-args\":[\"arg1\",\"arg2\"],\"launch-image\":\"launch.png\"},\"badge\":2,\"content-available\":1,\"sound\":\"test.aiff\"}," + 350 | "\"arr\":[\"a\",2],\"arr2\":[\"a\",2],\"num\":55,\"str\":\"string\"}" 351 | if string(json) != expectedJson { 352 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, string(json))) 353 | } 354 | } 355 | 356 | func TestAlertBodyMarshalThrowErrorIfPayloadTooBigWithCustomFields(t *testing.T) { 357 | //lots of custom fields to force failure 358 | customFields := map[string]interface{}{ 359 | "num": 55, 360 | "str": "string", 361 | "arr": []interface{}{"a", 2}, 362 | "obj": map[string]string{ 363 | "obja": "a", 364 | "objb": "b", 365 | }, 366 | "obj2": map[string]string{ 367 | "obja": "a", 368 | "objb": "b", 369 | }, 370 | "obj3": map[string]string{ 371 | "obja": "a", 372 | "objb": "b", 373 | }, 374 | "obj4": map[string]string{ 375 | "obja": "a", 376 | "objb": "b", 377 | }, 378 | "obj5": map[string]string{ 379 | "obja": "a", 380 | "objb": "b", 381 | }, 382 | } 383 | 384 | p := Payload{ 385 | Badge: NewBadgeNumber(2), 386 | ContentAvailable: 1, 387 | Sound: "test.aiff", 388 | CustomFields: customFields, 389 | AlertBody: APSAlertBody{ 390 | Body: "Testing this payload", 391 | LaunchImage: "launch.png", 392 | }, 393 | } 394 | 395 | payloadSize := 256 396 | 397 | _, err := p.Marshal(payloadSize) 398 | if err == nil { 399 | t.Error("Should have thrown marshaling error") 400 | } 401 | } 402 | 403 | func TestBadgeOnlyMarshal(t *testing.T) { 404 | p := Payload{ 405 | Badge: NewBadgeNumber(2), 406 | } 407 | 408 | payloadSize := 256 409 | 410 | json, err := p.Marshal(payloadSize) 411 | if err != nil { 412 | t.Error(err) 413 | } 414 | 415 | expectedJson := "{\"aps\":{\"badge\":2}}" 416 | if string(json) != expectedJson { 417 | t.Error(fmt.Sprintf("Expected %v but got %v", expectedJson, string(json))) 418 | } 419 | } 420 | 421 | func BenchmarkSimpleMarshalTruncate256WithCustomFields(b *testing.B) { 422 | customFields := map[string]interface{}{ 423 | "num": 55, 424 | "str": "string", 425 | "arr": []interface{}{"a", 2}, 426 | "obj": map[string]string{ 427 | "obja": "a", 428 | "objb": "b", 429 | }, 430 | } 431 | 432 | p := Payload{ 433 | AlertText: "Testing this payload with a bunch of text that should get truncated " + 434 | "so truncate this already please yes thank you blah blah blah blah blah blah " + 435 | "plus some more text", 436 | Badge: NewBadgeNumber(2), 437 | ContentAvailable: 1, 438 | Sound: "test.aiff", 439 | CustomFields: customFields, 440 | } 441 | 442 | b.ResetTimer() 443 | for i := 0; i < b.N; i++ { 444 | p.Marshal(256) 445 | } 446 | } 447 | 448 | func BenchmarkSimpleMarshalTruncate1024WithCustomFields(b *testing.B) { 449 | customFields := map[string]interface{}{ 450 | "num": 55, 451 | "str": "string", 452 | "arr": []interface{}{"a", 2}, 453 | "obj": map[string]string{ 454 | "obja": "a", 455 | "objb": "b", 456 | }, 457 | } 458 | 459 | p := Payload{ 460 | AlertText: "Testing this payload with a bunch of text that should get truncated " + 461 | "so truncate this already please yes thank you blah blah blah blah blah blah " + 462 | "plus some more text", 463 | Badge: NewBadgeNumber(2), 464 | ContentAvailable: 1, 465 | Sound: "test.aiff", 466 | CustomFields: customFields, 467 | } 468 | 469 | b.ResetTimer() 470 | for i := 0; i < b.N; i++ { 471 | p.Marshal(1024) 472 | } 473 | } 474 | 475 | func BenchmarkAlertBodyMarshalTruncate256WithCustomFields(b *testing.B) { 476 | customFields := map[string]interface{}{ 477 | "num": 55, 478 | "str": "string", 479 | "arr": []interface{}{"a", 2}, 480 | "obj": map[string]string{ 481 | "obja": "a", 482 | "objb": "b", 483 | }, 484 | } 485 | 486 | p := Payload{ 487 | Badge: NewBadgeNumber(2), 488 | ContentAvailable: 1, 489 | Sound: "test.aiff", 490 | CustomFields: customFields, 491 | AlertBody: APSAlertBody{ 492 | Body: "Testing this payload with a bunch of text that should get truncated " + 493 | "so truncate this already please yes thank you blah blah blah blah blah blah " + 494 | "plus some more text", 495 | LaunchImage: "launch.png", 496 | }, 497 | } 498 | 499 | b.ResetTimer() 500 | for i := 0; i < b.N; i++ { 501 | p.Marshal(256) 502 | } 503 | } 504 | 505 | func BenchmarkAlertBodyMarshalTruncate1024WithCustomFields(b *testing.B) { 506 | customFields := map[string]interface{}{ 507 | "num": 55, 508 | "str": "string", 509 | "arr": []interface{}{"a", 2}, 510 | "obj": map[string]string{ 511 | "obja": "a", 512 | "objb": "b", 513 | }, 514 | } 515 | 516 | p := Payload{ 517 | Badge: NewBadgeNumber(2), 518 | ContentAvailable: 1, 519 | Sound: "test.aiff", 520 | CustomFields: customFields, 521 | AlertBody: APSAlertBody{ 522 | Body: "Testing this payload with a bunch of text that should get truncated " + 523 | "so truncate this already please yes thank you blah blah blah blah blah blah " + 524 | "plus some more text", 525 | ActionLocKey: "act-loc-key", 526 | LocKey: "loc-key", 527 | LocArgs: []string{"arg1", "arg2"}, 528 | LaunchImage: "launch.png", 529 | }, 530 | } 531 | 532 | b.ResetTimer() 533 | for i := 0; i < b.N; i++ { 534 | p.Marshal(1024) 535 | } 536 | } 537 | --------------------------------------------------------------------------------