├── testutil ├── badkey ├── cmix.rip.crt └── cmix.rip.key ├── .gitignore ├── main.go ├── Makefile ├── notifications ├── providers │ ├── provider.go │ ├── fcm.go │ └── apns.go ├── params.go ├── receive_test.go ├── receive.go ├── impl_test.go ├── ndf_test.go ├── send_test.go ├── ephemeral_test.go ├── send.go ├── ndf.go ├── legacyRegistration.go ├── ephemeral.go ├── legacyRegistration_test.go ├── impl.go ├── registration.go └── registration_test.go ├── constants └── constants.go ├── LICENSE ├── README.md ├── cmd ├── version.go ├── version_vars.go └── root.go ├── io └── poll.go ├── storage ├── buffer_test.go ├── buffer.go ├── database_test.go ├── database.go ├── storage.go ├── storage_test.go ├── databaseImpl.go └── databaseImpl_test.go ├── .gitlab-ci.yml └── go.mod /testutil/badkey: -------------------------------------------------------------------------------- 1 | asdf123 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | # Ignore glide files/folders 3 | glide.lock 4 | vendor/ 5 | # Ignore Jetbrains IDE folder 6 | .idea/* 7 | # Ignore vim .swp buffers for open files 8 | .*.swp 9 | .*.swo 10 | # Ignore local development scripts 11 | localdev_* 12 | # Ignore logs 13 | *.log 14 | creds/* 15 | 16 | testutil/ 17 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package main 9 | 10 | import ( 11 | "gitlab.com/elixxir/notifications-bot/cmd" 12 | ) 13 | 14 | func main() { 15 | cmd.Execute() 16 | } 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: update master release update_master update_release build clean version 2 | 3 | version: 4 | go run main.go generate 5 | mv version_vars.go cmd/version_vars.go 6 | 7 | clean: 8 | rm -rf vendor/ 9 | go mod vendor 10 | 11 | update: 12 | -GOFLAGS="" go get all 13 | 14 | build: 15 | go build ./... 16 | go mod tidy 17 | 18 | update_release: 19 | GOFLAGS="" go get gitlab.com/xx_network/primitives@release 20 | GOFLAGS="" go get gitlab.com/elixxir/primitives@release 21 | GOFLAGS="" go get gitlab.com/xx_network/comms@release 22 | GOFLAGS="" go get gitlab.com/elixxir/comms@release 23 | 24 | update_master: 25 | GOFLAGS="" go get gitlab.com/xx_network/primitives@master 26 | GOFLAGS="" go get gitlab.com/elixxir/primitives@master 27 | GOFLAGS="" go get gitlab.com/xx_network/comms@master 28 | GOFLAGS="" go get gitlab.com/elixxir/comms@master 29 | 30 | master: update_master clean build version 31 | 32 | release: update_release clean build version 33 | -------------------------------------------------------------------------------- /notifications/providers/provider.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Package providers contains logic for interacting with external notifications providers such as APNS 9 | 10 | package providers 11 | 12 | import "gitlab.com/elixxir/notifications-bot/storage" 13 | 14 | // Provider interface represents an external notification provider, implementing 15 | // an easy-to-use Notify function for the rest of the repo to call. 16 | type Provider interface { 17 | // Notify sends a notification and returns the token status and an error 18 | Notify(csv string, target storage.GTNResult) (bool, error) 19 | } 20 | -------------------------------------------------------------------------------- /notifications/params.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import "gitlab.com/elixxir/notifications-bot/notifications/providers" 11 | 12 | // Params struct holds info passed in for configuration 13 | type Params struct { 14 | Address string 15 | CertPath string 16 | KeyPath string 17 | NotificationsPerBatch int 18 | MaxNotificationPayload int 19 | NotificationRate int 20 | FBCreds string 21 | APNS providers.APNSParams 22 | HavenFBCreds string 23 | HavenAPNS providers.APNSParams 24 | HttpsCertPath string 25 | HttpsKeyPath string 26 | } 27 | -------------------------------------------------------------------------------- /notifications/receive_test.go: -------------------------------------------------------------------------------- 1 | package notifications 2 | 3 | import ( 4 | pb "gitlab.com/elixxir/comms/mixmessages" 5 | "gitlab.com/elixxir/notifications-bot/storage" 6 | "gitlab.com/xx_network/comms/connect" 7 | "sync" 8 | "testing" 9 | ) 10 | 11 | // Happy path. 12 | func TestImpl_ReceiveNotificationBatch(t *testing.T) { 13 | s, err := storage.NewStorage("", "", "", "", "") 14 | impl := &Impl{ 15 | Storage: s, 16 | roundStore: sync.Map{}, 17 | maxNotifications: 0, 18 | maxPayloadBytes: 0, 19 | } 20 | 21 | notifBatch := &pb.NotificationBatch{ 22 | RoundID: 42, 23 | Notifications: []*pb.NotificationData{ 24 | { 25 | EphemeralID: 5, 26 | IdentityFP: []byte("IdentityFP"), 27 | MessageHash: []byte("MessageHash"), 28 | }, 29 | }, 30 | } 31 | 32 | auth := &connect.Auth{ 33 | IsAuthenticated: true, 34 | } 35 | 36 | err = impl.ReceiveNotificationBatch(notifBatch, auth) 37 | if err != nil { 38 | t.Errorf("ReceiveNotificationBatch() returned an error: %+v", err) 39 | } 40 | 41 | nbm := impl.Storage.GetNotificationBuffer().Swap() 42 | if len(nbm[5]) < 1 { 43 | t.Errorf("Notification was not added to notification buffer: %+v", nbm[5]) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /constants/constants.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package constants 9 | 10 | const NotificationsTag = "notificationData" 11 | const NotificationTitle = "Privacy: protected!" 12 | const NotificationBody = "Some notifications are not for you to ensure privacy; we hope to remove this notification soon" 13 | 14 | type App uint8 15 | 16 | const ( 17 | MessengerIOS App = iota 18 | MessengerAndroid 19 | HavenIOS 20 | HavenAndroid 21 | ) 22 | 23 | func (a App) String() string { 24 | switch a { 25 | case MessengerIOS: 26 | return "messengerIOS" 27 | case MessengerAndroid: 28 | return "messengerAndroid" 29 | case HavenIOS: 30 | return "havenIOS" 31 | case HavenAndroid: 32 | return "havenAndroid" 33 | default: 34 | return "unknown" 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, xx network SEZC 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 2. Redistributions in binary form must reproduce the above copyright notice, 9 | this list of conditions and the following disclaimer in the documentation and/or 10 | other materials provided with the distribution. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 13 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 14 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 15 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 16 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 17 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 18 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 19 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 20 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 21 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # notifications-bot 2 | 3 | Notifications bot is used to send push notifications to users' devices via firebase. It continuously loops, polling gateway for users to notify and sending the notifications. 4 | 5 | # Config File 6 | 7 | ```yaml 8 | # ================================== 9 | # Notification Server Configuration 10 | # ================================== 11 | 12 | # START YAML === 13 | # Verbose logging 14 | logLevel: "${verbose}" 15 | # Path to log file 16 | log: "${log_path}" 17 | 18 | # Database connection information 19 | dbUsername: "${db_username}" 20 | dbPassword: "${db_password}" 21 | dbName: "${db_name}" 22 | dbAddress: "${db_address}" 23 | 24 | # Path to this server's private key file 25 | keyPath: "${key_path}" 26 | # Path to this server's certificate file 27 | certPath: "${cert_path}" 28 | # The listening port of this server 29 | port: ${port} 30 | 31 | # Path to the firebase credentials files 32 | firebaseCredentialsPath: "{fb_creds_path}" 33 | havenFirebaseCredentialsPath: "{fb_creds_path}" 34 | 35 | # Path to the permissioning server certificate file 36 | permissioningCertPath: "${permissioning_cert_path}" 37 | # Address:port of the permissioning server 38 | permissioningAddress: "${permissioning_address}:${port}" 39 | 40 | # XX Messenger APNS parameters 41 | apnsKeyPath: "" 42 | apnsKeyID: "" 43 | apnsIssuer: "" 44 | apnsBundleID: "" 45 | apnsDev: true 46 | 47 | # Haven APNS parameters 48 | havenApnsKeyPath: "" 49 | havenApnsKeyID: "" 50 | havenApnsIssuer: "" 51 | havenApnsBundleID: "" 52 | havenApnsDev: true 53 | 54 | # Notification params 55 | notificationRate: 30 # Duration in seconds 56 | notificationsPerBatch: 20 57 | # === END YAML 58 | ``` 59 | -------------------------------------------------------------------------------- /cmd/version.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Handles command-line version functionality 9 | 10 | package cmd 11 | 12 | import ( 13 | "fmt" 14 | "github.com/spf13/cobra" 15 | "gitlab.com/xx_network/primitives/utils" 16 | ) 17 | 18 | // Change this value to set the version for this build 19 | const currentVersion = "3.0.0" 20 | 21 | func printVersion() { 22 | fmt.Printf("Elixxir Notifications Server v%s -- %s\n\n", SEMVER, GITVERSION) 23 | fmt.Printf("Dependencies:\n\n%s\n", DEPENDENCIES) 24 | } 25 | 26 | func init() { 27 | rootCmd.AddCommand(versionCmd) 28 | rootCmd.AddCommand(generateCmd) 29 | } 30 | 31 | var versionCmd = &cobra.Command{ 32 | Use: "version", 33 | Short: "Print the version and dependency information for the Elixxir binary", 34 | Long: `Print the version and dependency information for the Elixxir binary`, 35 | Run: func(cmd *cobra.Command, args []string) { 36 | printVersion() 37 | }, 38 | } 39 | 40 | var generateCmd = &cobra.Command{ 41 | Use: "generate", 42 | Short: "Generates version and dependency information for the Elixxir binary", 43 | Long: `Generates version and dependency information for the Elixxir binary`, 44 | Run: func(cmd *cobra.Command, args []string) { 45 | utils.GenerateVersionFile(currentVersion) 46 | }, 47 | } 48 | -------------------------------------------------------------------------------- /io/poll.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Poll the network for the NDF. Users should create an ndf Poller with: 9 | // poller := NewNdfPoller(Protocom object, permissioning host) 10 | // and subsequently call poller.GetNdf() to get a new copy of the NDF. 11 | // 12 | // Use the "PollingConn" interface in functions so you can mock the NDF for 13 | // testing. 14 | 15 | package io 16 | 17 | import ( 18 | pb "gitlab.com/elixxir/comms/mixmessages" 19 | "gitlab.com/elixxir/comms/notificationBot" 20 | "gitlab.com/xx_network/comms/connect" 21 | ) 22 | 23 | // PollingConn is an object that implements the PollNdf Function 24 | // and allows it to be mocked for testing. 25 | type PollingConn interface { 26 | PollNdf(ndfHash []byte) (*pb.NDF, error) 27 | } 28 | 29 | // NdfPoller is a regular connection to the permissioning server, created 30 | // with a protocomms object. 31 | type NdfPoller struct { 32 | permHost *connect.Host 33 | pc *notificationBot.Comms 34 | } 35 | 36 | // NewNdfPoller creates a new permconn object with a host and protocomms id. 37 | func NewNdfPoller(pc *notificationBot.Comms, pHost *connect.Host) NdfPoller { 38 | return NdfPoller{ 39 | pc: pc, 40 | permHost: pHost, 41 | } 42 | } 43 | 44 | // PollNdf gets the NDF from the Permissioning server. 45 | func (p NdfPoller) PollNdf(ndfHash []byte) (*pb.NDF, error) { 46 | permHost := p.permHost 47 | return p.pc.PollNdf(permHost, ndfHash) 48 | } 49 | -------------------------------------------------------------------------------- /notifications/receive.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import ( 11 | jww "github.com/spf13/jwalterweatherman" 12 | pb "gitlab.com/elixxir/comms/mixmessages" 13 | "gitlab.com/elixxir/primitives/notifications" 14 | "gitlab.com/xx_network/comms/connect" 15 | "gitlab.com/xx_network/primitives/id" 16 | "time" 17 | ) 18 | 19 | // ReceiveNotificationBatch receives the batch of notification data from gateway. 20 | func (nb *Impl) ReceiveNotificationBatch(notifBatch *pb.NotificationBatch, auth *connect.Auth) error { 21 | rid := notifBatch.RoundID 22 | 23 | _, loaded := nb.roundStore.LoadOrStore(rid, time.Now()) 24 | if loaded { 25 | jww.DEBUG.Printf("Dropping duplicate notification batch for round %+v", notifBatch.RoundID) 26 | return nil 27 | } 28 | 29 | jww.INFO.Printf("Received notification batch for round %+v", notifBatch.RoundID) 30 | 31 | buffer := nb.Storage.GetNotificationBuffer() 32 | data := processNotificationBatch(notifBatch) 33 | buffer.Add(id.Round(notifBatch.RoundID), data) 34 | 35 | return nil 36 | } 37 | 38 | func processNotificationBatch(l *pb.NotificationBatch) []*notifications.Data { 39 | var res []*notifications.Data 40 | for _, item := range l.Notifications { 41 | res = append(res, ¬ifications.Data{ 42 | EphemeralID: item.EphemeralID, 43 | RoundID: l.RoundID, 44 | IdentityFP: item.IdentityFP, 45 | MessageHash: item.MessageHash, 46 | }) 47 | } 48 | return res 49 | } 50 | -------------------------------------------------------------------------------- /testutil/cmix.rip.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIGDTCCA/WgAwIBAgIUNdjL0qGKH2CyhoIsuY7biOomzJswDQYJKoZIhvcNAQEL 3 | BQAwgYoxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTESMBAGA1UEBwwJQ2xhcmVt 4 | b250MRAwDgYDVQQKDAdFbGl4eGlyMRQwEgYDVQQLDAtEZXZlbG9wbWVudDERMA8G 5 | A1UEAwwIY21peC5yaXAxHzAdBgkqhkiG9w0BCQEWEGFkbWluQGVsaXh4aXIuaW8w 6 | HhcNMTkwODE1MjIxNjA1WhcNMjAwODE0MjIxNjA1WjCBijELMAkGA1UEBhMCVVMx 7 | CzAJBgNVBAgMAkNBMRIwEAYDVQQHDAlDbGFyZW1vbnQxEDAOBgNVBAoMB0VsaXh4 8 | aXIxFDASBgNVBAsMC0RldmVsb3BtZW50MREwDwYDVQQDDAhjbWl4LnJpcDEfMB0G 9 | CSqGSIb3DQEJARYQYWRtaW5AZWxpeHhpci5pbzCCAiIwDQYJKoZIhvcNAQEBBQAD 10 | ggIPADCCAgoCggIBAJl196YgyCyvQjAFSEOD43otMlXXFeegpXc0itYD/z3eD5kM 11 | gyV1Jpauqv9yoF/MR6iqIQZa2NHaybS9Y54hJHjR2OUlXmH3LTGYEQ+dMY8DMvOJ 12 | SyFb1HuKnTg8hmPsaEAASWbtqzFUB7s9j1UpluEA3V1EaqQUYHVr/gG/Bmu6xxyk 13 | P+iKDAArLbo3n5FZLeUshSTgYcl37XVEh7ZBswcZjqacj4+Fh+X3vYkZjMkItCHy 14 | WV0z2S/hkjYRIm265OXYuJMTSBYIeR6cfcbSsw+XZgpKtYLd4kTUX7wFxuySLhk6 15 | FWX8U4QfSwTX+X7SdJ2MEXyroAfZW7ESQLLsgeSoFb6b61rJWUhrVT517QboQbII 16 | 3zAIrOAoE5PWPSfE2jSBb5lIryabcpUsgnUvwckRoo5VYnqVvNdOyFsyt8Ui0FG3 17 | FWt+t6twfEqyBYE3pKuE9vavOTybGBgzWjCIVQ3EyQGyS2GIqMKdxPMeuKbB7QCh 18 | JLeCmCvVPVAwdsmdO8DE6Bx3MuUhcbb0t9KXp5disEOaNBxINHgh65pADpKiit1d 19 | wuFOoX3GPm1DJr75UdaGeohE4IG8vbfdy8baeSYh+KsM/yRMGxU/dfAXHSDc9SYj 20 | zlekLqvUW66sNNN275f+NU3OTDGA9mcTeYKTxGWKQRRbKBJ/pVb7dktfXCj/AgMB 21 | AAGjaTBnMB0GA1UdDgQWBBR2r2q3XONumeNgSq0VIPG771pJyDAfBgNVHSMEGDAW 22 | gBR2r2q3XONumeNgSq0VIPG771pJyDAPBgNVHRMBAf8EBTADAQH/MBQGA1UdEQQN 23 | MAuCCWZvby5jby51azANBgkqhkiG9w0BAQsFAAOCAgEABLGi8l6odqCTZ/leDIC0 24 | DobMvPtzgTYU3hQC+dOY39+h5fM1db511vpfZJ1Wt4dJmPb1Fl+PolyD572zla+K 25 | ioCKIpk4MnbLln6ohgeywwwq2kB26BoEhZ7oh97ou/uSjN6Kn0TuduEs9PLnrUGx 26 | FnkeqxHmuR1YVpHfgclev6AUYZamC/uEPv/+raHEyAmNhsVLMQvlZ4KUeQ4r4PLX 27 | hPmawEwNrRTKCTIPxRjJJFHALtR3/LueCwAi35CWqTxceetglZLvyQMeuVX2GDOL 28 | 7hgajcIPF8PEnsGGOmIBAMDcqlxoFEiJwAStwRwgKMwqOG7kGbvIm8lox6kYMVVO 29 | wkoC5b7AXZ46kxwrNrwbEVxm4uklJ0k3lQdfr9Mc45EsYfwGBRj4tUoT+WWmvn8o 30 | kqZ6pgJ1DS/GwwqSQcksXtGbgdQli+quOFZABp3+VA6gEUe1U1Tqf+6dkdCVhD+q 31 | 6N+0YSy7qkFxQWMqVJmAhIFP1xUhNUS/WjH3CUT/h6AWOm9Zx6UnopTfafkqVEKq 32 | LrrphmPSSr34HfnlAbrCe8Gc8hVlXeKpor65c4o2g8Kzc8Y6Jxc7FjqLkCU7sgfe 33 | uAS3kGPDjPhyjvzYZYo9I6Q2GcQ5oH9sgZmt10mfZK6qdvF25fL7i1n9WBIz6f0g 34 | 5xPCKQLdA4PCHxGG0uDlo6w= 35 | -----END CERTIFICATE----- 36 | -------------------------------------------------------------------------------- /notifications/impl_test.go: -------------------------------------------------------------------------------- 1 | package notifications 2 | 3 | import ( 4 | "fmt" 5 | "gitlab.com/elixxir/notifications-bot/notifications/providers" 6 | "gitlab.com/elixxir/notifications-bot/storage" 7 | "os" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | var port = 4200 13 | 14 | type MockProvider struct { 15 | donech chan string 16 | } 17 | 18 | func (mp *MockProvider) Notify(csv string, target storage.GTNResult) (bool, error) { 19 | mp.donech <- csv 20 | return true, nil 21 | } 22 | 23 | // Unit test for startnotifications 24 | // tests logic including error cases 25 | func TestStartNotifications(t *testing.T) { 26 | wd, err := os.Getwd() 27 | if err != nil { 28 | t.Errorf("Failed to get working dir: %+v", err) 29 | return 30 | } 31 | 32 | params := Params{ 33 | Address: "0.0.0.0:42010", 34 | NotificationsPerBatch: 20, 35 | NotificationRate: 30, 36 | APNS: providers.APNSParams{ 37 | KeyPath: "", 38 | KeyID: "WQT68265C5", 39 | Issuer: "S6JDM2WW29", 40 | BundleID: "io.xxlabs.messenger", 41 | }, 42 | } 43 | 44 | params.KeyPath = wd + "/../testutil/cmix.rip.key" 45 | _, err = StartNotifications(params, false, true) 46 | if err == nil || !strings.Contains(err.Error(), "failed to read certificate at") { 47 | t.Errorf("Should have thrown an error for no cert path") 48 | } 49 | 50 | params.CertPath = wd + "/../testutil/cmix.rip.crt" 51 | _, err = StartNotifications(params, false, true) 52 | if err != nil { 53 | t.Errorf("Failed to start notifications successfully: %+v", err) 54 | } 55 | } 56 | 57 | // unit test for newimplementation 58 | // tests logic and error cases 59 | func TestNewImplementation(t *testing.T) { 60 | instance := getNewImpl() 61 | 62 | impl := NewImplementation(instance) 63 | if impl.Functions.RegisterForNotifications == nil || impl.Functions.UnregisterForNotifications == nil { 64 | t.Errorf("Functions were not properly set") 65 | } 66 | } 67 | 68 | // func to get a quick new impl using test creds 69 | func getNewImpl() *Impl { 70 | wd, _ := os.Getwd() 71 | params := Params{ 72 | NotificationsPerBatch: 20, 73 | NotificationRate: 30, 74 | Address: fmt.Sprintf("0.0.0.0:%d", port), 75 | KeyPath: wd + "/../testutil/cmix.rip.key", 76 | CertPath: wd + "/../testutil/cmix.rip.crt", 77 | FBCreds: "", 78 | } 79 | port += 1 80 | instance, _ := StartNotifications(params, false, true) 81 | instance.Storage, _ = storage.NewStorage("", "", "", "", "") 82 | return instance 83 | } 84 | -------------------------------------------------------------------------------- /notifications/ndf_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import ( 11 | "bytes" 12 | pb "gitlab.com/elixxir/comms/mixmessages" 13 | "gitlab.com/elixxir/comms/testutils" 14 | "sync" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | type MockPoller struct { 20 | ndf pb.NDF 21 | sync.Mutex 22 | } 23 | 24 | func (m *MockPoller) PollNdf(ndfHash []byte) (*pb.NDF, error) { 25 | m.Lock() 26 | defer m.Unlock() 27 | return &m.ndf, nil 28 | } 29 | func (m *MockPoller) UpdateNdf(newNDF pb.NDF) { 30 | m.Lock() 31 | defer m.Unlock() 32 | m.ndf = newNDF 33 | } 34 | 35 | // TestTrackNdf performs a basic test of the trackNdf function 36 | func TestTrackNdf(t *testing.T) { 37 | // Stopping function for the thread 38 | quitCh := make(chan bool) 39 | 40 | startNDF := pb.NDF{Ndf: make([]byte, 10)} 41 | copy(startNDF.Ndf, testutils.ExampleNDF[0:10]) 42 | 43 | newNDF := pb.NDF{Ndf: make([]byte, 10)} 44 | copy(newNDF.Ndf, testutils.ExampleNDF[0:10]) 45 | 46 | poller := &MockPoller{ 47 | ndf: startNDF, 48 | } 49 | 50 | gwUpdates := 0 51 | lastNdf := make([]byte, 10) 52 | gatewayEventHandler := func(ndf pb.NDF) ([]byte, error) { 53 | t.Logf("Updating Gateways with new NDF") 54 | t.Logf("%v == %v?", ndf.Ndf, lastNdf) 55 | if !bytes.Equal(lastNdf, ndf.Ndf) { 56 | t.Logf("Incrementing counter") 57 | copy(lastNdf, ndf.Ndf) 58 | gwUpdates++ 59 | } 60 | // We control the hash, so we control the update calls... 61 | ndfHash := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, byte(gwUpdates % 255)} 62 | return ndfHash, nil 63 | } 64 | 65 | go trackNdf(poller, quitCh, gatewayEventHandler) 66 | 67 | // 3 changes, starting change 68 | time.Sleep(100 * time.Millisecond) 69 | 70 | // 2nd change Start -> newNDF 71 | newNDF.Ndf[5] = byte('a') 72 | poller.UpdateNdf(newNDF) 73 | time.Sleep(1100 * time.Millisecond) 74 | 75 | // 3rd change newNDF -> startNDF 76 | poller.UpdateNdf(startNDF) 77 | time.Sleep(1100 * time.Millisecond) 78 | 79 | select { 80 | case quitCh <- true: 81 | break 82 | case <-time.After(2 * time.Second): 83 | t.Errorf("Could not stop NDF Tracking Thread") 84 | } 85 | 86 | if gwUpdates != 3 { 87 | t.Errorf("updates not detected, expected 3 got: %d", gwUpdates) 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /storage/buffer_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package storage 9 | 10 | import ( 11 | "gitlab.com/elixxir/primitives/notifications" 12 | "gitlab.com/xx_network/primitives/id" 13 | "gitlab.com/xx_network/primitives/id/ephemeral" 14 | "math/rand" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | func TestNotificationBuffer_Sorting(t *testing.T) { 20 | nb := NewNotificationBuffer() 21 | uid1 := id.NewIdFromString("zezima", id.User, t) 22 | eid1, _, _, _ := ephemeral.GetId(uid1, 16, time.Now().UnixNano()) 23 | uid2 := id.NewIdFromString("escaline", id.User, t) 24 | eid2, _, _, _ := ephemeral.GetId(uid2, 16, time.Now().UnixNano()) 25 | 26 | eid1count := 0 27 | eid2count := 0 28 | for i := 0; i <= 5; i++ { 29 | nd := []*notifications.Data{} 30 | rand.Seed(time.Now().UnixNano()) 31 | min := 2 32 | max := 5 33 | numNotifs := rand.Intn(max-min+1) + min 34 | rid := rand.Intn(500) + 1 35 | for j := 0; j <= numNotifs; j++ { 36 | msgHash := make([]byte, 32) 37 | ifp := make([]byte, 25) 38 | rand.Read(msgHash) 39 | rand.Read(ifp) 40 | var eid int64 41 | if rid%2 == 0 { 42 | eid = eid1.Int64() 43 | eid1count++ 44 | } else { 45 | eid = eid2.Int64() 46 | eid2count++ 47 | } 48 | nd = append(nd, ¬ifications.Data{ 49 | EphemeralID: eid, 50 | RoundID: uint64(rid), 51 | IdentityFP: ifp, 52 | MessageHash: msgHash, 53 | }) 54 | } 55 | nb.Add(id.Round(rid), nd) 56 | } 57 | 58 | sorted := nb.Swap() 59 | 60 | if nl, ok := sorted[eid1.Int64()]; ok { 61 | if len(nl) != eid1count { 62 | t.Errorf("Did not find expected number of notifications for eid1. Expected: %d, received: %d", eid1count, len(nl)) 63 | } 64 | var last uint64 65 | for _, n := range nl { 66 | if n.RoundID < last { 67 | t.Error("Ordering was incorrect") 68 | } 69 | last = n.RoundID 70 | } 71 | } 72 | if nl, ok := sorted[eid2.Int64()]; ok { 73 | if len(nl) != eid2count { 74 | t.Errorf("Did not find expected number of notifications for eid1. Expected: %d, received: %d", eid2count, len(nl)) 75 | } 76 | var last uint64 77 | for _, n := range nl { 78 | if n.RoundID < last { 79 | t.Error("Ordering was incorrect") 80 | } 81 | last = n.RoundID 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /notifications/providers/fcm.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package providers 9 | 10 | import ( 11 | "context" 12 | firebase "firebase.google.com/go" 13 | "firebase.google.com/go/messaging" 14 | "github.com/pkg/errors" 15 | jww "github.com/spf13/jwalterweatherman" 16 | "gitlab.com/elixxir/notifications-bot/storage" 17 | "google.golang.org/api/option" 18 | "strings" 19 | "time" 20 | ) 21 | 22 | // fcm struct representing Firebase cloud messaging providers 23 | type fcm struct { 24 | client *messaging.Client 25 | } 26 | 27 | // NewFCM returns an FCM-backed provider interface. 28 | func NewFCM(serviceKeyPath string) (Provider, error) { 29 | ctx := context.Background() 30 | opt := option.WithCredentialsFile(serviceKeyPath) 31 | app, err := firebase.NewApp(context.Background(), nil, opt) 32 | if err != nil { 33 | return nil, errors.Errorf("Error initializing app: %v", err) 34 | } 35 | 36 | cl, err := app.Messaging(ctx) 37 | if err != nil { 38 | return nil, errors.Errorf("Error getting Messaging app: %+v", err) 39 | } 40 | 41 | return &fcm{ 42 | client: cl, 43 | }, nil 44 | } 45 | 46 | // Notify implements the Provider interface for FCM, sending the notifications to the provider. 47 | func (f *fcm) Notify(csv string, target storage.GTNResult) (bool, error) { 48 | ctx := context.Background() 49 | ttl := 7 * 24 * time.Hour 50 | message := &messaging.Message{ 51 | Data: map[string]string{ 52 | "notificationsTag": csv, // TODO: swap to notificationsTag constant from notifications package (move to avoid circular dep) 53 | }, 54 | Android: &messaging.AndroidConfig{ 55 | Priority: "high", 56 | TTL: &ttl, 57 | }, 58 | Token: target.Token, 59 | } 60 | 61 | resp, err := f.client.Send(ctx, message) 62 | if err != nil { 63 | // Check token validity 64 | validToken := true 65 | invalidToken := strings.Contains(err.Error(), "400") && 66 | strings.Contains(err.Error(), "Invalid registration") 67 | 68 | if strings.Contains(err.Error(), "404") || invalidToken { 69 | validToken = false 70 | err = errors.WithMessagef(err, "Failed to notify user with Transmission RSA hash %+v due to invalid token", target.TransmissionRSAHash) 71 | } else { 72 | err = errors.WithMessagef(err, "Failed to notify user with Transmission RSA hash %+v", target.TransmissionRSAHash) 73 | } 74 | 75 | return validToken, err 76 | } 77 | jww.DEBUG.Printf("Notified ephemeral ID %+v [%+v] via fcm and received response %+v", target.EphemeralId, target.Token, resp) 78 | return true, nil 79 | } 80 | -------------------------------------------------------------------------------- /notifications/send_test.go: -------------------------------------------------------------------------------- 1 | package notifications 2 | 3 | import ( 4 | "gitlab.com/elixxir/notifications-bot/constants" 5 | "gitlab.com/elixxir/notifications-bot/notifications/providers" 6 | "gitlab.com/elixxir/notifications-bot/storage" 7 | "gitlab.com/elixxir/primitives/notifications" 8 | "gitlab.com/xx_network/primitives/id" 9 | "gitlab.com/xx_network/primitives/id/ephemeral" 10 | "sync" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | func TestImpl_SendBatch(t *testing.T) { 16 | // Init storage 17 | s, err := storage.NewStorage("", "", "", "", "") 18 | if err != nil { 19 | t.Errorf("Failed to make new storage: %+v", err) 20 | } 21 | 22 | dchan := make(chan string, 10) 23 | // Init mock firebase comms 24 | //badsend := func(firebase.FBSender, string, string) (string, error) { 25 | // return "", errors.New("Failed") 26 | //} 27 | 28 | // Create impl 29 | i := Impl{ 30 | providers: map[string]providers.Provider{}, 31 | Storage: s, 32 | 33 | roundStore: sync.Map{}, 34 | maxNotifications: 0, 35 | maxPayloadBytes: 0, 36 | } 37 | 38 | i.providers[constants.MessengerAndroid.String()] = &MockProvider{donech: dchan} 39 | i.providers[constants.MessengerIOS.String()] = &MockProvider{donech: dchan} 40 | 41 | // Identity setup 42 | uid := id.NewIdFromString("zezima", id.User, t) 43 | iid, err := ephemeral.GetIntermediaryId(uid) 44 | if err != nil { 45 | t.Errorf("Failed to create iid: %+v", err) 46 | } 47 | if err != nil { 48 | t.Errorf("Could not parse precanned time: %v", err.Error()) 49 | } 50 | _, epoch := ephemeral.HandleQuantization(time.Now()) 51 | _, err = s.RegisterForNotifications(iid, []byte("rsacert"), "fcm:token", constants.MessengerAndroid.String(), epoch, 16) 52 | if err != nil { 53 | t.Errorf("Failed to add fake user: %+v", err) 54 | } 55 | eph, err := s.GetLatestEphemeral() 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | _, err = i.SendBatch(map[int64][]*notifications.Data{}) 60 | if err != nil { 61 | t.Errorf("Error on sending empty batch: %+v", err) 62 | } 63 | 64 | unsent, err := i.SendBatch(map[int64][]*notifications.Data{ 65 | eph.EphemeralId: {{EphemeralID: eph.EphemeralId, RoundID: 3, MessageHash: []byte("hello"), IdentityFP: []byte("identity")}}, 66 | }) 67 | if err != nil { 68 | t.Errorf("Error on sending small batch: %+v", err) 69 | } 70 | if len(unsent) < 1 { 71 | t.Errorf("Should have received notification back as unsent, instead got %+v", unsent) 72 | } 73 | 74 | i.maxPayloadBytes = 4096 75 | i.maxNotifications = 20 76 | unsent, err = i.SendBatch(map[int64][]*notifications.Data{ 77 | 1: {{EphemeralID: eph.EphemeralId, RoundID: 3, MessageHash: []byte("hello"), IdentityFP: []byte("identity")}}, 78 | }) 79 | if err != nil { 80 | t.Errorf("Error on sending small batch again: %+v", err) 81 | } 82 | if len(unsent) > 0 { 83 | t.Errorf("Should have received notification back as unsent, instead got %+v", unsent) 84 | } 85 | 86 | timeout := time.NewTicker(3 * time.Second) 87 | select { 88 | case <-dchan: 89 | t.Logf("Received on data chan!") 90 | case <-timeout.C: 91 | t.Errorf("Did not receive data before timeout") 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | before_script: 2 | - go version || echo "Go executable not found." 3 | - echo $CI_BUILD_REF 4 | - echo $CI_PROJECT_DIR 5 | - echo $PWD 6 | - eval $(ssh-agent -s) 7 | - echo "$SSH_PRIVATE_KEY" | tr -d '\r' | ssh-add - > /dev/null 8 | - mkdir -p ~/.ssh 9 | - chmod 700 ~/.ssh 10 | - ssh-keyscan -t rsa $GITLAB_SERVER > ~/.ssh/known_hosts 11 | - git config --global url."git@$GITLAB_SERVER:".insteadOf "https://gitlab.com/" 12 | - git config --global url."git@$GITLAB_SERVER:".insteadOf "https://git.xx.network/" --add 13 | - export PATH=$HOME/go/bin:$PATH 14 | 15 | stages: 16 | - build 17 | - trigger_integration 18 | 19 | build: 20 | stage: build 21 | image: $DOCKER_IMAGE 22 | except: 23 | - tags 24 | script: 25 | - git clean -ffdx 26 | - go mod vendor -v 27 | - go build ./... 28 | - go mod tidy 29 | - apt-get update 30 | - apt-get install bc -y 31 | 32 | - mkdir -p testdata 33 | # Test coverage 34 | - go-acc --covermode atomic --output testdata/coverage.out ./... -- -v 35 | # Exclude cmd from test coverage as it is command line related tooling 36 | - cat testdata/coverage.out | grep -v cmd | grep -v main.go > testdata/coverage-real.out 37 | - go tool cover -func=testdata/coverage-real.out 38 | - go tool cover -html=testdata/coverage-real.out -o testdata/coverage.html 39 | 40 | # Test Coverage Check 41 | - go tool cover -func=testdata/coverage-real.out | grep "total:" | awk '{print $3}' | sed 's/\%//g' > testdata/coverage-percentage.txt 42 | - export CODE_CHECK=$(echo "$(cat testdata/coverage-percentage.txt) >= $MIN_CODE_COVERAGE" | bc -l) 43 | - (if [ "$CODE_CHECK" == "1" ]; then echo "Minimum coverage of $MIN_CODE_COVERAGE succeeded"; else echo "Minimum coverage of $MIN_CODE_COVERAGE failed"; exit 1; fi); 44 | 45 | - mkdir -p release 46 | - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' ./... 47 | - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o release/notifications.linux64 main.go 48 | # - GOOS=windows GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o release/notifications.win64 main.go 49 | # - GOOS=windows GOARCH=386 CGO_ENABLED=0 go build -ldflags '-w -s' -o release/notifications.win32 main.go 50 | - GOOS=darwin GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o release/notifications.darwin64 main.go 51 | 52 | - /hash-file.sh release/notifications.linux64 53 | artifacts: 54 | paths: 55 | - vendor/ 56 | - testdata/ 57 | - release/ 58 | 59 | tag: 60 | stage: trigger_integration 61 | only: 62 | - master 63 | image: $DOCKER_IMAGE 64 | script: 65 | - git remote add origin_tags git@git.xx.network:elixxir/notifications-bot.git || true 66 | - git remote set-url origin_tags git@git.xx.network:elixxir/notifications-bot.git || true 67 | - git tag $(./release/notifications.linux64 version | grep "Elixxir Notifications Server v"| cut -d ' ' -f4) -f 68 | - git push origin_tags -f --tags 69 | 70 | trigger-integration: 71 | stage: trigger_integration 72 | trigger: 73 | project: elixxir/integration 74 | branch: $CI_COMMIT_REF_NAME 75 | only: 76 | - master 77 | - release 78 | -------------------------------------------------------------------------------- /testutil/cmix.rip.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCZdfemIMgsr0Iw 3 | BUhDg+N6LTJV1xXnoKV3NIrWA/893g+ZDIMldSaWrqr/cqBfzEeoqiEGWtjR2sm0 4 | vWOeISR40djlJV5h9y0xmBEPnTGPAzLziUshW9R7ip04PIZj7GhAAElm7asxVAe7 5 | PY9VKZbhAN1dRGqkFGB1a/4BvwZrusccpD/oigwAKy26N5+RWS3lLIUk4GHJd+11 6 | RIe2QbMHGY6mnI+PhYfl972JGYzJCLQh8lldM9kv4ZI2ESJtuuTl2LiTE0gWCHke 7 | nH3G0rMPl2YKSrWC3eJE1F+8Bcbski4ZOhVl/FOEH0sE1/l+0nSdjBF8q6AH2Vux 8 | EkCy7IHkqBW+m+tayVlIa1U+de0G6EGyCN8wCKzgKBOT1j0nxNo0gW+ZSK8mm3KV 9 | LIJ1L8HJEaKOVWJ6lbzXTshbMrfFItBRtxVrfrercHxKsgWBN6SrhPb2rzk8mxgY 10 | M1owiFUNxMkBskthiKjCncTzHrimwe0AoSS3gpgr1T1QMHbJnTvAxOgcdzLlIXG2 11 | 9LfSl6eXYrBDmjQcSDR4IeuaQA6SoordXcLhTqF9xj5tQya++VHWhnqIROCBvL23 12 | 3cvG2nkmIfirDP8kTBsVP3XwFx0g3PUmI85XpC6r1FuurDTTdu+X/jVNzkwxgPZn 13 | E3mCk8RlikEUWygSf6VW+3ZLX1wo/wIDAQABAoICADbY2bUPXFvUG6TMDoLK3X9q 14 | LeZOJC8P1HOhXMmWzh/PgOWjei/mCe+q58S6tCTo/ueCPqFl0L5YIuUtFzCKCd4A 15 | qjNjwrLiw81I2zgoZ3EEpK4z8J0wk+W/qedSgnmuIahWMeXOpfPQY58BJnw6jAlI 16 | 5NUTwcV43uy4tyTqoCHatJVBAvJafGWHCSXYAXjSVbvlyIRgibKW4VCbOKHkI1cz 17 | RC+6HvkdsW94ts4MSqwDJ1ZGprfP8xzQ0w/t2c88D9kyNu2h4460Yo2VQsLPxV2a 18 | L8cX4n2uTdNpz/mEWw9sQ3uSrdkwcKz0/jvx8OLp1vIEAK/9KwoDHmlP673HF7kS 19 | uyjQ+Lc3qKJBloisl9q93187D/Q9fEundBXCLRI2rNbqa0PdTCE5uONNS5jqlD9V 20 | 17had/N2+NqqITDxiroya2BxGO25svRHW1zGsvR+22ijqC5T0rOIDj8mFnMkIVl3 21 | E3fI7DvjC8ML1xAZIciE5trLDwuBxEwrrDvHZ7lXfjZC2KB/Ic/xGr0URY10yR6x 22 | l7skwnV7QVkSYBTwpFB2umkDEEqZwOJdkRquAyB9cREUaFQyVdQGzJ9UQ7uCyCcE 23 | bpGkMTU/OVC9U1HcaCuwrqBUObuZac/ANAHcldGoq2P3D1JRrHDgxu+qRN4Tn1N3 24 | 6Zmvd6wMT16/YWOrbDmBAoIBAQDLfqgegWHxAxjehanpWO9HzvMI+E5XaAMWrg9N 25 | Rt0knKRzmFRJmHBzf14T0F38qZp0nnrfxho0bwO9Hi2lovOaESux4mO8xvX5cal0 26 | K7toZptqHF0SnNtcGMohh3RYR2yydTr3pW+3g9GxNXBOEjLgcmmmQjy7ddW+nnTm 27 | vaSG8psOmIoDi5zryr0FAXlhujPXMw/ANVN47J6Z3T23g8ZSsGoCZTb3m+0Yr5FO 28 | BPeZaqxUOV175J0E7lVbpmu1mePaKhxmP4HyyTmEzKB03f9QiDtNEqAGrfeiKu5N 29 | YVGZKfD+D1ESXBYQspn9SzgSpaosmVUxwZlT3N3UGj+90+TBAoIBAQDBDnExGU+Q 30 | 5IJ66Sb68CWvWs88aoX0+1ySyvv1QLHeO0s6h4Oudl8BpCP6mNz6Ah3TI0nPNPwd 31 | RphagWkNx2iAMY8tlygfqeLqxvQQa4O13YdduNxvKn3JFQT95zmsTfbXUet4NUDb 32 | BHU1ELKQRvSDPcKI5HwD088cORDeNHbhaf4w0exyg27HywqaiEiugWEF+HTjICvS 33 | j7UHu2SJZx+PmFouZh4VA3y83XM1aPe1kpvsOam2U8BLVuAh6WC2YRd/D0UC8KBB 34 | tL+I/1Wl3/Te4nD8/E/ktBee86qJseWJj295YqnNOxwdIHCItFj9zzrxRHJC8wrA 35 | Tzi3cNJc4r2/AoIBAQCIDBKL6oY3YpGy23wARPQcdxBm89M0fpZqCE60pGbevlb3 36 | 7WSUm3S60vFrn8fmTuGzi+ysRL5qRbojKtTrwyjH3SjwUWHK6N20OjhHMaGmp8rJ 37 | w1K6SeIiDXS9u9id2IqkONoGjTVGZLvBlO+TUnlvMy7M98WwspQHT9rqFOm1Suc9 38 | d9/1hNaRkcaTXSvwmVNlUX77SuuWkeNrDM3hLhleWRFWrqJ0Imv+MAqeNZXvnLC/ 39 | 0mZcmTgc/tZUYsvp6ou55KN9/IF9duicj03LNEwoZBv5aDVSoeZIJhmR5Dlwg+jZ 40 | ghX6h1Q9L5riC/LeDKHcFVsu27cNqUEpN69b0xlBAoIBADo+XS/u7u5LwoHKbZQv 41 | d28b+oHDsX5jh15SFwm65u6g/OU/lR7BX5BjMOedzq0ujkjw0IfO+HDsp3JGsKcT 42 | jhd+3C9o9xX2bxtdwqxhg28O0pQX/YkcTK6pxMPFSsUNEHeNo3i0uEhbY/EKhJS6 43 | k3I048fhBvkwob9mCAzBz0vaanHYI3m45Wcpfp14mFTte5QNjVYokpAIAxm+E2rQ 44 | zdjIogx1ioXUc4GXXfazIGiLPrdZ/jWfttgD8cLJYgAj8q7GsI9egTrRiSePwQs2 45 | Me48atIoXQImwymVYdIA9bs2pu78MTZVqvmum8TihCauqp23hLFmGcxDGl1dkFmt 46 | nokCggEAAyrr2t5WTz/mmVIOTl8PIq265m1ReULQkrheOdNq81tOohb/RAGo2uvO 47 | eL4gPPF4nhLA1ucp9N8T5CjBhhv5sVBMBE4Nza2eXUNM5YXidPZXOah3CBtQclSd 48 | 2VKFporZ9yK8d9o+iWNKhF4lMwpBj1S0c04HekxWcMXtJAhuOyv80MY1cgd/Y/qg 49 | S6kO/ueQOam3J0mANKVkko7wj74RnXeg/Mc81ndS1r/eMIh52OZA/rI/HdbS1TLl 50 | WnkZul1kaZ2xlgZTtPytDwFcF3deP6WPqly5y758YkW++HlO/aCKjcqAVvwOSbaV 51 | 7pmdM9jmyD/ahamb2nOk1CzpWmz0Ww== 52 | -----END PRIVATE KEY----- 53 | -------------------------------------------------------------------------------- /notifications/ephemeral_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import ( 11 | "fmt" 12 | jww "github.com/spf13/jwalterweatherman" 13 | "gitlab.com/elixxir/notifications-bot/constants" 14 | "gitlab.com/elixxir/notifications-bot/storage" 15 | "gitlab.com/xx_network/comms/connect" 16 | "gitlab.com/xx_network/primitives/id" 17 | "gitlab.com/xx_network/primitives/id/ephemeral" 18 | "os" 19 | "testing" 20 | "time" 21 | ) 22 | 23 | func TestMain(m *testing.M) { 24 | jww.SetStdoutThreshold(jww.LevelTrace) 25 | connect.TestingOnlyDisableTLS = true 26 | os.Exit(m.Run()) 27 | } 28 | 29 | func TestImpl_InitDeleter(t *testing.T) { 30 | s, err := storage.NewStorage("", "", "", "", "") 31 | if err != nil { 32 | t.Errorf("Failed to init storage: %+v", err) 33 | } 34 | impl := &Impl{ 35 | Storage: s, 36 | } 37 | uid := id.NewIdFromString("deleter_zezima", id.User, t) 38 | iid, err := ephemeral.GetIntermediaryId(uid) 39 | if err != nil { 40 | t.Errorf("Failed to get intermediary ephemeral id: %+v", err) 41 | } 42 | if err != nil { 43 | t.Fatalf("Could not parse precanned time: %v", err.Error()) 44 | } 45 | _, epoch := ephemeral.HandleQuantization(time.Now().Add(-30 * time.Hour)) 46 | _, err = s.RegisterForNotifications(iid, []byte("trsa"), "token", constants.MessengerIOS.String(), epoch, 16) 47 | if err != nil { 48 | t.Errorf("Failed to add user to storage: %+v", err) 49 | } 50 | 51 | e, err := s.GetLatestEphemeral() 52 | elist, err := s.GetEphemeral(e.EphemeralId) 53 | if err != nil { 54 | t.Errorf("Failed to get latest ephemeral for user: %+v", err) 55 | } 56 | if elist == nil { 57 | t.Error("Did not receive ephemeral for user") 58 | } 59 | impl.initDeleter() 60 | time.Sleep(time.Second * 10) 61 | elist, err = s.GetEphemeral(e.EphemeralId) 62 | if err == nil { 63 | t.Errorf("Ephemeral should have been deleted, did not receive error: %+v", e) 64 | } 65 | } 66 | 67 | func TestImpl_InitCreator(t *testing.T) { 68 | s, err := storage.NewStorage("", "", "", "", "") 69 | if err != nil { 70 | t.Errorf("Failed to init storage: %+v", err) 71 | t.FailNow() 72 | } 73 | impl, err := StartNotifications(Params{ 74 | NotificationsPerBatch: 20, 75 | NotificationRate: 30, 76 | Address: "", 77 | CertPath: "", 78 | KeyPath: "", 79 | FBCreds: "", 80 | }, true, true) 81 | if err != nil { 82 | t.Errorf("Failed to create impl: %+v", err) 83 | t.FailNow() 84 | } 85 | impl.Storage = s 86 | uid := id.NewIdFromString("zezima", id.User, t) 87 | iid, err := ephemeral.GetIntermediaryId(uid) 88 | if err != nil { 89 | t.Errorf("Failed to get intermediary ephemeral id: %+v", err) 90 | } 91 | if err != nil { 92 | t.Errorf("Could not parse precanned time: %v", err.Error()) 93 | } 94 | _, epoch := ephemeral.HandleQuantization(time.Now()) 95 | u, err := s.RegisterForNotifications(iid, []byte("trsa"), "token", constants.MessengerIOS.String(), epoch, 16) 96 | if err != nil { 97 | t.Errorf("Failed to add user to storage: %+v", err) 98 | } 99 | 100 | u, err = s.GetUser(u.TransmissionRSAHash) 101 | if err != nil { 102 | t.Fatal(err) 103 | } 104 | fmt.Println(u.Identities[0].OffsetNum) 105 | impl.initCreator() 106 | e, err := s.GetLatestEphemeral() 107 | if err != nil { 108 | t.Errorf("Failed to get latest ephemeral: %+v", err) 109 | } 110 | if e == nil { 111 | t.Error("Did not receive ephemeral for user") 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /notifications/providers/apns.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package providers 9 | 10 | import ( 11 | "encoding/base64" 12 | "github.com/pkg/errors" 13 | "github.com/sideshow/apns2" 14 | "github.com/sideshow/apns2/payload" 15 | apnstoken "github.com/sideshow/apns2/token" 16 | jww "github.com/spf13/jwalterweatherman" 17 | "gitlab.com/elixxir/notifications-bot/constants" 18 | "gitlab.com/elixxir/notifications-bot/storage" 19 | "time" 20 | ) 21 | 22 | // APNSParams holds config info specific to apple's push notification service 23 | type APNSParams struct { 24 | KeyPath string 25 | KeyID string 26 | Issuer string 27 | BundleID string 28 | Dev bool 29 | } 30 | 31 | // apns struct represents an APNS provider 32 | type apns struct { 33 | *apns2.Client 34 | topic string 35 | } 36 | 37 | // NewApns returns an APNS-backed provider interface. 38 | func NewApns(params APNSParams) (Provider, error) { 39 | var apnsClient *apns2.Client 40 | if params.KeyID == "" || params.Issuer == "" || params.BundleID == "" { 41 | return nil, errors.Errorf("APNS not properly configured: %+v", params) 42 | } 43 | 44 | jww.INFO.Printf("Initializing APNS provider for %s (%s) with key ID %s", params.BundleID, params.Issuer, params.KeyID) 45 | if params.Dev { 46 | jww.WARN.Printf("APNS provider for %s running in dev mode", params.BundleID) 47 | } 48 | 49 | authKey, err := apnstoken.AuthKeyFromFile(params.KeyPath) 50 | if err != nil { 51 | return nil, errors.WithMessage(err, "Failed to load auth key from file") 52 | } 53 | token := &apnstoken.Token{ 54 | AuthKey: authKey, 55 | // KeyID from developer account (Certificates, Identifiers & Profiles -> Keys) 56 | KeyID: params.KeyID, 57 | // TeamID from developer account (View Account -> Membership) 58 | TeamID: params.Issuer, 59 | } 60 | apnsClient = apns2.NewTokenClient(token) 61 | if params.Dev { 62 | jww.INFO.Printf("Running with dev apns gateway") 63 | apnsClient.Development() 64 | } else { 65 | apnsClient.Production() 66 | } 67 | 68 | return &apns{ 69 | Client: apnsClient, 70 | topic: params.BundleID, 71 | }, nil 72 | } 73 | 74 | // Notify implements the Provider interface for APNS, sending the notifications to the provider. 75 | func (a *apns) Notify(csv string, target storage.GTNResult) (bool, error) { 76 | notifPayload := payload.NewPayload().AlertTitle(constants.NotificationTitle).AlertBody( 77 | constants.NotificationBody).MutableContent().Custom( 78 | constants.NotificationsTag, csv) 79 | notif := &apns2.Notification{ 80 | CollapseID: base64.StdEncoding.EncodeToString(target.TransmissionRSAHash), 81 | DeviceToken: target.Token, 82 | Expiration: time.Now().Add(time.Hour * 24 * 7), 83 | Priority: apns2.PriorityHigh, 84 | Payload: notifPayload, 85 | PushType: apns2.PushTypeAlert, 86 | Topic: a.topic, 87 | } 88 | resp, err := a.Client.Push(notif) 89 | if err != nil { 90 | return true, errors.WithMessagef(err, "Failed to send notification via APNS: %+v", resp) 91 | // TODO : Should be re-enabled for specific error cases? deep dive on apns docs may be helpful 92 | //err := db.DeleteUserByHash(u.TransmissionRSAHash) 93 | //if err != nil { 94 | // return errors.WithMessagef(err, "Failed to remove user registration tRSA hash: %+v", u.TransmissionRSAHash) 95 | //} 96 | } 97 | jww.DEBUG.Printf("Notified ephemeral ID %+v [%+v] via APNS and received response %+v", target.EphemeralId, target.Token, resp) 98 | return true, nil 99 | } 100 | 101 | func (a *apns) GetTopic() string { 102 | return a.topic 103 | } 104 | -------------------------------------------------------------------------------- /storage/buffer.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package storage 9 | 10 | import ( 11 | jww "github.com/spf13/jwalterweatherman" 12 | "gitlab.com/elixxir/primitives/notifications" 13 | "gitlab.com/xx_network/primitives/id" 14 | "sync" 15 | "sync/atomic" 16 | ) 17 | 18 | // NotificationBuffer struct holds notifications received by the bot that have yet to be sent 19 | // IT uses a sync.Map with a RWMutex to allow swapping maps for faster concurrent read/write access 20 | // Stores lowest and highest rounds to provide ordering when queried 21 | type NotificationBuffer struct { 22 | lock sync.RWMutex 23 | gr *uint64 24 | lr *uint64 25 | bufMap *sync.Map 26 | } 27 | 28 | // NewNotificationBuffer is the constructor for NotificationBuffers. Initializes maps & sets initial atomic values 29 | func NewNotificationBuffer() *NotificationBuffer { 30 | gr, lr := uint64(0), uint64(0) 31 | 32 | nb := &NotificationBuffer{ 33 | bufMap: &sync.Map{}, 34 | gr: &gr, 35 | lr: &lr, 36 | } 37 | return nb 38 | } 39 | 40 | // Swap takes the write lock, replaces the current sync.Map and round values with new ones, and 41 | // (outside the lock) sorts the old map into a map[ephID][]*notifications.Data, where each ephID list is sorted by RID 42 | // NOTE THAT ANY UNSENT NOTIFICATIONS FROM SWAP MUST BE RE-ADDED TO THE BUFFER 43 | func (bnm *NotificationBuffer) Swap() map[int64][]*notifications.Data { 44 | bnm.lock.Lock() 45 | 46 | // Swap map & reset greatest and least rounds 47 | var m *sync.Map 48 | m, bnm.bufMap = bnm.bufMap, &sync.Map{} 49 | lr := atomic.SwapUint64(bnm.lr, 0) 50 | gr := atomic.SwapUint64(bnm.gr, 0) 51 | 52 | bnm.lock.Unlock() 53 | 54 | outMap := make(map[int64][]*notifications.Data) 55 | 56 | // Function originally used to range over sync.Map, now called in order by RID on entries using gr and lr 57 | f := func(key, value interface{}) bool { 58 | l := value.([]*notifications.Data) 59 | for _, n := range l { 60 | nSlice, exists := outMap[n.EphemeralID] 61 | if exists { 62 | nSlice = append(nSlice, n) 63 | } else { 64 | nSlice = []*notifications.Data{n} 65 | } 66 | outMap[n.EphemeralID] = nSlice 67 | } 68 | return true 69 | } 70 | 71 | // Iterate through seen rounds from least to greatest 72 | for i := lr; i <= gr; i++ { 73 | rid := id.Round(i) 74 | nlist, ok := m.Load(rid) 75 | if !ok { 76 | jww.DEBUG.Printf("No notification data for round %+v", rid) 77 | continue 78 | } 79 | f(rid, nlist) 80 | } 81 | 82 | return outMap 83 | } 84 | 85 | // Add accepts a list of notification data and an associated round ID 86 | // The list will be inserted to the current sync.Map under the given round ID 87 | // NOTE: THIS WILL OVERWRITE, SHOULD BE CALLED ONCE PER ROUND, OR AGAIN TO REPLACE OVERFLOW NOTIFICATIONS 88 | func (bnm *NotificationBuffer) Add(rid id.Round, l []*notifications.Data) { 89 | bnm.lock.RLock() 90 | defer bnm.lock.RUnlock() 91 | 92 | // Update stored round IDs 93 | bnm.updateRIDs(rid) 94 | 95 | // Store data for round 96 | bnm.bufMap.Store(rid, l) 97 | } 98 | 99 | func (bnm *NotificationBuffer) updateRIDs(rid id.Round) { 100 | flop := false 101 | for flop == false { 102 | gr := atomic.LoadUint64(bnm.gr) 103 | if gr < uint64(rid) || gr == 0 { 104 | flop = atomic.CompareAndSwapUint64(bnm.gr, gr, uint64(rid)) 105 | } else { 106 | flop = true 107 | } 108 | } 109 | flop = false 110 | 111 | for flop == false { 112 | lr := atomic.LoadUint64(bnm.lr) 113 | if lr > uint64(rid) || lr == 0 { 114 | flop = atomic.CompareAndSwapUint64(bnm.lr, lr, uint64(rid)) 115 | } else { 116 | flop = true 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /notifications/send.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import ( 11 | "github.com/pkg/errors" 12 | jww "github.com/spf13/jwalterweatherman" 13 | "gitlab.com/elixxir/notifications-bot/storage" 14 | "gitlab.com/elixxir/primitives/notifications" 15 | "gitlab.com/xx_network/primitives/id" 16 | "time" 17 | ) 18 | 19 | const notificationsTag = "notificationData" 20 | 21 | // Sender is a long-running thread which sends out received notifications to 22 | // the appropriate providers every sendFreq seconds. 23 | func (nb *Impl) Sender(sendFreq int) { 24 | sendTicker := time.NewTicker(time.Duration(sendFreq) * time.Second) 25 | for { 26 | select { 27 | case <-sendTicker.C: 28 | go func() { 29 | // Retreive & swap notification buffer 30 | notifBuf := nb.Storage.GetNotificationBuffer() 31 | notifMap := notifBuf.Swap() 32 | 33 | if len(notifMap) == 0 { 34 | return 35 | } 36 | 37 | unsent := map[uint64][]*notifications.Data{} 38 | rest, err := nb.SendBatch(notifMap) 39 | if err != nil { 40 | jww.ERROR.Printf("Failed to send notification batch: %+v", err) 41 | // If we fail to run SendBatch, put everything back in unsent 42 | for _, elist := range notifMap { 43 | for _, n := range elist { 44 | unsent[n.RoundID] = append(unsent[n.RoundID], n) 45 | } 46 | } 47 | } else { 48 | // Loop through rest and add to unsent map 49 | for _, n := range rest { 50 | unsent[n.RoundID] = append(unsent[n.RoundID], n) 51 | } 52 | } 53 | // Re-add unsent notifications to the buffer 54 | for rid, nd := range unsent { 55 | notifBuf.Add(id.Round(rid), nd) 56 | } 57 | }() 58 | } 59 | } 60 | } 61 | 62 | // SendBatch accepts the map of ephemeralID:list[notifications.Data] 63 | // It handles logic for building the CSV & sending to devices 64 | func (nb *Impl) SendBatch(data map[int64][]*notifications.Data) ([]*notifications.Data, error) { 65 | csvs := map[int64]string{} 66 | var ephemerals []int64 67 | var unsent []*notifications.Data 68 | jww.INFO.Printf("data: %+v", data) 69 | for i, ilist := range data { 70 | var overflow, toSend []*notifications.Data 71 | if len(ilist) > nb.maxNotifications { 72 | overflow = ilist[nb.maxNotifications:] 73 | toSend = ilist[:nb.maxNotifications] 74 | } else { 75 | toSend = ilist[:] 76 | } 77 | 78 | notifs, rest := notifications.BuildNotificationCSV(toSend, nb.maxPayloadBytes-len([]byte(notificationsTag))) 79 | overflow = append(overflow, rest...) 80 | csvs[i] = string(notifs) 81 | ephemerals = append(ephemerals, i) 82 | unsent = append(unsent, overflow...) 83 | } 84 | toNotify, err := nb.Storage.GetToNotify(ephemerals) 85 | if err != nil { 86 | return nil, errors.WithMessage(err, "Failed to get list of tokens to notify") 87 | } 88 | for i := range toNotify { 89 | go func(res storage.GTNResult) { 90 | nb.notify(csvs[res.EphemeralId], res) 91 | }(toNotify[i]) 92 | } 93 | return unsent, nil 94 | } 95 | 96 | // notify is a helper function which handles sending notifications to either APNS or firebase 97 | func (nb *Impl) notify(csv string, toNotify storage.GTNResult) { 98 | provider, ok := nb.providers[toNotify.App] 99 | if !ok { 100 | jww.ERROR.Printf("Could not find provider for app %s", toNotify.App) 101 | return 102 | } 103 | tokenValid, err := provider.Notify(csv, toNotify) 104 | if err != nil { 105 | jww.ERROR.Println(err) 106 | if !tokenValid { 107 | jww.DEBUG.Printf("User with tRSA hash %+v has invalid token [%+v] for app %s - attempting to remove", toNotify.TransmissionRSAHash, toNotify.Token, toNotify.App) 108 | err := nb.Storage.DeleteToken(toNotify.Token) 109 | if err != nil { 110 | jww.ERROR.Printf("Failed to remove %s token registration tRSA hash %+v: %+v", toNotify.App, toNotify.TransmissionRSAHash, err) 111 | } 112 | } 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module gitlab.com/elixxir/notifications-bot 2 | 3 | go 1.19 4 | 5 | require ( 6 | firebase.google.com/go v3.12.0+incompatible 7 | github.com/pkg/errors v0.9.1 8 | github.com/sideshow/apns2 v0.20.0 9 | github.com/spf13/cobra v1.5.0 10 | github.com/spf13/jwalterweatherman v1.1.0 11 | github.com/spf13/viper v1.12.0 12 | gitlab.com/elixxir/comms v0.0.4-0.20230608201134-3cac2b04fb52 13 | gitlab.com/elixxir/crypto v0.0.7-0.20230519213156-886b0387c218 14 | gitlab.com/elixxir/primitives v0.0.3-0.20230214180039-9a25e2d3969c 15 | gitlab.com/xx_network/comms v0.0.4-0.20230214180029-5387fb85736d 16 | gitlab.com/xx_network/crypto v0.0.5-0.20230214003943-8a09396e95dd 17 | gitlab.com/xx_network/primitives v0.0.4-0.20230310205521-c440e68e34c4 18 | google.golang.org/api v0.103.0 19 | gorm.io/driver/postgres v1.5.0 20 | gorm.io/driver/sqlite v1.4.4 21 | gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11 22 | ) 23 | 24 | require ( 25 | cloud.google.com/go v0.107.0 // indirect 26 | cloud.google.com/go/compute v1.13.0 // indirect 27 | cloud.google.com/go/compute/metadata v0.2.1 // indirect 28 | cloud.google.com/go/firestore v1.9.0 // indirect 29 | cloud.google.com/go/iam v0.10.0 // indirect 30 | cloud.google.com/go/longrunning v0.3.0 // indirect 31 | cloud.google.com/go/storage v1.27.0 // indirect 32 | git.xx.network/elixxir/grpc-web-go-client v0.0.0-20230214175953-5b5a8c33d28a // indirect 33 | github.com/cenkalti/backoff/v4 v4.1.3 // indirect 34 | github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f // indirect 35 | github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect 36 | github.com/elliotchance/orderedmap v1.4.0 // indirect 37 | github.com/fsnotify/fsnotify v1.5.4 // indirect 38 | github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 // indirect 39 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 40 | github.com/golang/protobuf v1.5.2 // indirect 41 | github.com/google/go-cmp v0.5.9 // indirect 42 | github.com/google/uuid v1.3.0 // indirect 43 | github.com/googleapis/enterprise-certificate-proxy v0.2.0 // indirect 44 | github.com/googleapis/gax-go/v2 v2.7.0 // indirect 45 | github.com/gorilla/websocket v1.5.0 // indirect 46 | github.com/hashicorp/hcl v1.0.0 // indirect 47 | github.com/improbable-eng/grpc-web v0.15.0 // indirect 48 | github.com/inconshreveable/mousetrap v1.0.0 // indirect 49 | github.com/jackc/pgpassfile v1.0.0 // indirect 50 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 51 | github.com/jackc/pgx/v5 v5.3.0 // indirect 52 | github.com/jinzhu/inflection v1.0.0 // indirect 53 | github.com/jinzhu/now v1.1.5 // indirect 54 | github.com/klauspost/compress v1.15.9 // indirect 55 | github.com/klauspost/cpuid/v2 v2.1.0 // indirect 56 | github.com/magiconair/properties v1.8.6 // indirect 57 | github.com/mattn/go-sqlite3 v1.14.15 // indirect 58 | github.com/mitchellh/go-homedir v1.1.0 // indirect 59 | github.com/mitchellh/mapstructure v1.5.0 // indirect 60 | github.com/pelletier/go-toml v1.9.5 // indirect 61 | github.com/pelletier/go-toml/v2 v2.0.2 // indirect 62 | github.com/rs/cors v1.8.2 // indirect 63 | github.com/soheilhy/cmux v0.1.5 // indirect 64 | github.com/spf13/afero v1.9.2 // indirect 65 | github.com/spf13/cast v1.5.0 // indirect 66 | github.com/spf13/pflag v1.0.5 // indirect 67 | github.com/subosito/gotenv v1.4.0 // indirect 68 | github.com/zeebo/blake3 v0.2.3 // indirect 69 | gitlab.com/xx_network/ring v0.0.3-0.20220902183151-a7d3b15bc981 // indirect 70 | go.opencensus.io v0.24.0 // indirect 71 | go.uber.org/atomic v1.10.0 // indirect 72 | golang.org/x/crypto v0.6.0 // indirect 73 | golang.org/x/net v0.6.0 // indirect 74 | golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect 75 | golang.org/x/sync v0.1.0 // indirect 76 | golang.org/x/sys v0.5.0 // indirect 77 | golang.org/x/text v0.7.0 // indirect 78 | golang.org/x/time v0.1.0 // indirect 79 | golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect 80 | google.golang.org/appengine v1.6.7 // indirect 81 | google.golang.org/genproto v0.0.0-20221205194025-8222ab48f5fc // indirect 82 | google.golang.org/grpc v1.51.0 // indirect 83 | google.golang.org/protobuf v1.28.1 // indirect 84 | gopkg.in/ini.v1 v1.66.6 // indirect 85 | gopkg.in/yaml.v2 v2.4.0 // indirect 86 | gopkg.in/yaml.v3 v3.0.1 // indirect 87 | nhooyr.io/websocket v1.8.7 // indirect 88 | src.agwa.name/tlshacks v0.0.0-20220518131152-d2c6f4e2b780 // indirect 89 | ) 90 | -------------------------------------------------------------------------------- /notifications/ndf.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // ndf controls gateway updates from the permissioning server 9 | 10 | package notifications 11 | 12 | import ( 13 | //"github.com/pkg/errors" 14 | "bytes" 15 | "github.com/pkg/errors" 16 | jww "github.com/spf13/jwalterweatherman" 17 | pb "gitlab.com/elixxir/comms/mixmessages" 18 | "gitlab.com/elixxir/notifications-bot/io" 19 | "sync/atomic" 20 | "time" 21 | ) 22 | 23 | // Stopper function that stops the thread on a timeout 24 | type Stopper func(timeout time.Duration) bool 25 | 26 | // GatewaysChanged function processes the gateways changed event when detected 27 | // in the NDF 28 | type GatewaysChanged func(ndf pb.NDF) ([]byte, error) 29 | 30 | // TrackNdf kicks off the ndf tracking thread 31 | func (nb *Impl) TrackNdf() { 32 | // Handler function for the gateways changed event 33 | gatewayEventHandler := func(ndf pb.NDF) ([]byte, error) { 34 | jww.DEBUG.Printf("Updating Gateways with new NDF") 35 | // TODO: If this returns an error, print that error if it occurs 36 | err := nb.inst.UpdatePartialNdf(&ndf) 37 | if err != nil { 38 | return nil, errors.WithMessage(err, "Failed to update partial NDF") 39 | } 40 | err = nb.inst.UpdateGatewayConnections() 41 | if err != nil { 42 | return nil, errors.WithMessage(err, "Failed to update gateway connections") 43 | } 44 | atomic.SwapUint32(nb.receivedNdf, 1) 45 | return nb.inst.GetPartialNdf().GetHash(), nil 46 | } 47 | 48 | // Stopping function for the thread 49 | quitCh := make(chan bool) 50 | nb.ndfStopper = func(timeout time.Duration) bool { 51 | select { 52 | case quitCh <- true: 53 | return true 54 | case <-time.After(timeout): 55 | jww.ERROR.Printf("Could not stop NDF Tracking Thread") 56 | return false 57 | } 58 | } 59 | 60 | // Polling object 61 | permHost, _ := nb.Comms.GetHost(nb.inst.GetPermissioningId()) 62 | poller := io.NewNdfPoller(nb.Comms, permHost) 63 | 64 | go trackNdf(poller, quitCh, gatewayEventHandler) 65 | } 66 | 67 | func trackNdf(poller io.PollingConn, quitCh chan bool, gwEvt GatewaysChanged) { 68 | pollDelay := 1 * time.Second 69 | hashCh := make(chan []byte, 1) 70 | lastNdf := pb.NDF{Ndf: []byte{}} 71 | lastNdfHash := []byte{} 72 | for { 73 | jww.TRACE.Printf("Polling for NDF") 74 | 75 | // FIXME: This is mildly hacky because we rely on the call back 76 | // to return the ndf hash right now. 77 | select { 78 | case newHash := <-hashCh: 79 | lastNdfHash = newHash 80 | default: 81 | break 82 | } 83 | 84 | ndf, err := poller.PollNdf(lastNdfHash) 85 | if err != nil { 86 | jww.ERROR.Printf("polling ndf: %+v", err) 87 | ndf = nil 88 | } 89 | 90 | // If the cur differs from the last one, trigger the update 91 | // event 92 | // TODO: Improve this to only trigger when gatways are updated 93 | // this isn't useful right now because gw event handlers 94 | // actually update the full ndf each time, so it's a 95 | // choice between comparing the full hash or additional 96 | // network traffic given the current state of API. 97 | if ndf != nil && len(ndf.Ndf) > 0 && !bytes.Equal(ndf.Ndf, lastNdf.Ndf) { 98 | // FIXME: we should be able to get hash from the ndf 99 | // object, but we can't. 100 | go func() { 101 | h, err := gwEvt(*ndf) 102 | if err != nil { 103 | jww.ERROR.Println(err) 104 | return 105 | } 106 | hashCh <- h 107 | }() 108 | lastNdf = *ndf 109 | } 110 | 111 | select { 112 | case <-quitCh: 113 | jww.DEBUG.Printf("Exiting trackNDF thread...") 114 | return 115 | case <-time.After(pollDelay): 116 | continue 117 | } 118 | } 119 | } 120 | 121 | func (nb *Impl) ReceivedNdf() *uint32 { 122 | return nb.receivedNdf 123 | } 124 | 125 | func (nb *Impl) Cleaner() { 126 | cleanF := func(key, val interface{}) bool { 127 | t := val.(time.Time) 128 | if time.Since(t) > (5 * time.Minute) { 129 | nb.roundStore.Delete(key) 130 | } 131 | return true 132 | } 133 | 134 | cleanTicker := time.NewTicker(time.Minute * 10) 135 | 136 | for { 137 | select { 138 | case <-cleanTicker.C: 139 | nb.roundStore.Range(cleanF) 140 | } 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /notifications/legacyRegistration.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import ( 11 | "github.com/pkg/errors" 12 | pb "gitlab.com/elixxir/comms/mixmessages" 13 | "gitlab.com/elixxir/crypto/hash" 14 | "gitlab.com/elixxir/crypto/registration" 15 | "gitlab.com/elixxir/notifications-bot/constants" 16 | "gitlab.com/xx_network/crypto/signature/rsa" 17 | "gitlab.com/xx_network/primitives/id" 18 | "gitlab.com/xx_network/primitives/id/ephemeral" 19 | "strings" 20 | "time" 21 | ) 22 | 23 | // RegisterForNotifications is called by the client, and adds a user registration to our database 24 | func (nb *Impl) RegisterForNotifications(request *pb.NotificationRegisterRequest) error { 25 | var err error 26 | // Check auth & inputs 27 | if string(request.Token) == "" { 28 | return errors.New("Cannot register for notifications with empty client token") 29 | } 30 | 31 | // Verify permissioning RSA signature 32 | permHost, ok := nb.Comms.GetHost(&id.Permissioning) 33 | if !ok { 34 | return errors.New("Could not find permissioning host to verify client signature") 35 | } 36 | err = registration.VerifyWithTimestamp(permHost.GetPubKey(), request.RegistrationTimestamp, 37 | string(request.TransmissionRsa), request.TransmissionRsaSig) 38 | if err != nil { 39 | return errors.WithMessage(err, "Failed to verify perm sig with timestamp") 40 | } 41 | 42 | // Verify IID transmission RSA signature 43 | h, err := hash.NewCMixHash() 44 | if err != nil { 45 | return errors.WithMessage(err, "Failed to create cmix hash") 46 | } 47 | _, err = h.Write(request.IntermediaryId) 48 | if err != nil { 49 | return errors.Wrap(err, "Failed to write intermediary id to hash") 50 | } 51 | pub, err := rsa.LoadPublicKeyFromPem(request.TransmissionRsa) 52 | if err != nil { 53 | return errors.WithMessage(err, "Failed to load public key from bytes") 54 | } 55 | err = rsa.Verify(pub, hash.CMixHash, h.Sum(nil), request.IIDTransmissionRsaSig, nil) 56 | if err != nil { 57 | return errors.Wrap(err, "Failed to verify IID signature from client") 58 | } 59 | 60 | // Add the user to storage 61 | _, epoch := ephemeral.HandleQuantization(time.Now()) 62 | 63 | var app string 64 | if strings.Contains(request.Token, ":") { 65 | app = constants.MessengerAndroid.String() 66 | } else { 67 | app = constants.MessengerIOS.String() 68 | } 69 | 70 | _, err = nb.Storage.RegisterForNotifications(request.IntermediaryId, request.TransmissionRsa, request.Token, app, epoch, nb.inst.GetPartialNdf().Get().AddressSpace[0].Size) 71 | if err != nil { 72 | return errors.Wrap(err, "Failed to register user with notifications") 73 | } 74 | 75 | return nil 76 | } 77 | 78 | // UnregisterForNotifications is called by the client, and removes a user registration from our database 79 | func (nb *Impl) UnregisterForNotifications(request *pb.NotificationUnregisterRequest) error { 80 | h, err := hash.NewCMixHash() 81 | if err != nil { 82 | return errors.WithMessage(err, "Failed to create cmix hash") 83 | } 84 | _, err = h.Write(request.IntermediaryId) 85 | if err != nil { 86 | return errors.WithMessage(err, "Failed to write intermediary id to hash") 87 | } 88 | 89 | ident, err := nb.Storage.GetIdentity(request.IntermediaryId) 90 | if err != nil { 91 | return errors.WithMessagef(err, "Failed to find user with intermediary ID %+v", request.IntermediaryId) 92 | } 93 | 94 | // Get the user by identity 95 | // Error if the identity has more than one registered user 96 | if len(ident.Users) != 1 { 97 | return errors.Errorf("Cannot legacy unregister an IID with more than one active user") 98 | } 99 | u := ident.Users[0] 100 | 101 | pub, err := rsa.LoadPublicKeyFromPem(u.TransmissionRSA) 102 | if err != nil { 103 | return errors.WithMessage(err, "Failed to load public key from database") 104 | } 105 | err = rsa.Verify(pub, hash.CMixHash, h.Sum(nil), request.IIDTransmissionRsaSig, nil) 106 | if err != nil { 107 | return errors.Wrap(err, "Failed to verify IID signature from client") 108 | } 109 | err = nb.Storage.LegacyUnregister(request.IntermediaryId) 110 | if err != nil { 111 | return errors.Wrap(err, "Failed to unregister user with notifications") 112 | } 113 | return nil 114 | } 115 | -------------------------------------------------------------------------------- /cmd/version_vars.go: -------------------------------------------------------------------------------- 1 | // Code generated by go generate; DO NOT EDIT. 2 | // This file was generated by robots at 3 | // 2023-07-26 15:03:38.041301298 -0700 PDT m=+0.007681005 4 | 5 | package cmd 6 | 7 | const GITVERSION = `ef4a53c Merge branch 'release' of git.xx.network:elixxir/notifications-bot into project/HavenBeta` 8 | const SEMVER = "3.0.0" 9 | const DEPENDENCIES = `module gitlab.com/elixxir/notifications-bot 10 | 11 | go 1.19 12 | 13 | require ( 14 | firebase.google.com/go v3.12.0+incompatible 15 | github.com/pkg/errors v0.9.1 16 | github.com/sideshow/apns2 v0.20.0 17 | github.com/spf13/cobra v1.5.0 18 | github.com/spf13/jwalterweatherman v1.1.0 19 | github.com/spf13/viper v1.12.0 20 | gitlab.com/elixxir/comms v0.0.4-0.20230608201134-3cac2b04fb52 21 | gitlab.com/elixxir/crypto v0.0.7-0.20230519213156-886b0387c218 22 | gitlab.com/elixxir/primitives v0.0.3-0.20230214180039-9a25e2d3969c 23 | gitlab.com/xx_network/comms v0.0.4-0.20230214180029-5387fb85736d 24 | gitlab.com/xx_network/crypto v0.0.5-0.20230214003943-8a09396e95dd 25 | gitlab.com/xx_network/primitives v0.0.4-0.20230310205521-c440e68e34c4 26 | google.golang.org/api v0.103.0 27 | gorm.io/driver/postgres v1.5.0 28 | gorm.io/driver/sqlite v1.4.4 29 | gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11 30 | ) 31 | 32 | require ( 33 | cloud.google.com/go v0.107.0 // indirect 34 | cloud.google.com/go/compute v1.13.0 // indirect 35 | cloud.google.com/go/compute/metadata v0.2.1 // indirect 36 | cloud.google.com/go/firestore v1.9.0 // indirect 37 | cloud.google.com/go/iam v0.10.0 // indirect 38 | cloud.google.com/go/longrunning v0.3.0 // indirect 39 | cloud.google.com/go/storage v1.27.0 // indirect 40 | git.xx.network/elixxir/grpc-web-go-client v0.0.0-20230214175953-5b5a8c33d28a // indirect 41 | github.com/cenkalti/backoff/v4 v4.1.3 // indirect 42 | github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f // indirect 43 | github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect 44 | github.com/elliotchance/orderedmap v1.4.0 // indirect 45 | github.com/fsnotify/fsnotify v1.5.4 // indirect 46 | github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 // indirect 47 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 48 | github.com/golang/protobuf v1.5.2 // indirect 49 | github.com/google/go-cmp v0.5.9 // indirect 50 | github.com/google/uuid v1.3.0 // indirect 51 | github.com/googleapis/enterprise-certificate-proxy v0.2.0 // indirect 52 | github.com/googleapis/gax-go/v2 v2.7.0 // indirect 53 | github.com/gorilla/websocket v1.5.0 // indirect 54 | github.com/hashicorp/hcl v1.0.0 // indirect 55 | github.com/improbable-eng/grpc-web v0.15.0 // indirect 56 | github.com/inconshreveable/mousetrap v1.0.0 // indirect 57 | github.com/jackc/pgpassfile v1.0.0 // indirect 58 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 59 | github.com/jackc/pgx/v5 v5.3.0 // indirect 60 | github.com/jinzhu/inflection v1.0.0 // indirect 61 | github.com/jinzhu/now v1.1.5 // indirect 62 | github.com/klauspost/compress v1.15.9 // indirect 63 | github.com/klauspost/cpuid/v2 v2.1.0 // indirect 64 | github.com/magiconair/properties v1.8.6 // indirect 65 | github.com/mattn/go-sqlite3 v1.14.15 // indirect 66 | github.com/mitchellh/go-homedir v1.1.0 // indirect 67 | github.com/mitchellh/mapstructure v1.5.0 // indirect 68 | github.com/pelletier/go-toml v1.9.5 // indirect 69 | github.com/pelletier/go-toml/v2 v2.0.2 // indirect 70 | github.com/rs/cors v1.8.2 // indirect 71 | github.com/soheilhy/cmux v0.1.5 // indirect 72 | github.com/spf13/afero v1.9.2 // indirect 73 | github.com/spf13/cast v1.5.0 // indirect 74 | github.com/spf13/pflag v1.0.5 // indirect 75 | github.com/subosito/gotenv v1.4.0 // indirect 76 | github.com/zeebo/blake3 v0.2.3 // indirect 77 | gitlab.com/xx_network/ring v0.0.3-0.20220902183151-a7d3b15bc981 // indirect 78 | go.opencensus.io v0.24.0 // indirect 79 | go.uber.org/atomic v1.10.0 // indirect 80 | golang.org/x/crypto v0.6.0 // indirect 81 | golang.org/x/net v0.6.0 // indirect 82 | golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect 83 | golang.org/x/sync v0.1.0 // indirect 84 | golang.org/x/sys v0.5.0 // indirect 85 | golang.org/x/text v0.7.0 // indirect 86 | golang.org/x/time v0.1.0 // indirect 87 | golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect 88 | google.golang.org/appengine v1.6.7 // indirect 89 | google.golang.org/genproto v0.0.0-20221205194025-8222ab48f5fc // indirect 90 | google.golang.org/grpc v1.51.0 // indirect 91 | google.golang.org/protobuf v1.28.1 // indirect 92 | gopkg.in/ini.v1 v1.66.6 // indirect 93 | gopkg.in/yaml.v2 v2.4.0 // indirect 94 | gopkg.in/yaml.v3 v3.0.1 // indirect 95 | nhooyr.io/websocket v1.8.7 // indirect 96 | src.agwa.name/tlshacks v0.0.0-20220518131152-d2c6f4e2b780 // indirect 97 | ) 98 | ` 99 | -------------------------------------------------------------------------------- /notifications/ephemeral.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import ( 11 | "fmt" 12 | jww "github.com/spf13/jwalterweatherman" 13 | "gitlab.com/elixxir/notifications-bot/storage" 14 | "gitlab.com/xx_network/primitives/id/ephemeral" 15 | "strconv" 16 | "time" 17 | ) 18 | 19 | const offsetPhase = ephemeral.Period / ephemeral.NumOffsets 20 | const creationLead = 5 * time.Minute 21 | const deletionDelay = -(time.Duration(ephemeral.Period) + creationLead) 22 | const ephemeralStateKey = "lastEphemeralOffset" 23 | 24 | // EphIdCreator runs as a thread to track ephemeral IDs for users who registered to receive push notifications 25 | func (nb *Impl) EphIdCreator() { 26 | nb.initCreator() 27 | ticker := time.NewTicker(time.Duration(offsetPhase)) 28 | go nb.addEphemerals(time.Now().Add(creationLead)) 29 | //handle all future epochs 30 | for true { 31 | <-ticker.C 32 | go nb.addEphemerals(time.Now().Add(creationLead)) 33 | } 34 | } 35 | 36 | func (nb *Impl) initCreator() { 37 | // Retrieve most recent ephemeral from storage 38 | var lastEpochTime time.Time 39 | lastEphEpoch, err := nb.Storage.GetStateValue(ephemeralStateKey) 40 | if err != nil { 41 | jww.WARN.Printf("Failed to get latest ephemeral: %+v", err) 42 | lastEpochTime = time.Now().Add(-time.Duration(ephemeral.Period)) 43 | } else { 44 | lastEpochInt, err := strconv.Atoi(lastEphEpoch) 45 | if err != nil { 46 | jww.FATAL.Printf("Failed to convert last epoch to int: %+v", err) 47 | } 48 | lastEpochTime = time.Unix(0, int64(lastEpochInt)*offsetPhase) // Epoch time of last ephemeral ID 49 | // If the last epoch is further back than the ephemeral ID period, only go back one period for generation 50 | if lastEpochTime.Before(time.Now().Add(-time.Duration(ephemeral.Period))) { 51 | lastEpochTime = time.Now().Add(-time.Duration(ephemeral.Period)) 52 | } 53 | } 54 | // Add all missed ephemeral IDs 55 | // increment by offsetPhase up to 5 minutes from now making ephemerals 56 | for endTime := time.Now().Add(creationLead); lastEpochTime.Before(endTime); lastEpochTime = lastEpochTime.Add(time.Duration(offsetPhase)) { 57 | nb.addEphemerals(lastEpochTime) 58 | } 59 | // handle the next epoch 60 | _, epoch := ephemeral.HandleQuantization(lastEpochTime) 61 | nextTrigger := time.Unix(0, int64(epoch)*offsetPhase) 62 | 63 | // Check for users with no associated ephemerals, add them if found (this should not happen unless there were issues) 64 | orphaned, err := nb.Storage.GetOrphanedIdentities() 65 | if err != nil { 66 | jww.FATAL.Panicf("Failed to retrieve orphaned users: %+v", err) 67 | } 68 | if len(orphaned) > 0 { 69 | jww.WARN.Printf("Found %d orphaned users in database", len(orphaned)) 70 | } 71 | for _, i := range orphaned { 72 | _, err := nb.Storage.AddLatestEphemeral(i, epoch, uint(nb.inst.GetPartialNdf().Get().AddressSpace[0].Size)) // TODO: is this the correct epoch? Should we do the previous one as well? 73 | if err != nil { 74 | jww.WARN.Printf("Failed to add latest ephemeral for orphaned identity %+v: %+v", i.IntermediaryId, err) 75 | } 76 | } 77 | 78 | jww.INFO.Println(fmt.Sprintf("Sleeping until next trigger at %+v", nextTrigger)) 79 | time.Sleep(time.Until(nextTrigger)) 80 | } 81 | 82 | func (nb *Impl) addEphemerals(start time.Time) { 83 | currentOffset, epoch := ephemeral.HandleQuantization(start) 84 | def := nb.inst.GetPartialNdf() 85 | // FIXME: Does the address space need more logic here? 86 | err := nb.Storage.AddEphemeralsForOffset(currentOffset, epoch, uint(def.Get().AddressSpace[0].Size), start) 87 | if err != nil { 88 | jww.WARN.Printf("failed to update ephemerals: %+v", err) 89 | } 90 | err = nb.Storage.UpsertState(&storage.State{ 91 | Key: ephemeralStateKey, 92 | Value: strconv.Itoa(int(epoch)), 93 | }) 94 | } 95 | 96 | func (nb *Impl) EphIdDeleter() { 97 | nb.initDeleter() 98 | ticker := time.NewTicker(time.Duration(offsetPhase)) 99 | //handle all future epochs 100 | for true { 101 | <-ticker.C 102 | go nb.deleteEphemerals(time.Now().Add(deletionDelay)) 103 | } 104 | } 105 | 106 | func (nb *Impl) initDeleter() { 107 | //handle the next epoch 108 | _, epoch := ephemeral.HandleQuantization(time.Now()) 109 | nextTrigger := time.Unix(0, int64(epoch+1)*offsetPhase) 110 | // Bring us into phase with ephemeral identity creation 111 | time.Sleep(time.Until(nextTrigger)) 112 | go nb.deleteEphemerals(time.Now().Add(deletionDelay)) 113 | } 114 | 115 | func (nb *Impl) deleteEphemerals(start time.Time) { 116 | fmt.Println("deleteEphemerals") 117 | _, currentEpoch := ephemeral.HandleQuantization(start) 118 | err := nb.Storage.DeleteOldEphemerals(currentEpoch) 119 | if err != nil { 120 | jww.WARN.Printf("failed to delete ephemerals: %+v", err) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /storage/database_test.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "fmt" 5 | "gitlab.com/elixxir/notifications-bot/constants" 6 | "gitlab.com/xx_network/primitives/id" 7 | "gitlab.com/xx_network/primitives/id/ephemeral" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestDatabase(t *testing.T) { 13 | s, err := NewStorage("", "", "cmix", "", "") 14 | if err != nil { 15 | t.Fatal(err) 16 | } 17 | 18 | addressSpace := uint8(16) 19 | 20 | var toNotify []int64 21 | _, epoch := ephemeral.HandleQuantization(time.Now()) 22 | 23 | id1 := id.NewIdFromString("mr_peanutbutter", id.User, t) 24 | iid1, err := ephemeral.GetIntermediaryId(id1) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | eph, _, _, err := ephemeral.GetId(id1, uint(addressSpace), time.Now().UnixNano()) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | toNotify = append(toNotify, eph.Int64()) 33 | 34 | trsa := []byte("trsa") 35 | token1 := "apnstoken01" 36 | token2 := "fcm:token02" 37 | token3 := "apnstoken03" 38 | 39 | id2 := id.NewIdFromString("lex_luthor", id.User, t) 40 | iid2, err := ephemeral.GetIntermediaryId(id2) 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | eph, _, _, err = ephemeral.GetId(id2, uint(addressSpace), time.Now().UnixNano()) 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | toNotify = append(toNotify, eph.Int64()) 49 | 50 | id3 := id.NewIdFromString("spooderman", id.User, t) 51 | iid3, err := ephemeral.GetIntermediaryId(id3) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | eph, _, _, err = ephemeral.GetId(id3, uint(addressSpace), time.Now().UnixNano()) 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | toNotify = append(toNotify, eph.Int64()) 60 | 61 | trsa2 := []byte("trsa2") 62 | id4 := id.NewIdFromString("mr. morales", id.User, t) 63 | iid4, err := ephemeral.GetIntermediaryId(id4) 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | eph, _, _, err = ephemeral.GetId(id4, uint(addressSpace), time.Now().UnixNano()) 68 | if err != nil { 69 | t.Fatal(err) 70 | } 71 | toNotify = append(toNotify, eph.Int64()) 72 | token4 := "fcm:token04" 73 | 74 | // Register user 1 with token 1 and identity 1 75 | _, err = s.RegisterForNotifications(iid1, trsa, token1, constants.MessengerIOS.String(), epoch, addressSpace) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | 80 | // User1: 81 | // Tokens: 1 82 | // Identities: 1 83 | gtnList, err := s.GetToNotify(toNotify) 84 | if err != nil { 85 | t.Fatal(err) 86 | } 87 | if len(gtnList) != 1 { 88 | t.Fatalf("Got wrong gtnlist: %+v", gtnList) 89 | } 90 | 91 | // Register user 2 with token 4 and identity 1 92 | _, err = s.RegisterForNotifications(iid1, trsa2, token4, constants.MessengerAndroid.String(), epoch, addressSpace) 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | 97 | // User1: 98 | // Tokens: 1 99 | // Identities: 1 100 | // User2: 101 | // Tokens: 4 102 | // Identities: 1 103 | gtnList, err = s.GetToNotify(toNotify) 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | if len(gtnList) != 2 { 108 | t.Fatal("Got wrong gtnlist") 109 | } 110 | 111 | // Call identitcal registration on user 1 (no change) 112 | _, err = s.RegisterForNotifications(iid1, trsa, token1, constants.MessengerIOS.String(), epoch, addressSpace) 113 | if err != nil { 114 | t.Fatal(err) 115 | } 116 | 117 | // User1: 118 | // Tokens: 1 119 | // Identities: 1 120 | // User2: 121 | // Tokens: 4 122 | // Identities: 1 123 | gtnList, err = s.GetToNotify(toNotify) 124 | if err != nil { 125 | t.Fatal(err) 126 | } 127 | t.Log(gtnList) 128 | if len(gtnList) != 2 { 129 | t.Fatal("Got wrong gtnlist") 130 | } 131 | 132 | // Register user 1 with identity 2 (still on token 1) 133 | _, err = s.RegisterForNotifications(iid2, trsa, token1, constants.MessengerIOS.String(), epoch, addressSpace) 134 | if err != nil { 135 | t.Fatal(err) 136 | } 137 | 138 | // User1: 139 | // Tokens: 1 140 | // Identities: 1, 2 141 | // User2: 142 | // Tokens: 4 143 | // Identities: 1 144 | gtnList, err = s.GetToNotify(toNotify) 145 | if err != nil { 146 | t.Fatal(err) 147 | } 148 | t.Log(gtnList) 149 | if len(gtnList) != 3 { 150 | t.Fatal("Got wrong gtnlist") 151 | } 152 | 153 | // Register user 1 with token 2 154 | _, err = s.RegisterForNotifications(iid2, trsa, token2, constants.MessengerAndroid.String(), epoch, addressSpace) 155 | if err != nil { 156 | t.Fatal(err) 157 | } 158 | 159 | // User1: 160 | // Tokens: 1, 2 161 | // Identities: 1, 2 162 | // User2: 163 | // Tokens: 4 164 | // Identities: 1 165 | gtnList, err = s.GetToNotify(toNotify) 166 | if err != nil { 167 | t.Fatal(err) 168 | } 169 | if len(gtnList) != 5 { 170 | t.Fatal("Got wrong gtnlist") 171 | } 172 | 173 | // Register user 1 with token3 and identity3 174 | _, err = s.RegisterForNotifications(iid3, trsa, token3, constants.MessengerIOS.String(), epoch, addressSpace) 175 | if err != nil { 176 | t.Fatal(err) 177 | } 178 | 179 | // User1: 180 | // Tokens: 1, 2, 3 181 | // Identities: 1, 2, 3 182 | // User2: 183 | // Tokens: 4 184 | // Identities: 1 185 | gtnList, err = s.GetToNotify(toNotify) 186 | if err != nil { 187 | t.Fatal(err) 188 | } 189 | if len(gtnList) != 10 { 190 | t.Fatalf("Got wrong gtnlist: %+v", gtnList) 191 | } 192 | 193 | // Register user 2 with identity 4 194 | _, err = s.RegisterForNotifications(iid4, trsa2, token4, constants.MessengerAndroid.String(), epoch, addressSpace) 195 | if err != nil { 196 | t.Fatal(err) 197 | } 198 | 199 | // User1: 200 | // Tokens: 1, 2, 3 201 | // Identities: 1, 2, 3 202 | // User2: 203 | // Tokens: 4 204 | // Identities: 1, 4 205 | gtnList, err = s.GetToNotify(toNotify) 206 | if err != nil { 207 | t.Fatal(err) 208 | } 209 | if len(gtnList) != 11 { 210 | t.Fatalf("Got wrong gtnlist: %+v", gtnList) 211 | } 212 | 213 | gtnList, err = s.GetToNotify([]int64{toNotify[0]}) 214 | if len(gtnList) != 4 { 215 | fmt.Println(toNotify[0]) 216 | t.Log(len(gtnList)) 217 | t.Fatalf("Got wrong gtnlist: %+v", gtnList) 218 | } 219 | 220 | } 221 | -------------------------------------------------------------------------------- /notifications/legacyRegistration_test.go: -------------------------------------------------------------------------------- 1 | package notifications 2 | 3 | import ( 4 | pb "gitlab.com/elixxir/comms/mixmessages" 5 | "gitlab.com/elixxir/crypto/hash" 6 | "gitlab.com/elixxir/crypto/registration" 7 | "gitlab.com/elixxir/notifications-bot/storage" 8 | "gitlab.com/xx_network/comms/connect" 9 | "gitlab.com/xx_network/crypto/csprng" 10 | "gitlab.com/xx_network/crypto/signature/rsa" 11 | "gitlab.com/xx_network/primitives/id" 12 | "gitlab.com/xx_network/primitives/id/ephemeral" 13 | "gitlab.com/xx_network/primitives/utils" 14 | "os" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | // Unit test for RegisterForNotifications 20 | func TestImpl_RegisterForNotifications(t *testing.T) { 21 | impl := getNewImpl() 22 | var err error 23 | impl.Storage, err = storage.NewStorage("", "", "", "", "") 24 | if err != nil { 25 | t.Errorf("Failed to create storage: %+v", err) 26 | } 27 | wd, err := os.Getwd() 28 | if err != nil { 29 | t.Errorf("Failed to get working dir: %+v", err) 30 | } 31 | permCert, err := utils.ReadFile(wd + "/../testutil/cmix.rip.crt") 32 | if err != nil { 33 | t.Errorf("Failed to read test cert file: %+v", err) 34 | } 35 | permKey, err := utils.ReadFile(wd + "/../testutil/cmix.rip.key") 36 | if err != nil { 37 | t.Errorf("Failed to read test key file: %+v", err) 38 | } 39 | private, err := rsa.GenerateKey(csprng.NewSystemRNG(), 4096) 40 | if err != nil { 41 | t.Errorf("Failed to create private key: %+v", err) 42 | } 43 | public := private.GetPublic() 44 | key := rsa.CreatePrivateKeyPem(private) 45 | crt := rsa.CreatePublicKeyPem(public) 46 | uid := id.NewIdFromString("zezima", id.User, t) 47 | iid, err := ephemeral.GetIntermediaryId(uid) 48 | if err != nil { 49 | t.Errorf("Failed to make iid: %+v", err) 50 | } 51 | h, err := hash.NewCMixHash() 52 | if err != nil { 53 | t.Errorf("Failed to make cmix hash: %+v", err) 54 | } 55 | _, err = h.Write(iid) 56 | if err != nil { 57 | t.Errorf("Failed to write to hash: %+v", err) 58 | } 59 | pk, err := rsa.LoadPrivateKeyFromPem(key) 60 | if err != nil { 61 | t.Errorf("Failed to load pk from pem: %+v", err) 62 | } 63 | sig, err := rsa.Sign(csprng.NewSystemRNG(), pk, hash.CMixHash, h.Sum(nil), nil) 64 | if err != nil { 65 | t.Errorf("Failed to sign: %+v", err) 66 | } 67 | _, err = impl.Comms.AddHost(&id.Permissioning, "0.0.0.0", permCert, connect.GetDefaultHostParams()) 68 | if err != nil { 69 | t.Errorf("Failed to add host: %+v", err) 70 | } 71 | loadedPermKey, err := rsa.LoadPrivateKeyFromPem(permKey) 72 | if err != nil { 73 | t.Errorf("Failed to load perm key from bytes: %+v", err) 74 | } 75 | ts := time.Now().UnixNano() 76 | psig, err := registration.SignWithTimestamp(csprng.NewSystemRNG(), loadedPermKey, ts, string(crt)) 77 | 78 | err = impl.RegisterForNotifications(&pb.NotificationRegisterRequest{ 79 | Token: "token", 80 | IntermediaryId: iid, 81 | TransmissionRsa: crt, 82 | TransmissionSalt: []byte("salt"), 83 | TransmissionRsaSig: psig, 84 | IIDTransmissionRsaSig: sig, 85 | RegistrationTimestamp: ts, 86 | }) 87 | if err != nil { 88 | t.Errorf("Failed to register for notifications: %+v", err) 89 | } 90 | } 91 | 92 | // Unit test for UnregisterForNotifications 93 | func TestImpl_UnregisterForNotifications(t *testing.T) { 94 | impl := getNewImpl() 95 | var err error 96 | impl.Storage, err = storage.NewStorage("", "", "", "", "") 97 | if err != nil { 98 | t.Errorf("Failed to create storage: %+v", err) 99 | } 100 | wd, err := os.Getwd() 101 | if err != nil { 102 | t.Errorf("Failed to get working dir: %+v", err) 103 | } 104 | permCert, err := utils.ReadFile(wd + "/../testutil/cmix.rip.crt") 105 | if err != nil { 106 | t.Errorf("Failed to read test cert file: %+v", err) 107 | } 108 | permKey, err := utils.ReadFile(wd + "/../testutil/cmix.rip.key") 109 | if err != nil { 110 | t.Errorf("Failed to read test key file: %+v", err) 111 | } 112 | private, err := rsa.GenerateKey(csprng.NewSystemRNG(), 4096) 113 | if err != nil { 114 | t.Errorf("Failed to create private key: %+v", err) 115 | } 116 | public := private.GetPublic() 117 | key := rsa.CreatePrivateKeyPem(private) 118 | crt := rsa.CreatePublicKeyPem(public) 119 | uid := id.NewIdFromString("zezima", id.User, t) 120 | iid, err := ephemeral.GetIntermediaryId(uid) 121 | if err != nil { 122 | t.Errorf("Failed to get intermediary ID: %+v", err) 123 | } 124 | h, err := hash.NewCMixHash() 125 | if err != nil { 126 | t.Errorf("Failed to make cmix hash: %+v", err) 127 | } 128 | _, err = h.Write(iid) 129 | if err != nil { 130 | t.Errorf("Failed to write to hash: %+v", err) 131 | } 132 | pk, err := rsa.LoadPrivateKeyFromPem(key) 133 | if err != nil { 134 | t.Errorf("Failed to load pk from pem: %+v", err) 135 | } 136 | sig, err := rsa.Sign(csprng.NewSystemRNG(), pk, hash.CMixHash, h.Sum(nil), nil) 137 | if err != nil { 138 | t.Errorf("Failed to sign: %+v", err) 139 | } 140 | _, err = impl.Comms.AddHost(&id.Permissioning, "0.0.0.0", permCert, connect.GetDefaultHostParams()) 141 | if err != nil { 142 | t.Errorf("Failed to add host: %+v", err) 143 | } 144 | h.Reset() 145 | _, err = h.Write(crt) 146 | if err != nil { 147 | t.Errorf("Failed to write to hash: %+v", err) 148 | } 149 | loadedPermKey, err := rsa.LoadPrivateKeyFromPem(permKey) 150 | if err != nil { 151 | t.Errorf("Failed to load perm key from bytes: %+v", err) 152 | } 153 | ts := time.Now().UnixNano() 154 | psig, err := registration.SignWithTimestamp(csprng.NewSystemRNG(), loadedPermKey, ts, string(crt)) 155 | 156 | err = impl.RegisterForNotifications(&pb.NotificationRegisterRequest{ 157 | Token: "token", 158 | IntermediaryId: iid, 159 | TransmissionRsa: crt, 160 | TransmissionSalt: []byte("salt"), 161 | TransmissionRsaSig: psig, 162 | IIDTransmissionRsaSig: sig, 163 | RegistrationTimestamp: ts, 164 | }) 165 | if err != nil { 166 | t.Errorf("Failed to register for notifications: %+v", err) 167 | } 168 | 169 | err = impl.UnregisterForNotifications(&pb.NotificationUnregisterRequest{ 170 | IntermediaryId: iid, 171 | IIDTransmissionRsaSig: sig, 172 | }) 173 | if err != nil { 174 | t.Errorf("Failed to unregister for notifications: %+v", err) 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /notifications/impl.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Package notifications contains the core logic for interacting with the notifications bot. 9 | // 10 | // This includes registering users, receiving notifications, and sending to providers. 11 | 12 | package notifications 13 | 14 | import ( 15 | "crypto/tls" 16 | "github.com/pkg/errors" 17 | jww "github.com/spf13/jwalterweatherman" 18 | pb "gitlab.com/elixxir/comms/mixmessages" 19 | "gitlab.com/elixxir/comms/network" 20 | "gitlab.com/elixxir/comms/notificationBot" 21 | "gitlab.com/elixxir/notifications-bot/constants" 22 | "gitlab.com/elixxir/notifications-bot/notifications/providers" 23 | "gitlab.com/elixxir/notifications-bot/storage" 24 | "gitlab.com/xx_network/comms/connect" 25 | "gitlab.com/xx_network/primitives/id" 26 | "gitlab.com/xx_network/primitives/ndf" 27 | "gitlab.com/xx_network/primitives/netTime" 28 | "gitlab.com/xx_network/primitives/utils" 29 | "sync" 30 | ) 31 | 32 | // Impl for notifications; holds comms, storage object, creds and main functions 33 | type Impl struct { 34 | Comms *notificationBot.Comms 35 | Storage *storage.Storage 36 | inst *network.Instance 37 | receivedNdf *uint32 38 | roundStore sync.Map 39 | maxNotifications int 40 | maxPayloadBytes int 41 | 42 | providers map[string]providers.Provider 43 | 44 | ndfStopper Stopper 45 | } 46 | 47 | // StartNotifications creates an Impl from the information passed in 48 | func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) { 49 | var cert, key []byte 50 | var err error 51 | 52 | // Read in private key 53 | if params.KeyPath != "" { 54 | key, err = utils.ReadFile(params.KeyPath) 55 | if err != nil { 56 | return nil, errors.Wrapf(err, "failed to read key at %+v", params.KeyPath) 57 | } 58 | } else { 59 | jww.WARN.Println("Running without key...") 60 | } 61 | 62 | if !noTLS { 63 | // Read in TLS keys from files 64 | cert, err = utils.ReadFile(params.CertPath) 65 | if err != nil { 66 | return nil, errors.Wrapf(err, "failed to read certificate at %+v", params.CertPath) 67 | } 68 | } 69 | 70 | receivedNdf := uint32(0) 71 | 72 | impl := &Impl{ 73 | providers: map[string]providers.Provider{}, 74 | receivedNdf: &receivedNdf, 75 | maxNotifications: params.NotificationsPerBatch, 76 | maxPayloadBytes: params.MaxNotificationPayload, 77 | } 78 | 79 | // Set up firebase messaging client 80 | if !noFirebase { 81 | impl.providers[constants.MessengerAndroid.String()], err = providers.NewFCM(params.FBCreds) 82 | if err != nil { 83 | jww.WARN.Printf("Failed to start firebase provider for %s", constants.MessengerAndroid) 84 | } 85 | 86 | if params.HavenFBCreds != "" { 87 | impl.providers[constants.HavenAndroid.String()], err = providers.NewFCM(params.HavenFBCreds) 88 | if err != nil { 89 | jww.WARN.Printf("Failed to start firebase provider for %s", constants.HavenAndroid) 90 | } 91 | } 92 | } 93 | 94 | if params.KeyPath == "" { 95 | jww.WARN.Println("WARNING: RUNNING WITHOUT APNS") 96 | } else { 97 | impl.providers[constants.MessengerIOS.String()], err = providers.NewApns(params.APNS) 98 | if err != nil { 99 | jww.WARN.Printf("Failed to start APNS provider for %s", constants.MessengerIOS) 100 | } 101 | } 102 | 103 | if params.HavenAPNS.KeyPath == "" { 104 | jww.WARN.Println("WARNING: RUNNING WITHOUT HAVEN APNS") 105 | } else { 106 | impl.providers[constants.HavenIOS.String()], err = providers.NewApns(params.HavenAPNS) 107 | if err != nil { 108 | jww.WARN.Printf("Failed to start APNS provider for %s", constants.HavenIOS) 109 | } 110 | } 111 | 112 | // Start notification comms server 113 | handler := NewImplementation(impl) 114 | comms := notificationBot.StartNotificationBot(&id.NotificationBot, params.Address, handler, cert, key) 115 | impl.Comms = comms 116 | i, err := network.NewInstance(impl.Comms.ProtoComms, &ndf.NetworkDefinition{AddressSpace: []ndf.AddressSpace{{Size: 16, Timestamp: netTime.Now()}}}, nil, nil, network.None, false) 117 | if err != nil { 118 | return nil, errors.WithMessage(err, "Failed to start instance") 119 | } 120 | i.SetGatewayAuthentication() 121 | impl.inst = i 122 | 123 | go impl.Cleaner() 124 | go impl.Sender(params.NotificationRate) 125 | 126 | go func() { 127 | if params.HttpsKeyPath == "" || params.HttpsCertPath == "" { 128 | jww.WARN.Println("Running without HTTPS") 129 | return 130 | } 131 | httpsCertificate, err := tls.LoadX509KeyPair(params.HttpsCertPath, params.HttpsKeyPath) 132 | if err != nil { 133 | jww.ERROR.Printf("Failed to load https certificate: %+v", err) 134 | return 135 | } 136 | err = comms.ServeHttps(httpsCertificate) 137 | if err != nil { 138 | jww.ERROR.Printf("Failed to serve HTTPS: %+v", err) 139 | } 140 | }() 141 | 142 | return impl, nil 143 | } 144 | 145 | // NewImplementation initializes impl object 146 | func NewImplementation(instance *Impl) *notificationBot.Implementation { 147 | impl := notificationBot.NewImplementation() 148 | 149 | impl.Functions.RegisterForNotifications = func(request *pb.NotificationRegisterRequest) error { 150 | return instance.RegisterForNotifications(request) 151 | } 152 | 153 | impl.Functions.UnregisterForNotifications = func(request *pb.NotificationUnregisterRequest) error { 154 | return instance.UnregisterForNotifications(request) 155 | } 156 | 157 | impl.Functions.ReceiveNotificationBatch = func(data *pb.NotificationBatch, auth *connect.Auth) error { 158 | return instance.ReceiveNotificationBatch(data, auth) 159 | } 160 | impl.Functions.RegisterToken = func(msg *pb.RegisterTokenRequest) error { 161 | err := instance.RegisterToken(msg) 162 | if err != nil { 163 | jww.ERROR.Printf("Failed to RegisterToken: %+v", err) 164 | } 165 | return err 166 | } 167 | impl.Functions.RegisterTrackedID = func(msg *pb.RegisterTrackedIdRequest) error { 168 | err := instance.RegisterTrackedID(msg) 169 | if err != nil { 170 | jww.ERROR.Printf("Failed to RegisterTrackedID: %+v", err) 171 | } 172 | return err 173 | } 174 | impl.Functions.UnregisterToken = func(msg *pb.UnregisterTokenRequest) error { 175 | err := instance.UnregisterToken(msg) 176 | if err != nil { 177 | jww.ERROR.Printf("Failed to UnregisterToken: %+v", err) 178 | } 179 | return err 180 | } 181 | impl.Functions.UnregisterTrackedID = func(msg *pb.UnregisterTrackedIdRequest) error { 182 | err := instance.UnregisterTrackedID(msg.Request) 183 | if err != nil { 184 | jww.ERROR.Printf("Failed to UnregisterTrackedID: %+v", err) 185 | } 186 | return err 187 | } 188 | 189 | return impl 190 | } 191 | -------------------------------------------------------------------------------- /notifications/registration.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import ( 11 | "encoding/base64" 12 | "github.com/pkg/errors" 13 | jww "github.com/spf13/jwalterweatherman" 14 | pb "gitlab.com/elixxir/comms/mixmessages" 15 | "gitlab.com/elixxir/crypto/notifications" 16 | "gitlab.com/elixxir/crypto/registration" 17 | "gitlab.com/elixxir/crypto/rsa" 18 | "gitlab.com/xx_network/primitives/id" 19 | "gitlab.com/xx_network/primitives/id/ephemeral" 20 | "time" 21 | ) 22 | 23 | var timestampError = "Timestamp of request must be within last 5 seconds. Request timestamp: %s, current time: %s" 24 | 25 | // RegisterToken registers the given token. It evaluates that the TransmissionRsaRegistarSig is 26 | // correct. The RSA->PEM relationship is one to many. It will succeed if the token is already 27 | // registered. 28 | func (nb *Impl) RegisterToken(msg *pb.RegisterTokenRequest) error { 29 | jww.INFO.Println("RegisterToken") 30 | requestTimestamp := time.Unix(0, msg.RequestTimestamp) 31 | if time.Now().Sub(requestTimestamp) > time.Second*5 { 32 | return errors.Errorf(timestampError, requestTimestamp.String(), time.Now().String()) 33 | } 34 | // Verify permissioning RSA signature 35 | permHost, ok := nb.Comms.GetHost(&id.Permissioning) 36 | if !ok { 37 | return errors.New("Could not find permissioning host to verify client signature") 38 | } 39 | jww.INFO.Printf("Verifying perm sig with params:\n\tPubKey: %s\n\tTimestamp: %d\n\tTRSA: %s\n\tSIG: %s\n", base64.StdEncoding.EncodeToString(permHost.GetPubKey().Bytes()), msg.RegistrationTimestamp, base64.StdEncoding.EncodeToString(msg.TransmissionRsaPem), base64.StdEncoding.EncodeToString(msg.TransmissionRsaRegistrarSig)) 40 | err := registration.VerifyWithTimestamp(permHost.GetPubKey(), msg.RegistrationTimestamp, 41 | string(msg.TransmissionRsaPem), msg.TransmissionRsaRegistrarSig) 42 | if err != nil { 43 | return errors.WithMessage(err, "Failed to verify permissioning signature") 44 | } 45 | 46 | // Verify token signature 47 | pub, err := rsa.GetScheme().UnmarshalPublicKeyPEM(msg.TransmissionRsaPem) 48 | if err != nil { 49 | return errors.WithMessage(err, "Failed to unmarshal public key") 50 | } 51 | err = notifications.VerifyToken(pub, msg.Token, msg.App, requestTimestamp, notifications.RegisterTokenTag, msg.TokenSignature) 52 | if err != nil { 53 | return errors.WithMessage(err, "Failed to verify token signature") 54 | } 55 | 56 | return nb.Storage.RegisterToken(msg.Token, msg.App, msg.TransmissionRsaPem) 57 | } 58 | 59 | // RegisterTrackedID registers the given ID to be tracked. The request is signed 60 | // Returns an error if TransmissionRSA is not registered with a valid token. 61 | // The actual ID is not revealed, instead an intermediary value is sent which cannot 62 | // be revered to get the ID, but is repeatable. So it can be rainbow-tabled. 63 | func (nb *Impl) RegisterTrackedID(msg *pb.RegisterTrackedIdRequest) error { 64 | jww.INFO.Println("RegisterTrackedID") 65 | requestTimestamp := time.Unix(0, msg.Request.RequestTimestamp) 66 | if time.Now().Sub(requestTimestamp) > time.Second*5 { 67 | return errors.Errorf(timestampError, requestTimestamp.String(), time.Now().String()) 68 | } 69 | 70 | // Verify permissioning RSA signature 71 | permHost, ok := nb.Comms.GetHost(&id.Permissioning) 72 | if !ok { 73 | return errors.New("Could not find permissioning host to verify client signature") 74 | } 75 | jww.INFO.Printf("Verifying perm sig with params:\n\tPubKey: %s\n\tTimestamp: %d\n\tTRSA: %s\n\tSIG: %s\n", base64.StdEncoding.EncodeToString(permHost.GetPubKey().Bytes()), msg.RegistrationTimestamp, base64.StdEncoding.EncodeToString(msg.Request.TransmissionRsaPem), base64.StdEncoding.EncodeToString(msg.TransmissionRsaRegistrarSig)) 76 | err := registration.VerifyWithTimestamp(permHost.GetPubKey(), msg.RegistrationTimestamp, 77 | string(msg.Request.TransmissionRsaPem), msg.TransmissionRsaRegistrarSig) 78 | if err != nil { 79 | return errors.WithMessage(err, "Failed to verify permissioning signature") 80 | } 81 | 82 | pub, err := rsa.GetScheme().UnmarshalPublicKeyPEM(msg.Request.TransmissionRsaPem) 83 | if err != nil { 84 | return errors.WithMessage(err, "Failed to unmarshal public key") 85 | } 86 | 87 | err = notifications.VerifyIdentity(pub, msg.Request.TrackedIntermediaryID, requestTimestamp, notifications.RegisterTrackedIDTag, msg.Request.Signature) 88 | if err != nil { 89 | return errors.WithMessage(err, "Failed to verify identity signature") 90 | } 91 | _, epoch := ephemeral.HandleQuantization(time.Now()) 92 | 93 | return nb.Storage.RegisterTrackedID(msg.Request.TrackedIntermediaryID, msg.Request.TransmissionRsaPem, epoch, nb.inst.GetPartialNdf().Get().AddressSpace[0].Size) 94 | } 95 | 96 | // UnregisterToken unregisters the given device token. The request is signed. 97 | // Does not return an error if the token cannot be found 98 | func (nb *Impl) UnregisterToken(msg *pb.UnregisterTokenRequest) error { 99 | jww.INFO.Println("UnregisterToken") 100 | requestTimestamp := time.Unix(0, msg.RequestTimestamp) 101 | if time.Now().Sub(requestTimestamp) > time.Second*5 { 102 | return errors.Errorf(timestampError, requestTimestamp.String(), time.Now().String()) 103 | } 104 | 105 | pub, err := rsa.GetScheme().UnmarshalPublicKeyPEM(msg.TransmissionRsaPem) 106 | if err != nil { 107 | return errors.WithMessage(err, "Failed to unmarshal public key") 108 | } 109 | 110 | err = notifications.VerifyToken(pub, msg.Token, msg.App, requestTimestamp, notifications.UnregisterTokenTag, msg.TokenSignature) 111 | if err != nil { 112 | return errors.WithMessage(err, "Failed to verify token signature") 113 | } 114 | 115 | return nb.Storage.UnregisterToken(msg.Token, msg.TransmissionRsaPem) 116 | } 117 | 118 | // UnregisterTrackedID unregisters the given tracked ID. The request is signed. 119 | // Does not return an error if the ID cannot be found 120 | func (nb *Impl) UnregisterTrackedID(msg *pb.TrackedIntermediaryIdRequest) error { 121 | jww.INFO.Println("UnregisterTrackedID") 122 | requestTimestamp := time.Unix(0, msg.RequestTimestamp) 123 | if time.Now().Sub(requestTimestamp) > time.Second*5 { 124 | return errors.Errorf(timestampError, requestTimestamp.String(), time.Now().String()) 125 | } 126 | 127 | pub, err := rsa.GetScheme().UnmarshalPublicKeyPEM(msg.TransmissionRsaPem) 128 | if err != nil { 129 | return errors.WithMessage(err, "Failed to unmarshal public key") 130 | } 131 | 132 | err = notifications.VerifyIdentity(pub, msg.TrackedIntermediaryID, requestTimestamp, notifications.UnregisterTrackedIDTag, msg.Signature) 133 | if err != nil { 134 | return errors.WithMessage(err, "Failed to verify identity signature") 135 | } 136 | 137 | return nb.Storage.UnregisterTrackedIDs(msg.TrackedIntermediaryID, msg.TransmissionRsaPem) 138 | } 139 | -------------------------------------------------------------------------------- /storage/database.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package storage 9 | 10 | import ( 11 | "fmt" 12 | "github.com/pkg/errors" 13 | jww "github.com/spf13/jwalterweatherman" 14 | "gorm.io/driver/postgres" 15 | "gorm.io/driver/sqlite" 16 | "gorm.io/gorm" 17 | "gorm.io/gorm/logger" 18 | "time" 19 | ) 20 | 21 | const postgresConnectString = "host=%s port=%s user=%s dbname=%s sslmode=disable" 22 | const sqliteDatabasePath = "file:%s?mode=memory&cache=shared" 23 | 24 | // interface declaration for storage methods 25 | type database interface { 26 | UpsertState(state *State) error 27 | GetStateValue(key string) (string, error) 28 | 29 | insertUser(user *User) error 30 | GetUser(transmissionRsaHash []byte) (*User, error) 31 | deleteUser(transmissionRsaHash []byte) error 32 | GetAllUsers() ([]*User, error) 33 | 34 | registerTrackedIdentity(user User, identity Identity) error 35 | registerTrackedIdentities(user User, ids []Identity) error 36 | 37 | GetIdentity(iid []byte) (*Identity, error) 38 | insertIdentity(identity *Identity) error 39 | getIdentitiesByOffset(offset int64) ([]*Identity, error) 40 | GetOrphanedIdentities() ([]*Identity, error) 41 | 42 | insertEphemeral(ephemeral *Ephemeral) error 43 | GetEphemeral(ephemeralId int64) ([]*Ephemeral, error) 44 | GetLatestEphemeral() (*Ephemeral, error) 45 | DeleteOldEphemerals(currentEpoch int32) error 46 | GetToNotify(ephemeralIds []int64) ([]GTNResult, error) 47 | 48 | insertToken(token Token) error 49 | DeleteToken(token string) error 50 | 51 | unregisterIdentities(u *User, iids []Identity) error 52 | unregisterTokens(u *User, tokens []Token) error 53 | registerForNotifications(u *User, identity Identity, token Token) error 54 | LegacyUnregister(iid []byte) error 55 | } 56 | 57 | // DatabaseImpl is a struct which implements database on an underlying gorm.DB 58 | type DatabaseImpl struct { 59 | db *gorm.DB // Stored database connection 60 | } 61 | 62 | // State table 63 | type State struct { 64 | Key string `gorm:"primary_key"` 65 | Value string `gorm:"NOT NULL"` 66 | } 67 | 68 | type UserV1 struct { 69 | TransmissionRSAHash []byte `gorm:"primaryKey"` 70 | IntermediaryId []byte `gorm:"not null; index"` 71 | OffsetNum int64 `gorm:"not null; index"` 72 | TransmissionRSA []byte `gorm:"not null"` 73 | Signature []byte `gorm:"not null"` 74 | Token string `gorm:"not null"` 75 | Ephemerals []Ephemeral `gorm:"foreignKey:transmission_rsa_hash;references:transmission_rsa_hash;constraint:OnDelete:CASCADE;"` 76 | } 77 | 78 | type Token struct { 79 | Token string `gorm:"primaryKey"` 80 | App string 81 | TransmissionRSAHash []byte `gorm:"not null;references users(transmission_rsa_hash)"` 82 | } 83 | 84 | type User struct { 85 | TransmissionRSAHash []byte `gorm:"primaryKey"` 86 | TransmissionRSA []byte `gorm:"not null"` 87 | Tokens []Token `gorm:"foreignKey:TransmissionRSAHash;constraint:OnDelete:CASCADE;"` 88 | Identities []Identity `gorm:"many2many:user_identities;"` 89 | } 90 | 91 | // CREATES JOIN TABLE user_identities 92 | // Table "public.user_identities" 93 | // Column | Type | Collation | Nullable | Default 94 | // ----------------------------+-------+-----------+----------+--------- 95 | // user_transmission_rsa_hash | bytea | | not null | 96 | // identity_intermediary_id | bytea | | not null | 97 | // Indexes: 98 | // "user_identities_pkey" PRIMARY KEY, btree (user_transmission_rsa_hash, identity_intermediary_id) 99 | // Foreign-key constraints: 100 | // "fk_user_identities_identity" FOREIGN KEY (identity_intermediary_id) REFERENCES identities(intermediary_id) 101 | // "fk_user_identities_user" FOREIGN KEY (user_transmission_rsa_hash) REFERENCES users(transmission_rsa_hash) 102 | 103 | type Identity struct { 104 | IntermediaryId []byte `gorm:"primaryKey"` 105 | OffsetNum int64 `gorm:"not null; index"` 106 | Users []User `gorm:"many2many:user_identities;"` 107 | Ephemerals []Ephemeral `gorm:"foreignKey:intermediary_id;references:intermediary_id;constraint:OnDelete:CASCADE;"` 108 | } 109 | 110 | type Ephemeral struct { 111 | ID uint `gorm:"primaryKey"` 112 | IntermediaryId []byte `gorm:"not null;references identities(intermediary_id)"` 113 | EphemeralId int64 `gorm:"not null; index"` 114 | Epoch int32 `gorm:"not null; index"` 115 | } 116 | 117 | // Initialize the database interface with database backend 118 | // Returns a database interface, close function, and error 119 | func newDatabase(username, password, dbName, address, 120 | port string) (database, error) { 121 | var err error 122 | var db *gorm.DB 123 | var dialector gorm.Dialector 124 | // Connect to the database if the correct information is provided 125 | usePostgres := address != "" && port != "" 126 | if usePostgres { 127 | // Create the database connection 128 | connectString := fmt.Sprintf( 129 | postgresConnectString, 130 | address, port, username, dbName) 131 | // Handle empty database password 132 | if len(password) > 0 { 133 | connectString += fmt.Sprintf(" password=%s", password) 134 | } 135 | dialector = postgres.Open(connectString) 136 | } else { 137 | jww.WARN.Printf("Database backend connection information not provided") 138 | temporaryDbPath := fmt.Sprintf(sqliteDatabasePath, dbName) 139 | dialector = sqlite.Open(temporaryDbPath) 140 | } 141 | 142 | // Create the database connection 143 | db, err = gorm.Open(dialector, &gorm.Config{ 144 | Logger: logger.New(jww.TRACE, logger.Config{LogLevel: logger.Info}), 145 | }) 146 | if err != nil { 147 | return nil, errors.Errorf("Unable to initialize in-memory sqlite database backend: %+v", err) 148 | } 149 | 150 | if !usePostgres { 151 | // Enable foreign keys because they are disabled in SQLite by default 152 | if err = db.Exec("PRAGMA foreign_keys = ON", nil).Error; err != nil { 153 | return nil, err 154 | } 155 | 156 | // Enable Write Ahead Logging to enable multiple DB connections 157 | if err = db.Exec("PRAGMA journal_mode = WAL;", nil).Error; err != nil { 158 | return nil, err 159 | } 160 | } 161 | 162 | // Get and configure the internal database ConnPool 163 | sqlDb, err := db.DB() 164 | if err != nil { 165 | return nil, errors.Errorf("Unable to configure database connection pool: %+v", err) 166 | } 167 | // SetMaxIdleConns sets the maximum number of connections in the idle connection pool. 168 | sqlDb.SetMaxIdleConns(10) 169 | // SetMaxOpenConns sets the maximum number of open connections to the Database. 170 | sqlDb.SetMaxOpenConns(50) 171 | // SetConnMaxLifetime sets the maximum amount of time a connection may be idle. 172 | sqlDb.SetConnMaxIdleTime(10 * time.Minute) 173 | // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. 174 | sqlDb.SetConnMaxLifetime(12 * time.Hour) 175 | 176 | // Initialize the database schema 177 | // WARNING: Order is important. Do not change without database testing 178 | models := []interface{}{&Token{}, &User{}, &Identity{}, &Ephemeral{}, &State{}} 179 | for _, model := range models { 180 | err = db.AutoMigrate(model) 181 | if err != nil { 182 | return nil, err 183 | } 184 | } 185 | 186 | // Build the interface 187 | di := &DatabaseImpl{ 188 | db: db, 189 | } 190 | 191 | jww.INFO.Println("Database backend initialized successfully!") 192 | return database(di), nil 193 | } 194 | -------------------------------------------------------------------------------- /storage/storage.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package storage 9 | 10 | import ( 11 | "fmt" 12 | "github.com/pkg/errors" 13 | "gitlab.com/elixxir/crypto/hash" 14 | "gitlab.com/xx_network/primitives/id/ephemeral" 15 | "gorm.io/gorm" 16 | "time" 17 | ) 18 | 19 | type Storage struct { 20 | database 21 | notificationBuffer *NotificationBuffer 22 | } 23 | 24 | // NewStorage creates a new Storage object with the given connection parameters 25 | func NewStorage(username, password, dbName, address, port string) (*Storage, error) { 26 | db, err := newDatabase(username, password, dbName, address, port) 27 | nb := NewNotificationBuffer() 28 | storage := &Storage{db, nb} 29 | return storage, err 30 | } 31 | 32 | // RegisterToken registers a token to a user based on their transmission RSA 33 | func (s *Storage) RegisterToken(token, app string, transmissionRSA []byte) error { 34 | transmissionRSAHash, err := getHash(transmissionRSA) 35 | if err != nil { 36 | return errors.WithMessage(err, "Failed to hash transmisssion RSA") 37 | } 38 | 39 | u, err := s.GetUser(transmissionRSAHash) 40 | if err != nil { 41 | if errors.Is(err, gorm.ErrRecordNotFound) { 42 | u = &User{ 43 | TransmissionRSAHash: transmissionRSAHash, 44 | TransmissionRSA: transmissionRSA, 45 | Tokens: []Token{ 46 | {Token: token, TransmissionRSAHash: transmissionRSAHash, App: app}, 47 | }, 48 | } 49 | return s.insertUser(u) 50 | } else { 51 | return err 52 | } 53 | } 54 | 55 | return s.database.insertToken(Token{ 56 | App: app, 57 | Token: token, 58 | TransmissionRSAHash: transmissionRSAHash, 59 | }) 60 | } 61 | 62 | // UnregisterToken token unregisters a token from the user with the passed in RSA 63 | func (s *Storage) UnregisterToken(token string, transmissionRSA []byte) error { 64 | transmissionRSAHash, err := getHash(transmissionRSA) 65 | if err != nil { 66 | return errors.WithMessage(err, "Failed to hash transmisssion RSA") 67 | } 68 | 69 | u, err := s.GetUser(transmissionRSAHash) 70 | if err != nil { 71 | if !errors.Is(err, gorm.ErrRecordNotFound) { 72 | return errors.WithMessage(err, "Failed to retrieve user") 73 | } 74 | return nil 75 | } 76 | 77 | err = s.database.unregisterTokens(u, []Token{{Token: token}}) 78 | if err != nil { 79 | if errors.Is(err, gorm.ErrRecordNotFound) { 80 | return nil 81 | } 82 | return err 83 | } 84 | return nil 85 | } 86 | 87 | // RegisterTrackedID registers a tracked ID for the user with the passed in RSA 88 | func (s *Storage) RegisterTrackedID(iidList [][]byte, transmissionRSA []byte, epoch int32, addressSpace uint8) error { 89 | transmissionRSAHash, err := getHash(transmissionRSA) 90 | if err != nil { 91 | return errors.WithMessage(err, "Failed to hash transmisssion RSA") 92 | } 93 | 94 | u, err := s.GetUser(transmissionRSAHash) 95 | if err != nil { 96 | if errors.Is(err, gorm.ErrRecordNotFound) { 97 | u = &User{ 98 | TransmissionRSAHash: transmissionRSAHash, 99 | TransmissionRSA: transmissionRSA, 100 | } 101 | err = s.insertUser(u) 102 | if err != nil { 103 | return errors.WithMessage(err, "Failed to register user") 104 | } 105 | } else { 106 | return errors.WithMessage(err, "Failed to look up user") 107 | } 108 | } 109 | 110 | var ids []Identity 111 | for _, iid := range iidList { 112 | identity, err := s.GetIdentity(iid) 113 | if err != nil { 114 | if errors.Is(err, gorm.ErrRecordNotFound) { 115 | identity = &Identity{ 116 | IntermediaryId: iid, 117 | OffsetNum: ephemeral.GetOffsetNum(ephemeral.GetOffset(iid)), 118 | } 119 | err = s.insertIdentity(identity) 120 | if err != nil { 121 | return err 122 | } 123 | _, err = s.AddLatestEphemeral(identity, epoch, uint(addressSpace)) 124 | if err != nil { 125 | return err 126 | } 127 | } else { 128 | return err 129 | } 130 | } 131 | ids = append(ids, *identity) 132 | } 133 | 134 | return s.database.registerTrackedIdentities(*u, ids) 135 | } 136 | 137 | // UnregisterTrackedIDs unregisters a tracked id from the user with the passed in RSA 138 | func (s *Storage) UnregisterTrackedIDs(trackedIdList [][]byte, transmissionRSA []byte) error { 139 | transmissionRSAHash, err := getHash(transmissionRSA) 140 | if err != nil { 141 | return errors.WithMessage(err, "Failed to hash transmisssion RSA") 142 | } 143 | 144 | u, err := s.GetUser(transmissionRSAHash) 145 | if err != nil { 146 | if !errors.Is(err, gorm.ErrRecordNotFound) { 147 | return errors.WithMessage(err, "Failed to retrieve user") 148 | } 149 | return nil 150 | } 151 | 152 | var ids []Identity 153 | for _, i := range trackedIdList { 154 | ids = append(ids, Identity{IntermediaryId: i}) 155 | } 156 | 157 | err = s.database.unregisterIdentities(u, ids) 158 | if err != nil { 159 | if errors.Is(err, gorm.ErrRecordNotFound) { 160 | return nil 161 | } 162 | return err 163 | } 164 | return nil 165 | } 166 | 167 | // RegisterForNotifications registers a user with the passed in transmissionRSA 168 | // to receive notifications on the identity with intermediary id iid, with the passed in token 169 | func (s *Storage) RegisterForNotifications(iid, transmissionRSA []byte, token, app string, epoch int32, addressSpace uint8) (*User, error) { 170 | transmissionRSAHash, err := getHash(transmissionRSA) 171 | if err != nil { 172 | return nil, errors.WithMessage(err, "Failed to hash transmisssion RSA") 173 | } 174 | identity, err := s.GetIdentity(iid) 175 | if err != nil { 176 | if errors.Is(err, gorm.ErrRecordNotFound) { 177 | identity = &Identity{ 178 | IntermediaryId: iid, 179 | OffsetNum: ephemeral.GetOffsetNum(ephemeral.GetOffset(iid)), 180 | } 181 | err = s.insertIdentity(identity) 182 | if err != nil { 183 | return nil, err 184 | } 185 | _, err = s.AddLatestEphemeral(identity, epoch, uint(addressSpace)) 186 | if err != nil { 187 | return nil, err 188 | } 189 | } else { 190 | return nil, err 191 | } 192 | } 193 | 194 | u, err := s.GetUser(transmissionRSAHash) 195 | if err != nil { 196 | if errors.Is(err, gorm.ErrRecordNotFound) { 197 | u = &User{ 198 | TransmissionRSAHash: transmissionRSAHash, 199 | TransmissionRSA: transmissionRSA, 200 | Tokens: []Token{ 201 | {Token: token, TransmissionRSAHash: transmissionRSAHash, App: app}, 202 | }, Identities: []Identity{*identity}, 203 | } 204 | return u, s.insertUser(u) 205 | } else { 206 | return nil, err 207 | } 208 | } 209 | 210 | return u, s.registerForNotifications(u, *identity, Token{Token: token, App: app, TransmissionRSAHash: transmissionRSAHash}) 211 | } 212 | 213 | // AddLatestEphemeral generates an ephemeral ID for the passed in identity and adds it to storage 214 | func (s *Storage) AddLatestEphemeral(i *Identity, epoch int32, size uint) (*Ephemeral, error) { 215 | now := time.Now() 216 | eid, _, _, err := ephemeral.GetIdFromIntermediary(i.IntermediaryId, size, now.UnixNano()) 217 | if err != nil { 218 | return nil, errors.WithMessage(err, "Failed to get ephemeral id for user") 219 | } 220 | 221 | e := &Ephemeral{ 222 | IntermediaryId: i.IntermediaryId, 223 | EphemeralId: eid.Int64(), 224 | Epoch: epoch, 225 | } 226 | err = s.insertEphemeral(e) 227 | if err != nil { 228 | return nil, err 229 | } 230 | 231 | eid2, _, _, err := ephemeral.GetIdFromIntermediary(i.IntermediaryId, size, now.Add(5*time.Minute).UnixNano()) 232 | if err != nil { 233 | return nil, errors.WithMessage(err, "Failed to get ephemeral id for user") 234 | } 235 | if eid2.Int64() != eid.Int64() { 236 | e := &Ephemeral{ 237 | IntermediaryId: i.IntermediaryId, 238 | EphemeralId: eid2.Int64(), 239 | Epoch: epoch + 1, 240 | } 241 | fmt.Printf("Adding ephemeral: %+v\n", e) 242 | err = s.insertEphemeral(e) 243 | if err != nil { 244 | return nil, err 245 | } 246 | } 247 | 248 | return e, err 249 | } 250 | 251 | // AddEphemeralsForOffset generates new ephemerals for all identities with the given offset, using the passed in parameters 252 | func (s *Storage) AddEphemeralsForOffset(offset int64, epoch int32, size uint, t time.Time) error { 253 | identities, err := s.getIdentitiesByOffset(offset) 254 | if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { 255 | return errors.WithMessage(err, "Failed to get users for given offset") 256 | } 257 | if len(identities) > 0 { 258 | fmt.Println(fmt.Sprintf("Adding ephemerals for identities: %+v", identities)) 259 | } 260 | for _, i := range identities { 261 | eid, _, _, err := ephemeral.GetIdFromIntermediary(i.IntermediaryId, size, t.UnixNano()) 262 | if err != nil { 263 | return errors.WithMessage(err, "Failed to get eid for user") 264 | } 265 | err = s.insertEphemeral(&Ephemeral{ 266 | IntermediaryId: i.IntermediaryId, 267 | EphemeralId: eid.Int64(), 268 | Epoch: epoch, 269 | }) 270 | if err != nil { 271 | return errors.WithMessage(err, "Failed to insert ephemeral ID for user") 272 | } 273 | } 274 | return nil 275 | } 276 | 277 | func (s *Storage) GetNotificationBuffer() *NotificationBuffer { 278 | return s.notificationBuffer 279 | } 280 | 281 | func getHash(transmissionRSA []byte) (transmissionRSAHash []byte, err error) { 282 | h, err := hash.NewCMixHash() 283 | if err != nil { 284 | return 285 | } 286 | _, err = h.Write(transmissionRSA) 287 | if err != nil { 288 | return 289 | } 290 | transmissionRSAHash = h.Sum(nil) 291 | return 292 | } 293 | -------------------------------------------------------------------------------- /storage/storage_test.go: -------------------------------------------------------------------------------- 1 | // ////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // 4 | // // 5 | // 6 | // Use of this source code is governed by a license that can be found in the // 7 | // LICENSE file. // 8 | // ////////////////////////////////////////////////////////////////////////////// 9 | package storage 10 | 11 | import ( 12 | "gitlab.com/elixxir/notifications-bot/constants" 13 | "gitlab.com/xx_network/crypto/csprng" 14 | "gitlab.com/xx_network/crypto/signature/rsa" 15 | "gitlab.com/xx_network/primitives/id" 16 | "gitlab.com/xx_network/primitives/id/ephemeral" 17 | "testing" 18 | "time" 19 | ) 20 | 21 | func TestStorage_RegisterToken(t *testing.T) { 22 | s, err := NewStorage("", "", "", "", "") 23 | if err != nil { 24 | t.Fatalf("Failed to create new storage object: %+v", err) 25 | } 26 | 27 | token := "TestToken" 28 | app := "HavenIOS" 29 | trsaPrivate, err := rsa.GenerateKey(csprng.NewSystemRNG(), 512) 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | pub := rsa.CreatePublicKeyPem(trsaPrivate.GetPublic()) 34 | 35 | err = s.RegisterToken(token, app, pub) 36 | if err != nil { 37 | t.Fatalf("Failed to register token: %+v", err) 38 | } 39 | 40 | err = s.RegisterToken(token, app, pub) 41 | if err != nil { 42 | t.Fatalf("Duplicate register token returned unexpected error: %+v", err) 43 | } 44 | } 45 | 46 | func TestStorage_RegisterTrackedID(t *testing.T) { 47 | s, err := NewStorage("", "", "", "", "") 48 | if err != nil { 49 | t.Fatalf("Failed to create new storage object: %+v", err) 50 | } 51 | 52 | token := "TestToken" 53 | app := "HavenIOS" 54 | trsaPrivate, err := rsa.GenerateKey(csprng.NewSystemRNG(), 512) 55 | if err != nil { 56 | t.Fatal(err) 57 | } 58 | pub := rsa.CreatePublicKeyPem(trsaPrivate.GetPublic()) 59 | testId, err := id.NewRandomID(csprng.NewSystemRNG(), id.User) 60 | if err != nil { 61 | t.Fatalf("Failed to generate test ID: %+v", err) 62 | } 63 | iid, err := ephemeral.GetIntermediaryId(testId) 64 | if err != nil { 65 | t.Fatalf("Failed to generate intermediary ID: %+v", err) 66 | } 67 | _, epoch := ephemeral.HandleQuantization(time.Now()) 68 | 69 | err = s.RegisterToken(token, app, pub) 70 | if err != nil { 71 | t.Fatalf("Failed to register token: %+v", err) 72 | } 73 | 74 | err = s.RegisterTrackedID([][]byte{iid}, pub, epoch, 16) 75 | if err != nil { 76 | t.Fatalf("Received error registering identity: %+v", err) 77 | } 78 | 79 | err = s.RegisterTrackedID([][]byte{iid}, pub, epoch, 16) 80 | if err != nil { 81 | t.Fatalf("Received unexpected error on duplicate identity registration: %+v", err) 82 | } 83 | } 84 | 85 | func TestStorage_UnregisterToken(t *testing.T) { 86 | s, err := NewStorage("", "", "", "", "") 87 | if err != nil { 88 | t.Fatalf("Failed to create new storage object: %+v", err) 89 | } 90 | 91 | token := "TestToken" 92 | otherToken := "TestToken2" 93 | app := "HavenIOS" 94 | trsaPrivate, err := rsa.GenerateKey(csprng.NewSystemRNG(), 512) 95 | if err != nil { 96 | t.Fatal(err) 97 | } 98 | pub := rsa.CreatePublicKeyPem(trsaPrivate.GetPublic()) 99 | 100 | err = s.UnregisterToken(token, pub) 101 | if err != nil { 102 | t.Fatalf("Received error on unregister with nothing inserted: %+v", err) 103 | } 104 | 105 | err = s.RegisterToken(token, app, pub) 106 | if err != nil { 107 | t.Fatalf("Failed to register token: %+v", err) 108 | } 109 | 110 | err = s.UnregisterToken(otherToken, pub) 111 | if err != nil { 112 | t.Fatalf("Received error on unregister when token not inserted: %+v", err) 113 | } 114 | 115 | err = s.RegisterToken(otherToken, app, pub) 116 | if err != nil { 117 | t.Fatalf("Failed to register second token: %+v", err) 118 | } 119 | 120 | trsaHash, err := getHash(pub) 121 | if err != nil { 122 | t.Fatalf("Failed to get trsa hash: %+v", err) 123 | } 124 | u, err := s.GetUser(trsaHash) 125 | if err != nil { 126 | t.Fatalf("Failed to get user: %+v", err) 127 | } 128 | 129 | if len(u.Tokens) != 2 { 130 | t.Fatalf("Did not receive expected tokens on user") 131 | } 132 | 133 | err = s.UnregisterToken(token, pub) 134 | if err != nil { 135 | t.Fatalf("Received error on unregister: %+v", err) 136 | } 137 | 138 | u, err = s.GetUser(trsaHash) 139 | if err != nil { 140 | t.Fatalf("Failed to get user after token deletion: %+v", err) 141 | } 142 | 143 | if len(u.Tokens) != 1 { 144 | t.Fatalf("Tokens on user should have been reduced to 1") 145 | } 146 | 147 | err = s.UnregisterToken(otherToken, pub) 148 | if err != nil { 149 | t.Fatalf("Received error on second token unregister: %+v", err) 150 | } 151 | 152 | u, err = s.GetUser(trsaHash) 153 | if err != nil { 154 | t.Fatalf("User should still exist after unregister, instead got: %+v", err) 155 | } 156 | 157 | } 158 | 159 | func TestStorage_UnregisterTrackedID(t *testing.T) { 160 | s, err := NewStorage("", "", "", "", "") 161 | if err != nil { 162 | t.Fatalf("Failed to create new storage object: %+v", err) 163 | } 164 | 165 | token := "TestToken" 166 | app := "HavenIOS" 167 | trsaPrivate, err := rsa.GenerateKey(csprng.NewSystemRNG(), 512) 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | pub := rsa.CreatePublicKeyPem(trsaPrivate.GetPublic()) 172 | testId, err := id.NewRandomID(csprng.NewSystemRNG(), id.User) 173 | if err != nil { 174 | t.Fatalf("Failed to generate test ID: %+v", err) 175 | } 176 | iid, err := ephemeral.GetIntermediaryId(testId) 177 | if err != nil { 178 | t.Fatalf("Failed to generate IID: %+v", err) 179 | } 180 | testId2, err := id.NewRandomID(csprng.NewSystemRNG(), id.User) 181 | if err != nil { 182 | t.Fatalf("Failed to generate test ID: %+v", err) 183 | } 184 | iid2, err := ephemeral.GetIntermediaryId(testId2) 185 | if err != nil { 186 | t.Fatalf("Failed to generate IID: %+v", err) 187 | } 188 | _, epoch := ephemeral.HandleQuantization(time.Now()) 189 | 190 | err = s.UnregisterTrackedIDs([][]byte{iid}, pub) 191 | if err != nil { 192 | t.Fatalf("Error on unregister tracked ID with nothing inserted: %+v", err) 193 | } 194 | 195 | err = s.RegisterToken(token, app, pub) 196 | if err != nil { 197 | t.Fatalf("Failed to register token: %+v", err) 198 | } 199 | 200 | err = s.UnregisterTrackedIDs([][]byte{iid}, pub) 201 | if err != nil { 202 | t.Fatalf("Error on unregister tracked ID with user inserted, but no tracked IDs: %+v", err) 203 | } 204 | 205 | err = s.RegisterToken(token, app, pub) 206 | if err != nil { 207 | t.Fatalf("Failed to register token: %+v", err) 208 | } 209 | 210 | err = s.RegisterTrackedID([][]byte{iid}, pub, epoch, 16) 211 | if err != nil { 212 | t.Fatalf("Received error registering identity: %+v", err) 213 | } 214 | 215 | err = s.UnregisterTrackedIDs([][]byte{iid2}, pub) 216 | if err != nil { 217 | t.Fatalf("Error on unregister untracked ID: %+v", err) 218 | } 219 | 220 | err = s.RegisterTrackedID([][]byte{iid2}, pub, epoch, 16) 221 | if err != nil { 222 | t.Fatalf("Received error registering identity: %+v", err) 223 | } 224 | 225 | trsaHash, err := getHash(pub) 226 | if err != nil { 227 | t.Fatalf("Failed to get trsa hash: %+v", err) 228 | } 229 | u, err := s.GetUser(trsaHash) 230 | if err != nil { 231 | t.Fatalf("Failed to get user: %+v", err) 232 | } 233 | 234 | if len(u.Identities) != 2 { 235 | t.Fatalf("Did not receive expected identities for user") 236 | } 237 | 238 | err = s.UnregisterTrackedIDs([][]byte{iid}, pub) 239 | if err != nil { 240 | t.Fatalf("Failed to unregister tracked ID: %+v", err) 241 | } 242 | 243 | u, err = s.GetUser(trsaHash) 244 | if err != nil { 245 | t.Fatalf("Failed to get user after first delete: %+v", err) 246 | } 247 | 248 | if len(u.Identities) != 1 { 249 | t.Fatalf("Identity was not properly deleted") 250 | } 251 | 252 | err = s.UnregisterTrackedIDs([][]byte{iid2}, pub) 253 | if err != nil { 254 | t.Fatalf("Failed to unregister tracked ID: %+v", err) 255 | } 256 | 257 | u, err = s.GetUser(trsaHash) 258 | if err != nil { 259 | t.Fatalf("User should still exist after unregister, instead got: %+v", err) 260 | } 261 | if len(u.Tokens) != 1 { 262 | t.Fatalf("User tokens should be unaffected by unregistering ID") 263 | } 264 | } 265 | 266 | func TestStorage_RegisterForNotifications(t *testing.T) { 267 | s, err := NewStorage("", "", "", "", "") 268 | if err != nil { 269 | t.Errorf("Failed to create new storage object: %+v", err) 270 | } 271 | uid := id.NewIdFromString("zezima", id.User, t) 272 | iid, err := ephemeral.GetIntermediaryId(uid) 273 | if err != nil { 274 | t.Errorf("Failed to create iid: %+v", err) 275 | } 276 | if err != nil { 277 | t.Errorf("Could not parse precanned time: %v", err.Error()) 278 | } 279 | _, err = s.RegisterForNotifications(iid, []byte("transmissionrsa"), "token", constants.MessengerIOS.String(), 0, 8) 280 | if err != nil { 281 | t.Errorf("Failed to add user: %+v", err) 282 | } 283 | } 284 | 285 | func TestStorage_AddLatestEphemeral(t *testing.T) { 286 | s, err := NewStorage("", "", "", "", "") 287 | if err != nil { 288 | t.Errorf("Failed to create new storage object: %+v", err) 289 | } 290 | uid := id.NewIdFromString("zezima", id.User, t) 291 | iid, err := ephemeral.GetIntermediaryId(uid) 292 | if err != nil { 293 | t.Errorf("Failed to create iid: %+v", err) 294 | } 295 | if err != nil { 296 | t.Errorf("Could not parse precanned time: %v", err.Error()) 297 | } 298 | ident := &Identity{ 299 | IntermediaryId: iid, 300 | OffsetNum: ephemeral.GetOffsetNum(ephemeral.GetOffset(iid)), 301 | } 302 | err = s.insertIdentity(ident) 303 | if err != nil { 304 | t.Errorf("Failed to add user: %+v", err) 305 | } 306 | _, err = s.AddLatestEphemeral(ident, 5, 16) 307 | if err != nil { 308 | t.Errorf("Failed to add latest ephemeral: %+v", err) 309 | } 310 | } 311 | 312 | func TestStorage_AddEphemeralsForOffset(t *testing.T) { 313 | _, err := NewStorage("", "", "", "", "") 314 | if err != nil { 315 | t.Errorf("Failed to create new storage object: %+v", err) 316 | } 317 | } 318 | -------------------------------------------------------------------------------- /cmd/root.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Package cmd initializes the CLI and config parsers as well as the logger 9 | 10 | package cmd 11 | 12 | import ( 13 | "fmt" 14 | "github.com/spf13/cobra" 15 | jww "github.com/spf13/jwalterweatherman" 16 | "github.com/spf13/viper" 17 | "gitlab.com/elixxir/comms/mixmessages" 18 | "gitlab.com/elixxir/notifications-bot/notifications" 19 | "gitlab.com/elixxir/notifications-bot/notifications/providers" 20 | "gitlab.com/elixxir/notifications-bot/storage" 21 | "gitlab.com/xx_network/comms/connect" 22 | "gitlab.com/xx_network/primitives/id" 23 | "gitlab.com/xx_network/primitives/utils" 24 | "net" 25 | "os" 26 | "sync/atomic" 27 | "time" 28 | ) 29 | 30 | var ( 31 | cfgFile, logPath string 32 | verbose bool 33 | noTLS bool 34 | validConfig bool 35 | NotificationParams notifications.Params 36 | loopDelay int 37 | ) 38 | 39 | // rootCmd represents the base command when called without any subcommands 40 | var rootCmd = &cobra.Command{ 41 | Use: "registration", 42 | Short: "Runs a registration server for cMix", 43 | Long: `This server provides registration functions on cMix`, 44 | Args: cobra.NoArgs, 45 | Run: func(cmd *cobra.Command, args []string) { 46 | initConfig() 47 | initLog() 48 | 49 | if verbose { 50 | err := os.Setenv("GRPC_GO_LOG_SEVERITY_LEVEL", "info") 51 | if err != nil { 52 | jww.ERROR.Printf("Could not set GRPC_GO_LOG_SEVERITY_LEVEL: %+v", err) 53 | } 54 | 55 | err = os.Setenv("GRPC_GO_LOG_VERBOSITY_LEVEL", "2") 56 | if err != nil { 57 | jww.ERROR.Printf("Could not set GRPC_GO_LOG_VERBOSITY_LEVEL: %+v", err) 58 | } 59 | } 60 | 61 | // Parse config file options 62 | certPath := viper.GetString("certPath") 63 | keyPath := viper.GetString("keyPath") 64 | localAddress := fmt.Sprintf("0.0.0.0:%d", viper.GetInt("port")) 65 | fbCreds, err := utils.ExpandPath(viper.GetString("firebaseCredentialsPath")) 66 | if err != nil { 67 | jww.FATAL.Panicf("Unable to expand credentials path: %+v", err) 68 | } 69 | 70 | havenFbCreds, err := utils.ExpandPath(viper.GetString("havenFirebaseCredentialsPath")) 71 | if err != nil { 72 | jww.FATAL.Panicf("Unable to expand haven credentials path: %+v", err) 73 | } 74 | 75 | apnsKeyPath, err := utils.ExpandPath(viper.GetString("apnsKeyPath")) 76 | if err != nil { 77 | jww.FATAL.Panicf("Unable to expand apns key path: %+v", err) 78 | } 79 | havenApnsKeyPath, err := utils.ExpandPath(viper.GetString("havenApnsKeyPath")) 80 | if err != nil { 81 | jww.FATAL.Panicf("Unable to expand apns key path: %+v", err) 82 | } 83 | 84 | httpsCertPath, err := utils.ExpandPath(viper.GetString("httpsCert")) 85 | if err != nil { 86 | jww.FATAL.Panicf("Unable to expand https key path: %+v", err) 87 | } 88 | httpsKeyPath, err := utils.ExpandPath(viper.GetString("httpsKey")) 89 | if err != nil { 90 | jww.FATAL.Panicf("Unable to expand https cert path: %+v", err) 91 | } 92 | viper.SetDefault("notificationRate", 30) 93 | viper.SetDefault("notificationsPerBatch", 20) 94 | // This is set to approx. 90% of the stated limit (4096) 95 | viper.SetDefault("maxNotificationPayload", 3686) 96 | // Populate params 97 | NotificationParams = notifications.Params{ 98 | Address: localAddress, 99 | CertPath: certPath, 100 | KeyPath: keyPath, 101 | FBCreds: fbCreds, 102 | NotificationRate: viper.GetInt("notificationRate"), 103 | NotificationsPerBatch: viper.GetInt("notificationsPerBatch"), 104 | MaxNotificationPayload: viper.GetInt("maxNotificationPayload"), 105 | APNS: providers.APNSParams{ 106 | KeyPath: apnsKeyPath, 107 | KeyID: viper.GetString("apnsKeyID"), 108 | Issuer: viper.GetString("apnsIssuer"), 109 | BundleID: viper.GetString("apnsBundleID"), 110 | Dev: viper.GetBool("apnsDev"), 111 | }, 112 | HavenAPNS: providers.APNSParams{ 113 | KeyPath: havenApnsKeyPath, 114 | KeyID: viper.GetString("havenApnsKeyID"), 115 | Issuer: viper.GetString("havenApnsIssuer"), 116 | BundleID: viper.GetString("havenApnsBundleID"), 117 | Dev: viper.GetBool("havenApnsDev"), 118 | }, 119 | HavenFBCreds: havenFbCreds, 120 | HttpsCertPath: httpsCertPath, 121 | HttpsKeyPath: httpsKeyPath, 122 | } 123 | 124 | rawAddr := viper.GetString("dbAddress") 125 | var addr, port string 126 | if rawAddr != "" { 127 | addr, port, err = net.SplitHostPort(rawAddr) 128 | if err != nil { 129 | jww.FATAL.Panicf("Unable to get database port from %s: %+v", rawAddr, err) 130 | } 131 | } 132 | // Initialize the storage backend 133 | s, err := storage.NewStorage( 134 | viper.GetString("dbUsername"), 135 | viper.GetString("dbPassword"), 136 | viper.GetString("dbName"), 137 | addr, 138 | port, 139 | ) 140 | if err != nil { 141 | jww.FATAL.Panicf("Failed to initialize storage: %+v", err) 142 | } 143 | 144 | // Start notifications server 145 | jww.INFO.Println("Starting Notifications...") 146 | impl, err := notifications.StartNotifications(NotificationParams, noTLS, false) 147 | if err != nil { 148 | jww.FATAL.Panicf("Failed to start notifications server: %+v", err) 149 | } 150 | 151 | impl.Storage = s 152 | 153 | // Read in permissioning certificate 154 | cert, err := utils.ReadFile(viper.GetString("permissioningCertPath")) 155 | if err != nil { 156 | jww.FATAL.Panicf("Could not read permissioning cert: %+v", err) 157 | } 158 | 159 | // Add host for permissioning server 160 | hostParams := connect.GetDefaultHostParams() 161 | hostParams.AuthEnabled = false 162 | _, err = impl.Comms.AddHost(&id.Permissioning, viper.GetString("permissioningAddress"), cert, hostParams) 163 | if err != nil { 164 | jww.FATAL.Panicf("Failed to Create permissioning host: %+v", err) 165 | } 166 | 167 | // Start ephemeral ID tracking 168 | errChan := make(chan error) 169 | impl.TrackNdf() 170 | for atomic.LoadUint32(impl.ReceivedNdf()) != 1 { 171 | time.Sleep(time.Second) 172 | } 173 | go impl.EphIdCreator() 174 | go impl.EphIdDeleter() 175 | 176 | // Wait forever to prevent process from ending 177 | err = <-errChan 178 | jww.FATAL.Panicf("Notifications loop error received: %+v", err) 179 | }, 180 | } 181 | 182 | // Execute adds all child commands to the root command and sets flags 183 | // appropriately. This is called by main.main(). It only needs to 184 | // happen once to the rootCmd. 185 | func Execute() { 186 | if err := rootCmd.Execute(); err != nil { 187 | jww.ERROR.Println(err) 188 | os.Exit(1) 189 | } 190 | } 191 | 192 | // init is the initialization function for Cobra which defines commands 193 | // and flags. 194 | func init() { 195 | // NOTE: The point of init() is to be declarative. 196 | // There is one init in each sub command. Do not put variable declarations 197 | // here, and ensure all the Flags are of the *P variety, unless there's a 198 | // very good reason not to have them as local params to sub command." 199 | 200 | // Here you will define your flags and configuration settings. 201 | // Cobra supports persistent flags, which, if defined here, 202 | // will be global for your application. 203 | rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, 204 | "Show verbose logs for debugging") 205 | 206 | rootCmd.Flags().StringVarP(&cfgFile, "config", "c", 207 | "", "Sets a custom config file path") 208 | 209 | rootCmd.Flags().BoolVar(&noTLS, "noTLS", false, 210 | "Runs without TLS enabled") 211 | 212 | rootCmd.Flags().IntVarP(&loopDelay, "loopDelay", "", 500, 213 | "Set the delay between notification loops (in milliseconds)") 214 | 215 | // Bind config and command line flags of the same name 216 | err := viper.BindPFlag("verbose", rootCmd.Flags().Lookup("verbose")) 217 | handleBindingError(err, "verbose") 218 | } 219 | 220 | // Handle flag binding errors 221 | func handleBindingError(err error, flag string) { 222 | if err != nil { 223 | jww.FATAL.Panicf("Error on binding flag \"%s\":%+v", flag, err) 224 | } 225 | } 226 | 227 | // initConfig reads in config file and ENV variables if set. 228 | func initConfig() { 229 | //Use default config location if none is passed 230 | var err error 231 | validConfig = true 232 | if cfgFile == "" { 233 | cfgFile, err = utils.SearchDefaultLocations("notifications.yaml", "xxnetwork") 234 | if err != nil { 235 | validConfig = false 236 | jww.FATAL.Panicf("Failed to find config file: %+v", err) 237 | } 238 | } else { 239 | cfgFile, err = utils.ExpandPath(cfgFile) 240 | if err != nil { 241 | validConfig = false 242 | jww.FATAL.Panicf("Failed to expand config file path: %+v", err) 243 | } 244 | } 245 | 246 | viper.SetConfigFile(cfgFile) 247 | viper.AutomaticEnv() // read in environment variables that match 248 | 249 | // If a config file is found, read it in. 250 | if err := viper.ReadInConfig(); err != nil { 251 | fmt.Printf("Unable to read config file (%s): %+v", cfgFile, err.Error()) 252 | validConfig = false 253 | } 254 | } 255 | 256 | // initLog initializes logging thresholds and the log path. 257 | func initLog() { 258 | vipLogLevel := viper.GetUint("logLevel") 259 | 260 | // Check the level of logs to display 261 | if vipLogLevel > 1 { 262 | // Set the GRPC log level 263 | err := os.Setenv("GRPC_GO_LOG_SEVERITY_LEVEL", "info") 264 | if err != nil { 265 | jww.ERROR.Printf("Could not set GRPC_GO_LOG_SEVERITY_LEVEL: %+v", err) 266 | } 267 | 268 | err = os.Setenv("GRPC_GO_LOG_VERBOSITY_LEVEL", "99") 269 | if err != nil { 270 | jww.ERROR.Printf("Could not set GRPC_GO_LOG_VERBOSITY_LEVEL: %+v", err) 271 | } 272 | // Turn on trace logs 273 | jww.SetLogThreshold(jww.LevelTrace) 274 | jww.SetStdoutThreshold(jww.LevelTrace) 275 | mixmessages.TraceMode() 276 | } else if vipLogLevel == 1 { 277 | // Turn on debugging logs 278 | jww.SetLogThreshold(jww.LevelDebug) 279 | jww.SetStdoutThreshold(jww.LevelDebug) 280 | mixmessages.DebugMode() 281 | } else { 282 | // Turn on info logs 283 | jww.SetLogThreshold(jww.LevelInfo) 284 | jww.SetStdoutThreshold(jww.LevelInfo) 285 | } 286 | 287 | logPath = viper.GetString("log") 288 | 289 | logFile, err := os.OpenFile(logPath, 290 | os.O_CREATE|os.O_WRONLY|os.O_APPEND, 291 | 0644) 292 | if err != nil { 293 | fmt.Printf("Could not open log file %s!\n", logPath) 294 | } else { 295 | jww.SetLogOutput(logFile) 296 | } 297 | } 298 | -------------------------------------------------------------------------------- /storage/databaseImpl.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Handles implementation of the database backend 9 | 10 | package storage 11 | 12 | import ( 13 | "encoding/base64" 14 | "github.com/pkg/errors" 15 | jww "github.com/spf13/jwalterweatherman" 16 | "gorm.io/gorm" 17 | "gorm.io/gorm/clause" 18 | ) 19 | 20 | // UpsertState inserts the given State into Storage if it does not exist, 21 | // or updates the Database State if its value does not match the given State. 22 | func (d *DatabaseImpl) UpsertState(state *State) error { 23 | jww.TRACE.Printf("Attempting to insert State into DB: %+v", state) 24 | 25 | // Build a transaction to prevent race conditions 26 | return d.db.Transaction(func(tx *gorm.DB) error { 27 | return tx.Clauses(clause.OnConflict{ 28 | Columns: []clause.Column{{Name: "key"}}, 29 | DoUpdates: clause.AssignmentColumns([]string{"value"}), 30 | }).Create(state).Error 31 | }) 32 | } 33 | 34 | // GetStateValue returns a State's value from Storage with the given key 35 | // or an error if a matching State does not exist. 36 | func (d *DatabaseImpl) GetStateValue(key string) (string, error) { 37 | result := &State{Key: key} 38 | err := d.db.Take(result).Error 39 | jww.TRACE.Printf("Obtained State from DB: %+v", result) 40 | return result.Value, err 41 | } 42 | 43 | // DeleteToken deletes the given token from storage. 44 | func (d *DatabaseImpl) DeleteToken(token string) error { 45 | return d.db.Where("token = ?", token).Delete(&Token{Token: token}).Error 46 | } 47 | 48 | // insertUser inserts or updates a User in storage. 49 | func (d *DatabaseImpl) insertUser(user *User) error { 50 | return d.db.Clauses(clause.OnConflict{DoNothing: true}).Create(user).Error 51 | } 52 | 53 | // GetUser retrieves a user from storage with the passed in key. 54 | func (d *DatabaseImpl) GetUser(transmissionRsaHash []byte) (*User, error) { 55 | u := &User{} 56 | err := d.db.Preload("Identities").Preload("Tokens").Take(u, "transmission_rsa_hash = ?", transmissionRsaHash).Error 57 | if err != nil { 58 | return nil, err 59 | } 60 | return u, nil 61 | } 62 | 63 | // deleteUser removes the User with the passed in key from storage. 64 | func (d *DatabaseImpl) deleteUser(transmissionRsaHash []byte) error { 65 | err := d.db.Delete(&User{ 66 | TransmissionRSAHash: transmissionRsaHash, 67 | }).Error 68 | if err != nil { 69 | return errors.Errorf("Failed to delete user with tRSA hash %s: %+v", transmissionRsaHash, err) 70 | } 71 | return nil 72 | } 73 | 74 | // GetAllUsers returns a list of all users in storage. 75 | func (d *DatabaseImpl) GetAllUsers() ([]*User, error) { 76 | var dest []*User 77 | return dest, d.db.Find(&dest).Error 78 | } 79 | 80 | // GetIdentity retrieves an Identity from storage by primary key. 81 | func (d *DatabaseImpl) GetIdentity(iid []byte) (*Identity, error) { 82 | i := &Identity{} 83 | err := d.db.Preload("Users").Take(i, "intermediary_id = ?", iid).Error 84 | if err != nil { 85 | return nil, err 86 | } 87 | return i, nil 88 | } 89 | 90 | // insertIdentity adds an identity to storage. 91 | func (d *DatabaseImpl) insertIdentity(identity *Identity) error { 92 | return d.db.Clauses(clause.OnConflict{ 93 | DoNothing: true, 94 | }).Create(identity).Error 95 | } 96 | 97 | // getIdentitiesByOffset returns a list of all identities with the given offset. 98 | func (d *DatabaseImpl) getIdentitiesByOffset(offset int64) ([]*Identity, error) { 99 | var result []*Identity 100 | err := d.db.Where(&Identity{OffsetNum: offset}).Find(&result).Error 101 | return result, err 102 | } 103 | 104 | // GetOrphanedIdentities returns a list of identities with no associated ephemerals. 105 | func (d *DatabaseImpl) GetOrphanedIdentities() ([]*Identity, error) { 106 | var dest []*Identity 107 | return dest, d.db.Find(&dest, "NOT EXISTS (select * from ephemerals where ephemerals.intermediary_id = identities.intermediary_id)").Error 108 | } 109 | 110 | // insertEphemeral inserts an Ephemeral into storage. 111 | func (d *DatabaseImpl) insertEphemeral(ephemeral *Ephemeral) error { 112 | return d.db.Create(&ephemeral).Error 113 | } 114 | 115 | // GetEphemeral retrieves a list of ephemerals with the given ID. 116 | func (d *DatabaseImpl) GetEphemeral(ephemeralId int64) ([]*Ephemeral, error) { 117 | var result []*Ephemeral 118 | err := d.db.Where("ephemeral_id = ?", ephemeralId).Find(&result).Error 119 | if err != nil { 120 | return nil, err 121 | } 122 | if len(result) < 1 { 123 | return nil, gorm.ErrRecordNotFound 124 | } 125 | return result, nil 126 | } 127 | 128 | // GTNResult is a type wrapping the custom query for GetToNotify. 129 | type GTNResult struct { 130 | Token string 131 | App string 132 | TransmissionRSAHash []byte 133 | EphemeralId int64 134 | } 135 | 136 | // The following struct can be used to scan in the intermediary result tables t1 and t2 137 | //type T1Result struct { 138 | // EphemeralId int64 139 | // IntermediaryId []byte 140 | //} 141 | // 142 | //type T2Result struct { 143 | // EphemeralId int64 144 | // TransmissionRsaHash []byte 145 | //} 146 | // 147 | //type T3Result struct { 148 | // TransmissionRsaHash []byte 149 | // EphemeralId int64 150 | //} 151 | 152 | // GetToNotify returns a list of GTNResult data matching the list of ephemeral IDs passed in. 153 | func (d *DatabaseImpl) GetToNotify(ephemeralIds []int64) ([]GTNResult, error) { 154 | var result []GTNResult 155 | err := d.db.Transaction(func(tx *gorm.DB) error { 156 | t1 := tx.Table("identities").Select("ephemerals.ephemeral_id, identities.intermediary_id").Joins("inner join ephemerals on ephemerals.intermediary_id = identities.intermediary_id").Where("ephemerals.ephemeral_id in ?", ephemeralIds) 157 | t2 := tx.Table("user_identities").Select("t1.ephemeral_id, user_identities.user_transmission_rsa_hash as transmission_rsa_hash").Joins("right join (?) as t1 on t1.intermediary_id = user_identities.identity_intermediary_id", t1) 158 | t3 := tx.Model(&User{}).Select("users.transmission_rsa_hash, t2.ephemeral_id").Joins("right join (?) as t2 on users.transmission_rsa_hash = t2.transmission_rsa_hash", t2) 159 | return tx.Model(&Token{}).Distinct().Select("tokens.token, tokens.app, t3.transmission_rsa_hash, t3.ephemeral_id").Joins("right join (?) as t3 on tokens.transmission_rsa_hash = t3.transmission_rsa_hash", t3).Scan(&result).Error 160 | }) 161 | return result, err 162 | } 163 | 164 | // DeleteOldEphemerals deletes all ephemerals from storage with an epoch before the passed in value. 165 | func (d *DatabaseImpl) DeleteOldEphemerals(currentEpoch int32) error { 166 | res := d.db.Where("epoch < ?", currentEpoch).Delete(&Ephemeral{}) 167 | return res.Error 168 | } 169 | 170 | // GetLatestEphemeral retrieves an ephemeral with the highest epoch from storage. 171 | func (d *DatabaseImpl) GetLatestEphemeral() (*Ephemeral, error) { 172 | var result []*Ephemeral 173 | err := d.db.Order("epoch desc").Limit(1).Find(&result).Error 174 | if err != nil { 175 | return nil, err 176 | } 177 | if len(result) < 1 { 178 | return nil, gorm.ErrRecordNotFound 179 | } 180 | return result[0], nil 181 | } 182 | 183 | // registerForNotifications is primarily used for legacy calls. 184 | // It links an extant user with the given identity and token. 185 | func (d *DatabaseImpl) registerForNotifications(u *User, identity Identity, token Token) error { 186 | return d.db.Transaction(func(tx *gorm.DB) error { 187 | err := tx.Model(u).Association("Identities").Append(&identity) 188 | if err != nil { 189 | return errors.WithMessage(err, "Failed to register identity") 190 | } 191 | 192 | err = tx.Model(u).Association("Tokens").Append(&token) 193 | if err != nil { 194 | return errors.WithMessage(err, "Failed to register token") 195 | } 196 | return nil 197 | }) 198 | } 199 | 200 | // unregisterIdentities deletes all given identities from the given user. 201 | // It does not remove the user or the identities, just the association. 202 | func (d *DatabaseImpl) unregisterIdentities(u *User, iids []Identity) error { 203 | return d.db.Transaction(func(tx *gorm.DB) error { 204 | err := tx.Model(&u).Association("Identities").Delete(iids) 205 | if err != nil { 206 | return errors.WithMessage(err, "Failed to break association") 207 | } 208 | // This code will clean up users and identities with no associations 209 | // it has been intentionally disabled 210 | // TODO: long-running cleanup thread for identities? 211 | // it has been intentionally disabled 212 | //for _, iid := range iids { 213 | // var count int64 214 | // err = tx.Table("user_identities").Where("identity_intermediary_id = ?", iid.IntermediaryId).Count(&count).Error 215 | // if err != nil { 216 | // return errors.WithMessagef(err, "Failed count user_identities for identity %+v", iid.IntermediaryId) 217 | // } 218 | // if count == 0 { 219 | // err = tx.Delete(iid).Error 220 | // if err != nil { 221 | // return errors.WithMessage(err, "Failed to delete identity") 222 | // } 223 | // } 224 | // 225 | // err = tx.Table("user_identities").Where("user_transmission_rsa_hash = ?", u.TransmissionRSAHash).Count(&count).Error 226 | // if err != nil { 227 | // return errors.WithMessagef(err, "Failed to count user_identities for user %+v", u.TransmissionRSAHash) 228 | // } 229 | // if count == 0 { 230 | // err = tx.Delete(u).Error 231 | // if err != nil { 232 | // return errors.WithMessage(err, "Failed to delete user") 233 | // } 234 | // } 235 | //} 236 | return nil 237 | }) 238 | } 239 | 240 | // unregisterTokens deletes all given tokens from the passed in user. 241 | // It does not remove the tokens or user, just their association. 242 | func (d *DatabaseImpl) unregisterTokens(u *User, tokens []Token) error { 243 | return d.db.Transaction(func(tx *gorm.DB) error { 244 | for _, t := range tokens { 245 | err := tx.Delete(t).Error 246 | if err != nil { 247 | return errors.WithMessage(err, "Failed to delete token") 248 | } 249 | } 250 | 251 | // This code will delete the user if the unregistered token is its last 252 | // it has been intentionally disabled 253 | //count := tx.Model(u).Association("Tokens").Count() 254 | // 255 | //if count == 0 { 256 | // err := tx.Model(&u).Association("Identities").Clear() 257 | // if err != nil { 258 | // return errors.WithMessage(err, "Failed to prep user for delete") 259 | // } 260 | // err = tx.Delete(&u).Error 261 | // if err != nil { 262 | // return errors.WithMessage(err, "Failed to delete user") 263 | // } 264 | //} 265 | return nil 266 | }) 267 | } 268 | 269 | // LegacyUnregister is a function to mimic the old unregister logic. 270 | // It will delete a user and identity if they have a 1:1 relationship. 271 | func (d *DatabaseImpl) LegacyUnregister(iid []byte) error { 272 | return d.db.Transaction(func(tx *gorm.DB) error { 273 | var res Identity 274 | err := tx.Preload("Users").Find(&res, "intermediary_id = ?", iid).Error 275 | if err != nil { 276 | return err 277 | } 278 | if len(res.Users) > 1 { 279 | return errors.Errorf("legacyUnregister can only be called for identities with a single associated user") 280 | } 281 | 282 | err = tx.Model(&Identity{IntermediaryId: iid}).Association("Users").Clear() 283 | if err != nil { 284 | return errors.WithMessage(err, "Failed to break association") 285 | } 286 | 287 | err = tx.Delete(&Identity{IntermediaryId: iid}).Error 288 | if err != nil { 289 | return errors.WithMessage(err, "Failed to delete identity") 290 | } 291 | err = tx.Delete(&User{TransmissionRSAHash: res.Users[0].TransmissionRSAHash}).Error 292 | if err != nil { 293 | return errors.WithMessage(err, "Failed to delete user") 294 | } 295 | return nil 296 | }) 297 | } 298 | 299 | // insertToken adds a token to storage. 300 | func (d *DatabaseImpl) insertToken(token Token) error { 301 | return d.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&token).Error 302 | } 303 | 304 | // registerTrackedIdentity links an Identity to a User. 305 | func (d *DatabaseImpl) registerTrackedIdentity(user User, identity Identity) error { 306 | return d.db.Model(&user).Association("Identities").Append(&identity) 307 | } 308 | 309 | func (d *DatabaseImpl) registerTrackedIdentities(user User, ids []Identity) error { 310 | return d.db.Transaction(func(tx *gorm.DB) error { 311 | for _, iid := range ids { 312 | err := tx.Model(&user).Association("Identities").Append(&iid) 313 | if err != nil { 314 | return errors.WithMessagef(err, "Failed to register identity %s to user with transmission RSA hash %s", 315 | base64.StdEncoding.EncodeToString(iid.IntermediaryId), base64.StdEncoding.EncodeToString(user.TransmissionRSAHash)) 316 | } 317 | } 318 | return nil 319 | }) 320 | } 321 | -------------------------------------------------------------------------------- /notifications/registration_test.go: -------------------------------------------------------------------------------- 1 | package notifications 2 | 3 | import ( 4 | "gitlab.com/elixxir/comms/mixmessages" 5 | "gitlab.com/elixxir/crypto/notifications" 6 | "gitlab.com/elixxir/crypto/registration" 7 | rsa2 "gitlab.com/elixxir/crypto/rsa" 8 | "gitlab.com/elixxir/notifications-bot/constants" 9 | "gitlab.com/elixxir/notifications-bot/storage" 10 | "gitlab.com/xx_network/comms/connect" 11 | "gitlab.com/xx_network/crypto/csprng" 12 | "gitlab.com/xx_network/crypto/signature/rsa" 13 | "gitlab.com/xx_network/primitives/id" 14 | "gitlab.com/xx_network/primitives/id/ephemeral" 15 | "gitlab.com/xx_network/primitives/utils" 16 | "os" 17 | "testing" 18 | "time" 19 | ) 20 | 21 | func TestImpl_RegisterToken(t *testing.T) { 22 | impl := getNewImpl() 23 | var err error 24 | impl.Storage, err = storage.NewStorage("", "", "", "", "") 25 | if err != nil { 26 | t.Errorf("Failed to create storage: %+v", err) 27 | } 28 | wd, err := os.Getwd() 29 | if err != nil { 30 | t.Errorf("Failed to get working dir: %+v", err) 31 | } 32 | 33 | permCert, err := utils.ReadFile(wd + "/../testutil/cmix.rip.crt") 34 | if err != nil { 35 | t.Errorf("Failed to read test cert file: %+v", err) 36 | } 37 | _, err = impl.Comms.AddHost(&id.Permissioning, "0.0.0.0", permCert, connect.GetDefaultHostParams()) 38 | if err != nil { 39 | t.Errorf("Failed to add host: %+v", err) 40 | } 41 | permKey, err := utils.ReadFile(wd + "/../testutil/cmix.rip.key") 42 | if err != nil { 43 | t.Errorf("Failed to read test key file: %+v", err) 44 | } 45 | private, err := rsa2.GetScheme().Generate(csprng.NewSystemRNG(), 4096) 46 | if err != nil { 47 | t.Errorf("Failed to create private key: %+v", err) 48 | } 49 | public := private.Public() 50 | 51 | crt := public.MarshalPem() 52 | //uid := id.NewIdFromString("zezima", id.User, t) 53 | ////iid, err := ephemeral.GetIntermediaryId(uid) 54 | ////if err != nil { 55 | //// t.Errorf("Failed to get intermediary ID: %+v", err) 56 | ////} 57 | loadedPermKey, err := rsa.LoadPrivateKeyFromPem(permKey) 58 | if err != nil { 59 | t.Errorf("Failed to load perm key from bytes: %+v", err) 60 | } 61 | ts := time.Now().UnixNano() 62 | psig, err := registration.SignWithTimestamp(csprng.NewSystemRNG(), loadedPermKey, ts, string(crt)) 63 | 64 | token := "testtoken" 65 | reqTs := time.Now() 66 | sig, err := notifications.SignToken(private, token, constants.MessengerAndroid.String(), reqTs, notifications.RegisterTokenTag, csprng.NewSystemRNG()) 67 | 68 | err = impl.RegisterToken(&mixmessages.RegisterTokenRequest{ 69 | App: constants.MessengerAndroid.String(), 70 | Token: token, 71 | TransmissionRsaPem: crt, 72 | RegistrationTimestamp: ts, 73 | TransmissionRsaRegistrarSig: []byte("whoops"), 74 | RequestTimestamp: reqTs.UnixNano(), 75 | TokenSignature: sig, 76 | }) 77 | if err == nil { 78 | t.Fatal("Expected error verifying perm sig") 79 | } 80 | 81 | err = impl.RegisterToken(&mixmessages.RegisterTokenRequest{ 82 | App: constants.MessengerAndroid.String(), 83 | Token: token, 84 | TransmissionRsaPem: crt, 85 | RegistrationTimestamp: ts, 86 | TransmissionRsaRegistrarSig: psig, 87 | RequestTimestamp: reqTs.UnixNano(), 88 | TokenSignature: []byte("whoops"), 89 | }) 90 | if err == nil { 91 | t.Fatal("Expected error verifying token sig") 92 | } 93 | 94 | err = impl.RegisterToken(&mixmessages.RegisterTokenRequest{ 95 | App: constants.MessengerAndroid.String(), 96 | Token: token, 97 | TransmissionRsaPem: crt, 98 | RegistrationTimestamp: ts, 99 | TransmissionRsaRegistrarSig: psig, 100 | RequestTimestamp: reqTs.UnixNano(), 101 | TokenSignature: sig, 102 | }) 103 | if err != nil { 104 | t.Fatal(err) 105 | } 106 | } 107 | 108 | func TestImpl_RegisterTrackedID(t *testing.T) { 109 | impl := getNewImpl() 110 | var err error 111 | impl.Storage, err = storage.NewStorage("", "", "", "", "") 112 | if err != nil { 113 | t.Errorf("Failed to create storage: %+v", err) 114 | } 115 | wd, err := os.Getwd() 116 | if err != nil { 117 | t.Errorf("Failed to get working dir: %+v", err) 118 | } 119 | 120 | permCert, err := utils.ReadFile(wd + "/../testutil/cmix.rip.crt") 121 | if err != nil { 122 | t.Errorf("Failed to read test cert file: %+v", err) 123 | } 124 | _, err = impl.Comms.AddHost(&id.Permissioning, "0.0.0.0", permCert, connect.GetDefaultHostParams()) 125 | if err != nil { 126 | t.Errorf("Failed to add host: %+v", err) 127 | } 128 | permKey, err := utils.ReadFile(wd + "/../testutil/cmix.rip.key") 129 | if err != nil { 130 | t.Errorf("Failed to read test key file: %+v", err) 131 | } 132 | private, err := rsa2.GetScheme().Generate(csprng.NewSystemRNG(), 4096) 133 | if err != nil { 134 | t.Errorf("Failed to create private key: %+v", err) 135 | } 136 | public := private.Public() 137 | 138 | crt := public.MarshalPem() 139 | uid := id.NewIdFromString("zezima", id.User, t) 140 | iid, err := ephemeral.GetIntermediaryId(uid) 141 | if err != nil { 142 | t.Errorf("Failed to get intermediary ID: %+v", err) 143 | } 144 | loadedPermKey, err := rsa.LoadPrivateKeyFromPem(permKey) 145 | if err != nil { 146 | t.Errorf("Failed to load perm key from bytes: %+v", err) 147 | } 148 | ts := time.Now().UnixNano() 149 | psig, err := registration.SignWithTimestamp(csprng.NewSystemRNG(), loadedPermKey, ts, string(crt)) 150 | 151 | token := "testtoken" 152 | reqTs := time.Now() 153 | tokenSig, err := notifications.SignToken(private, token, constants.MessengerAndroid.String(), reqTs, notifications.RegisterTokenTag, csprng.NewSystemRNG()) 154 | 155 | err = impl.RegisterToken(&mixmessages.RegisterTokenRequest{ 156 | App: constants.MessengerAndroid.String(), 157 | Token: token, 158 | TransmissionRsaPem: crt, 159 | RegistrationTimestamp: ts, 160 | TransmissionRsaRegistrarSig: psig, 161 | RequestTimestamp: reqTs.UnixNano(), 162 | TokenSignature: tokenSig, 163 | }) 164 | if err != nil { 165 | t.Fatal(err) 166 | } 167 | 168 | iidSig, err := notifications.SignIdentity(private, [][]byte{iid}, reqTs, notifications.RegisterTrackedIDTag, csprng.NewSystemRNG()) 169 | 170 | err = impl.RegisterTrackedID(&mixmessages.RegisterTrackedIdRequest{ 171 | Request: &mixmessages.TrackedIntermediaryIdRequest{ 172 | TrackedIntermediaryID: [][]byte{iid}, 173 | TransmissionRsaPem: crt, 174 | RequestTimestamp: reqTs.UnixNano(), 175 | Signature: nil, 176 | }, 177 | RegistrationTimestamp: ts, 178 | TransmissionRsaRegistrarSig: psig, 179 | }) 180 | if err == nil { 181 | t.Fatal("Expected error verifying tracked ID sig") 182 | } 183 | 184 | err = impl.RegisterTrackedID(&mixmessages.RegisterTrackedIdRequest{ 185 | Request: &mixmessages.TrackedIntermediaryIdRequest{ 186 | TrackedIntermediaryID: [][]byte{iid}, 187 | TransmissionRsaPem: crt, 188 | RequestTimestamp: reqTs.UnixNano(), 189 | Signature: iidSig, 190 | }, 191 | RegistrationTimestamp: ts, 192 | TransmissionRsaRegistrarSig: psig, 193 | }) 194 | if err != nil { 195 | t.Fatal(err) 196 | } 197 | } 198 | 199 | func TestImpl_UnregisterToken(t *testing.T) { 200 | impl := getNewImpl() 201 | var err error 202 | impl.Storage, err = storage.NewStorage("", "", "", "", "") 203 | if err != nil { 204 | t.Errorf("Failed to create storage: %+v", err) 205 | } 206 | wd, err := os.Getwd() 207 | if err != nil { 208 | t.Errorf("Failed to get working dir: %+v", err) 209 | } 210 | 211 | permCert, err := utils.ReadFile(wd + "/../testutil/cmix.rip.crt") 212 | if err != nil { 213 | t.Errorf("Failed to read test cert file: %+v", err) 214 | } 215 | _, err = impl.Comms.AddHost(&id.Permissioning, "0.0.0.0", permCert, connect.GetDefaultHostParams()) 216 | if err != nil { 217 | t.Errorf("Failed to add host: %+v", err) 218 | } 219 | permKey, err := utils.ReadFile(wd + "/../testutil/cmix.rip.key") 220 | if err != nil { 221 | t.Errorf("Failed to read test key file: %+v", err) 222 | } 223 | private, err := rsa2.GetScheme().Generate(csprng.NewSystemRNG(), 4096) 224 | if err != nil { 225 | t.Errorf("Failed to create private key: %+v", err) 226 | } 227 | public := private.Public() 228 | 229 | crt := public.MarshalPem() 230 | //uid := id.NewIdFromString("zezima", id.User, t) 231 | ////iid, err := ephemeral.GetIntermediaryId(uid) 232 | ////if err != nil { 233 | //// t.Errorf("Failed to get intermediary ID: %+v", err) 234 | ////} 235 | loadedPermKey, err := rsa.LoadPrivateKeyFromPem(permKey) 236 | if err != nil { 237 | t.Errorf("Failed to load perm key from bytes: %+v", err) 238 | } 239 | ts := time.Now().UnixNano() 240 | psig, err := registration.SignWithTimestamp(csprng.NewSystemRNG(), loadedPermKey, ts, string(crt)) 241 | 242 | token := "testtoken" 243 | reqTs := time.Now() 244 | sig, err := notifications.SignToken(private, token, constants.MessengerAndroid.String(), reqTs, notifications.RegisterTokenTag, csprng.NewSystemRNG()) 245 | 246 | err = impl.RegisterToken(&mixmessages.RegisterTokenRequest{ 247 | App: constants.MessengerAndroid.String(), 248 | Token: token, 249 | TransmissionRsaPem: crt, 250 | RegistrationTimestamp: ts, 251 | TransmissionRsaRegistrarSig: psig, 252 | RequestTimestamp: reqTs.UnixNano(), 253 | TokenSignature: sig, 254 | }) 255 | if err != nil { 256 | t.Fatal(err) 257 | } 258 | 259 | err = impl.UnregisterToken(&mixmessages.UnregisterTokenRequest{ 260 | App: constants.MessengerAndroid.String(), 261 | Token: token, 262 | TransmissionRsaPem: crt, 263 | RequestTimestamp: reqTs.UnixNano(), 264 | TokenSignature: sig, 265 | }) 266 | if err == nil { 267 | t.Fatal("Expected error verifying register signature") 268 | } 269 | 270 | unregSig, err := notifications.SignToken(private, token, constants.MessengerAndroid.String(), reqTs, notifications.UnregisterTokenTag, csprng.NewSystemRNG()) 271 | err = impl.UnregisterToken(&mixmessages.UnregisterTokenRequest{ 272 | App: constants.MessengerAndroid.String(), 273 | Token: token, 274 | TransmissionRsaPem: crt, 275 | RequestTimestamp: reqTs.UnixNano(), 276 | TokenSignature: unregSig, 277 | }) 278 | if err != nil { 279 | t.Fatal(err) 280 | } 281 | } 282 | 283 | func TestImpl_UnregisterTrackedID(t *testing.T) { 284 | impl := getNewImpl() 285 | var err error 286 | impl.Storage, err = storage.NewStorage("", "", "", "", "") 287 | if err != nil { 288 | t.Errorf("Failed to create storage: %+v", err) 289 | } 290 | wd, err := os.Getwd() 291 | if err != nil { 292 | t.Errorf("Failed to get working dir: %+v", err) 293 | } 294 | 295 | permCert, err := utils.ReadFile(wd + "/../testutil/cmix.rip.crt") 296 | if err != nil { 297 | t.Errorf("Failed to read test cert file: %+v", err) 298 | } 299 | _, err = impl.Comms.AddHost(&id.Permissioning, "0.0.0.0", permCert, connect.GetDefaultHostParams()) 300 | if err != nil { 301 | t.Errorf("Failed to add host: %+v", err) 302 | } 303 | permKey, err := utils.ReadFile(wd + "/../testutil/cmix.rip.key") 304 | if err != nil { 305 | t.Errorf("Failed to read test key file: %+v", err) 306 | } 307 | private, err := rsa2.GetScheme().Generate(csprng.NewSystemRNG(), 4096) 308 | if err != nil { 309 | t.Errorf("Failed to create private key: %+v", err) 310 | } 311 | public := private.Public() 312 | 313 | crt := public.MarshalPem() 314 | uid := id.NewIdFromString("zezima", id.User, t) 315 | iid, err := ephemeral.GetIntermediaryId(uid) 316 | if err != nil { 317 | t.Errorf("Failed to get intermediary ID: %+v", err) 318 | } 319 | loadedPermKey, err := rsa.LoadPrivateKeyFromPem(permKey) 320 | if err != nil { 321 | t.Errorf("Failed to load perm key from bytes: %+v", err) 322 | } 323 | ts := time.Now().UnixNano() 324 | psig, err := registration.SignWithTimestamp(csprng.NewSystemRNG(), loadedPermKey, ts, string(crt)) 325 | 326 | token := "testtoken" 327 | reqTs := time.Now() 328 | tokenSig, err := notifications.SignToken(private, token, constants.MessengerAndroid.String(), reqTs, notifications.RegisterTokenTag, csprng.NewSystemRNG()) 329 | 330 | err = impl.RegisterToken(&mixmessages.RegisterTokenRequest{ 331 | App: constants.MessengerAndroid.String(), 332 | Token: token, 333 | TransmissionRsaPem: crt, 334 | RegistrationTimestamp: ts, 335 | TransmissionRsaRegistrarSig: psig, 336 | RequestTimestamp: reqTs.UnixNano(), 337 | TokenSignature: tokenSig, 338 | }) 339 | if err != nil { 340 | t.Fatal(err) 341 | } 342 | 343 | iidSig, err := notifications.SignIdentity(private, [][]byte{iid}, reqTs, notifications.RegisterTrackedIDTag, csprng.NewSystemRNG()) 344 | 345 | err = impl.RegisterTrackedID(&mixmessages.RegisterTrackedIdRequest{ 346 | Request: &mixmessages.TrackedIntermediaryIdRequest{ 347 | TrackedIntermediaryID: [][]byte{iid}, 348 | TransmissionRsaPem: crt, 349 | RequestTimestamp: reqTs.UnixNano(), 350 | Signature: iidSig, 351 | }, 352 | RegistrationTimestamp: ts, 353 | TransmissionRsaRegistrarSig: psig, 354 | }) 355 | if err != nil { 356 | t.Fatal(err) 357 | } 358 | 359 | err = impl.UnregisterTrackedID(&mixmessages.TrackedIntermediaryIdRequest{ 360 | TrackedIntermediaryID: [][]byte{iid}, 361 | TransmissionRsaPem: crt, 362 | RequestTimestamp: reqTs.UnixNano(), 363 | Signature: iidSig, 364 | }) 365 | if err == nil { 366 | t.Fatal("Expected err attempting to unregister with same sig") 367 | } 368 | 369 | unregSig, err := notifications.SignIdentity(private, [][]byte{iid}, reqTs, notifications.UnregisterTrackedIDTag, csprng.NewSystemRNG()) 370 | err = impl.UnregisterTrackedID(&mixmessages.TrackedIntermediaryIdRequest{ 371 | TrackedIntermediaryID: [][]byte{iid}, 372 | TransmissionRsaPem: crt, 373 | RequestTimestamp: reqTs.UnixNano(), 374 | Signature: unregSig, 375 | }) 376 | if err != nil { 377 | t.Fatal(err) 378 | } 379 | } 380 | -------------------------------------------------------------------------------- /storage/databaseImpl_test.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "gitlab.com/elixxir/crypto/hash" 7 | "gitlab.com/elixxir/notifications-bot/constants" 8 | "gitlab.com/xx_network/crypto/csprng" 9 | "gitlab.com/xx_network/crypto/signature/rsa" 10 | "gitlab.com/xx_network/primitives/id" 11 | "gitlab.com/xx_network/primitives/id/ephemeral" 12 | "gorm.io/gorm" 13 | "testing" 14 | ) 15 | 16 | func TestDatabaseImpl_UpsertState(t *testing.T) { 17 | db, err := newDatabase("", "", "TestDatabaseImpl_UpsertState", "", "") 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | expectedState := &State{ 22 | Key: "state_key", 23 | Value: "state_val", 24 | } 25 | err = db.UpsertState(expectedState) 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | retrievedState, err := db.GetStateValue(expectedState.Key) 31 | if err != nil { 32 | t.Fatal(err) 33 | } 34 | 35 | if retrievedState != expectedState.Value { 36 | t.Fatalf("Did not get expected state value\n\tExpected: %s\n\tReceived: %s\n", expectedState.Value, retrievedState) 37 | } 38 | 39 | expectedState2 := &State{ 40 | Key: expectedState.Key, 41 | Value: "state_value_two", 42 | } 43 | err = db.UpsertState(expectedState2) 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | 48 | retrievedState, err = db.GetStateValue(expectedState.Key) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | if retrievedState != expectedState2.Value { 54 | t.Fatalf("State value did not change after upsert\n\tExpected: %s\n\tReceived: %s\n", expectedState2.Value, retrievedState) 55 | } 56 | } 57 | 58 | func TestDatabaseImpl_GetStateValue(t *testing.T) { 59 | db, err := newDatabase("", "", "TestDatabaseImpl_GetStateValue", "", "") 60 | if err != nil { 61 | t.Fatal(err) 62 | } 63 | expectedState := &State{ 64 | Key: "state_key", 65 | Value: "state_val", 66 | } 67 | 68 | retrievedState, err := db.GetStateValue(expectedState.Key) 69 | if err == nil { 70 | t.Fatalf("Should have received error when state not inserted, instead got %s", retrievedState) 71 | } 72 | if retrievedState != "" { 73 | t.Fatal("Should not have received data for state not yet inserted") 74 | } 75 | 76 | err = db.UpsertState(expectedState) 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | 81 | retrievedState, err = db.GetStateValue(expectedState.Key) 82 | if err != nil { 83 | t.Fatal(err) 84 | } 85 | 86 | if retrievedState != expectedState.Value { 87 | t.Fatalf("Did not receive expected value\n\tExpected: %s\n\tReceived: %s\n", expectedState.Value, retrievedState) 88 | } 89 | } 90 | 91 | func TestDatabaseImpl_DeleteToken(t *testing.T) { 92 | db, err := newDatabase("", "", "TestDatabaseImpl_DeleteToken", "", "") 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | 97 | identity := generateTestIdentity(t) 98 | u := generateTestUser(t) 99 | 100 | err = db.insertUser(u) 101 | if err != nil { 102 | t.Fatal(err) 103 | } 104 | 105 | err = db.insertIdentity(&identity) 106 | if err != nil { 107 | t.Fatal(err) 108 | } 109 | 110 | token := "apnstoken01" 111 | err = db.registerForNotifications(u, identity, Token{ 112 | Token: token, 113 | App: constants.MessengerIOS.String(), 114 | TransmissionRSAHash: u.TransmissionRSAHash, 115 | }) 116 | if err != nil { 117 | t.Fatal(err) 118 | } 119 | 120 | receivedUser, err := db.GetUser(u.TransmissionRSAHash) 121 | if err != nil { 122 | t.Fatal(err) 123 | } 124 | 125 | if len(receivedUser.Tokens) != 1 { 126 | t.Fatalf("User should have %d tokens registered, instead had %d", 1, len(receivedUser.Tokens)) 127 | } 128 | 129 | err = db.DeleteToken(token) 130 | if err != nil { 131 | t.Fatal(err) 132 | } 133 | 134 | receivedUser, err = db.GetUser(u.TransmissionRSAHash) 135 | if err != nil { 136 | t.Fatal(err) 137 | } 138 | 139 | if len(receivedUser.Tokens) != 0 { 140 | t.Fatalf("User should have %d tokens registered, instead had %d", 1, len(receivedUser.Tokens)) 141 | } 142 | } 143 | 144 | func TestDatabaseImpl_insertUser(t *testing.T) { 145 | db, err := newDatabase("", "", "TestDatabaseImpl_insertUser", "", "") 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | 150 | u := generateTestUser(t) 151 | 152 | err = db.insertUser(u) 153 | if err != nil { 154 | t.Fatal(err) 155 | } 156 | 157 | receivedUser, err := db.GetUser(u.TransmissionRSAHash) 158 | if err != nil { 159 | t.Fatal(err) 160 | } 161 | if !bytes.Equal(u.TransmissionRSA, receivedUser.TransmissionRSA) || !bytes.Equal(u.TransmissionRSAHash, receivedUser.TransmissionRSAHash) { 162 | t.Fatalf("Did not receive expected user data\n\tExpected: %+v\n\tReceived: %+v\n", u, receivedUser) 163 | } 164 | 165 | err = db.insertUser(u) 166 | if err != nil { 167 | t.Fatal(err) 168 | } 169 | 170 | receivedUser, err = db.GetUser(u.TransmissionRSAHash) 171 | if err != nil { 172 | t.Fatal(err) 173 | } 174 | if !bytes.Equal(u.TransmissionRSA, receivedUser.TransmissionRSA) || !bytes.Equal(u.TransmissionRSAHash, receivedUser.TransmissionRSAHash) { 175 | t.Fatalf("Did not receive expected user data\n\tExpected: %+v\n\tReceived: %+v\n", u, receivedUser) 176 | } 177 | } 178 | 179 | func TestDatabaseImpl_GetUser(t *testing.T) { 180 | db, err := newDatabase("", "", "TestDatabaseImpl_GetUser", "", "") 181 | if err != nil { 182 | t.Fatal(err) 183 | } 184 | 185 | u := generateTestUser(t) 186 | 187 | receivedUser, err := db.GetUser(u.TransmissionRSAHash) 188 | if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { 189 | t.Fatalf("Expected gorm.ErrRecordNotFound when no user exists, instead got %+v", err) 190 | } 191 | 192 | err = db.insertUser(u) 193 | if err != nil { 194 | t.Fatal(err) 195 | } 196 | 197 | receivedUser, err = db.GetUser(u.TransmissionRSAHash) 198 | if err != nil { 199 | t.Fatal(err) 200 | } 201 | if !bytes.Equal(u.TransmissionRSA, receivedUser.TransmissionRSA) || !bytes.Equal(u.TransmissionRSAHash, receivedUser.TransmissionRSAHash) { 202 | t.Fatalf("Did not receive expected user data\n\tExpected: %+v\n\tReceived: %+v\n", u, receivedUser) 203 | } 204 | } 205 | 206 | func TestDatabaseImpl_deleteUser(t *testing.T) { 207 | db, err := newDatabase("", "", "TestDatabaseImpl_deleteUser", "", "") 208 | if err != nil { 209 | t.Fatal(err) 210 | } 211 | 212 | u := generateTestUser(t) 213 | 214 | receivedUser, err := db.GetUser(u.TransmissionRSAHash) 215 | if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { 216 | t.Fatalf("Expected gorm.ErrRecordNotFound when no user exists, instead got %+v", err) 217 | } 218 | 219 | err = db.insertUser(u) 220 | if err != nil { 221 | t.Fatal(err) 222 | } 223 | 224 | receivedUser, err = db.GetUser(u.TransmissionRSAHash) 225 | if err != nil { 226 | t.Fatal(err) 227 | } 228 | if !bytes.Equal(u.TransmissionRSA, receivedUser.TransmissionRSA) || !bytes.Equal(u.TransmissionRSAHash, receivedUser.TransmissionRSAHash) { 229 | t.Fatalf("Did not receive expected user data\n\tExpected: %+v\n\tReceived: %+v\n", u, receivedUser) 230 | } 231 | 232 | err = db.deleteUser(u.TransmissionRSAHash) 233 | if err != nil { 234 | t.Fatal(err) 235 | } 236 | 237 | receivedUser, err = db.GetUser(u.TransmissionRSAHash) 238 | if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { 239 | t.Fatalf("Expected gorm.ErrRecordNotFound when no user exists, instead got %+v", err) 240 | } 241 | } 242 | 243 | func TestDatabaseImpl_GetAllUsers(t *testing.T) { 244 | db, err := newDatabase("", "", "TestDatabaseImpl_GetAllUsers", "", "") 245 | if err != nil { 246 | t.Fatal(err) 247 | } 248 | 249 | startUsers, err := db.GetAllUsers() 250 | if err != nil { 251 | t.Fatal(err) 252 | } 253 | if len(startUsers) != 0 { 254 | t.Fatalf("Did not receive expected user count\n\tExpected: %d\n\tReceived: %d\n", 0, len(startUsers)) 255 | } 256 | 257 | expectedUsers := 5 258 | for i := 1; i <= expectedUsers; i++ { 259 | u := generateTestUser(t) 260 | err = db.insertUser(u) 261 | if err != nil { 262 | t.Fatal(err) 263 | } 264 | 265 | receivedUsers, err := db.GetAllUsers() 266 | if err != nil { 267 | t.Fatal(err) 268 | } 269 | if len(receivedUsers) != i { 270 | t.Fatalf("Did not receive expected user count\n\tExpected: %d\n\tReceived: %d\n", i, len(receivedUsers)) 271 | } 272 | } 273 | 274 | } 275 | 276 | func TestDatabaseImpl_getIdentity(t *testing.T) { 277 | db, err := newDatabase("", "", "TestDatabaseImpl_getIdentity", "", "") 278 | if err != nil { 279 | t.Fatal(err) 280 | } 281 | 282 | identity := generateTestIdentity(t) 283 | 284 | err = db.insertIdentity(&identity) 285 | if err != nil { 286 | t.Fatal(err) 287 | } 288 | 289 | receivedIdentity, err := db.GetIdentity(identity.IntermediaryId) 290 | if err != nil { 291 | t.Fatal(err) 292 | } 293 | if !bytes.Equal(receivedIdentity.IntermediaryId, identity.IntermediaryId) || receivedIdentity.OffsetNum != identity.OffsetNum { 294 | t.Fatalf("Did not receive expected identity data\n\tExpected: %+v\n\tReceived: %+v\n", identity, receivedIdentity) 295 | } 296 | } 297 | 298 | func TestDatabaseImpl_getIdentitiesByOffset(t *testing.T) { 299 | db, err := newDatabase("", "", "TestDatabaseImpl_getIdentitiesByOffset", "", "") 300 | if err != nil { 301 | t.Fatal(err) 302 | } 303 | 304 | identity := generateTestIdentity(t) 305 | 306 | err = db.insertIdentity(&identity) 307 | if err != nil { 308 | t.Fatal(err) 309 | } 310 | 311 | offsetIdentities, err := db.getIdentitiesByOffset(identity.OffsetNum) 312 | if err != nil { 313 | t.Fatal(err) 314 | } 315 | if len(offsetIdentities) != 1 { 316 | t.Fatalf("Did not receive expected offset identities") 317 | } 318 | 319 | offsetIdentities, err = db.getIdentitiesByOffset(identity.OffsetNum + 1) 320 | if err != nil { 321 | t.Fatal(err) 322 | } 323 | if len(offsetIdentities) != 0 { 324 | t.Fatalf("Did not receive expected offset identities") 325 | } 326 | 327 | } 328 | 329 | func TestDatabaseImpl_GetOrphanedIdentities(t *testing.T) { 330 | db, err := newDatabase("", "", "TestDatabaseImpl_GetOrphanedIdentities", "", "") 331 | if err != nil { 332 | t.Fatal(err) 333 | } 334 | 335 | identity := generateTestIdentity(t) 336 | 337 | err = db.insertIdentity(&identity) 338 | if err != nil { 339 | t.Fatal(err) 340 | } 341 | 342 | orphaned, err := db.GetOrphanedIdentities() 343 | if err != nil { 344 | t.Fatal(err) 345 | } 346 | if len(orphaned) != 1 { 347 | t.Fatalf("Did not receive expected count of orphaned identities\n\tExpected: %+v\n\tReceived: %+v\n", 1, len(orphaned)) 348 | } 349 | 350 | identity2 := generateTestIdentity(t) 351 | 352 | err = db.insertIdentity(&identity2) 353 | if err != nil { 354 | t.Fatal(err) 355 | } 356 | 357 | orphaned, err = db.GetOrphanedIdentities() 358 | if err != nil { 359 | t.Fatal(err) 360 | } 361 | if len(orphaned) != 2 { 362 | t.Fatalf("Did not receive expected count of orphaned identities\n\tExpected: %+v\n\tReceived: %+v\n", 2, len(orphaned)) 363 | } 364 | 365 | err = db.insertEphemeral(&Ephemeral{ 366 | ID: 0, 367 | IntermediaryId: identity.IntermediaryId, 368 | EphemeralId: 123, 369 | Epoch: 123, 370 | }) 371 | if err != nil { 372 | t.Fatal(err) 373 | } 374 | 375 | orphaned, err = db.GetOrphanedIdentities() 376 | if err != nil { 377 | t.Fatal(err) 378 | } 379 | if len(orphaned) != 1 { 380 | t.Fatalf("Did not receive expected count of orphaned identities\n\tExpected: %+v\n\tReceived: %+v\n", 1, len(orphaned)) 381 | } 382 | } 383 | 384 | func TestDatabaseImpl_insertEphemeral(t *testing.T) { 385 | db, err := newDatabase("", "", "TestDatabaseImpl_insertEphemeral", "", "") 386 | if err != nil { 387 | t.Fatal(err) 388 | } 389 | 390 | identity := generateTestIdentity(t) 391 | 392 | e1 := &Ephemeral{ 393 | ID: 0, 394 | IntermediaryId: identity.IntermediaryId, 395 | EphemeralId: 123, 396 | Epoch: 123, 397 | } 398 | 399 | err = db.insertEphemeral(e1) 400 | if err == nil { 401 | t.Fatal("Should fail to insert ephemeral with no associated identity") 402 | } 403 | 404 | err = db.insertIdentity(&identity) 405 | if err != nil { 406 | t.Fatal(err) 407 | } 408 | 409 | err = db.insertEphemeral(e1) 410 | if err != nil { 411 | t.Fatal(err) 412 | } 413 | 414 | } 415 | 416 | func TestDatabaseImpl_GetEphemeral(t *testing.T) { 417 | db, err := newDatabase("", "", "TestDatabaseImpl_GetEphemeral", "", "") 418 | if err != nil { 419 | t.Fatal(err) 420 | } 421 | 422 | identity := generateTestIdentity(t) 423 | 424 | e1 := &Ephemeral{ 425 | ID: 0, 426 | IntermediaryId: identity.IntermediaryId, 427 | EphemeralId: 123, 428 | Epoch: 123, 429 | } 430 | 431 | ephs, err := db.GetEphemeral(e1.EphemeralId) 432 | if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { 433 | t.Fatal(err) 434 | } 435 | if len(ephs) != 0 { 436 | t.Fatalf("Did not receive expected ephemerals\n\tExpected: %+v\n\tReceived: %+v\n", ephs[0], e1) 437 | } 438 | 439 | err = db.insertIdentity(&identity) 440 | if err != nil { 441 | t.Fatal(err) 442 | } 443 | 444 | err = db.insertEphemeral(e1) 445 | if err != nil { 446 | t.Fatal(err) 447 | } 448 | 449 | ephs, err = db.GetEphemeral(e1.EphemeralId) 450 | if err != nil { 451 | t.Fatal(err) 452 | } 453 | if len(ephs) != 1 || ephs[0].EphemeralId != e1.EphemeralId { 454 | t.Fatalf("Did not receive expected ephemerals\n\tExpected: %+v\n\tReceived: %+v\n", ephs[0], e1) 455 | } 456 | } 457 | 458 | func TestDatabaseImpl_DeleteOldEphemerals(t *testing.T) { 459 | db, err := newDatabase("", "", "TestDatabaseImpl_DeleteOldEphemerals", "", "") 460 | if err != nil { 461 | t.Fatal(err) 462 | } 463 | 464 | identity := generateTestIdentity(t) 465 | 466 | e1 := &Ephemeral{ 467 | ID: 0, 468 | IntermediaryId: identity.IntermediaryId, 469 | EphemeralId: 123, 470 | Epoch: 123, 471 | } 472 | 473 | err = db.insertIdentity(&identity) 474 | if err != nil { 475 | t.Fatal(err) 476 | } 477 | 478 | err = db.insertEphemeral(e1) 479 | if err != nil { 480 | t.Fatal(err) 481 | } 482 | 483 | ephs, err := db.GetEphemeral(e1.EphemeralId) 484 | if err != nil { 485 | t.Fatal(err) 486 | } 487 | if len(ephs) != 1 || ephs[0].EphemeralId != e1.EphemeralId { 488 | t.Fatalf("Did not receive expected ephemerals\n\tExpected: %+v\n\tReceived: %+v\n", ephs[0], e1) 489 | } 490 | 491 | err = db.DeleteOldEphemerals(e1.Epoch) 492 | if err != nil { 493 | t.Fatal(err) 494 | } 495 | 496 | ephs, err = db.GetEphemeral(e1.EphemeralId) 497 | if err != nil { 498 | t.Fatal(err) 499 | } 500 | if len(ephs) != 1 || ephs[0].EphemeralId != e1.EphemeralId { 501 | t.Fatalf("Did not receive expected ephemerals\n\tExpected: %+v\n\tReceived: %+v\n", ephs[0], e1) 502 | } 503 | 504 | err = db.DeleteOldEphemerals(e1.Epoch + 1) 505 | if err != nil { 506 | t.Fatal(err) 507 | } 508 | 509 | ephs, err = db.GetEphemeral(e1.EphemeralId) 510 | if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { 511 | t.Fatal("Did not receive expected gorm.ErrRecordNotFound") 512 | } 513 | if len(ephs) != 0 { 514 | t.Fatalf("Did not receive expected ephemerals\n\tExpected: %+v\n\tReceived: %+v\n", ephs[0], e1) 515 | } 516 | } 517 | 518 | func TestDatabaseImpl_GetLatestEphemeral(t *testing.T) { 519 | db, err := newDatabase("", "", "TestDatabaseImpl_GetLatestEphemeral", "", "") 520 | if err != nil { 521 | t.Fatal(err) 522 | } 523 | 524 | identity := generateTestIdentity(t) 525 | 526 | e1 := &Ephemeral{ 527 | IntermediaryId: identity.IntermediaryId, 528 | EphemeralId: 123, 529 | Epoch: 123, 530 | } 531 | 532 | err = db.insertIdentity(&identity) 533 | if err != nil { 534 | t.Fatal(err) 535 | } 536 | 537 | err = db.insertEphemeral(e1) 538 | if err != nil { 539 | t.Fatal(err) 540 | } 541 | 542 | latest, err := db.GetLatestEphemeral() 543 | if err != nil { 544 | t.Fatal(err) 545 | } 546 | if latest.EphemeralId != e1.EphemeralId { 547 | t.Fatalf("Did not receive expected ephemeral\n\tExpected: %d\n\tReceived: %d\n", e1.ID, latest.ID) 548 | } 549 | 550 | e2 := &Ephemeral{ 551 | IntermediaryId: identity.IntermediaryId, 552 | EphemeralId: 124, 553 | Epoch: 123, 554 | } 555 | 556 | err = db.insertEphemeral(e2) 557 | if err != nil { 558 | t.Fatal(err) 559 | } 560 | 561 | latest, err = db.GetLatestEphemeral() 562 | if err != nil { 563 | t.Fatal(err) 564 | } 565 | if latest.EphemeralId != e2.EphemeralId { 566 | t.Fatalf("Did not receive expected ephemeral\n\tExpected: %d\n\tReceived: %d\n", e2.ID, latest.ID) 567 | } 568 | 569 | e3 := &Ephemeral{ 570 | IntermediaryId: identity.IntermediaryId, 571 | EphemeralId: 124, 572 | Epoch: 122, 573 | } 574 | 575 | err = db.insertEphemeral(e3) 576 | if err != nil { 577 | t.Fatal(err) 578 | } 579 | 580 | latest, err = db.GetLatestEphemeral() 581 | if err != nil { 582 | t.Fatal(err) 583 | } 584 | if latest.EphemeralId != e2.EphemeralId { 585 | t.Fatalf("Did not receive expected ephemeral\n\tExpected: %d\n\tReceived: %d\n", e2.ID, latest.ID) 586 | } 587 | 588 | e4 := &Ephemeral{ 589 | IntermediaryId: identity.IntermediaryId, 590 | EphemeralId: 126, 591 | Epoch: 125, 592 | } 593 | 594 | err = db.insertEphemeral(e4) 595 | if err != nil { 596 | t.Fatal(err) 597 | } 598 | 599 | latest, err = db.GetLatestEphemeral() 600 | if err != nil { 601 | t.Fatal(err) 602 | } 603 | if latest.EphemeralId != e4.EphemeralId { 604 | t.Fatalf("Did not receive expected ephemeral\n\tExpected: %d\n\tReceived: %d\n", e4.ID, latest.ID) 605 | } 606 | 607 | e5 := &Ephemeral{ 608 | IntermediaryId: identity.IntermediaryId, 609 | EphemeralId: 127, 610 | Epoch: 121, 611 | } 612 | 613 | err = db.insertEphemeral(e5) 614 | if err != nil { 615 | t.Fatal(err) 616 | } 617 | 618 | latest, err = db.GetLatestEphemeral() 619 | if err != nil { 620 | t.Fatal(err) 621 | } 622 | if latest.EphemeralId != e4.EphemeralId { 623 | t.Fatalf("Did not receive expected ephemeral\n\tExpected: %d\n\tReceived: %d\n", e4.ID, latest.ID) 624 | } 625 | } 626 | 627 | func TestDatabaseImpl_registerForNotifications(t *testing.T) { 628 | db, err := newDatabase("", "", "TestDatabaseImpl_registerForNotifications", "", "") 629 | if err != nil { 630 | t.Fatal(err) 631 | } 632 | 633 | identity := generateTestIdentity(t) 634 | u := generateTestUser(t) 635 | 636 | err = db.insertUser(u) 637 | if err != nil { 638 | t.Fatal(err) 639 | } 640 | 641 | err = db.insertIdentity(&identity) 642 | if err != nil { 643 | t.Fatal(err) 644 | } 645 | 646 | ru, err := db.GetUser(u.TransmissionRSAHash) 647 | if err != nil { 648 | t.Fatal(err) 649 | } 650 | if len(ru.Tokens) != 0 || len(ru.Identities) != 0 || !bytes.Equal(ru.TransmissionRSAHash, u.TransmissionRSAHash) { 651 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u, ru) 652 | } 653 | 654 | token := "apnstoken02" 655 | err = db.registerForNotifications(u, identity, Token{ 656 | Token: token, 657 | App: constants.MessengerIOS.String(), 658 | }) 659 | if err != nil { 660 | t.Fatal(err) 661 | } 662 | 663 | identity2 := generateTestIdentity(t) 664 | u2 := generateTestUser(t) 665 | 666 | err = db.insertUser(u2) 667 | if err != nil { 668 | t.Fatal(err) 669 | } 670 | 671 | err = db.insertIdentity(&identity) 672 | if err != nil { 673 | t.Fatal(err) 674 | } 675 | 676 | token2 := "fcm:token2" 677 | err = db.registerForNotifications(u2, identity2, Token{ 678 | Token: token2, 679 | App: constants.MessengerAndroid.String(), 680 | }) 681 | if err != nil { 682 | t.Fatal(err) 683 | } 684 | 685 | ru, err = db.GetUser(u2.TransmissionRSAHash) 686 | if err != nil { 687 | t.Fatal(err) 688 | } 689 | if len(ru.Tokens) != 1 || len(ru.Identities) != 1 || !bytes.Equal(ru.TransmissionRSAHash, u2.TransmissionRSAHash) { 690 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u2, ru) 691 | } 692 | 693 | err = db.registerForNotifications(u, identity2, Token{ 694 | Token: token, 695 | App: constants.MessengerIOS.String(), 696 | }) 697 | if err != nil { 698 | t.Fatal(err) 699 | } 700 | ru, err = db.GetUser(u.TransmissionRSAHash) 701 | if err != nil { 702 | t.Fatal(err) 703 | } 704 | if len(ru.Tokens) != 1 || len(ru.Identities) != 2 || !bytes.Equal(ru.TransmissionRSAHash, u.TransmissionRSAHash) { 705 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u2, ru) 706 | } 707 | 708 | err = db.registerForNotifications(u, identity2, Token{ 709 | Token: token2, 710 | App: constants.MessengerAndroid.String(), 711 | }) 712 | if err != nil { 713 | t.Fatal(err) 714 | } 715 | ru, err = db.GetUser(u.TransmissionRSAHash) 716 | if err != nil { 717 | t.Fatal(err) 718 | } 719 | if len(ru.Tokens) != 2 || len(ru.Identities) != 2 || !bytes.Equal(ru.TransmissionRSAHash, u.TransmissionRSAHash) { 720 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u, ru) 721 | } 722 | 723 | err = db.registerForNotifications(u2, identity, Token{ 724 | Token: token, 725 | App: constants.MessengerIOS.String(), 726 | }) 727 | if err != nil { 728 | t.Fatal(err) 729 | } 730 | ru, err = db.GetUser(u2.TransmissionRSAHash) 731 | if err != nil { 732 | t.Fatal(err) 733 | } 734 | if len(ru.Tokens) != 2 || len(ru.Identities) != 2 || !bytes.Equal(ru.TransmissionRSAHash, u2.TransmissionRSAHash) { 735 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u, ru) 736 | } 737 | } 738 | 739 | func TestDatabaseImpl_unregisterIdentities(t *testing.T) { 740 | db, err := newDatabase("", "", "TestDatabaseImpl_unregisterIdentities", "", "") 741 | if err != nil { 742 | t.Fatal(err) 743 | } 744 | 745 | u := generateTestUser(t) 746 | identity := generateTestIdentity(t) 747 | 748 | err = db.insertUser(u) 749 | if err != nil { 750 | t.Fatal(err) 751 | } 752 | 753 | err = db.unregisterIdentities(u, []Identity{identity}) 754 | if err != nil { 755 | t.Fatalf("Should not return error even if identity doesn't exist: %+v", err) 756 | } 757 | 758 | err = db.insertUser(u) 759 | if err != nil { 760 | t.Fatal(err) 761 | } 762 | 763 | err = db.insertIdentity(&identity) 764 | if err != nil { 765 | t.Fatal(err) 766 | } 767 | 768 | token := "apnstoken02" 769 | err = db.registerForNotifications(u, identity, Token{ 770 | Token: token, 771 | App: constants.MessengerIOS.String(), 772 | }) 773 | if err != nil { 774 | t.Fatal(err) 775 | } 776 | 777 | identity2 := generateTestIdentity(t) 778 | 779 | token2 := "fcm:token2" 780 | err = db.registerForNotifications(u, identity2, Token{ 781 | Token: token2, 782 | App: constants.MessengerAndroid.String(), 783 | }) 784 | if err != nil { 785 | t.Fatal(err) 786 | } 787 | 788 | ru, err := db.GetUser(u.TransmissionRSAHash) 789 | if err != nil { 790 | t.Fatal(err) 791 | } 792 | if len(ru.Tokens) != 2 || len(ru.Identities) != 2 || !bytes.Equal(ru.TransmissionRSAHash, u.TransmissionRSAHash) { 793 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u, ru) 794 | } 795 | 796 | err = db.unregisterIdentities(u, []Identity{identity}) 797 | if err != nil { 798 | t.Fatalf("Failed to unregister identity: %+v", err) 799 | } 800 | 801 | ru, err = db.GetUser(u.TransmissionRSAHash) 802 | if err != nil { 803 | t.Fatal(err) 804 | } 805 | if len(ru.Tokens) != 2 || len(ru.Identities) != 1 || !bytes.Equal(ru.TransmissionRSAHash, u.TransmissionRSAHash) { 806 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u, ru) 807 | } 808 | } 809 | 810 | func TestDatabaseImpl_unregisterTokens(t *testing.T) { 811 | db, err := newDatabase("", "", "TestDatabaseImpl_unregisterTokens", "", "") 812 | if err != nil { 813 | t.Fatal(err) 814 | } 815 | 816 | identity := generateTestIdentity(t) 817 | u := generateTestUser(t) 818 | 819 | err = db.insertUser(u) 820 | if err != nil { 821 | t.Fatal(err) 822 | } 823 | 824 | token := "apnstoken02" 825 | err = db.unregisterTokens(u, []Token{Token{Token: token}}) 826 | if err != nil { 827 | t.Fatalf("Should not return error even if identity doesn't exist: %+v", err) 828 | } 829 | 830 | _, err = db.GetUser(u.TransmissionRSAHash) 831 | if err != nil { 832 | t.Fatalf("User should still exist after unregister, instead got: %+v", err) 833 | } 834 | 835 | err = db.insertUser(u) 836 | if err != nil { 837 | t.Fatal(err) 838 | } 839 | 840 | err = db.insertIdentity(&identity) 841 | if err != nil { 842 | t.Fatal(err) 843 | } 844 | 845 | err = db.registerForNotifications(u, identity, Token{ 846 | Token: token, 847 | App: constants.MessengerIOS.String(), 848 | }) 849 | if err != nil { 850 | t.Fatal(err) 851 | } 852 | 853 | identity2 := generateTestIdentity(t) 854 | 855 | err = db.insertIdentity(&identity) 856 | if err != nil { 857 | t.Fatal(err) 858 | } 859 | 860 | token2 := "fcm:token2" 861 | err = db.registerForNotifications(u, identity2, Token{ 862 | Token: token2, 863 | App: constants.MessengerAndroid.String(), 864 | }) 865 | if err != nil { 866 | t.Fatal(err) 867 | } 868 | 869 | ru, err := db.GetUser(u.TransmissionRSAHash) 870 | if err != nil { 871 | t.Fatal(err) 872 | } 873 | if len(ru.Tokens) != 2 || len(ru.Identities) != 2 || !bytes.Equal(ru.TransmissionRSAHash, u.TransmissionRSAHash) { 874 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u, ru) 875 | } 876 | 877 | err = db.unregisterTokens(u, []Token{Token{Token: token}}) 878 | if err != nil { 879 | t.Fatalf("Failed to unregister token: %+v", err) 880 | } 881 | ru, err = db.GetUser(u.TransmissionRSAHash) 882 | if err != nil { 883 | t.Fatal(err) 884 | } 885 | if len(ru.Tokens) != 1 || len(ru.Identities) != 2 || !bytes.Equal(ru.TransmissionRSAHash, u.TransmissionRSAHash) { 886 | t.Fatalf("Did not receive expected user\n\tExpected: %+v\n\t: Receiveid: %+v\n", u, ru) 887 | } 888 | 889 | err = db.unregisterTokens(u, []Token{Token{Token: token2}}) 890 | if err != nil { 891 | t.Fatalf("Failed to unregister token: %+v", err) 892 | } 893 | _, err = db.GetUser(u.TransmissionRSAHash) 894 | if err != nil { 895 | t.Fatalf("User should still exist after unregister, instead got: %+v", err) 896 | } 897 | if len(u.Identities) != 2 { 898 | t.Fatalf("User identities should be unaffected by token removal") 899 | } 900 | 901 | _, err = db.GetIdentity(identity.IntermediaryId) 902 | if err != nil { 903 | t.Fatalf("Failed to get identity: %+v", err) 904 | } 905 | } 906 | 907 | func TestDatabaseImpl_LegacyUnregister(t *testing.T) { 908 | db, err := newDatabase("", "", "TestDatabaseImpl_LegacyUnregister", "", "") 909 | if err != nil { 910 | t.Fatal(err) 911 | } 912 | 913 | identity := generateTestIdentity(t) 914 | u := generateTestUser(t) 915 | 916 | err = db.insertUser(u) 917 | if err != nil { 918 | t.Fatal(err) 919 | } 920 | 921 | err = db.insertIdentity(&identity) 922 | if err != nil { 923 | t.Fatal(err) 924 | } 925 | 926 | token := "apnstoken01" 927 | err = db.registerForNotifications(u, identity, Token{ 928 | Token: token, 929 | App: constants.MessengerIOS.String(), 930 | }) 931 | if err != nil { 932 | t.Fatal(err) 933 | } 934 | 935 | _, err = db.GetUser(u.TransmissionRSAHash) 936 | if err != nil { 937 | t.Fatal(err) 938 | } 939 | 940 | err = db.LegacyUnregister(identity.IntermediaryId) 941 | if err != nil { 942 | t.Fatal(err) 943 | } 944 | 945 | _, err = db.GetUser(u.TransmissionRSAHash) 946 | if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { 947 | t.Fatalf("Expected gorm.ErrRecordNotFound, instead got %+v", err) 948 | } 949 | 950 | err = db.insertUser(u) 951 | if err != nil { 952 | t.Fatal(err) 953 | } 954 | 955 | err = db.insertIdentity(&identity) 956 | if err != nil { 957 | t.Fatal(err) 958 | } 959 | 960 | err = db.registerForNotifications(u, identity, Token{ 961 | Token: token, 962 | App: constants.MessengerIOS.String(), 963 | }) 964 | if err != nil { 965 | t.Fatal(err) 966 | } 967 | 968 | token2 := "fcm:newtoken" 969 | u2 := generateTestUser(t) 970 | err = db.insertUser(u2) 971 | if err != nil { 972 | t.Fatal(err) 973 | } 974 | 975 | err = db.insertIdentity(&identity) 976 | if err != nil { 977 | t.Fatal(err) 978 | } 979 | err = db.registerForNotifications(u2, identity, Token{ 980 | Token: token2, 981 | App: constants.MessengerAndroid.String(), 982 | }) 983 | if err != nil { 984 | t.Fatal(err) 985 | } 986 | 987 | err = db.LegacyUnregister(identity.IntermediaryId) 988 | if err == nil { 989 | t.Fatal("Should have received error trying to unregister iid with multiple associated users") 990 | } 991 | 992 | receivedIdent, err := db.GetIdentity(identity.IntermediaryId) 993 | if err != nil { 994 | t.Fatal(err) 995 | } 996 | if !bytes.Equal(receivedIdent.IntermediaryId, identity.IntermediaryId) { 997 | t.Fatal("Did not receive expected identity") 998 | } 999 | _, err = db.GetUser(u.TransmissionRSAHash) 1000 | if err != nil { 1001 | t.Fatal(err) 1002 | } 1003 | _, err = db.GetUser(u2.TransmissionRSAHash) 1004 | if err != nil { 1005 | t.Fatal(err) 1006 | } 1007 | } 1008 | 1009 | func generateTestIdentity(t *testing.T) Identity { 1010 | uid, err := id.NewRandomID(csprng.NewSystemRNG(), id.User) 1011 | if err != nil { 1012 | t.Fatal(err) 1013 | } 1014 | iid, err := ephemeral.GetIntermediaryId(uid) 1015 | if err != nil { 1016 | t.Fatal(err) 1017 | } 1018 | identity := Identity{ 1019 | IntermediaryId: iid, 1020 | OffsetNum: ephemeral.GetOffsetNum(ephemeral.GetOffset(iid)), 1021 | } 1022 | return identity 1023 | } 1024 | func generateTestUser(t *testing.T) *User { 1025 | trsa, err := rsa.GenerateKey(csprng.NewSystemRNG(), 512) 1026 | if err != nil { 1027 | t.Fatal(err) 1028 | } 1029 | h := hash.CMixHash.New() 1030 | h.Write(trsa.GetPublic().Bytes()) 1031 | u := &User{ 1032 | TransmissionRSAHash: h.Sum(nil), 1033 | TransmissionRSA: trsa.GetPublic().Bytes(), 1034 | } 1035 | return u 1036 | } 1037 | --------------------------------------------------------------------------------