├── .gitignore ├── LICENSE ├── README.md ├── go.mod ├── go.sum ├── shuffle.go ├── shuffle_bench_test.go ├── shuffle_test.go ├── spec ├── shuffle_test_gen.py └── tests.csv └── test_util.go /.gitignore: -------------------------------------------------------------------------------- 1 | ### Go template 2 | # Binaries for programs and plugins 3 | *.exe 4 | *.exe~ 5 | *.dll 6 | *.so 7 | *.dylib 8 | 9 | # Test binary, build with `go test -c` 10 | *.test 11 | 12 | # Output of the go coverage tool, specifically when used with LiteIDE 13 | *.out 14 | 15 | venv 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 @protolambda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # eth2-shuffle 2 | 3 | Shuffling algorithm for ETH 2.0. 4 | 5 | Implemented in four ways: 6 | 7 | 1. Shuffle elements in array 8 | 2. Shuffle individual element (get permuted index) 9 | 3. Un-shuffle (i.e. reverse of shuffle effect) elements in array 10 | 4. Un-shuffle individual element (get un-permuted index) 11 | 12 | Implementation can be found in `shuffle.go` 13 | 14 | Note: you can change the hash-function and number of rounds. 15 | Tests use SHA-256 (upcoming in ETH 2.0 spec) and 90 rounds (already a constant). 16 | 17 | ## Tests 18 | 19 | Tests can be found in `shuffle_test.go`. 20 | 21 | There are N parallel test cases, generated by code as defined in the spec. Generation code can be found here: `spec/shuffle_test_gen.py`. 22 | Each of these cases as a sub-test for each of the four shuffle-functions mentioned earlier. 23 | 24 | ## Benchmarks 25 | 26 | These benchmarks are ran on a dev-laptop, nothing special. 27 | The primary goal of these benchmarks is to compare per-index shuffling and complete shuffling, not to make it faster than XYZ. 28 | Feel free to run them on your own hardware to compare with your implementations. 29 | 30 | With `-test.benchtime=10s`: 31 | 32 | ``` 33 | goos: linux 34 | goarch: amd64 35 | pkg: eth2-shuffle 36 | 37 | BenchmarkPermuteIndex/PermuteIndex_4000000-8 300000 49013 ns/op 38 | BenchmarkPermuteIndex/PermuteIndex_40000-8 300000 48936 ns/op 39 | BenchmarkPermuteIndex/PermuteIndex_400-8 300000 48709 ns/op 40 | BenchmarkIndexComparison/Indexwise_ShuffleList_40000-8 10 1947872791 ns/op 41 | BenchmarkIndexComparison/Indexwise_ShuffleList_400-8 1000 19435826 ns/op 42 | BenchmarkShuffleList/ShuffleList_4000000-8 10 1253702761 ns/op 43 | BenchmarkShuffleList/ShuffleList_40000-8 1000 12152166 ns/op 44 | BenchmarkShuffleList/ShuffleList_400-8 100000 191813 ns/op 45 | 46 | ``` 47 | 48 | ### `PermuteIndex_X` 49 | Benchmark shuffling of a single item, in a virtual context of `X` items, which are not being shuffled. 50 | Not that the size `X` of the list does not matter much at all, 51 | it's really just bottlenecked by the performance of shuffling a single index. 52 | 53 | ### `Indexwise_ShuffleList_X` 54 | Benchmark shuffling of `X` items, but each of them individually using `PermuteIndex`. 55 | Note that there's no `4,000,000` case, it's too inefficient. 56 | Also note that shuffling `40,000` this way, is slower than shuffling a list of 100x the size, 57 | the efficient way using `ShuffleList`. 58 | 59 | ### `ShuffleList_X` 60 | Benchmark shuffling of a list of `X` items. (The efficient way, i.e. all simultaneously) 61 | 62 | 63 | ## Contributing 64 | 65 | Contributions welcome, please keep the implementation in-line with the ETH 2.0 spec. 66 | 67 | ## License 68 | 69 | MIT, see license file. 70 | 71 | 72 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/protolambda/eth2-shuffle 2 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protolambda/eth2-shuffle/fd840f1036c1f8f6d7625ffe6ff4d9c60f942876/go.sum -------------------------------------------------------------------------------- /shuffle.go: -------------------------------------------------------------------------------- 1 | package eth2_shuffle 2 | 3 | import "encoding/binary" 4 | 5 | type HashFn func(input []byte) []byte 6 | 7 | const hSeedSize = int8(32) 8 | const hRoundSize = int8(1) 9 | const hPositionWindowSize = int8(4) 10 | const hPivotViewSize = hSeedSize + hRoundSize 11 | const hTotalSize = hSeedSize + hRoundSize + hPositionWindowSize 12 | 13 | // To make it completely clear: 14 | // Memory layout hash input: 15 | // | 32 bytes for seed ... || 1 byte, round || 4 bytes for position window ...| 16 | // | <---------------------- for pivot hash ----- ... --------> || ignored for pivot hash | 17 | // | <---------------------------------------- for source hash ---------------------------------------------> | 18 | 19 | /* 20 | Return `p(index)` in a pseudorandom permutation `p` of `0...list_size-1` with ``seed`` as entropy. 21 | 22 | Utilizes 'swap or not' shuffling found in 23 | https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf 24 | See the 'generalized domain' algorithm on page 3. 25 | 26 | Eth 2.0 spec implementation here: 27 | https://github.com/ethereum/eth2.0-specs/blob/dev/specs/core/0_beacon-chain.md#get_permuted_index 28 | */ 29 | 30 | // Permute index, i.e. shuffle an individual list item without allocating a complete list. 31 | // Returns the index in the would-be shuffled list. 32 | func PermuteIndex(hashFn HashFn, rounds uint8, index uint64, listSize uint64, seed [32]byte) uint64 { 33 | return innerPermuteIndex(hashFn, rounds, index, listSize, seed, true) 34 | } 35 | 36 | // Inverse of PermuteIndex, returns original index when given the same shuffling context parameters and permuted index. 37 | func UnpermuteIndex(hashFn HashFn, rounds uint8, index uint64, listSize uint64, seed [32]byte) uint64 { 38 | return innerPermuteIndex(hashFn, rounds, index, listSize, seed, false) 39 | } 40 | 41 | func innerPermuteIndex(hashFn HashFn, rounds uint8, index uint64, listSize uint64, seed [32]byte, dir bool) uint64 { 42 | if rounds == 0 { 43 | return index 44 | } 45 | buf := make([]byte, hTotalSize, hTotalSize) 46 | r := uint8(0) 47 | if !dir { 48 | // Start at last round. 49 | // Iterating through the rounds in reverse, un-swaps everything, effectively un-shuffling the list. 50 | r = rounds - 1 51 | } 52 | // Seed is always the first 32 bytes of the hash input, we never have to change this part of the buffer. 53 | copy(buf[:hSeedSize], seed[:]) 54 | for { 55 | // spec: pivot = bytes_to_int(hash(seed + int_to_bytes1(round))[0:8]) % list_size 56 | // This is the "int_to_bytes1(round)", appended to the seed. 57 | buf[hSeedSize] = r 58 | // Seed is already in place, now just hash the correct part of the buffer, and take a uint64 from it, 59 | // and modulo it to get a pivot within range. 60 | pivot := binary.LittleEndian.Uint64(hashFn(buf[:hPivotViewSize])[:8]) % listSize 61 | // spec: flip = (pivot - index) % list_size 62 | // Add extra list_size to prevent underflows. 63 | // "flip" will be the other side of the pair 64 | flip := (pivot + (listSize - index)) % listSize 65 | // spec: position = max(index, flip) 66 | // Why? Don't do double work: we consider every pair only once. 67 | // (Otherwise we would swap it back in place) 68 | // Pick the highest index of the pair as position to retrieve randomness with. 69 | position := index 70 | if flip > position { 71 | position = flip 72 | } 73 | // spec: source = hash(seed + int_to_bytes1(round) + int_to_bytes4(position // 256)) 74 | // - seed is still in 0:32 (excl., 32 bytes) 75 | // - round number is still in 32 76 | // - mix in the position for randomness, except the last byte of it, 77 | // which will be used later to select a bit from the resulting hash. 78 | binary.LittleEndian.PutUint32(buf[hPivotViewSize:], uint32(position>>8)) 79 | source := hashFn(buf) 80 | // spec: byte = source[(position % 256) // 8] 81 | // Effectively keep the first 5 bits of the byte value of the position, 82 | // and use it to retrieve one of the 32 (= 2^5) bytes of the hash. 83 | byteV := source[(position&0xff)>>3] 84 | // Using the last 3 bits of the position-byte, determine which bit to get from the hash-byte (8 bits, = 2^3) 85 | // spec: bit = (byte >> (position % 8)) % 2 86 | bitV := (byteV >> (position & 0x7)) & 0x1 87 | // Now that we have our "coin-flip", swap index, or don't. 88 | // If bitV, flip. 89 | if bitV == 1 { 90 | index = flip 91 | } 92 | // go forwards? 93 | if dir { 94 | // -> shuffle 95 | r++ 96 | if r == rounds { 97 | break 98 | } 99 | } else { 100 | if r == 0 { 101 | break 102 | } 103 | // -> un-shuffle 104 | r-- 105 | } 106 | } 107 | return index 108 | } 109 | 110 | /* 111 | 112 | def shuffle(list_size, seed): 113 | indices = list(range(list_size)) 114 | for round in range(90): 115 | hash_bytes = b''.join([ 116 | hash(seed + round.to_bytes(1, 'little') + (i).to_bytes(4, 'little')) 117 | for i in range((list_size + 255) // 256) 118 | ]) 119 | pivot = int.from_bytes(hash(seed + round.to_bytes(1, 'little')), 'little') % list_size 120 | 121 | powers_of_two = [1, 2, 4, 8, 16, 32, 64, 128] 122 | 123 | for i, index in enumerate(indices): 124 | flip = (pivot - index) % list_size 125 | hash_pos = index if index > flip else flip 126 | byte = hash_bytes[hash_pos // 8] 127 | if byte & powers_of_two[hash_pos % 8]: 128 | indices[i] = flip 129 | return indices 130 | 131 | Heavily-optimized version of the set-shuffling algorithm proposed by Vitalik to shuffle all items in a list together. 132 | 133 | Original here: 134 | https://github.com/ethereum/eth2.0-specs/pull/576#issue-250741806 135 | 136 | Main differences, implemented by @protolambda: 137 | - User can supply input slice to shuffle, simple provide [0,1,2,3,4, ...] to get a list of cleanly shuffled indices. 138 | - Input slice is shuffled (hence no return value), no new array is allocated 139 | - Allocations as minimal as possible: only a very minimal buffer for hashing 140 | (this should be allocated on the stack, compiler will find it with escape analysis). 141 | This is not bigger than what's used for shuffling a single index! 142 | As opposed to larger allocations (size O(n) instead of O(1)) made in the original. 143 | - Replaced pseudocode/python workarounds with bit-logic. 144 | - User can provide their own hash-function (as long as it outputs a 32 len byte slice) 145 | 146 | */ 147 | 148 | // Shuffles the list 149 | func ShuffleList(hashFn HashFn, input []uint64, rounds uint8, seed [32]byte) { 150 | innerShuffleList(hashFn, input, rounds, seed, true) 151 | } 152 | 153 | // Un-shuffles the list 154 | func UnshuffleList(hashFn HashFn, input []uint64, rounds uint8, seed [32]byte) { 155 | innerShuffleList(hashFn, input, rounds, seed, false) 156 | } 157 | 158 | // Shuffles or unshuffles, depending on the `dir` (true for shuffling, false for unshuffling 159 | func innerShuffleList(hashFn HashFn, input []uint64, rounds uint8, seed [32]byte, dir bool) { 160 | if len(input) <= 1 { 161 | // nothing to (un)shuffle 162 | return 163 | } 164 | if rounds == 0 { 165 | return 166 | } 167 | listSize := uint64(len(input)) 168 | buf := make([]byte, hTotalSize, hTotalSize) 169 | r := uint8(0) 170 | if !dir { 171 | // Start at last round. 172 | // Iterating through the rounds in reverse, un-swaps everything, effectively un-shuffling the list. 173 | r = rounds - 1 174 | } 175 | // Seed is always the first 32 bytes of the hash input, we never have to change this part of the buffer. 176 | copy(buf[:hSeedSize], seed[:]) 177 | for { 178 | // spec: pivot = bytes_to_int(hash(seed + int_to_bytes1(round))[0:8]) % list_size 179 | // This is the "int_to_bytes1(round)", appended to the seed. 180 | buf[hSeedSize] = r 181 | // Seed is already in place, now just hash the correct part of the buffer, and take a uint64 from it, 182 | // and modulo it to get a pivot within range. 183 | pivot := binary.LittleEndian.Uint64(hashFn(buf[:hPivotViewSize])[:8]) % listSize 184 | 185 | // Split up the for-loop in two: 186 | // 1. Handle the part from 0 (incl) to pivot (incl). This is mirrored around (pivot / 2) 187 | // 2. Handle the part from pivot (excl) to N (excl). This is mirrored around ((pivot / 2) + (size/2)) 188 | // The pivot defines a split in the array, with each of the splits mirroring their data within the split. 189 | // Print out some example even/odd sized index lists, with some even/odd pivots, 190 | // and you can deduce how the mirroring works exactly. 191 | // Note that the mirror is strict enough to not consider swapping the index @mirror with itself. 192 | mirror := (pivot + 1) >> 1 193 | // Since we are iterating through the "positions" in order, we can just repeat the hash every 256th position. 194 | // No need to pre-compute every possible hash for efficiency like in the example code. 195 | // We only need it consecutively (we are going through each in reverse order however, but same thing) 196 | // 197 | // spec: source = hash(seed + int_to_bytes1(round) + int_to_bytes4(position // 256)) 198 | // - seed is still in 0:32 (excl., 32 bytes) 199 | // - round number is still in 32 200 | // - mix in the position for randomness, except the last byte of it, 201 | // which will be used later to select a bit from the resulting hash. 202 | // We start from the pivot position, and work back to the mirror position (of the part left to the pivot). 203 | // This makes us process each pear exactly once (instead of unnecessarily twice, like in the spec) 204 | binary.LittleEndian.PutUint32(buf[hPivotViewSize:], uint32(pivot>>8)) 205 | source := hashFn(buf) 206 | byteV := source[(pivot&0xff)>>3] 207 | for i, j := uint64(0), pivot; i < mirror; i, j = i+1, j-1 { 208 | // The pair is i,j. With j being the bigger of the two, hence the "position" identifier of the pair. 209 | // Every 256th bit (aligned to j). 210 | if j&0xff == 0xff { 211 | // just overwrite the last part of the buffer, reuse the start (seed, round) 212 | binary.LittleEndian.PutUint32(buf[hPivotViewSize:], uint32(j>>8)) 213 | source = hashFn(buf) 214 | } 215 | // Same trick with byte retrieval. Only every 8th. 216 | if j&0x7 == 0x7 { 217 | byteV = source[(j&0xff)>>3] 218 | } 219 | bitV := (byteV >> (j & 0x7)) & 0x1 220 | 221 | if bitV == 1 { 222 | // swap the pair items 223 | input[i], input[j] = input[j], input[i] 224 | } 225 | } 226 | // Now repeat, but for the part after the pivot. 227 | mirror = (pivot + listSize + 1) >> 1 228 | end := listSize - 1 229 | // Again, seed and round input is in place, just update the position. 230 | // We start at the end, and work back to the mirror point. 231 | // This makes us process each pear exactly once (instead of unnecessarily twice, like in the spec) 232 | binary.LittleEndian.PutUint32(buf[hPivotViewSize:], uint32(end>>8)) 233 | source = hashFn(buf) 234 | byteV = source[(end&0xff)>>3] 235 | for i, j := pivot+1, end; i < mirror; i, j = i+1, j-1 { 236 | // Exact same thing (copy of above loop body) 237 | //-------------------------------------------- 238 | // The pair is i,j. With j being the bigger of the two, hence the "position" identifier of the pair. 239 | // Every 256th bit (aligned to j). 240 | if j&0xff == 0xff { 241 | // just overwrite the last part of the buffer, reuse the start (seed, round) 242 | binary.LittleEndian.PutUint32(buf[hPivotViewSize:], uint32(j>>8)) 243 | source = hashFn(buf) 244 | } 245 | // Same trick with byte retrieval. Only every 8th. 246 | if j&0x7 == 0x7 { 247 | byteV = source[(j&0xff)>>3] 248 | } 249 | bitV := (byteV >> (j & 0x7)) & 0x1 250 | 251 | if bitV == 1 { 252 | // swap the pair items 253 | input[i], input[j] = input[j], input[i] 254 | } 255 | //-------------------------------------------- 256 | } 257 | // go forwards? 258 | if dir { 259 | // -> shuffle 260 | r++ 261 | if r == rounds { 262 | break 263 | } 264 | } else { 265 | if r == 0 { 266 | break 267 | } 268 | // -> un-shuffle 269 | r-- 270 | } 271 | } 272 | } 273 | -------------------------------------------------------------------------------- /shuffle_bench_test.go: -------------------------------------------------------------------------------- 1 | package eth2_shuffle 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func BenchmarkPermuteIndex(b *testing.B) { 9 | listSizes := []uint64{4000000, 40000, 400} 10 | 11 | hashFn := getStandardHashFn() 12 | // "random" seed for testing. Can be any 32 bytes. 13 | seed := [32]byte{123, 42} 14 | 15 | // rounds of shuffling, constant in spec 16 | rounds := uint8(90) 17 | 18 | for _, listSize := range listSizes { 19 | // benchmark! 20 | b.Run(fmt.Sprintf("PermuteIndex_%d", listSize), func(ib *testing.B) { 21 | for i := uint64(0); i < uint64(ib.N); i++ { 22 | PermuteIndex(hashFn, rounds, i % listSize, listSize, seed) 23 | } 24 | }) 25 | } 26 | } 27 | 28 | func BenchmarkIndexComparison(b *testing.B) { 29 | // 4M is just too inefficient to even start comparing. 30 | listSizes := []uint64{40000, 400} 31 | 32 | hashFn := getStandardHashFn() 33 | // "random" seed for testing. Can be any 32 bytes. 34 | seed := [32]byte{123, 42} 35 | 36 | // rounds of shuffling, constant in spec 37 | rounds := uint8(90) 38 | 39 | for _, listSize := range listSizes { 40 | // benchmark! 41 | b.Run(fmt.Sprintf("Indexwise_ShuffleList_%d", listSize), func(ib *testing.B) { 42 | for i := 0; i < ib.N; i++ { 43 | // Simulate a list-shuffle by running permute-index listSize times. 44 | for j := uint64(0); j < listSize; j++ { 45 | PermuteIndex(hashFn, rounds, j, listSize, seed) 46 | } 47 | } 48 | }) 49 | } 50 | } 51 | 52 | func BenchmarkShuffleList(b *testing.B) { 53 | listSizes := []uint64{4000000, 40000, 400} 54 | 55 | hashFn := getStandardHashFn() 56 | // "random" seed for testing. Can be any 32 bytes. 57 | seed := [32]byte{123, 42} 58 | 59 | // rounds of shuffling, constant in spec 60 | rounds := uint8(90) 61 | 62 | for _, listSize := range listSizes { 63 | // list to test 64 | testIndices := make([]uint64, listSize, listSize) 65 | // fill 66 | for i := uint64(0); i < listSize; i++ { 67 | testIndices[i] = i 68 | } 69 | // benchmark! 70 | b.Run(fmt.Sprintf("ShuffleList_%d", listSize), func(ib *testing.B) { 71 | for i := 0; i < ib.N; i++ { 72 | ShuffleList(hashFn, testIndices, rounds, seed) 73 | } 74 | }) 75 | } 76 | } 77 | 78 | //// TODO optimize memory allocations even more by analysis of statistics 79 | //func BenchmarkShuffleListWithAllocsReport(b *testing.B) { 80 | // b.ReportAllocs() 81 | // BenchmarkShuffleList(b) 82 | //} 83 | 84 | -------------------------------------------------------------------------------- /shuffle_test.go: -------------------------------------------------------------------------------- 1 | package eth2_shuffle 2 | 3 | import ( 4 | "encoding/csv" 5 | "encoding/hex" 6 | "fmt" 7 | "os" 8 | "strconv" 9 | "strings" 10 | "testing" 11 | ) 12 | 13 | 14 | func readEncodedListInput(input string, requiredLen int64, lineIndex int) ([]uint64, error) { 15 | var itemStrs []string 16 | if input != "" { 17 | itemStrs = strings.Split(input, ":") 18 | } else { 19 | itemStrs = make([]string, 0) 20 | } 21 | if int64(len(itemStrs)) != requiredLen { 22 | return nil, fmt.Errorf("expected outputs length does not match list size on line %d\n", lineIndex) 23 | } 24 | items := make([]uint64, len(itemStrs), len(itemStrs)) 25 | for i, itemStr := range itemStrs { 26 | item, err := strconv.ParseInt(itemStr, 10, 64) 27 | if err != nil { 28 | return nil, fmt.Errorf("expected list item on line %d, item %d cannot be parsed\n", lineIndex, i) 29 | } 30 | items[i] = uint64(item) 31 | } 32 | return items, nil 33 | } 34 | 35 | func TestAgainstSpec(t *testing.T) { 36 | // Open CSV file 37 | f, err := os.Open("spec/tests.csv") 38 | if err != nil { 39 | panic(err) 40 | } 41 | defer f.Close() 42 | 43 | // Read File into a Variable 44 | lines, err := csv.NewReader(f).ReadAll() 45 | if err != nil { 46 | panic(err) 47 | } 48 | 49 | // constant in spec 50 | rounds := uint8(90) 51 | 52 | // Loop through lines & turn into object 53 | for lineIndex, line := range lines { 54 | 55 | parsedSeed, err := hex.DecodeString(line[0]) 56 | if err != nil { 57 | t.Fatalf("seed on line %d cannot be parsed\n", lineIndex) 58 | } 59 | listSize, err := strconv.ParseInt(line[1], 10, 32) 60 | if err != nil { 61 | t.Fatalf("list size on line %d cannot be parsed\n", lineIndex) 62 | } 63 | inputItems, err := readEncodedListInput(line[2], listSize, lineIndex) 64 | expectedItems, err := readEncodedListInput(line[3], listSize, lineIndex) 65 | 66 | t.Run("", func(listSize uint64, shuffleIn []uint64, shuffleOut []uint64) func(st *testing.T) { 67 | return func(st *testing.T) { 68 | seed := [32]byte{} 69 | copy(seed[:], parsedSeed) 70 | // run every test case in parallel. Input data is copied, for loop won't mess it up. 71 | st.Parallel() 72 | 73 | hashFn := getStandardHashFn() 74 | 75 | st.Run("PermuteIndex", func (it *testing.T) { 76 | for i := uint64(0); i < listSize; i++ { 77 | // calculate the permuted index. (i.e. shuffle single index) 78 | permuted := PermuteIndex(hashFn, rounds, i, listSize, seed) 79 | // compare with expectation 80 | if shuffleIn[i] != shuffleOut[permuted] { 81 | it.FailNow() 82 | } 83 | } 84 | }) 85 | 86 | st.Run("UnpermuteIndex", func (it *testing.T) { 87 | // for each index, test un-permuting 88 | for i := uint64(0); i < listSize; i++ { 89 | // calculate the un-permuted index. (i.e. un-shuffle single index) 90 | unpermuted := UnpermuteIndex(hashFn, rounds, i, listSize, seed) 91 | // compare with expectation 92 | if shuffleOut[i] != shuffleIn[unpermuted] { 93 | it.FailNow() 94 | } 95 | } 96 | }) 97 | 98 | st.Run("ShuffleList", func (it *testing.T) { 99 | // create input, this slice will be shuffled. 100 | testInput := make([]uint64, listSize, listSize) 101 | copy(testInput, shuffleIn) 102 | // shuffle! 103 | ShuffleList(hashFn, testInput, rounds, seed) 104 | // compare shuffled list to expected output 105 | for i := uint64(0); i < listSize; i++ { 106 | if testInput[i] != shuffleOut[i] { 107 | it.FailNow() 108 | } 109 | } 110 | }) 111 | 112 | st.Run("UnshuffleList", func (it *testing.T) { 113 | // create input, this slice will be un-shuffled. 114 | testInput := make([]uint64, listSize, listSize) 115 | copy(testInput, shuffleOut) 116 | // un-shuffle! 117 | UnshuffleList(hashFn, testInput, rounds, seed) 118 | // compare shuffled list to original input 119 | for i := uint64(0); i < listSize; i++ { 120 | if testInput[i] != shuffleIn[i] { 121 | it.FailNow() 122 | } 123 | } 124 | }) 125 | } 126 | }(uint64(listSize), inputItems, expectedItems)) 127 | } 128 | } 129 | 130 | -------------------------------------------------------------------------------- /spec/shuffle_test_gen.py: -------------------------------------------------------------------------------- 1 | import binascii 2 | import csv 3 | import random 4 | from hashlib import sha256 5 | 6 | SHUFFLE_ROUND_COUNT = 90 7 | 8 | 9 | def bytes_to_int(data: bytes) -> int: 10 | return int.from_bytes(data, 'little') 11 | 12 | 13 | def int_to_bytes1(x): 14 | return x.to_bytes(1, 'little') 15 | 16 | 17 | def int_to_bytes4(x): 18 | return x.to_bytes(4, 'little') 19 | 20 | 21 | def hash(data: bytes) -> bytes: 22 | return sha256(data).digest() 23 | 24 | 25 | def get_permuted_index(index: int, list_size: int, seed: bytes) -> int: 26 | """ 27 | Return `p(index)` in a pseudorandom permutation `p` of `0...list_size-1` with ``seed`` as entropy. 28 | 29 | Utilizes 'swap or not' shuffling found in 30 | https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf 31 | See the 'generalized domain' algorithm on page 3. 32 | """ 33 | assert index < list_size 34 | assert list_size <= 2 ** 40 35 | 36 | for round in range(SHUFFLE_ROUND_COUNT): 37 | pivot = bytes_to_int(hash(seed + int_to_bytes1(round))[0:8]) % list_size 38 | flip = (pivot - index) % list_size 39 | position = max(index, flip) 40 | source = hash(seed + int_to_bytes1(round) + int_to_bytes4(position // 256)) 41 | byte = source[(position % 256) // 8] 42 | bit = (byte >> (position % 8)) % 2 43 | index = flip if bit else index 44 | 45 | return index 46 | 47 | 48 | with open('tests.csv', mode='w') as employee_file: 49 | tests_writer = csv.writer(employee_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 50 | 51 | for seed in [hash(int_to_bytes4(seed_init_value)) for seed_init_value in range(30)]: 52 | for list_size in [0, 1, 2, 3, 5, 10, 100, 1000]: 53 | start_list = [i for i in range(list_size)] 54 | # random input, using python shuffle. Seed is static here, we just want consistent test generation. 55 | # Checking the shuffling on a simple incremental list is not good enough. 56 | # I.e. we want the shuffle to be independent of the contents of the list. 57 | random.seed(123) 58 | random.shuffle(start_list) 59 | encoded_start = ":".join([str(x) for x in start_list]) 60 | shuffling = [0 for _ in range(list_size)] 61 | for i in range(list_size): 62 | shuffling[get_permuted_index(i, list_size, seed)] = i 63 | end_list = [start_list[x] for x in shuffling] 64 | encoded_shuffled = ":".join([str(v) for v in end_list]) 65 | tests_writer.writerow([binascii.hexlify(seed).decode("utf-8"), list_size, encoded_start, encoded_shuffled]) 66 | -------------------------------------------------------------------------------- /test_util.go: -------------------------------------------------------------------------------- 1 | package eth2_shuffle 2 | 3 | import "crypto/sha256" 4 | 5 | func getStandardHashFn() HashFn { 6 | hash := sha256.New() 7 | hashFn := func(in []byte) []byte { 8 | hash.Reset() 9 | hash.Write(in) 10 | return hash.Sum(nil) 11 | } 12 | return hashFn 13 | } 14 | --------------------------------------------------------------------------------