├── main.go ├── .gitignore ├── scripts ├── active-users.service └── README.md ├── main_test.go ├── cmd ├── https_test.go ├── epoch_test.go ├── version.go ├── ipAddress │ ├── convert_test.go │ └── convert.go ├── profile.go ├── epoch.go ├── idf.go ├── filteredUpdates.go ├── bloom.go ├── version_vars.go ├── poll_test.go ├── autocert.go ├── gossip.go ├── rateLimitGossip.go ├── bloom_test.go ├── knownRoundsWrapper.go ├── knownRoundsWrapper_test.go ├── poll.go ├── bloomGossip.go ├── filteredUpdates_test.go ├── params.go └── root.go ├── Makefile ├── notifications ├── notifications_test.go └── notifications.go ├── LICENSE ├── storage ├── unmixedMessageBuffer.go ├── extendedRoundStorage.go ├── extendedRoundStorage_test.go ├── storage.go ├── unmixedMapBuffer_test.go ├── unmixedMapBuffer.go ├── storage_test.go ├── gatewayDb.go └── database.go ├── gateway.yaml ├── .gitlab-ci.yml ├── go.mod ├── autocert ├── interface.go ├── dns_test.go └── dns.go └── README.md /main.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package main 9 | 10 | import "gitlab.com/elixxir/gateway/cmd" 11 | 12 | func main() { 13 | cmd.Execute() 14 | } 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | # Ignore glide files/folders 3 | glide.lock 4 | vendor/ 5 | # Ignore Jetbrains IDE folder 6 | .idea/* 7 | # Ignore vim .swp buffers for open files 8 | .*.swp 9 | .*.swo 10 | # Ignore local development scripts 11 | localdev_* 12 | # Ignore logs 13 | *.log 14 | # Android things 15 | *.iml 16 | /android/.gradle 17 | /android/local.properties 18 | /android/.idea/workspace.xml 19 | /android/.idea/libraries 20 | /android/.DS_Store 21 | /android/build 22 | /android/captures 23 | /android/.externalNativeBuild 24 | *.apk 25 | *.ap_ 26 | *.dex 27 | *.class 28 | *.aar 29 | *.jar 30 | # Ignore genered version file 31 | cmd/version_vars.go 32 | -------------------------------------------------------------------------------- /scripts/active-users.service: -------------------------------------------------------------------------------- 1 | # Example service file for the active_users script. 2 | # Emphasizes required keys and minimal additional configuration 3 | # and expects the standard gateway file paths as specified in the handbook. 4 | # Other paths can easily be provided by specifying other command line args. 5 | # Should share aws access keys with the ones provided to wrapper script. 6 | 7 | [Unit] 8 | Description=Job that starts the Active Users Script 9 | 10 | [Service] 11 | User=elixxir 12 | Type=simple 13 | ExecStart=/opt/xxnetwork/active-users.py --pass "" --aws-key "" --aws-secret "" 14 | 15 | [Install] 16 | WantedBy=multi-user.target 17 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package main 9 | 10 | import ( 11 | "os/exec" 12 | "testing" 13 | ) 14 | 15 | // Smoke test for version 16 | func TestMainVersion(t *testing.T) { 17 | command := exec.Command("go", "run", "main.go", "version") 18 | err := command.Run() 19 | if e, ok := err.(*exec.ExitError); ok && !e.Success() { 20 | t.Errorf("Smoke test failed with %v", e) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /cmd/https_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "bytes" 5 | "gitlab.com/elixxir/gateway/storage" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestStoreHttpsCreds(t *testing.T) { 11 | db, err := storage.NewStorage("", "", "", "", "", true) 12 | if err != nil { 13 | t.Fatal(err) 14 | } 15 | creds := storedHttpsCreds{ 16 | Key: []byte("TestKey"), 17 | Cert: []byte("TestCert"), 18 | } 19 | 20 | err = storeHttpsCreds(creds.Cert, creds.Key, db) 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | 25 | cert, key, err := loadHttpsCreds(db) 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | if !bytes.Equal(cert, creds.Cert) || !bytes.Equal(key, creds.Key) { 31 | t.Fatalf("Did not receive expected creds\n\tExpected: "+ 32 | "\n\t\tKey: %+v\n\t\tCert: %+v\n\tReceived: \n\t\t"+ 33 | "Key: %+v\n\t\tCert: %+v\n", creds.Key, creds.Cert, key, cert) 34 | } 35 | } 36 | 37 | func TestGetReplaceAt(t *testing.T) { 38 | day := time.Hour * 24 39 | notAfter := time.Now().Add(90 * day) 40 | t.Log(notAfter) 41 | replaceAt := getReplaceAt(notAfter, time.Hour*24*30, time.Hour*24*7) 42 | t.Log(replaceAt) 43 | if replaceAt.After(notAfter) { 44 | t.Fatalf("Replaceat %s should be before notAfter %s", replaceAt, notAfter) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: update master release update_master update_release build clean version 2 | 3 | version: 4 | go run main.go generate 5 | mv version_vars.go cmd/version_vars.go 6 | 7 | clean: 8 | rm -rf vendor/ 9 | go mod vendor 10 | 11 | update: 12 | -GOFLAGS="" go get all 13 | 14 | build: 15 | go build ./... 16 | go mod tidy 17 | 18 | update_release: 19 | GOFLAGS="" go get gitlab.com/xx_network/primitives@release 20 | GOFLAGS="" go get gitlab.com/elixxir/primitives@release 21 | GOFLAGS="" go get gitlab.com/xx_network/crypto@release 22 | GOFLAGS="" go get gitlab.com/elixxir/crypto@release 23 | GOFLAGS="" go get gitlab.com/xx_network/comms@release 24 | GOFLAGS="" go get gitlab.com/elixxir/comms@release 25 | GOFLAGS="" go get gitlab.com/elixxir/bloomfilter@release 26 | 27 | update_master: 28 | GOFLAGS="" go get gitlab.com/xx_network/primitives@master 29 | GOFLAGS="" go get gitlab.com/elixxir/primitives@master 30 | GOFLAGS="" go get gitlab.com/xx_network/crypto@master 31 | GOFLAGS="" go get gitlab.com/elixxir/crypto@master 32 | GOFLAGS="" go get gitlab.com/xx_network/comms@master 33 | GOFLAGS="" go get gitlab.com/elixxir/comms@master 34 | GOFLAGS="" go get gitlab.com/elixxir/bloomfilter@master 35 | 36 | master: update_master clean build version 37 | 38 | release: update_release clean build version 39 | -------------------------------------------------------------------------------- /notifications/notifications_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package notifications 9 | 10 | import ( 11 | "bytes" 12 | "gitlab.com/xx_network/primitives/id" 13 | "testing" 14 | ) 15 | 16 | // unit test for notify function 17 | func TestUserNotifications_Notify(t *testing.T) { 18 | un := UserNotifications{} 19 | un.Notify(id.NewIdFromBytes([]byte("test"), t)) 20 | if len(un.ids) != 1 && bytes.Compare(un.ids[0].Bytes(), []byte("test")) != 0 { 21 | t.Errorf("Failed to properly add user notification") 22 | } 23 | 24 | un.Notify(id.NewIdFromBytes([]byte("test"), t)) 25 | if len(un.ids) != 1 { 26 | t.Errorf("Number of ids should still be one, since id is the same") 27 | } 28 | } 29 | 30 | // Unit test for notified function 31 | func TestUserNotifications_Notified(t *testing.T) { 32 | un := UserNotifications{} 33 | un.Notify(id.NewIdFromBytes([]byte("test"), t)) 34 | ret := un.Notified() 35 | if len(ret) != 1 && ret[0] != id.NewIdFromBytes([]byte("test"), t) { 36 | t.Error("Did not properly return list of ids") 37 | } 38 | if un.ids != nil { 39 | t.Error("Did not clear IDs after returning") 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /cmd/epoch_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import "testing" 11 | 12 | // Happy path 13 | // Can't test panic paths, obviously 14 | func TestGetEpoch(t *testing.T) { 15 | ts := int64(300000) 16 | period := int64(5000) 17 | expected := uint32(60) 18 | result := GetEpoch(ts, period) 19 | if result != expected { 20 | t.Errorf("Invalid GetEpoch result: Got %d Expected %d", result, expected) 21 | } 22 | } 23 | 24 | // Various happy paths 25 | func TestGetEpochTimestamp(t *testing.T) { 26 | epoch := uint32(60) 27 | period := int64(5000) 28 | expected := int64(300000) 29 | result := GetEpochTimestamp(epoch, period) 30 | if result != expected { 31 | t.Errorf("Invalid GetEpochTimestamp result: Got %d Expected %d", result, expected) 32 | } 33 | 34 | period = 0 35 | expected = 0 36 | result = GetEpochTimestamp(epoch, period) 37 | if result != expected { 38 | t.Errorf("Invalid GetEpochTimestamp result: Got %d Expected %d", result, expected) 39 | } 40 | 41 | period = -5000 42 | expected = -300000 43 | result = GetEpochTimestamp(epoch, period) 44 | if result != expected { 45 | t.Errorf("Invalid GetEpochTimestamp result: Got %d Expected %d", result, expected) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | To whom it may concern, 2 | 3 | You can download, modify, compile and deploy this source code for the purpose of participating in the xx network as a 4 | beta node, in accordance with your beta node participation agreement. 5 | 6 | You can also download, modify, compile and deploy the source code for non-commercial testing and verification 7 | (i.e., security and bug review) purposes. You can repost aspects of the source code on both the BetaNet forum 8 | (forum.xx.network) and the official Discord (https://discord.gg/D4NHmv4) consistent with these purposes. 9 | To release a bug or security report outside the BetaNet forum or official Discord, you must notify bugs@xx.network at 10 | least three business days in advance Pacific time. 11 | 12 | The xx network SEZC hereby grants you a non-transferable license under its legal rights limited to the purposes above. 13 | 14 | This Agreement and the license that it grants you expires the earlier of April 1st 2022 or the launch of the xx network 15 | MainNet. 16 | 17 | THE SOURCE CODE IS PROVIDED TO YOU ON AN “AS IS” BASIS WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS OR IMPLIED, 18 | INCLUDING ANY WARRANTY OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR USE, OR ANY WARRANTY THAT THE SOURCE CODE 19 | DOES NOT INFRINGE THE RIGHTS OF OTHERS (WHETHER PATENT RIGHTS, COPYRIGHTS OR OTHERWISE). 20 | 21 | THE XX NETWORK SEZC WILL NOT BE LIABLE TO YOU FOR ANY DAMAGES OF ANY KIND, WHETHER DIRECT, SPECIAL, CONSEQUENTIAL, 22 | INCIDENTAL, INDIRECT OR OTHERWISE, EVEN IF THE XX NETWORK SEZC HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES, 23 | WHICH ARISE OUT OF THIS AGREEMENT OR THE USE OR PERFORMANCE OF THE SOURCE CODE. 24 | 25 | The xx network SEZC 26 | -------------------------------------------------------------------------------- /cmd/version.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Handles command-line version functionality 9 | 10 | package cmd 11 | 12 | import ( 13 | "fmt" 14 | 15 | "github.com/spf13/cobra" 16 | "gitlab.com/xx_network/primitives/utils" 17 | ) 18 | 19 | // Change this value to set the version for this build 20 | const currentVersion = "3.16.1" 21 | 22 | func printVersion() { 23 | fmt.Printf("xx network Gateway v%s -- %s\n\n", SEMVER, GITVERSION) 24 | fmt.Printf("Dependencies:\n\n%s\n", DEPENDENCIES) 25 | } 26 | 27 | func init() { 28 | rootCmd.AddCommand(versionCmd) 29 | rootCmd.AddCommand(generateCmd) 30 | } 31 | 32 | var versionCmd = &cobra.Command{ 33 | Use: "version", 34 | Short: "Print the version and dependency information for the xx network binary", 35 | Long: `Print the version and dependency information for the xx network binary`, 36 | Run: func(cmd *cobra.Command, args []string) { 37 | printVersion() 38 | }, 39 | } 40 | 41 | var generateCmd = &cobra.Command{ 42 | Use: "generate", 43 | Short: "Generates version and dependency information for the xx network binary", 44 | Long: `Generates version and dependency information for the xx network binary`, 45 | Run: func(cmd *cobra.Command, args []string) { 46 | utils.GenerateVersionFile(currentVersion) 47 | }, 48 | } 49 | -------------------------------------------------------------------------------- /notifications/notifications.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // notifications contains the structure and functions for tracking users who should be sent push notifications 9 | 10 | package notifications 11 | 12 | import "gitlab.com/xx_network/primitives/id" 13 | 14 | // UserNotifications stores the list of user ids to be notified 15 | type UserNotifications struct { 16 | ids []*id.ID 17 | } 18 | 19 | // Notify adds a user to the list of users to be notified 20 | // If the user is already in the list, a duplicate record is not added 21 | func (n *UserNotifications) Notify(uid *id.ID) { 22 | if n.ids == nil { 23 | n.ids = make([]*id.ID, 0) 24 | } 25 | _, found := find(n.ids, uid) 26 | if found { 27 | return 28 | } 29 | n.ids = append(n.ids, uid) 30 | } 31 | 32 | // Notified returns a list of string representations of user ids to be notified 33 | func (n *UserNotifications) Notified() []*id.ID { 34 | var ret []*id.ID 35 | for _, uid := range n.ids { 36 | ret = append(ret, uid) 37 | } 38 | n.ids = nil 39 | return ret 40 | } 41 | 42 | // find is a helper method for Notify, used to determine if a given user id is already in its list 43 | func find(slice []*id.ID, val *id.ID) (int, bool) { 44 | for i, item := range slice { 45 | if item.Cmp(val) { 46 | return i, true 47 | } 48 | } 49 | return -1, false 50 | } 51 | -------------------------------------------------------------------------------- /cmd/ipAddress/convert_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package ipAddress 9 | 10 | import ( 11 | "bytes" 12 | "testing" 13 | ) 14 | 15 | // Happy path. 16 | func TestStringToByte(t *testing.T) { 17 | ipAddr := "1.2.3.4" 18 | 19 | expected := []byte{1, 2, 3, 4} 20 | 21 | recieved, err := StringToByte(ipAddr) 22 | if err != nil { 23 | t.Fatalf(err.Error()) 24 | } 25 | 26 | if !bytes.Equal(expected, recieved) { 27 | t.Fatalf("Unexpected output converting IP address from string to byte."+ 28 | "\n\tExpected: %v"+ 29 | "\n\tReceived: %v", expected, recieved) 30 | } 31 | } 32 | 33 | // Error path. 34 | func TestStringToByte2(t *testing.T) { 35 | invalidIpAddr := "1a.2b.3c.4d" 36 | 37 | _, err := StringToByte(invalidIpAddr) 38 | if err == nil { 39 | t.Fatalf("Expected error case, should not be able to convert %s to a byte slice", invalidIpAddr) 40 | } 41 | 42 | } 43 | 44 | // Happy path. 45 | func TestByteToString(t *testing.T) { 46 | expected := "1.2.3.4" 47 | 48 | ipAddr := []byte{1, 2, 3, 4} 49 | 50 | received, err := ByteToString(ipAddr) 51 | if err != nil { 52 | t.Fatalf(err.Error()) 53 | } 54 | 55 | if expected != received { 56 | t.Fatalf("Unexpected output converting IP address from byte to string."+ 57 | "\n\tExpected: %v"+ 58 | "\n\tReceived: %v", expected, received) 59 | 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /cmd/profile.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | jww "github.com/spf13/jwalterweatherman" 12 | "runtime" 13 | "sync" 14 | ) 15 | 16 | type stats struct { 17 | MemoryAllocated uint64 18 | NumThreads int 19 | } 20 | 21 | var prevStats *stats 22 | var statsMutex sync.Mutex 23 | 24 | func PrintProfilingStatistics() { 25 | statsMutex.Lock() 26 | // Get Total Allocated Memory 27 | var memStats runtime.MemStats 28 | runtime.ReadMemStats(&memStats) 29 | memoryAllocated := memStats.Alloc 30 | 31 | // Number of threads 32 | numThreads := runtime.NumGoroutine() 33 | 34 | curStats := &stats{ 35 | MemoryAllocated: memoryAllocated, 36 | NumThreads: numThreads, 37 | } 38 | 39 | memDelta := int64(memoryAllocated) 40 | threadDelta := numThreads 41 | 42 | if prevStats != nil { 43 | memDelta -= int64(prevStats.MemoryAllocated) 44 | threadDelta -= prevStats.NumThreads 45 | } 46 | 47 | prevStats = curStats 48 | 49 | plusOrMinus := "+" 50 | if memDelta < 0 { 51 | plusOrMinus = "" 52 | } 53 | jww.INFO.Printf("Total memory allocation: %d (%s%d)", memoryAllocated, 54 | plusOrMinus, memDelta) 55 | 56 | plusOrMinus = "+" 57 | if threadDelta < 0 { 58 | plusOrMinus = "" 59 | } 60 | jww.INFO.Printf("Total thread count: %d (%s%d)", numThreads, 61 | plusOrMinus, threadDelta) 62 | statsMutex.Unlock() 63 | } 64 | -------------------------------------------------------------------------------- /storage/unmixedMessageBuffer.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package storage 9 | 10 | import ( 11 | pb "gitlab.com/elixxir/comms/mixmessages" 12 | "gitlab.com/xx_network/primitives/id" 13 | ) 14 | 15 | // Interface for interacting with the UnmixedMessageBuffer. 16 | type UnmixedMessageBuffer interface { 17 | // AddUnmixedMessage adds an unmixed message to send to the cMix node. 18 | AddUnmixedMessage(msg *pb.Slot, sender *id.ID, ip string, round id.Round) error 19 | 20 | // AddManyUnmixedMessage adds many unmixed messages to send to the cMix node. 21 | AddManyUnmixedMessages(msg []*pb.GatewaySlot, sender *id.ID, ip string, round id.Round) error 22 | 23 | // GetRoundMessages returns the batch associated with the roundID 24 | PopRound(rndId id.Round) (*pb.Batch, []*id.ID, []string) 25 | 26 | // LenUnmixed return the number of messages within the requested round 27 | LenUnmixed(rndId id.Round) int 28 | 29 | // SetAsRoundLeader initializes a round as our responsibility batchSize := 4 30 | 31 | SetAsRoundLeader(rndId id.Round, batchSize uint32) 32 | 33 | // IsRoundFull returns true if the number of slots associated with 34 | // the round ID matches the batchsize of that round 35 | IsRoundFull(rndId id.Round) bool 36 | 37 | // IsRoundLeader returns true if object mapped to this round has 38 | // been previously set 39 | IsRoundLeader(rndId id.Round) bool 40 | 41 | } 42 | -------------------------------------------------------------------------------- /cmd/epoch.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | "github.com/pkg/errors" 12 | jww "github.com/spf13/jwalterweatherman" 13 | ) 14 | 15 | // GetEpochEdge determines the Epoch value of the given timestamp with the 16 | // given period while returning an error. To be used when either of the 17 | // inputs come from the network. 18 | func GetEpochEdge(ts int64, period int64) (uint32, error) { 19 | if period == 0 { 20 | return 0, errors.New("GetEpochEdge: Period length is 0, " + 21 | "cannot divide by zero") 22 | } else if ts < 0 { 23 | return 0, errors.Errorf("GetEpochEdge: Cannot calculate "+ 24 | "epoch with a negative timestamp: %d", ts) 25 | } else if period < 0 { 26 | return 0, errors.Errorf("GetEpochEdge: Cannot calculate "+ 27 | "epoch with a negative period size: %d", period) 28 | } 29 | return uint32(ts / period), nil 30 | } 31 | 32 | // GetEpoch determines the Epoch value of the given timestamp 33 | // with the given period. Panics on error. For internal use 34 | func GetEpoch(ts int64, period int64) uint32 { 35 | epoch, err := GetEpochEdge(ts, period) 36 | 37 | if err != nil { 38 | jww.FATAL.Panicf("%+v", err) 39 | } 40 | 41 | return epoch 42 | } 43 | 44 | // GetEpochTimestamp determines the timestamp value of the given epoch 45 | func GetEpochTimestamp(epoch uint32, period int64) int64 { 46 | return period * int64(epoch) 47 | } 48 | -------------------------------------------------------------------------------- /cmd/idf.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | "bytes" 12 | "github.com/pkg/errors" 13 | jww "github.com/spf13/jwalterweatherman" 14 | "gitlab.com/xx_network/primitives/id" 15 | "gitlab.com/xx_network/primitives/id/idf" 16 | "gitlab.com/xx_network/primitives/ndf" 17 | ) 18 | 19 | // Helper that updates parses the NDF in order to create our IDF 20 | func (gw *Instance) setupIDF(nodeId []byte, ourNdf *ndf.NetworkDefinition) (err error) { 21 | 22 | // Determine the index of this gateway 23 | for i, node := range ourNdf.Nodes { 24 | // Find our node in the ndf 25 | if bytes.Compare(node.ID, nodeId) == 0 { 26 | 27 | // Save the IDF to the idfPath 28 | err := writeIDF(ourNdf, i, idfPath) 29 | if err != nil { 30 | jww.WARN.Printf("Could not write ID File: %s", 31 | idfPath) 32 | } 33 | 34 | return nil 35 | } 36 | } 37 | 38 | return errors.Errorf("Unable to locate ID %v in NDF!", nodeId) 39 | } 40 | 41 | // writeIDF writes the identity file for the gateway into the given location 42 | func writeIDF(ndf *ndf.NetworkDefinition, index int, idfPath string) error { 43 | // Create IDF based on NDF ID 44 | zeroSalt := make([]byte, 32) 45 | gwID, err := id.Unmarshal(ndf.Gateways[index].ID) 46 | // Save new ID to file 47 | if err == nil { 48 | err = idf.LoadIDF(idfPath, zeroSalt, gwID) 49 | } 50 | if err != nil { 51 | errors.Errorf("Failed to save IDF: %+v", err) 52 | } 53 | return nil 54 | } 55 | -------------------------------------------------------------------------------- /gateway.yaml: -------------------------------------------------------------------------------- 1 | ## 2 | # Gateway Configuration File 3 | ## 4 | 5 | # Level of debugging to print (0 = info, 1 = debug, >1 = trace). (Default info) 6 | logLevel: 0 7 | 8 | # Path where log file will be saved. (Default "log/gateway.log") 9 | log: "/opt/xxnetwork/log/gateway.log" 10 | 11 | # Port for Gateway to listen on. Gateway must be the only listener on this port. 12 | # (Required) 13 | port: 22840 14 | 15 | # The IP address of the machine running cMix that the Gateway communicates with. 16 | # Expects an IPv4 address with a port. (Required) 17 | cmixAddress: "0.0.0.0:11420" 18 | 19 | # Path to where the identity file (IDF) is saved. The IDF stores the Gateway's 20 | # network identity. This is used by the wrapper management script. (Required) 21 | idfPath: "/opt/xxnetwork/cred/gateway-IDF.json" 22 | 23 | # Path to the private key associated with the self-signed TLS certificate. 24 | # (Required) 25 | keyPath: "/opt/xxnetwork/cred/gateway-key.key" 26 | 27 | # Path to the self-signed TLS certificate for Gateway. Expects PEM format. 28 | # (Required) 29 | certPath: "/opt/xxnetwork/cred/gateway-cert.crt" 30 | 31 | # Path to the self-signed TLS certificate for cMix. Expects PEM format. 32 | # (Required) 33 | cmixCertPath: "/opt/xxnetwork/cred/cmix-cert.crt" 34 | 35 | # Path to the self-signed TLS certificate for the Scheduling server. Expects 36 | # PEM format. (Required) 37 | schedulingCertPath: "/opt/xxnetwork/cred/scheduling-cert.crt" 38 | 39 | # Database connection information. (Required) 40 | dbName: "cmix_gateway" 41 | dbAddress: "0.0.0.0:5432" 42 | dbUsername: "cmix" 43 | dbPassword: "[password for database]" 44 | 45 | ## 46 | # WARNING: Do not modify the options below unless explicitly required. 47 | ## 48 | 49 | # Local IP address of the Gateway, used for internal listening. Expects an IPv4 50 | # address without a port. (Default "0.0.0.0") 51 | #listeningAddress: "0.0.0.0" 52 | 53 | # The public IPv4 address of the Gateway, as reported to the network. When not 54 | # set, external IP address lookup services are used to set this value. If a 55 | # port is not included, then the port from the port flag is used instead. 56 | #overridePublicIP: "1.2.3.4" -------------------------------------------------------------------------------- /cmd/ipAddress/convert.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package ipAddress 9 | 10 | import ( 11 | "github.com/pkg/errors" 12 | "strconv" 13 | "strings" 14 | ) 15 | 16 | // StringToByte will convert an IP address into a byte slice. 17 | // Example "1.2.3.4" -> []byte{1,2,3,4}. 18 | func StringToByte(ipAddr string) ([]byte, error) { 19 | // Split IP values separated by the '.' delimiter 20 | addrVals := strings.Split(ipAddr, ".") 21 | 22 | // Check validity of address by ensuring 4 values 23 | if len(addrVals) != 4 { 24 | return nil, errors.Errorf("Invalid input, %s is not recognized as an IP", ipAddr) 25 | } 26 | 27 | // Initialize byte slice 28 | b := make([]byte, 4) 29 | 30 | // Iterate through each value 31 | for i, addrVal := range addrVals { 32 | // Convert to byte 33 | addr, err := strconv.Atoi(addrVal) 34 | if err != nil { 35 | return nil, errors.WithMessagef(err, "Could not convert IP address (%s) to byte data", ipAddr) 36 | } 37 | 38 | // Place in byte array 39 | b[i] = byte(addr) 40 | } 41 | 42 | // Return IP address as byte data 43 | return b, nil 44 | } 45 | 46 | // ByteToString converts a byte representation of an IP address to a string. 47 | // Example: []byte{1,2,3,4} -> "1.2.3.4". 48 | func ByteToString(ipAddr []byte) (string, error) { 49 | // Check validity of address by ensuring 4 values 50 | if len(ipAddr) != 4 { 51 | return "", errors.Errorf("Invalid input, %s is not recognized as an IP", ipAddr) 52 | } 53 | 54 | // Convert each value to a string, place in a slice of strings 55 | addrVals := make([]string, 4) 56 | for i, b := range ipAddr { 57 | addrVals[i] = strconv.Itoa(int(b)) 58 | } 59 | 60 | // Return joined string slice by "." delimiter 61 | return strings.Join(addrVals, "."), nil 62 | 63 | } 64 | -------------------------------------------------------------------------------- /cmd/filteredUpdates.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | pb "gitlab.com/elixxir/comms/mixmessages" 12 | "gitlab.com/elixxir/comms/network" 13 | ds "gitlab.com/elixxir/comms/network/dataStructures" 14 | "gitlab.com/elixxir/primitives/states" 15 | "gitlab.com/xx_network/crypto/signature/ec" 16 | ) 17 | 18 | type FilteredUpdates struct { 19 | updates *ds.Updates 20 | instance *network.Instance 21 | ecPubKey *ec.PublicKey 22 | } 23 | 24 | func NewFilteredUpdates(instance *network.Instance) (*FilteredUpdates, error) { 25 | ecPubKey, err := ec.LoadPublicKey(instance.GetEllipticPublicKey()) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | return &FilteredUpdates{ 31 | updates: ds.NewUpdates(), 32 | instance: instance, 33 | ecPubKey: ecPubKey, 34 | }, nil 35 | } 36 | 37 | // Get an update ID 38 | func (fu *FilteredUpdates) GetRoundUpdate(updateID int) (*pb.RoundInfo, error) { 39 | return fu.updates.GetUpdate(updateID) 40 | } 41 | 42 | // Get updates from a given round 43 | func (fu *FilteredUpdates) GetRoundUpdates(id int) []*pb.RoundInfo { 44 | return fu.updates.GetUpdates(id) 45 | } 46 | 47 | // get the most recent update id 48 | func (fu *FilteredUpdates) GetLastUpdateID() int { 49 | return fu.updates.GetLastUpdateID() 50 | } 51 | 52 | // Pluralized version of RoundUpdate 53 | func (fu *FilteredUpdates) RoundUpdates(rounds []*pb.RoundInfo) error { 54 | // Process all rounds passed in 55 | for _, round := range rounds { 56 | err := fu.RoundUpdate(round) 57 | if err != nil { 58 | return err 59 | } 60 | } 61 | return nil 62 | } 63 | 64 | // Add a round to the updates filter 65 | func (fu *FilteredUpdates) RoundUpdate(info *pb.RoundInfo) error { 66 | switch states.Round(info.State) { 67 | // Only add to filter states client cares about 68 | case states.COMPLETED, states.FAILED, states.QUEUED: 69 | 70 | roundCopy := info.DeepCopy() 71 | 72 | // Clear out the rsa signature, keeping the EC signature 73 | // only for FilteredUpdates 74 | roundCopy.Signature = nil 75 | 76 | // Create a wrapped round object and store it 77 | rnd := ds.NewRound(roundCopy, nil, fu.ecPubKey) 78 | 79 | err := fu.updates.AddRound(rnd) 80 | if err != nil { 81 | return err 82 | } 83 | default: 84 | 85 | } 86 | 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | before_script: 2 | ## 3 | ## Go Setup 4 | ## 5 | - go version || echo "Go executable not found." 6 | - echo $CI_BUILD_REF 7 | - echo $CI_PROJECT_DIR 8 | - echo $PWD 9 | - eval $(ssh-agent -s) 10 | - echo "$SSH_PRIVATE_KEY" | tr -d '\r' | ssh-add - > /dev/null 11 | - mkdir -p ~/.ssh 12 | - chmod 700 ~/.ssh 13 | - ssh-keyscan -t rsa $GITLAB_SERVER > ~/.ssh/known_hosts 14 | - git config --global url."git@$GITLAB_SERVER:".insteadOf "https://gitlab.com/" 15 | - git config --global url."git@$GITLAB_SERVER:".insteadOf "https://git.xx.network/" --add 16 | - export PATH=$HOME/go/bin:$PATH 17 | 18 | stages: 19 | - build 20 | - trigger_integration 21 | 22 | build: 23 | stage: build 24 | image: $DOCKER_IMAGE 25 | except: 26 | - tags 27 | script: 28 | - git clean -ffdx 29 | - go mod vendor -v 30 | - go build ./... 31 | - go mod tidy 32 | - mkdir -p testdata 33 | 34 | # Test coverage 35 | - go-acc --covermode atomic --output testdata/coverage.out ./... -- -v 36 | # Exclude some specific packages and files 37 | - grep -v -e cmd -e gatewayDb.go testdata/coverage.out > testdata/coverage-real.out 38 | - go tool cover -func=testdata/coverage-real.out 39 | - go tool cover -html=testdata/coverage-real.out -o testdata/coverage.html 40 | 41 | # Test Coverage Check 42 | - go tool cover -func=testdata/coverage-real.out | grep "total:" | awk '{print $3}' | sed 's/\%//g' > testdata/coverage-percentage.txt 43 | - export CODE_CHECK=$(echo "$(cat testdata/coverage-percentage.txt) >= $MIN_CODE_COVERAGE" | bc -l) 44 | - (if [ "$CODE_CHECK" == "1" ]; then echo "Minimum coverage of $MIN_CODE_COVERAGE succeeded"; else echo "Minimum coverage of $MIN_CODE_COVERAGE failed"; exit 1; fi); 45 | 46 | - mkdir -p release 47 | - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' ./... 48 | - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o release/gateway.linux64 main.go 49 | - GOOS=windows GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o release/gateway.win64 main.go 50 | # - GOOS=windows GOARCH=386 CGO_ENABLED=0 go build -ldflags '-w -s' -o release/gateway.win32 main.go 51 | - GOOS=darwin GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o release/gateway.darwin64 main.go 52 | - /upload-artifacts.sh release/ 53 | - /upload-artifact.sh gateway release/gateway.linux64 54 | artifacts: 55 | paths: 56 | - vendor/ 57 | - testdata/ 58 | - release/ 59 | 60 | tag: 61 | stage: trigger_integration 62 | only: 63 | - master 64 | image: $DOCKER_IMAGE 65 | script: 66 | - git remote add origin_tags git@git.xx.network:elixxir/gateway.git || true 67 | - git remote set-url origin_tags git@git.xx.network:elixxir/gateway.git || true 68 | - git tag $(./release/gateway.linux64 version | grep "xx network Gateway v"| cut -d ' ' -f4) -f 69 | - git push origin_tags -f --tags 70 | 71 | trigger_integration: 72 | stage: trigger_integration 73 | trigger: 74 | project: elixxir/integration 75 | branch: $CI_COMMIT_REF_NAME 76 | only: 77 | - release 78 | - master 79 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Active Users Script 2 | 3 | ### Description 4 | 5 | The active users script is a lightweight tool for inferring the number of active clients that are 6 | running on cMix over time. It uses bloom filter data stored in gateway databases in order 7 | to provide that approximation. Additionally, it uploads these data to CloudWatch so that 8 | it can be aggregated from many gateways in order to show a more realistic approximation in 9 | a decentralized network. 10 | 11 | ### Installation 12 | 13 | > __Note:__ This guide assumes a standard gateway installation as prescribed by the node handbook. 14 | > The service file and script inputs can be modified for alternate installations as needed 15 | > but will not be covered here. 16 | > 17 | > Run `./active-users.py --help` for more information on these options. 18 | 19 | 1. Prepare the service file in a text editor of your choice. 20 | There are four relevant items you must ensure are set correctly: 21 | 1. `User=elixxir` - The username provided here should match the user that runs the gateway wrapper. Can be 22 | determined by running `ls -l /opt/xxnetwork` and examining the output, for example. 23 | 2. `--pass` - Provide your gateway database password. Can be extracted from `/opt/xxnetwork/config/gateway.yaml` 24 | under the `dbPassword` field. 25 | 3. `--aws-key` - AWS access credentials for CloudWatch. Can be extracted from 26 | `/opt/xxnetwork/xxnetwork-gateway.service` under the `--s3-access-key` field. 27 | 4. `--aws-secret` - AWS access credentials for CloudWatch. Can be extracted from 28 | `/opt/xxnetwork/xxnetwork-gateway.service` under the `--s3-secret` field. 29 | 2. Install required Python dependencies on gateway server by running `pip3 install psycopg2-binary` 30 | 3. Stage the script. Place `active-users.py` at `/opt/xxnetwork/active-users.py` on your gateway machine, matching 31 | the `ExecStart` path to the script provided in the service file. 32 | 4. Make the script executable by running `chmod +x /opt/xxnetwork/active-users.py`. 33 | 1. Additionally, verify the user provided to the service file is the owner of the script. 34 | If not, you can change it via `chown elixxir:elixxir /opt/xxnetwork/active-users.py`, for example. 35 | 5. Stage the service file. Take the completed service file from __Step 1__ and place it at 36 | `/etc/systemd/system/active-users.service`. This will require root permissions. 37 | 1. For example, `sudo nano /etc/systemd/system/active-users.service`, paste, and save. 38 | 6. Enable the service by running `sudo systemctl enable active-users` 39 | 7. Start the service by running `sudo systemctl restart active-users` 40 | 8. Verify the active users script is running correctly. This can be accomplished in a variety of ways: 41 | 1. Run `systemctl status active-users`. The resulting page should indicate the service is `active (running)`. 42 | 2. Check the log directory by running `ls /opt/xxnetwork/log`. Both the `active-users.log` and `active-users.csv` 43 | files should be present. 44 | 9. If these checks fail and you were specifically petitioned to run this tool, please get in touch with the team. 45 | Otherwise, the lightweight script will continue to run in the background unless stopped via 46 | `sudo systemctl stop active-users` 47 | -------------------------------------------------------------------------------- /cmd/bloom.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | "encoding/binary" 12 | "github.com/pkg/errors" 13 | jww "github.com/spf13/jwalterweatherman" 14 | bloom "gitlab.com/elixxir/bloomfilter" 15 | "gitlab.com/elixxir/primitives/states" 16 | "gitlab.com/xx_network/primitives/id" 17 | "gitlab.com/xx_network/primitives/id/ephemeral" 18 | "strings" 19 | ) 20 | 21 | // This file will handle the logistics of maintaining, creating and deleting user bloom filters 22 | 23 | // Constants for constructing a bloom filter 24 | const bloomFilterSize = 648 // In Bits 25 | const bloomFilterHashes = 10 26 | 27 | // Upserts filters of passed in recipients, using the round ID 28 | func (gw *Instance) UpsertFilters(recipients map[ephemeral.Id]interface{}, roundId id.Round) error { 29 | var errReturn error 30 | var errs []string 31 | 32 | // Get epoch information 33 | round, err := gw.NetInf.GetRound(roundId) 34 | if err != nil { 35 | return err 36 | } 37 | roundTimestamp := round.Timestamps[states.QUEUED] 38 | epoch := GetEpoch(int64(roundTimestamp), gw.period) 39 | 40 | for recipient := range recipients { 41 | err := gw.UpsertFilter(recipient, roundId, epoch) 42 | if err != nil { 43 | errs = append(errs, err.Error()) 44 | } 45 | } 46 | 47 | if len(errs) > 0 { 48 | errReturn = errors.New(strings.Join(errs, errorDelimiter)) 49 | } 50 | 51 | return errReturn 52 | } 53 | 54 | // Helper function which updates the clients bloom filter 55 | func (gw *Instance) UpsertFilter(recipientId ephemeral.Id, roundId id.Round, epoch uint32) error { 56 | jww.DEBUG.Printf("Adding bloom filter for client %d on round %d with epoch %d", 57 | recipientId.Int64(), roundId, epoch) 58 | 59 | // Generate a new filter 60 | // Initialize a new bloom filter 61 | newBloom, err := bloom.InitByParameters(bloomFilterSize, bloomFilterHashes) 62 | if err != nil { 63 | return errors.Errorf("Unable to generate new bloom filter: %s", err) 64 | } 65 | 66 | // Add the round to the bloom filter 67 | serializedRound := serializeRound(roundId) 68 | newBloom.Add(serializedRound) 69 | 70 | // Add the round to the bloom filter 71 | // Marshal the new bloom filter 72 | marshaledBloom, err := newBloom.MarshalBinary() 73 | if err != nil { 74 | return errors.Errorf("Unable to marshal new bloom filter: %s", err) 75 | } 76 | 77 | // Upsert the filter to storage 78 | return gw.storage.HandleBloomFilter(recipientId, marshaledBloom, roundId, epoch) 79 | } 80 | 81 | // Serializes a round into a byte array. 82 | // fixme: Used as bloom filters requires insertion 83 | // of a byte array into the data structure 84 | // better way to do this? look into internals of bloom filter 85 | // likely a marshal function internal to the filter 86 | func serializeRound(roundId id.Round) []byte { 87 | b := make([]byte, 8) 88 | binary.LittleEndian.PutUint64(b, uint64(roundId)) 89 | return b 90 | } 91 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module gitlab.com/elixxir/gateway 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 7 | github.com/golang/protobuf v1.5.2 8 | github.com/pkg/errors v0.9.1 9 | github.com/spf13/cobra v1.1.1 10 | github.com/spf13/jwalterweatherman v1.1.0 11 | github.com/spf13/viper v1.7.1 12 | gitlab.com/elixxir/bloomfilter v0.0.0-20230322223210-fa84f6842de8 13 | gitlab.com/elixxir/comms v0.0.4-0.20230310205528-f06faa0d2f0b 14 | gitlab.com/elixxir/crypto v0.0.7-0.20230322181929-8cb5fa100824 15 | gitlab.com/elixxir/primitives v0.0.3-0.20230214180039-9a25e2d3969c 16 | gitlab.com/xx_network/comms v0.0.4-0.20230214180029-5387fb85736d 17 | gitlab.com/xx_network/crypto v0.0.5-0.20230214003943-8a09396e95dd 18 | gitlab.com/xx_network/primitives v0.0.4-0.20230310205521-c440e68e34c4 19 | golang.org/x/crypto v0.5.0 20 | golang.org/x/net v0.5.0 21 | google.golang.org/grpc v1.49.0 22 | google.golang.org/protobuf v1.28.1 23 | gorm.io/driver/postgres v1.4.5 24 | gorm.io/gorm v1.24.1-0.20221019064659-5dd2bb482755 25 | ) 26 | 27 | require ( 28 | git.xx.network/elixxir/grpc-web-go-client v0.0.0-20230214175953-5b5a8c33d28a // indirect 29 | github.com/cenkalti/backoff/v4 v4.1.3 // indirect 30 | github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f // indirect 31 | github.com/elliotchance/orderedmap v1.4.0 // indirect 32 | github.com/fsnotify/fsnotify v1.4.9 // indirect 33 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect 34 | github.com/gorilla/websocket v1.5.0 // indirect 35 | github.com/hashicorp/hcl v1.0.0 // indirect 36 | github.com/improbable-eng/grpc-web v0.15.0 // indirect 37 | github.com/inconshreveable/mousetrap v1.0.0 // indirect 38 | github.com/jackc/chunkreader/v2 v2.0.1 // indirect 39 | github.com/jackc/pgconn v1.13.0 // indirect 40 | github.com/jackc/pgio v1.0.0 // indirect 41 | github.com/jackc/pgpassfile v1.0.0 // indirect 42 | github.com/jackc/pgproto3/v2 v2.3.1 // indirect 43 | github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect 44 | github.com/jackc/pgtype v1.12.0 // indirect 45 | github.com/jackc/pgx/v4 v4.17.2 // indirect 46 | github.com/jinzhu/inflection v1.0.0 // indirect 47 | github.com/jinzhu/now v1.1.5 // indirect 48 | github.com/klauspost/compress v1.11.7 // indirect 49 | github.com/magiconair/properties v1.8.4 // indirect 50 | github.com/mitchellh/go-homedir v1.1.0 // indirect 51 | github.com/mitchellh/mapstructure v1.4.0 // indirect 52 | github.com/pelletier/go-toml v1.8.1 // indirect 53 | github.com/rs/cors v1.8.2 // indirect 54 | github.com/smartystreets/assertions v1.1.0 // indirect 55 | github.com/soheilhy/cmux v0.1.5 // indirect 56 | github.com/spf13/afero v1.5.1 // indirect 57 | github.com/spf13/cast v1.3.1 // indirect 58 | github.com/spf13/pflag v1.0.5 // indirect 59 | github.com/subosito/gotenv v1.2.0 // indirect 60 | github.com/zeebo/blake3 v0.1.1 // indirect 61 | gitlab.com/xx_network/ring v0.0.3-0.20220902183151-a7d3b15bc981 // indirect 62 | go.uber.org/atomic v1.10.0 // indirect 63 | golang.org/x/sys v0.4.0 // indirect 64 | golang.org/x/text v0.6.0 // indirect 65 | google.golang.org/genproto v0.0.0-20220822174746-9e6da59bd2fc // indirect 66 | gopkg.in/ini.v1 v1.62.0 // indirect 67 | gopkg.in/yaml.v2 v2.4.0 // indirect 68 | nhooyr.io/websocket v1.8.7 // indirect 69 | src.agwa.name/tlshacks v0.0.0-20220518131152-d2c6f4e2b780 // indirect 70 | ) 71 | -------------------------------------------------------------------------------- /autocert/interface.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Package autocert requests an ACME certificate using EAB credentials 9 | // and provides helper functions to wait until the certificate is issued. 10 | package autocert 11 | 12 | import ( 13 | "context" 14 | "crypto" 15 | "io" 16 | "time" 17 | 18 | "gitlab.com/elixxir/crypto/rsa" 19 | "golang.org/x/crypto/acme" 20 | ) 21 | 22 | // Client autocert interface provides a simplified ACME Client 23 | // with the ability to create a new request, 24 | // 25 | // To Use, you generate a private key, then: 26 | // 1. Register with the server using EAB credentials. 27 | // 2. Challenge which accepts and returns the challenge information. 28 | // 3. Cert which waits until the server accepts and authorizes 29 | // your challenge, then requests and returns your cert in PEM format. 30 | type Client interface { 31 | // Register authorizes this private key with the server. 32 | // eabKeyID is the key ID for External Account Binding and is a string 33 | // eabKey is a base64 raw encoded string for External Account Binding 34 | // email is the e-mail address to use when registering. 35 | // when nil is returned, the Registration succeeded and can continue 36 | // onto the Request step. 37 | Register(privateKey rsa.PrivateKey, eabKeyID, eabKey, email string) error 38 | 39 | // Request retrieves and accepts the appropriate ACME challenge for this 40 | // Client, and returns the challenge string (e.g., DNS Token to set) 41 | Request(domain string) (key, value string, err error) 42 | 43 | // CreateCSR generates an issuer compliant certificate signed request 44 | CreateCSR(domain, email, country, nodeID string, rng io.Reader) (csrPEM, 45 | csrDER []byte, err error) 46 | 47 | // Issue blocks until the challenge is accepted by the remote server, 48 | // and returns a certificate based on the private key and the key in PEM 49 | // format. 50 | Issue(csr []byte, timeout time.Duration) (cert, key []byte, err error) 51 | } 52 | 53 | // Internal client interface so we can mock tests. 54 | // Update as needed based on what we use in the base API. 55 | type acmeClient interface { 56 | // Attribute setters/getters 57 | GetDirectoryURL() string 58 | SetDirectoryURL(d string) 59 | GetKey() crypto.Signer 60 | SetKey(k crypto.Signer) 61 | 62 | // Networking funcs 63 | GetReg(ctx context.Context, _ string) (*acme.Account, error) 64 | Register(ctx context.Context, acct *acme.Account, 65 | tosFn func(tosURL string) bool) (*acme.Account, error) 66 | 67 | DNS01ChallengeRecord(token string) (string, error) 68 | AuthorizeOrder(ctx context.Context, 69 | authzIDs []acme.AuthzID) (*acme.Order, error) 70 | CreateOrderCert(ctx context.Context, finalURL string, csr []byte, 71 | ty bool) ([][]byte, string, error) 72 | GetAuthorization(ctx context.Context, 73 | authzURL string) (*acme.Authorization, error) 74 | Accept(ctx context.Context, 75 | chal *acme.Challenge) (*acme.Challenge, error) 76 | WaitAuthorization(ctx context.Context, 77 | authzURL string) (*acme.Authorization, error) 78 | } 79 | -------------------------------------------------------------------------------- /storage/extendedRoundStorage.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Bridge for ExtendedRoundStorage between comms and the database 9 | // Allows Gateway to store round information longer than would be stored in the 10 | // ring buffer. 11 | 12 | package storage 13 | 14 | import ( 15 | "github.com/golang/protobuf/proto" 16 | "github.com/pkg/errors" 17 | pb "gitlab.com/elixxir/comms/mixmessages" 18 | "gitlab.com/xx_network/primitives/id" 19 | "strings" 20 | ) 21 | 22 | // Store a new round info object into the map 23 | func (s *Storage) Store(ri *pb.RoundInfo) error { 24 | // Marshal the data so it can be stored 25 | m, err := proto.Marshal(ri) 26 | if err != nil { 27 | return err 28 | } 29 | 30 | // Create our DB Round object to store 31 | dbr := Round{ 32 | Id: ri.ID, 33 | UpdateId: ri.UpdateID, 34 | InfoBlob: m, 35 | } 36 | 37 | // Store it 38 | err = s.UpsertRound(&dbr) 39 | if err != nil { 40 | return err 41 | } 42 | return nil 43 | } 44 | 45 | // Get a round info object from the memory map database 46 | func (s *Storage) Retrieve(id id.Round) (*pb.RoundInfo, error) { 47 | // Retrieve round from the database 48 | dbr, err := s.GetRound(id) 49 | // Detect if we have an error, if it is because the round couldn't be found 50 | // we suppress it. Otherwise, bring it up the path. 51 | if err != nil { 52 | if strings.HasPrefix(err.Error(), "Could not find Round with ID ") { 53 | return nil, nil 54 | } else { 55 | return nil, err 56 | } 57 | } 58 | 59 | // Convert it to a pb.RoundInfo object 60 | u := &pb.RoundInfo{} 61 | err = proto.Unmarshal(dbr.InfoBlob, u) 62 | if err != nil { 63 | return nil, err 64 | } 65 | 66 | // Return it 67 | return u, nil 68 | } 69 | 70 | // Get multiple specific round info objects from the memory map database 71 | func (s *Storage) RetrieveMany(rounds []id.Round) ([]*pb.RoundInfo, error) { 72 | var r = make([]*pb.RoundInfo, len(rounds)) 73 | 74 | foundRounds := false 75 | 76 | // Iterate over all rounds provided and put them in the round array 77 | for i, rid := range rounds { 78 | ri, err := s.Retrieve(rid) 79 | if err != nil { 80 | continue 81 | } 82 | r[i] = ri 83 | foundRounds = true 84 | } 85 | 86 | if !foundRounds { 87 | return nil, errors.New("Failed to find any of the rounds") 88 | } 89 | 90 | return r, nil 91 | } 92 | 93 | // Retrieve a concurrent range of round info objects from the memory map database 94 | func (s *Storage) RetrieveRange(first, last id.Round) ([]*pb.RoundInfo, error) { 95 | if first > last { 96 | return nil, errors.New("Failed to retrieve range of rounds: last round must be greater than first.") 97 | } 98 | idRange := uint64(last-first) + 1 99 | 100 | var r = make([]*pb.RoundInfo, idRange) 101 | 102 | // Iterate over all IDs in the range, retrieving them and putting them in the 103 | // round array 104 | for i := uint64(0); i < idRange; i++ { 105 | ri, err := s.Retrieve(id.Round(uint64(first) + i)) 106 | if err != nil { 107 | return nil, err 108 | } 109 | r[i] = ri 110 | } 111 | 112 | return r, nil 113 | } 114 | -------------------------------------------------------------------------------- /cmd/version_vars.go: -------------------------------------------------------------------------------- 1 | // Code generated by go generate; DO NOT EDIT. 2 | // This file was generated by robots at 3 | // 2023-07-26 14:06:11.458296 -0500 CDT m=+0.043888978 4 | 5 | package cmd 6 | 7 | const GITVERSION = `3d1d402 Merge remote-tracking branch 'origin/release' into release` 8 | const SEMVER = "3.16.1" 9 | const DEPENDENCIES = `module gitlab.com/elixxir/gateway 10 | 11 | go 1.19 12 | 13 | require ( 14 | github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 15 | github.com/golang/protobuf v1.5.2 16 | github.com/pkg/errors v0.9.1 17 | github.com/spf13/cobra v1.1.1 18 | github.com/spf13/jwalterweatherman v1.1.0 19 | github.com/spf13/viper v1.7.1 20 | gitlab.com/elixxir/bloomfilter v0.0.0-20230322223210-fa84f6842de8 21 | gitlab.com/elixxir/comms v0.0.4-0.20230310205528-f06faa0d2f0b 22 | gitlab.com/elixxir/crypto v0.0.7-0.20230322181929-8cb5fa100824 23 | gitlab.com/elixxir/primitives v0.0.3-0.20230214180039-9a25e2d3969c 24 | gitlab.com/xx_network/comms v0.0.4-0.20230214180029-5387fb85736d 25 | gitlab.com/xx_network/crypto v0.0.5-0.20230214003943-8a09396e95dd 26 | gitlab.com/xx_network/primitives v0.0.4-0.20230310205521-c440e68e34c4 27 | golang.org/x/crypto v0.5.0 28 | golang.org/x/net v0.5.0 29 | google.golang.org/grpc v1.49.0 30 | google.golang.org/protobuf v1.28.1 31 | gorm.io/driver/postgres v1.4.5 32 | gorm.io/gorm v1.24.1-0.20221019064659-5dd2bb482755 33 | ) 34 | 35 | require ( 36 | git.xx.network/elixxir/grpc-web-go-client v0.0.0-20230214175953-5b5a8c33d28a // indirect 37 | github.com/cenkalti/backoff/v4 v4.1.3 // indirect 38 | github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f // indirect 39 | github.com/elliotchance/orderedmap v1.4.0 // indirect 40 | github.com/fsnotify/fsnotify v1.4.9 // indirect 41 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect 42 | github.com/gorilla/websocket v1.5.0 // indirect 43 | github.com/hashicorp/hcl v1.0.0 // indirect 44 | github.com/improbable-eng/grpc-web v0.15.0 // indirect 45 | github.com/inconshreveable/mousetrap v1.0.0 // indirect 46 | github.com/jackc/chunkreader/v2 v2.0.1 // indirect 47 | github.com/jackc/pgconn v1.13.0 // indirect 48 | github.com/jackc/pgio v1.0.0 // indirect 49 | github.com/jackc/pgpassfile v1.0.0 // indirect 50 | github.com/jackc/pgproto3/v2 v2.3.1 // indirect 51 | github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect 52 | github.com/jackc/pgtype v1.12.0 // indirect 53 | github.com/jackc/pgx/v4 v4.17.2 // indirect 54 | github.com/jinzhu/inflection v1.0.0 // indirect 55 | github.com/jinzhu/now v1.1.5 // indirect 56 | github.com/klauspost/compress v1.11.7 // indirect 57 | github.com/magiconair/properties v1.8.4 // indirect 58 | github.com/mitchellh/go-homedir v1.1.0 // indirect 59 | github.com/mitchellh/mapstructure v1.4.0 // indirect 60 | github.com/pelletier/go-toml v1.8.1 // indirect 61 | github.com/rs/cors v1.8.2 // indirect 62 | github.com/smartystreets/assertions v1.1.0 // indirect 63 | github.com/soheilhy/cmux v0.1.5 // indirect 64 | github.com/spf13/afero v1.5.1 // indirect 65 | github.com/spf13/cast v1.3.1 // indirect 66 | github.com/spf13/pflag v1.0.5 // indirect 67 | github.com/subosito/gotenv v1.2.0 // indirect 68 | github.com/zeebo/blake3 v0.1.1 // indirect 69 | gitlab.com/xx_network/ring v0.0.3-0.20220902183151-a7d3b15bc981 // indirect 70 | go.uber.org/atomic v1.10.0 // indirect 71 | golang.org/x/sys v0.4.0 // indirect 72 | golang.org/x/text v0.6.0 // indirect 73 | google.golang.org/genproto v0.0.0-20220822174746-9e6da59bd2fc // indirect 74 | gopkg.in/ini.v1 v1.62.0 // indirect 75 | gopkg.in/yaml.v2 v2.4.0 // indirect 76 | nhooyr.io/websocket v1.8.7 // indirect 77 | src.agwa.name/tlshacks v0.0.0-20220518131152-d2c6f4e2b780 // indirect 78 | ) 79 | ` 80 | -------------------------------------------------------------------------------- /cmd/poll_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | pb "gitlab.com/elixxir/comms/mixmessages" 12 | "gitlab.com/elixxir/comms/network" 13 | "gitlab.com/elixxir/comms/testkeys" 14 | "gitlab.com/elixxir/gateway/storage" 15 | "gitlab.com/xx_network/primitives/ndf" 16 | "gitlab.com/xx_network/primitives/rateLimiting" 17 | "testing" 18 | "time" 19 | ) 20 | 21 | // Error path: Pass in invalid messages 22 | func TestInstance_Poll_NilCheck(t *testing.T) { 23 | // Build the gateway instance 24 | params := Params{ 25 | NodeAddress: NODE_ADDRESS, 26 | ServerCertPath: testkeys.GetNodeCertPath(), 27 | CertPath: testkeys.GetGatewayCertPath(), 28 | DevMode: true, 29 | } 30 | 31 | params.messageRateLimitParams = &rateLimiting.MapParams{ 32 | Capacity: 10, 33 | LeakedTokens: 1, 34 | LeakDuration: 10 * time.Second, 35 | PollDuration: 10 * time.Second, 36 | BucketMaxAge: 10 * time.Second, 37 | } 38 | 39 | gw := NewGatewayInstance(params) 40 | gw.InitNetwork() 41 | 42 | // Pass in a nil client ID 43 | clientReq := &pb.GatewayPoll{ 44 | Partial: nil, 45 | LastUpdate: 0, 46 | ReceptionID: nil, 47 | } 48 | 49 | testNDF, _ := ndf.Unmarshal(ExampleJSON) 50 | 51 | // This is bad. It needs to be fixed (Ben's fault for not fixing correctly) 52 | var err error 53 | ers := &storage.Storage{} 54 | gw.NetInf, err = network.NewInstance(gatewayInstance.Comms.ProtoComms, testNDF, testNDF, ers, network.Lazy, false) 55 | gw.filteredUpdates, err = NewFilteredUpdates(gw.NetInf) 56 | if err != nil { 57 | t.Fatalf("Failed to create filtered update: %v", err) 58 | } 59 | _, err = gw.Poll(clientReq) 60 | if err == nil { 61 | t.Errorf("Expected error path. Should error when passing a nil clientID") 62 | } 63 | 64 | // Pass in a completely nil message 65 | _, err = gw.Poll(nil) 66 | if err == nil { 67 | t.Errorf("Expected error path. Should error when passing a nil message") 68 | } 69 | } 70 | 71 | // Happy path 72 | //func TestInstance_Poll(t *testing.T) { 73 | // //Build the gateway instance 74 | // params := Params{ 75 | // NodeAddress: NODE_ADDRESS, 76 | // ServerCertPath: testkeys.GetNodeCertPath(), 77 | // CertPath: testkeys.GetGatewayCertPath(), 78 | // } 79 | // 80 | // gw := NewGatewayInstance(params) 81 | // gw.InitNetwork() 82 | // gw.period = 30 83 | // 84 | // clientId := id.NewIdFromBytes([]byte("test"), t) 85 | // ephemId, _, _, err := ephemeral.GetId(clientId, 8, time.Now().UnixNano()) 86 | // 87 | // clientReq := &pb.GatewayPoll{ 88 | // Partial: nil, 89 | // LastUpdate: 0, 90 | // ReceptionID: ephemId[:], 91 | // } 92 | // testNDF, _ := ndf.Unmarshal(ExampleJSON) 93 | // 94 | // // This is bad. It needs to be fixed (Ben's fault for not fixing correctly) 95 | // ers := &storage.Storage{} 96 | // gw.NetInf, err = network.NewInstance(gatewayInstance.Comms.ProtoComms, testNDF, testNDF, ers) 97 | // 98 | // // TODO: Remove this when jake fixes the database please [Insert deity] 99 | // // Setup a database based on a map impl 100 | // gw.storage, _ = storage.NewStorage("", "", "", "", "") 101 | // 102 | // _, err = gw.Poll(clientReq) 103 | // if err != nil { 104 | // t.Errorf("Failed to poll: %v", err) 105 | // } 106 | // 107 | //} 108 | -------------------------------------------------------------------------------- /cmd/autocert.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | "fmt" 12 | "os" 13 | "time" 14 | 15 | "github.com/spf13/cobra" 16 | jww "github.com/spf13/jwalterweatherman" 17 | "github.com/spf13/viper" 18 | "gitlab.com/elixxir/crypto/fastRNG" 19 | "gitlab.com/elixxir/gateway/autocert" 20 | "gitlab.com/xx_network/crypto/csprng" 21 | "gitlab.com/xx_network/primitives/utils" 22 | ) 23 | 24 | var autocertCmd = &cobra.Command{ 25 | Use: "autocert", 26 | Short: "automatic cert request test command", 27 | Long: `Attempt to request a cert for TLS, used for manual testing`, 28 | Run: func(cmd *cobra.Command, args []string) { 29 | initLog() 30 | 31 | eabKeyID := viper.GetString("eabKeyID") 32 | eabKey := viper.GetString("eabKey") 33 | domain := viper.GetString("domain") 34 | email := viper.GetString("email") 35 | 36 | if eabKey == "" || eabKeyID == "" || domain == "" { 37 | fmt.Printf("need eabKeyID, eabKey, and domain: "+ 38 | "%s,%s,%s", eabKeyID, eabKey, domain) 39 | os.Exit(-1) 40 | } 41 | 42 | rng := fastRNG.NewStreamGenerator(10, 1, csprng.NewSystemRNG) 43 | 44 | certGetter := autocert.NewDNS() 45 | 46 | privKeyPEM, err := utils.ReadFile("certkey.pem") 47 | if os.IsNotExist(err) { 48 | certKey, err := autocert.GenerateCertKey( 49 | rng.GetStream()) 50 | if err != nil { 51 | jww.FATAL.Panicf("%+v", err) 52 | } 53 | 54 | err = utils.WriteFile("certkey.pem", 55 | certKey.MarshalPem(), 56 | 0700, 0755) 57 | if err != nil { 58 | jww.FATAL.Panicf("%+v", err) 59 | } 60 | 61 | err = certGetter.Register(certKey, eabKeyID, eabKey, 62 | email) 63 | if err != nil { 64 | jww.FATAL.Panicf("%+v", err) 65 | } 66 | } else { 67 | certGetter, err = autocert.LoadDNS(privKeyPEM) 68 | if err != nil { 69 | jww.FATAL.Panicf("%+v", err) 70 | } 71 | } 72 | 73 | chalDomain, challenge, err := certGetter.Request(domain) 74 | if err != nil { 75 | jww.FATAL.Panicf("%+v", err) 76 | } 77 | 78 | fmt.Printf("ADD TXT RECORD: %s\t%s\n", chalDomain, challenge) 79 | 80 | csrPEM, csrDER, err := certGetter.CreateCSR(domain, email, "USA", 81 | "NodeID", rng.GetStream()) 82 | if err != nil { 83 | jww.FATAL.Panicf("%+v", err) 84 | return 85 | } 86 | 87 | err = utils.WriteFile("cert-csr.pem", csrPEM, 0700, 0755) 88 | if err != nil { 89 | jww.FATAL.Panicf("%+v", err) 90 | return 91 | } 92 | 93 | cert, key, err := certGetter.Issue(csrDER, time.Hour) 94 | if err != nil { 95 | jww.FATAL.Panicf("%+v", err) 96 | return 97 | } 98 | 99 | err = utils.WriteFile("cert.pem", cert, 0700, 0755) 100 | 101 | if err != nil { 102 | jww.FATAL.Panicf("%+v", err) 103 | return 104 | } 105 | 106 | err = utils.WriteFile("certkey.pem", key, 0700, 0755) 107 | if err != nil { 108 | jww.FATAL.Panicf("%+v", err) 109 | return 110 | } 111 | 112 | }, 113 | } 114 | 115 | func init() { 116 | rootCmd.AddCommand(autocertCmd) 117 | 118 | autocertCmd.Flags().StringP("eabKeyID", "i", "", 119 | "EAB Key ID (Required)") 120 | err := viper.BindPFlag("eabKeyID", autocertCmd.Flags().Lookup( 121 | "eabKeyID")) 122 | handleBindingError(err, "eabKeyID") 123 | 124 | autocertCmd.Flags().StringP("eabKey", "k", "", 125 | "EAB Key base64 format (Required)") 126 | err = viper.BindPFlag("eabKey", autocertCmd.Flags().Lookup( 127 | "eabKey")) 128 | handleBindingError(err, "eabKey") 129 | 130 | autocertCmd.Flags().StringP("domain", "d", "", 131 | "domain name to attempt to register") 132 | err = viper.BindPFlag("domain", autocertCmd.Flags().Lookup( 133 | "domain")) 134 | handleBindingError(err, "domain") 135 | 136 | autocertCmd.Flags().StringP("email", "e", "admins@elixxir.io", 137 | "email for registration, defaults to admins@elixxir.io") 138 | err = viper.BindPFlag("email", autocertCmd.Flags().Lookup( 139 | "email")) 140 | handleBindingError(err, "email") 141 | 142 | } 143 | -------------------------------------------------------------------------------- /cmd/gossip.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Contains functionality related to inter-gateway gossip 9 | 10 | package cmd 11 | 12 | import ( 13 | "crypto/rand" 14 | "github.com/pkg/errors" 15 | jww "github.com/spf13/jwalterweatherman" 16 | "gitlab.com/xx_network/comms/gossip" 17 | "gitlab.com/xx_network/crypto/signature/rsa" 18 | "gitlab.com/xx_network/primitives/id" 19 | ) 20 | 21 | // Tag for the client rate limit gossip protocol 22 | const RateLimitGossip = "clientRateLimit" 23 | 24 | // Tag for the bloom filter gossip 25 | const BloomFilterGossip = "bloomFilter" 26 | 27 | // Starts a thread for monitoring and handling changes to gossip peers 28 | func (gw *Instance) StartPeersThread() { 29 | go func() { 30 | rateLimitProtocol, exists := gw.Comms.Manager.Get(RateLimitGossip) 31 | if !exists { 32 | jww.ERROR.Printf("Unable to get gossip rateLimitProtocol!") 33 | return 34 | } 35 | bloomProtocol, exists := gw.Comms.Manager.Get(BloomFilterGossip) 36 | if !exists { 37 | jww.ERROR.Printf("Unable to get gossip BloomFilter!") 38 | return 39 | } 40 | 41 | //add all previously present gateways 42 | /*for _, gateway := range gw.NetInf.GetFullNdf().Get().Gateways{ 43 | gwId, err := id.Unmarshal(gateway.ID) 44 | if err != nil { 45 | jww.WARN.Printf("Unable to unmarshal gossip peer: %+v", err) 46 | continue 47 | } 48 | jww.INFO.Printf("Added %s to gossip peers list", gwId) 49 | err = rateLimitProtocol.AddGossipPeer(gwId) 50 | if err != nil { 51 | jww.WARN.Printf("Unable to add rate limit gossip peer: %+v", err) 52 | } 53 | err = bloomProtocol.AddGossipPeer(gwId) 54 | if err != nil { 55 | jww.WARN.Printf("Unable to add bloom gossip peer: %+v", err) 56 | } 57 | }*/ 58 | 59 | for { 60 | select { 61 | // TODO: Add kill case? 62 | case removeId := <-gw.removeGateway: 63 | jww.INFO.Printf("Removed %s to gossip peers list", removeId) 64 | err := rateLimitProtocol.RemoveGossipPeer(removeId) 65 | if err != nil { 66 | jww.WARN.Printf("Unable to remove rate limit gossip peer: %+v", err) 67 | } 68 | err = bloomProtocol.RemoveGossipPeer(removeId) 69 | if err != nil { 70 | jww.WARN.Printf("Unable to remove bloom gossip peer: %+v", err) 71 | } 72 | case add := <-gw.addGateway: 73 | gwId, err := id.Unmarshal(add.Gateway.ID) 74 | if err != nil { 75 | jww.WARN.Printf("Unable to unmarshal gossip peer: %+v", err) 76 | continue 77 | } 78 | jww.INFO.Printf("Added %s to gossip peers list", gwId) 79 | err = rateLimitProtocol.AddGossipPeer(gwId) 80 | if err != nil { 81 | jww.WARN.Printf("Unable to add rate limit gossip peer: %+v", err) 82 | } 83 | err = bloomProtocol.AddGossipPeer(gwId) 84 | if err != nil { 85 | jww.WARN.Printf("Unable to add bloom gossip peer: %+v", err) 86 | } 87 | } 88 | } 89 | }() 90 | } 91 | 92 | // Verify function for Gossip messages 93 | func (gw *Instance) gossipVerify(msg *gossip.GossipMsg, _ []byte) error { 94 | // Locate origin host 95 | origin, err := id.Unmarshal(msg.Origin) 96 | if err != nil { 97 | return errors.Errorf("Unable to unmarshal origin: %+v", err) 98 | } 99 | host, exists := gw.Comms.GetHost(origin) 100 | if !exists { 101 | return errors.Errorf("Unable to locate origin host: %s", host) 102 | } 103 | 104 | // Prepare message hash 105 | options := rsa.NewDefaultOptions() 106 | hash := options.Hash.New() 107 | hash.Write(gossip.Marshal(msg)) 108 | hashed := hash.Sum(nil) 109 | 110 | // Verify signature of message using origin host's public key 111 | err = rsa.Verify(host.GetPubKey(), options.Hash, hashed, msg.Signature, nil) 112 | if err != nil { 113 | return errors.Errorf("Unable to verify signature: %+v", err) 114 | } 115 | 116 | if msg.Tag == RateLimitGossip { 117 | return nil 118 | } else if msg.Tag == BloomFilterGossip { 119 | return nil 120 | } 121 | 122 | return errors.Errorf("Unrecognized tag: %s", msg.Tag) 123 | 124 | } 125 | 126 | // Helper function used to obtain Signature bytes of a given GossipMsg 127 | func buildGossipSignature(gossipMsg *gossip.GossipMsg, privKey *rsa.PrivateKey) ([]byte, error) { 128 | // Hash the message 129 | options := rsa.NewDefaultOptions() 130 | hash := options.Hash.New() 131 | hash.Write(gossip.Marshal(gossipMsg)) 132 | hashed := hash.Sum(nil) 133 | 134 | // Sign the message 135 | return rsa.Sign(rand.Reader, privKey, 136 | options.Hash, hashed, nil) 137 | } 138 | -------------------------------------------------------------------------------- /storage/extendedRoundStorage_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Testing file for extendedRoundStorage.go functions 9 | 10 | package storage 11 | 12 | import ( 13 | pb "gitlab.com/elixxir/comms/mixmessages" 14 | "gitlab.com/xx_network/primitives/id" 15 | "testing" 16 | ) 17 | 18 | // Tests the ERS wrapper functions in one test 19 | // Testing them all individually relies either on the Store function working 20 | // or copy pasting the entire Store function code and embedding the full 21 | // function into each test. 22 | func TestERS(t *testing.T) { 23 | // Setup a database based on a map impl 24 | m := &MapImpl{ 25 | rounds: map[id.Round]*Round{}, 26 | } 27 | 28 | // Create a fake round info to store 29 | origR10 := pb.RoundInfo{ 30 | ID: 10, 31 | UpdateID: 7, 32 | BatchSize: 9, 33 | ResourceQueueTimeoutMillis: 18, 34 | } 35 | 36 | // Store a round 37 | ers := Storage{m} 38 | err := ers.Store(&origR10) 39 | if err != nil { 40 | t.Error(err) 41 | } 42 | 43 | // Grab that round 44 | grabR10, err := ers.Retrieve(id.Round(10)) 45 | if err != nil { 46 | t.Error(err) 47 | } 48 | if grabR10.ID != origR10.ID && grabR10.UpdateID != origR10.UpdateID && grabR10.BatchSize != origR10.BatchSize && 49 | grabR10.ResourceQueueTimeoutMillis != origR10.ResourceQueueTimeoutMillis { 50 | t.Error("Grabbed round object does not look to be the same as the stored one") 51 | } 52 | 53 | // Grab a round that doesn't exist and check it silently fails 54 | _, err = ers.Retrieve(id.Round(5)) 55 | if err != nil { 56 | t.Error("Retrieve did not silently fail on getting non-existent round.", err) 57 | } 58 | 59 | // Create and store two more fake round infos 60 | origR8 := pb.RoundInfo{ 61 | ID: 8, 62 | UpdateID: 2, 63 | BatchSize: 5, 64 | ResourceQueueTimeoutMillis: 23, 65 | } 66 | origR7 := pb.RoundInfo{ 67 | ID: 7, 68 | UpdateID: 43, 69 | BatchSize: 2, 70 | ResourceQueueTimeoutMillis: 39, 71 | } 72 | err = ers.Store(&origR8) 73 | if err != nil { 74 | t.Error(err) 75 | } 76 | err = ers.Store(&origR7) 77 | if err != nil { 78 | t.Error(err) 79 | } 80 | 81 | // Test RetrieveMany 82 | rounds, err := ers.RetrieveMany([]id.Round{10, 9, 8, 7}) 83 | if err != nil { 84 | t.Error(err) 85 | } 86 | if rounds[0].ID != origR10.ID && rounds[0].UpdateID != origR10.UpdateID && rounds[0].BatchSize != origR10.BatchSize && 87 | rounds[0].ResourceQueueTimeoutMillis != origR10.ResourceQueueTimeoutMillis { 88 | t.Error("Grabbed round object does not look to be the same as the stored one") 89 | } 90 | if rounds[1] != nil { 91 | t.Error("Did not receive placeholder for rid 9") 92 | } 93 | if rounds[2].ID != origR8.ID && rounds[2].UpdateID != origR8.UpdateID && rounds[2].BatchSize != origR8.BatchSize && 94 | rounds[2].ResourceQueueTimeoutMillis != origR10.ResourceQueueTimeoutMillis { 95 | t.Error("Grabbed round object does not look to be the same as the stored one") 96 | } 97 | if rounds[3].ID != origR7.ID && rounds[3].UpdateID != origR7.UpdateID && rounds[3].BatchSize != origR7.BatchSize && 98 | rounds[3].ResourceQueueTimeoutMillis != origR7.ResourceQueueTimeoutMillis { 99 | t.Error("Grabbed round object does not look to be the same as the stored one") 100 | } 101 | 102 | // Test RetrieveRange 103 | rounds, err = ers.RetrieveRange(7, 10) 104 | if err != nil { 105 | t.Error(err) 106 | } 107 | if rounds[3].ID != origR10.ID && rounds[3].UpdateID != origR10.UpdateID && rounds[3].BatchSize != origR10.BatchSize && 108 | rounds[3].ResourceQueueTimeoutMillis != origR10.ResourceQueueTimeoutMillis { 109 | t.Error("Grabbed round object does not look to be the same as the stored one") 110 | } 111 | if rounds[2] != nil { 112 | t.Error("Did not receive placeholder for rid 9") 113 | } 114 | if rounds[1].ID != origR8.ID && rounds[1].UpdateID != origR8.UpdateID && rounds[1].BatchSize != origR8.BatchSize && 115 | rounds[1].ResourceQueueTimeoutMillis != origR10.ResourceQueueTimeoutMillis { 116 | t.Error("Grabbed round object does not look to be the same as the stored one") 117 | } 118 | if rounds[0].ID != origR7.ID && rounds[0].UpdateID != origR7.UpdateID && rounds[0].BatchSize != origR7.BatchSize && 119 | rounds[0].ResourceQueueTimeoutMillis != origR7.ResourceQueueTimeoutMillis { 120 | t.Error("Grabbed round object does not look to be the same as the stored one") 121 | } 122 | 123 | rounds, err = ers.RetrieveRange(10, 7) 124 | if err == nil { 125 | t.Error("Should have received an error when first is greater than last") 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /cmd/rateLimitGossip.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Contains gossip methods specific to rate limit gossiping 9 | 10 | package cmd 11 | 12 | import ( 13 | "github.com/golang/protobuf/proto" 14 | "github.com/pkg/errors" 15 | jww "github.com/spf13/jwalterweatherman" 16 | pb "gitlab.com/elixxir/comms/mixmessages" 17 | "gitlab.com/elixxir/gateway/cmd/ipAddress" 18 | "gitlab.com/xx_network/comms/gossip" 19 | "gitlab.com/xx_network/primitives/id" 20 | ) 21 | 22 | // Initialize fields required for the gossip protocol specialized to rate limiting 23 | func (gw *Instance) InitRateLimitGossip() { 24 | 25 | flags := gossip.DefaultProtocolFlags() 26 | flags.FanOut = 4 27 | flags.MaximumReSends = 2 28 | flags.NumParallelSends = 1000 29 | flags.SelfGossip = false 30 | 31 | // Register gossip protocol for bloom filters 32 | gw.Comms.Manager.NewGossip(RateLimitGossip, flags, 33 | gw.gossipRateLimitReceive, gw.gossipVerify, nil) 34 | 35 | } 36 | 37 | // Receive function for Gossip messages specialized to rate limiting 38 | func (gw *Instance) gossipRateLimitReceive(msg *gossip.GossipMsg) error { 39 | // Unmarshal the Sender data 40 | payloadMsg := &pb.BatchSenders{} 41 | err := proto.Unmarshal(msg.GetPayload(), payloadMsg) 42 | if err != nil { 43 | return errors.Errorf("Could not unmarshal gossip payload: %v", err) 44 | } 45 | 46 | capacity, leaked, duration := gw.GetRateLimitParams() 47 | 48 | jww.INFO.Printf("rate limit gossip for round %d: %d senders, %d ips", 49 | payloadMsg.RoundID, len(payloadMsg.SenderIds), 50 | len(payloadMsg.Ips)) 51 | 52 | // Add to leaky bucket for each sender 53 | for _, senderBytes := range payloadMsg.SenderIds { 54 | senderId, err := id.Unmarshal(senderBytes) 55 | if err != nil { 56 | return errors.Errorf("Could not unmarshal sender ID: %+v", err) 57 | } 58 | gw.idRateLimiting.LookupBucket(senderId.String()).AddWithExternalParams(1, capacity, leaked, duration) 59 | } 60 | for _, ipBytes := range payloadMsg.Ips { 61 | ipStr, err := ipAddress.ByteToString(ipBytes) 62 | if err != nil { 63 | jww.WARN.Printf("round %d rate limit gossip sent "+ 64 | "an invalid ip addr %v: %s", payloadMsg.RoundID, ipBytes, err) 65 | } else { 66 | gw.idRateLimiting.LookupBucket(ipStr).AddWithExternalParams(1, capacity, leaked, duration) 67 | } 68 | 69 | } 70 | return nil 71 | } 72 | 73 | // GossipBatch builds a gossip message containing all of the sender IDs 74 | // within the batch and gossips it to all peers 75 | func (gw *Instance) GossipBatch(round id.Round, senders []*id.ID, ips []string) error { 76 | var err error 77 | 78 | // Build the message 79 | gossipMsg := &gossip.GossipMsg{ 80 | Tag: RateLimitGossip, 81 | Origin: gw.Comms.GetId().Marshal(), 82 | } 83 | 84 | // Add the GossipMsg payload 85 | gossipMsg.Payload, err = buildGossipPayloadRateLimit(round, senders, ips) 86 | if err != nil { 87 | return errors.Errorf("Unable to build rate limit gossip payload: %+v", err) 88 | } 89 | 90 | // Add the GossipMsg signature 91 | gossipMsg.Signature, err = buildGossipSignature(gossipMsg, gw.Comms.GetPrivateKey()) 92 | if err != nil { 93 | return errors.Errorf("Unable to build gossip signature: %+v", err) 94 | } 95 | 96 | // Gossip the message 97 | gossipProtocol, ok := gw.Comms.Manager.Get(RateLimitGossip) 98 | if !ok { 99 | return errors.Errorf("Unable to get gossip protocol.") 100 | } 101 | numPeers, errs := gossipProtocol.Gossip(gossipMsg) 102 | 103 | // Return any errors up the stack 104 | if len(errs) != 0 { 105 | jww.TRACE.Printf("Failed to rate limit gossip to: %v", errs) 106 | return errors.Errorf("Could not send to %d out of %d peers", len(errs), numPeers) 107 | } 108 | return nil 109 | } 110 | 111 | // Helper function used to convert Batch into a GossipMsg payload 112 | func buildGossipPayloadRateLimit(round id.Round, senders []*id.ID, ips []string) ([]byte, error) { 113 | // Nil check for the received back 114 | if senders == nil || ips == nil { 115 | return nil, errors.New("Batch does not contain necessary round info needed to gossip") 116 | } 117 | 118 | // Collect all of the sender IDs in the batch 119 | ipsBytesSlice := make([][]byte, 0, len(ips)) 120 | for _, ipStr := range ips { 121 | ipsBytes, err := ipAddress.StringToByte(ipStr) 122 | if err != nil { 123 | jww.WARN.Printf("ip %s failed to get added for round %d"+ 124 | " because : %s", ipStr, round, err) 125 | } else { 126 | ipsBytesSlice = append(ipsBytesSlice, ipsBytes) 127 | } 128 | } 129 | 130 | sendersByteSlice := make([][]byte, 0, len(senders)) 131 | for _, sID := range senders { 132 | sendersByteSlice = append(sendersByteSlice, sID.Marshal()) 133 | } 134 | 135 | payloadMsg := &pb.BatchSenders{ 136 | SenderIds: sendersByteSlice, 137 | Ips: ipsBytesSlice, 138 | RoundID: uint64(round), 139 | } 140 | return proto.Marshal(payloadMsg) 141 | } 142 | -------------------------------------------------------------------------------- /cmd/bloom_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | "gitlab.com/elixxir/comms/testkeys" 12 | "gitlab.com/elixxir/gateway/storage" 13 | "gitlab.com/xx_network/primitives/id" 14 | "gitlab.com/xx_network/primitives/id/ephemeral" 15 | "gitlab.com/xx_network/primitives/rateLimiting" 16 | "testing" 17 | "time" 18 | ) 19 | 20 | // Happy path 21 | func TestInstance_upsertUserFilter(t *testing.T) { 22 | // Create gateway instance 23 | params := Params{ 24 | NodeAddress: NODE_ADDRESS, 25 | ServerCertPath: testkeys.GetNodeCertPath(), 26 | CertPath: testkeys.GetGatewayCertPath(), 27 | DevMode: true, 28 | } 29 | params.messageRateLimitParams = &rateLimiting.MapParams{ 30 | Capacity: 10, 31 | LeakedTokens: 1, 32 | LeakDuration: 10 * time.Second, 33 | PollDuration: 10 * time.Second, 34 | BucketMaxAge: 10 * time.Second, 35 | } 36 | 37 | gw := NewGatewayInstance(params) 38 | rndId := id.Round(0) 39 | 40 | // Create a mock client 41 | testClientId := id.NewIdFromString("0", id.User, t) 42 | testEphId, _, _, err := ephemeral.GetId(testClientId, 64, time.Now().UnixNano()) 43 | if err != nil { 44 | t.Errorf("Could not create an ephemeral id: %v", err) 45 | } 46 | testEpoch := uint32(0) 47 | 48 | // Pull a bloom filter from the database on the client ID BEFORE INSERTION 49 | retrievedFilters, err := gw.storage.GetClientBloomFilters(testEphId, testEpoch, testEpoch) 50 | 51 | // Check that this filter is nil 52 | if err == nil || retrievedFilters != nil { 53 | t.Errorf("Should not get test client from storage prior to insertion.") 54 | } 55 | 56 | // Create a bloom filter on this client ID 57 | err = gw.UpsertFilter(testEphId, rndId, testEpoch) 58 | if err != nil { 59 | t.Errorf("Failed to create user bloom filter: %s", err) 60 | } 61 | 62 | // Pull a bloom filter from the database on the client ID AFTER INSERTION 63 | retrievedFilters, err = gw.storage.GetClientBloomFilters(testEphId, testEpoch, testEpoch) 64 | if err != nil { 65 | t.Errorf("Could not get filters from storage: %s", err) 66 | } 67 | 68 | // Check that it is of the expected length and not nil 69 | if retrievedFilters == nil || len(retrievedFilters) != 1 { 70 | t.Errorf("Retrieved client did not store new bloom filter") 71 | } 72 | 73 | // Insert a client already 74 | err = gw.storage.UpsertClient(&storage.Client{ 75 | Id: testClientId.Marshal(), 76 | }) 77 | if err != nil { 78 | t.Errorf("Could not load client into storage: %v", err) 79 | } 80 | 81 | // Create a bloom filter on this client ID 82 | err = gw.UpsertFilter(testEphId, id.Round(1), testEpoch) 83 | if err != nil { 84 | t.Errorf("Failed to create user bloom filter: %s", err) 85 | } 86 | 87 | // Pull a bloom filter from the database on the client ID AFTER INSERTION 88 | retrievedFilters, err = gw.storage.GetClientBloomFilters(testEphId, testEpoch, testEpoch) 89 | if err != nil { 90 | t.Errorf("Could not get filters from storage: %s", err) 91 | } 92 | 93 | // Check that it is of the expected length and not nil 94 | if retrievedFilters == nil { 95 | t.Errorf("Retrieved client did not store new bloom filter") 96 | } 97 | 98 | } 99 | 100 | // Happy path 101 | func TestInstance_UpsertFilters(t *testing.T) { 102 | // Create gateway instance 103 | params := Params{ 104 | NodeAddress: NODE_ADDRESS, 105 | ServerCertPath: testkeys.GetNodeCertPath(), 106 | CertPath: testkeys.GetGatewayCertPath(), 107 | DevMode: true, 108 | } 109 | params.messageRateLimitParams = &rateLimiting.MapParams{ 110 | Capacity: 10, 111 | LeakedTokens: 1, 112 | LeakDuration: 10 * time.Second, 113 | PollDuration: 10 * time.Second, 114 | BucketMaxAge: 10 * time.Second, 115 | } 116 | 117 | gw := NewGatewayInstance(params) 118 | rndId := id.Round(0) 119 | 120 | // Create a mock client 121 | testClientId := id.NewIdFromString("0", id.User, t) 122 | testEphId, _, _, err := ephemeral.GetId(testClientId, 64, time.Now().UnixNano()) 123 | if err != nil { 124 | t.Errorf("Could not create an ephemeral id: %v", err) 125 | } 126 | 127 | testEpoch := uint32(0) 128 | 129 | // Check that the databases are empty of filters 130 | retrievedFilter, err := gw.storage.GetClientBloomFilters(testEphId, testEpoch, testEpoch) 131 | // Check that this filter is nil 132 | if err == nil || retrievedFilter != nil { 133 | t.Errorf("Should not get test client from storage prior to insertion.") 134 | } 135 | 136 | // This should result in a bloom filter being created 137 | err = gw.UpsertFilter(testEphId, rndId, testEpoch) 138 | if err != nil { 139 | t.Errorf("Could not create a bloom filter: %v", err) 140 | } 141 | 142 | // Check that a bloom filter has been created 143 | retrievedFilter, err = gw.storage.GetClientBloomFilters(testEphId, testEpoch, testEpoch) 144 | if retrievedFilter == nil || len(retrievedFilter) != 1 { 145 | t.Errorf("Retrieved ehphemeral filter was not expected. Should be non-nil an dlength of 1") 146 | } 147 | 148 | } 149 | -------------------------------------------------------------------------------- /storage/storage.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Handles the high level storage API. 9 | // This layer merges the business logic layer and the database layer 10 | 11 | package storage 12 | 13 | import ( 14 | jww "github.com/spf13/jwalterweatherman" 15 | "gitlab.com/xx_network/primitives/id" 16 | "gitlab.com/xx_network/primitives/id/ephemeral" 17 | "time" 18 | ) 19 | 20 | const ( 21 | // Determines maximum runtime (in seconds) of specific DB queries 22 | dbTimeout = 10 * time.Second 23 | // Determines maximum number of uses for a BloomFilter in a given period. 24 | maxBloomUses = 64 25 | ) 26 | 27 | // API for the storage layer 28 | type Storage struct { 29 | // Stored database interface 30 | database 31 | } 32 | 33 | // Create a new Storage object wrapping a database interface 34 | // Returns a Storage object and error 35 | func NewStorage(username, password, dbName, address, port string, devmode bool) (*Storage, error) { 36 | db, err := newDatabase(username, password, dbName, address, port, devmode) 37 | storage := &Storage{db} 38 | return storage, err 39 | } 40 | 41 | // Clears certain data from Storage older than the given timestamp 42 | // This includes Round and MixedMessage information 43 | func (s *Storage) ClearOldStorage(ts time.Time) error { 44 | err := s.deleteRound(ts) 45 | if err != nil { 46 | return err 47 | } 48 | 49 | return s.deleteMixedMessages(ts) 50 | } 51 | 52 | // Builds a ClientBloomFilter with the given parameters, then stores it 53 | func (s *Storage) HandleBloomFilter(recipientId ephemeral.Id, filterBytes []byte, roundId id.Round, epoch uint32) error { 54 | // Ignore zero-value recipient ID for now - this is a reserved address 55 | recipientIdInt := recipientId.Int64() 56 | if recipientIdInt == 0 { 57 | return nil 58 | } 59 | 60 | // Build a newly-initialized ClientBloomFilter to be stored 61 | validFilter := &ClientBloomFilter{ 62 | RecipientId: &recipientIdInt, 63 | Epoch: epoch, 64 | // FirstRound is input as CurrentRound for later calculation 65 | FirstRound: uint64(roundId), 66 | // RoundRange is empty for now as it can't be calculated yet 67 | RoundRange: 0, 68 | Filter: filterBytes, 69 | } 70 | 71 | // Commit the new/updated ClientBloomFilter 72 | return s.upsertClientBloomFilter(validFilter) 73 | } 74 | 75 | // Returns a slice of MixedMessage from database with matching recipientId and roundId 76 | // Also returns a boolean for whether the gateway contains other messages for the given Round 77 | func (s *Storage) GetMixedMessages(recipientId ephemeral.Id, roundId id.Round) (msgs []*MixedMessage, hasRound bool, err error) { 78 | // Determine whether this gateway has any messages for the given roundId 79 | count, hasRound, err := s.countMixedMessagesByRound(roundId) 80 | if !hasRound || count == 0 { 81 | return 82 | } 83 | 84 | // If the gateway has messages, return messages relevant to the given recipientId and roundId 85 | msgs, err = s.getMixedMessages(recipientId, roundId) 86 | return 87 | } 88 | 89 | // Helper function for HandleBloomFilter 90 | // Returns the bitwise OR of two byte slices 91 | func or(existingBuffer, additionalBuffer []byte) []byte { 92 | if existingBuffer == nil { 93 | return additionalBuffer 94 | } else if additionalBuffer == nil { 95 | return existingBuffer 96 | } else if len(existingBuffer) != len(additionalBuffer) { 97 | jww.ERROR.Printf("Unable to perform bitwise OR: Slice lens invalid.") 98 | return existingBuffer 99 | } 100 | 101 | result := make([]byte, len(existingBuffer)) 102 | for i := range existingBuffer { 103 | result[i] = existingBuffer[i] | additionalBuffer[i] 104 | } 105 | return result 106 | } 107 | 108 | // Combine with and update this filter using oldFilter 109 | // Used in upsertFilter functionality in order to ensure atomicity 110 | // Kept in business logic layer because functionality is shared 111 | func (f *ClientBloomFilter) combine(oldFilter *ClientBloomFilter) { 112 | 113 | // Initialize FirstRound variable if needed 114 | if oldFilter.FirstRound == uint64(0) { 115 | oldFilter.FirstRound = f.FirstRound 116 | } 117 | 118 | // calculate what the first round should be 119 | firstRound := oldFilter.FirstRound 120 | if f.FirstRound < oldFilter.FirstRound { 121 | firstRound = f.FirstRound 122 | } 123 | 124 | // calculate what the last round should be 125 | lastRound := oldFilter.lastRound() 126 | if f.lastRound() > lastRound { 127 | lastRound = f.lastRound() 128 | } 129 | 130 | // set the first round 131 | // note this MUST be after last round is calculated 132 | // becasue the value in f is used in the last round calculation 133 | f.FirstRound = firstRound 134 | 135 | // calculate the round range based upon the first and last round 136 | f.RoundRange = uint32(lastRound - firstRound) 137 | 138 | // Combine the filters 139 | f.Filter = or(oldFilter.Filter, f.Filter) 140 | f.Uses = oldFilter.Uses + 1 141 | f.Id = oldFilter.Id 142 | } 143 | 144 | func (f *ClientBloomFilter) lastRound() uint64 { 145 | return f.FirstRound + uint64(f.RoundRange) 146 | } 147 | -------------------------------------------------------------------------------- /storage/unmixedMapBuffer_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package storage 9 | 10 | import ( 11 | pb "gitlab.com/elixxir/comms/mixmessages" 12 | id "gitlab.com/xx_network/primitives/id" 13 | "testing" 14 | ) 15 | 16 | // tests that unmixed messages are properly added to the unmixed buffer 17 | func TestUnmixedMapBuffer_AddUnmixedMessage(t *testing.T) { 18 | testMap := make(map[id.Round]*SendRound) 19 | unmixedMessageBuf := &UnmixedMessagesMap{ 20 | messages: testMap, 21 | } 22 | 23 | numOutgoingMsgs := len(unmixedMessageBuf.messages) 24 | unmixedMessageBuf.SetAsRoundLeader(id.Round(0), 5) 25 | 26 | senderId := id.NewIdFromString("test", id.User, t) 27 | 28 | unmixedMessageBuf.AddUnmixedMessage(&pb.Slot{SenderID: id.ZeroUser.Marshal()}, senderId, "", id.Round(0)) 29 | 30 | if len(unmixedMessageBuf.messages) != numOutgoingMsgs+1 { 31 | t.Errorf("AddUnMixedMessage: Message was not added to outgoing" + 32 | " message buffer properly!") 33 | } 34 | } 35 | 36 | // tests that removing messages from unmixed buffer works correctly 37 | func TestUnmixedMapBuffer_GetUnmixedMessages(t *testing.T) { 38 | unmixedMessageBuf := NewUnmixedMessagesMap() 39 | 40 | if unmixedMessageBuf.LenUnmixed(id.Round(0)) != 0 { 41 | t.Errorf("GetRoundMessages: Queue should be empty! Has %d messages!", 42 | unmixedMessageBuf.LenUnmixed(id.Round(0))) 43 | } 44 | 45 | batch, _, _ := unmixedMessageBuf.PopRound(0) 46 | if batch != nil { 47 | t.Errorf("GetRoundMessages: Should have returned empty batch") 48 | } 49 | testSlot := &pb.Slot{SenderID: id.ZeroUser.Marshal()} 50 | 51 | unmixedMessageBuf.SetAsRoundLeader(id.Round(0), 4) 52 | senderId := id.NewIdFromString("test", id.User, t) 53 | 54 | unmixedMessageBuf.AddUnmixedMessage(testSlot, senderId, "", id.Round(0)) 55 | 56 | // First confirm there is a message present 57 | if unmixedMessageBuf.LenUnmixed(0) != 1 { 58 | t.Errorf("GetRoundMessages: Queue should have 1 message!") 59 | } 60 | 61 | unmixedMessageBuf.PopRound(0) 62 | 63 | // Test that if minCount is greater than the amount of messages, then the 64 | // batch that is returned is nil 65 | 66 | unmixedMessageBuf.AddUnmixedMessage(testSlot, senderId, "", id.Round(0)) 67 | 68 | batch, receivedId, _ := unmixedMessageBuf.PopRound(0) 69 | 70 | if batch != nil { 71 | t.Errorf("Error case of minCount being greater than the amount of"+ 72 | "messages, should received a nil batch but received: %v", batch) 73 | } 74 | 75 | if !senderId.Cmp(receivedId[0]) { 76 | t.Errorf("Error case, should receive a sender ID") 77 | } 78 | 79 | } 80 | 81 | // Happy path 82 | func TestUnmixedMessagesMap_IsRoundFull(t *testing.T) { 83 | unmixedMessageBuf := NewUnmixedMessagesMap() 84 | rndId := id.Round(4) 85 | batchSize := 3 86 | unmixedMessageBuf.SetAsRoundLeader(rndId, uint32(batchSize)) 87 | senderId := id.NewIdFromString("test", id.User, t) 88 | 89 | for i := 0; i < batchSize; i++ { 90 | unmixedMessageBuf.AddUnmixedMessage(&pb.Slot{},senderId, "", rndId) 91 | } 92 | 93 | if !unmixedMessageBuf.IsRoundFull(rndId) { 94 | t.Errorf("Message buffer for round %d should be full."+ 95 | "\n\tExpected messages: %d"+ 96 | "\n\tReceived messaged: %d", rndId, batchSize, unmixedMessageBuf.LenUnmixed(rndId)) 97 | } 98 | } 99 | 100 | // Unit test 101 | func TestUnmixedMessagesMap_IsRoundLeader(t *testing.T) { 102 | unmixedMessageBuf := NewUnmixedMessagesMap() 103 | rndId := id.Round(4) 104 | batchSize := 3 105 | 106 | if unmixedMessageBuf.IsRoundLeader(rndId) { 107 | t.Errorf("Marked as a round leader incorrectly. Should only return true" + 108 | "after a call to SetAsRoundLeader") 109 | } 110 | 111 | unmixedMessageBuf.SetAsRoundLeader(rndId, uint32(batchSize)) 112 | if !unmixedMessageBuf.IsRoundLeader(rndId) { 113 | t.Errorf("Should be marked as a round leader for round %d", rndId) 114 | } 115 | 116 | } 117 | 118 | // Unit test 119 | func TestUnmixedMessagesMap_AddManyUnmixedMessages(t *testing.T) { 120 | testMap := make(map[id.Round]*SendRound) 121 | unmixedMessageBuf := &UnmixedMessagesMap{ 122 | messages: testMap, 123 | } 124 | maxSlots := 5 125 | unmixedMessageBuf.SetAsRoundLeader(id.Round(0), uint32(maxSlots)) 126 | 127 | // Insert slots up to a full batch 128 | slots := make([]*pb.GatewaySlot, 0) 129 | for i := 0; i < maxSlots-1; i++ { 130 | slot := &pb.GatewaySlot{ 131 | Message: &pb.Slot{SenderID: id.ZeroUser.Marshal()}, 132 | } 133 | slots = append(slots, slot) 134 | } 135 | rnd := id.Round(0) 136 | senderId := id.NewIdFromString("test", id.User, t) 137 | 138 | err := unmixedMessageBuf.AddManyUnmixedMessages(slots, senderId, "", rnd) 139 | if err != nil { 140 | t.Fatalf("AddManyUnmixedMessages error: %v", err) 141 | } 142 | 143 | // Construct an extra slot and attempt to insert 144 | slot := &pb.GatewaySlot{ 145 | Message: &pb.Slot{SenderID: id.ZeroUser.Marshal()}, 146 | } 147 | extraSlots := []*pb.GatewaySlot{slot} 148 | err = unmixedMessageBuf.AddManyUnmixedMessages(extraSlots, senderId, "", rnd) 149 | if err == nil { 150 | t.Fatalf("AddManyUnmixedMessages error: " + 151 | "Should not be able to insert into already full batch") 152 | } 153 | 154 | } 155 | -------------------------------------------------------------------------------- /storage/unmixedMapBuffer.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package storage 9 | 10 | import ( 11 | "fmt" 12 | "github.com/pkg/errors" 13 | jww "github.com/spf13/jwalterweatherman" 14 | pb "gitlab.com/elixxir/comms/mixmessages" 15 | "gitlab.com/xx_network/primitives/id" 16 | "sync" 17 | ) 18 | 19 | // UnmixedMessagesMap holds messages that have been received by gateway but have 20 | // yet to been submitted to the network for mixing. 21 | type UnmixedMessagesMap struct { 22 | messages map[id.Round]*SendRound 23 | mux sync.RWMutex 24 | } 25 | 26 | type SendRound struct { 27 | batch *pb.Batch 28 | maxElements uint32 29 | Senders []*id.ID 30 | Ips []string 31 | sent bool 32 | } 33 | 34 | // NewUnmixedMessagesMap initialize a UnmixedMessageBuffer interface. 35 | func NewUnmixedMessagesMap() UnmixedMessageBuffer { 36 | // Build the UnmixedMessagesMap 37 | buffer := &UnmixedMessagesMap{ 38 | messages: map[id.Round]*SendRound{}, 39 | } 40 | 41 | return buffer 42 | } 43 | 44 | // AddUnmixedMessage adds a message to send to the cMix node. 45 | func (umb *UnmixedMessagesMap) AddUnmixedMessage(msg *pb.Slot, sender *id.ID, ip string, roundId id.Round) error { 46 | umb.mux.Lock() 47 | defer umb.mux.Unlock() 48 | 49 | 50 | retrievedBatch, ok := umb.messages[roundId] 51 | if !ok { 52 | return errors.New("cannot add message to unknown round") 53 | } 54 | 55 | if retrievedBatch.sent { 56 | return errors.New("Cannot add message to already sent batch") 57 | } 58 | 59 | if len(retrievedBatch.batch.Slots) == int(retrievedBatch.maxElements) { 60 | return errors.New("Cannot add message to full batch") 61 | } 62 | 63 | // If the batch for this round was already created, add another message 64 | retrievedBatch.batch.Slots = append(retrievedBatch.batch.Slots, msg) 65 | retrievedBatch.Senders = append(retrievedBatch.Senders, sender) 66 | retrievedBatch.Ips = append(retrievedBatch.Ips, ip) 67 | 68 | umb.messages[roundId] = retrievedBatch 69 | return nil 70 | } 71 | 72 | // AddManyUnmixedMessage adds many unmixed messages to send to the cMix node. 73 | func (umb *UnmixedMessagesMap) AddManyUnmixedMessages(msgs []*pb.GatewaySlot, sender *id.ID, 74 | ip string, roundId id.Round) error { 75 | umb.mux.Lock() 76 | defer umb.mux.Unlock() 77 | 78 | // Pull batch from store (map) 79 | retrievedBatch, ok := umb.messages[roundId] 80 | if !ok { 81 | return errors.New("cannot add message to unknown round") 82 | } 83 | 84 | // Check that the batch has not 85 | if retrievedBatch.sent { 86 | return errors.New("Cannot add message to already sent batch") 87 | } 88 | 89 | // Check that adding these message wil not exceed the batch size 90 | resultingSlots := len(retrievedBatch.batch.Slots) + len(msgs) 91 | if resultingSlots >= int(retrievedBatch.maxElements) { 92 | fmt.Printf("resulting Slots: %d\nmax %d\n", resultingSlots, int(retrievedBatch.maxElements)) 93 | return errors.New("Cannot add messages to full batch") 94 | } 95 | 96 | // Collect all slots into a list 97 | slots := make([]*pb.Slot, len(msgs)) 98 | for i := 0; i < len(msgs); i++ { 99 | slots[i] = msgs[i].Message 100 | retrievedBatch.Senders = append(retrievedBatch.Senders, sender) 101 | retrievedBatch.Ips = append(retrievedBatch.Ips, ip) 102 | } 103 | 104 | // If the batch for this round was already created, add another message 105 | retrievedBatch.batch.Slots = append(retrievedBatch.batch.Slots, slots...) 106 | umb.messages[roundId] = retrievedBatch 107 | return nil 108 | } 109 | 110 | // GetRoundMessages returns the batch associated with the roundID 111 | func (umb *UnmixedMessagesMap) PopRound(roundId id.Round) (*pb.Batch, []*id.ID, []string) { 112 | umb.mux.Lock() 113 | defer umb.mux.Unlock() 114 | 115 | retrievedBatch, ok := umb.messages[roundId] 116 | if !ok { 117 | return nil, nil, nil 118 | } 119 | 120 | retrievedBatch.sent = true 121 | 122 | // Handle batches too small to send 123 | batch := retrievedBatch.batch 124 | senders := retrievedBatch.Senders 125 | ips := retrievedBatch.Ips 126 | retrievedBatch.batch = nil 127 | umb.messages[roundId] = retrievedBatch 128 | return batch, senders, ips 129 | } 130 | 131 | // LenUnmixed return the number of messages within the requested round 132 | func (umb *UnmixedMessagesMap) LenUnmixed(rndId id.Round) int { 133 | umb.mux.RLock() 134 | defer umb.mux.RUnlock() 135 | b, ok := umb.messages[rndId] 136 | if !ok { 137 | return 0 138 | } 139 | 140 | return len(b.batch.Slots) 141 | } 142 | 143 | func (umb *UnmixedMessagesMap) IsRoundFull(roundId id.Round) bool { 144 | umb.mux.RLock() 145 | defer umb.mux.RUnlock() 146 | slots := umb.messages[roundId].batch.GetSlots() 147 | return len(slots) == int(umb.messages[roundId].maxElements) 148 | } 149 | 150 | // SetAsRoundLeader initializes a round as our responsibility ny initializing 151 | // marking that round as non-nil within the internal map 152 | func (umb *UnmixedMessagesMap) SetAsRoundLeader(roundId id.Round, batchsize uint32) { 153 | umb.mux.Lock() 154 | defer umb.mux.Unlock() 155 | 156 | if _, ok := umb.messages[roundId]; ok { 157 | jww.FATAL.Panicf("Can set as round leader for extant round %d", 158 | roundId) 159 | } 160 | jww.INFO.Printf("Adding round buffer for round %d", roundId) 161 | umb.messages[roundId] = &SendRound{ 162 | batch: &pb.Batch{Slots: make([]*pb.Slot, 0, batchsize)}, 163 | maxElements: batchsize, 164 | Senders: make([]*id.ID, 0, batchsize), 165 | Ips: make([]string,0,batchsize), 166 | } 167 | } 168 | 169 | // IsRoundLeader returns true if object mapped to this round has 170 | // been previously set 171 | func (umb *UnmixedMessagesMap) IsRoundLeader(roundId id.Round) bool { 172 | umb.mux.RLock() 173 | defer umb.mux.RUnlock() 174 | 175 | _, ok := umb.messages[roundId] 176 | return ok 177 | } 178 | -------------------------------------------------------------------------------- /cmd/knownRoundsWrapper.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // The knownRoundsWrapper contains the gateway's known rounds in a structure 9 | // that saves the marshalled known rounds to memory and the database every time 10 | // a change is made instead of marshalling the known rounds every time the 11 | // marshalled bytes are needed. 12 | 13 | package cmd 14 | 15 | import ( 16 | "encoding/base64" 17 | "github.com/pkg/errors" 18 | jww "github.com/spf13/jwalterweatherman" 19 | "gitlab.com/elixxir/gateway/storage" 20 | "gitlab.com/elixxir/primitives/knownRounds" 21 | "gitlab.com/xx_network/primitives/id" 22 | "sync" 23 | "time" 24 | ) 25 | 26 | // Error messages. 27 | const ( 28 | // Determines round differences that triggers a truncate 29 | knownRoundsTruncateThreshold id.Round = 3000 30 | storageUpsertErr = "failed to upsert marshalled KnownRounds to storage: %+v" 31 | storageGetErr = "failed to get KnownRounds from storage: %+v" 32 | storageDecodeErr = "failed to decode KnownRounds from storage: %+v" 33 | storageUnmarshalErr = "failed to unmarshal KnownRounds from storage: %+v" 34 | ) 35 | 36 | type knownRoundsWrapper struct { 37 | kr *knownRounds.KnownRounds 38 | marshalled []byte 39 | truncated []byte 40 | l sync.RWMutex 41 | backupChan chan bool 42 | backupPeriod time.Duration 43 | } 44 | 45 | // newKnownRoundsWrapper creates a new knownRoundsWrapper with a new KnownRounds 46 | // initialised to the round capacity and saves marshalled bytes. 47 | func newKnownRoundsWrapper(roundCapacity int, store *storage.Storage) (*knownRoundsWrapper, error) { 48 | krw := &knownRoundsWrapper{ 49 | kr: knownRounds.NewKnownRound(roundCapacity), 50 | marshalled: []byte{}, 51 | truncated: []byte{}, 52 | backupChan: make(chan bool, 1), 53 | } 54 | 55 | krw.backupState(store) 56 | 57 | // There is no round 0 58 | krw.kr.Check(0) 59 | jww.TRACE.Printf("Initial KnownRound State: %+v", krw.kr) 60 | 61 | // Save marshalled knownRounds to memory and storage 62 | err := krw.saveUnsafe() 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | jww.DEBUG.Printf("Initial KnownRound Marshal: %v", krw.marshalled) 68 | 69 | krw.backupPeriod = 5 * time.Second 70 | 71 | return krw, nil 72 | } 73 | 74 | // check force checks the round and saves the KnownRounds. 75 | func (krw *knownRoundsWrapper) check(rid id.Round) error { 76 | krw.l.Lock() 77 | defer krw.l.Unlock() 78 | 79 | krw.kr.Check(rid) 80 | 81 | return krw.saveUnsafe() 82 | } 83 | 84 | func (krw *knownRoundsWrapper) truncateMarshal() []byte { 85 | krw.l.RLock() 86 | defer krw.l.RUnlock() 87 | 88 | bytes := make([]byte, len(krw.truncated)) 89 | copy(bytes, krw.truncated) 90 | 91 | return bytes 92 | } 93 | 94 | func (krw *knownRoundsWrapper) getLastChecked() id.Round { 95 | krw.l.RLock() 96 | defer krw.l.RUnlock() 97 | 98 | return krw.kr.GetLastChecked() 99 | } 100 | 101 | // forceCheck force checks the round and saves the KnownRounds. 102 | func (krw *knownRoundsWrapper) forceCheck(rid id.Round) error { 103 | krw.l.Lock() 104 | defer krw.l.Unlock() 105 | 106 | krw.kr.ForceCheck(rid) 107 | 108 | return krw.saveUnsafe() 109 | } 110 | 111 | // getMarshal returns a copy of the marshalled bytes of the KnownRounds. 112 | func (krw *knownRoundsWrapper) getMarshal() []byte { 113 | krw.l.RLock() 114 | defer krw.l.RUnlock() 115 | 116 | bytes := make([]byte, len(krw.marshalled)) 117 | copy(bytes, krw.marshalled) 118 | 119 | return bytes 120 | } 121 | 122 | // save the marshalled KnownRounds to memory and storage. This 123 | // function is thread safe. 124 | func (krw *knownRoundsWrapper) save() error { 125 | krw.l.Lock() 126 | defer krw.l.Unlock() 127 | 128 | return krw.saveUnsafe() 129 | } 130 | 131 | // saveUnsafe saves the marshalled KnownRounds but the mutex must be 132 | // locked by the caller. 133 | func (krw *knownRoundsWrapper) saveUnsafe() error { 134 | // Marshal and save knownRounds 135 | krw.marshalled = krw.kr.Marshal() 136 | if krw.kr.GetLastChecked() > knownRoundsTruncateThreshold { 137 | krw.truncated = krw.kr.Truncate(krw.kr.GetLastChecked() - knownRoundsTruncateThreshold).Marshal() 138 | } else { 139 | krw.truncated = krw.marshalled 140 | } 141 | 142 | // Send a signal to the backup chan 143 | // This is a non-blocking send to a buffered channel - this means that if 144 | // there is already a waiting signal in the channel, another will not be 145 | // sent. The result is that we will run a backup at most once per interval, 146 | // but will not continue backing up if no new data has been added 147 | select { 148 | case krw.backupChan <- true: 149 | default: 150 | } 151 | 152 | return nil 153 | } 154 | 155 | // Store known rounds marshalled in state at most once every 5 seconds 156 | func (krw *knownRoundsWrapper) backupState(store *storage.Storage) { 157 | go func() { 158 | for { 159 | select { 160 | // Wait on backup channel for triggers 161 | case <-krw.backupChan: 162 | // Store knownRounds data 163 | err := store.UpsertState(&storage.State{ 164 | Key: storage.KnownRoundsKey, 165 | Value: base64.StdEncoding.EncodeToString(krw.marshalled), 166 | }) 167 | if err != nil { 168 | jww.ERROR.Printf(storageUpsertErr, err) 169 | } 170 | // Sleep for backupPeriod after running 171 | // backupChan is buffered, so if requests come in during sleep 172 | // this will run again immediately after 173 | time.Sleep(krw.backupPeriod) 174 | } 175 | } 176 | }() 177 | } 178 | 179 | // Returns whether the given round calls for a truncated knownRound 180 | func (krw *knownRoundsWrapper) needsTruncated(round id.Round) bool { 181 | lastChecked := krw.kr.GetLastChecked() 182 | return round < lastChecked && lastChecked-round > knownRoundsTruncateThreshold 183 | } 184 | 185 | // load the KnownRounds from storage into the knownRoundsWrapper. 186 | func (krw *knownRoundsWrapper) load(store *storage.Storage) error { 187 | krw.l.Lock() 188 | defer krw.l.Unlock() 189 | 190 | // Get an existing knownRounds value from storage 191 | data, err := store.GetStateValue(storage.KnownRoundsKey) 192 | if err != nil { 193 | return errors.Errorf(storageGetErr, err) 194 | } 195 | 196 | dataDecode, err := base64.StdEncoding.DecodeString(data) 197 | if err != nil { 198 | return errors.Errorf(storageDecodeErr, err) 199 | } 200 | 201 | // Parse the data and store the KnownRounds 202 | err = krw.kr.Unmarshal(dataDecode) 203 | if err != nil { 204 | return errors.Errorf(storageUnmarshalErr, err) 205 | } 206 | 207 | // Save the marshalled KnownRounds 208 | krw.marshalled = dataDecode 209 | if krw.kr.GetLastChecked() > knownRoundsTruncateThreshold { 210 | krw.truncated = krw.kr.Truncate(krw.kr.GetLastChecked() - knownRoundsTruncateThreshold).Marshal() 211 | } else { 212 | krw.truncated = dataDecode 213 | } 214 | 215 | return nil 216 | } 217 | -------------------------------------------------------------------------------- /storage/storage_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package storage 9 | 10 | import ( 11 | "bytes" 12 | "gitlab.com/xx_network/primitives/id" 13 | "gitlab.com/xx_network/primitives/id/ephemeral" 14 | "math/rand" 15 | "testing" 16 | ) 17 | 18 | // Happy path 19 | func TestOr(t *testing.T) { 20 | l1 := []byte{65, 0, 0, 0, 0, 172} 21 | l2 := []byte{72, 66, 67, 226, 130, 1} 22 | expected := []byte{73, 66, 67, 226, 130, 173} 23 | 24 | result := or(l1, l2) 25 | if !bytes.Equal(result, expected) { 26 | t.Errorf("Invalid Or Return 1: %v", result) 27 | } 28 | } 29 | 30 | // Nil paths 31 | func TestOr_Nil(t *testing.T) { 32 | var l1 []byte 33 | var l2 []byte 34 | 35 | result := or(l1, l2) 36 | if result != nil { 37 | t.Errorf("Invalid Nil Or Return 1: %v", result) 38 | } 39 | 40 | l1 = []byte("test") 41 | result = or(l1, l2) 42 | if !bytes.Equal(l1, result) { 43 | t.Errorf("Invalid Nil Or Return 2: %v", result) 44 | } 45 | 46 | l1 = []byte("test") 47 | result = or(l2, l1) 48 | if !bytes.Equal(l1, result) { 49 | t.Errorf("Invalid Nil Or Return 3: %v", result) 50 | } 51 | } 52 | 53 | // Unequal length path 54 | func TestOr_Length(t *testing.T) { 55 | l1 := []byte("CHUNGUS") 56 | l2 := []byte("no") 57 | 58 | result := or(l1, l2) 59 | if !bytes.Equal(l1, result) { 60 | t.Errorf("Invalid Len Or Return 1: %v", result) 61 | } 62 | 63 | result = or(l2, l1) 64 | if !bytes.Equal(l2, result) { 65 | t.Errorf("Invalid Len Or Return 2: %v", result) 66 | } 67 | } 68 | 69 | // Happy path - New filter 70 | func TestClientBloomFilter_Combine_New(t *testing.T) { 71 | testFilter := []byte("test") 72 | recipientId := int64(10) 73 | recipientId2 := int64(0) 74 | oldFilter := &ClientBloomFilter{ 75 | RecipientId: &recipientId2, 76 | Epoch: 0, 77 | FirstRound: 0, 78 | RoundRange: 0, 79 | Filter: testFilter, 80 | } 81 | newFilter := &ClientBloomFilter{ 82 | RecipientId: &recipientId, 83 | Epoch: 10, 84 | FirstRound: 10, 85 | RoundRange: 0, 86 | } 87 | 88 | newFilter.combine(oldFilter) 89 | 90 | // Ensure some things did not change 91 | if *newFilter.RecipientId != recipientId { 92 | t.Errorf("Unexpected recipient change: %d", newFilter.RecipientId) 93 | } 94 | if newFilter.Epoch != 10 { 95 | t.Errorf("Unexpected epoch change: %d", newFilter.Epoch) 96 | } 97 | if newFilter.RoundRange != 0 { 98 | t.Errorf("Unexpected RoundRange value: %d", newFilter.RoundRange) 99 | } 100 | 101 | // Ensure some things did change 102 | if oldFilter.FirstRound != 10 { 103 | t.Errorf("Expected FirstRound change: %d", oldFilter.FirstRound) 104 | } 105 | if !bytes.Equal(newFilter.Filter, testFilter) { 106 | t.Errorf("Unexpected Filter value: %v", newFilter.Filter) 107 | } 108 | } 109 | 110 | // Happy path - Update filter 111 | func TestClientBloomFilter_Combine_Update(t *testing.T) { 112 | testFilter := []byte("test") 113 | recipientId := int64(10) 114 | oldFilter := &ClientBloomFilter{ 115 | RecipientId: &recipientId, 116 | Epoch: 10, 117 | FirstRound: 10, 118 | RoundRange: 0, 119 | Filter: testFilter, 120 | } 121 | newFilter := &ClientBloomFilter{ 122 | RecipientId: &recipientId, 123 | Epoch: 10, 124 | FirstRound: 20, 125 | RoundRange: 0, 126 | Filter: testFilter, 127 | } 128 | 129 | newFilter.combine(oldFilter) 130 | 131 | // Ensure some things didn't change 132 | if oldFilter.FirstRound != 10 { 133 | t.Errorf("Unexpected FirstRound change: %d", oldFilter.FirstRound) 134 | } 135 | if !bytes.Equal(newFilter.Filter, testFilter) { 136 | t.Errorf("Unexpected Filter value: %v", newFilter.Filter) 137 | } 138 | 139 | // Ensure some things did change 140 | if newFilter.FirstRound != oldFilter.FirstRound { 141 | t.Errorf("Expected FirstRound change: %d", newFilter.FirstRound) 142 | } 143 | if newFilter.RoundRange != 10 { 144 | t.Errorf("Expected RoundRange change: %d", newFilter.RoundRange) 145 | } 146 | } 147 | 148 | // Happy path - Update filter with newer oldFilter 149 | func TestClientBloomFilter_Combine_UpdateOld(t *testing.T) { 150 | testFilter := []byte("test") 151 | recipientId := int64(10) 152 | oldFilter := &ClientBloomFilter{ 153 | RecipientId: &recipientId, 154 | Epoch: 10, 155 | FirstRound: 10, 156 | RoundRange: 50, 157 | Filter: testFilter, 158 | } 159 | newFilter := &ClientBloomFilter{ 160 | RecipientId: &recipientId, 161 | Epoch: 10, 162 | FirstRound: 20, 163 | RoundRange: 0, 164 | Filter: testFilter, 165 | } 166 | 167 | newFilter.combine(oldFilter) 168 | 169 | // Ensure some things didn't change 170 | if oldFilter.FirstRound != 10 { 171 | t.Errorf("Unexpected FirstRound change: %d", oldFilter.FirstRound) 172 | } 173 | if !bytes.Equal(newFilter.Filter, testFilter) { 174 | t.Errorf("Unexpected Filter value: %v", newFilter.Filter) 175 | } 176 | 177 | // Ensure some things did change 178 | if newFilter.FirstRound != oldFilter.FirstRound { 179 | t.Errorf("Expected FirstRound change: %d", newFilter.FirstRound) 180 | } 181 | if newFilter.RoundRange != oldFilter.RoundRange { 182 | t.Errorf("Expected RoundRange change: %d", newFilter.RoundRange) 183 | } 184 | } 185 | 186 | // Happy path 187 | func TestStorage_GetMixedMessages(t *testing.T) { 188 | testMsgID := rand.Uint64() 189 | testRoundID := id.Round(rand.Uint64()) 190 | testRecipientID := ephemeral.Id{1, 2, 3} 191 | testMixedMessage := &MixedMessage{ 192 | Id: testMsgID, 193 | RoundId: uint64(testRoundID), 194 | RecipientId: testRecipientID.Int64(), 195 | } 196 | storage := &Storage{ 197 | &MapImpl{ 198 | mixedMessages: MixedMessageMap{ 199 | RoundId: map[id.Round]map[int64]map[uint64]*MixedMessage{ 200 | testRoundID: {testRecipientID.Int64(): {testMsgID: testMixedMessage}}, 201 | }, 202 | RecipientId: map[int64]map[id.Round]map[uint64]*MixedMessage{ 203 | testRecipientID.Int64(): {testRoundID: {testMsgID: testMixedMessage}}, 204 | }, 205 | RoundIdCount: map[id.Round]uint64{testRoundID: 1}, 206 | }, 207 | clientRounds: map[uint64]*ClientRound{ 208 | uint64(testRoundID): {}, 209 | }, 210 | }, 211 | } 212 | 213 | msgs, hasRound, err := storage.GetMixedMessages(testRecipientID, testRoundID) 214 | if len(msgs) != 1 { 215 | t.Errorf("Retrieved unexpected number of messages: %d", len(msgs)) 216 | } 217 | if !hasRound { 218 | t.Errorf("Expected valid round!") 219 | } 220 | if err != nil { 221 | t.Errorf(err.Error()) 222 | } 223 | } 224 | 225 | // Invalid gateway path 226 | func TestStorage_GetMixedMessagesInvalidGw(t *testing.T) { 227 | testRoundID := id.Round(rand.Uint64()) 228 | testRecipientID := ephemeral.Id{1, 2, 3} 229 | 230 | storage := &Storage{ 231 | &MapImpl{ 232 | mixedMessages: MixedMessageMap{ 233 | RoundId: map[id.Round]map[int64]map[uint64]*MixedMessage{}, 234 | RecipientId: map[int64]map[id.Round]map[uint64]*MixedMessage{}, 235 | }, 236 | }, 237 | } 238 | 239 | msgs, hasRound, err := storage.GetMixedMessages(testRecipientID, testRoundID) 240 | if len(msgs) != 0 { 241 | t.Errorf("Retrieved unexpected number of messages: %d", len(msgs)) 242 | } 243 | if hasRound { 244 | t.Errorf("Expected invalid round!") 245 | } 246 | if err != nil { 247 | t.Errorf(err.Error()) 248 | } 249 | } 250 | -------------------------------------------------------------------------------- /cmd/knownRoundsWrapper_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | "bytes" 12 | "encoding/base64" 13 | "gitlab.com/elixxir/gateway/storage" 14 | "gitlab.com/elixxir/primitives/knownRounds" 15 | "gitlab.com/xx_network/primitives/rateLimiting" 16 | "reflect" 17 | "strings" 18 | "testing" 19 | "time" 20 | ) 21 | 22 | // Unit test of newKnownRoundsWrapper. 23 | func Test_newKnownRoundsWrapper(t *testing.T) { 24 | // Create new gateway instance 25 | params := Params{DevMode: true} 26 | params.messageRateLimitParams = &rateLimiting.MapParams{ 27 | Capacity: 10, 28 | LeakedTokens: 1, 29 | LeakDuration: 10 * time.Second, 30 | PollDuration: 10 * time.Second, 31 | BucketMaxAge: 10 * time.Second, 32 | } 33 | gw := NewGatewayInstance(params) 34 | roundCapacity := 255 35 | expected := &knownRoundsWrapper{ 36 | kr: knownRounds.NewKnownRound(roundCapacity), 37 | } 38 | expected.kr.Check(0) 39 | expected.marshalled = expected.kr.Marshal() 40 | expected.truncated = expected.marshalled 41 | expected.backupPeriod = 5 * time.Second 42 | expectedData := base64.StdEncoding.EncodeToString(expected.marshalled) 43 | 44 | krw, err := newKnownRoundsWrapper(roundCapacity, gw.storage) 45 | if err != nil { 46 | t.Errorf("newKnownRoundsWrapper returned an error: %+v", err) 47 | } 48 | 49 | expected.backupChan = krw.backupChan 50 | 51 | if !reflect.DeepEqual(expected, krw) { 52 | t.Errorf("newKnownRoundsWrapper failed to return the expected "+ 53 | "knownRoundsWrapper.\nexpected: %+v\nreceived: %+v", expected, krw) 54 | } 55 | 56 | data, err := gw.storage.GetStateValue(storage.KnownRoundsKey) 57 | if err != nil { 58 | t.Errorf("Failed to load saved knownRoundsWrapper: %+v", err) 59 | } 60 | 61 | if data != expectedData { 62 | t.Errorf("newKnownRoundsWrapper failed to save the expected "+ 63 | "knownRoundsWrapper.\nexpected: %+v\nreceived: %+v", expectedData, data) 64 | } 65 | } 66 | 67 | // Unit test of knownRoundsWrapper.check. 68 | func Test_knownRoundsWrapper_check(t *testing.T) { 69 | params := Params{DevMode: true} 70 | params.messageRateLimitParams = &rateLimiting.MapParams{ 71 | Capacity: 10, 72 | LeakedTokens: 1, 73 | LeakDuration: 10 * time.Second, 74 | PollDuration: 10 * time.Second, 75 | BucketMaxAge: 10 * time.Second, 76 | } 77 | gw := NewGatewayInstance(params) 78 | krw, err := newKnownRoundsWrapper(10, gw.storage) 79 | if err != nil { 80 | t.Errorf("Failed to create new knownRoundsWrapper: %+v", err) 81 | } 82 | krw.backupPeriod = 0 83 | 84 | err = krw.check(10) 85 | if err != nil { 86 | t.Errorf("check returned an error: %+v", err) 87 | } 88 | 89 | if !bytes.Equal(krw.marshalled, krw.kr.Marshal()) { 90 | t.Errorf("check failed to save the expected marshalled bytes."+ 91 | "\nexpected: %+v\nreceived: %+v", krw.kr.Marshal(), krw.marshalled) 92 | } 93 | time.Sleep(time.Second) 94 | 95 | data, err := gw.storage.GetStateValue(storage.KnownRoundsKey) 96 | if err != nil { 97 | t.Errorf("Failed to load saved knownRoundsWrapper: %+v", err) 98 | } 99 | 100 | expectedData := base64.StdEncoding.EncodeToString(krw.marshalled) 101 | if data != expectedData { 102 | t.Errorf("check failed to save the expected knownRoundsWrapper."+ 103 | "\nexpected: %+v\nreceived: %+v", expectedData, data) 104 | } 105 | } 106 | 107 | // Unit test of knownRoundsWrapper.forceCheck. 108 | func Test_knownRoundsWrapper_forceCheck(t *testing.T) { 109 | params := Params{DevMode: true} 110 | params.messageRateLimitParams = &rateLimiting.MapParams{ 111 | Capacity: 10, 112 | LeakedTokens: 1, 113 | LeakDuration: 10 * time.Second, 114 | PollDuration: 10 * time.Second, 115 | BucketMaxAge: 10 * time.Second, 116 | } 117 | gw := NewGatewayInstance(params) 118 | krw, err := newKnownRoundsWrapper(10, gw.storage) 119 | if err != nil { 120 | t.Errorf("Failed to create new knownRoundsWrapper: %+v", err) 121 | } 122 | krw.backupPeriod = 0 123 | 124 | err = krw.forceCheck(10) 125 | if err != nil { 126 | t.Errorf("forceCheck returned an error: %+v", err) 127 | } 128 | time.Sleep(time.Second) 129 | 130 | if !bytes.Equal(krw.marshalled, krw.kr.Marshal()) { 131 | t.Errorf("forceCheck failed to save the expected marshalled bytes."+ 132 | "\nexpected: %+v\nreceived: %+v", krw.kr.Marshal(), krw.marshalled) 133 | } 134 | 135 | data, err := gw.storage.GetStateValue(storage.KnownRoundsKey) 136 | if err != nil { 137 | t.Errorf("Failed to load saved knownRoundsWrapper: %+v", err) 138 | } 139 | 140 | expectedData := base64.StdEncoding.EncodeToString(krw.marshalled) 141 | if data != expectedData { 142 | t.Errorf("forceCheck failed to save the expected knownRoundsWrapper."+ 143 | "\nexpected: %+v\nreceived: %+v", expectedData, data) 144 | } 145 | } 146 | 147 | // Unit test of knownRoundsWrapper.forceCheck. 148 | func Test_knownRoundsWrapper_getMarshal(t *testing.T) { 149 | params := Params{DevMode: true} 150 | params.messageRateLimitParams = &rateLimiting.MapParams{ 151 | Capacity: 10, 152 | LeakedTokens: 1, 153 | LeakDuration: 10 * time.Second, 154 | PollDuration: 10 * time.Second, 155 | BucketMaxAge: 10 * time.Second, 156 | } 157 | gw := NewGatewayInstance(params) 158 | krw, err := newKnownRoundsWrapper(10, gw.storage) 159 | if err != nil { 160 | t.Errorf("Failed to create new knownRoundsWrapper: %+v", err) 161 | } 162 | 163 | if !bytes.Equal(krw.getMarshal(), krw.kr.Marshal()) { 164 | t.Errorf("getMarshal did not return the expected bytes."+ 165 | "\nexpected: %+v\nreceived: %+v", krw.kr.Marshal(), krw.getMarshal()) 166 | } 167 | } 168 | 169 | // Tests that a knownRoundsWrapper that is saved and loaded matches the 170 | // original. 171 | func Test_knownRoundsWrapper_save_load(t *testing.T) { 172 | params := Params{DevMode: true} 173 | params.messageRateLimitParams = &rateLimiting.MapParams{ 174 | Capacity: 10, 175 | LeakedTokens: 1, 176 | LeakDuration: 10 * time.Second, 177 | PollDuration: 10 * time.Second, 178 | BucketMaxAge: 10 * time.Second, 179 | } 180 | gw := NewGatewayInstance(params) 181 | krw, err := newKnownRoundsWrapper(10, gw.storage) 182 | if err != nil { 183 | t.Errorf("Failed to create new knownRoundsWrapper: %+v", err) 184 | } 185 | 186 | err = krw.save() 187 | if err != nil { 188 | t.Errorf("save retuned an error: %+v", err) 189 | } 190 | 191 | loadedKrw := &knownRoundsWrapper{ 192 | kr: knownRounds.NewKnownRound(10), 193 | } 194 | 195 | err = loadedKrw.load(gw.storage) 196 | if err != nil { 197 | t.Errorf("load retuned an error: %+v", err) 198 | } 199 | 200 | if !bytes.Equal(loadedKrw.marshalled, krw.marshalled) { 201 | t.Errorf("Saved and loaded knownRoundsWrapper does not match original."+ 202 | "\nexpected: %+v\nreceived: %+v", krw, loadedKrw) 203 | } 204 | } 205 | 206 | // Tests that knownRoundsWrapper.load returns an error if the state value cannot 207 | // be found 208 | func Test_knownRoundsWrapper_load_GetStateValueError(t *testing.T) { 209 | store, err := storage.NewStorage("", "", "", "", "", true) 210 | if err != nil { 211 | t.Fatalf("failed to create new storage: %+v", err) 212 | } 213 | krw := &knownRoundsWrapper{kr: knownRounds.NewKnownRound(10)} 214 | expectedErr := strings.SplitN(storageGetErr, "%", 2)[0] 215 | 216 | err = krw.load(store) 217 | if err == nil || !strings.Contains(err.Error(), expectedErr) { 218 | t.Errorf("load did not return the expected error."+ 219 | "\nexpected: %s\nreceived: %+v", expectedErr, err) 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /cmd/poll.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Contains polling-related functionality 9 | 10 | package cmd 11 | 12 | import ( 13 | "gitlab.com/xx_network/primitives/id" 14 | "gitlab.com/xx_network/primitives/netTime" 15 | "time" 16 | 17 | "github.com/pkg/errors" 18 | jww "github.com/spf13/jwalterweatherman" 19 | "gitlab.com/elixxir/comms/gateway" 20 | pb "gitlab.com/elixxir/comms/mixmessages" 21 | "gitlab.com/elixxir/comms/network" 22 | "gitlab.com/elixxir/primitives/version" 23 | "gitlab.com/xx_network/comms/connect" 24 | "gitlab.com/xx_network/primitives/id/ephemeral" 25 | "gitlab.com/xx_network/primitives/ndf" 26 | ) 27 | 28 | // Handler for a client's poll to a gateway. Returns all the last updates and known rounds 29 | func (gw *Instance) Poll(clientRequest *pb.GatewayPoll) ( 30 | *pb.GatewayPollResponse, error) { 31 | // Record the beginning of Poll processing; returned with the response 32 | startTime := netTime.Now() 33 | 34 | // Nil check to check for valid clientRequest 35 | if clientRequest == nil { 36 | return &pb.GatewayPollResponse{}, errors.Errorf( 37 | "Poll() clientRequest is empty") 38 | } 39 | 40 | // Make sure Gateway network instance is not nil 41 | if gw.NetInf == nil { 42 | return &pb.GatewayPollResponse{}, errors.New(ndf.NO_NDF) 43 | } 44 | 45 | // Get version sent from client 46 | clientVersion, err := version.ParseVersion(string(clientRequest.ClientVersion)) 47 | if err != nil { 48 | return &pb.GatewayPollResponse{}, errors.Errorf( 49 | "Unable to ParseVersion for clientRequest: %+v", err) 50 | } 51 | // Get version from NDF 52 | expectedClientVersion, err := version.ParseVersion(gw.NetInf.GetFullNdf().Get().ClientVersion) 53 | if err != nil { 54 | return &pb.GatewayPollResponse{}, errors.Errorf( 55 | "Unable to ParseVersion for gateway's NDF: %+v", err) 56 | } 57 | // Check that the two versions are compatible 58 | if version.IsCompatible(expectedClientVersion, clientVersion) == false { 59 | return &pb.GatewayPollResponse{}, errors.Errorf( 60 | "client version \"%s\" was not compatible with NDF defined minimum version", clientRequest.ClientVersion) 61 | } 62 | 63 | earliestRoundId, _, _, err := gw.GetEarliestRound() 64 | if err != nil { 65 | return &pb.GatewayPollResponse{}, errors.WithMessage(err, "Failed to "+ 66 | "retrieve earliest round info, no state currently exists with this gateway") 67 | } 68 | 69 | // Check if the clientID is populated and valid 70 | receptionId, err := ephemeral.Marshal(clientRequest.GetReceptionID()) 71 | if err != nil { 72 | return &pb.GatewayPollResponse{}, errors.Errorf( 73 | "Poll() - Valid ReceptionID required: %+v", err) 74 | } 75 | 76 | // Determine Client epoch range 77 | startEpoch, err := GetEpochEdge(time.Unix(0, clientRequest.StartTimestamp).UnixNano(), gw.period) 78 | if err != nil { 79 | return &pb.GatewayPollResponse{}, errors.WithMessage(err, "Failed to "+ 80 | "handle client poll due to invalid start timestamp") 81 | } 82 | endEpoch, err := GetEpochEdge(time.Unix(0, clientRequest.EndTimestamp).UnixNano(), gw.period) 83 | if err != nil { 84 | return &pb.GatewayPollResponse{}, errors.WithMessage(err, "Failed to "+ 85 | "handle client poll due to invalid end timestamp") 86 | } 87 | 88 | // get the known rounds before the client filters are received, otherwise there can 89 | // be a race condition because known rounds is updated after the bloom filters, 90 | // so you can get a known rounds that denotes an updated bloom filter while 91 | // it was not received 92 | var knownRounds []byte 93 | if gw.krw.needsTruncated(id.Round(clientRequest.LastRound)) { 94 | knownRounds = gw.krw.truncateMarshal() 95 | } else { 96 | knownRounds = gw.krw.getMarshal() 97 | } 98 | jww.TRACE.Printf("Poll retrieved knownrounds, last checked: %d", gw.krw.getLastChecked()) 99 | 100 | // These errors are suppressed, as DB errors shouldn't go to client 101 | // and if there is trouble getting filters returned, nil filters 102 | // are returned to the client. Debug to avoid message spam. 103 | clientFilters, err := gw.storage.GetClientBloomFilters( 104 | receptionId, startEpoch, endEpoch) 105 | jww.DEBUG.Printf("Adding %d client filters for %d", len(clientFilters), receptionId.Int64()) 106 | if err != nil { 107 | jww.DEBUG.Printf("Could not get filters in range %d - %d for %d when polling: %v", startEpoch, endEpoch, receptionId.Int64(), err) 108 | } 109 | 110 | // Build ClientBlooms metadata 111 | filtersMsg := &pb.ClientBlooms{ 112 | Period: gw.period, 113 | FirstTimestamp: GetEpochTimestamp(startEpoch, gw.period), 114 | } 115 | 116 | if len(clientFilters) > 0 { 117 | filtersMsg.Filters = make([]*pb.ClientBloom, 0, endEpoch-startEpoch+1) 118 | // Build ClientBloomFilter list for client 119 | for _, f := range clientFilters { 120 | filtersMsg.Filters = append(filtersMsg.Filters, &pb.ClientBloom{ 121 | Filter: f.Filter, 122 | FirstRound: f.FirstRound, 123 | RoundRange: f.RoundRange, 124 | }) 125 | } 126 | } 127 | 128 | // Exclude the NDF and network round updates on client request 129 | if clientRequest.GetDisableUpdates() { 130 | return &pb.GatewayPollResponse{ 131 | KnownRounds: knownRounds, 132 | Filters: filtersMsg, 133 | EarliestRound: earliestRoundId, 134 | ReceivedTs: startTime.UnixNano(), 135 | GatewayDelay: int64(netTime.Now().Sub(startTime)), 136 | }, nil 137 | } 138 | 139 | var netDef *pb.NDF 140 | var updates []*pb.RoundInfo 141 | isSame := gw.NetInf.GetPartialNdf().CompareHash(clientRequest.Partial.Hash) 142 | if !isSame { 143 | netDef = gw.NetInf.GetPartialNdf().GetPb() 144 | } else if clientRequest.FastPolling { 145 | // Get the range of updates from the filtered updates structure for client 146 | // and with an EDDSA signature 147 | updates = gw.filteredUpdates.GetRoundUpdates(int(clientRequest.LastUpdate)) 148 | 149 | } else { 150 | // Get the range of updates from the consensus object, with all updates 151 | // and the RSA Signature 152 | updates = gw.NetInf.GetRoundUpdates(int(clientRequest.LastUpdate)) 153 | } 154 | 155 | return &pb.GatewayPollResponse{ 156 | PartialNDF: netDef, 157 | Updates: updates, 158 | KnownRounds: knownRounds, 159 | Filters: filtersMsg, 160 | EarliestRound: earliestRoundId, 161 | ReceivedTs: startTime.UnixNano(), 162 | GatewayDelay: int64(netTime.Now().Sub(startTime)), 163 | }, nil 164 | } 165 | 166 | // PollServer sends a poll message to the server and returns a response. 167 | func PollServer(conn *gateway.Comms, pollee *connect.Host, ndf, 168 | partialNdf *network.SecuredNdf, lastUpdate uint64, addr string) ( 169 | *pb.ServerPollResponse, error) { 170 | jww.TRACE.Printf("Address being sent to server: [%v]", addr) 171 | var ndfHash, partialNdfHash *pb.NDFHash 172 | ndfHash = &pb.NDFHash{ 173 | Hash: make([]byte, 0), 174 | } 175 | 176 | partialNdfHash = &pb.NDFHash{ 177 | Hash: make([]byte, 0), 178 | } 179 | 180 | if ndf != nil { 181 | ndfHash = &pb.NDFHash{Hash: ndf.GetHash()} 182 | } 183 | if partialNdf != nil { 184 | partialNdfHash = &pb.NDFHash{Hash: partialNdf.GetHash()} 185 | } 186 | 187 | pollMsg := &pb.ServerPoll{ 188 | Full: ndfHash, 189 | Partial: partialNdfHash, 190 | LastUpdate: lastUpdate, 191 | Error: "", 192 | GatewayAddress: addr, 193 | GatewayVersion: currentVersion, 194 | } 195 | 196 | resp, err := conn.SendPoll(pollee, pollMsg) 197 | return resp, err 198 | } 199 | -------------------------------------------------------------------------------- /cmd/bloomGossip.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Contains gossip methods specific to bloom filter gossiping 9 | 10 | package cmd 11 | 12 | import ( 13 | "github.com/golang/protobuf/proto" 14 | "github.com/pkg/errors" 15 | jww "github.com/spf13/jwalterweatherman" 16 | pb "gitlab.com/elixxir/comms/mixmessages" 17 | "gitlab.com/xx_network/comms/gossip" 18 | "gitlab.com/xx_network/primitives/id" 19 | "gitlab.com/xx_network/primitives/id/ephemeral" 20 | "sync" 21 | "sync/atomic" 22 | "time" 23 | ) 24 | 25 | const errorDelimiter = "; " 26 | const bloomUploadRetries = 5 27 | 28 | // Initialize fields required for the gossip protocol specialized to bloom filters 29 | func (gw *Instance) InitBloomGossip() { 30 | flags := gossip.DefaultProtocolFlags() 31 | flags.FanOut = 25 32 | flags.MaximumReSends = 2 33 | flags.NumParallelSends = 1000 34 | flags.Fingerprinter = func(msg *gossip.GossipMsg) gossip.Fingerprint { 35 | preSum := append([]byte(msg.Tag), msg.Payload...) 36 | return gossip.NewFingerprint(preSum) 37 | } 38 | // Register gossip protocol for bloom filters 39 | gw.Comms.Manager.NewGossip(BloomFilterGossip, flags, 40 | gw.gossipBloomFilterReceive, gw.gossipVerify, nil) 41 | } 42 | 43 | // GossipBloom builds a gossip message containing all the recipient IDs 44 | // within the bloom filter and gossips it to all peers 45 | func (gw *Instance) GossipBloom(recipients map[ephemeral.Id]interface{}, roundId id.Round, roundTimestamp int64) error { 46 | var err error 47 | 48 | // Retrieve gossip protocol 49 | gossipProtocol, ok := gw.Comms.Manager.Get(BloomFilterGossip) 50 | if !ok { 51 | return errors.Errorf("Unable to get gossip protocol.") 52 | } 53 | 54 | // Build the message 55 | gossipMsg := &gossip.GossipMsg{ 56 | Tag: BloomFilterGossip, 57 | Origin: gw.Comms.GetId().Marshal(), 58 | } 59 | 60 | // Add the GossipMsg payload 61 | gossipMsg.Payload, err = buildGossipPayloadBloom(recipients, roundId, uint64(roundTimestamp)) 62 | if err != nil { 63 | return errors.Errorf("Unable to build gossip payload: %+v", err) 64 | } 65 | 66 | // Add the GossipMsg signature 67 | gossipMsg.Signature, err = buildGossipSignature(gossipMsg, gw.Comms.GetPrivateKey()) 68 | if err != nil { 69 | return errors.Errorf("Unable to build gossip signature: %+v", err) 70 | } 71 | // Gossip the message 72 | numPeers, errs := gossipProtocol.Gossip(gossipMsg) 73 | 74 | jww.INFO.Printf("Gossiping Blooms for round %v at ts %s", roundId, 75 | time.Unix(0, roundTimestamp)) 76 | 77 | // Return any errors up the stack 78 | if len(errs) != 0 { 79 | jww.TRACE.Printf("Failed to rate limit gossip to: %v", errs) 80 | return errors.Errorf("Could not send to %d out of %d peers", len(errs), numPeers) 81 | } 82 | return nil 83 | } 84 | 85 | // Receive function for Gossip messages regarding bloom filters. 86 | func (gw *Instance) gossipBloomFilterReceive(msg *gossip.GossipMsg) error { 87 | 88 | received := time.Now() 89 | 90 | // Unmarshal the Recipients data 91 | payloadMsg := &pb.Recipients{} 92 | err := proto.Unmarshal(msg.Payload, payloadMsg) 93 | if err != nil { 94 | return errors.Errorf("Could not unmarshal message into expected format: %s", err) 95 | } 96 | 97 | roundID := id.Round(payloadMsg.RoundID) 98 | 99 | var wg sync.WaitGroup 100 | 101 | epoch := GetEpoch(int64(payloadMsg.RoundTS), gw.period) 102 | 103 | totalNumAttempts := uint32(0) 104 | failedInsert := uint32(0) 105 | 106 | gw.bloomFilterGossip.Lock() 107 | defer gw.bloomFilterGossip.Unlock() 108 | 109 | // WARNING: this needs function IDENTICALLY to the code in ProcessCompletedBatch in 110 | // gateway.go, but due to this being hot code, has subtle differences which 111 | // lead to a different implementation 112 | //Go through each of the recipients 113 | for _, recipient := range payloadMsg.RecipientIds { 114 | wg.Add(1) 115 | go func(localRecipient []byte) { 116 | defer wg.Done() 117 | // Marshal the id 118 | recipientId, localErr := ephemeral.Marshal(localRecipient) 119 | if localErr != nil { 120 | jww.WARN.Printf("Failed to unmarshal recipient %v for "+ 121 | "bloom gossip for round %d: %s", localRecipient, roundID, localErr) 122 | return 123 | } 124 | // retry insertion into the database in the event that there is an 125 | // insertion on the same ephemeral id by multiple rounds at the same 126 | // time, in which case all but one will fail 127 | i := 0 128 | for ; i < bloomUploadRetries && (localErr != nil || i == 0); i++ { 129 | localErr = gw.UpsertFilter(recipientId, roundID, epoch) 130 | if localErr != nil { 131 | jww.WARN.Printf("Failed to upsert recipient %d bloom on "+ 132 | "round %d on attempt %d: %s", localRecipient, roundID, i, localErr.Error()) 133 | } 134 | } 135 | 136 | atomic.AddUint32(&totalNumAttempts, uint32(i)) 137 | if localErr != nil { 138 | jww.ERROR.Printf("Failed to upsert recipient %d bloom on "+ 139 | "round %d on all attemps(%d/%d): %+v", localRecipient, roundID, i, i, localErr) 140 | atomic.AddUint32(&failedInsert, 1) 141 | } 142 | }(recipient) 143 | } 144 | wg.Wait() 145 | 146 | finishedInsert := time.Now() 147 | averageAttempts := float32(totalNumAttempts) / float32(len(payloadMsg.RecipientIds)) 148 | 149 | if failedInsert == 0 { 150 | //denote the reception in known rounds 151 | err = gw.krw.forceCheck(roundID) 152 | if err != nil { 153 | jww.ERROR.Printf("Gossip received not recorded due to known rounds error for "+ 154 | "round %d with %d recipients at %s: "+ 155 | "\n\t inserts finished at %s, KR insert finished at %s, "+ 156 | "\n]t round started at ts %s, average attempts: %f (total: %d): %+v", roundID, 157 | len(payloadMsg.RecipientIds), received, finishedInsert, time.Now(), 158 | time.Unix(0, int64(payloadMsg.RoundTS)), averageAttempts, totalNumAttempts, err) 159 | } else { 160 | jww.INFO.Printf("Gossip received for round %d with %d recipients at %s: "+ 161 | "\n\t inserts finished at %s, KR insert finished at %s, "+ 162 | "\n]t round started at ts %s, average attempts: %f (total: %d)", roundID, 163 | len(payloadMsg.RecipientIds), received, finishedInsert, time.Now(), 164 | time.Unix(0, int64(payloadMsg.RoundTS)), averageAttempts, totalNumAttempts) 165 | } 166 | } else { 167 | jww.ERROR.Printf("Gossip received not recorded due to bloom upsert failures for %d recipeints"+ 168 | " for round %d with %d recipients at %s: "+ 169 | "\n\t inserts finished at %s, KR insert finished at %s, "+ 170 | "\n]t round started at ts %s, average attempts: %f (total: %d)", failedInsert, roundID, 171 | len(payloadMsg.RecipientIds), received, finishedInsert, time.Now(), 172 | time.Unix(0, int64(payloadMsg.RoundTS)), averageAttempts, totalNumAttempts) 173 | } 174 | 175 | jww.TRACE.Printf("Gossip received for RID %d, lastChecked: %d", roundID, gw.krw.getLastChecked()) 176 | 177 | return nil 178 | } 179 | 180 | // Helper function used to convert recipientIds into a GossipMsg payload 181 | func buildGossipPayloadBloom(recipientIDs map[ephemeral.Id]interface{}, roundId id.Round, roundTS uint64) ([]byte, error) { 182 | // Iterate over the map, placing keys back in a list 183 | // without any duplicates 184 | i := 0 185 | recipients := make([][]byte, len(recipientIDs)) 186 | for key := range recipientIDs { 187 | recipients[i] = make([]byte, len(key)) 188 | copy(recipients[i], key[:]) 189 | i++ 190 | } 191 | 192 | // Build the message payload and return 193 | payloadMsg := &pb.Recipients{ 194 | RecipientIds: recipients, 195 | RoundID: uint64(roundId), 196 | RoundTS: roundTS, 197 | } 198 | return proto.Marshal(payloadMsg) 199 | } 200 | -------------------------------------------------------------------------------- /cmd/filteredUpdates_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package cmd 9 | 10 | import ( 11 | "crypto/rand" 12 | "gitlab.com/elixxir/comms/mixmessages" 13 | "gitlab.com/elixxir/comms/network" 14 | ds "gitlab.com/elixxir/comms/network/dataStructures" 15 | "gitlab.com/elixxir/comms/testutils" 16 | "gitlab.com/elixxir/primitives/states" 17 | "gitlab.com/xx_network/crypto/signature/ec" 18 | "gitlab.com/xx_network/primitives/ndf" 19 | "testing" 20 | ) 21 | 22 | func TestFilteredUpdates_RoundUpdate(t *testing.T) { 23 | validUpdateId := uint64(4) 24 | validMsg := &mixmessages.RoundInfo{ 25 | ID: 2, 26 | UpdateID: validUpdateId, 27 | State: uint32(states.COMPLETED), 28 | BatchSize: 8, 29 | Timestamps: []uint64{0, 1, 2, 3, 4, 5}, 30 | } 31 | 32 | ecPrivKey, err := ec.NewKeyPair(rand.Reader) 33 | if err != nil { 34 | t.Fatalf("Failed to generate test key: %v", err) 35 | } 36 | 37 | pubKey := ecPrivKey.GetPublic() 38 | 39 | fullNdf, err := ds.NewNdf(&ndf.NetworkDefinition{ 40 | Registration: ndf.Registration{EllipticPubKey: pubKey.MarshalText()}, 41 | }) 42 | if err != nil { 43 | t.Fatalf("Failed to generate a mock ndf: %v", err) 44 | } 45 | 46 | netInf, err := network.NewInstance(gatewayInstance.Comms.ProtoComms, fullNdf.Get(), fullNdf.Get(), nil, network.Lazy, true) 47 | testFilter, err := NewFilteredUpdates(netInf) 48 | if err != nil { 49 | t.Fatalf("Failed to create filtered update: %v", err) 50 | } 51 | err = testutils.SignRoundInfoEddsa(validMsg, ecPrivKey, t) 52 | if err != nil { 53 | t.Fatalf("Failed to sign message: %v", err) 54 | } 55 | 56 | err = testFilter.RoundUpdate(validMsg) 57 | // Fixme 58 | /* if err == nil { 59 | t.Error("Should have failed to veNewSecuredNdfrify") 60 | }*/ 61 | 62 | t.Logf("err update: %v", err) 63 | 64 | retrieved, err := testFilter.GetRoundUpdate(int(validMsg.UpdateID)) 65 | if err != nil || retrieved == nil { 66 | t.Logf("retrieved: %v", retrieved) 67 | t.Logf("err: %v", err) 68 | t.Errorf("Should have stored msg with state %s", states.Round(validMsg.State)) 69 | } 70 | 71 | invalidUpdateId := uint64(5) 72 | invalidMsg := &mixmessages.RoundInfo{ 73 | ID: 2, 74 | UpdateID: invalidUpdateId, 75 | State: uint32(states.PRECOMPUTING), 76 | BatchSize: 8, 77 | } 78 | 79 | err = testFilter.RoundUpdate(invalidMsg) 80 | if err != nil { 81 | t.Errorf("Failed to update round: %v", err) 82 | } 83 | 84 | retrieved, err = testFilter.GetRoundUpdate(int(invalidUpdateId)) 85 | if err == nil || retrieved != nil { 86 | t.Errorf("Should not have inserted round with state %s", 87 | states.Round(invalidMsg.State)) 88 | } 89 | 90 | } 91 | 92 | func TestFilteredUpdates_RoundUpdates(t *testing.T) { 93 | validUpdateId := uint64(4) 94 | validMsg := &mixmessages.RoundInfo{ 95 | ID: 2, 96 | UpdateID: validUpdateId, 97 | State: uint32(states.COMPLETED), 98 | BatchSize: 8, 99 | Timestamps: []uint64{0, 1, 2, 3, 4, 5}, 100 | } 101 | 102 | ellipticKey, err := ec.NewKeyPair(rand.Reader) 103 | if err != nil { 104 | t.Fatalf("Failed to generate test ellitpic key: %v", err) 105 | } 106 | 107 | fullNdf, err := ds.NewNdf(&ndf.NetworkDefinition{ 108 | Registration: ndf.Registration{ 109 | EllipticPubKey: ellipticKey.GetPublic().MarshalText(), 110 | }, 111 | }) 112 | if err != nil { 113 | t.Fatalf("Failed to generate a mock ndf: %v", err) 114 | } 115 | 116 | netInf, err := network.NewInstance(gatewayInstance.Comms.ProtoComms, fullNdf.Get(), fullNdf.Get(), nil, network.Lazy, true) 117 | if err != nil { 118 | t.Fatalf("Failed to generate instance: %v", err) 119 | } 120 | 121 | testFilter, err := NewFilteredUpdates(netInf) 122 | if err != nil { 123 | t.Fatalf("Failed to create filtered update: %v", err) 124 | } 125 | invalidUpdateId := uint64(5) 126 | invalidMsg := &mixmessages.RoundInfo{ 127 | ID: 2, 128 | UpdateID: invalidUpdateId, 129 | State: uint32(states.PRECOMPUTING), 130 | BatchSize: 8, 131 | Timestamps: []uint64{0, 1}, 132 | } 133 | 134 | err = testutils.SignRoundInfoEddsa(validMsg, ellipticKey, t) 135 | if err != nil { 136 | t.Fatalf("Failed to sign message: %v", err) 137 | } 138 | 139 | err = testutils.SignRoundInfoEddsa(invalidMsg, ellipticKey, t) 140 | if err != nil { 141 | t.Fatalf("Failed to sign message: %v", err) 142 | } 143 | 144 | roundUpdates := []*mixmessages.RoundInfo{validMsg, invalidMsg} 145 | 146 | err = testFilter.RoundUpdates(roundUpdates) 147 | if err != nil { 148 | t.Error("Should have failed to get perm host") 149 | } 150 | 151 | retrieved, err := testFilter.GetRoundUpdate(int(validUpdateId)) 152 | if err != nil || retrieved == nil { 153 | t.Errorf("Should have stored msg with state %s", states.Round(validMsg.State)) 154 | } 155 | 156 | retrieved, err = testFilter.GetRoundUpdate(int(invalidUpdateId)) 157 | if err == nil || retrieved != nil { 158 | t.Errorf("Should not have inserted round with state %s", 159 | states.Round(invalidMsg.State)) 160 | } 161 | 162 | } 163 | 164 | func TestFilteredUpdates_GetRoundUpdate(t *testing.T) { 165 | ellipticKey, err := ec.NewKeyPair(rand.Reader) 166 | if err != nil { 167 | t.Fatalf("Failed to generate test ellitpic key: %v", err) 168 | } 169 | 170 | fullNdf, err := ds.NewNdf(&ndf.NetworkDefinition{ 171 | Registration: ndf.Registration{ 172 | EllipticPubKey: ellipticKey.GetPublic().MarshalText(), 173 | }, 174 | }) 175 | if err != nil { 176 | t.Fatalf("Failed to generate a mock ndf: %v", err) 177 | } 178 | 179 | netInf, err := network.NewInstance(gatewayInstance.Comms.ProtoComms, fullNdf.Get(), fullNdf.Get(), nil, network.Lazy, true) 180 | if err != nil { 181 | t.Fatalf("Failed to generate instance: %v", err) 182 | } 183 | 184 | testFilter, err := NewFilteredUpdates(netInf) 185 | if err != nil { 186 | t.Fatalf("Failed to create filtered update: %v", err) 187 | } 188 | ri := &mixmessages.RoundInfo{ 189 | ID: uint64(1), 190 | UpdateID: uint64(1), 191 | State: uint32(states.QUEUED), 192 | Timestamps: []uint64{0, 1, 2, 3}, 193 | } 194 | testutils.SignRoundInfoEddsa(ri, ellipticKey, t) 195 | rnd := ds.NewRound(ri, nil, ellipticKey.GetPublic()) 196 | 197 | _ = testFilter.updates.AddRound(rnd) 198 | r, err := testFilter.GetRoundUpdate(1) 199 | if err != nil || r == nil { 200 | t.Errorf("Failed to retrieve round update: %+v", err) 201 | } 202 | } 203 | 204 | func TestFilteredUpdates_GetRoundUpdates(t *testing.T) { 205 | ellipticKey, err := ec.NewKeyPair(rand.Reader) 206 | if err != nil { 207 | t.Fatalf("Failed to generate test ellitpic key: %v", err) 208 | } 209 | 210 | fullNdf, err := ds.NewNdf(&ndf.NetworkDefinition{ 211 | Registration: ndf.Registration{ 212 | EllipticPubKey: ellipticKey.GetPublic().MarshalText(), 213 | }, 214 | }) 215 | if err != nil { 216 | t.Fatalf("Failed to generate a mock ndf: %v", err) 217 | } 218 | 219 | netInf, err := network.NewInstance(gatewayInstance.Comms.ProtoComms, fullNdf.Get(), fullNdf.Get(), nil, network.Lazy, true) 220 | if err != nil { 221 | t.Fatalf("Failed to generate instance: %v", err) 222 | } 223 | 224 | testFilter, err := NewFilteredUpdates(netInf) 225 | if err != nil { 226 | t.Fatalf("Failed to create filtered update: %v", err) 227 | } 228 | roundInfoOne := &mixmessages.RoundInfo{ 229 | ID: uint64(1), 230 | UpdateID: uint64(2), 231 | State: uint32(states.QUEUED), 232 | Timestamps: []uint64{0, 1, 2, 3}, 233 | } 234 | if err = testutils.SignRoundInfoEddsa(roundInfoOne, ellipticKey, t); err != nil { 235 | t.Fatalf("Failed to sign round info: %v", err) 236 | } 237 | roundInfoTwo := &mixmessages.RoundInfo{ 238 | ID: uint64(2), 239 | UpdateID: uint64(3), 240 | State: uint32(states.QUEUED), 241 | Timestamps: []uint64{0, 1, 2, 3}, 242 | } 243 | if err = testutils.SignRoundInfoEddsa(roundInfoTwo, ellipticKey, t); err != nil { 244 | t.Fatalf("Failed to sign round info: %v", err) 245 | } 246 | roundOne := ds.NewRound(roundInfoOne, nil, ellipticKey.GetPublic()) 247 | roundTwo := ds.NewRound(roundInfoTwo, nil, ellipticKey.GetPublic()) 248 | 249 | _ = testFilter.updates.AddRound(roundOne) 250 | _ = testFilter.updates.AddRound(roundTwo) 251 | r := testFilter.GetRoundUpdates(1) 252 | if len(r) == 0 { 253 | t.Errorf("Failed to retrieve round updates") 254 | } 255 | 256 | r = testFilter.GetRoundUpdates(2) 257 | if len(r) == 0 { 258 | t.Errorf("Failed to retrieve round updates") 259 | } 260 | 261 | r = testFilter.GetRoundUpdates(23) 262 | if len(r) != 0 { 263 | t.Errorf("Retrieved a round that was never inserted: %v", r) 264 | } 265 | 266 | } 267 | -------------------------------------------------------------------------------- /cmd/params.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Contains Params-related functionality 9 | 10 | package cmd 11 | 12 | import ( 13 | "fmt" 14 | jww "github.com/spf13/jwalterweatherman" 15 | "github.com/spf13/viper" 16 | "gitlab.com/elixxir/comms/publicAddress" 17 | "gitlab.com/xx_network/comms/gossip" 18 | "gitlab.com/xx_network/primitives/rateLimiting" 19 | "net" 20 | "strconv" 21 | "strings" 22 | "time" 23 | ) 24 | 25 | type Params struct { 26 | NodeAddress string `yaml:"cmixAddress"` 27 | Port int 28 | PublicAddress string // Gateway's public IP address (with port) 29 | ListeningAddress string // Gateway's local IP address (with port) 30 | CertPath string 31 | KeyPath string 32 | 33 | DbUsername string 34 | DbPassword string 35 | DbName string 36 | DbAddress string 37 | DbPort string 38 | 39 | ServerCertPath string `yaml:"cmixCertPath"` 40 | IDFPath string 41 | PermissioningCertPath string `yaml:"schedulingCertPath"` 42 | 43 | rateLimitParams *rateLimiting.MapParams 44 | messageRateLimitParams *rateLimiting.MapParams 45 | gossipFlags gossip.ManagerFlags 46 | 47 | DevMode bool 48 | DisableGossip bool 49 | 50 | HttpsCountry string 51 | AuthorizerAddress string 52 | AutocertIssueTimeout time.Duration 53 | // time.Duration used to calculate lower bound of when to replace TLS cert, based on its expiry 54 | CertReplaceWindow time.Duration 55 | // Maximum random delay for cert replacement after reaching the start of CertReplaceWindow 56 | MaxCertReplaceDelay time.Duration 57 | cleanupInterval time.Duration 58 | 59 | MinRegisteredNodes int 60 | } 61 | 62 | const ( 63 | // Default time period for checking storage for stored items older 64 | // than the retention period value 65 | cleanupIntervalDefault = 5 * time.Minute 66 | ) 67 | 68 | func InitParams(vip *viper.Viper) Params { 69 | if !validConfig { 70 | jww.FATAL.Panicf("Invalid Config File: %s", cfgFile) 71 | } 72 | 73 | // Print all config options 74 | jww.INFO.Printf("All config params: %+v", vip.AllKeys()) 75 | 76 | certPath = viper.GetString("certPath") 77 | if certPath == "" { 78 | jww.FATAL.Panicf("Gateway.yaml certPath is required, path provided is empty.") 79 | } 80 | 81 | idfPath = viper.GetString("idfPath") 82 | if idfPath == "" { 83 | jww.FATAL.Panicf("Gateway.yaml idfPath is required, path provided is empty.") 84 | } 85 | keyPath = viper.GetString("keyPath") 86 | 87 | var nodeAddress string 88 | if viper.IsSet("cmixAddress") { 89 | nodeAddress = viper.GetString("cmixAddress") 90 | } else if viper.IsSet("nodeAddress") { 91 | nodeAddress = viper.GetString("nodeAddress") 92 | } else { 93 | jww.FATAL.Panicf("Gateway.yaml cmixAddress is required, address provided is empty.") 94 | } 95 | 96 | if viper.IsSet("schedulingCertPath") { 97 | permissioningCertPath = viper.GetString("schedulingCertPath") 98 | } else if viper.IsSet("permissioningCertPath") { 99 | permissioningCertPath = viper.GetString("permissioningCertPath") 100 | } else { 101 | jww.FATAL.Panicf("Gateway.yaml schedulingCertPath is required, path provided is empty.") 102 | } 103 | 104 | gwPort = viper.GetInt("port") 105 | if gwPort == 0 { 106 | jww.FATAL.Panicf("Gateway.yaml port is required, provided port is empty/not set.") 107 | } 108 | 109 | if viper.IsSet("cmixCertPath") { 110 | serverCertPath = viper.GetString("cmixCertPath") 111 | } else if viper.IsSet("serverCertPath") { 112 | serverCertPath = viper.GetString("serverCertPath") 113 | } else { 114 | jww.FATAL.Panicf("Gateway.yaml cmixCertPath is required, path provided is empty.") 115 | } 116 | 117 | // Get gateway's public IP or use the IP override 118 | overrideIP := viper.GetString("overridePublicIP") 119 | gwAddress, err := publicAddress.GetIpOverride(overrideIP, gwPort) 120 | if err != nil { 121 | jww.FATAL.Panicf("Failed to get public IP: %+v", err) 122 | } 123 | 124 | // Construct listening address 125 | listeningIP := viper.GetString("listeningAddress") 126 | if listeningIP == "" { 127 | listeningIP = "0.0.0.0" 128 | } 129 | listeningAddress := net.JoinHostPort(listeningIP, strconv.Itoa(gwPort)) 130 | 131 | dbpass := viper.GetString("dbPassword") 132 | jww.INFO.Printf("config: %+v", viper.ConfigFileUsed()) 133 | ps := fmt.Sprintf("Params: \n %+v", vip.AllSettings()) 134 | ps = strings.ReplaceAll(ps, 135 | "dbpassword:"+dbpass, 136 | "dbpassword:[dbpass]") 137 | jww.INFO.Printf(ps) 138 | jww.INFO.Printf("Gateway port: %d", gwPort) 139 | jww.INFO.Printf("Gateway public IP: %s", gwAddress) 140 | jww.INFO.Printf("Gateway listening address: %s", listeningAddress) 141 | jww.INFO.Printf("Gateway node: %s", nodeAddress) 142 | 143 | // If the values aren't default, repopulate flag values with customized values 144 | // Otherwise use the default values 145 | gossipFlags := gossip.DefaultManagerFlags() 146 | if gossipFlags.BufferExpirationTime != bufferExpiration || 147 | gossipFlags.MonitorThreadFrequency != monitorThreadFrequency { 148 | 149 | gossipFlags = gossip.ManagerFlags{ 150 | BufferExpirationTime: bufferExpiration, 151 | MonitorThreadFrequency: monitorThreadFrequency, 152 | } 153 | } 154 | 155 | // Construct the rate limiting params 156 | bucketMapParams := &rateLimiting.MapParams{ 157 | Capacity: capacity, 158 | LeakedTokens: leakedTokens, 159 | LeakDuration: leakDuration, 160 | PollDuration: pollDuration, 161 | BucketMaxAge: bucketMaxAge, 162 | } 163 | 164 | messageLimitingParams := &rateLimiting.MapParams{ 165 | Capacity: 1, 166 | LeakedTokens: 1, 167 | LeakDuration: 2 * time.Second, 168 | PollDuration: pollDuration, 169 | BucketMaxAge: bucketMaxAge, 170 | } 171 | 172 | // Time to periodically check for old objects in storage 173 | viper.SetDefault("cleanupInterval", cleanupIntervalDefault) 174 | cleanupInterval := viper.GetDuration("cleanupInterval") 175 | 176 | // Obtain database connection info 177 | rawAddr := viper.GetString("dbAddress") 178 | var addr, port string 179 | if rawAddr != "" { 180 | addr, port, err = net.SplitHostPort(rawAddr) 181 | if err != nil { 182 | jww.FATAL.Panicf("Unable to get database port from %s: %+v", rawAddr, err) 183 | } 184 | } 185 | 186 | // Authorizer address 187 | authorizerAddressKey := "authorizerAddress" 188 | viper.SetDefault(authorizerAddressKey, "auth.mainnet.cmix.rip:11420") 189 | authorizerAddress := viper.GetString(authorizerAddressKey) 190 | 191 | autocertTimeoutKey := "autocertIssueTimeout" 192 | viper.SetDefault(autocertTimeoutKey, time.Hour) 193 | autocertTimeout := viper.GetDuration(autocertTimeoutKey) 194 | 195 | certReplaceWindowKey := "certReplaceWindow" 196 | viper.SetDefault(certReplaceWindowKey, time.Duration(30*24*time.Hour)) 197 | certReplaceWindow := viper.GetDuration(certReplaceWindowKey) 198 | 199 | maxCertReplaceDelayKey := "maxCertReplaceDelay" 200 | viper.SetDefault(maxCertReplaceDelayKey, time.Duration(5*24*time.Hour)) 201 | maxCertReplaceDelay := viper.GetDuration(maxCertReplaceDelayKey) 202 | 203 | minRegisteredNodesKey := "minRegisteredNodes" 204 | defaultMinRegisteredNodes := 3 205 | viper.SetDefault(minRegisteredNodesKey, defaultMinRegisteredNodes) 206 | minRegisteredNodes := viper.GetInt(minRegisteredNodesKey) 207 | 208 | if minRegisteredNodes != defaultMinRegisteredNodes { 209 | jww.ERROR.Printf("WARNING: MINIMUM REGISTERED NODES HAS BEEN "+ 210 | "CHANGED FROM DEFAULT (%d) TO %d", 211 | defaultMinRegisteredNodes, minRegisteredNodes) 212 | } 213 | 214 | return Params{ 215 | Port: gwPort, 216 | PublicAddress: gwAddress, 217 | ListeningAddress: listeningAddress, 218 | NodeAddress: nodeAddress, 219 | CertPath: certPath, 220 | KeyPath: keyPath, 221 | ServerCertPath: serverCertPath, 222 | IDFPath: idfPath, 223 | PermissioningCertPath: permissioningCertPath, 224 | gossipFlags: gossipFlags, 225 | rateLimitParams: bucketMapParams, 226 | messageRateLimitParams: messageLimitingParams, 227 | DbName: viper.GetString("dbName"), 228 | DbUsername: viper.GetString("dbUsername"), 229 | DbPassword: viper.GetString("dbPassword"), 230 | DbAddress: addr, 231 | DbPort: port, 232 | DevMode: viper.GetBool("devMode"), 233 | DisableGossip: viper.GetBool("disableGossip"), 234 | cleanupInterval: cleanupInterval, 235 | AuthorizerAddress: authorizerAddress, 236 | AutocertIssueTimeout: autocertTimeout, 237 | CertReplaceWindow: certReplaceWindow, 238 | MaxCertReplaceDelay: maxCertReplaceDelay, 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # elixxir/gateway 2 | 3 | [![pipeline status](https://gitlab.com/elixxir/gateway/badges/master/pipeline.svg)](https://gitlab.com/elixxir/gateway/commits/master) 4 | [![coverage report](https://gitlab.com/elixxir/gateway/badges/master/coverage.svg)](https://gitlab.com/elixxir/gateway/commits/master) 5 | 6 | ## Purpose 7 | 8 | Gateways are go-betweens for the servers and clients. They retain messages that 9 | have gone through the network for clients to fetch at their leisure, and send 10 | batches of unprocessed messages to the server team that will process them. 11 | 12 | Gateways are likely to acquire additional functions in the future, including 13 | load balancing and DDoS protection, and connecting to more than one node at 14 | a time. 15 | 16 | ## Running a Gateway 17 | 18 | To run the gateway: 19 | 20 | ``` 21 | go run main.go --config [configuration-file] 22 | ``` 23 | 24 | ## Example configuration file 25 | 26 | The Gateway configuration file must be named `gateway.yaml` and be located in 27 | one of the following directories: 28 | 1. `$HOME/.xxnetwork/` 29 | 2. `/opt/xxnetwork/` 30 | 3. `/etc/xxnetwork/` 31 | 32 | Gateway searches for the YAML file in that order and uses the first occurance 33 | found. 34 | 35 | Note: YAML prohibits the use of tabs because whitespace has meaning. 36 | 37 | ```yaml 38 | # Level of debugging to print (0 = info, 1 = debug, >1 = trace). (Default info) 39 | logLevel: 1 40 | 41 | # Path where log file will be saved. (Default "log/gateway.log") 42 | log: "/opt/xxnetwork/log/gateway.log" 43 | 44 | # Port for Gateway to listen on. Gateway must be the only listener on this port. 45 | # (Required) 46 | port: 22840 47 | 48 | # Local IP address of the Gateway, used for internal listening. Expects an IPv4 49 | # address without a port. (Default "0.0.0.0") 50 | listeningAddress: "" 51 | 52 | # The public IPv4 address of the Gateway, as reported to the network. When not 53 | # set, external IP address lookup services are used to set this value. If a 54 | # port is not included, then the port from the port flag is used instead. 55 | overridePublicIP: "" 56 | 57 | # The IP address of the machine running cMix that the Gateway communicates with. 58 | # Expects an IPv4 address with a port. (Required) 59 | cmixAddress: "0.0.0.0:11420" 60 | 61 | # Path to where the identity file (IDF) is saved. The IDF stores the Gateway's 62 | # network identity. This is used by the wrapper management script. (Required) 63 | idfPath: "/opt/xxnetwork/cred/gateway-IDF.json" 64 | 65 | # Path to the private key associated with the self-signed TLS certificate. 66 | # (Required) 67 | keyPath: "/opt/xxnetwork/cred/gateway-key.key" 68 | 69 | # Path to the self-signed TLS certificate for Gateway. Expects PEM format. 70 | # (Required) 71 | certPath: "/opt/xxnetwork/cred/gateway-cert.crt" 72 | 73 | # Path to the self-signed TLS certificate for cMix. Expects PEM format. 74 | # (Required) 75 | cmixCertPath: "/opt/xxnetwork/cred/cmix-cert.crt" 76 | 77 | # Path to the self-signed TLS certificate for the Scheduling server. Expects 78 | # PEM format. (Required) 79 | schedulingCertPath: "/opt/xxnetwork/cred/scheduling-cert.crt" 80 | 81 | # Database connection information. (Required) 82 | dbName: "cmix_gateway" 83 | dbAddress: "0.0.0.0:5432" 84 | dbUsername: "cmix" 85 | dbPassword: "" 86 | 87 | # Flags listed below should be left as their defaults unless you know what you 88 | # are doing. 89 | 90 | # How often the periodic storage tracker checks for items older than the 91 | # retention period value. Expects duration in "s", "m", "h". (Defaults to 5 92 | # minutes) 93 | cleanupInterval: 5m 94 | 95 | # Flags for gossip protocol 96 | 97 | # How long a message record should last in the gossip buffer if it arrives 98 | # before the Gateway starts handling the gossip. (Default 300s) 99 | bufferExpiration: 300s 100 | 101 | # Frequency with which to check the gossip buffer. Should be long, since the 102 | # thread takes a lock each time it checks the buffer. (Default 150s) 103 | monitorThreadFrequency: 150s 104 | 105 | # Flags for rate limiting communications 106 | 107 | # The capacity of rate limiting buckets in the map. (Default 20) 108 | capacity: 20 109 | 110 | # The rate that the rate limiting bucket leaks tokens at [tokens/ns]. (Default 3) 111 | leakedTokens: 3 112 | 113 | # How often the number of leaked tokens is leaked from the bucket. (Default 1ms) 114 | leakDuration: 1ms 115 | 116 | # How often inactive buckets are removed. (Default 10s) 117 | pollDuration: 10s 118 | 119 | # The max age of a bucket without activity before it is removed. (Default 10s) 120 | bucketMaxAge: 10s 121 | 122 | # time.Duration used to calculate lower bound of when to replace TLS cert, based on its expiry (default 30 days) 123 | certReplaceWindow: 720h 124 | 125 | # Maximum random delay for cert replacement after reaching the start of CertReplaceWindow (default 5 days) 126 | maxCertReplaceDelay: 120h 127 | ``` 128 | 129 | ## Command line flags 130 | 131 | The command line flags for the server can be generated `--help` as follows: 132 | 133 | 134 | ``` 135 | % go run main.go --help 136 | The cMix gateways coordinate communications between servers and clients 137 | 138 | Usage: 139 | gateway [flags] 140 | gateway [command] 141 | 142 | Available Commands: 143 | autocert automatic cert request test command 144 | generate Generates version and dependency information for the xx network binary 145 | help Help about any command 146 | version Print the version and dependency information for the xx network binary 147 | 148 | Flags: 149 | --bucketMaxAge duration The max age of a bucket without activity before it is removed. (default 10s) 150 | --bufferExpiration duration How long a message record should last in the gossip buffer if it arrives before the Gateway starts handling the gossip. (default 5m0s) 151 | --capacity uint32 The capacity of rate-limiting buckets in the map. (default 20) 152 | --certPath string Path to the self-signed TLS certificate for Gateway. Expects PEM format. (Required) 153 | --cmixAddress string The IP address of the machine running cMix that the Gateway communicates with. Expects an IPv4 address with a port. (Required) 154 | --cmixCertPath string Path to the self-signed TLS certificate for cMix. Expects PEM format. (Required) 155 | -c, --config string Path to load the Gateway configuration file from. (Required) 156 | --enableGossip Feature flag for in progress gossip functionality 157 | -h, --help help for gateway 158 | --idfPath string Path to where the identity file (IDF) is saved. The IDF stores the Gateway's Node's network identity. This is used by the wrapper management script. (Required) 159 | --keyPath string Path to the private key associated with the self-signed TLS certificate. (Required) 160 | --kr int Amount of rounds to keep track of in kr (default 1024) 161 | --leakDuration duration How often the number of leaked tokens is leaked from the bucket. (default 1ms) 162 | --leakedTokens uint32 The rate that the rate limiting bucket leaks tokens at [tokens/ns]. (default 3) 163 | --log string Path where log file will be saved. (default "log/gateway.log") 164 | -l, --logLevel uint Level of debugging to print (0 = info, 1 = debug, >1 = trace). 165 | --monitorThreadFrequency duration Frequency with which to check the gossip buffer. (default 2m30s) 166 | --pollDuration duration How often inactive buckets are removed. (default 10s) 167 | -p, --port int Port for Gateway to listen on.Gateway must be the only listener on this port. (Required) (default -1) 168 | --profile-cpu string Enable cpu profiling to this file 169 | --schedulingCertPath string Path to the self-signed TLS certificate for the Scheduling server. Expects PEM format. (Required) 170 | 171 | Use "gateway [command] --help" for more information about a command. 172 | ``` 173 | 174 | All of those flags, except `--config`, override values in the configuration 175 | file. 176 | 177 | The `version` subcommand prints the version: 178 | 179 | 180 | ``` 181 | $ go run main.go version 182 | Elixxir Gateway v1.1.0 -- 426617f Fix MessageTimeout, change localAddress to listeningAddress and mark hidden, and change example nodeAddress 183 | 184 | Dependencies: 185 | 186 | module gitlab.com/elixxir/gateway 187 | 188 | go 1.13 189 | ... 190 | ``` 191 | 192 | The `generate` subcommand is used for updating version information (see the 193 | next section). 194 | 195 | ## Updating Version Info 196 | ``` 197 | $ go run main.go generate 198 | $ mv version_vars.go cmd 199 | ``` 200 | 201 | The `autocert` subcommand should not be used. It allows you to make a 202 | ZeroSSL certificate request and returns DNS Settings which can only be 203 | set by the xx network technical team. This subcommand is used for testing only. 204 | 205 | ## Project Structure 206 | 207 | 208 | `cmd` handles command line flags and all gateway logic. 209 | 210 | `notifications` handles notification logic use to push alerts to clients. 211 | 212 | `storage` contains the database and ram-based storage implementations. 213 | 214 | ## Compiling the Binary 215 | 216 | To compile a binary that will run the server on your platform, 217 | you will need to run one of the commands in the following sections. 218 | The `.gitlab-ci.yml` file also contains cross build instructions 219 | for all of these platforms. 220 | 221 | 222 | ### Linux 223 | 224 | ``` 225 | GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o gateway main.go 226 | ``` 227 | 228 | ### Windows 229 | 230 | ``` 231 | GOOS=windows GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o gateway main.go 232 | ``` 233 | 234 | or 235 | 236 | ``` 237 | GOOS=windows GOARCH=386 CGO_ENABLED=0 go build -ldflags '-w -s' -o gateway main.go 238 | ``` 239 | 240 | for a 32 bit version. 241 | 242 | ### Mac OSX 243 | 244 | ``` 245 | GOOS=darwin GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o gateway main.go 246 | ``` 247 | -------------------------------------------------------------------------------- /autocert/dns_test.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | package autocert 9 | 10 | import ( 11 | "bytes" 12 | "crypto" 13 | "encoding/pem" 14 | "fmt" 15 | "math/rand" 16 | "os" 17 | "testing" 18 | "time" 19 | 20 | "golang.org/x/crypto/acme" 21 | "golang.org/x/net/context" 22 | ) 23 | 24 | func TestMain(m *testing.M) { 25 | rand.Seed(42) // consistent answers, more or less 26 | dnsClientObj = func() acmeClient { 27 | return &dummyACMEClient{} 28 | } 29 | os.Exit(m.Run()) 30 | } 31 | 32 | func TestGenerateCertTestGenerate(t *testing.T) { 33 | // Regular Gen 34 | rng1 := &dummyRNG{} 35 | key, err := GenerateCertKey(rng1) 36 | if err != nil { 37 | t.Errorf("%+v", err) 38 | } 39 | 40 | if key.Size() != 4096/8 { 41 | t.Errorf("bad key size, expected 4096, got %d", key.Size()) 42 | } 43 | 44 | // Short read error 45 | // NOTE: The RSA generate function will happily loop forever 46 | // if the RNG just does an implied short read, it has to return an 47 | // error! 48 | rng2 := &shortRNG{} 49 | key, err = GenerateCertKey(rng2) 50 | if key != nil || err == nil { 51 | t.Errorf("expected failure, got success") 52 | } 53 | } 54 | 55 | // Very basic smoke test to create an object and then 56 | // "load" the same object, which is only checking to 57 | // make sure there are no crashes when doing this and that 58 | // keys get loaded properly via PEM. 59 | func TestNewAndLoad(t *testing.T) { 60 | n := NewDNS() 61 | actualN := n.(*dnsClient) 62 | 63 | rng1 := &dummyRNG{} 64 | key, err := GenerateCertKey(rng1) 65 | if err != nil { 66 | t.Errorf("%+v", err) 67 | } 68 | actualN.SetKey(key.GetGoRSA()) 69 | actualN.PrivateKey = key 70 | 71 | n2, err := LoadDNS(key.MarshalPem()) 72 | if err != nil { 73 | t.Errorf("%+v", err) 74 | } 75 | actualN2 := n2.(*dnsClient) 76 | 77 | if bytes.Compare(actualN.PrivateKey.MarshalPem(), 78 | actualN2.PrivateKey.MarshalPem()) != 0 { 79 | t.Errorf("keys do not match: %v\n%v", 80 | actualN.PrivateKey.MarshalPem(), 81 | actualN2.PrivateKey.MarshalPem()) 82 | } 83 | if actualN2.GetKey() == nil { 84 | t.Errorf("internal key not set") 85 | } 86 | } 87 | 88 | // Smoke Test for Registration 89 | func TestRegister(t *testing.T) { 90 | n := NewDNS() 91 | rng1 := &dummyRNG{} 92 | key, err := GenerateCertKey(rng1) 93 | if err != nil { 94 | t.Errorf("%+v", err) 95 | } 96 | 97 | // dGVzdHN0cmluZwo= = "teststring" 98 | err = n.Register(key, "eabKeyID", "dGVzdHN0cmluZwo", 99 | "email@example.com") 100 | if err != nil { 101 | t.Errorf("%+v", err) 102 | } 103 | actualN := n.(*dnsClient) 104 | if actualN.GetKey() == nil { 105 | t.Errorf("internal key not set") 106 | } 107 | if bytes.Compare(actualN.PrivateKey.MarshalPem(), 108 | key.MarshalPem()) != 0 { 109 | t.Errorf("key not set, byte mismatch:\n%v\n%v\n", 110 | key.MarshalPem(), actualN.PrivateKey.MarshalPem()) 111 | } 112 | } 113 | 114 | func TestRequest(t *testing.T) { 115 | n := NewDNS() 116 | rng1 := &dummyRNG{} 117 | key, err := GenerateCertKey(rng1) 118 | if err != nil { 119 | t.Errorf("%+v", err) 120 | } 121 | 122 | // dGVzdHN0cmluZwo= = "teststring" 123 | err = n.Register(key, "eabKeyID", "dGVzdHN0cmluZwo", 124 | "email@example.com") 125 | if err != nil { 126 | t.Errorf("%+v", err) 127 | } 128 | k, v, err := n.Request("example.com") 129 | if err != nil { 130 | t.Errorf("%+v", err) 131 | } 132 | expK := "_acme-challenge.example.com" 133 | if k != expK { 134 | t.Errorf("unexpected key: expect %s, got %s", expK, k) 135 | } 136 | if v != "dnsChallengeString" { 137 | t.Errorf("unexpect chal: expected dnsChallengeString got %s", 138 | v) 139 | } 140 | 141 | actualN := n.(*dnsClient) 142 | if actualN.GetKey() == nil { 143 | t.Errorf("internal key not set") 144 | } 145 | if bytes.Compare(actualN.PrivateKey.MarshalPem(), 146 | key.MarshalPem()) != 0 { 147 | t.Errorf("key not set, byte mismatch:\n%v\n%v\n", 148 | key.MarshalPem(), actualN.PrivateKey.MarshalPem()) 149 | } 150 | if actualN.AuthzURL != "url1" { 151 | t.Errorf("bad AuthzURL: %s", actualN.AuthzURL) 152 | } 153 | if actualN.Domain != "example.com" { 154 | t.Errorf("bad domain: %s", actualN.Domain) 155 | } 156 | if actualN.AuthzFinalizeURL != "finalme" { 157 | t.Errorf("bad final URL: %s", actualN.AuthzFinalizeURL) 158 | } 159 | } 160 | 161 | func TestCSR(t *testing.T) { 162 | n := NewDNS() 163 | rng1 := &dummyRNG{} 164 | key, err := GenerateCertKey(rng1) 165 | if err != nil { 166 | t.Errorf("%+v", err) 167 | } 168 | 169 | // dGVzdHN0cmluZwo= = "teststring" 170 | err = n.Register(key, "eabKeyID", "dGVzdHN0cmluZwo", 171 | "email@example.com") 172 | if err != nil { 173 | t.Errorf("%+v", err) 174 | } 175 | 176 | csrPEM, csrDER, err := n.CreateCSR("example.com", "test@example.com", 177 | "USA", "nodeID", rng1) 178 | if err != nil { 179 | t.Errorf("%+v", err) 180 | } 181 | 182 | expPEM := []byte{45, 45, 45, 45, 45, 66, 69, 71, 73, 78, 32, 67, 69, 183 | 82, 84, 73, 70, 73, 67, 65} 184 | if bytes.Compare(expPEM, csrPEM[:len(expPEM)]) != 0 { 185 | t.Errorf("bad pem:\n\t%v\n\t%v", expPEM, csrPEM[:len(expPEM)]) 186 | } 187 | expDER := []byte{48, 130, 4, 143, 48, 130, 2, 119, 2, 1, 0, 48, 74, 188 | 49, 12, 48, 10, 6} 189 | if bytes.Compare(expDER, csrDER[:len(expDER)]) != 0 { 190 | t.Errorf("bad der:\n\t%v\n\t%v", expDER, csrDER[:len(expDER)]) 191 | } 192 | } 193 | 194 | func TestIssue(t *testing.T) { 195 | n := NewDNS() 196 | rng1 := &dummyRNG{} 197 | key, err := GenerateCertKey(rng1) 198 | if err != nil { 199 | t.Errorf("%+v", err) 200 | } 201 | 202 | // dGVzdHN0cmluZwo= = "teststring" 203 | err = n.Register(key, "eabKeyID", "dGVzdHN0cmluZwo", 204 | "admins@elixxir.io") 205 | if err != nil { 206 | t.Errorf("%+v", err) 207 | } 208 | _, _, err = n.Request("rick1.xxn2.work") 209 | if err != nil { 210 | t.Errorf("%+v", err) 211 | } 212 | 213 | _, csrDER, err := n.CreateCSR("rick1.xxn2.work", "admins@elixxir.io", 214 | "USA", "nodeID", rng1) 215 | if err != nil { 216 | t.Errorf("%+v", err) 217 | } 218 | 219 | _, _, err = n.Issue(csrDER, time.Minute) 220 | if err != nil { 221 | t.Errorf("%+v", err) 222 | } 223 | 224 | } 225 | 226 | type dummyACMEClient struct { 227 | Key crypto.Signer 228 | URL string 229 | } 230 | 231 | func (a *dummyACMEClient) GetDirectoryURL() string { 232 | return a.URL 233 | } 234 | func (a *dummyACMEClient) SetDirectoryURL(d string) { 235 | a.URL = d 236 | } 237 | func (a *dummyACMEClient) GetKey() crypto.Signer { 238 | return a.Key 239 | } 240 | func (a *dummyACMEClient) SetKey(k crypto.Signer) { 241 | a.Key = k 242 | } 243 | func (a *dummyACMEClient) GetReg(ctx context.Context, 244 | x string) (*acme.Account, error) { 245 | return &acme.Account{}, nil 246 | } 247 | func (a *dummyACMEClient) Register(ctx context.Context, acct *acme.Account, 248 | tosFn func(tosURL string) bool) (*acme.Account, error) { 249 | return a.GetReg(ctx, "") 250 | } 251 | func (a *dummyACMEClient) DNS01ChallengeRecord(token string) (string, error) { 252 | return "dnsChallengeString", nil 253 | } 254 | func (a *dummyACMEClient) AuthorizeOrder(ctx context.Context, 255 | authzIDs []acme.AuthzID) (*acme.Order, error) { 256 | return &acme.Order{ 257 | AuthzURLs: []string{ 258 | "url1", 259 | }, 260 | FinalizeURL: "finalme", 261 | }, nil 262 | } 263 | func (a *dummyACMEClient) CreateOrderCert(ctx context.Context, 264 | finalURL string, csr []byte, ty bool) ([][]byte, string, error) { 265 | validCert := `-----BEGIN CERTIFICATE----- 266 | MIIHbjCCBVagAwIBAgIQSM+Uht+j3A/bOj7JPRBDrzANBgkqhkiG9w0BAQwFADBL 267 | MQswCQYDVQQGEwJBVDEQMA4GA1UEChMHWmVyb1NTTDEqMCgGA1UEAxMhWmVyb1NT 268 | TCBSU0EgRG9tYWluIFNlY3VyZSBTaXRlIENBMB4XDTIyMTEwNTAwMDAwMFoXDTIz 269 | MDIwMzIzNTk1OVowGjEYMBYGA1UEAxMPcmljazEueHhuMi53b3JrMIICIjANBgkq 270 | hkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAxn0SZaCiduE3RINmtOI2GsmO4jGputtF 271 | RhV9c8SwbnHZvVsZGJpcF0Zp/ONhSbIiPJEzVoNUylwTNp4JFr1ePkigLOrR4akJ 272 | ovmAE1GzhNuBvgyxjS05e/hgGzDno186r9WaxcDnuMqaXF3syGEGoyDkB4hSJmY3 273 | YOya8KrxZNHw7IDGOIjoyWCz8XzBvXR5KGztSgixG28IYkB+wAyfHj9IrRf0/T3Y 274 | u6CYbBgwD5JLEG2Mn5WLbDfcIDubNxB4ZrxoFKUfZGfoankrO3C7cPjn+JgC8ZZz 275 | kNN1KlD1ZqkNDPI077UvVhhBhbPQehYcEbCwLBSoXQRvVn7Ij4OROIqgTJI5wBIo 276 | FrmYfgbI4xjEgYBfLS3u+FUvM7KrCh9IaFC5gtJSxsqGHkxsAYYEyXNqys/yvCEb 277 | fCBMH08zmoHhqLbvQNIrCY+L31C4jcvS/FaqGiWjusljUW5/1FKi8SnHCWrneAxN 278 | oKlsWHTYbgppf9N42Yy25w/Yp4n7cpUoFmCio6L9KLsZsRcWGuX9kzdcklqVOsuF 279 | Ej+XnOdAVujXbgAkvtlyxlcQky1n9PZ2jxOnyQQ22+9PuGVz9DMI6MIrkScSfWYO 280 | BDN6Lb6O3CZ0DRl3h1c7KUlCssoNcqi67G7vpSuRYMQ7TvSaxZIpf4pIepufrUiV 281 | FXJSdrupTwMCAwEAAaOCAn0wggJ5MB8GA1UdIwQYMBaAFMjZeGii2Rlo1T1y3l8K 282 | Pty1hoamMB0GA1UdDgQWBBTvyeRInS6KVVyO9eK82Fy8Jl2o7TAOBgNVHQ8BAf8E 283 | BAMCBaAwDAYDVR0TAQH/BAIwADAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUH 284 | AwIwSQYDVR0gBEIwQDA0BgsrBgEEAbIxAQICTjAlMCMGCCsGAQUFBwIBFhdodHRw 285 | czovL3NlY3RpZ28uY29tL0NQUzAIBgZngQwBAgEwgYgGCCsGAQUFBwEBBHwwejBL 286 | BggrBgEFBQcwAoY/aHR0cDovL3plcm9zc2wuY3J0LnNlY3RpZ28uY29tL1plcm9T 287 | U0xSU0FEb21haW5TZWN1cmVTaXRlQ0EuY3J0MCsGCCsGAQUFBzABhh9odHRwOi8v 288 | emVyb3NzbC5vY3NwLnNlY3RpZ28uY29tMIIBBgYKKwYBBAHWeQIEAgSB9wSB9ADy 289 | AHcArfe++nz/EMiLnT2cHj4YarRnKV3PsQwkyoWGNOvcgooAAAGERYbgFQAABAMA 290 | SDBGAiEA/d+USH39vf8RvckdgiB+a+NyDEb/7xG4VIBZPYIXLQACIQD4XySXos55 291 | o0wyE3XTzylkqoyupOaLtGpSkbDZlIKp2QB3AHoyjFTYty22IOo44FIe6YQWcDIT 292 | hU070ivBOlejUutSAAABhEWG3+MAAAQDAEgwRgIhAOBKcYLCr3EhrTqyywMvCA3Z 293 | OeYesLfOdpJJJgkGB9aVAiEA36DwB++0dEzJf6JiXsXpuUfjZ4KgtDWufl+FYNd3 294 | LlgwGgYDVR0RBBMwEYIPcmljazEueHhuMi53b3JrMA0GCSqGSIb3DQEBDAUAA4IC 295 | AQAon7p8PMR/dbGBzqG0zoNotB+BxqdJV1VtJk0XYKQl/i03fBr4fzvhnBEFvKe4 296 | qjk2YUqB52muLl64n0o0yI96p74F5j+X1xbuX1JwT1hA3udGRNrS4vkFTUV/ymhe 297 | CPWnZaSpIudtb5uO/nNCZ/+794NHzmPHbC4oTo83wRFoxZst50jvC2a5E9ewY41h 298 | uZ0la/n6Q+2/F9CBYYPvLeLqPmRco9QhXl6CDndvetwNkOXh75Kt517Xu97TcdK2 299 | S+UdpRcZWonxosNxLEFtu3otRmxzT/3PhMit3GqQxw3gFzqjNIgwmumF2buSGZeH 300 | SVBh+HkqxyE286UtHanCPrYv4ev4WjV05PvRbwSZyx/d//Lvmo9lheXzKN4HVVbh 301 | wZX1HDu+mkmeHquDAQYT3Wbn0f5rmXaCbCBtAOLDpJs8jEu/9Xsb8R9IgsBdUkeL 302 | tPt8haNAF82yYKlcBFJhPsQZSSabMy1Ew6gXqDCLC+YyxY5ASuLk4fOBMqfGuVDS 303 | 9Po+PFUvdas/yypuMauZWazj2kXki06xGylTb5pDgResTuGMQ4X7+9COzoGLNedt 304 | iz2njhHe7DWAQiOFOllO1mCUvzJ0oyYF4C9YHIT4j+Mzq31mmpLAcv/oI10Yaco9 305 | ha+P46sxUx4cDvtzq29TynRlbNAXj27yripo/2Azn6IG/Q== 306 | -----END CERTIFICATE-----` 307 | pemBlk, _ := pem.Decode([]byte(validCert)) 308 | der := pemBlk.Bytes 309 | return [][]byte{der}, "certURL", nil 310 | } 311 | func (a *dummyACMEClient) GetAuthorization(ctx context.Context, 312 | authzURL string) (*acme.Authorization, error) { 313 | return &acme.Authorization{ 314 | Challenges: []*acme.Challenge{ 315 | { 316 | Type: "dns-01", 317 | }, 318 | }, 319 | }, nil 320 | } 321 | func (a *dummyACMEClient) Accept(ctx context.Context, 322 | chal *acme.Challenge) (*acme.Challenge, error) { 323 | return chal, nil 324 | } 325 | func (a *dummyACMEClient) WaitAuthorization(ctx context.Context, 326 | authzURL string) (*acme.Authorization, error) { 327 | return &acme.Authorization{ 328 | Status: acme.StatusValid, 329 | }, nil 330 | } 331 | 332 | type dummyRNG struct{} 333 | 334 | func (z *dummyRNG) Read(b []byte) (int, error) { 335 | return rand.Read(b) 336 | } 337 | 338 | type shortRNG struct{} 339 | 340 | func (z *shortRNG) Read(b []byte) (int, error) { 341 | k, _ := rand.Read(b[0 : len(b)/2]) 342 | return k, fmt.Errorf("short read") 343 | } 344 | -------------------------------------------------------------------------------- /autocert/dns.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Package autocert requests an ACME certificate using EAB credentials 9 | // and provides helper functions to wait until the certificate is issued. 10 | package autocert 11 | 12 | import ( 13 | "crypto" 14 | "crypto/x509" 15 | "crypto/x509/pkix" 16 | "encoding/base64" 17 | "encoding/pem" 18 | "fmt" 19 | "io" 20 | "time" 21 | 22 | "github.com/pkg/errors" 23 | 24 | jww "github.com/spf13/jwalterweatherman" 25 | "gitlab.com/elixxir/crypto/rsa" 26 | "golang.org/x/crypto/acme" 27 | "golang.org/x/net/context" 28 | ) 29 | 30 | const ZeroSSLACMEURL = "https://acme.zerossl.com/v2/DV90" 31 | const TimedOutWaitingErr = "Timed out waiting for authorization" 32 | 33 | type dnsClient struct { 34 | acmeClient 35 | AuthzURL string 36 | AuthzFinalizeURL string 37 | Domain string 38 | PrivateKey rsa.PrivateKey 39 | } 40 | 41 | type acmeClientImpl struct { 42 | *acme.Client 43 | } 44 | 45 | // Change this to mock the client 46 | var dnsClientObj = func() acmeClient { 47 | return &acmeClientImpl{ 48 | Client: &acme.Client{ 49 | DirectoryURL: ZeroSSLACMEURL, 50 | }, 51 | } 52 | } 53 | 54 | // GenerateCertKey generates a 4096 bit RSA Private Key that can be used in 55 | // the certificate request. 56 | func GenerateCertKey(csprng io.Reader) (rsa.PrivateKey, error) { 57 | pKey, err := rsa.GetScheme().Generate(csprng, 4096) 58 | if err != nil { 59 | return nil, err 60 | } 61 | return pKey, nil 62 | } 63 | 64 | // NewDNS creates a new empty DNS Client object 65 | func NewDNS() Client { 66 | return &dnsClient{ 67 | acmeClient: dnsClientObj(), 68 | AuthzURL: "", 69 | AuthzFinalizeURL: "", 70 | Domain: "", 71 | } 72 | } 73 | 74 | // LoadDNS recreates a DNS client object based on the private key PEM file 75 | func LoadDNS(privateKeyPEM []byte) (Client, error) { 76 | d := &dnsClient{ 77 | acmeClient: dnsClientObj(), 78 | AuthzURL: "", 79 | AuthzFinalizeURL: "", 80 | Domain: "", 81 | } 82 | privateKey, err := rsa.GetScheme().UnmarshalPrivateKeyPEM(privateKeyPEM) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | d.SetKey(privateKey.GetGoRSA()) 88 | d.PrivateKey = privateKey 89 | 90 | ctx, cancelFn := getDefaultContext() 91 | defer cancelFn() 92 | acct, err := d.GetReg(ctx, "") 93 | if err == nil { 94 | jww.DEBUG.Printf("looked up acct: %v", acct) 95 | } 96 | return d, err 97 | } 98 | 99 | func (d *dnsClient) Register(privateKey rsa.PrivateKey, 100 | eabKeyID, eabKey, email string) error { 101 | // Let's rule out dumb mistakes and decode/create the external account 102 | // binding first. 103 | eabHMAC, err := base64.RawURLEncoding.DecodeString(eabKey) 104 | if err != nil { 105 | return err 106 | } 107 | 108 | d.SetKey(privateKey.GetGoRSA()) 109 | 110 | acctReq := &acme.Account{ 111 | ExternalAccountBinding: &acme.ExternalAccountBinding{ 112 | KID: eabKeyID, 113 | Key: eabHMAC, 114 | }, 115 | Contact: []string{fmt.Sprintf("mailto:%s", email)}, 116 | } 117 | 118 | // Note: this is wonky, because the account object sent is not modified 119 | // and a new one gets returned. A review of the internals shows that 120 | // only the ExternalAccountBinding and Contact objects are used 121 | ctx, cancelFn := getDefaultContext() 122 | defer cancelFn() 123 | acct, err := d.acmeClient.Register(ctx, acctReq, acme.AcceptTOS) 124 | if err != nil { 125 | return err 126 | } 127 | 128 | d.PrivateKey = privateKey 129 | 130 | jww.DEBUG.Printf("Account Registered: %v", acct) 131 | return nil 132 | } 133 | 134 | func (d *dnsClient) Request(domain string) (key, value string, err error) { 135 | authzIDs := []acme.AuthzID{ 136 | { 137 | Type: "dns", 138 | Value: domain, 139 | }, 140 | } 141 | 142 | order, err := getAuthOrder(d.acmeClient, authzIDs) 143 | if err != nil { 144 | jww.ERROR.Printf("Authorize failed: %+v", err) 145 | return "", "", err 146 | } 147 | jww.DEBUG.Printf("Order Returned: %v", order) 148 | 149 | dns01, authzURL, err := getDNSChallenge(d.acmeClient, order) 150 | if err != nil { 151 | jww.ERROR.Printf("DNS challenge failed: %+v", err) 152 | return "", "", err 153 | } 154 | 155 | d.AuthzURL = authzURL 156 | d.Domain = domain 157 | d.AuthzFinalizeURL = order.FinalizeURL 158 | 159 | if dns01 == nil { 160 | return "already validated", "none", nil 161 | } 162 | 163 | jww.DEBUG.Printf("DNS Challenge: %v", dns01) 164 | 165 | dns01, err = acceptDNSChallenge(d.acmeClient, dns01) 166 | if err != nil { 167 | jww.ERROR.Printf("accept DNS failed: %+v", err) 168 | return "", "", err 169 | } 170 | 171 | dnsChallenge, err := d.DNS01ChallengeRecord(dns01.Token) 172 | if err != nil { 173 | jww.ERROR.Printf("DNS token challenge failed: %+v", err) 174 | return "", "", err 175 | } 176 | 177 | key = fmt.Sprintf("_acme-challenge.%s", domain) 178 | value = dnsChallenge 179 | 180 | jww.DEBUG.Printf("TXT Record:\n%s\t%s", key, value) 181 | 182 | return key, value, nil 183 | } 184 | 185 | func (d *dnsClient) CreateCSR(domain, email, country, nodeID string, 186 | rng io.Reader) (csrPEM, csrDER []byte, err error) { 187 | 188 | subject := pkix.Name{ 189 | Country: []string{country}, 190 | Organization: []string{"xx network"}, 191 | OrganizationalUnit: []string{nodeID}, 192 | CommonName: domain, 193 | } 194 | 195 | csrTemplate := &x509.CertificateRequest{ 196 | SignatureAlgorithm: x509.SHA512WithRSA, 197 | PublicKeyAlgorithm: x509.RSA, 198 | PublicKey: d.PrivateKey.GetGoRSA().PublicKey, 199 | Subject: subject, 200 | } 201 | csrDER, err = x509.CreateCertificateRequest( 202 | rng, 203 | csrTemplate, 204 | d.PrivateKey.GetGoRSA(), 205 | ) 206 | 207 | csrPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", 208 | Bytes: csrDER}) 209 | 210 | // return csrPEM, err 211 | return csrPEM, csrDER, err 212 | } 213 | 214 | func (d *dnsClient) Issue(csr []byte, timeout time.Duration) (cert, key []byte, err error) { 215 | if d.AuthzURL == "" { 216 | return nil, nil, errors.Errorf("missing auth, call Request") 217 | } 218 | authz, err := waitForAuthorization(d.acmeClient, d.AuthzURL, timeout) 219 | if err != nil { 220 | return nil, nil, err 221 | } 222 | if authz.Status != acme.StatusValid { 223 | return nil, nil, errors.Errorf("invalid status object: %v", 224 | authz) 225 | } 226 | jww.DEBUG.Printf("Final Auth: %v", authz) 227 | 228 | ctx, cancelFn := getDefaultContext() 229 | defer cancelFn() 230 | der, certURL, err := d.CreateOrderCert(ctx, d.AuthzFinalizeURL, csr, 231 | false) 232 | if err != nil { 233 | jww.ERROR.Printf("cannot create cert: %+v", err) 234 | return nil, nil, err 235 | } 236 | 237 | jww.DEBUG.Printf("got cert from %s, parsing...", certURL) 238 | 239 | certObj, err := x509.ParseCertificate(der[0]) 240 | if err != nil { 241 | jww.ERROR.Printf("cannot parse cert: %+v", err) 242 | return nil, nil, err 243 | } 244 | 245 | err = certObj.VerifyHostname(d.Domain) 246 | if err != nil { 247 | jww.ERROR.Printf("cannot verify cert hostname: %+v", err) 248 | return nil, nil, err 249 | } 250 | 251 | cert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", 252 | Bytes: der[0]}) 253 | 254 | return cert, d.PrivateKey.MarshalPem(), nil 255 | } 256 | 257 | // Internal helper network functions 258 | 259 | func getDefaultContext() (context.Context, context.CancelFunc) { 260 | return context.WithTimeout(context.Background(), 60*time.Second) 261 | } 262 | 263 | func getAuthOrder(client acmeClient, 264 | authzIDs []acme.AuthzID) (*acme.Order, error) { 265 | ctx, cancelFn := getDefaultContext() 266 | defer cancelFn() 267 | return client.AuthorizeOrder(ctx, authzIDs) 268 | } 269 | 270 | func getDNSChallenge(client acmeClient, order *acme.Order) (*acme.Challenge, 271 | string, error) { 272 | for i := 0; i < len(order.AuthzURLs); i++ { 273 | authzURL := order.AuthzURLs[i] 274 | authz, err := getAuth(client, authzURL) 275 | if err != nil { 276 | jww.WARN.Printf("unable to retrieve authz %s: %+v", authzURL, err) 277 | continue 278 | } 279 | if authz.Status == acme.StatusValid { 280 | return nil, authzURL, nil 281 | } 282 | c := findDNSChallenge(authz.Challenges) 283 | if c != nil { 284 | return c, order.AuthzURLs[i], nil 285 | } 286 | } 287 | return nil, "", errors.Errorf("no dns challenge available") 288 | } 289 | 290 | func getAuth(client acmeClient, authzURL string) (*acme.Authorization, error) { 291 | ctx, cancelFn := getDefaultContext() 292 | defer cancelFn() 293 | authz, err := client.GetAuthorization(ctx, authzURL) 294 | if err != nil { 295 | return nil, err 296 | } 297 | return authz, nil 298 | } 299 | 300 | func findDNSChallenge(challenges []*acme.Challenge) *acme.Challenge { 301 | for i := 0; i < len(challenges); i++ { 302 | challenge := challenges[i] 303 | jww.DEBUG.Printf("Challenge Type: %s", challenge.Type) 304 | if challenge.Type == "dns-01" { 305 | return challenge 306 | } 307 | } 308 | return nil 309 | } 310 | 311 | func acceptDNSChallenge(client acmeClient, 312 | dns01 *acme.Challenge) (*acme.Challenge, error) { 313 | ctx, cancelFn := getDefaultContext() 314 | defer cancelFn() 315 | return client.Accept(ctx, dns01) 316 | } 317 | 318 | func waitForAuthorization(client acmeClient, 319 | authzURL string, timeout time.Duration) (*acme.Authorization, error) { 320 | to := time.NewTimer(timeout) 321 | for { 322 | select { 323 | case <-to.C: 324 | return nil, errors.New(TimedOutWaitingErr) 325 | default: 326 | } 327 | ctx, cancelFn := context.WithTimeout(context.Background(), 328 | 15*time.Second) 329 | authz, err := client.WaitAuthorization(ctx, authzURL) 330 | cancelFn() 331 | if err != nil { 332 | jww.WARN.Printf("WaitAuthorization: %s, continuing...", 333 | err.Error()) 334 | time.Sleep(30 * time.Second) 335 | continue 336 | } 337 | return authz, nil 338 | } 339 | } 340 | 341 | // --- Internal acmeClient implementation 342 | func (a *acmeClientImpl) GetDirectoryURL() string { 343 | return a.Client.DirectoryURL 344 | } 345 | func (a *acmeClientImpl) SetDirectoryURL(d string) { 346 | a.Client.DirectoryURL = d 347 | } 348 | func (a *acmeClientImpl) GetKey() crypto.Signer { 349 | return a.Client.Key 350 | } 351 | func (a *acmeClientImpl) SetKey(k crypto.Signer) { 352 | a.Client.Key = k 353 | } 354 | func (a *acmeClientImpl) GetReg(ctx context.Context, 355 | x string) (*acme.Account, error) { 356 | return a.Client.GetReg(ctx, x) 357 | } 358 | func (a *acmeClientImpl) Register(ctx context.Context, acct *acme.Account, 359 | tosFn func(tosURL string) bool) (*acme.Account, error) { 360 | return a.Client.Register(ctx, acct, tosFn) 361 | } 362 | func (a *acmeClientImpl) DNS01ChallengeRecord(token string) (string, error) { 363 | return a.Client.DNS01ChallengeRecord(token) 364 | } 365 | func (a *acmeClientImpl) AuthorizeOrder(ctx context.Context, 366 | authzIDs []acme.AuthzID) (*acme.Order, error) { 367 | return a.Client.AuthorizeOrder(ctx, authzIDs) 368 | } 369 | func (a *acmeClientImpl) CreateOrderCert(ctx context.Context, 370 | finalURL string, csr []byte, ty bool) ([][]byte, string, error) { 371 | return a.Client.CreateOrderCert(ctx, finalURL, csr, ty) 372 | } 373 | func (a *acmeClientImpl) GetAuthorization(ctx context.Context, 374 | authzURL string) (*acme.Authorization, error) { 375 | return a.Client.GetAuthorization(ctx, authzURL) 376 | } 377 | func (a *acmeClientImpl) Accept(ctx context.Context, 378 | chal *acme.Challenge) (*acme.Challenge, error) { 379 | return a.Client.Accept(ctx, chal) 380 | } 381 | func (a *acmeClientImpl) WaitAuthorization(ctx context.Context, 382 | authzURL string) (*acme.Authorization, error) { 383 | return a.Client.WaitAuthorization(ctx, authzURL) 384 | } 385 | -------------------------------------------------------------------------------- /storage/gatewayDb.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Handles the database ORM for gateways 9 | 10 | package storage 11 | 12 | import ( 13 | "bytes" 14 | "context" 15 | "database/sql" 16 | "github.com/pkg/errors" 17 | jww "github.com/spf13/jwalterweatherman" 18 | "gitlab.com/xx_network/primitives/id" 19 | "gitlab.com/xx_network/primitives/id/ephemeral" 20 | "gorm.io/gorm" 21 | "strings" 22 | "time" 23 | ) 24 | 25 | // newContext builds a context for database operations. 26 | func newContext() (context.Context, context.CancelFunc) { 27 | return context.WithTimeout(context.Background(), dbTimeout) 28 | } 29 | 30 | // perfLog prints some basic query information to the log. 31 | func perfLog(queryName string, queryStart time.Time) { 32 | queryTime := time.Since(queryStart) 33 | jww.TRACE.Printf("Query %s took %v", queryName, queryTime) 34 | if queryTime > time.Second { 35 | jww.WARN.Printf("Query %s took an unexpectedly long time: %v", 36 | queryName, queryTime) 37 | } 38 | } 39 | 40 | // catchErrors forces panics in the event of a CDE, otherwise acts as a pass-through. 41 | func catchErrors(err error) error { 42 | if err != nil { 43 | if errors.Is(err, context.DeadlineExceeded) { 44 | jww.FATAL.Panicf("Database call timed out: %+v", err.Error()) 45 | } 46 | if strings.Contains(err.Error(), "No space left on device") { 47 | jww.FATAL.Panicf("Storage device full: %+v", err.Error()) 48 | } 49 | } 50 | return err 51 | } 52 | 53 | // Inserts the given State into Database if it does not exist 54 | // Or updates the Database State if its value does not match the given State 55 | func (d *DatabaseImpl) UpsertState(state *State) error { 56 | queryStart := time.Now() 57 | // Build a transaction to prevent race conditions 58 | ctx, cancel := newContext() 59 | err := d.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { 60 | // Make a copy of the provided state 61 | newState := *state 62 | 63 | // Attempt to insert state into the Database, 64 | // or if it already exists, replace state with the Database value 65 | err := tx.FirstOrCreate(state, &State{Key: state.Key}).Error 66 | if err != nil { 67 | return err 68 | } 69 | 70 | // If state is already present in the Database, overwrite it with newState 71 | if newState.Value != state.Value { 72 | return tx.Save(newState).Error 73 | } 74 | 75 | // Commit 76 | return nil 77 | }) 78 | cancel() 79 | perfLog("UpsertState", queryStart) 80 | return catchErrors(err) 81 | } 82 | 83 | // Returns a State's value from Database with the given key 84 | // Or an error if a matching State does not exist 85 | func (d *DatabaseImpl) GetStateValue(key string) (string, error) { 86 | queryStart := time.Now() 87 | result := &State{Key: key} 88 | ctx, cancel := newContext() 89 | err := d.db.WithContext(ctx).Take(result).Error 90 | cancel() 91 | perfLog("GetStateValue", queryStart) 92 | return result.Value, catchErrors(err) 93 | } 94 | 95 | // Returns a Client from database with the given id 96 | // Or an error if a matching Client does not exist 97 | func (d *DatabaseImpl) GetClient(id *id.ID) (*Client, error) { 98 | queryStart := time.Now() 99 | result := &Client{} 100 | ctx, cancel := newContext() 101 | err := d.db.WithContext(ctx).Take(&result, "id = ?", id.Marshal()).Error 102 | cancel() 103 | perfLog("GetClient", queryStart) 104 | return result, catchErrors(err) 105 | } 106 | 107 | // Upsert client into the database - replace key field if it differs so interrupted reg doesn't fail 108 | func (d *DatabaseImpl) UpsertClient(client *Client) error { 109 | queryStart := time.Now() 110 | // Make a copy of the provided client 111 | newClient := *client 112 | 113 | // Build a transaction to prevent race conditions 114 | ctx, cancel := newContext() 115 | err := d.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { 116 | // Attempt to insert the client into the database, 117 | // or if it already exists, replace client with the database value 118 | err := tx.FirstOrCreate(client, &Client{Id: client.Id}).Error 119 | if err != nil { 120 | return err 121 | } 122 | 123 | // If the provided client has a different Key than the database value, 124 | // overwrite the database value with the provided client 125 | if !bytes.Equal(client.Key, newClient.Key) { 126 | return tx.Save(&newClient).Error 127 | } 128 | 129 | // Commit 130 | return nil 131 | }) 132 | cancel() 133 | perfLog("UpsertClient", queryStart) 134 | return catchErrors(err) 135 | } 136 | 137 | // Returns a Round from database with the given id 138 | // Or an error if a matching Round does not exist 139 | func (d *DatabaseImpl) GetRound(id id.Round) (*Round, error) { 140 | result := &Round{} 141 | err := d.db.Take(&result, "id = ?", uint64(id)).Error 142 | return result, catchErrors(err) 143 | } 144 | 145 | // Inserts the given Round into database if it does not exist 146 | // Or updates the given Round if the provided Round UpdateId is greater 147 | func (d *DatabaseImpl) UpsertRound(round *Round) error { 148 | // Build a transaction to prevent race conditions 149 | err := d.db.Transaction(func(tx *gorm.DB) error { 150 | oldRound := &Round{ 151 | LastUpdated: time.Now(), 152 | } 153 | 154 | // Attempt to insert the round into the database, 155 | // or if it already exists, replace round with the database value 156 | err := tx.Where(&Round{Id: round.Id}).FirstOrCreate(oldRound).Error 157 | if err != nil { 158 | return err 159 | } 160 | 161 | // If the provided round has a greater UpdateId than the database value, 162 | // overwrite the database value with the provided round 163 | if oldRound.UpdateId < round.UpdateId { 164 | round.LastUpdated = time.Now() 165 | return tx.Save(&round).Error 166 | } 167 | 168 | // Commit 169 | return nil 170 | }) 171 | return catchErrors(err) 172 | } 173 | 174 | // Deletes all Round objects before the given timestamp from database 175 | func (d *DatabaseImpl) deleteRound(ts time.Time) error { 176 | return catchErrors(d.db.Where("last_updated <= ?", ts).Delete(Round{}).Error) 177 | } 178 | 179 | // Count the number of MixedMessage in the database for the given roundId 180 | func (d *DatabaseImpl) countMixedMessagesByRound(roundId id.Round) (uint64, bool, error) { 181 | queryStart := time.Now() 182 | var roundCount int64 183 | ctx, cancel := newContext() 184 | err := d.db.WithContext(ctx).Model(&ClientRound{}).Where("id = ?", uint64(roundId)).Count(&roundCount).Error 185 | cancel() 186 | if err != nil { 187 | return 0, false, catchErrors(err) 188 | } 189 | hasRound := roundCount > 0 190 | 191 | var count int64 192 | if hasRound { 193 | ctx, cancel = newContext() 194 | err = d.db.WithContext(ctx).Model(&MixedMessage{}).Where("round_id = ?", uint64(roundId)).Count(&count).Error 195 | cancel() 196 | } 197 | 198 | perfLog("countMixedMessagesByRound", queryStart) 199 | return uint64(count), hasRound, catchErrors(err) 200 | } 201 | 202 | // Returns a slice of MixedMessages from database 203 | // with matching recipientId and roundId 204 | // Or an error if a matching Round does not exist 205 | func (d *DatabaseImpl) getMixedMessages(recipientId ephemeral.Id, roundId id.Round) ([]*MixedMessage, error) { 206 | queryStart := time.Now() 207 | results := make([]*MixedMessage, 0) 208 | ctx, cancel := newContext() 209 | err := d.db.WithContext(ctx).Find(&results, 210 | &MixedMessage{RecipientId: recipientId.Int64(), 211 | RoundId: uint64(roundId)}).Error 212 | cancel() 213 | perfLog("getMixedMessages", queryStart) 214 | return results, catchErrors(err) 215 | } 216 | 217 | // Inserts the given list of MixedMessage into database 218 | // NOTE: Do not specify Id attribute for messages, it is autogenerated 219 | func (d *DatabaseImpl) InsertMixedMessages(cr *ClientRound) error { 220 | queryStart := time.Now() 221 | ctx, cancel := newContext() 222 | err := d.db.WithContext(ctx).Create(cr).Error 223 | cancel() 224 | perfLog("InsertMixedMessages", queryStart) 225 | return catchErrors(err) 226 | } 227 | 228 | // Deletes all MixedMessages before the given timestamp from database 229 | func (d *DatabaseImpl) deleteMixedMessages(ts time.Time) error { 230 | return d.db.Where("timestamp <= ?", ts).Delete(ClientRound{}).Error 231 | } 232 | 233 | // Returns ClientBloomFilter from database with the given recipientId 234 | // and an Epoch between startEpoch and endEpoch (inclusive) 235 | // Or an error if no matching ClientBloomFilter exist 236 | func (d *DatabaseImpl) GetClientBloomFilters(recipientId ephemeral.Id, startEpoch, endEpoch uint32) ([]*ClientBloomFilter, error) { 237 | jww.DEBUG.Printf("Getting filters for client [%v]", recipientId) 238 | queryStart := time.Now() 239 | 240 | var results []*ClientBloomFilter 241 | recipientIdInt := recipientId.Int64() 242 | ctx, cancel := newContext() 243 | err := d.db.WithContext(ctx).Where("epoch BETWEEN ? AND ?", startEpoch, endEpoch).Find(&results, &ClientBloomFilter{RecipientId: &recipientIdInt}).Error 244 | cancel() 245 | 246 | perfLog("GetClientBloomFilters", queryStart) 247 | jww.DEBUG.Printf("Returning filters [%v] for client [%v]", results, recipientId) 248 | return results, catchErrors(err) 249 | } 250 | 251 | // upsertClientBloomFilter into database if it does not exist, or updates the 252 | // ClientBloomFilter in the database if the ClientBloomFilter already exists. 253 | func (d *DatabaseImpl) upsertClientBloomFilter(filter *ClientBloomFilter) error { 254 | jww.DEBUG.Printf("Upserting filter for client %d at epoch %d", *filter.RecipientId, filter.Epoch) 255 | queryStart := time.Now() 256 | 257 | // Build a transaction to prevent race conditions 258 | ctx, cancel := newContext() 259 | err := d.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { 260 | // Initialize variable for returning existing value from the database 261 | oldFilter := &ClientBloomFilter{ 262 | Filter: make([]byte, len(filter.Filter)), 263 | } 264 | 265 | // Attempt to insert filter into the database. 266 | // If it already exists and hasn't reached maxBloomUses, 267 | // replace oldFilter with the database value. 268 | err := tx.Where(&ClientBloomFilter{ 269 | Epoch: filter.Epoch, 270 | RecipientId: filter.RecipientId, 271 | }).Where("uses < ?", maxBloomUses).FirstOrCreate(oldFilter).Error 272 | if err != nil { 273 | return err 274 | } 275 | 276 | // Combine oldFilter with filter 277 | filter.combine(oldFilter) 278 | 279 | // Commit to the database 280 | err = tx.Save(filter).Error 281 | if err != nil { 282 | return err 283 | } 284 | return nil 285 | }, &sql.TxOptions{Isolation: sql.LevelSerializable}) 286 | cancel() 287 | perfLog("upsertClientBloomFilter", queryStart) 288 | return catchErrors(err) 289 | } 290 | 291 | // Deletes all ClientBloomFilter with Epoch <= the given epoch 292 | // Returns an error if no matching ClientBloomFilter exist 293 | func (d *DatabaseImpl) DeleteClientFiltersBeforeEpoch(epoch uint32) error { 294 | return catchErrors(d.db.Delete(ClientBloomFilter{}, "epoch <= ?", epoch).Error) 295 | } 296 | 297 | // Returns the lowest FirstRound value from ClientBloomFilter 298 | // Or an error if no ClientBloomFilter exist 299 | func (d *DatabaseImpl) GetLowestBloomRound() (uint64, error) { 300 | result := &ClientBloomFilter{} 301 | err := d.db.Order("first_round asc").Take(result).Error 302 | if err != nil { 303 | return 0, catchErrors(err) 304 | } 305 | jww.TRACE.Printf("Obtained lowest ClientBloomFilter FirstRound from DB: %d", result.FirstRound) 306 | return result.FirstRound, nil 307 | } 308 | 309 | // Returns multiple Rounds from database with the given ids 310 | // Or an error if no matching Rounds exist 311 | func (d *DatabaseImpl) GetRounds(ids []id.Round) ([]*Round, error) { 312 | // Convert IDs to plain numbers 313 | plainIds := make([]uint64, len(ids)) 314 | for i, v := range ids { 315 | plainIds[i] = uint64(v) 316 | } 317 | 318 | // Execute the query 319 | results := make([]*Round, 0) 320 | err := d.db.Where("id IN (?)", plainIds).Find(&results).Error 321 | 322 | return results, catchErrors(err) 323 | } 324 | -------------------------------------------------------------------------------- /storage/database.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Handles low level database control and interfaces 9 | 10 | package storage 11 | 12 | import ( 13 | "context" 14 | "fmt" 15 | "github.com/pkg/errors" 16 | jww "github.com/spf13/jwalterweatherman" 17 | "gitlab.com/xx_network/primitives/id" 18 | "gitlab.com/xx_network/primitives/id/ephemeral" 19 | "gorm.io/driver/postgres" 20 | "gorm.io/gorm" 21 | "gorm.io/gorm/logger" 22 | "sync" 23 | "time" 24 | ) 25 | 26 | // Interface declaration for storage methods 27 | type database interface { 28 | UpsertState(state *State) error 29 | GetStateValue(key string) (string, error) 30 | 31 | GetClient(id *id.ID) (*Client, error) 32 | UpsertClient(client *Client) error 33 | 34 | GetRound(id id.Round) (*Round, error) 35 | UpsertRound(round *Round) error 36 | deleteRound(ts time.Time) error 37 | 38 | countMixedMessagesByRound(roundId id.Round) (uint64, bool, error) 39 | getMixedMessages(recipientId ephemeral.Id, roundId id.Round) ([]*MixedMessage, error) 40 | InsertMixedMessages(cr *ClientRound) error 41 | deleteMixedMessages(ts time.Time) error 42 | 43 | GetClientBloomFilters(recipientId ephemeral.Id, startEpoch, endEpoch uint32) ([]*ClientBloomFilter, error) 44 | upsertClientBloomFilter(filter *ClientBloomFilter) error 45 | DeleteClientFiltersBeforeEpoch(epoch uint32) error 46 | 47 | // TODO: Currently not used. May want to remove. 48 | GetLowestBloomRound() (uint64, error) 49 | GetRounds(ids []id.Round) ([]*Round, error) 50 | } 51 | 52 | // DatabaseImpl implements the database interface with an underlying DB 53 | type DatabaseImpl struct { 54 | db *gorm.DB // Stored database connection 55 | } 56 | 57 | // MapImpl implements the database interface with an underlying Map 58 | type MapImpl struct { 59 | states map[string]string 60 | statesLock sync.RWMutex 61 | clients map[id.ID]*Client 62 | clientsLock sync.RWMutex 63 | rounds map[id.Round]*Round 64 | roundsLock sync.RWMutex 65 | clientRounds map[uint64]*ClientRound 66 | clientRoundsLock sync.RWMutex 67 | mixedMessages MixedMessageMap 68 | bloomFilters BloomFilterMap 69 | } 70 | 71 | // MixedMessageMap contains a list of MixedMessage sorted into two maps so that 72 | // they can key on RoundId and RecipientId. All messages are stored by their 73 | // unique ID. 74 | type MixedMessageMap struct { 75 | RoundId map[id.Round]map[int64]map[uint64]*MixedMessage 76 | RecipientId map[int64]map[id.Round]map[uint64]*MixedMessage 77 | RoundIdCount map[id.Round]uint64 78 | IdTrack uint64 79 | sync.RWMutex 80 | } 81 | 82 | // BloomFilterMap contains a list of ClientBloomFilter sorted in a map that can key on RecipientId. 83 | type BloomFilterMap struct { 84 | RecipientId map[int64]*ClientBloomFilterList 85 | primaryKey *uint64 86 | sync.RWMutex 87 | } 88 | 89 | type ClientBloomFilterList struct { 90 | list [][]*ClientBloomFilter 91 | start uint32 92 | } 93 | 94 | // State is a Key-Value store used for persisting Gateway information 95 | type State struct { 96 | Key string `gorm:"primaryKey"` 97 | Value string `gorm:"not null"` 98 | } 99 | 100 | // Enumerates various Keys in the State table. 101 | const ( 102 | PeriodKey = "Period" 103 | LastUpdateKey = "LastUpdateId" 104 | KnownRoundsKey = "KnownRoundsV3" 105 | HttpsCertificateKey = "HttpsCertificate" 106 | ) 107 | 108 | // Client and its associated keys. 109 | type Client struct { 110 | Id []byte `gorm:"primaryKey"` 111 | Key []byte `gorm:"not null"` 112 | } 113 | 114 | // Round represents the Round information that is relevant to Gateways. 115 | type Round struct { 116 | Id uint64 `gorm:"primaryKey;autoIncrement:false"` 117 | UpdateId uint64 `gorm:"unique;not null"` 118 | InfoBlob []byte 119 | LastUpdated time.Time `gorm:"index;not null"` // Timestamp of most recent Update 120 | } 121 | 122 | // ClientRound represents the Round information that is relevant to Clients. 123 | type ClientRound struct { 124 | Id uint64 `gorm:"primaryKey;autoIncrement:false"` 125 | Timestamp time.Time `gorm:"index;not null"` // Round Realtime timestamp 126 | 127 | Messages []MixedMessage `gorm:"foreignKey:RoundId;constraint:OnDelete:CASCADE"` 128 | } 129 | 130 | type ClientBloomFilter struct { 131 | Id uint64 `gorm:"primaryKey;autoIncrement:true"` 132 | // Pointer to enforce zero-value reading in ORM. 133 | // Additionally, we desire to make composite indexes on the more distinct column first. 134 | RecipientId *int64 `gorm:"index:idx_client_bloom_filters_recipient_id_epoch,priority:1;not null"` 135 | Epoch uint32 `gorm:"index:idx_client_bloom_filters_recipient_id_epoch,priority:2;not null"` 136 | FirstRound uint64 `gorm:"index;not null"` 137 | RoundRange uint32 `gorm:"not null"` 138 | Filter []byte `gorm:"not null"` 139 | Uses uint32 `gorm:"not null;default:0"` // Keep track of how many times used 140 | } 141 | 142 | type MixedMessage struct { 143 | Id uint64 `gorm:"primaryKey;autoIncrement:true"` 144 | RoundId uint64 `gorm:"index;not null;references rounds(id)"` 145 | RecipientId int64 `gorm:"index;not null"` 146 | MessageContents []byte `gorm:"not null"` 147 | } 148 | 149 | // NewMixedMessage creates a new MixedMessage object with the given attributes. 150 | // NOTE: Do not modify the MixedMessage.Id attribute. 151 | func NewMixedMessage(roundId id.Round, recipientId ephemeral.Id, messageContentsA, messageContentsB []byte) *MixedMessage { 152 | 153 | messageContents := make([]byte, len(messageContentsA)+len(messageContentsB)) 154 | copy(messageContents[:len(messageContentsA)], messageContentsA) 155 | copy(messageContents[len(messageContentsA):], messageContentsB) 156 | 157 | return &MixedMessage{ 158 | RoundId: uint64(roundId), 159 | RecipientId: recipientId.Int64(), 160 | MessageContents: messageContents, 161 | } 162 | } 163 | 164 | // GetMessageContents return the separated message contents of the MixedMessage. 165 | func (m *MixedMessage) GetMessageContents() (messageContentsA, messageContentsB []byte) { 166 | splitPosition := len(m.MessageContents) / 2 167 | messageContentsA = m.MessageContents[:splitPosition] 168 | messageContentsB = m.MessageContents[splitPosition:] 169 | return 170 | } 171 | 172 | // Initialize the database interface with database backend 173 | // Returns a database interface and error 174 | func newDatabase(username, password, dbName, address, 175 | port string, devMode bool) (database, error) { 176 | 177 | var err error 178 | var db *gorm.DB 179 | // Connect to the database if the correct information is provided 180 | if address != "" && port != "" { 181 | // Create the database connection 182 | connectString := fmt.Sprintf( 183 | "host=%s port=%s user=%s dbname=%s sslmode=disable", 184 | address, port, username, dbName) 185 | // Handle empty database password 186 | if len(password) > 0 { 187 | connectString += fmt.Sprintf(" password=%s", password) 188 | } 189 | db, err = gorm.Open(postgres.Open(connectString), &gorm.Config{ 190 | Logger: logger.New(jww.TRACE, logger.Config{LogLevel: logger.Info}), 191 | }) 192 | } 193 | 194 | // Return the map-backend interface 195 | // in the event there is a database error or information is not provided 196 | if (address == "" || port == "") || err != nil { 197 | 198 | var failReason string 199 | if err != nil { 200 | failReason = fmt.Sprintf("Unable to initialize database backend: %+v", err) 201 | jww.WARN.Printf(failReason) 202 | } else { 203 | failReason = "Database backend connection information not provided" 204 | jww.WARN.Printf(failReason) 205 | } 206 | 207 | if !devMode { 208 | jww.FATAL.Panicf("Gateway cannot run in production "+ 209 | "without a database: %s", failReason) 210 | } 211 | 212 | defer jww.INFO.Println("Map backend initialized successfully!") 213 | 214 | mapImpl := &MapImpl{ 215 | clients: map[id.ID]*Client{}, 216 | rounds: map[id.Round]*Round{}, 217 | states: map[string]string{}, 218 | 219 | mixedMessages: MixedMessageMap{ 220 | RoundId: map[id.Round]map[int64]map[uint64]*MixedMessage{}, 221 | RecipientId: map[int64]map[id.Round]map[uint64]*MixedMessage{}, 222 | RoundIdCount: map[id.Round]uint64{}, 223 | IdTrack: 0, 224 | }, 225 | bloomFilters: BloomFilterMap{ 226 | RecipientId: map[int64]*ClientBloomFilterList{}, 227 | }, 228 | clientRounds: map[uint64]*ClientRound{}, 229 | } 230 | 231 | return database(mapImpl), nil 232 | } 233 | 234 | // Get and configure the internal database ConnPool 235 | sqlDb, err := db.DB() 236 | if err != nil { 237 | return database(&DatabaseImpl{}), errors.Errorf( 238 | "Unable to configure database connection pool: %+v", err) 239 | } 240 | // SetMaxIdleConns sets the maximum number of connections in the idle connection pool. 241 | sqlDb.SetMaxIdleConns(10) 242 | // SetMaxOpenConns sets the maximum number of open connections to the Database. 243 | sqlDb.SetMaxOpenConns(50) 244 | // SetConnMaxLifetime sets the maximum amount of time a connection may be idle. 245 | sqlDb.SetConnMaxIdleTime(10 * time.Minute) 246 | // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. 247 | sqlDb.SetConnMaxLifetime(12 * time.Hour) 248 | 249 | // Ensure database structure is up-to-date. 250 | err = migrate(db) 251 | if err != nil { 252 | return database(&DatabaseImpl{}), errors.Errorf( 253 | "Failed to migrate database: %+v", err) 254 | } 255 | 256 | // Build the interface 257 | di := &DatabaseImpl{ 258 | db: db, 259 | } 260 | 261 | jww.INFO.Println("Database backend initialized successfully!") 262 | return database(di), nil 263 | } 264 | 265 | // migrate is a basic database structure migrator. 266 | func migrate(db *gorm.DB) error { 267 | migrateTimestamp := time.Now() 268 | 269 | // Perform automatic migrations of basic table structure. 270 | // WARNING: Order is important. Do not change without database testing. 271 | err := db.AutoMigrate(&Client{}, &Round{}, &ClientRound{}, 272 | &MixedMessage{}, &ClientBloomFilter{}, State{}) 273 | if err != nil { 274 | return err 275 | } 276 | 277 | // Determine the current version of the database via structural checks. 278 | currentVersion := 0 279 | columns, err := db.Migrator().ColumnTypes(&ClientBloomFilter{}) 280 | if err != nil { 281 | return err 282 | } 283 | for _, column := range columns { 284 | if isPrimaryKey, _ := column.PrimaryKey(); column.Name() == "id" && isPrimaryKey { 285 | currentVersion = 1 286 | break 287 | } 288 | } 289 | if !db.Migrator().HasIndex(ClientBloomFilter{}, "idx_client_bloom_filters_recipient_id") { 290 | currentVersion = 2 291 | } 292 | 293 | jww.INFO.Printf("Current database version: v%d", currentVersion) 294 | 295 | // Perform any required manual migrations. 296 | if minVersion := 1; currentVersion < minVersion { 297 | jww.INFO.Printf("Performing database migration from v%d -> v%d", 298 | currentVersion, minVersion) 299 | ctx, cancel := context.WithTimeout(context.Background(), dbTimeout*5) 300 | err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { 301 | err := tx.Exec("ALTER TABLE client_bloom_filters DROP CONSTRAINT client_bloom_filters_pkey;").Error 302 | if err != nil { 303 | return err 304 | } 305 | 306 | // Commit 307 | return tx.Exec("ALTER TABLE client_bloom_filters ADD PRIMARY KEY (id);").Error 308 | }) 309 | cancel() 310 | if err != nil { 311 | return err 312 | } 313 | currentVersion = minVersion 314 | } 315 | if minVersion := 2; currentVersion < minVersion { 316 | jww.INFO.Printf("Performing database migration from v%d -> v%d", 317 | currentVersion, minVersion) 318 | ctx, cancel := context.WithTimeout(context.Background(), dbTimeout*5) 319 | err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { 320 | err := tx.Migrator().DropIndex(ClientBloomFilter{}, "idx_client_bloom_filters_epoch") 321 | if err != nil { 322 | return err 323 | } 324 | 325 | // Commit 326 | return tx.Migrator().DropIndex(ClientBloomFilter{}, "idx_client_bloom_filters_recipient_id") 327 | }) 328 | cancel() 329 | if err != nil { 330 | return err 331 | } 332 | currentVersion = minVersion 333 | } 334 | 335 | jww.DEBUG.Printf("Database initialization took %s", 336 | time.Now().Sub(migrateTimestamp).String()) 337 | return nil 338 | } 339 | -------------------------------------------------------------------------------- /cmd/root.go: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////// 2 | // Copyright © 2022 xx foundation // 3 | // // 4 | // Use of this source code is governed by a license that can be found in the // 5 | // LICENSE file. // 6 | //////////////////////////////////////////////////////////////////////////////// 7 | 8 | // Package cmd initializes the CLI and config parsers as well as the logger. 9 | package cmd 10 | 11 | import ( 12 | "crypto/sha256" 13 | "encoding/binary" 14 | "fmt" 15 | "os" 16 | "os/signal" 17 | "runtime/pprof" 18 | "strconv" 19 | "strings" 20 | "syscall" 21 | "time" 22 | 23 | "github.com/spf13/cobra" 24 | jww "github.com/spf13/jwalterweatherman" 25 | "github.com/spf13/viper" 26 | "gitlab.com/elixxir/comms/mixmessages" 27 | "gitlab.com/elixxir/crypto/cmix" 28 | "gitlab.com/elixxir/gateway/storage" 29 | "gitlab.com/xx_network/primitives/id" 30 | "gitlab.com/xx_network/primitives/utils" 31 | "google.golang.org/grpc/grpclog" 32 | ) 33 | 34 | // Flags to import from command line or config file 35 | var ( 36 | cfgFile, idfPath, logPath string 37 | certPath, keyPath, serverCertPath, 38 | permissioningCertPath string 39 | logLevel uint // 0 = info, 1 = debug, >1 = trace 40 | gwPort int 41 | validConfig bool 42 | 43 | kr int 44 | 45 | // For gossip protocol 46 | bufferExpiration, monitorThreadFrequency time.Duration 47 | 48 | // For rate limiting 49 | capacity, leakedTokens uint32 50 | leakDuration, pollDuration, bucketMaxAge time.Duration 51 | ) 52 | 53 | // RootCmd represents the base command when called without any sub-commands 54 | var rootCmd = &cobra.Command{ 55 | Use: "gateway", 56 | Short: "Runs a cMix gateway", 57 | Long: `The cMix gateways coordinate communications between servers and clients`, 58 | Args: cobra.NoArgs, 59 | Run: func(cmd *cobra.Command, args []string) { 60 | initConfig() 61 | initLog() 62 | profileOut := viper.GetString("profile-cpu") 63 | if profileOut != "" { 64 | f, err := os.Create(profileOut) 65 | if err != nil { 66 | jww.FATAL.Panicf("%+v", err) 67 | } 68 | pprof.StartCPUProfile(f) 69 | } 70 | 71 | params := InitParams(viper.GetViper()) 72 | // Build gateway implementation object 73 | gateway := NewGatewayInstance(params) 74 | err := gateway.SetPeriod() 75 | if err != nil { 76 | jww.FATAL.Panicf("Unable to set gateway period: %+v", err) 77 | } 78 | 79 | // start gateway network interactions 80 | for { 81 | err := gateway.InitNetwork() 82 | if err == nil { 83 | break 84 | } 85 | errMsg := err.Error() 86 | tic := strings.Contains(errMsg, "transport is closing") 87 | cde := strings.Contains(errMsg, "DeadlineExceeded") 88 | if tic || cde { 89 | if gateway.Comms != nil { 90 | gateway.Comms.Shutdown() 91 | } 92 | 93 | jww.ERROR.Printf("Cannot connect to node, "+ 94 | "retrying in 10s: %+v", err) 95 | time.Sleep(10 * time.Second) 96 | continue 97 | } 98 | jww.FATAL.Panicf(err.Error()) 99 | } 100 | 101 | if params.DevMode { 102 | jww.WARN.Printf("Starting in developer mode (devMode)" + 103 | " -- this will break on betanet or mainnet...") 104 | addPrecannedIDs(gateway) 105 | } 106 | 107 | jww.INFO.Printf("Starting xx network gateway v%s", SEMVER) 108 | 109 | // Begin gateway persistent components 110 | if !params.DisableGossip { 111 | jww.INFO.Println("Gossip is enabled") 112 | gateway.StartPeersThread() 113 | } 114 | 115 | gateway.Start() 116 | 117 | // Open Signal Handler for safe program exit 118 | stopCh := ReceiveExitSignal() 119 | 120 | // Block forever to prevent the program ending 121 | // Block until a signal is received, then call the function 122 | // provided 123 | select { 124 | case <-stopCh: 125 | jww.INFO.Printf( 126 | "Received Exit (SIGTERM or SIGINT) signal...\n") 127 | select { 128 | case gateway.ipAddrRateLimitQuit <- struct{}{}: 129 | case <-time.After(20 * time.Second): 130 | jww.ERROR.Println("Failed to stop ipAddrRateLimit") 131 | } 132 | 133 | select { 134 | case gateway.idRateLimitQuit <- struct{}{}: 135 | case <-time.After(20 * time.Second): 136 | jww.ERROR.Println("Failed to stop idRateLimitQuit") 137 | } 138 | 139 | select { 140 | case gateway.earliestRoundQuitChan <- struct{}{}: 141 | case <-time.After(20 * time.Second): 142 | jww.ERROR.Println("Failed to stop earliestRoundQuitChan") 143 | } 144 | 145 | select { 146 | case gateway.replaceCertificateQuit <- struct{}{}: 147 | case <-time.After(20 * time.Second): 148 | jww.ERROR.Println("Failed to stop replace certificate thread") 149 | } 150 | 151 | gateway.Comms.Shutdown() 152 | 153 | if profileOut != "" { 154 | pprof.StopCPUProfile() 155 | } 156 | } 157 | 158 | }, 159 | } 160 | 161 | // ReceiveExitSignal signals a stop chan when it receives 162 | // SIGTERM or SIGINT 163 | func ReceiveExitSignal() chan os.Signal { 164 | // Set up channel on which to send signal notifications. 165 | // We must use a buffered channel or risk missing the signal 166 | // if we're not ready to receive when the signal is sent. 167 | c := make(chan os.Signal, 1) 168 | signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) 169 | return c 170 | } 171 | 172 | func addPrecannedIDs(gateway *Instance) { 173 | // add precannedIDs 174 | for i := uint64(0); i < 41; i++ { 175 | u := new(id.ID) 176 | binary.BigEndian.PutUint64(u[:], i) 177 | u.SetType(id.User) 178 | h := sha256.New() 179 | h.Reset() 180 | h.Write([]byte(strconv.Itoa(int(4000 + i)))) 181 | baseKey := gateway.NetInf.GetCmixGroup().NewIntFromBytes(h.Sum(nil)) 182 | jww.INFO.Printf("Added precan transmisssion key: %v", 183 | baseKey.Bytes()) 184 | cgKey := cmix.GenerateClientGatewayKey(baseKey) 185 | // Insert client information to database 186 | newClient := &storage.Client{ 187 | Id: u.Marshal(), 188 | Key: cgKey, 189 | } 190 | 191 | err := gateway.storage.UpsertClient(newClient) 192 | if err != nil { 193 | jww.ERROR.Printf("Unable to insert precanned client: %+v", err) 194 | } 195 | } 196 | jww.INFO.Printf("Added precanned users") 197 | } 198 | 199 | // Execute adds all child commands to the root command and sets flags 200 | // appropriately. This is called by main.main(). It only needs to 201 | // happen once to the RootCmd. 202 | func Execute() { 203 | if err := rootCmd.Execute(); err != nil { 204 | jww.ERROR.Println(err) 205 | os.Exit(1) 206 | } 207 | } 208 | 209 | // init is the initialization function for Cobra which defines commands 210 | // and flags. 211 | func init() { 212 | // NOTE: The point of init() is to be declarative. 213 | // There is one init in each sub command. Do not put variable declarations 214 | // here, and ensure all the Flags are of the *P variety, unless there's a 215 | // very good reason not to have them as local Params to sub command." 216 | 217 | // Here you will define your flags and configuration settings. 218 | // Cobra supports persistent flags, which, if defined here, 219 | // will be global for your application. 220 | rootCmd.Flags().StringVarP(&cfgFile, "config", "c", "", 221 | "Path to load the Gateway configuration file from. (Required)") 222 | 223 | rootCmd.Flags().IntP("port", "p", -1, "Port for Gateway to listen on."+ 224 | "Gateway must be the only listener on this port. (Required)") 225 | err := viper.BindPFlag("port", rootCmd.Flags().Lookup("port")) 226 | handleBindingError(err, "port") 227 | 228 | rootCmd.Flags().StringVar(&idfPath, "idfPath", "", 229 | "Path to where the identity file (IDF) is saved. The IDF stores the "+ 230 | "Gateway's Node's network identity. This is used by the wrapper "+ 231 | "management script. (Required)") 232 | err = viper.BindPFlag("idfPath", rootCmd.Flags().Lookup("idfPath")) 233 | handleBindingError(err, "idfPath") 234 | 235 | rootCmd.Flags().UintVarP(&logLevel, "logLevel", "l", 0, 236 | "Level of debugging to print (0 = info, 1 = debug, >1 = trace).") 237 | err = viper.BindPFlag("logLevel", rootCmd.Flags().Lookup("logLevel")) 238 | handleBindingError(err, "logLevel") 239 | 240 | rootCmd.Flags().StringVar(&logPath, "log", "log/gateway.log", 241 | "Path where log file will be saved.") 242 | err = viper.BindPFlag("log", rootCmd.Flags().Lookup("log")) 243 | handleBindingError(err, "log") 244 | 245 | rootCmd.Flags().String("cmixAddress", "", 246 | "The IP address of the machine running cMix that the Gateway "+ 247 | "communicates with. Expects an IPv4 address with a port. (Required)") 248 | err = viper.BindPFlag("cmixAddress", rootCmd.Flags().Lookup("cmixAddress")) 249 | handleBindingError(err, "cmixAddress") 250 | 251 | rootCmd.Flags().StringVar(&certPath, "certPath", "", 252 | "Path to the self-signed TLS certificate for Gateway. Expects PEM "+ 253 | "format. (Required)") 254 | err = viper.BindPFlag("certPath", rootCmd.Flags().Lookup("certPath")) 255 | handleBindingError(err, "certPath") 256 | 257 | rootCmd.Flags().StringVar(&keyPath, "keyPath", "", 258 | "Path to the private key associated with the self-signed TLS "+ 259 | "certificate. (Required)") 260 | err = viper.BindPFlag("keyPath", rootCmd.Flags().Lookup("keyPath")) 261 | handleBindingError(err, "keyPath") 262 | 263 | rootCmd.Flags().StringVar(&serverCertPath, "cmixCertPath", "", 264 | "Path to the self-signed TLS certificate for cMix. Expects PEM "+ 265 | "format. (Required)") 266 | err = viper.BindPFlag("cmixCertPath", rootCmd.Flags().Lookup("cmixCertPath")) 267 | handleBindingError(err, "cmixCertPath") 268 | 269 | rootCmd.Flags().StringVar(&permissioningCertPath, "schedulingCertPath", "", 270 | "Path to the self-signed TLS certificate for the Scheduling server. "+ 271 | "Expects PEM format. (Required)") 272 | err = viper.BindPFlag("schedulingCertPath", rootCmd.Flags().Lookup("schedulingCertPath")) 273 | handleBindingError(err, "schedulingCertPath") 274 | 275 | // RATE LIMITING FLAGS 276 | rootCmd.Flags().Uint32Var(&capacity, "capacity", 20, 277 | "The capacity of rate-limiting buckets in the map.") 278 | err = viper.BindPFlag("capacity", rootCmd.Flags().Lookup("capacity")) 279 | handleBindingError(err, "Rate_Limiting_Capacity") 280 | 281 | rootCmd.Flags().Uint32Var(&leakedTokens, "leakedTokens", 3, 282 | "The rate that the rate limiting bucket leaks tokens at [tokens/ns].") 283 | err = viper.BindPFlag("leakedTokens", rootCmd.Flags().Lookup("leakedTokens")) 284 | handleBindingError(err, "Rate_Limiting_LeakedTokens") 285 | 286 | rootCmd.Flags().DurationVar(&leakDuration, "leakDuration", 1*time.Millisecond, 287 | "How often the number of leaked tokens is leaked from the bucket.") 288 | err = viper.BindPFlag("leakDuration", rootCmd.Flags().Lookup("leakDuration")) 289 | handleBindingError(err, "Rate_Limiting_LeakDuration") 290 | 291 | rootCmd.Flags().DurationVar(&pollDuration, "pollDuration", 10*time.Second, 292 | "How often inactive buckets are removed.") 293 | err = viper.BindPFlag("pollDuration", rootCmd.Flags().Lookup("pollDuration")) 294 | handleBindingError(err, "Rate_Limiting_PollDuration") 295 | 296 | rootCmd.Flags().DurationVar(&bucketMaxAge, "bucketMaxAge", 10*time.Second, 297 | "The max age of a bucket without activity before it is removed.") 298 | err = viper.BindPFlag("bucketMaxAge", rootCmd.Flags().Lookup("bucketMaxAge")) 299 | handleBindingError(err, "Rate_Limiting_BucketMaxAge") 300 | 301 | // GOSSIP MANAGER FLAGS 302 | rootCmd.Flags().BoolP("enableGossip", "", false, 303 | "Feature flag for in progress gossip functionality") 304 | err = viper.BindPFlag("enableGossip", rootCmd.Flags().Lookup("enableGossip")) 305 | handleBindingError(err, "Enable_Gossip") 306 | 307 | rootCmd.Flags().DurationVar(&bufferExpiration, "bufferExpiration", 300*time.Second, 308 | "How long a message record should last in the gossip buffer if it "+ 309 | "arrives before the Gateway starts handling the gossip.") 310 | err = viper.BindPFlag("bufferExpiration", rootCmd.Flags().Lookup("bufferExpiration")) 311 | handleBindingError(err, "Rate_Limiting_BufferExpiration") 312 | 313 | rootCmd.Flags().DurationVar(&monitorThreadFrequency, "monitorThreadFrequency", 150*time.Second, 314 | "Frequency with which to check the gossip buffer.") 315 | err = viper.BindPFlag("monitorThreadFrequency", rootCmd.Flags().Lookup("monitorThreadFrequency")) 316 | handleBindingError(err, "Rate_Limiting_MonitorThreadFrequency") 317 | 318 | rootCmd.Flags().IntVar(&kr, "kr", 1024, // fixme: probably should be orders of magnitudes bigger? 319 | "Amount of rounds to keep track of in kr") 320 | err = viper.BindPFlag("kr", rootCmd.Flags().Lookup("kr")) 321 | handleBindingError(err, "Known_Rounds") 322 | 323 | // DevMode enables developer mode, which allows you to run without 324 | // a database and with unsafe "precanned" users 325 | rootCmd.Flags().BoolP("devMode", "", false, 326 | "Run in development/testing mode. Do not use on beta or main "+ 327 | "nets") 328 | err = viper.BindPFlag("devMode", rootCmd.Flags().Lookup("devMode")) 329 | handleBindingError(err, "Rate_Limiting_MonitorThreadFrequency") 330 | _ = rootCmd.Flags().MarkHidden("devMode") 331 | 332 | rootCmd.Flags().String("profile-cpu", "", 333 | "Enable cpu profiling to this file") 334 | viper.BindPFlag("profile-cpu", rootCmd.Flags().Lookup("profile-cpu")) 335 | 336 | } 337 | 338 | // Handle flag binding errors 339 | func handleBindingError(err error, flag string) { 340 | if err != nil { 341 | jww.FATAL.Panicf("Error on binding flag \"%s\":%+v", flag, err) 342 | } 343 | } 344 | 345 | // initConfig reads in config file and ENV variables if set. 346 | func initConfig() { 347 | validConfig = true 348 | if cfgFile == "" { 349 | jww.FATAL.Panicf("No config file provided.") 350 | } 351 | 352 | cfgFile, _ = utils.ExpandPath(cfgFile) 353 | viper.SetConfigFile(cfgFile) 354 | viper.AutomaticEnv() // read in environment variables that match 355 | 356 | // If a config file is found, read it in. 357 | if err := viper.ReadInConfig(); err != nil { 358 | fmt.Printf("Unable to read config file (%s): %+v", cfgFile, err.Error()) 359 | validConfig = false 360 | } 361 | 362 | } 363 | 364 | // initLog initializes logging thresholds and the log path. 365 | func initLog() { 366 | // Set log file 367 | logPath = viper.GetString("log") 368 | logFile, err := os.OpenFile(logPath, 369 | os.O_CREATE|os.O_WRONLY|os.O_APPEND, 370 | 0644) 371 | if err != nil { 372 | fmt.Printf("Could not open log file %s!\n", logPath) 373 | jww.SetLogOutput(os.Stderr) 374 | } else { 375 | jww.SetLogOutput(logFile) 376 | } 377 | 378 | // Check the level of logs to display 379 | vipLogLevel := viper.GetUint("logLevel") 380 | if vipLogLevel > 1 { 381 | // Set GRPC trace logging 382 | grpcLogger := grpclog.NewLoggerV2WithVerbosity( 383 | logFile, logFile, logFile, 99) 384 | grpclog.SetLoggerV2(grpcLogger) 385 | 386 | // Turn on trace logs 387 | jww.SetLogThreshold(jww.LevelTrace) 388 | jww.SetStdoutThreshold(jww.LevelTrace) 389 | mixmessages.TraceMode() 390 | } else if vipLogLevel == 1 { 391 | // Turn on debugging logs 392 | jww.SetLogThreshold(jww.LevelDebug) 393 | jww.SetStdoutThreshold(jww.LevelDebug) 394 | mixmessages.DebugMode() 395 | } else { 396 | // Turn on info logs 397 | jww.SetLogThreshold(jww.LevelInfo) 398 | jww.SetStdoutThreshold(jww.LevelInfo) 399 | } 400 | } 401 | --------------------------------------------------------------------------------