├── .gitignore ├── .travis.yml ├── Gopkg.lock ├── Gopkg.toml ├── LICENSE ├── README.md ├── brute.go ├── brute_test.go ├── cidranger.go ├── cidranger_test.go ├── example └── custom-ranger-asn.go ├── go.mod ├── go.sum ├── net ├── ip.go └── ip_test.go ├── testdata └── aws_ip_ranges.json ├── trie.go ├── trie_test.go └── version.go /.gitignore: -------------------------------------------------------------------------------- 1 | vendor 2 | .idea -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.13.x 4 | - 1.14.x 5 | - tip 6 | before_install: 7 | - travis_retry go get github.com/mattn/goveralls 8 | script: 9 | - go test -v -covermode=count -coverprofile=coverage.out ./... 10 | - travis_retry $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci 11 | -------------------------------------------------------------------------------- /Gopkg.lock: -------------------------------------------------------------------------------- 1 | # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. 2 | 3 | 4 | [[projects]] 5 | digest = "1:a2c1d0e43bd3baaa071d1b9ed72c27d78169b2b269f71c105ac4ba34b1be4a39" 6 | name = "github.com/davecgh/go-spew" 7 | packages = ["spew"] 8 | pruneopts = "UT" 9 | revision = "346938d642f2ec3594ed81d874461961cd0faa76" 10 | version = "v1.1.0" 11 | 12 | [[projects]] 13 | digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" 14 | name = "github.com/pmezard/go-difflib" 15 | packages = ["difflib"] 16 | pruneopts = "UT" 17 | revision = "792786c7400a136282c1664665ae0a8db921c6c2" 18 | version = "v1.0.0" 19 | 20 | [[projects]] 21 | digest = "1:f85e109eda8f6080877185d1c39e98dd8795e1780c08beca28304b87fd855a1c" 22 | name = "github.com/stretchr/testify" 23 | packages = ["assert"] 24 | pruneopts = "UT" 25 | revision = "12b6f73e6084dad08a7c6e575284b177ecafbc71" 26 | version = "v1.2.1" 27 | 28 | [solve-meta] 29 | analyzer-name = "dep" 30 | analyzer-version = 1 31 | input-imports = ["github.com/stretchr/testify/assert"] 32 | solver-name = "gps-cdcl" 33 | solver-version = 1 34 | -------------------------------------------------------------------------------- /Gopkg.toml: -------------------------------------------------------------------------------- 1 | # Gopkg.toml example 2 | # 3 | # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md 4 | # for detailed Gopkg.toml documentation. 5 | # 6 | # required = ["github.com/user/thing/cmd/thing"] 7 | # ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] 8 | # 9 | # [[constraint]] 10 | # name = "github.com/user/project" 11 | # version = "1.0.0" 12 | # 13 | # [[constraint]] 14 | # name = "github.com/user/project2" 15 | # branch = "dev" 16 | # source = "github.com/myfork/project2" 17 | # 18 | # [[override]] 19 | # name = "github.com/x/y" 20 | # version = "2.4.0" 21 | # 22 | # [prune] 23 | # non-go = false 24 | # go-tests = true 25 | # unused-packages = true 26 | 27 | 28 | [[constraint]] 29 | name = "github.com/stretchr/testify" 30 | version = "1.2.1" 31 | 32 | [prune] 33 | go-tests = true 34 | unused-packages = true 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Yulin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cidranger 2 | Fast IP to CIDR block(s) lookup using trie in Golang, inspired by [IPv4 route lookup linux](https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux). Possible use cases include detecting if a IP address is from published cloud provider CIDR blocks (e.g. 52.95.110.1 is contained in published AWS Route53 CIDR 52.95.110.0/24), IP routing rules, etc. 3 | 4 | [![GoDoc Reference](https://img.shields.io/badge/godoc-reference-5272B4.svg?style=flat-square)](https://godoc.org/github.com/yl2chen/cidranger) 5 | [![Build Status](https://img.shields.io/travis/yl2chen/cidranger.svg?branch=master&style=flat-square)](https://travis-ci.org/yl2chen/cidranger) 6 | [![Coverage Status](https://img.shields.io/coveralls/yl2chen/cidranger.svg?branch=master&style=flat-square)](https://coveralls.io/github/yl2chen/cidranger?branch=master) 7 | [![Go Report Card](https://goreportcard.com/badge/github.com/yl2chen/cidranger?&style=flat-square)](https://goreportcard.com/report/github.com/yl2chen/cidranger) 8 | 9 | This is visualization of a trie storing CIDR blocks `128.0.0.0/2` `192.0.0.0/2` `200.0.0.0/5` without path compression, the 0/1 number on the path indicates the bit value of the IP address at specified bit position, hence the path from root node to a child node represents a CIDR block that contains all IP ranges of its children, and children's children. 10 |

11 | 12 | Visualization of trie storing same CIDR blocks with path compression, improving both lookup speed and memory footprint. 13 |

14 | 15 | ## Getting Started 16 | Configure imports. 17 | ```go 18 | import ( 19 | "net" 20 | 21 | "github.com/yl2chen/cidranger" 22 | ) 23 | ``` 24 | Create a new ranger implemented using Path-Compressed prefix trie. 25 | ```go 26 | ranger := NewPCTrieRanger() 27 | ``` 28 | Inserts CIDR blocks. 29 | ```go 30 | _, network1, _ := net.ParseCIDR("192.168.1.0/24") 31 | _, network2, _ := net.ParseCIDR("128.168.1.0/24") 32 | ranger.Insert(NewBasicRangerEntry(*network1)) 33 | ranger.Insert(NewBasicRangerEntry(*network2)) 34 | ``` 35 | To attach any additional value(s) to the entry, simply create custom struct 36 | storing the desired value(s) that implements the RangerEntry interface: 37 | ```go 38 | type RangerEntry interface { 39 | Network() net.IPNet 40 | } 41 | ``` 42 | The prefix trie can be visualized as: 43 | ``` 44 | 0.0.0.0/0 (target_pos:31:has_entry:false) 45 | | 1--> 128.0.0.0/1 (target_pos:30:has_entry:false) 46 | | | 0--> 128.168.1.0/24 (target_pos:7:has_entry:true) 47 | | | 1--> 192.168.1.0/24 (target_pos:7:has_entry:true) 48 | ``` 49 | To test if given IP is contained in constructed ranger, 50 | ```go 51 | contains, err = ranger.Contains(net.ParseIP("128.168.1.0")) // returns true, nil 52 | contains, err = ranger.Contains(net.ParseIP("192.168.2.0")) // returns false, nil 53 | ``` 54 | To get all the networks given is contained in, 55 | ```go 56 | containingNetworks, err = ranger.ContainingNetworks(net.ParseIP("128.168.1.0")) 57 | ``` 58 | To get all networks in ranger, 59 | ```go 60 | entries, err := ranger.CoveredNetworks(*AllIPv4) // for IPv4 61 | entries, err := ranger.CoveredNetworks(*AllIPv6) // for IPv6 62 | ``` 63 | 64 | ## Benchmark 65 | Compare hit/miss case for IPv4/IPv6 using PC trie vs brute force implementation, Ranger is initialized with published AWS ip ranges (889 IPv4 CIDR blocks and 360 IPv6) 66 | ```go 67 | // Ipv4 lookup hit scenario 68 | BenchmarkPCTrieHitIPv4UsingAWSRanges-4 5000000 353 ns/op 69 | BenchmarkBruteRangerHitIPv4UsingAWSRanges-4 100000 13719 ns/op 70 | 71 | // Ipv6 lookup hit scenario, counter-intuitively faster then IPv4 due to less IPv6 CIDR 72 | // blocks in the AWS dataset, hence the constructed trie has less path splits and depth. 73 | BenchmarkPCTrieHitIPv6UsingAWSRanges-4 10000000 143 ns/op 74 | BenchmarkBruteRangerHitIPv6UsingAWSRanges-4 300000 5178 ns/op 75 | 76 | // Ipv4 lookup miss scenario 77 | BenchmarkPCTrieMissIPv4UsingAWSRanges-4 20000000 96.5 ns/op 78 | BenchmarkBruteRangerMissIPv4UsingAWSRanges-4 50000 24781 ns/op 79 | 80 | // Ipv6 lookup miss scenario 81 | BenchmarkPCTrieHMissIPv6UsingAWSRanges-4 10000000 115 ns/op 82 | BenchmarkBruteRangerMissIPv6UsingAWSRanges-4 100000 10824 ns/op 83 | ``` 84 | 85 | ## Example of IPv6 trie: 86 | ``` 87 | ::/0 (target_pos:127:has_entry:false) 88 | | 0--> 2400::/14 (target_pos:113:has_entry:false) 89 | | | 0--> 2400:6400::/22 (target_pos:105:has_entry:false) 90 | | | | 0--> 2400:6500::/32 (target_pos:95:has_entry:false) 91 | | | | | 0--> 2400:6500::/39 (target_pos:88:has_entry:false) 92 | | | | | | 0--> 2400:6500:0:7000::/53 (target_pos:74:has_entry:false) 93 | | | | | | | 0--> 2400:6500:0:7000::/54 (target_pos:73:has_entry:false) 94 | | | | | | | | 0--> 2400:6500:0:7000::/55 (target_pos:72:has_entry:false) 95 | | | | | | | | | 0--> 2400:6500:0:7000::/56 (target_pos:71:has_entry:true) 96 | | | | | | | | | 1--> 2400:6500:0:7100::/56 (target_pos:71:has_entry:true) 97 | | | | | | | | 1--> 2400:6500:0:7200::/56 (target_pos:71:has_entry:true) 98 | | | | | | | 1--> 2400:6500:0:7400::/55 (target_pos:72:has_entry:false) 99 | | | | | | | | 0--> 2400:6500:0:7400::/56 (target_pos:71:has_entry:true) 100 | | | | | | | | 1--> 2400:6500:0:7500::/56 (target_pos:71:has_entry:true) 101 | | | | | | 1--> 2400:6500:100:7000::/54 (target_pos:73:has_entry:false) 102 | | | | | | | 0--> 2400:6500:100:7100::/56 (target_pos:71:has_entry:true) 103 | | | | | | | 1--> 2400:6500:100:7200::/56 (target_pos:71:has_entry:true) 104 | | | | | 1--> 2400:6500:ff00::/64 (target_pos:63:has_entry:true) 105 | | | | 1--> 2400:6700:ff00::/64 (target_pos:63:has_entry:true) 106 | | | 1--> 2403:b300:ff00::/64 (target_pos:63:has_entry:true) 107 | ``` 108 | -------------------------------------------------------------------------------- /brute.go: -------------------------------------------------------------------------------- 1 | package cidranger 2 | 3 | import ( 4 | "net" 5 | 6 | rnet "github.com/yl2chen/cidranger/net" 7 | ) 8 | 9 | // bruteRanger is a brute force implementation of Ranger. Insertion and 10 | // deletion of networks is performed on an internal storage in the form of 11 | // map[string]net.IPNet (constant time operations). However, inclusion tests are 12 | // always performed linearly at no guaranteed traversal order of recorded networks, 13 | // so one can assume a worst case performance of O(N). The performance can be 14 | // boosted many ways, e.g. changing usage of net.IPNet.Contains() to using masked 15 | // bits equality checking, but the main purpose of this implementation is for 16 | // testing because the correctness of this implementation can be easily guaranteed, 17 | // and used as the ground truth when running a wider range of 'random' tests on 18 | // other more sophisticated implementations. 19 | type bruteRanger struct { 20 | ipV4Entries map[string]RangerEntry 21 | ipV6Entries map[string]RangerEntry 22 | } 23 | 24 | // newBruteRanger returns a new Ranger. 25 | func newBruteRanger() Ranger { 26 | return &bruteRanger{ 27 | ipV4Entries: make(map[string]RangerEntry), 28 | ipV6Entries: make(map[string]RangerEntry), 29 | } 30 | } 31 | 32 | // Insert inserts a RangerEntry into ranger. 33 | func (b *bruteRanger) Insert(entry RangerEntry) error { 34 | network := entry.Network() 35 | key := network.String() 36 | if _, found := b.ipV4Entries[key]; !found { 37 | entries, err := b.getEntriesByVersion(entry.Network().IP) 38 | if err != nil { 39 | return err 40 | } 41 | entries[key] = entry 42 | } 43 | return nil 44 | } 45 | 46 | // Remove removes a RangerEntry identified by given network from ranger. 47 | func (b *bruteRanger) Remove(network net.IPNet) (RangerEntry, error) { 48 | networks, err := b.getEntriesByVersion(network.IP) 49 | if err != nil { 50 | return nil, err 51 | } 52 | key := network.String() 53 | if networkToDelete, found := networks[key]; found { 54 | delete(networks, key) 55 | return networkToDelete, nil 56 | } 57 | return nil, nil 58 | } 59 | 60 | // Contains returns bool indicating whether given ip is contained by any 61 | // network in ranger. 62 | func (b *bruteRanger) Contains(ip net.IP) (bool, error) { 63 | entries, err := b.getEntriesByVersion(ip) 64 | if err != nil { 65 | return false, err 66 | } 67 | for _, entry := range entries { 68 | network := entry.Network() 69 | if network.Contains(ip) { 70 | return true, nil 71 | } 72 | } 73 | return false, nil 74 | } 75 | 76 | // ContainingNetworks returns all RangerEntry(s) that given ip contained in. 77 | func (b *bruteRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { 78 | entries, err := b.getEntriesByVersion(ip) 79 | if err != nil { 80 | return nil, err 81 | } 82 | results := []RangerEntry{} 83 | for _, entry := range entries { 84 | network := entry.Network() 85 | if network.Contains(ip) { 86 | results = append(results, entry) 87 | } 88 | } 89 | return results, nil 90 | } 91 | 92 | // CoveredNetworks returns the list of RangerEntry(s) the given ipnet 93 | // covers. That is, the networks that are completely subsumed by the 94 | // specified network. 95 | func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { 96 | entries, err := b.getEntriesByVersion(network.IP) 97 | if err != nil { 98 | return nil, err 99 | } 100 | var results []RangerEntry 101 | testNetwork := rnet.NewNetwork(network) 102 | for _, entry := range entries { 103 | entryNetwork := rnet.NewNetwork(entry.Network()) 104 | if testNetwork.Covers(entryNetwork) { 105 | results = append(results, entry) 106 | } 107 | } 108 | return results, nil 109 | } 110 | 111 | // Len returns number of networks in ranger. 112 | func (b *bruteRanger) Len() int { 113 | return len(b.ipV4Entries) + len(b.ipV6Entries) 114 | } 115 | 116 | func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) { 117 | if ip.To4() != nil { 118 | return b.ipV4Entries, nil 119 | } 120 | if ip.To16() != nil { 121 | return b.ipV6Entries, nil 122 | } 123 | return nil, ErrInvalidNetworkInput 124 | } 125 | -------------------------------------------------------------------------------- /brute_test.go: -------------------------------------------------------------------------------- 1 | package cidranger 2 | 3 | import ( 4 | "net" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestInsert(t *testing.T) { 12 | ranger := newBruteRanger().(*bruteRanger) 13 | _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24") 14 | _, networkIPv6, _ := net.ParseCIDR("8000::/96") 15 | entryIPv4 := NewBasicRangerEntry(*networkIPv4) 16 | entryIPv6 := NewBasicRangerEntry(*networkIPv6) 17 | 18 | ranger.Insert(entryIPv4) 19 | ranger.Insert(entryIPv6) 20 | 21 | assert.Equal(t, 1, len(ranger.ipV4Entries)) 22 | assert.Equal(t, entryIPv4, ranger.ipV4Entries["0.0.1.0/24"]) 23 | assert.Equal(t, 1, len(ranger.ipV6Entries)) 24 | assert.Equal(t, entryIPv6, ranger.ipV6Entries["8000::/96"]) 25 | } 26 | 27 | func TestInsertError(t *testing.T) { 28 | bRanger := newBruteRanger().(*bruteRanger) 29 | _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24") 30 | networkIPv4.IP = append(networkIPv4.IP, byte(4)) 31 | err := bRanger.Insert(NewBasicRangerEntry(*networkIPv4)) 32 | assert.Equal(t, ErrInvalidNetworkInput, err) 33 | } 34 | 35 | func TestRemove(t *testing.T) { 36 | ranger := newBruteRanger().(*bruteRanger) 37 | _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24") 38 | _, networkIPv6, _ := net.ParseCIDR("8000::/96") 39 | _, notInserted, _ := net.ParseCIDR("8000::/96") 40 | 41 | insertIPv4 := NewBasicRangerEntry(*networkIPv4) 42 | insertIPv6 := NewBasicRangerEntry(*networkIPv6) 43 | 44 | ranger.Insert(insertIPv4) 45 | deletedIPv4, err := ranger.Remove(*networkIPv4) 46 | assert.NoError(t, err) 47 | 48 | ranger.Insert(insertIPv6) 49 | deletedIPv6, err := ranger.Remove(*networkIPv6) 50 | assert.NoError(t, err) 51 | 52 | entry, err := ranger.Remove(*notInserted) 53 | assert.NoError(t, err) 54 | assert.Nil(t, entry) 55 | 56 | assert.Equal(t, insertIPv4, deletedIPv4) 57 | assert.Equal(t, 0, len(ranger.ipV4Entries)) 58 | assert.Equal(t, insertIPv6, deletedIPv6) 59 | assert.Equal(t, 0, len(ranger.ipV6Entries)) 60 | } 61 | 62 | func TestRemoveError(t *testing.T) { 63 | r := newBruteRanger().(*bruteRanger) 64 | _, invalidNetwork, _ := net.ParseCIDR("0.0.1.0/24") 65 | invalidNetwork.IP = append(invalidNetwork.IP, byte(4)) 66 | 67 | _, err := r.Remove(*invalidNetwork) 68 | assert.Equal(t, ErrInvalidNetworkInput, err) 69 | } 70 | 71 | func TestContains(t *testing.T) { 72 | r := newBruteRanger().(*bruteRanger) 73 | _, network, _ := net.ParseCIDR("0.0.1.0/24") 74 | _, network1, _ := net.ParseCIDR("8000::/112") 75 | r.Insert(NewBasicRangerEntry(*network)) 76 | r.Insert(NewBasicRangerEntry(*network1)) 77 | 78 | cases := []struct { 79 | ip net.IP 80 | contains bool 81 | err error 82 | name string 83 | }{ 84 | {net.ParseIP("0.0.1.255"), true, nil, "IPv4 should contain"}, 85 | {net.ParseIP("0.0.0.255"), false, nil, "IPv4 houldn't contain"}, 86 | {net.ParseIP("8000::ffff"), true, nil, "IPv6 shouldn't contain"}, 87 | {net.ParseIP("8000::1:ffff"), false, nil, "IPv6 shouldn't contain"}, 88 | {append(net.ParseIP("8000::1:ffff"), byte(0)), false, ErrInvalidNetworkInput, "Invalid IP"}, 89 | } 90 | 91 | for _, tc := range cases { 92 | t.Run(tc.name, func(t *testing.T) { 93 | contains, err := r.Contains(tc.ip) 94 | if tc.err != nil { 95 | assert.Equal(t, tc.err, err) 96 | } else { 97 | assert.NoError(t, err) 98 | assert.Equal(t, tc.contains, contains) 99 | } 100 | }) 101 | } 102 | } 103 | 104 | func TestContainingNetworks(t *testing.T) { 105 | r := newBruteRanger().(*bruteRanger) 106 | _, network1, _ := net.ParseCIDR("0.0.1.0/24") 107 | _, network2, _ := net.ParseCIDR("0.0.1.0/25") 108 | _, network3, _ := net.ParseCIDR("8000::/112") 109 | _, network4, _ := net.ParseCIDR("8000::/113") 110 | entry1 := NewBasicRangerEntry(*network1) 111 | entry2 := NewBasicRangerEntry(*network2) 112 | entry3 := NewBasicRangerEntry(*network3) 113 | entry4 := NewBasicRangerEntry(*network4) 114 | r.Insert(entry1) 115 | r.Insert(entry2) 116 | r.Insert(entry3) 117 | r.Insert(entry4) 118 | cases := []struct { 119 | ip net.IP 120 | containingNetworks []RangerEntry 121 | err error 122 | name string 123 | }{ 124 | {net.ParseIP("0.0.1.255"), []RangerEntry{entry1}, nil, "IPv4 should contain"}, 125 | {net.ParseIP("0.0.1.127"), []RangerEntry{entry1, entry2}, nil, "IPv4 should contain both"}, 126 | {net.ParseIP("0.0.0.127"), []RangerEntry{}, nil, "IPv4 should contain none"}, 127 | {net.ParseIP("8000::ffff"), []RangerEntry{entry3}, nil, "IPv6 should constain"}, 128 | {net.ParseIP("8000::7fff"), []RangerEntry{entry3, entry4}, nil, "IPv6 should contain both"}, 129 | {net.ParseIP("8000::1:7fff"), []RangerEntry{}, nil, "IPv6 should contain none"}, 130 | {append(net.ParseIP("8000::1:7fff"), byte(0)), nil, ErrInvalidNetworkInput, "Invalid IP"}, 131 | } 132 | 133 | for _, tc := range cases { 134 | t.Run(tc.name, func(t *testing.T) { 135 | networks, err := r.ContainingNetworks(tc.ip) 136 | if tc.err != nil { 137 | assert.Equal(t, tc.err, err) 138 | } else { 139 | assert.NoError(t, err) 140 | assert.Equal(t, len(tc.containingNetworks), len(networks)) 141 | for _, network := range tc.containingNetworks { 142 | assert.Contains(t, networks, network) 143 | } 144 | } 145 | }) 146 | } 147 | } 148 | 149 | func TestCoveredNetworks(t *testing.T) { 150 | for _, tc := range coveredNetworkTests { 151 | t.Run(tc.name, func(t *testing.T) { 152 | ranger := newBruteRanger() 153 | for _, insert := range tc.inserts { 154 | _, network, _ := net.ParseCIDR(insert) 155 | err := ranger.Insert(NewBasicRangerEntry(*network)) 156 | assert.NoError(t, err) 157 | } 158 | var expectedEntries []string 159 | for _, network := range tc.networks { 160 | expectedEntries = append(expectedEntries, network) 161 | } 162 | sort.Strings(expectedEntries) 163 | _, snet, _ := net.ParseCIDR(tc.search) 164 | networks, err := ranger.CoveredNetworks(*snet) 165 | assert.NoError(t, err) 166 | 167 | var results []string 168 | for _, result := range networks { 169 | net := result.Network() 170 | results = append(results, net.String()) 171 | } 172 | sort.Strings(results) 173 | 174 | assert.Equal(t, expectedEntries, results) 175 | }) 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /cidranger.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package cidranger provides utility to store CIDR blocks and perform ip 3 | inclusion tests against it. 4 | 5 | To create a new instance of the path-compressed trie: 6 | 7 | ranger := NewPCTrieRanger() 8 | 9 | To insert or remove an entry (any object that satisfies the RangerEntry 10 | interface): 11 | 12 | _, network, _ := net.ParseCIDR("192.168.0.0/24") 13 | ranger.Insert(NewBasicRangerEntry(*network)) 14 | ranger.Remove(network) 15 | 16 | If you desire for any value to be attached to the entry, simply 17 | create custom struct that satisfies the RangerEntry interface: 18 | 19 | type RangerEntry interface { 20 | Network() net.IPNet 21 | } 22 | 23 | To test whether an IP is contained in the constructed networks ranger: 24 | 25 | // returns bool, error 26 | containsBool, err := ranger.Contains(net.ParseIP("192.168.0.1")) 27 | 28 | To get a list of CIDR blocks in constructed ranger that contains IP: 29 | 30 | // returns []RangerEntry, error 31 | entries, err := ranger.ContainingNetworks(net.ParseIP("192.168.0.1")) 32 | 33 | To get a list of all IPv4/IPv6 rangers respectively: 34 | 35 | // returns []RangerEntry, error 36 | entries, err := ranger.CoveredNetworks(*AllIPv4) 37 | entries, err := ranger.CoveredNetworks(*AllIPv6) 38 | 39 | */ 40 | package cidranger 41 | 42 | import ( 43 | "fmt" 44 | "net" 45 | ) 46 | 47 | // ErrInvalidNetworkInput is returned upon invalid network input. 48 | var ErrInvalidNetworkInput = fmt.Errorf("Invalid network input") 49 | 50 | // ErrInvalidNetworkNumberInput is returned upon invalid network input. 51 | var ErrInvalidNetworkNumberInput = fmt.Errorf("Invalid network number input") 52 | 53 | // AllIPv4 is a IPv4 CIDR that contains all networks 54 | var AllIPv4 = parseCIDRUnsafe("0.0.0.0/0") 55 | 56 | // AllIPv6 is a IPv6 CIDR that contains all networks 57 | var AllIPv6 = parseCIDRUnsafe("0::0/0") 58 | 59 | func parseCIDRUnsafe(s string) *net.IPNet { 60 | _, cidr, _ := net.ParseCIDR(s) 61 | return cidr 62 | } 63 | 64 | // RangerEntry is an interface for insertable entry into a Ranger. 65 | type RangerEntry interface { 66 | Network() net.IPNet 67 | } 68 | 69 | type basicRangerEntry struct { 70 | ipNet net.IPNet 71 | } 72 | 73 | func (b *basicRangerEntry) Network() net.IPNet { 74 | return b.ipNet 75 | } 76 | 77 | // NewBasicRangerEntry returns a basic RangerEntry that only stores the network 78 | // itself. 79 | func NewBasicRangerEntry(ipNet net.IPNet) RangerEntry { 80 | return &basicRangerEntry{ 81 | ipNet: ipNet, 82 | } 83 | } 84 | 85 | // Ranger is an interface for cidr block containment lookups. 86 | type Ranger interface { 87 | Insert(entry RangerEntry) error 88 | Remove(network net.IPNet) (RangerEntry, error) 89 | Contains(ip net.IP) (bool, error) 90 | ContainingNetworks(ip net.IP) ([]RangerEntry, error) 91 | CoveredNetworks(network net.IPNet) ([]RangerEntry, error) 92 | Len() int 93 | } 94 | 95 | // NewPCTrieRanger returns a versionedRanger that supports both IPv4 and IPv6 96 | // using the path compressed trie implemention. 97 | func NewPCTrieRanger() Ranger { 98 | return newVersionedRanger(newPrefixTree) 99 | } 100 | -------------------------------------------------------------------------------- /cidranger_test.go: -------------------------------------------------------------------------------- 1 | package cidranger 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | "math/rand" 7 | "net" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | rnet "github.com/yl2chen/cidranger/net" 13 | ) 14 | 15 | /* 16 | ****************************************************************** 17 | Test Contains/ContainingNetworks against basic brute force ranger. 18 | ****************************************************************** 19 | */ 20 | 21 | func TestContainsAgainstBaseIPv4(t *testing.T) { 22 | testContainsAgainstBase(t, 100000, randIPv4Gen) 23 | } 24 | 25 | func TestContainingNetworksAgaistBaseIPv4(t *testing.T) { 26 | testContainingNetworksAgainstBase(t, 100000, randIPv4Gen) 27 | } 28 | 29 | func TestCoveredNetworksAgainstBaseIPv4(t *testing.T) { 30 | testCoversNetworksAgainstBase(t, 100000, randomIPNetGenFactory(ipV4AWSRangesIPNets)) 31 | } 32 | 33 | // IPv6 spans an extremely large address space (2^128), randomly generated IPs 34 | // will often fall outside of the test ranges (AWS public CIDR blocks), so it 35 | // it more meaningful for testing to run from a curated list of IPv6 IPs. 36 | func TestContainsAgaistBaseIPv6(t *testing.T) { 37 | testContainsAgainstBase(t, 100000, curatedAWSIPv6Gen) 38 | } 39 | 40 | func TestContainingNetworksAgaistBaseIPv6(t *testing.T) { 41 | testContainingNetworksAgainstBase(t, 100000, curatedAWSIPv6Gen) 42 | } 43 | 44 | func TestCoveredNetworksAgainstBaseIPv6(t *testing.T) { 45 | testCoversNetworksAgainstBase(t, 100000, randomIPNetGenFactory(ipV6AWSRangesIPNets)) 46 | } 47 | 48 | func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { 49 | if testing.Short() { 50 | t.Skip("Skipping memory test in `-short` mode") 51 | } 52 | rangers := []Ranger{NewPCTrieRanger()} 53 | baseRanger := newBruteRanger() 54 | for _, ranger := range rangers { 55 | configureRangerWithAWSRanges(t, ranger) 56 | } 57 | configureRangerWithAWSRanges(t, baseRanger) 58 | 59 | for i := 0; i < iterations; i++ { 60 | nn := ipGen() 61 | expected, err := baseRanger.Contains(nn.ToIP()) 62 | assert.NoError(t, err) 63 | for _, ranger := range rangers { 64 | actual, err := ranger.Contains(nn.ToIP()) 65 | assert.NoError(t, err) 66 | assert.Equal(t, expected, actual) 67 | } 68 | } 69 | } 70 | 71 | func testContainingNetworksAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { 72 | if testing.Short() { 73 | t.Skip("Skipping memory test in `-short` mode") 74 | } 75 | rangers := []Ranger{NewPCTrieRanger()} 76 | baseRanger := newBruteRanger() 77 | for _, ranger := range rangers { 78 | configureRangerWithAWSRanges(t, ranger) 79 | } 80 | configureRangerWithAWSRanges(t, baseRanger) 81 | 82 | for i := 0; i < iterations; i++ { 83 | nn := ipGen() 84 | expected, err := baseRanger.ContainingNetworks(nn.ToIP()) 85 | assert.NoError(t, err) 86 | for _, ranger := range rangers { 87 | actual, err := ranger.ContainingNetworks(nn.ToIP()) 88 | assert.NoError(t, err) 89 | assert.Equal(t, len(expected), len(actual)) 90 | for _, network := range actual { 91 | assert.Contains(t, expected, network) 92 | } 93 | } 94 | } 95 | } 96 | 97 | func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkGenerator) { 98 | if testing.Short() { 99 | t.Skip("Skipping memory test in `-short` mode") 100 | } 101 | rangers := []Ranger{NewPCTrieRanger()} 102 | baseRanger := newBruteRanger() 103 | for _, ranger := range rangers { 104 | configureRangerWithAWSRanges(t, ranger) 105 | } 106 | configureRangerWithAWSRanges(t, baseRanger) 107 | 108 | for i := 0; i < iterations; i++ { 109 | network := netGen() 110 | expected, err := baseRanger.CoveredNetworks(network.IPNet) 111 | assert.NoError(t, err) 112 | for _, ranger := range rangers { 113 | actual, err := ranger.CoveredNetworks(network.IPNet) 114 | assert.NoError(t, err) 115 | assert.Equal(t, len(expected), len(actual)) 116 | for _, network := range actual { 117 | assert.Contains(t, expected, network) 118 | } 119 | } 120 | } 121 | } 122 | 123 | /* 124 | ****************************************************************** 125 | Benchmarks. 126 | ****************************************************************** 127 | */ 128 | 129 | func BenchmarkPCTrieHitIPv4UsingAWSRanges(b *testing.B) { 130 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("52.95.110.1"), NewPCTrieRanger()) 131 | } 132 | func BenchmarkBruteRangerHitIPv4UsingAWSRanges(b *testing.B) { 133 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("52.95.110.1"), newBruteRanger()) 134 | } 135 | 136 | func BenchmarkPCTrieHitIPv6UsingAWSRanges(b *testing.B) { 137 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) 138 | } 139 | func BenchmarkBruteRangerHitIPv6UsingAWSRanges(b *testing.B) { 140 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger()) 141 | } 142 | 143 | func BenchmarkPCTrieMissIPv4UsingAWSRanges(b *testing.B) { 144 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("123.123.123.123"), NewPCTrieRanger()) 145 | } 146 | func BenchmarkBruteRangerMissIPv4UsingAWSRanges(b *testing.B) { 147 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("123.123.123.123"), newBruteRanger()) 148 | } 149 | 150 | func BenchmarkPCTrieHMissIPv6UsingAWSRanges(b *testing.B) { 151 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620::ffff"), NewPCTrieRanger()) 152 | } 153 | func BenchmarkBruteRangerMissIPv6UsingAWSRanges(b *testing.B) { 154 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620::ffff"), newBruteRanger()) 155 | } 156 | 157 | func BenchmarkPCTrieHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { 158 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("52.95.110.1"), NewPCTrieRanger()) 159 | } 160 | func BenchmarkBruteRangerHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { 161 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("52.95.110.1"), newBruteRanger()) 162 | } 163 | 164 | func BenchmarkPCTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { 165 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) 166 | } 167 | func BenchmarkBruteRangerHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { 168 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger()) 169 | } 170 | 171 | func BenchmarkPCTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { 172 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("123.123.123.123"), NewPCTrieRanger()) 173 | } 174 | func BenchmarkBruteRangerMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { 175 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("123.123.123.123"), newBruteRanger()) 176 | } 177 | 178 | func BenchmarkPCTrieHMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { 179 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620::ffff"), NewPCTrieRanger()) 180 | } 181 | func BenchmarkBruteRangerMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { 182 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620::ffff"), newBruteRanger()) 183 | } 184 | 185 | func BenchmarkNewPathprefixTriev4(b *testing.B) { 186 | benchmarkNewPathprefixTrie(b, "192.128.0.0/24") 187 | } 188 | 189 | func BenchmarkNewPathprefixTriev6(b *testing.B) { 190 | benchmarkNewPathprefixTrie(b, "8000::/24") 191 | } 192 | 193 | func benchmarkContainsUsingAWSRanges(tb testing.TB, nn net.IP, ranger Ranger) { 194 | configureRangerWithAWSRanges(tb, ranger) 195 | for n := 0; n < tb.(*testing.B).N; n++ { 196 | ranger.Contains(nn) 197 | } 198 | } 199 | 200 | func benchmarkContainingNetworksUsingAWSRanges(tb testing.TB, nn net.IP, ranger Ranger) { 201 | configureRangerWithAWSRanges(tb, ranger) 202 | for n := 0; n < tb.(*testing.B).N; n++ { 203 | ranger.ContainingNetworks(nn) 204 | } 205 | } 206 | 207 | func benchmarkNewPathprefixTrie(b *testing.B, net1 string) { 208 | _, ipNet1, _ := net.ParseCIDR(net1) 209 | ones, _ := ipNet1.Mask.Size() 210 | 211 | n1 := rnet.NewNetwork(*ipNet1) 212 | uOnes := uint(ones) 213 | 214 | b.ResetTimer() 215 | for n := 0; n < b.N; n++ { 216 | newPathprefixTrie(n1, uOnes) 217 | } 218 | } 219 | 220 | /* 221 | ****************************************************************** 222 | Helper methods and initialization. 223 | ****************************************************************** 224 | */ 225 | 226 | type ipGenerator func() rnet.NetworkNumber 227 | 228 | func randIPv4Gen() rnet.NetworkNumber { 229 | return rnet.NetworkNumber{rand.Uint32()} 230 | } 231 | func randIPv6Gen() rnet.NetworkNumber { 232 | return rnet.NetworkNumber{rand.Uint32(), rand.Uint32(), rand.Uint32(), rand.Uint32()} 233 | } 234 | func curatedAWSIPv6Gen() rnet.NetworkNumber { 235 | randIdx := rand.Intn(len(ipV6AWSRangesIPNets)) 236 | 237 | // Randomly generate an IP somewhat near the range. 238 | network := ipV6AWSRangesIPNets[randIdx] 239 | nn := rnet.NewNetworkNumber(network.IP) 240 | ones, bits := network.Mask.Size() 241 | zeros := bits - ones 242 | nnPartIdx := zeros / rnet.BitsPerUint32 243 | nn[nnPartIdx] = rand.Uint32() 244 | return nn 245 | } 246 | 247 | type networkGenerator func() rnet.Network 248 | 249 | func randomIPNetGenFactory(pool []*net.IPNet) networkGenerator { 250 | return func() rnet.Network { 251 | return rnet.NewNetwork(*pool[rand.Intn(len(pool))]) 252 | } 253 | } 254 | 255 | type AWSRanges struct { 256 | Prefixes []Prefix `json:"prefixes"` 257 | IPv6Prefixes []IPv6Prefix `json:"ipv6_prefixes"` 258 | } 259 | 260 | type Prefix struct { 261 | IPPrefix string `json:"ip_prefix"` 262 | Region string `json:"region"` 263 | Service string `json:"service"` 264 | } 265 | 266 | type IPv6Prefix struct { 267 | IPPrefix string `json:"ipv6_prefix"` 268 | Region string `json:"region"` 269 | Service string `json:"service"` 270 | } 271 | 272 | var awsRanges *AWSRanges 273 | var ipV4AWSRangesIPNets []*net.IPNet 274 | var ipV6AWSRangesIPNets []*net.IPNet 275 | 276 | func loadAWSRanges() *AWSRanges { 277 | file, err := ioutil.ReadFile("./testdata/aws_ip_ranges.json") 278 | if err != nil { 279 | panic(err) 280 | } 281 | var ranges AWSRanges 282 | err = json.Unmarshal(file, &ranges) 283 | if err != nil { 284 | panic(err) 285 | } 286 | return &ranges 287 | } 288 | 289 | func configureRangerWithAWSRanges(tb testing.TB, ranger Ranger) { 290 | for _, prefix := range awsRanges.Prefixes { 291 | _, network, err := net.ParseCIDR(prefix.IPPrefix) 292 | assert.NoError(tb, err) 293 | ranger.Insert(NewBasicRangerEntry(*network)) 294 | } 295 | for _, prefix := range awsRanges.IPv6Prefixes { 296 | _, network, err := net.ParseCIDR(prefix.IPPrefix) 297 | assert.NoError(tb, err) 298 | ranger.Insert(NewBasicRangerEntry(*network)) 299 | } 300 | } 301 | 302 | func init() { 303 | awsRanges = loadAWSRanges() 304 | for _, prefix := range awsRanges.IPv6Prefixes { 305 | _, network, _ := net.ParseCIDR(prefix.IPPrefix) 306 | ipV6AWSRangesIPNets = append(ipV6AWSRangesIPNets, network) 307 | } 308 | for _, prefix := range awsRanges.Prefixes { 309 | _, network, _ := net.ParseCIDR(prefix.IPPrefix) 310 | ipV4AWSRangesIPNets = append(ipV4AWSRangesIPNets, network) 311 | } 312 | rand.Seed(time.Now().Unix()) 313 | } 314 | -------------------------------------------------------------------------------- /example/custom-ranger-asn.go: -------------------------------------------------------------------------------- 1 | /* 2 | Example of how to extend github.com/yl2chen/cidranger 3 | 4 | This adds ASN as a string field, along with methods to get the ASN and the CIDR as strings 5 | 6 | Thank you to yl2chen for his assistance and work on this library 7 | */ 8 | package main 9 | 10 | import ( 11 | "fmt" 12 | "net" 13 | "os" 14 | 15 | "github.com/yl2chen/cidranger" 16 | ) 17 | 18 | // custom structure that conforms to RangerEntry interface 19 | type customRangerEntry struct { 20 | ipNet net.IPNet 21 | asn string 22 | } 23 | 24 | // get function for network 25 | func (b *customRangerEntry) Network() net.IPNet { 26 | return b.ipNet 27 | } 28 | 29 | // get function for network converted to string 30 | func (b *customRangerEntry) NetworkStr() string { 31 | return b.ipNet.String() 32 | } 33 | 34 | // get function for ASN 35 | func (b *customRangerEntry) Asn() string { 36 | return b.asn 37 | } 38 | 39 | // create customRangerEntry object using net and asn 40 | func newCustomRangerEntry(ipNet net.IPNet, asn string) cidranger.RangerEntry { 41 | return &customRangerEntry{ 42 | ipNet: ipNet, 43 | asn: asn, 44 | } 45 | } 46 | 47 | // entry point 48 | func main() { 49 | 50 | // instantiate NewPCTrieRanger 51 | ranger := cidranger.NewPCTrieRanger() 52 | 53 | // Load sample data using our custom function 54 | _, network, _ := net.ParseCIDR("192.168.1.0/24") 55 | ranger.Insert(newCustomRangerEntry(*network, "0001")) 56 | 57 | _, network, _ = net.ParseCIDR("128.168.1.0/24") 58 | ranger.Insert(newCustomRangerEntry(*network, "0002")) 59 | 60 | // Check if IP is contained within ranger 61 | contains, err := ranger.Contains(net.ParseIP("128.168.1.7")) 62 | if err != nil { 63 | fmt.Println("ranger.Contains()", err.Error()) 64 | os.Exit(1) 65 | } 66 | fmt.Println("Contains:", contains) 67 | 68 | // request networks containing this IP 69 | ip := "192.168.1.42" 70 | entries, err := ranger.ContainingNetworks(net.ParseIP(ip)) 71 | if err != nil { 72 | fmt.Println("ranger.ContainingNetworks()", err.Error()) 73 | os.Exit(1) 74 | } 75 | 76 | fmt.Printf("Entries for %s:\n", ip) 77 | for _, e := range entries { 78 | 79 | // Cast e (cidranger.RangerEntry to struct customRangerEntry 80 | entry, ok := e.(*customRangerEntry) 81 | if !ok { 82 | continue 83 | } 84 | 85 | // Get network (converted to string by function) 86 | n := entry.NetworkStr() 87 | 88 | // Get ASN 89 | a := entry.Asn() 90 | 91 | // Display 92 | fmt.Println("\t", n, a) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yl2chen/cidranger 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/stretchr/testify v1.6.1 7 | gopkg.in/yaml.v2 v2.2.2 // indirect 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 6 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 7 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 8 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= 9 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 12 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 13 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 14 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 15 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 16 | -------------------------------------------------------------------------------- /net/ip.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package net provides utility functions for working with IPs (net.IP). 3 | */ 4 | package net 5 | 6 | import ( 7 | "bytes" 8 | "encoding/binary" 9 | "fmt" 10 | "math" 11 | "net" 12 | ) 13 | 14 | // IPVersion is version of IP address. 15 | type IPVersion string 16 | 17 | // Helper constants. 18 | const ( 19 | IPv4Uint32Count = 1 20 | IPv6Uint32Count = 4 21 | 22 | BitsPerUint32 = 32 23 | BytePerUint32 = 4 24 | 25 | IPv4 IPVersion = "IPv4" 26 | IPv6 IPVersion = "IPv6" 27 | ) 28 | 29 | // ErrInvalidBitPosition is returned when bits requested is not valid. 30 | var ErrInvalidBitPosition = fmt.Errorf("bit position not valid") 31 | 32 | // ErrVersionMismatch is returned upon mismatch in network input versions. 33 | var ErrVersionMismatch = fmt.Errorf("Network input version mismatch") 34 | 35 | // ErrNoGreatestCommonBit is an error returned when no greatest common bit 36 | // exists for the cidr ranges. 37 | var ErrNoGreatestCommonBit = fmt.Errorf("No greatest common bit") 38 | 39 | // NetworkNumber represents an IP address using uint32 as internal storage. 40 | // IPv4 usings 1 uint32, while IPv6 uses 4 uint32. 41 | type NetworkNumber []uint32 42 | 43 | // NewNetworkNumber returns a equivalent NetworkNumber to given IP address, 44 | // return nil if ip is neither IPv4 nor IPv6. 45 | func NewNetworkNumber(ip net.IP) NetworkNumber { 46 | if ip == nil { 47 | return nil 48 | } 49 | coercedIP := ip.To4() 50 | parts := 1 51 | if coercedIP == nil { 52 | coercedIP = ip.To16() 53 | parts = 4 54 | } 55 | if coercedIP == nil { 56 | return nil 57 | } 58 | nn := make(NetworkNumber, parts) 59 | for i := 0; i < parts; i++ { 60 | idx := i * net.IPv4len 61 | nn[i] = binary.BigEndian.Uint32(coercedIP[idx : idx+net.IPv4len]) 62 | } 63 | return nn 64 | } 65 | 66 | // ToV4 returns ip address if ip is IPv4, returns nil otherwise. 67 | func (n NetworkNumber) ToV4() NetworkNumber { 68 | if len(n) != IPv4Uint32Count { 69 | return nil 70 | } 71 | return n 72 | } 73 | 74 | // ToV6 returns ip address if ip is IPv6, returns nil otherwise. 75 | func (n NetworkNumber) ToV6() NetworkNumber { 76 | if len(n) != IPv6Uint32Count { 77 | return nil 78 | } 79 | return n 80 | } 81 | 82 | // ToIP returns equivalent net.IP. 83 | func (n NetworkNumber) ToIP() net.IP { 84 | ip := make(net.IP, len(n)*BytePerUint32) 85 | for i := 0; i < len(n); i++ { 86 | idx := i * net.IPv4len 87 | binary.BigEndian.PutUint32(ip[idx:idx+net.IPv4len], n[i]) 88 | } 89 | if len(ip) == net.IPv4len { 90 | ip = net.IPv4(ip[0], ip[1], ip[2], ip[3]) 91 | } 92 | return ip 93 | } 94 | 95 | // Equal is the equality test for 2 network numbers. 96 | func (n NetworkNumber) Equal(n1 NetworkNumber) bool { 97 | if len(n) != len(n1) { 98 | return false 99 | } 100 | if n[0] != n1[0] { 101 | return false 102 | } 103 | if len(n) == IPv6Uint32Count { 104 | return n[1] == n1[1] && n[2] == n1[2] && n[3] == n1[3] 105 | } 106 | return true 107 | } 108 | 109 | // Next returns the next logical network number. 110 | func (n NetworkNumber) Next() NetworkNumber { 111 | newIP := make(NetworkNumber, len(n)) 112 | copy(newIP, n) 113 | for i := len(newIP) - 1; i >= 0; i-- { 114 | newIP[i]++ 115 | if newIP[i] > 0 { 116 | break 117 | } 118 | } 119 | return newIP 120 | } 121 | 122 | // Previous returns the previous logical network number. 123 | func (n NetworkNumber) Previous() NetworkNumber { 124 | newIP := make(NetworkNumber, len(n)) 125 | copy(newIP, n) 126 | for i := len(newIP) - 1; i >= 0; i-- { 127 | newIP[i]-- 128 | if newIP[i] < math.MaxUint32 { 129 | break 130 | } 131 | } 132 | return newIP 133 | } 134 | 135 | // Bit returns uint32 representing the bit value at given position, e.g., 136 | // "128.0.0.0" has bit value of 1 at position 31, and 0 for positions 30 to 0. 137 | func (n NetworkNumber) Bit(position uint) (uint32, error) { 138 | if int(position) > len(n)*BitsPerUint32-1 { 139 | return 0, ErrInvalidBitPosition 140 | } 141 | idx := len(n) - 1 - int(position/BitsPerUint32) 142 | // Mod 31 to get array index. 143 | rShift := position & (BitsPerUint32 - 1) 144 | return (n[idx] >> rShift) & 1, nil 145 | } 146 | 147 | // LeastCommonBitPosition returns the smallest position of the preceding common 148 | // bits of the 2 network numbers, and returns an error ErrNoGreatestCommonBit 149 | // if the two network number diverges from the first bit. 150 | // e.g., if the network number diverges after the 1st bit, it returns 131 for 151 | // IPv6 and 31 for IPv4 . 152 | func (n NetworkNumber) LeastCommonBitPosition(n1 NetworkNumber) (uint, error) { 153 | if len(n) != len(n1) { 154 | return 0, ErrVersionMismatch 155 | } 156 | for i := 0; i < len(n); i++ { 157 | mask := uint32(1) << 31 158 | pos := uint(31) 159 | for ; mask > 0; mask >>= 1 { 160 | if n[i]&mask != n1[i]&mask { 161 | if i == 0 && pos == 31 { 162 | return 0, ErrNoGreatestCommonBit 163 | } 164 | return (pos + 1) + uint(BitsPerUint32)*uint(len(n)-i-1), nil 165 | } 166 | pos-- 167 | } 168 | } 169 | return 0, nil 170 | } 171 | 172 | // Network represents a block of network numbers, also known as CIDR. 173 | type Network struct { 174 | net.IPNet 175 | Number NetworkNumber 176 | Mask NetworkNumberMask 177 | } 178 | 179 | // NewNetwork returns Network built using given net.IPNet. 180 | func NewNetwork(ipNet net.IPNet) Network { 181 | return Network{ 182 | IPNet: ipNet, 183 | Number: NewNetworkNumber(ipNet.IP), 184 | Mask: NetworkNumberMask(NewNetworkNumber(net.IP(ipNet.Mask))), 185 | } 186 | } 187 | 188 | // Masked returns a new network conforming to new mask. 189 | func (n Network) Masked(ones int) Network { 190 | mask := net.CIDRMask(ones, len(n.Number)*BitsPerUint32) 191 | return NewNetwork(net.IPNet{ 192 | IP: n.IP.Mask(mask), 193 | Mask: mask, 194 | }) 195 | } 196 | 197 | // Contains returns true if NetworkNumber is in range of Network, false 198 | // otherwise. 199 | func (n Network) Contains(nn NetworkNumber) bool { 200 | if len(n.Mask) != len(nn) { 201 | return false 202 | } 203 | if nn[0]&n.Mask[0] != n.Number[0] { 204 | return false 205 | } 206 | if len(nn) == IPv6Uint32Count { 207 | return nn[1]&n.Mask[1] == n.Number[1] && nn[2]&n.Mask[2] == n.Number[2] && nn[3]&n.Mask[3] == n.Number[3] 208 | } 209 | return true 210 | } 211 | 212 | // Contains returns true if Network covers o, false otherwise 213 | func (n Network) Covers(o Network) bool { 214 | if len(n.Number) != len(o.Number) { 215 | return false 216 | } 217 | nMaskSize, _ := n.IPNet.Mask.Size() 218 | oMaskSize, _ := o.IPNet.Mask.Size() 219 | return n.Contains(o.Number) && nMaskSize <= oMaskSize 220 | } 221 | 222 | // LeastCommonBitPosition returns the smallest position of the preceding common 223 | // bits of the 2 networks, and returns an error ErrNoGreatestCommonBit 224 | // if the two network number diverges from the first bit. 225 | func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) { 226 | maskSize, _ := n.IPNet.Mask.Size() 227 | if maskSize1, _ := n1.IPNet.Mask.Size(); maskSize1 < maskSize { 228 | maskSize = maskSize1 229 | } 230 | maskPosition := len(n1.Number)*BitsPerUint32 - maskSize 231 | lcb, err := n.Number.LeastCommonBitPosition(n1.Number) 232 | if err != nil { 233 | return 0, err 234 | } 235 | return uint(math.Max(float64(maskPosition), float64(lcb))), nil 236 | } 237 | 238 | // Equal is the equality test for 2 networks. 239 | func (n Network) Equal(n1 Network) bool { 240 | return bytes.Equal(n.IPNet.IP, n1.IPNet.IP) && bytes.Equal(n.IPNet.Mask, n1.IPNet.Mask) 241 | } 242 | 243 | func (n Network) String() string { 244 | return n.IPNet.String() 245 | } 246 | 247 | // NetworkNumberMask is an IP address. 248 | type NetworkNumberMask NetworkNumber 249 | 250 | // Mask returns a new masked NetworkNumber from given NetworkNumber. 251 | func (m NetworkNumberMask) Mask(n NetworkNumber) (NetworkNumber, error) { 252 | if len(m) != len(n) { 253 | return nil, ErrVersionMismatch 254 | } 255 | result := make(NetworkNumber, len(m)) 256 | result[0] = m[0] & n[0] 257 | if len(m) == IPv6Uint32Count { 258 | result[1] = m[1] & n[1] 259 | result[2] = m[2] & n[2] 260 | result[3] = m[3] & n[3] 261 | } 262 | return result, nil 263 | } 264 | 265 | // NextIP returns the next sequential ip. 266 | func NextIP(ip net.IP) net.IP { 267 | return NewNetworkNumber(ip).Next().ToIP() 268 | } 269 | 270 | // PreviousIP returns the previous sequential ip. 271 | func PreviousIP(ip net.IP) net.IP { 272 | return NewNetworkNumber(ip).Previous().ToIP() 273 | } 274 | -------------------------------------------------------------------------------- /net/ip_test.go: -------------------------------------------------------------------------------- 1 | package net 2 | 3 | import ( 4 | "math" 5 | "net" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNewNetworkNumber(t *testing.T) { 12 | cases := []struct { 13 | ip net.IP 14 | nn NetworkNumber 15 | name string 16 | }{ 17 | {nil, nil, "nil input"}, 18 | {net.IP([]byte{1, 1, 1, 1, 1}), nil, "bad input"}, 19 | {net.ParseIP("128.0.0.0"), NetworkNumber([]uint32{2147483648}), "IPv4"}, 20 | { 21 | net.ParseIP("2001:0db8::ff00:0042:8329"), 22 | NetworkNumber([]uint32{536939960, 0, 65280, 4358953}), 23 | "IPv6", 24 | }, 25 | } 26 | for _, tc := range cases { 27 | t.Run(tc.name, func(t *testing.T) { 28 | assert.Equal(t, tc.nn, NewNetworkNumber(tc.ip)) 29 | }) 30 | } 31 | } 32 | 33 | func TestNetworkNumberAssertion(t *testing.T) { 34 | cases := []struct { 35 | ip NetworkNumber 36 | to4 NetworkNumber 37 | to6 NetworkNumber 38 | name string 39 | }{ 40 | {NetworkNumber([]uint32{1}), NetworkNumber([]uint32{1}), nil, "is IPv4"}, 41 | {NetworkNumber([]uint32{1, 1, 1, 1}), nil, NetworkNumber([]uint32{1, 1, 1, 1}), "is IPv6"}, 42 | {NetworkNumber([]uint32{1, 1}), nil, nil, "is invalid"}, 43 | } 44 | for _, tc := range cases { 45 | t.Run(tc.name, func(t *testing.T) { 46 | assert.Equal(t, tc.to4, tc.ip.ToV4()) 47 | assert.Equal(t, tc.to6, tc.ip.ToV6()) 48 | }) 49 | } 50 | } 51 | 52 | func TestNetworkNumberBit(t *testing.T) { 53 | cases := []struct { 54 | ip NetworkNumber 55 | ones map[uint]bool 56 | name string 57 | }{ 58 | {NewNetworkNumber(net.ParseIP("128.0.0.0")), map[uint]bool{31: true}, "128.0.0.0"}, 59 | {NewNetworkNumber(net.ParseIP("1.1.1.1")), map[uint]bool{0: true, 8: true, 16: true, 24: true}, "1.1.1.1"}, 60 | {NewNetworkNumber(net.ParseIP("8000::")), map[uint]bool{127: true}, "8000::"}, 61 | {NewNetworkNumber(net.ParseIP("8000::8000")), map[uint]bool{127: true, 15: true}, "8000::8000"}, 62 | } 63 | for _, tc := range cases { 64 | t.Run(tc.name, func(t *testing.T) { 65 | for i := uint(0); i < uint(len(tc.ip)*BitsPerUint32); i++ { 66 | bit, err := tc.ip.Bit(i) 67 | assert.NoError(t, err) 68 | if _, isOne := tc.ones[i]; isOne { 69 | assert.Equal(t, uint32(1), bit) 70 | } else { 71 | assert.Equal(t, uint32(0), bit) 72 | } 73 | } 74 | }) 75 | } 76 | } 77 | 78 | func TestNetworkNumberBitError(t *testing.T) { 79 | cases := []struct { 80 | ip NetworkNumber 81 | position uint 82 | err error 83 | name string 84 | }{ 85 | {NewNetworkNumber(net.ParseIP("128.0.0.0")), 0, nil, "IPv4 index in bound"}, 86 | {NewNetworkNumber(net.ParseIP("128.0.0.0")), 31, nil, "IPv4 index in bound"}, 87 | {NewNetworkNumber(net.ParseIP("128.0.0.0")), 32, ErrInvalidBitPosition, "IPv4 index out of bounds"}, 88 | {NewNetworkNumber(net.ParseIP("8000::")), 0, nil, "IPv6 index in bound"}, 89 | {NewNetworkNumber(net.ParseIP("8000::")), 127, nil, "IPv6 index in bound"}, 90 | {NewNetworkNumber(net.ParseIP("8000::")), 128, ErrInvalidBitPosition, "IPv6 index out of bounds"}, 91 | } 92 | for _, tc := range cases { 93 | t.Run(tc.name, func(t *testing.T) { 94 | _, err := tc.ip.Bit(tc.position) 95 | assert.Equal(t, tc.err, err) 96 | }) 97 | } 98 | } 99 | 100 | func TestNetworkNumberEqual(t *testing.T) { 101 | cases := []struct { 102 | n1 NetworkNumber 103 | n2 NetworkNumber 104 | equals bool 105 | name string 106 | }{ 107 | {NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32}, true, "IPv4 equals"}, 108 | {NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32 - 1}, false, "IPv4 does not equal"}, 109 | {NetworkNumber{1, 1, 1, 1}, NetworkNumber{1, 1, 1, 1}, true, "IPv6 equals"}, 110 | {NetworkNumber{1, 1, 1, 1}, NetworkNumber{1, 1, 1, 2}, false, "IPv6 does not equal"}, 111 | {NetworkNumber{1}, NetworkNumber{1, 2, 3, 4}, false, "Version mismatch"}, 112 | } 113 | for _, tc := range cases { 114 | t.Run(tc.name, func(t *testing.T) { 115 | assert.Equal(t, tc.equals, tc.n1.Equal(tc.n2)) 116 | }) 117 | } 118 | } 119 | 120 | func TestNetworkNumberNext(t *testing.T) { 121 | cases := []struct { 122 | ip string 123 | next string 124 | name string 125 | }{ 126 | {"0.0.0.0", "0.0.0.1", "IPv4 basic"}, 127 | {"0.0.0.255", "0.0.1.0", "IPv4 rollover"}, 128 | {"0.255.255.255", "1.0.0.0", "IPv4 consecutive rollover"}, 129 | {"8000::0", "8000::1", "IPv6 basic"}, 130 | {"0::ffff", "0::1:0", "IPv6 rollover"}, 131 | {"0:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "1::", "IPv6 consecutive rollover"}, 132 | } 133 | 134 | for _, tc := range cases { 135 | t.Run(tc.name, func(t *testing.T) { 136 | ip := NewNetworkNumber(net.ParseIP(tc.ip)) 137 | expected := NewNetworkNumber(net.ParseIP(tc.next)) 138 | assert.Equal(t, expected, ip.Next()) 139 | }) 140 | } 141 | } 142 | 143 | func TestNeworkNumberPrevious(t *testing.T) { 144 | cases := []struct { 145 | ip string 146 | previous string 147 | name string 148 | }{ 149 | {"0.0.0.1", "0.0.0.0", "IPv4 basic"}, 150 | {"0.0.1.0", "0.0.0.255", "IPv4 rollover"}, 151 | {"1.0.0.0", "0.255.255.255", "IPv4 consecutive rollover"}, 152 | {"8000::1", "8000::0", "IPv6 basic"}, 153 | {"0::1:0", "0::ffff", "IPv6 rollover"}, 154 | {"1::0", "0:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "IPv6 consecutive rollover"}, 155 | } 156 | 157 | for _, tc := range cases { 158 | t.Run(tc.name, func(t *testing.T) { 159 | ip := NewNetworkNumber(net.ParseIP(tc.ip)) 160 | expected := NewNetworkNumber(net.ParseIP(tc.previous)) 161 | assert.Equal(t, expected, ip.Previous()) 162 | }) 163 | } 164 | } 165 | 166 | func TestLeastCommonBitPositionForNetworks(t *testing.T) { 167 | cases := []struct { 168 | ip1 NetworkNumber 169 | ip2 NetworkNumber 170 | position uint 171 | err error 172 | name string 173 | }{ 174 | { 175 | NetworkNumber([]uint32{2147483648}), 176 | NetworkNumber([]uint32{3221225472, 0, 0, 0}), 177 | 0, ErrVersionMismatch, "Version mismatch", 178 | }, 179 | { 180 | NetworkNumber([]uint32{2147483648}), 181 | NetworkNumber([]uint32{3221225472}), 182 | 31, nil, "IPv4 31st position", 183 | }, 184 | { 185 | NetworkNumber([]uint32{2147483648}), 186 | NetworkNumber([]uint32{2147483648}), 187 | 0, nil, "IPv4 0th position", 188 | }, 189 | { 190 | NetworkNumber([]uint32{2147483648}), 191 | NetworkNumber([]uint32{1}), 192 | 0, ErrNoGreatestCommonBit, "IPv4 diverge at first bit", 193 | }, 194 | { 195 | NetworkNumber([]uint32{2147483648, 0, 0, 0}), 196 | NetworkNumber([]uint32{3221225472, 0, 0, 0}), 197 | 127, nil, "IPv6 127th position", 198 | }, 199 | { 200 | NetworkNumber([]uint32{2147483648, 1, 1, 1}), 201 | NetworkNumber([]uint32{2147483648, 1, 1, 1}), 202 | 0, nil, "IPv6 0th position", 203 | }, 204 | { 205 | NetworkNumber([]uint32{2147483648, 0, 0, 0}), 206 | NetworkNumber([]uint32{0, 0, 0, 1}), 207 | 0, ErrNoGreatestCommonBit, "IPv6 diverge at first bit", 208 | }, 209 | } 210 | for _, tc := range cases { 211 | t.Run(tc.name, func(t *testing.T) { 212 | pos, err := tc.ip1.LeastCommonBitPosition(tc.ip2) 213 | assert.Equal(t, tc.err, err) 214 | assert.Equal(t, tc.position, pos) 215 | }) 216 | } 217 | } 218 | 219 | func TestNewNetwork(t *testing.T) { 220 | _, ipNet, _ := net.ParseCIDR("192.128.0.0/24") 221 | n := NewNetwork(*ipNet) 222 | 223 | assert.Equal(t, *ipNet, n.IPNet) 224 | assert.Equal(t, NetworkNumber{3229614080}, n.Number) 225 | assert.Equal(t, NetworkNumberMask{math.MaxUint32 - uint32(math.MaxUint8)}, n.Mask) 226 | } 227 | 228 | func TestNetworkMasked(t *testing.T) { 229 | cases := []struct { 230 | network string 231 | mask int 232 | maskedNetwork string 233 | }{ 234 | {"192.168.0.0/16", 16, "192.168.0.0/16"}, 235 | {"192.168.0.0/16", 14, "192.168.0.0/14"}, 236 | {"192.168.0.0/16", 18, "192.168.0.0/18"}, 237 | {"192.168.0.0/16", 8, "192.0.0.0/8"}, 238 | {"8000::/128", 96, "8000::/96"}, 239 | {"8000::/128", 128, "8000::/128"}, 240 | {"8000::/96", 112, "8000::/112"}, 241 | {"8000:ffff::/96", 16, "8000::/16"}, 242 | } 243 | for _, testcase := range cases { 244 | _, network, _ := net.ParseCIDR(testcase.network) 245 | _, expected, _ := net.ParseCIDR(testcase.maskedNetwork) 246 | n1 := NewNetwork(*network) 247 | e1 := NewNetwork(*expected) 248 | assert.True(t, e1.String() == n1.Masked(testcase.mask).String()) 249 | } 250 | } 251 | 252 | func TestNetworkEqual(t *testing.T) { 253 | cases := []struct { 254 | n1 string 255 | n2 string 256 | equal bool 257 | name string 258 | }{ 259 | {"192.128.0.0/24", "192.128.0.0/24", true, "IPv4 equals"}, 260 | {"192.128.0.0/24", "192.128.0.0/23", false, "IPv4 not equals"}, 261 | {"8000::/24", "8000::/24", true, "IPv6 equals"}, 262 | {"8000::/24", "8000::/23", false, "IPv6 not equals"}, 263 | } 264 | for _, tc := range cases { 265 | t.Run(tc.name, func(t *testing.T) { 266 | _, ipNet1, _ := net.ParseCIDR(tc.n1) 267 | _, ipNet2, _ := net.ParseCIDR(tc.n2) 268 | assert.Equal(t, tc.equal, NewNetwork(*ipNet1).Equal(NewNetwork(*ipNet2))) 269 | }) 270 | } 271 | } 272 | 273 | func TestNetworkContains(t *testing.T) { 274 | cases := []struct { 275 | network string 276 | firstIP string 277 | lastIP string 278 | name string 279 | }{ 280 | {"192.168.0.0/24", "192.168.0.0", "192.168.0.255", "192.168.0.0/24 contains"}, 281 | {"8000::0/120", "8000::0", "8000::ff", "8000::0/120 contains"}, 282 | } 283 | for _, tc := range cases { 284 | t.Run(tc.name, func(t *testing.T) { 285 | _, net1, _ := net.ParseCIDR(tc.network) 286 | network := NewNetwork(*net1) 287 | ip := NewNetworkNumber(net.ParseIP(tc.firstIP)) 288 | lastIP := NewNetworkNumber(net.ParseIP(tc.lastIP)) 289 | assert.False(t, network.Contains(ip.Previous())) 290 | assert.False(t, network.Contains(lastIP.Next())) 291 | for ; !ip.Equal(lastIP.Next()); ip = ip.Next() { 292 | assert.True(t, network.Contains(ip)) 293 | } 294 | }) 295 | } 296 | } 297 | 298 | func TestNetworkContainsVersionMismatch(t *testing.T) { 299 | cases := []struct { 300 | network string 301 | ip string 302 | name string 303 | }{ 304 | {"192.168.0.0/24", "8000::0", "IPv6 in IPv4 network"}, 305 | {"8000::0/120", "192.168.0.0", "IPv4 in IPv6 network"}, 306 | } 307 | for _, tc := range cases { 308 | t.Run(tc.name, func(t *testing.T) { 309 | _, net1, _ := net.ParseCIDR(tc.network) 310 | network := NewNetwork(*net1) 311 | assert.False(t, network.Contains(NewNetworkNumber(net.ParseIP(tc.ip)))) 312 | }) 313 | } 314 | } 315 | 316 | func TestNetworkCovers(t *testing.T) { 317 | cases := []struct { 318 | network string 319 | covers string 320 | result bool 321 | name string 322 | }{ 323 | {"10.0.0.0/24", "10.0.0.1/25", true, "contains"}, 324 | {"10.0.0.0/24", "11.0.0.1/25", false, "not contains"}, 325 | {"10.0.0.0/16", "10.0.0.0/15", false, "prefix false"}, 326 | {"10.0.0.0/15", "10.0.0.0/16", true, "prefix true"}, 327 | {"10.0.0.0/15", "10.0.0.0/15", true, "same"}, 328 | {"10::0/15", "10.0.0.0/15", false, "ip version mismatch"}, 329 | {"10::0/15", "10::0/16", true, "ipv6"}, 330 | } 331 | 332 | for _, tc := range cases { 333 | t.Run(tc.name, func(t *testing.T) { 334 | _, n, _ := net.ParseCIDR(tc.network) 335 | network := NewNetwork(*n) 336 | _, n, _ = net.ParseCIDR(tc.covers) 337 | covers := NewNetwork(*n) 338 | assert.Equal(t, tc.result, network.Covers(covers)) 339 | }) 340 | } 341 | } 342 | 343 | func TestNetworkLeastCommonBitPosition(t *testing.T) { 344 | cases := []struct { 345 | cidr1 string 346 | cidr2 string 347 | expectedPos uint 348 | expectedErr error 349 | name string 350 | }{ 351 | {"0.0.1.0/24", "0.0.0.0/24", uint(9), nil, "IPv4 diverge before mask pos"}, 352 | {"0.0.0.0/24", "0.0.0.0/24", uint(8), nil, "IPv4 diverge after mask pos"}, 353 | {"0.0.0.128/24", "0.0.0.0/16", uint(16), nil, "IPv4 different mask pos"}, 354 | {"128.0.0.0/24", "0.0.0.0/24", 0, ErrNoGreatestCommonBit, "IPv4 diverge at 1st pos"}, 355 | {"8000::/96", "8000::1:0:0/96", uint(33), nil, "IPv6 diverge before mask pos"}, 356 | {"8000::/96", "8000::8:0/96", uint(32), nil, "IPv6 diverge after mask pos"}, 357 | {"8000::/96", "8000::/95", uint(33), nil, "IPv6 different mask pos"}, 358 | {"ffff::0/24", "0::1/24", 0, ErrNoGreatestCommonBit, "IPv6 diverge at 1st pos"}, 359 | } 360 | for _, c := range cases { 361 | _, cidr1, err := net.ParseCIDR(c.cidr1) 362 | assert.NoError(t, err) 363 | _, cidr2, err := net.ParseCIDR(c.cidr2) 364 | assert.NoError(t, err) 365 | n1 := NewNetwork(*cidr1) 366 | pos, err := n1.LeastCommonBitPosition(NewNetwork(*cidr2)) 367 | if c.expectedErr != nil { 368 | assert.Equal(t, c.expectedErr, err) 369 | } else { 370 | assert.Equal(t, c.expectedPos, pos) 371 | } 372 | } 373 | } 374 | 375 | func TestMask(t *testing.T) { 376 | cases := []struct { 377 | mask NetworkNumberMask 378 | ip NetworkNumber 379 | masked NetworkNumber 380 | err error 381 | name string 382 | }{ 383 | {NetworkNumberMask{math.MaxUint32}, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32}, nil, "nop IPv4 mask"}, 384 | {NetworkNumberMask{math.MaxUint32 - math.MaxUint16}, NetworkNumber{math.MaxUint16 + 1}, NetworkNumber{math.MaxUint16 + 1}, nil, "nop IPv4 mask"}, 385 | {NetworkNumberMask{math.MaxUint32 - math.MaxUint16}, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32 - math.MaxUint16}, nil, "IPv4 masked"}, 386 | {NetworkNumberMask{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, nil, "nop IPv6 mask"}, 387 | {NetworkNumberMask{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, nil, "nop IPv6 mask"}, 388 | {NetworkNumberMask{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, nil, "IPv6 masked"}, 389 | {NetworkNumberMask{math.MaxUint32}, NetworkNumber{math.MaxUint32, 0}, nil, ErrVersionMismatch, "Version mismatch"}, 390 | } 391 | for _, tc := range cases { 392 | t.Run(tc.name, func(t *testing.T) { 393 | masked, err := tc.mask.Mask(tc.ip) 394 | assert.Equal(t, tc.masked, masked) 395 | assert.Equal(t, tc.err, err) 396 | }) 397 | } 398 | } 399 | 400 | func TestNextIP(t *testing.T) { 401 | cases := []struct { 402 | ip string 403 | next string 404 | name string 405 | }{ 406 | {"0.0.0.0", "0.0.0.1", "IPv4 basic"}, 407 | {"0.0.0.255", "0.0.1.0", "IPv4 rollover"}, 408 | {"0.255.255.255", "1.0.0.0", "IPv4 consecutive rollover"}, 409 | {"8000::0", "8000::1", "IPv6 basic"}, 410 | {"0::ffff", "0::1:0", "IPv6 rollover"}, 411 | {"0:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "1::", "IPv6 consecutive rollover"}, 412 | } 413 | 414 | for _, tc := range cases { 415 | t.Run(tc.name, func(t *testing.T) { 416 | assert.Equal(t, net.ParseIP(tc.next), NextIP(net.ParseIP(tc.ip))) 417 | }) 418 | } 419 | } 420 | 421 | func TestPreviousIP(t *testing.T) { 422 | cases := []struct { 423 | ip string 424 | next string 425 | name string 426 | }{ 427 | {"0.0.0.1", "0.0.0.0", "IPv4 basic"}, 428 | {"0.0.1.0", "0.0.0.255", "IPv4 rollover"}, 429 | {"1.0.0.0", "0.255.255.255", "IPv4 consecutive rollover"}, 430 | {"8000::1", "8000::0", "IPv6 basic"}, 431 | {"0::1:0", "0::ffff", "IPv6 rollover"}, 432 | {"1::0", "0:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "IPv6 consecutive rollover"}, 433 | } 434 | 435 | for _, tc := range cases { 436 | t.Run(tc.name, func(t *testing.T) { 437 | assert.Equal(t, net.ParseIP(tc.next), PreviousIP(net.ParseIP(tc.ip))) 438 | }) 439 | } 440 | } 441 | 442 | /* 443 | ********************************* 444 | Benchmarking ip manipulations. 445 | ********************************* 446 | */ 447 | func BenchmarkNetworkNumberBitIPv4(b *testing.B) { 448 | benchmarkNetworkNumberBit(b, "52.95.110.1", 6) 449 | } 450 | func BenchmarkNetworkNumberBitIPv6(b *testing.B) { 451 | benchmarkNetworkNumberBit(b, "2600:1ffe:e000::", 44) 452 | } 453 | 454 | func BenchmarkNetworkNumberEqualIPv4(b *testing.B) { 455 | benchmarkNetworkNumberEqual(b, "52.95.110.1", "52.95.110.1") 456 | } 457 | 458 | func BenchmarkNetworkNumberEqualIPv6(b *testing.B) { 459 | benchmarkNetworkNumberEqual(b, "2600:1ffe:e000::", "2600:1ffe:e000::") 460 | } 461 | 462 | func BenchmarkNetworkContainsIPv4(b *testing.B) { 463 | benchmarkNetworkContains(b, "52.95.110.0/24", "52.95.110.1") 464 | } 465 | 466 | func BenchmarkNetworkContainsIPv6(b *testing.B) { 467 | benchmarkNetworkContains(b, "2600:1ffe:e000::/40", "2600:1ffe:f000::") 468 | } 469 | 470 | func BenchmarkNetworkEqualIPv4(b *testing.B) { 471 | benchmarkNetworkEqual(b, "192.128.0.0/24", "192.128.0.0/24") 472 | } 473 | 474 | func BenchmarkNetworkEqualIPv6(b *testing.B) { 475 | benchmarkNetworkEqual(b, "8000::/24", "8000::/24") 476 | } 477 | 478 | func benchmarkNetworkNumberBit(b *testing.B, ip string, pos uint) { 479 | nn := NewNetworkNumber(net.ParseIP(ip)) 480 | for n := 0; n < b.N; n++ { 481 | nn.Bit(pos) 482 | } 483 | } 484 | 485 | func benchmarkNetworkNumberEqual(b *testing.B, ip1 string, ip2 string) { 486 | nn1 := NewNetworkNumber(net.ParseIP(ip1)) 487 | nn2 := NewNetworkNumber(net.ParseIP(ip2)) 488 | for n := 0; n < b.N; n++ { 489 | nn1.Equal(nn2) 490 | } 491 | } 492 | 493 | func benchmarkNetworkContains(b *testing.B, cidr string, ip string) { 494 | nn := NewNetworkNumber(net.ParseIP(ip)) 495 | _, ipNet, _ := net.ParseCIDR(cidr) 496 | network := NewNetwork(*ipNet) 497 | for n := 0; n < b.N; n++ { 498 | network.Contains(nn) 499 | } 500 | } 501 | 502 | func benchmarkNetworkEqual(b *testing.B, net1 string, net2 string) { 503 | _, ipNet1, _ := net.ParseCIDR(net1) 504 | _, ipNet2, _ := net.ParseCIDR(net2) 505 | n1 := NewNetwork(*ipNet1) 506 | n2 := NewNetwork(*ipNet2) 507 | for n := 0; n < b.N; n++ { 508 | n1.Equal(n2) 509 | } 510 | } 511 | -------------------------------------------------------------------------------- /trie.go: -------------------------------------------------------------------------------- 1 | package cidranger 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "strings" 7 | 8 | rnet "github.com/yl2chen/cidranger/net" 9 | ) 10 | 11 | // prefixTrie is a path-compressed (PC) trie implementation of the 12 | // ranger interface inspired by this blog post: 13 | // https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux 14 | // 15 | // CIDR blocks are stored using a prefix tree structure where each node has its 16 | // parent as prefix, and the path from the root node represents current CIDR 17 | // block. 18 | // 19 | // For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are 20 | // 32 bits long and each bit represents a prefix tree starting at that bit. This 21 | // property also guarantees constant lookup time in Big-O notation. 22 | // 23 | // Path compression compresses a string of node with only 1 child into a single 24 | // node, decrease the amount of lookups necessary during containment tests. 25 | // 26 | // Level compression dictates the amount of direct children of a node by 27 | // allowing it to handle multiple bits in the path. The heuristic (based on 28 | // children population) to decide when the compression and decompression happens 29 | // is outlined in the prior linked blog, and will be experimented with in more 30 | // depth in this project in the future. 31 | // 32 | // Note: Can not insert both IPv4 and IPv6 network addresses into the same 33 | // prefix trie, use versionedRanger wrapper instead. 34 | // 35 | // TODO: Implement level-compressed component of the LPC trie. 36 | type prefixTrie struct { 37 | parent *prefixTrie 38 | children []*prefixTrie 39 | 40 | numBitsSkipped uint 41 | numBitsHandled uint 42 | 43 | network rnet.Network 44 | entry RangerEntry 45 | 46 | size int // This is only maintained in the root trie. 47 | } 48 | 49 | // newPrefixTree creates a new prefixTrie. 50 | func newPrefixTree(version rnet.IPVersion) Ranger { 51 | _, rootNet, _ := net.ParseCIDR("0.0.0.0/0") 52 | if version == rnet.IPv6 { 53 | _, rootNet, _ = net.ParseCIDR("0::0/0") 54 | } 55 | return &prefixTrie{ 56 | children: make([]*prefixTrie, 2, 2), 57 | numBitsSkipped: 0, 58 | numBitsHandled: 1, 59 | network: rnet.NewNetwork(*rootNet), 60 | } 61 | } 62 | 63 | func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie { 64 | path := &prefixTrie{ 65 | children: make([]*prefixTrie, 2, 2), 66 | numBitsSkipped: numBitsSkipped, 67 | numBitsHandled: 1, 68 | network: network.Masked(int(numBitsSkipped)), 69 | } 70 | return path 71 | } 72 | 73 | func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie { 74 | ones, _ := network.IPNet.Mask.Size() 75 | leaf := newPathprefixTrie(network, uint(ones)) 76 | leaf.entry = entry 77 | return leaf 78 | } 79 | 80 | // Insert inserts a RangerEntry into prefix trie. 81 | func (p *prefixTrie) Insert(entry RangerEntry) error { 82 | network := entry.Network() 83 | sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry) 84 | if sizeIncreased { 85 | p.size++ 86 | } 87 | return err 88 | } 89 | 90 | // Remove removes RangerEntry identified by given network from trie. 91 | func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) { 92 | entry, err := p.remove(rnet.NewNetwork(network)) 93 | if entry != nil { 94 | p.size-- 95 | } 96 | return entry, err 97 | } 98 | 99 | // Contains returns boolean indicating whether given ip is contained in any 100 | // of the inserted networks. 101 | func (p *prefixTrie) Contains(ip net.IP) (bool, error) { 102 | nn := rnet.NewNetworkNumber(ip) 103 | if nn == nil { 104 | return false, ErrInvalidNetworkNumberInput 105 | } 106 | return p.contains(nn) 107 | } 108 | 109 | // ContainingNetworks returns the list of RangerEntry(s) the given ip is 110 | // contained in in ascending prefix order. 111 | func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { 112 | nn := rnet.NewNetworkNumber(ip) 113 | if nn == nil { 114 | return nil, ErrInvalidNetworkNumberInput 115 | } 116 | return p.containingNetworks(nn) 117 | } 118 | 119 | // CoveredNetworks returns the list of RangerEntry(s) the given ipnet 120 | // covers. That is, the networks that are completely subsumed by the 121 | // specified network. 122 | func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { 123 | net := rnet.NewNetwork(network) 124 | return p.coveredNetworks(net) 125 | } 126 | 127 | // Len returns number of networks in ranger. 128 | func (p *prefixTrie) Len() int { 129 | return p.size 130 | } 131 | 132 | // String returns string representation of trie, mainly for visualization and 133 | // debugging. 134 | func (p *prefixTrie) String() string { 135 | children := []string{} 136 | padding := strings.Repeat("| ", p.level()+1) 137 | for bits, child := range p.children { 138 | if child == nil { 139 | continue 140 | } 141 | childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String()) 142 | children = append(children, childStr) 143 | } 144 | return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network, 145 | p.targetBitPosition(), p.hasEntry(), strings.Join(children, "")) 146 | } 147 | 148 | func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) { 149 | if !p.network.Contains(number) { 150 | return false, nil 151 | } 152 | if p.hasEntry() { 153 | return true, nil 154 | } 155 | if p.targetBitPosition() < 0 { 156 | return false, nil 157 | } 158 | bit, err := p.targetBitFromIP(number) 159 | if err != nil { 160 | return false, err 161 | } 162 | child := p.children[bit] 163 | if child != nil { 164 | return child.contains(number) 165 | } 166 | return false, nil 167 | } 168 | 169 | func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) { 170 | results := []RangerEntry{} 171 | if !p.network.Contains(number) { 172 | return results, nil 173 | } 174 | if p.hasEntry() { 175 | results = []RangerEntry{p.entry} 176 | } 177 | if p.targetBitPosition() < 0 { 178 | return results, nil 179 | } 180 | bit, err := p.targetBitFromIP(number) 181 | if err != nil { 182 | return nil, err 183 | } 184 | child := p.children[bit] 185 | if child != nil { 186 | ranges, err := child.containingNetworks(number) 187 | if err != nil { 188 | return nil, err 189 | } 190 | if len(ranges) > 0 { 191 | if len(results) > 0 { 192 | results = append(results, ranges...) 193 | } else { 194 | results = ranges 195 | } 196 | } 197 | } 198 | return results, nil 199 | } 200 | 201 | func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) { 202 | var results []RangerEntry 203 | if network.Covers(p.network) { 204 | for entry := range p.walkDepth() { 205 | results = append(results, entry) 206 | } 207 | } else if p.targetBitPosition() >= 0 { 208 | bit, err := p.targetBitFromIP(network.Number) 209 | if err != nil { 210 | return results, err 211 | } 212 | child := p.children[bit] 213 | if child != nil { 214 | return child.coveredNetworks(network) 215 | } 216 | } 217 | return results, nil 218 | } 219 | 220 | func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) { 221 | if p.network.Equal(network) { 222 | sizeIncreased := p.entry == nil 223 | p.entry = entry 224 | return sizeIncreased, nil 225 | } 226 | 227 | bit, err := p.targetBitFromIP(network.Number) 228 | if err != nil { 229 | return false, err 230 | } 231 | existingChild := p.children[bit] 232 | 233 | // No existing child, insert new leaf trie. 234 | if existingChild == nil { 235 | p.appendTrie(bit, newEntryTrie(network, entry)) 236 | return true, nil 237 | } 238 | 239 | // Check whether it is necessary to insert additional path prefix between current trie and existing child, 240 | // in the case that inserted network diverges on its path to existing child. 241 | lcb, err := network.LeastCommonBitPosition(existingChild.network) 242 | divergingBitPos := int(lcb) - 1 243 | if divergingBitPos > existingChild.targetBitPosition() { 244 | pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb) 245 | err := p.insertPrefix(bit, pathPrefix, existingChild) 246 | if err != nil { 247 | return false, err 248 | } 249 | // Update new child 250 | existingChild = pathPrefix 251 | } 252 | return existingChild.insert(network, entry) 253 | } 254 | 255 | func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) { 256 | p.children[bit] = prefix 257 | prefix.parent = p 258 | } 259 | 260 | func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error { 261 | // Set parent/child relationship between current trie and inserted pathPrefix 262 | p.children[bit] = pathPrefix 263 | pathPrefix.parent = p 264 | 265 | // Set parent/child relationship between inserted pathPrefix and original child 266 | pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number) 267 | if err != nil { 268 | return err 269 | } 270 | pathPrefix.children[pathPrefixBit] = child 271 | child.parent = pathPrefix 272 | return nil 273 | } 274 | 275 | func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) { 276 | if p.hasEntry() && p.network.Equal(network) { 277 | entry := p.entry 278 | p.entry = nil 279 | 280 | err := p.compressPathIfPossible() 281 | if err != nil { 282 | return nil, err 283 | } 284 | return entry, nil 285 | } 286 | if p.targetBitPosition() < 0 { 287 | return nil, nil 288 | } 289 | bit, err := p.targetBitFromIP(network.Number) 290 | if err != nil { 291 | return nil, err 292 | } 293 | child := p.children[bit] 294 | if child != nil { 295 | return child.remove(network) 296 | } 297 | return nil, nil 298 | } 299 | 300 | func (p *prefixTrie) qualifiesForPathCompression() bool { 301 | // Current prefix trie can be path compressed if it meets all following. 302 | // 1. records no CIDR entry 303 | // 2. has single or no child 304 | // 3. is not root trie 305 | return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil 306 | } 307 | 308 | func (p *prefixTrie) compressPathIfPossible() error { 309 | if !p.qualifiesForPathCompression() { 310 | // Does not qualify to be compressed 311 | return nil 312 | } 313 | 314 | // Find lone child. 315 | var loneChild *prefixTrie 316 | for _, child := range p.children { 317 | if child != nil { 318 | loneChild = child 319 | break 320 | } 321 | } 322 | 323 | // Find root of currnt single child lineage. 324 | parent := p.parent 325 | for ; parent.qualifiesForPathCompression(); parent = parent.parent { 326 | } 327 | parentBit, err := parent.targetBitFromIP(p.network.Number) 328 | if err != nil { 329 | return err 330 | } 331 | parent.children[parentBit] = loneChild 332 | 333 | // Attempts to furthur apply path compression at current lineage parent, in case current lineage 334 | // compressed into parent. 335 | return parent.compressPathIfPossible() 336 | } 337 | 338 | func (p *prefixTrie) childrenCount() int { 339 | count := 0 340 | for _, child := range p.children { 341 | if child != nil { 342 | count++ 343 | } 344 | } 345 | return count 346 | } 347 | 348 | func (p *prefixTrie) totalNumberOfBits() uint { 349 | return rnet.BitsPerUint32 * uint(len(p.network.Number)) 350 | } 351 | 352 | func (p *prefixTrie) targetBitPosition() int { 353 | return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1 354 | } 355 | 356 | func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) { 357 | // This is a safe uint boxing of int since we should never attempt to get 358 | // target bit at a negative position. 359 | return n.Bit(uint(p.targetBitPosition())) 360 | } 361 | 362 | func (p *prefixTrie) hasEntry() bool { 363 | return p.entry != nil 364 | } 365 | 366 | func (p *prefixTrie) level() int { 367 | if p.parent == nil { 368 | return 0 369 | } 370 | return p.parent.level() + 1 371 | } 372 | 373 | // walkDepth walks the trie in depth order, for unit testing. 374 | func (p *prefixTrie) walkDepth() <-chan RangerEntry { 375 | entries := make(chan RangerEntry) 376 | go func() { 377 | if p.hasEntry() { 378 | entries <- p.entry 379 | } 380 | childEntriesList := []<-chan RangerEntry{} 381 | for _, trie := range p.children { 382 | if trie == nil { 383 | continue 384 | } 385 | childEntriesList = append(childEntriesList, trie.walkDepth()) 386 | } 387 | for _, childEntries := range childEntriesList { 388 | for entry := range childEntries { 389 | entries <- entry 390 | } 391 | } 392 | close(entries) 393 | }() 394 | return entries 395 | } 396 | -------------------------------------------------------------------------------- /trie_test.go: -------------------------------------------------------------------------------- 1 | package cidranger 2 | 3 | import ( 4 | "encoding/binary" 5 | "math/rand" 6 | "net" 7 | "runtime" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | rnet "github.com/yl2chen/cidranger/net" 13 | ) 14 | 15 | func getAllByVersion(version rnet.IPVersion) *net.IPNet { 16 | if version == rnet.IPv6 { 17 | return AllIPv6 18 | } 19 | return AllIPv4 20 | } 21 | 22 | func TestPrefixTrieInsert(t *testing.T) { 23 | cases := []struct { 24 | version rnet.IPVersion 25 | inserts []string 26 | expectedNetworksInDepthOrder []string 27 | name string 28 | }{ 29 | {rnet.IPv4, []string{"192.168.0.1/24"}, []string{"192.168.0.1/24"}, "basic insert"}, 30 | { 31 | rnet.IPv4, 32 | []string{"1.2.3.4/32", "1.2.3.5/32"}, 33 | []string{"1.2.3.4/32", "1.2.3.5/32"}, 34 | "single ip IPv4 network insert", 35 | }, 36 | { 37 | rnet.IPv6, 38 | []string{"0::1/128", "0::2/128"}, 39 | []string{"0::1/128", "0::2/128"}, 40 | "single ip IPv6 network insert", 41 | }, 42 | { 43 | rnet.IPv4, 44 | []string{"192.168.0.1/16", "192.168.0.1/24"}, 45 | []string{"192.168.0.1/16", "192.168.0.1/24"}, 46 | "in order insert", 47 | }, 48 | { 49 | rnet.IPv4, 50 | []string{"192.168.0.1/32", "192.168.0.1/32"}, 51 | []string{"192.168.0.1/32"}, 52 | "duplicate network insert", 53 | }, 54 | { 55 | rnet.IPv4, 56 | []string{"192.168.0.1/24", "192.168.0.1/16"}, 57 | []string{"192.168.0.1/16", "192.168.0.1/24"}, 58 | "reverse insert", 59 | }, 60 | { 61 | rnet.IPv4, 62 | []string{"192.168.0.1/24", "192.168.1.1/24"}, 63 | []string{"192.168.0.1/24", "192.168.1.1/24"}, 64 | "branch insert", 65 | }, 66 | { 67 | rnet.IPv4, 68 | []string{"192.168.0.1/24", "192.168.1.1/24", "192.168.1.1/30"}, 69 | []string{"192.168.0.1/24", "192.168.1.1/24", "192.168.1.1/30"}, 70 | "branch inserts", 71 | }, 72 | } 73 | for _, tc := range cases { 74 | t.Run(tc.name, func(t *testing.T) { 75 | trie := newPrefixTree(tc.version).(*prefixTrie) 76 | for _, insert := range tc.inserts { 77 | _, network, _ := net.ParseCIDR(insert) 78 | err := trie.Insert(NewBasicRangerEntry(*network)) 79 | assert.NoError(t, err) 80 | } 81 | 82 | assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match") 83 | 84 | allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version)) 85 | assert.Nil(t, err) 86 | assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match") 87 | 88 | walk := trie.walkDepth() 89 | for _, network := range tc.expectedNetworksInDepthOrder { 90 | _, ipnet, _ := net.ParseCIDR(network) 91 | expected := NewBasicRangerEntry(*ipnet) 92 | actual := <-walk 93 | assert.Equal(t, expected, actual) 94 | } 95 | 96 | // Ensure no unexpected elements in trie. 97 | for network := range walk { 98 | assert.Nil(t, network) 99 | } 100 | }) 101 | } 102 | } 103 | 104 | func TestPrefixTrieString(t *testing.T) { 105 | inserts := []string{"192.168.0.1/24", "192.168.1.1/24", "192.168.1.1/30"} 106 | trie := newPrefixTree(rnet.IPv4).(*prefixTrie) 107 | for _, insert := range inserts { 108 | _, network, _ := net.ParseCIDR(insert) 109 | trie.Insert(NewBasicRangerEntry(*network)) 110 | } 111 | expected := `0.0.0.0/0 (target_pos:31:has_entry:false) 112 | | 1--> 192.168.0.0/23 (target_pos:8:has_entry:false) 113 | | | 0--> 192.168.0.0/24 (target_pos:7:has_entry:true) 114 | | | 1--> 192.168.1.0/24 (target_pos:7:has_entry:true) 115 | | | | 0--> 192.168.1.0/30 (target_pos:1:has_entry:true)` 116 | assert.Equal(t, expected, trie.String()) 117 | } 118 | 119 | func TestPrefixTrieRemove(t *testing.T) { 120 | cases := []struct { 121 | version rnet.IPVersion 122 | inserts []string 123 | removes []string 124 | expectedRemoves []string 125 | expectedNetworksInDepthOrder []string 126 | expectedTrieString string 127 | name string 128 | }{ 129 | { 130 | rnet.IPv4, 131 | []string{"192.168.0.1/24"}, 132 | []string{"192.168.0.1/24"}, 133 | []string{"192.168.0.1/24"}, 134 | []string{}, 135 | "0.0.0.0/0 (target_pos:31:has_entry:false)", 136 | "basic remove", 137 | }, 138 | { 139 | rnet.IPv4, 140 | []string{"192.168.0.1/32"}, 141 | []string{"192.168.0.1/24"}, 142 | []string{""}, 143 | []string{"192.168.0.1/32"}, 144 | `0.0.0.0/0 (target_pos:31:has_entry:false) 145 | | 1--> 192.168.0.1/32 (target_pos:-1:has_entry:true)`, 146 | "remove from ranger that contains a single ip block", 147 | }, 148 | { 149 | rnet.IPv4, 150 | []string{"1.2.3.4/32", "1.2.3.5/32"}, 151 | []string{"1.2.3.5/32"}, 152 | []string{"1.2.3.5/32"}, 153 | []string{"1.2.3.4/32"}, 154 | `0.0.0.0/0 (target_pos:31:has_entry:false) 155 | | 0--> 1.2.3.4/32 (target_pos:-1:has_entry:true)`, 156 | "single ip IPv4 network remove", 157 | }, 158 | { 159 | rnet.IPv4, 160 | []string{"0::1/128", "0::2/128"}, 161 | []string{"0::2/128"}, 162 | []string{"0::2/128"}, 163 | []string{"0::1/128"}, 164 | `0.0.0.0/0 (target_pos:31:has_entry:false) 165 | | 0--> ::1/128 (target_pos:-1:has_entry:true)`, 166 | "single ip IPv6 network remove", 167 | }, 168 | { 169 | rnet.IPv4, 170 | []string{"192.168.0.1/24", "192.168.0.1/25", "192.168.0.1/26"}, 171 | []string{"192.168.0.1/25"}, 172 | []string{"192.168.0.1/25"}, 173 | []string{"192.168.0.1/24", "192.168.0.1/26"}, 174 | `0.0.0.0/0 (target_pos:31:has_entry:false) 175 | | 1--> 192.168.0.0/24 (target_pos:7:has_entry:true) 176 | | | 0--> 192.168.0.0/26 (target_pos:5:has_entry:true)`, 177 | "remove path prefix", 178 | }, 179 | { 180 | rnet.IPv4, 181 | []string{"192.168.0.1/24", "192.168.0.1/25", "192.168.0.64/26", "192.168.0.1/26"}, 182 | []string{"192.168.0.1/25"}, 183 | []string{"192.168.0.1/25"}, 184 | []string{"192.168.0.1/24", "192.168.0.1/26", "192.168.0.64/26"}, 185 | `0.0.0.0/0 (target_pos:31:has_entry:false) 186 | | 1--> 192.168.0.0/24 (target_pos:7:has_entry:true) 187 | | | 0--> 192.168.0.0/25 (target_pos:6:has_entry:false) 188 | | | | 0--> 192.168.0.0/26 (target_pos:5:has_entry:true) 189 | | | | 1--> 192.168.0.64/26 (target_pos:5:has_entry:true)`, 190 | "remove path prefix with more than 1 children", 191 | }, 192 | { 193 | rnet.IPv4, 194 | []string{"192.168.0.1/24", "192.168.0.1/25"}, 195 | []string{"192.168.0.1/26"}, 196 | []string{""}, 197 | []string{"192.168.0.1/24", "192.168.0.1/25"}, 198 | `0.0.0.0/0 (target_pos:31:has_entry:false) 199 | | 1--> 192.168.0.0/24 (target_pos:7:has_entry:true) 200 | | | 0--> 192.168.0.0/25 (target_pos:6:has_entry:true)`, 201 | "remove non existent", 202 | }, 203 | } 204 | 205 | for _, tc := range cases { 206 | t.Run(tc.name, func(t *testing.T) { 207 | trie := newPrefixTree(tc.version).(*prefixTrie) 208 | for _, insert := range tc.inserts { 209 | _, network, _ := net.ParseCIDR(insert) 210 | err := trie.Insert(NewBasicRangerEntry(*network)) 211 | assert.NoError(t, err) 212 | } 213 | for i, remove := range tc.removes { 214 | _, network, _ := net.ParseCIDR(remove) 215 | removed, err := trie.Remove(*network) 216 | assert.NoError(t, err) 217 | if str := tc.expectedRemoves[i]; str != "" { 218 | _, ipnet, _ := net.ParseCIDR(str) 219 | expected := NewBasicRangerEntry(*ipnet) 220 | assert.Equal(t, expected, removed) 221 | } else { 222 | assert.Nil(t, removed) 223 | } 224 | } 225 | 226 | assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match after revmoval") 227 | 228 | allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version)) 229 | assert.Nil(t, err) 230 | assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match") 231 | 232 | walk := trie.walkDepth() 233 | for _, network := range tc.expectedNetworksInDepthOrder { 234 | _, ipnet, _ := net.ParseCIDR(network) 235 | expected := NewBasicRangerEntry(*ipnet) 236 | actual := <-walk 237 | assert.Equal(t, expected, actual) 238 | } 239 | 240 | // Ensure no unexpected elements in trie. 241 | for network := range walk { 242 | assert.Nil(t, network) 243 | } 244 | 245 | assert.Equal(t, tc.expectedTrieString, trie.String()) 246 | }) 247 | } 248 | } 249 | 250 | func TestToReplicateIssue(t *testing.T) { 251 | cases := []struct { 252 | version rnet.IPVersion 253 | inserts []string 254 | ip net.IP 255 | networks []string 256 | name string 257 | }{ 258 | { 259 | rnet.IPv4, 260 | []string{"192.168.0.1/32"}, 261 | net.ParseIP("192.168.0.1"), 262 | []string{"192.168.0.1/32"}, 263 | "basic containing network for /32 mask", 264 | }, 265 | { 266 | rnet.IPv6, 267 | []string{"a::1/128"}, 268 | net.ParseIP("a::1"), 269 | []string{"a::1/128"}, 270 | "basic containing network for /128 mask", 271 | }, 272 | } 273 | for _, tc := range cases { 274 | t.Run(tc.name, func(t *testing.T) { 275 | trie := newPrefixTree(tc.version) 276 | for _, insert := range tc.inserts { 277 | _, network, _ := net.ParseCIDR(insert) 278 | err := trie.Insert(NewBasicRangerEntry(*network)) 279 | assert.NoError(t, err) 280 | } 281 | expectedEntries := []RangerEntry{} 282 | for _, network := range tc.networks { 283 | _, net, _ := net.ParseCIDR(network) 284 | expectedEntries = append(expectedEntries, NewBasicRangerEntry(*net)) 285 | } 286 | contains, err := trie.Contains(tc.ip) 287 | assert.NoError(t, err) 288 | assert.True(t, contains) 289 | networks, err := trie.ContainingNetworks(tc.ip) 290 | assert.NoError(t, err) 291 | assert.Equal(t, expectedEntries, networks) 292 | }) 293 | } 294 | } 295 | 296 | type expectedIPRange struct { 297 | start net.IP 298 | end net.IP 299 | } 300 | 301 | func TestPrefixTrieContains(t *testing.T) { 302 | cases := []struct { 303 | version rnet.IPVersion 304 | inserts []string 305 | expectedIPs []expectedIPRange 306 | name string 307 | }{ 308 | { 309 | rnet.IPv4, 310 | []string{"192.168.0.0/24"}, 311 | []expectedIPRange{ 312 | {net.ParseIP("192.168.0.0"), net.ParseIP("192.168.1.0")}, 313 | }, 314 | "basic contains", 315 | }, 316 | { 317 | rnet.IPv4, 318 | []string{"192.168.0.0/24", "128.168.0.0/24"}, 319 | []expectedIPRange{ 320 | {net.ParseIP("192.168.0.0"), net.ParseIP("192.168.1.0")}, 321 | {net.ParseIP("128.168.0.0"), net.ParseIP("128.168.1.0")}, 322 | }, 323 | "multiple ranges contains", 324 | }, 325 | } 326 | 327 | for _, tc := range cases { 328 | t.Run(tc.name, func(t *testing.T) { 329 | trie := newPrefixTree(tc.version) 330 | for _, insert := range tc.inserts { 331 | _, network, _ := net.ParseCIDR(insert) 332 | err := trie.Insert(NewBasicRangerEntry(*network)) 333 | assert.NoError(t, err) 334 | } 335 | for _, expectedIPRange := range tc.expectedIPs { 336 | var contains bool 337 | var err error 338 | start := expectedIPRange.start 339 | for ; !expectedIPRange.end.Equal(start); start = rnet.NextIP(start) { 340 | contains, err = trie.Contains(start) 341 | assert.NoError(t, err) 342 | assert.True(t, contains) 343 | } 344 | 345 | // Check out of bounds ips on both ends 346 | contains, err = trie.Contains(rnet.PreviousIP(expectedIPRange.start)) 347 | assert.NoError(t, err) 348 | assert.False(t, contains) 349 | contains, err = trie.Contains(rnet.NextIP(expectedIPRange.end)) 350 | assert.NoError(t, err) 351 | assert.False(t, contains) 352 | } 353 | }) 354 | } 355 | } 356 | 357 | func TestPrefixTrieContainingNetworks(t *testing.T) { 358 | cases := []struct { 359 | version rnet.IPVersion 360 | inserts []string 361 | ip net.IP 362 | networks []string 363 | name string 364 | }{ 365 | { 366 | rnet.IPv4, 367 | []string{"192.168.0.0/24"}, 368 | net.ParseIP("192.168.0.1"), 369 | []string{"192.168.0.0/24"}, 370 | "basic containing networks", 371 | }, 372 | { 373 | rnet.IPv4, 374 | []string{"192.168.0.0/24", "192.168.0.0/25"}, 375 | net.ParseIP("192.168.0.1"), 376 | []string{"192.168.0.0/24", "192.168.0.0/25"}, 377 | "inclusive networks", 378 | }, 379 | } 380 | for _, tc := range cases { 381 | t.Run(tc.name, func(t *testing.T) { 382 | trie := newPrefixTree(tc.version) 383 | for _, insert := range tc.inserts { 384 | _, network, _ := net.ParseCIDR(insert) 385 | err := trie.Insert(NewBasicRangerEntry(*network)) 386 | assert.NoError(t, err) 387 | } 388 | expectedEntries := []RangerEntry{} 389 | for _, network := range tc.networks { 390 | _, net, _ := net.ParseCIDR(network) 391 | expectedEntries = append(expectedEntries, NewBasicRangerEntry(*net)) 392 | } 393 | networks, err := trie.ContainingNetworks(tc.ip) 394 | assert.NoError(t, err) 395 | assert.Equal(t, expectedEntries, networks) 396 | }) 397 | } 398 | } 399 | 400 | type coveredNetworkTest struct { 401 | version rnet.IPVersion 402 | inserts []string 403 | search string 404 | networks []string 405 | name string 406 | } 407 | 408 | var coveredNetworkTests = []coveredNetworkTest{ 409 | { 410 | rnet.IPv4, 411 | []string{"192.168.0.0/24"}, 412 | "192.168.0.0/16", 413 | []string{"192.168.0.0/24"}, 414 | "basic covered networks", 415 | }, 416 | { 417 | rnet.IPv4, 418 | []string{"192.168.0.0/24"}, 419 | "10.1.0.0/16", 420 | nil, 421 | "nothing", 422 | }, 423 | { 424 | rnet.IPv4, 425 | []string{"192.168.0.0/24", "192.168.0.0/25"}, 426 | "192.168.0.0/16", 427 | []string{"192.168.0.0/24", "192.168.0.0/25"}, 428 | "multiple networks", 429 | }, 430 | { 431 | rnet.IPv4, 432 | []string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"}, 433 | "192.168.0.0/16", 434 | []string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"}, 435 | "multiple networks 2", 436 | }, 437 | { 438 | rnet.IPv4, 439 | []string{"192.168.1.1/32"}, 440 | "192.168.0.0/16", 441 | []string{"192.168.1.1/32"}, 442 | "leaf", 443 | }, 444 | { 445 | rnet.IPv4, 446 | []string{"0.0.0.0/0", "192.168.1.1/32"}, 447 | "192.168.0.0/16", 448 | []string{"192.168.1.1/32"}, 449 | "leaf with root", 450 | }, 451 | { 452 | rnet.IPv4, 453 | []string{ 454 | "0.0.0.0/0", "192.168.0.0/24", "192.168.1.1/32", 455 | "10.1.0.0/16", "10.1.1.0/24", 456 | }, 457 | "192.168.0.0/16", 458 | []string{"192.168.0.0/24", "192.168.1.1/32"}, 459 | "path not taken", 460 | }, 461 | { 462 | rnet.IPv4, 463 | []string{ 464 | "192.168.0.0/15", 465 | }, 466 | "192.168.0.0/16", 467 | nil, 468 | "only masks different", 469 | }, 470 | } 471 | 472 | func TestPrefixTrieCoveredNetworks(t *testing.T) { 473 | for _, tc := range coveredNetworkTests { 474 | t.Run(tc.name, func(t *testing.T) { 475 | trie := newPrefixTree(tc.version) 476 | for _, insert := range tc.inserts { 477 | _, network, _ := net.ParseCIDR(insert) 478 | err := trie.Insert(NewBasicRangerEntry(*network)) 479 | assert.NoError(t, err) 480 | } 481 | var expectedEntries []RangerEntry 482 | for _, network := range tc.networks { 483 | _, net, _ := net.ParseCIDR(network) 484 | expectedEntries = append(expectedEntries, 485 | NewBasicRangerEntry(*net)) 486 | } 487 | _, snet, _ := net.ParseCIDR(tc.search) 488 | networks, err := trie.CoveredNetworks(*snet) 489 | assert.NoError(t, err) 490 | assert.Equal(t, expectedEntries, networks) 491 | }) 492 | } 493 | } 494 | 495 | func TestTrieMemUsage(t *testing.T) { 496 | if testing.Short() { 497 | t.Skip("Skipping memory test in `-short` mode") 498 | } 499 | numIPs := 100000 500 | runs := 10 501 | 502 | // Avg heap allocation over all runs should not be more than the heap allocation of first run multiplied 503 | // by threshold, picking 1% as sane number for detecting memory leak. 504 | thresh := 1.01 505 | 506 | trie := newPrefixTree(rnet.IPv4) 507 | 508 | var baseLineHeap, totalHeapAllocOverRuns uint64 509 | for i := 0; i < runs; i++ { 510 | t.Logf("Executing Run %d of %d", i+1, runs) 511 | 512 | // Insert networks. 513 | for n := 0; n < numIPs; n++ { 514 | trie.Insert(NewBasicRangerEntry(GenLeafIPNet(GenIPV4()))) 515 | } 516 | t.Logf("Inserted All (%d networks)", trie.Len()) 517 | assert.Less(t, 0, trie.Len(), "Len should > 0") 518 | assert.LessOrEqualf(t, trie.Len(), numIPs, "Len should <= %d", numIPs) 519 | 520 | allNetworks, err := trie.CoveredNetworks(*getAllByVersion(rnet.IPv4)) 521 | assert.Nil(t, err) 522 | assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match") 523 | 524 | // Remove networks. 525 | _, all, _ := net.ParseCIDR("0.0.0.0/0") 526 | ll, _ := trie.CoveredNetworks(*all) 527 | for i := 0; i < len(ll); i++ { 528 | trie.Remove(ll[i].Network()) 529 | } 530 | t.Logf("Removed All (%d networks)", len(ll)) 531 | assert.Equal(t, 0, trie.Len(), "Len after removal should == 0") 532 | 533 | // Perform GC 534 | runtime.GC() 535 | 536 | // Get HeapAlloc stats. 537 | heapAlloc := GetHeapAllocation() 538 | totalHeapAllocOverRuns += heapAlloc 539 | if i == 0 { 540 | baseLineHeap = heapAlloc 541 | } 542 | } 543 | 544 | // Assert that heap allocation from first loop is within set threshold of avg over all runs. 545 | assert.Less(t, uint64(0), baseLineHeap) 546 | assert.LessOrEqual(t, float64(baseLineHeap), float64(totalHeapAllocOverRuns/uint64(runs))*thresh) 547 | } 548 | 549 | func GenLeafIPNet(ip net.IP) net.IPNet { 550 | return net.IPNet{ 551 | IP: ip, 552 | Mask: net.CIDRMask(32, 32), 553 | } 554 | } 555 | 556 | // GenIPV4 generates an IPV4 address 557 | func GenIPV4() net.IP { 558 | rand.Seed(time.Now().UnixNano()) 559 | nn := rand.Uint32() 560 | if nn < 4294967295 { 561 | nn++ 562 | } 563 | ip := make(net.IP, 4) 564 | binary.BigEndian.PutUint32(ip, uint32(nn)) 565 | return ip 566 | } 567 | 568 | func GetHeapAllocation() uint64 { 569 | var m runtime.MemStats 570 | runtime.ReadMemStats(&m) 571 | return m.HeapAlloc 572 | } 573 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package cidranger 2 | 3 | import ( 4 | "net" 5 | 6 | rnet "github.com/yl2chen/cidranger/net" 7 | ) 8 | 9 | type rangerFactory func(rnet.IPVersion) Ranger 10 | 11 | type versionedRanger struct { 12 | ipV4Ranger Ranger 13 | ipV6Ranger Ranger 14 | } 15 | 16 | func newVersionedRanger(factory rangerFactory) Ranger { 17 | return &versionedRanger{ 18 | ipV4Ranger: factory(rnet.IPv4), 19 | ipV6Ranger: factory(rnet.IPv6), 20 | } 21 | } 22 | 23 | func (v *versionedRanger) Insert(entry RangerEntry) error { 24 | network := entry.Network() 25 | ranger, err := v.getRangerForIP(network.IP) 26 | if err != nil { 27 | return err 28 | } 29 | return ranger.Insert(entry) 30 | } 31 | 32 | func (v *versionedRanger) Remove(network net.IPNet) (RangerEntry, error) { 33 | ranger, err := v.getRangerForIP(network.IP) 34 | if err != nil { 35 | return nil, err 36 | } 37 | return ranger.Remove(network) 38 | } 39 | 40 | func (v *versionedRanger) Contains(ip net.IP) (bool, error) { 41 | ranger, err := v.getRangerForIP(ip) 42 | if err != nil { 43 | return false, err 44 | } 45 | return ranger.Contains(ip) 46 | } 47 | 48 | func (v *versionedRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { 49 | ranger, err := v.getRangerForIP(ip) 50 | if err != nil { 51 | return nil, err 52 | } 53 | return ranger.ContainingNetworks(ip) 54 | } 55 | 56 | func (v *versionedRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { 57 | ranger, err := v.getRangerForIP(network.IP) 58 | if err != nil { 59 | return nil, err 60 | } 61 | return ranger.CoveredNetworks(network) 62 | } 63 | 64 | // Len returns number of networks in ranger. 65 | func (v *versionedRanger) Len() int { 66 | return v.ipV4Ranger.Len() + v.ipV6Ranger.Len() 67 | } 68 | 69 | func (v *versionedRanger) getRangerForIP(ip net.IP) (Ranger, error) { 70 | if ip.To4() != nil { 71 | return v.ipV4Ranger, nil 72 | } 73 | if ip.To16() != nil { 74 | return v.ipV6Ranger, nil 75 | } 76 | return nil, ErrInvalidNetworkNumberInput 77 | } 78 | --------------------------------------------------------------------------------