├── .gitignore
├── .travis.yml
├── Gopkg.lock
├── Gopkg.toml
├── LICENSE
├── README.md
├── brute.go
├── brute_test.go
├── cidranger.go
├── cidranger_test.go
├── example
└── custom-ranger-asn.go
├── go.mod
├── go.sum
├── net
├── ip.go
└── ip_test.go
├── testdata
└── aws_ip_ranges.json
├── trie.go
├── trie_test.go
└── version.go
/.gitignore:
--------------------------------------------------------------------------------
1 | vendor
2 | .idea
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 | go:
3 | - 1.13.x
4 | - 1.14.x
5 | - tip
6 | before_install:
7 | - travis_retry go get github.com/mattn/goveralls
8 | script:
9 | - go test -v -covermode=count -coverprofile=coverage.out ./...
10 | - travis_retry $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci
11 |
--------------------------------------------------------------------------------
/Gopkg.lock:
--------------------------------------------------------------------------------
1 | # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
2 |
3 |
4 | [[projects]]
5 | digest = "1:a2c1d0e43bd3baaa071d1b9ed72c27d78169b2b269f71c105ac4ba34b1be4a39"
6 | name = "github.com/davecgh/go-spew"
7 | packages = ["spew"]
8 | pruneopts = "UT"
9 | revision = "346938d642f2ec3594ed81d874461961cd0faa76"
10 | version = "v1.1.0"
11 |
12 | [[projects]]
13 | digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe"
14 | name = "github.com/pmezard/go-difflib"
15 | packages = ["difflib"]
16 | pruneopts = "UT"
17 | revision = "792786c7400a136282c1664665ae0a8db921c6c2"
18 | version = "v1.0.0"
19 |
20 | [[projects]]
21 | digest = "1:f85e109eda8f6080877185d1c39e98dd8795e1780c08beca28304b87fd855a1c"
22 | name = "github.com/stretchr/testify"
23 | packages = ["assert"]
24 | pruneopts = "UT"
25 | revision = "12b6f73e6084dad08a7c6e575284b177ecafbc71"
26 | version = "v1.2.1"
27 |
28 | [solve-meta]
29 | analyzer-name = "dep"
30 | analyzer-version = 1
31 | input-imports = ["github.com/stretchr/testify/assert"]
32 | solver-name = "gps-cdcl"
33 | solver-version = 1
34 |
--------------------------------------------------------------------------------
/Gopkg.toml:
--------------------------------------------------------------------------------
1 | # Gopkg.toml example
2 | #
3 | # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md
4 | # for detailed Gopkg.toml documentation.
5 | #
6 | # required = ["github.com/user/thing/cmd/thing"]
7 | # ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
8 | #
9 | # [[constraint]]
10 | # name = "github.com/user/project"
11 | # version = "1.0.0"
12 | #
13 | # [[constraint]]
14 | # name = "github.com/user/project2"
15 | # branch = "dev"
16 | # source = "github.com/myfork/project2"
17 | #
18 | # [[override]]
19 | # name = "github.com/x/y"
20 | # version = "2.4.0"
21 | #
22 | # [prune]
23 | # non-go = false
24 | # go-tests = true
25 | # unused-packages = true
26 |
27 |
28 | [[constraint]]
29 | name = "github.com/stretchr/testify"
30 | version = "1.2.1"
31 |
32 | [prune]
33 | go-tests = true
34 | unused-packages = true
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Yulin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cidranger
2 | Fast IP to CIDR block(s) lookup using trie in Golang, inspired by [IPv4 route lookup linux](https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux). Possible use cases include detecting if a IP address is from published cloud provider CIDR blocks (e.g. 52.95.110.1 is contained in published AWS Route53 CIDR 52.95.110.0/24), IP routing rules, etc.
3 |
4 | [](https://godoc.org/github.com/yl2chen/cidranger)
5 | [](https://travis-ci.org/yl2chen/cidranger)
6 | [](https://coveralls.io/github/yl2chen/cidranger?branch=master)
7 | [](https://goreportcard.com/report/github.com/yl2chen/cidranger)
8 |
9 | This is visualization of a trie storing CIDR blocks `128.0.0.0/2` `192.0.0.0/2` `200.0.0.0/5` without path compression, the 0/1 number on the path indicates the bit value of the IP address at specified bit position, hence the path from root node to a child node represents a CIDR block that contains all IP ranges of its children, and children's children.
10 |

11 |
12 | Visualization of trie storing same CIDR blocks with path compression, improving both lookup speed and memory footprint.
13 | 
14 |
15 | ## Getting Started
16 | Configure imports.
17 | ```go
18 | import (
19 | "net"
20 |
21 | "github.com/yl2chen/cidranger"
22 | )
23 | ```
24 | Create a new ranger implemented using Path-Compressed prefix trie.
25 | ```go
26 | ranger := NewPCTrieRanger()
27 | ```
28 | Inserts CIDR blocks.
29 | ```go
30 | _, network1, _ := net.ParseCIDR("192.168.1.0/24")
31 | _, network2, _ := net.ParseCIDR("128.168.1.0/24")
32 | ranger.Insert(NewBasicRangerEntry(*network1))
33 | ranger.Insert(NewBasicRangerEntry(*network2))
34 | ```
35 | To attach any additional value(s) to the entry, simply create custom struct
36 | storing the desired value(s) that implements the RangerEntry interface:
37 | ```go
38 | type RangerEntry interface {
39 | Network() net.IPNet
40 | }
41 | ```
42 | The prefix trie can be visualized as:
43 | ```
44 | 0.0.0.0/0 (target_pos:31:has_entry:false)
45 | | 1--> 128.0.0.0/1 (target_pos:30:has_entry:false)
46 | | | 0--> 128.168.1.0/24 (target_pos:7:has_entry:true)
47 | | | 1--> 192.168.1.0/24 (target_pos:7:has_entry:true)
48 | ```
49 | To test if given IP is contained in constructed ranger,
50 | ```go
51 | contains, err = ranger.Contains(net.ParseIP("128.168.1.0")) // returns true, nil
52 | contains, err = ranger.Contains(net.ParseIP("192.168.2.0")) // returns false, nil
53 | ```
54 | To get all the networks given is contained in,
55 | ```go
56 | containingNetworks, err = ranger.ContainingNetworks(net.ParseIP("128.168.1.0"))
57 | ```
58 | To get all networks in ranger,
59 | ```go
60 | entries, err := ranger.CoveredNetworks(*AllIPv4) // for IPv4
61 | entries, err := ranger.CoveredNetworks(*AllIPv6) // for IPv6
62 | ```
63 |
64 | ## Benchmark
65 | Compare hit/miss case for IPv4/IPv6 using PC trie vs brute force implementation, Ranger is initialized with published AWS ip ranges (889 IPv4 CIDR blocks and 360 IPv6)
66 | ```go
67 | // Ipv4 lookup hit scenario
68 | BenchmarkPCTrieHitIPv4UsingAWSRanges-4 5000000 353 ns/op
69 | BenchmarkBruteRangerHitIPv4UsingAWSRanges-4 100000 13719 ns/op
70 |
71 | // Ipv6 lookup hit scenario, counter-intuitively faster then IPv4 due to less IPv6 CIDR
72 | // blocks in the AWS dataset, hence the constructed trie has less path splits and depth.
73 | BenchmarkPCTrieHitIPv6UsingAWSRanges-4 10000000 143 ns/op
74 | BenchmarkBruteRangerHitIPv6UsingAWSRanges-4 300000 5178 ns/op
75 |
76 | // Ipv4 lookup miss scenario
77 | BenchmarkPCTrieMissIPv4UsingAWSRanges-4 20000000 96.5 ns/op
78 | BenchmarkBruteRangerMissIPv4UsingAWSRanges-4 50000 24781 ns/op
79 |
80 | // Ipv6 lookup miss scenario
81 | BenchmarkPCTrieHMissIPv6UsingAWSRanges-4 10000000 115 ns/op
82 | BenchmarkBruteRangerMissIPv6UsingAWSRanges-4 100000 10824 ns/op
83 | ```
84 |
85 | ## Example of IPv6 trie:
86 | ```
87 | ::/0 (target_pos:127:has_entry:false)
88 | | 0--> 2400::/14 (target_pos:113:has_entry:false)
89 | | | 0--> 2400:6400::/22 (target_pos:105:has_entry:false)
90 | | | | 0--> 2400:6500::/32 (target_pos:95:has_entry:false)
91 | | | | | 0--> 2400:6500::/39 (target_pos:88:has_entry:false)
92 | | | | | | 0--> 2400:6500:0:7000::/53 (target_pos:74:has_entry:false)
93 | | | | | | | 0--> 2400:6500:0:7000::/54 (target_pos:73:has_entry:false)
94 | | | | | | | | 0--> 2400:6500:0:7000::/55 (target_pos:72:has_entry:false)
95 | | | | | | | | | 0--> 2400:6500:0:7000::/56 (target_pos:71:has_entry:true)
96 | | | | | | | | | 1--> 2400:6500:0:7100::/56 (target_pos:71:has_entry:true)
97 | | | | | | | | 1--> 2400:6500:0:7200::/56 (target_pos:71:has_entry:true)
98 | | | | | | | 1--> 2400:6500:0:7400::/55 (target_pos:72:has_entry:false)
99 | | | | | | | | 0--> 2400:6500:0:7400::/56 (target_pos:71:has_entry:true)
100 | | | | | | | | 1--> 2400:6500:0:7500::/56 (target_pos:71:has_entry:true)
101 | | | | | | 1--> 2400:6500:100:7000::/54 (target_pos:73:has_entry:false)
102 | | | | | | | 0--> 2400:6500:100:7100::/56 (target_pos:71:has_entry:true)
103 | | | | | | | 1--> 2400:6500:100:7200::/56 (target_pos:71:has_entry:true)
104 | | | | | 1--> 2400:6500:ff00::/64 (target_pos:63:has_entry:true)
105 | | | | 1--> 2400:6700:ff00::/64 (target_pos:63:has_entry:true)
106 | | | 1--> 2403:b300:ff00::/64 (target_pos:63:has_entry:true)
107 | ```
108 |
--------------------------------------------------------------------------------
/brute.go:
--------------------------------------------------------------------------------
1 | package cidranger
2 |
3 | import (
4 | "net"
5 |
6 | rnet "github.com/yl2chen/cidranger/net"
7 | )
8 |
9 | // bruteRanger is a brute force implementation of Ranger. Insertion and
10 | // deletion of networks is performed on an internal storage in the form of
11 | // map[string]net.IPNet (constant time operations). However, inclusion tests are
12 | // always performed linearly at no guaranteed traversal order of recorded networks,
13 | // so one can assume a worst case performance of O(N). The performance can be
14 | // boosted many ways, e.g. changing usage of net.IPNet.Contains() to using masked
15 | // bits equality checking, but the main purpose of this implementation is for
16 | // testing because the correctness of this implementation can be easily guaranteed,
17 | // and used as the ground truth when running a wider range of 'random' tests on
18 | // other more sophisticated implementations.
19 | type bruteRanger struct {
20 | ipV4Entries map[string]RangerEntry
21 | ipV6Entries map[string]RangerEntry
22 | }
23 |
24 | // newBruteRanger returns a new Ranger.
25 | func newBruteRanger() Ranger {
26 | return &bruteRanger{
27 | ipV4Entries: make(map[string]RangerEntry),
28 | ipV6Entries: make(map[string]RangerEntry),
29 | }
30 | }
31 |
32 | // Insert inserts a RangerEntry into ranger.
33 | func (b *bruteRanger) Insert(entry RangerEntry) error {
34 | network := entry.Network()
35 | key := network.String()
36 | if _, found := b.ipV4Entries[key]; !found {
37 | entries, err := b.getEntriesByVersion(entry.Network().IP)
38 | if err != nil {
39 | return err
40 | }
41 | entries[key] = entry
42 | }
43 | return nil
44 | }
45 |
46 | // Remove removes a RangerEntry identified by given network from ranger.
47 | func (b *bruteRanger) Remove(network net.IPNet) (RangerEntry, error) {
48 | networks, err := b.getEntriesByVersion(network.IP)
49 | if err != nil {
50 | return nil, err
51 | }
52 | key := network.String()
53 | if networkToDelete, found := networks[key]; found {
54 | delete(networks, key)
55 | return networkToDelete, nil
56 | }
57 | return nil, nil
58 | }
59 |
60 | // Contains returns bool indicating whether given ip is contained by any
61 | // network in ranger.
62 | func (b *bruteRanger) Contains(ip net.IP) (bool, error) {
63 | entries, err := b.getEntriesByVersion(ip)
64 | if err != nil {
65 | return false, err
66 | }
67 | for _, entry := range entries {
68 | network := entry.Network()
69 | if network.Contains(ip) {
70 | return true, nil
71 | }
72 | }
73 | return false, nil
74 | }
75 |
76 | // ContainingNetworks returns all RangerEntry(s) that given ip contained in.
77 | func (b *bruteRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
78 | entries, err := b.getEntriesByVersion(ip)
79 | if err != nil {
80 | return nil, err
81 | }
82 | results := []RangerEntry{}
83 | for _, entry := range entries {
84 | network := entry.Network()
85 | if network.Contains(ip) {
86 | results = append(results, entry)
87 | }
88 | }
89 | return results, nil
90 | }
91 |
92 | // CoveredNetworks returns the list of RangerEntry(s) the given ipnet
93 | // covers. That is, the networks that are completely subsumed by the
94 | // specified network.
95 | func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
96 | entries, err := b.getEntriesByVersion(network.IP)
97 | if err != nil {
98 | return nil, err
99 | }
100 | var results []RangerEntry
101 | testNetwork := rnet.NewNetwork(network)
102 | for _, entry := range entries {
103 | entryNetwork := rnet.NewNetwork(entry.Network())
104 | if testNetwork.Covers(entryNetwork) {
105 | results = append(results, entry)
106 | }
107 | }
108 | return results, nil
109 | }
110 |
111 | // Len returns number of networks in ranger.
112 | func (b *bruteRanger) Len() int {
113 | return len(b.ipV4Entries) + len(b.ipV6Entries)
114 | }
115 |
116 | func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) {
117 | if ip.To4() != nil {
118 | return b.ipV4Entries, nil
119 | }
120 | if ip.To16() != nil {
121 | return b.ipV6Entries, nil
122 | }
123 | return nil, ErrInvalidNetworkInput
124 | }
125 |
--------------------------------------------------------------------------------
/brute_test.go:
--------------------------------------------------------------------------------
1 | package cidranger
2 |
3 | import (
4 | "net"
5 | "sort"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestInsert(t *testing.T) {
12 | ranger := newBruteRanger().(*bruteRanger)
13 | _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
14 | _, networkIPv6, _ := net.ParseCIDR("8000::/96")
15 | entryIPv4 := NewBasicRangerEntry(*networkIPv4)
16 | entryIPv6 := NewBasicRangerEntry(*networkIPv6)
17 |
18 | ranger.Insert(entryIPv4)
19 | ranger.Insert(entryIPv6)
20 |
21 | assert.Equal(t, 1, len(ranger.ipV4Entries))
22 | assert.Equal(t, entryIPv4, ranger.ipV4Entries["0.0.1.0/24"])
23 | assert.Equal(t, 1, len(ranger.ipV6Entries))
24 | assert.Equal(t, entryIPv6, ranger.ipV6Entries["8000::/96"])
25 | }
26 |
27 | func TestInsertError(t *testing.T) {
28 | bRanger := newBruteRanger().(*bruteRanger)
29 | _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
30 | networkIPv4.IP = append(networkIPv4.IP, byte(4))
31 | err := bRanger.Insert(NewBasicRangerEntry(*networkIPv4))
32 | assert.Equal(t, ErrInvalidNetworkInput, err)
33 | }
34 |
35 | func TestRemove(t *testing.T) {
36 | ranger := newBruteRanger().(*bruteRanger)
37 | _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
38 | _, networkIPv6, _ := net.ParseCIDR("8000::/96")
39 | _, notInserted, _ := net.ParseCIDR("8000::/96")
40 |
41 | insertIPv4 := NewBasicRangerEntry(*networkIPv4)
42 | insertIPv6 := NewBasicRangerEntry(*networkIPv6)
43 |
44 | ranger.Insert(insertIPv4)
45 | deletedIPv4, err := ranger.Remove(*networkIPv4)
46 | assert.NoError(t, err)
47 |
48 | ranger.Insert(insertIPv6)
49 | deletedIPv6, err := ranger.Remove(*networkIPv6)
50 | assert.NoError(t, err)
51 |
52 | entry, err := ranger.Remove(*notInserted)
53 | assert.NoError(t, err)
54 | assert.Nil(t, entry)
55 |
56 | assert.Equal(t, insertIPv4, deletedIPv4)
57 | assert.Equal(t, 0, len(ranger.ipV4Entries))
58 | assert.Equal(t, insertIPv6, deletedIPv6)
59 | assert.Equal(t, 0, len(ranger.ipV6Entries))
60 | }
61 |
62 | func TestRemoveError(t *testing.T) {
63 | r := newBruteRanger().(*bruteRanger)
64 | _, invalidNetwork, _ := net.ParseCIDR("0.0.1.0/24")
65 | invalidNetwork.IP = append(invalidNetwork.IP, byte(4))
66 |
67 | _, err := r.Remove(*invalidNetwork)
68 | assert.Equal(t, ErrInvalidNetworkInput, err)
69 | }
70 |
71 | func TestContains(t *testing.T) {
72 | r := newBruteRanger().(*bruteRanger)
73 | _, network, _ := net.ParseCIDR("0.0.1.0/24")
74 | _, network1, _ := net.ParseCIDR("8000::/112")
75 | r.Insert(NewBasicRangerEntry(*network))
76 | r.Insert(NewBasicRangerEntry(*network1))
77 |
78 | cases := []struct {
79 | ip net.IP
80 | contains bool
81 | err error
82 | name string
83 | }{
84 | {net.ParseIP("0.0.1.255"), true, nil, "IPv4 should contain"},
85 | {net.ParseIP("0.0.0.255"), false, nil, "IPv4 houldn't contain"},
86 | {net.ParseIP("8000::ffff"), true, nil, "IPv6 shouldn't contain"},
87 | {net.ParseIP("8000::1:ffff"), false, nil, "IPv6 shouldn't contain"},
88 | {append(net.ParseIP("8000::1:ffff"), byte(0)), false, ErrInvalidNetworkInput, "Invalid IP"},
89 | }
90 |
91 | for _, tc := range cases {
92 | t.Run(tc.name, func(t *testing.T) {
93 | contains, err := r.Contains(tc.ip)
94 | if tc.err != nil {
95 | assert.Equal(t, tc.err, err)
96 | } else {
97 | assert.NoError(t, err)
98 | assert.Equal(t, tc.contains, contains)
99 | }
100 | })
101 | }
102 | }
103 |
104 | func TestContainingNetworks(t *testing.T) {
105 | r := newBruteRanger().(*bruteRanger)
106 | _, network1, _ := net.ParseCIDR("0.0.1.0/24")
107 | _, network2, _ := net.ParseCIDR("0.0.1.0/25")
108 | _, network3, _ := net.ParseCIDR("8000::/112")
109 | _, network4, _ := net.ParseCIDR("8000::/113")
110 | entry1 := NewBasicRangerEntry(*network1)
111 | entry2 := NewBasicRangerEntry(*network2)
112 | entry3 := NewBasicRangerEntry(*network3)
113 | entry4 := NewBasicRangerEntry(*network4)
114 | r.Insert(entry1)
115 | r.Insert(entry2)
116 | r.Insert(entry3)
117 | r.Insert(entry4)
118 | cases := []struct {
119 | ip net.IP
120 | containingNetworks []RangerEntry
121 | err error
122 | name string
123 | }{
124 | {net.ParseIP("0.0.1.255"), []RangerEntry{entry1}, nil, "IPv4 should contain"},
125 | {net.ParseIP("0.0.1.127"), []RangerEntry{entry1, entry2}, nil, "IPv4 should contain both"},
126 | {net.ParseIP("0.0.0.127"), []RangerEntry{}, nil, "IPv4 should contain none"},
127 | {net.ParseIP("8000::ffff"), []RangerEntry{entry3}, nil, "IPv6 should constain"},
128 | {net.ParseIP("8000::7fff"), []RangerEntry{entry3, entry4}, nil, "IPv6 should contain both"},
129 | {net.ParseIP("8000::1:7fff"), []RangerEntry{}, nil, "IPv6 should contain none"},
130 | {append(net.ParseIP("8000::1:7fff"), byte(0)), nil, ErrInvalidNetworkInput, "Invalid IP"},
131 | }
132 |
133 | for _, tc := range cases {
134 | t.Run(tc.name, func(t *testing.T) {
135 | networks, err := r.ContainingNetworks(tc.ip)
136 | if tc.err != nil {
137 | assert.Equal(t, tc.err, err)
138 | } else {
139 | assert.NoError(t, err)
140 | assert.Equal(t, len(tc.containingNetworks), len(networks))
141 | for _, network := range tc.containingNetworks {
142 | assert.Contains(t, networks, network)
143 | }
144 | }
145 | })
146 | }
147 | }
148 |
149 | func TestCoveredNetworks(t *testing.T) {
150 | for _, tc := range coveredNetworkTests {
151 | t.Run(tc.name, func(t *testing.T) {
152 | ranger := newBruteRanger()
153 | for _, insert := range tc.inserts {
154 | _, network, _ := net.ParseCIDR(insert)
155 | err := ranger.Insert(NewBasicRangerEntry(*network))
156 | assert.NoError(t, err)
157 | }
158 | var expectedEntries []string
159 | for _, network := range tc.networks {
160 | expectedEntries = append(expectedEntries, network)
161 | }
162 | sort.Strings(expectedEntries)
163 | _, snet, _ := net.ParseCIDR(tc.search)
164 | networks, err := ranger.CoveredNetworks(*snet)
165 | assert.NoError(t, err)
166 |
167 | var results []string
168 | for _, result := range networks {
169 | net := result.Network()
170 | results = append(results, net.String())
171 | }
172 | sort.Strings(results)
173 |
174 | assert.Equal(t, expectedEntries, results)
175 | })
176 | }
177 | }
178 |
--------------------------------------------------------------------------------
/cidranger.go:
--------------------------------------------------------------------------------
1 | /*
2 | Package cidranger provides utility to store CIDR blocks and perform ip
3 | inclusion tests against it.
4 |
5 | To create a new instance of the path-compressed trie:
6 |
7 | ranger := NewPCTrieRanger()
8 |
9 | To insert or remove an entry (any object that satisfies the RangerEntry
10 | interface):
11 |
12 | _, network, _ := net.ParseCIDR("192.168.0.0/24")
13 | ranger.Insert(NewBasicRangerEntry(*network))
14 | ranger.Remove(network)
15 |
16 | If you desire for any value to be attached to the entry, simply
17 | create custom struct that satisfies the RangerEntry interface:
18 |
19 | type RangerEntry interface {
20 | Network() net.IPNet
21 | }
22 |
23 | To test whether an IP is contained in the constructed networks ranger:
24 |
25 | // returns bool, error
26 | containsBool, err := ranger.Contains(net.ParseIP("192.168.0.1"))
27 |
28 | To get a list of CIDR blocks in constructed ranger that contains IP:
29 |
30 | // returns []RangerEntry, error
31 | entries, err := ranger.ContainingNetworks(net.ParseIP("192.168.0.1"))
32 |
33 | To get a list of all IPv4/IPv6 rangers respectively:
34 |
35 | // returns []RangerEntry, error
36 | entries, err := ranger.CoveredNetworks(*AllIPv4)
37 | entries, err := ranger.CoveredNetworks(*AllIPv6)
38 |
39 | */
40 | package cidranger
41 |
42 | import (
43 | "fmt"
44 | "net"
45 | )
46 |
47 | // ErrInvalidNetworkInput is returned upon invalid network input.
48 | var ErrInvalidNetworkInput = fmt.Errorf("Invalid network input")
49 |
50 | // ErrInvalidNetworkNumberInput is returned upon invalid network input.
51 | var ErrInvalidNetworkNumberInput = fmt.Errorf("Invalid network number input")
52 |
53 | // AllIPv4 is a IPv4 CIDR that contains all networks
54 | var AllIPv4 = parseCIDRUnsafe("0.0.0.0/0")
55 |
56 | // AllIPv6 is a IPv6 CIDR that contains all networks
57 | var AllIPv6 = parseCIDRUnsafe("0::0/0")
58 |
59 | func parseCIDRUnsafe(s string) *net.IPNet {
60 | _, cidr, _ := net.ParseCIDR(s)
61 | return cidr
62 | }
63 |
64 | // RangerEntry is an interface for insertable entry into a Ranger.
65 | type RangerEntry interface {
66 | Network() net.IPNet
67 | }
68 |
69 | type basicRangerEntry struct {
70 | ipNet net.IPNet
71 | }
72 |
73 | func (b *basicRangerEntry) Network() net.IPNet {
74 | return b.ipNet
75 | }
76 |
77 | // NewBasicRangerEntry returns a basic RangerEntry that only stores the network
78 | // itself.
79 | func NewBasicRangerEntry(ipNet net.IPNet) RangerEntry {
80 | return &basicRangerEntry{
81 | ipNet: ipNet,
82 | }
83 | }
84 |
85 | // Ranger is an interface for cidr block containment lookups.
86 | type Ranger interface {
87 | Insert(entry RangerEntry) error
88 | Remove(network net.IPNet) (RangerEntry, error)
89 | Contains(ip net.IP) (bool, error)
90 | ContainingNetworks(ip net.IP) ([]RangerEntry, error)
91 | CoveredNetworks(network net.IPNet) ([]RangerEntry, error)
92 | Len() int
93 | }
94 |
95 | // NewPCTrieRanger returns a versionedRanger that supports both IPv4 and IPv6
96 | // using the path compressed trie implemention.
97 | func NewPCTrieRanger() Ranger {
98 | return newVersionedRanger(newPrefixTree)
99 | }
100 |
--------------------------------------------------------------------------------
/cidranger_test.go:
--------------------------------------------------------------------------------
1 | package cidranger
2 |
3 | import (
4 | "encoding/json"
5 | "io/ioutil"
6 | "math/rand"
7 | "net"
8 | "testing"
9 | "time"
10 |
11 | "github.com/stretchr/testify/assert"
12 | rnet "github.com/yl2chen/cidranger/net"
13 | )
14 |
15 | /*
16 | ******************************************************************
17 | Test Contains/ContainingNetworks against basic brute force ranger.
18 | ******************************************************************
19 | */
20 |
21 | func TestContainsAgainstBaseIPv4(t *testing.T) {
22 | testContainsAgainstBase(t, 100000, randIPv4Gen)
23 | }
24 |
25 | func TestContainingNetworksAgaistBaseIPv4(t *testing.T) {
26 | testContainingNetworksAgainstBase(t, 100000, randIPv4Gen)
27 | }
28 |
29 | func TestCoveredNetworksAgainstBaseIPv4(t *testing.T) {
30 | testCoversNetworksAgainstBase(t, 100000, randomIPNetGenFactory(ipV4AWSRangesIPNets))
31 | }
32 |
33 | // IPv6 spans an extremely large address space (2^128), randomly generated IPs
34 | // will often fall outside of the test ranges (AWS public CIDR blocks), so it
35 | // it more meaningful for testing to run from a curated list of IPv6 IPs.
36 | func TestContainsAgaistBaseIPv6(t *testing.T) {
37 | testContainsAgainstBase(t, 100000, curatedAWSIPv6Gen)
38 | }
39 |
40 | func TestContainingNetworksAgaistBaseIPv6(t *testing.T) {
41 | testContainingNetworksAgainstBase(t, 100000, curatedAWSIPv6Gen)
42 | }
43 |
44 | func TestCoveredNetworksAgainstBaseIPv6(t *testing.T) {
45 | testCoversNetworksAgainstBase(t, 100000, randomIPNetGenFactory(ipV6AWSRangesIPNets))
46 | }
47 |
48 | func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) {
49 | if testing.Short() {
50 | t.Skip("Skipping memory test in `-short` mode")
51 | }
52 | rangers := []Ranger{NewPCTrieRanger()}
53 | baseRanger := newBruteRanger()
54 | for _, ranger := range rangers {
55 | configureRangerWithAWSRanges(t, ranger)
56 | }
57 | configureRangerWithAWSRanges(t, baseRanger)
58 |
59 | for i := 0; i < iterations; i++ {
60 | nn := ipGen()
61 | expected, err := baseRanger.Contains(nn.ToIP())
62 | assert.NoError(t, err)
63 | for _, ranger := range rangers {
64 | actual, err := ranger.Contains(nn.ToIP())
65 | assert.NoError(t, err)
66 | assert.Equal(t, expected, actual)
67 | }
68 | }
69 | }
70 |
71 | func testContainingNetworksAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) {
72 | if testing.Short() {
73 | t.Skip("Skipping memory test in `-short` mode")
74 | }
75 | rangers := []Ranger{NewPCTrieRanger()}
76 | baseRanger := newBruteRanger()
77 | for _, ranger := range rangers {
78 | configureRangerWithAWSRanges(t, ranger)
79 | }
80 | configureRangerWithAWSRanges(t, baseRanger)
81 |
82 | for i := 0; i < iterations; i++ {
83 | nn := ipGen()
84 | expected, err := baseRanger.ContainingNetworks(nn.ToIP())
85 | assert.NoError(t, err)
86 | for _, ranger := range rangers {
87 | actual, err := ranger.ContainingNetworks(nn.ToIP())
88 | assert.NoError(t, err)
89 | assert.Equal(t, len(expected), len(actual))
90 | for _, network := range actual {
91 | assert.Contains(t, expected, network)
92 | }
93 | }
94 | }
95 | }
96 |
97 | func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkGenerator) {
98 | if testing.Short() {
99 | t.Skip("Skipping memory test in `-short` mode")
100 | }
101 | rangers := []Ranger{NewPCTrieRanger()}
102 | baseRanger := newBruteRanger()
103 | for _, ranger := range rangers {
104 | configureRangerWithAWSRanges(t, ranger)
105 | }
106 | configureRangerWithAWSRanges(t, baseRanger)
107 |
108 | for i := 0; i < iterations; i++ {
109 | network := netGen()
110 | expected, err := baseRanger.CoveredNetworks(network.IPNet)
111 | assert.NoError(t, err)
112 | for _, ranger := range rangers {
113 | actual, err := ranger.CoveredNetworks(network.IPNet)
114 | assert.NoError(t, err)
115 | assert.Equal(t, len(expected), len(actual))
116 | for _, network := range actual {
117 | assert.Contains(t, expected, network)
118 | }
119 | }
120 | }
121 | }
122 |
123 | /*
124 | ******************************************************************
125 | Benchmarks.
126 | ******************************************************************
127 | */
128 |
129 | func BenchmarkPCTrieHitIPv4UsingAWSRanges(b *testing.B) {
130 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("52.95.110.1"), NewPCTrieRanger())
131 | }
132 | func BenchmarkBruteRangerHitIPv4UsingAWSRanges(b *testing.B) {
133 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("52.95.110.1"), newBruteRanger())
134 | }
135 |
136 | func BenchmarkPCTrieHitIPv6UsingAWSRanges(b *testing.B) {
137 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger())
138 | }
139 | func BenchmarkBruteRangerHitIPv6UsingAWSRanges(b *testing.B) {
140 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger())
141 | }
142 |
143 | func BenchmarkPCTrieMissIPv4UsingAWSRanges(b *testing.B) {
144 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("123.123.123.123"), NewPCTrieRanger())
145 | }
146 | func BenchmarkBruteRangerMissIPv4UsingAWSRanges(b *testing.B) {
147 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("123.123.123.123"), newBruteRanger())
148 | }
149 |
150 | func BenchmarkPCTrieHMissIPv6UsingAWSRanges(b *testing.B) {
151 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620::ffff"), NewPCTrieRanger())
152 | }
153 | func BenchmarkBruteRangerMissIPv6UsingAWSRanges(b *testing.B) {
154 | benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620::ffff"), newBruteRanger())
155 | }
156 |
157 | func BenchmarkPCTrieHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) {
158 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("52.95.110.1"), NewPCTrieRanger())
159 | }
160 | func BenchmarkBruteRangerHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) {
161 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("52.95.110.1"), newBruteRanger())
162 | }
163 |
164 | func BenchmarkPCTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) {
165 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger())
166 | }
167 | func BenchmarkBruteRangerHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) {
168 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger())
169 | }
170 |
171 | func BenchmarkPCTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) {
172 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("123.123.123.123"), NewPCTrieRanger())
173 | }
174 | func BenchmarkBruteRangerMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) {
175 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("123.123.123.123"), newBruteRanger())
176 | }
177 |
178 | func BenchmarkPCTrieHMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) {
179 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620::ffff"), NewPCTrieRanger())
180 | }
181 | func BenchmarkBruteRangerMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) {
182 | benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620::ffff"), newBruteRanger())
183 | }
184 |
185 | func BenchmarkNewPathprefixTriev4(b *testing.B) {
186 | benchmarkNewPathprefixTrie(b, "192.128.0.0/24")
187 | }
188 |
189 | func BenchmarkNewPathprefixTriev6(b *testing.B) {
190 | benchmarkNewPathprefixTrie(b, "8000::/24")
191 | }
192 |
193 | func benchmarkContainsUsingAWSRanges(tb testing.TB, nn net.IP, ranger Ranger) {
194 | configureRangerWithAWSRanges(tb, ranger)
195 | for n := 0; n < tb.(*testing.B).N; n++ {
196 | ranger.Contains(nn)
197 | }
198 | }
199 |
200 | func benchmarkContainingNetworksUsingAWSRanges(tb testing.TB, nn net.IP, ranger Ranger) {
201 | configureRangerWithAWSRanges(tb, ranger)
202 | for n := 0; n < tb.(*testing.B).N; n++ {
203 | ranger.ContainingNetworks(nn)
204 | }
205 | }
206 |
207 | func benchmarkNewPathprefixTrie(b *testing.B, net1 string) {
208 | _, ipNet1, _ := net.ParseCIDR(net1)
209 | ones, _ := ipNet1.Mask.Size()
210 |
211 | n1 := rnet.NewNetwork(*ipNet1)
212 | uOnes := uint(ones)
213 |
214 | b.ResetTimer()
215 | for n := 0; n < b.N; n++ {
216 | newPathprefixTrie(n1, uOnes)
217 | }
218 | }
219 |
220 | /*
221 | ******************************************************************
222 | Helper methods and initialization.
223 | ******************************************************************
224 | */
225 |
226 | type ipGenerator func() rnet.NetworkNumber
227 |
228 | func randIPv4Gen() rnet.NetworkNumber {
229 | return rnet.NetworkNumber{rand.Uint32()}
230 | }
231 | func randIPv6Gen() rnet.NetworkNumber {
232 | return rnet.NetworkNumber{rand.Uint32(), rand.Uint32(), rand.Uint32(), rand.Uint32()}
233 | }
234 | func curatedAWSIPv6Gen() rnet.NetworkNumber {
235 | randIdx := rand.Intn(len(ipV6AWSRangesIPNets))
236 |
237 | // Randomly generate an IP somewhat near the range.
238 | network := ipV6AWSRangesIPNets[randIdx]
239 | nn := rnet.NewNetworkNumber(network.IP)
240 | ones, bits := network.Mask.Size()
241 | zeros := bits - ones
242 | nnPartIdx := zeros / rnet.BitsPerUint32
243 | nn[nnPartIdx] = rand.Uint32()
244 | return nn
245 | }
246 |
247 | type networkGenerator func() rnet.Network
248 |
249 | func randomIPNetGenFactory(pool []*net.IPNet) networkGenerator {
250 | return func() rnet.Network {
251 | return rnet.NewNetwork(*pool[rand.Intn(len(pool))])
252 | }
253 | }
254 |
255 | type AWSRanges struct {
256 | Prefixes []Prefix `json:"prefixes"`
257 | IPv6Prefixes []IPv6Prefix `json:"ipv6_prefixes"`
258 | }
259 |
260 | type Prefix struct {
261 | IPPrefix string `json:"ip_prefix"`
262 | Region string `json:"region"`
263 | Service string `json:"service"`
264 | }
265 |
266 | type IPv6Prefix struct {
267 | IPPrefix string `json:"ipv6_prefix"`
268 | Region string `json:"region"`
269 | Service string `json:"service"`
270 | }
271 |
272 | var awsRanges *AWSRanges
273 | var ipV4AWSRangesIPNets []*net.IPNet
274 | var ipV6AWSRangesIPNets []*net.IPNet
275 |
276 | func loadAWSRanges() *AWSRanges {
277 | file, err := ioutil.ReadFile("./testdata/aws_ip_ranges.json")
278 | if err != nil {
279 | panic(err)
280 | }
281 | var ranges AWSRanges
282 | err = json.Unmarshal(file, &ranges)
283 | if err != nil {
284 | panic(err)
285 | }
286 | return &ranges
287 | }
288 |
289 | func configureRangerWithAWSRanges(tb testing.TB, ranger Ranger) {
290 | for _, prefix := range awsRanges.Prefixes {
291 | _, network, err := net.ParseCIDR(prefix.IPPrefix)
292 | assert.NoError(tb, err)
293 | ranger.Insert(NewBasicRangerEntry(*network))
294 | }
295 | for _, prefix := range awsRanges.IPv6Prefixes {
296 | _, network, err := net.ParseCIDR(prefix.IPPrefix)
297 | assert.NoError(tb, err)
298 | ranger.Insert(NewBasicRangerEntry(*network))
299 | }
300 | }
301 |
302 | func init() {
303 | awsRanges = loadAWSRanges()
304 | for _, prefix := range awsRanges.IPv6Prefixes {
305 | _, network, _ := net.ParseCIDR(prefix.IPPrefix)
306 | ipV6AWSRangesIPNets = append(ipV6AWSRangesIPNets, network)
307 | }
308 | for _, prefix := range awsRanges.Prefixes {
309 | _, network, _ := net.ParseCIDR(prefix.IPPrefix)
310 | ipV4AWSRangesIPNets = append(ipV4AWSRangesIPNets, network)
311 | }
312 | rand.Seed(time.Now().Unix())
313 | }
314 |
--------------------------------------------------------------------------------
/example/custom-ranger-asn.go:
--------------------------------------------------------------------------------
1 | /*
2 | Example of how to extend github.com/yl2chen/cidranger
3 |
4 | This adds ASN as a string field, along with methods to get the ASN and the CIDR as strings
5 |
6 | Thank you to yl2chen for his assistance and work on this library
7 | */
8 | package main
9 |
10 | import (
11 | "fmt"
12 | "net"
13 | "os"
14 |
15 | "github.com/yl2chen/cidranger"
16 | )
17 |
18 | // custom structure that conforms to RangerEntry interface
19 | type customRangerEntry struct {
20 | ipNet net.IPNet
21 | asn string
22 | }
23 |
24 | // get function for network
25 | func (b *customRangerEntry) Network() net.IPNet {
26 | return b.ipNet
27 | }
28 |
29 | // get function for network converted to string
30 | func (b *customRangerEntry) NetworkStr() string {
31 | return b.ipNet.String()
32 | }
33 |
34 | // get function for ASN
35 | func (b *customRangerEntry) Asn() string {
36 | return b.asn
37 | }
38 |
39 | // create customRangerEntry object using net and asn
40 | func newCustomRangerEntry(ipNet net.IPNet, asn string) cidranger.RangerEntry {
41 | return &customRangerEntry{
42 | ipNet: ipNet,
43 | asn: asn,
44 | }
45 | }
46 |
47 | // entry point
48 | func main() {
49 |
50 | // instantiate NewPCTrieRanger
51 | ranger := cidranger.NewPCTrieRanger()
52 |
53 | // Load sample data using our custom function
54 | _, network, _ := net.ParseCIDR("192.168.1.0/24")
55 | ranger.Insert(newCustomRangerEntry(*network, "0001"))
56 |
57 | _, network, _ = net.ParseCIDR("128.168.1.0/24")
58 | ranger.Insert(newCustomRangerEntry(*network, "0002"))
59 |
60 | // Check if IP is contained within ranger
61 | contains, err := ranger.Contains(net.ParseIP("128.168.1.7"))
62 | if err != nil {
63 | fmt.Println("ranger.Contains()", err.Error())
64 | os.Exit(1)
65 | }
66 | fmt.Println("Contains:", contains)
67 |
68 | // request networks containing this IP
69 | ip := "192.168.1.42"
70 | entries, err := ranger.ContainingNetworks(net.ParseIP(ip))
71 | if err != nil {
72 | fmt.Println("ranger.ContainingNetworks()", err.Error())
73 | os.Exit(1)
74 | }
75 |
76 | fmt.Printf("Entries for %s:\n", ip)
77 | for _, e := range entries {
78 |
79 | // Cast e (cidranger.RangerEntry to struct customRangerEntry
80 | entry, ok := e.(*customRangerEntry)
81 | if !ok {
82 | continue
83 | }
84 |
85 | // Get network (converted to string by function)
86 | n := entry.NetworkStr()
87 |
88 | // Get ASN
89 | a := entry.Asn()
90 |
91 | // Display
92 | fmt.Println("\t", n, a)
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/yl2chen/cidranger
2 |
3 | go 1.13
4 |
5 | require (
6 | github.com/stretchr/testify v1.6.1
7 | gopkg.in/yaml.v2 v2.2.2 // indirect
8 | )
9 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
5 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
6 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
7 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
8 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
9 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
12 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
13 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
14 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
15 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
16 |
--------------------------------------------------------------------------------
/net/ip.go:
--------------------------------------------------------------------------------
1 | /*
2 | Package net provides utility functions for working with IPs (net.IP).
3 | */
4 | package net
5 |
6 | import (
7 | "bytes"
8 | "encoding/binary"
9 | "fmt"
10 | "math"
11 | "net"
12 | )
13 |
14 | // IPVersion is version of IP address.
15 | type IPVersion string
16 |
17 | // Helper constants.
18 | const (
19 | IPv4Uint32Count = 1
20 | IPv6Uint32Count = 4
21 |
22 | BitsPerUint32 = 32
23 | BytePerUint32 = 4
24 |
25 | IPv4 IPVersion = "IPv4"
26 | IPv6 IPVersion = "IPv6"
27 | )
28 |
29 | // ErrInvalidBitPosition is returned when bits requested is not valid.
30 | var ErrInvalidBitPosition = fmt.Errorf("bit position not valid")
31 |
32 | // ErrVersionMismatch is returned upon mismatch in network input versions.
33 | var ErrVersionMismatch = fmt.Errorf("Network input version mismatch")
34 |
35 | // ErrNoGreatestCommonBit is an error returned when no greatest common bit
36 | // exists for the cidr ranges.
37 | var ErrNoGreatestCommonBit = fmt.Errorf("No greatest common bit")
38 |
39 | // NetworkNumber represents an IP address using uint32 as internal storage.
40 | // IPv4 usings 1 uint32, while IPv6 uses 4 uint32.
41 | type NetworkNumber []uint32
42 |
43 | // NewNetworkNumber returns a equivalent NetworkNumber to given IP address,
44 | // return nil if ip is neither IPv4 nor IPv6.
45 | func NewNetworkNumber(ip net.IP) NetworkNumber {
46 | if ip == nil {
47 | return nil
48 | }
49 | coercedIP := ip.To4()
50 | parts := 1
51 | if coercedIP == nil {
52 | coercedIP = ip.To16()
53 | parts = 4
54 | }
55 | if coercedIP == nil {
56 | return nil
57 | }
58 | nn := make(NetworkNumber, parts)
59 | for i := 0; i < parts; i++ {
60 | idx := i * net.IPv4len
61 | nn[i] = binary.BigEndian.Uint32(coercedIP[idx : idx+net.IPv4len])
62 | }
63 | return nn
64 | }
65 |
66 | // ToV4 returns ip address if ip is IPv4, returns nil otherwise.
67 | func (n NetworkNumber) ToV4() NetworkNumber {
68 | if len(n) != IPv4Uint32Count {
69 | return nil
70 | }
71 | return n
72 | }
73 |
74 | // ToV6 returns ip address if ip is IPv6, returns nil otherwise.
75 | func (n NetworkNumber) ToV6() NetworkNumber {
76 | if len(n) != IPv6Uint32Count {
77 | return nil
78 | }
79 | return n
80 | }
81 |
82 | // ToIP returns equivalent net.IP.
83 | func (n NetworkNumber) ToIP() net.IP {
84 | ip := make(net.IP, len(n)*BytePerUint32)
85 | for i := 0; i < len(n); i++ {
86 | idx := i * net.IPv4len
87 | binary.BigEndian.PutUint32(ip[idx:idx+net.IPv4len], n[i])
88 | }
89 | if len(ip) == net.IPv4len {
90 | ip = net.IPv4(ip[0], ip[1], ip[2], ip[3])
91 | }
92 | return ip
93 | }
94 |
95 | // Equal is the equality test for 2 network numbers.
96 | func (n NetworkNumber) Equal(n1 NetworkNumber) bool {
97 | if len(n) != len(n1) {
98 | return false
99 | }
100 | if n[0] != n1[0] {
101 | return false
102 | }
103 | if len(n) == IPv6Uint32Count {
104 | return n[1] == n1[1] && n[2] == n1[2] && n[3] == n1[3]
105 | }
106 | return true
107 | }
108 |
109 | // Next returns the next logical network number.
110 | func (n NetworkNumber) Next() NetworkNumber {
111 | newIP := make(NetworkNumber, len(n))
112 | copy(newIP, n)
113 | for i := len(newIP) - 1; i >= 0; i-- {
114 | newIP[i]++
115 | if newIP[i] > 0 {
116 | break
117 | }
118 | }
119 | return newIP
120 | }
121 |
122 | // Previous returns the previous logical network number.
123 | func (n NetworkNumber) Previous() NetworkNumber {
124 | newIP := make(NetworkNumber, len(n))
125 | copy(newIP, n)
126 | for i := len(newIP) - 1; i >= 0; i-- {
127 | newIP[i]--
128 | if newIP[i] < math.MaxUint32 {
129 | break
130 | }
131 | }
132 | return newIP
133 | }
134 |
135 | // Bit returns uint32 representing the bit value at given position, e.g.,
136 | // "128.0.0.0" has bit value of 1 at position 31, and 0 for positions 30 to 0.
137 | func (n NetworkNumber) Bit(position uint) (uint32, error) {
138 | if int(position) > len(n)*BitsPerUint32-1 {
139 | return 0, ErrInvalidBitPosition
140 | }
141 | idx := len(n) - 1 - int(position/BitsPerUint32)
142 | // Mod 31 to get array index.
143 | rShift := position & (BitsPerUint32 - 1)
144 | return (n[idx] >> rShift) & 1, nil
145 | }
146 |
147 | // LeastCommonBitPosition returns the smallest position of the preceding common
148 | // bits of the 2 network numbers, and returns an error ErrNoGreatestCommonBit
149 | // if the two network number diverges from the first bit.
150 | // e.g., if the network number diverges after the 1st bit, it returns 131 for
151 | // IPv6 and 31 for IPv4 .
152 | func (n NetworkNumber) LeastCommonBitPosition(n1 NetworkNumber) (uint, error) {
153 | if len(n) != len(n1) {
154 | return 0, ErrVersionMismatch
155 | }
156 | for i := 0; i < len(n); i++ {
157 | mask := uint32(1) << 31
158 | pos := uint(31)
159 | for ; mask > 0; mask >>= 1 {
160 | if n[i]&mask != n1[i]&mask {
161 | if i == 0 && pos == 31 {
162 | return 0, ErrNoGreatestCommonBit
163 | }
164 | return (pos + 1) + uint(BitsPerUint32)*uint(len(n)-i-1), nil
165 | }
166 | pos--
167 | }
168 | }
169 | return 0, nil
170 | }
171 |
172 | // Network represents a block of network numbers, also known as CIDR.
173 | type Network struct {
174 | net.IPNet
175 | Number NetworkNumber
176 | Mask NetworkNumberMask
177 | }
178 |
179 | // NewNetwork returns Network built using given net.IPNet.
180 | func NewNetwork(ipNet net.IPNet) Network {
181 | return Network{
182 | IPNet: ipNet,
183 | Number: NewNetworkNumber(ipNet.IP),
184 | Mask: NetworkNumberMask(NewNetworkNumber(net.IP(ipNet.Mask))),
185 | }
186 | }
187 |
188 | // Masked returns a new network conforming to new mask.
189 | func (n Network) Masked(ones int) Network {
190 | mask := net.CIDRMask(ones, len(n.Number)*BitsPerUint32)
191 | return NewNetwork(net.IPNet{
192 | IP: n.IP.Mask(mask),
193 | Mask: mask,
194 | })
195 | }
196 |
197 | // Contains returns true if NetworkNumber is in range of Network, false
198 | // otherwise.
199 | func (n Network) Contains(nn NetworkNumber) bool {
200 | if len(n.Mask) != len(nn) {
201 | return false
202 | }
203 | if nn[0]&n.Mask[0] != n.Number[0] {
204 | return false
205 | }
206 | if len(nn) == IPv6Uint32Count {
207 | return nn[1]&n.Mask[1] == n.Number[1] && nn[2]&n.Mask[2] == n.Number[2] && nn[3]&n.Mask[3] == n.Number[3]
208 | }
209 | return true
210 | }
211 |
212 | // Contains returns true if Network covers o, false otherwise
213 | func (n Network) Covers(o Network) bool {
214 | if len(n.Number) != len(o.Number) {
215 | return false
216 | }
217 | nMaskSize, _ := n.IPNet.Mask.Size()
218 | oMaskSize, _ := o.IPNet.Mask.Size()
219 | return n.Contains(o.Number) && nMaskSize <= oMaskSize
220 | }
221 |
222 | // LeastCommonBitPosition returns the smallest position of the preceding common
223 | // bits of the 2 networks, and returns an error ErrNoGreatestCommonBit
224 | // if the two network number diverges from the first bit.
225 | func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) {
226 | maskSize, _ := n.IPNet.Mask.Size()
227 | if maskSize1, _ := n1.IPNet.Mask.Size(); maskSize1 < maskSize {
228 | maskSize = maskSize1
229 | }
230 | maskPosition := len(n1.Number)*BitsPerUint32 - maskSize
231 | lcb, err := n.Number.LeastCommonBitPosition(n1.Number)
232 | if err != nil {
233 | return 0, err
234 | }
235 | return uint(math.Max(float64(maskPosition), float64(lcb))), nil
236 | }
237 |
238 | // Equal is the equality test for 2 networks.
239 | func (n Network) Equal(n1 Network) bool {
240 | return bytes.Equal(n.IPNet.IP, n1.IPNet.IP) && bytes.Equal(n.IPNet.Mask, n1.IPNet.Mask)
241 | }
242 |
243 | func (n Network) String() string {
244 | return n.IPNet.String()
245 | }
246 |
247 | // NetworkNumberMask is an IP address.
248 | type NetworkNumberMask NetworkNumber
249 |
250 | // Mask returns a new masked NetworkNumber from given NetworkNumber.
251 | func (m NetworkNumberMask) Mask(n NetworkNumber) (NetworkNumber, error) {
252 | if len(m) != len(n) {
253 | return nil, ErrVersionMismatch
254 | }
255 | result := make(NetworkNumber, len(m))
256 | result[0] = m[0] & n[0]
257 | if len(m) == IPv6Uint32Count {
258 | result[1] = m[1] & n[1]
259 | result[2] = m[2] & n[2]
260 | result[3] = m[3] & n[3]
261 | }
262 | return result, nil
263 | }
264 |
265 | // NextIP returns the next sequential ip.
266 | func NextIP(ip net.IP) net.IP {
267 | return NewNetworkNumber(ip).Next().ToIP()
268 | }
269 |
270 | // PreviousIP returns the previous sequential ip.
271 | func PreviousIP(ip net.IP) net.IP {
272 | return NewNetworkNumber(ip).Previous().ToIP()
273 | }
274 |
--------------------------------------------------------------------------------
/net/ip_test.go:
--------------------------------------------------------------------------------
1 | package net
2 |
3 | import (
4 | "math"
5 | "net"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestNewNetworkNumber(t *testing.T) {
12 | cases := []struct {
13 | ip net.IP
14 | nn NetworkNumber
15 | name string
16 | }{
17 | {nil, nil, "nil input"},
18 | {net.IP([]byte{1, 1, 1, 1, 1}), nil, "bad input"},
19 | {net.ParseIP("128.0.0.0"), NetworkNumber([]uint32{2147483648}), "IPv4"},
20 | {
21 | net.ParseIP("2001:0db8::ff00:0042:8329"),
22 | NetworkNumber([]uint32{536939960, 0, 65280, 4358953}),
23 | "IPv6",
24 | },
25 | }
26 | for _, tc := range cases {
27 | t.Run(tc.name, func(t *testing.T) {
28 | assert.Equal(t, tc.nn, NewNetworkNumber(tc.ip))
29 | })
30 | }
31 | }
32 |
33 | func TestNetworkNumberAssertion(t *testing.T) {
34 | cases := []struct {
35 | ip NetworkNumber
36 | to4 NetworkNumber
37 | to6 NetworkNumber
38 | name string
39 | }{
40 | {NetworkNumber([]uint32{1}), NetworkNumber([]uint32{1}), nil, "is IPv4"},
41 | {NetworkNumber([]uint32{1, 1, 1, 1}), nil, NetworkNumber([]uint32{1, 1, 1, 1}), "is IPv6"},
42 | {NetworkNumber([]uint32{1, 1}), nil, nil, "is invalid"},
43 | }
44 | for _, tc := range cases {
45 | t.Run(tc.name, func(t *testing.T) {
46 | assert.Equal(t, tc.to4, tc.ip.ToV4())
47 | assert.Equal(t, tc.to6, tc.ip.ToV6())
48 | })
49 | }
50 | }
51 |
52 | func TestNetworkNumberBit(t *testing.T) {
53 | cases := []struct {
54 | ip NetworkNumber
55 | ones map[uint]bool
56 | name string
57 | }{
58 | {NewNetworkNumber(net.ParseIP("128.0.0.0")), map[uint]bool{31: true}, "128.0.0.0"},
59 | {NewNetworkNumber(net.ParseIP("1.1.1.1")), map[uint]bool{0: true, 8: true, 16: true, 24: true}, "1.1.1.1"},
60 | {NewNetworkNumber(net.ParseIP("8000::")), map[uint]bool{127: true}, "8000::"},
61 | {NewNetworkNumber(net.ParseIP("8000::8000")), map[uint]bool{127: true, 15: true}, "8000::8000"},
62 | }
63 | for _, tc := range cases {
64 | t.Run(tc.name, func(t *testing.T) {
65 | for i := uint(0); i < uint(len(tc.ip)*BitsPerUint32); i++ {
66 | bit, err := tc.ip.Bit(i)
67 | assert.NoError(t, err)
68 | if _, isOne := tc.ones[i]; isOne {
69 | assert.Equal(t, uint32(1), bit)
70 | } else {
71 | assert.Equal(t, uint32(0), bit)
72 | }
73 | }
74 | })
75 | }
76 | }
77 |
78 | func TestNetworkNumberBitError(t *testing.T) {
79 | cases := []struct {
80 | ip NetworkNumber
81 | position uint
82 | err error
83 | name string
84 | }{
85 | {NewNetworkNumber(net.ParseIP("128.0.0.0")), 0, nil, "IPv4 index in bound"},
86 | {NewNetworkNumber(net.ParseIP("128.0.0.0")), 31, nil, "IPv4 index in bound"},
87 | {NewNetworkNumber(net.ParseIP("128.0.0.0")), 32, ErrInvalidBitPosition, "IPv4 index out of bounds"},
88 | {NewNetworkNumber(net.ParseIP("8000::")), 0, nil, "IPv6 index in bound"},
89 | {NewNetworkNumber(net.ParseIP("8000::")), 127, nil, "IPv6 index in bound"},
90 | {NewNetworkNumber(net.ParseIP("8000::")), 128, ErrInvalidBitPosition, "IPv6 index out of bounds"},
91 | }
92 | for _, tc := range cases {
93 | t.Run(tc.name, func(t *testing.T) {
94 | _, err := tc.ip.Bit(tc.position)
95 | assert.Equal(t, tc.err, err)
96 | })
97 | }
98 | }
99 |
100 | func TestNetworkNumberEqual(t *testing.T) {
101 | cases := []struct {
102 | n1 NetworkNumber
103 | n2 NetworkNumber
104 | equals bool
105 | name string
106 | }{
107 | {NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32}, true, "IPv4 equals"},
108 | {NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32 - 1}, false, "IPv4 does not equal"},
109 | {NetworkNumber{1, 1, 1, 1}, NetworkNumber{1, 1, 1, 1}, true, "IPv6 equals"},
110 | {NetworkNumber{1, 1, 1, 1}, NetworkNumber{1, 1, 1, 2}, false, "IPv6 does not equal"},
111 | {NetworkNumber{1}, NetworkNumber{1, 2, 3, 4}, false, "Version mismatch"},
112 | }
113 | for _, tc := range cases {
114 | t.Run(tc.name, func(t *testing.T) {
115 | assert.Equal(t, tc.equals, tc.n1.Equal(tc.n2))
116 | })
117 | }
118 | }
119 |
120 | func TestNetworkNumberNext(t *testing.T) {
121 | cases := []struct {
122 | ip string
123 | next string
124 | name string
125 | }{
126 | {"0.0.0.0", "0.0.0.1", "IPv4 basic"},
127 | {"0.0.0.255", "0.0.1.0", "IPv4 rollover"},
128 | {"0.255.255.255", "1.0.0.0", "IPv4 consecutive rollover"},
129 | {"8000::0", "8000::1", "IPv6 basic"},
130 | {"0::ffff", "0::1:0", "IPv6 rollover"},
131 | {"0:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "1::", "IPv6 consecutive rollover"},
132 | }
133 |
134 | for _, tc := range cases {
135 | t.Run(tc.name, func(t *testing.T) {
136 | ip := NewNetworkNumber(net.ParseIP(tc.ip))
137 | expected := NewNetworkNumber(net.ParseIP(tc.next))
138 | assert.Equal(t, expected, ip.Next())
139 | })
140 | }
141 | }
142 |
143 | func TestNeworkNumberPrevious(t *testing.T) {
144 | cases := []struct {
145 | ip string
146 | previous string
147 | name string
148 | }{
149 | {"0.0.0.1", "0.0.0.0", "IPv4 basic"},
150 | {"0.0.1.0", "0.0.0.255", "IPv4 rollover"},
151 | {"1.0.0.0", "0.255.255.255", "IPv4 consecutive rollover"},
152 | {"8000::1", "8000::0", "IPv6 basic"},
153 | {"0::1:0", "0::ffff", "IPv6 rollover"},
154 | {"1::0", "0:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "IPv6 consecutive rollover"},
155 | }
156 |
157 | for _, tc := range cases {
158 | t.Run(tc.name, func(t *testing.T) {
159 | ip := NewNetworkNumber(net.ParseIP(tc.ip))
160 | expected := NewNetworkNumber(net.ParseIP(tc.previous))
161 | assert.Equal(t, expected, ip.Previous())
162 | })
163 | }
164 | }
165 |
166 | func TestLeastCommonBitPositionForNetworks(t *testing.T) {
167 | cases := []struct {
168 | ip1 NetworkNumber
169 | ip2 NetworkNumber
170 | position uint
171 | err error
172 | name string
173 | }{
174 | {
175 | NetworkNumber([]uint32{2147483648}),
176 | NetworkNumber([]uint32{3221225472, 0, 0, 0}),
177 | 0, ErrVersionMismatch, "Version mismatch",
178 | },
179 | {
180 | NetworkNumber([]uint32{2147483648}),
181 | NetworkNumber([]uint32{3221225472}),
182 | 31, nil, "IPv4 31st position",
183 | },
184 | {
185 | NetworkNumber([]uint32{2147483648}),
186 | NetworkNumber([]uint32{2147483648}),
187 | 0, nil, "IPv4 0th position",
188 | },
189 | {
190 | NetworkNumber([]uint32{2147483648}),
191 | NetworkNumber([]uint32{1}),
192 | 0, ErrNoGreatestCommonBit, "IPv4 diverge at first bit",
193 | },
194 | {
195 | NetworkNumber([]uint32{2147483648, 0, 0, 0}),
196 | NetworkNumber([]uint32{3221225472, 0, 0, 0}),
197 | 127, nil, "IPv6 127th position",
198 | },
199 | {
200 | NetworkNumber([]uint32{2147483648, 1, 1, 1}),
201 | NetworkNumber([]uint32{2147483648, 1, 1, 1}),
202 | 0, nil, "IPv6 0th position",
203 | },
204 | {
205 | NetworkNumber([]uint32{2147483648, 0, 0, 0}),
206 | NetworkNumber([]uint32{0, 0, 0, 1}),
207 | 0, ErrNoGreatestCommonBit, "IPv6 diverge at first bit",
208 | },
209 | }
210 | for _, tc := range cases {
211 | t.Run(tc.name, func(t *testing.T) {
212 | pos, err := tc.ip1.LeastCommonBitPosition(tc.ip2)
213 | assert.Equal(t, tc.err, err)
214 | assert.Equal(t, tc.position, pos)
215 | })
216 | }
217 | }
218 |
219 | func TestNewNetwork(t *testing.T) {
220 | _, ipNet, _ := net.ParseCIDR("192.128.0.0/24")
221 | n := NewNetwork(*ipNet)
222 |
223 | assert.Equal(t, *ipNet, n.IPNet)
224 | assert.Equal(t, NetworkNumber{3229614080}, n.Number)
225 | assert.Equal(t, NetworkNumberMask{math.MaxUint32 - uint32(math.MaxUint8)}, n.Mask)
226 | }
227 |
228 | func TestNetworkMasked(t *testing.T) {
229 | cases := []struct {
230 | network string
231 | mask int
232 | maskedNetwork string
233 | }{
234 | {"192.168.0.0/16", 16, "192.168.0.0/16"},
235 | {"192.168.0.0/16", 14, "192.168.0.0/14"},
236 | {"192.168.0.0/16", 18, "192.168.0.0/18"},
237 | {"192.168.0.0/16", 8, "192.0.0.0/8"},
238 | {"8000::/128", 96, "8000::/96"},
239 | {"8000::/128", 128, "8000::/128"},
240 | {"8000::/96", 112, "8000::/112"},
241 | {"8000:ffff::/96", 16, "8000::/16"},
242 | }
243 | for _, testcase := range cases {
244 | _, network, _ := net.ParseCIDR(testcase.network)
245 | _, expected, _ := net.ParseCIDR(testcase.maskedNetwork)
246 | n1 := NewNetwork(*network)
247 | e1 := NewNetwork(*expected)
248 | assert.True(t, e1.String() == n1.Masked(testcase.mask).String())
249 | }
250 | }
251 |
252 | func TestNetworkEqual(t *testing.T) {
253 | cases := []struct {
254 | n1 string
255 | n2 string
256 | equal bool
257 | name string
258 | }{
259 | {"192.128.0.0/24", "192.128.0.0/24", true, "IPv4 equals"},
260 | {"192.128.0.0/24", "192.128.0.0/23", false, "IPv4 not equals"},
261 | {"8000::/24", "8000::/24", true, "IPv6 equals"},
262 | {"8000::/24", "8000::/23", false, "IPv6 not equals"},
263 | }
264 | for _, tc := range cases {
265 | t.Run(tc.name, func(t *testing.T) {
266 | _, ipNet1, _ := net.ParseCIDR(tc.n1)
267 | _, ipNet2, _ := net.ParseCIDR(tc.n2)
268 | assert.Equal(t, tc.equal, NewNetwork(*ipNet1).Equal(NewNetwork(*ipNet2)))
269 | })
270 | }
271 | }
272 |
273 | func TestNetworkContains(t *testing.T) {
274 | cases := []struct {
275 | network string
276 | firstIP string
277 | lastIP string
278 | name string
279 | }{
280 | {"192.168.0.0/24", "192.168.0.0", "192.168.0.255", "192.168.0.0/24 contains"},
281 | {"8000::0/120", "8000::0", "8000::ff", "8000::0/120 contains"},
282 | }
283 | for _, tc := range cases {
284 | t.Run(tc.name, func(t *testing.T) {
285 | _, net1, _ := net.ParseCIDR(tc.network)
286 | network := NewNetwork(*net1)
287 | ip := NewNetworkNumber(net.ParseIP(tc.firstIP))
288 | lastIP := NewNetworkNumber(net.ParseIP(tc.lastIP))
289 | assert.False(t, network.Contains(ip.Previous()))
290 | assert.False(t, network.Contains(lastIP.Next()))
291 | for ; !ip.Equal(lastIP.Next()); ip = ip.Next() {
292 | assert.True(t, network.Contains(ip))
293 | }
294 | })
295 | }
296 | }
297 |
298 | func TestNetworkContainsVersionMismatch(t *testing.T) {
299 | cases := []struct {
300 | network string
301 | ip string
302 | name string
303 | }{
304 | {"192.168.0.0/24", "8000::0", "IPv6 in IPv4 network"},
305 | {"8000::0/120", "192.168.0.0", "IPv4 in IPv6 network"},
306 | }
307 | for _, tc := range cases {
308 | t.Run(tc.name, func(t *testing.T) {
309 | _, net1, _ := net.ParseCIDR(tc.network)
310 | network := NewNetwork(*net1)
311 | assert.False(t, network.Contains(NewNetworkNumber(net.ParseIP(tc.ip))))
312 | })
313 | }
314 | }
315 |
316 | func TestNetworkCovers(t *testing.T) {
317 | cases := []struct {
318 | network string
319 | covers string
320 | result bool
321 | name string
322 | }{
323 | {"10.0.0.0/24", "10.0.0.1/25", true, "contains"},
324 | {"10.0.0.0/24", "11.0.0.1/25", false, "not contains"},
325 | {"10.0.0.0/16", "10.0.0.0/15", false, "prefix false"},
326 | {"10.0.0.0/15", "10.0.0.0/16", true, "prefix true"},
327 | {"10.0.0.0/15", "10.0.0.0/15", true, "same"},
328 | {"10::0/15", "10.0.0.0/15", false, "ip version mismatch"},
329 | {"10::0/15", "10::0/16", true, "ipv6"},
330 | }
331 |
332 | for _, tc := range cases {
333 | t.Run(tc.name, func(t *testing.T) {
334 | _, n, _ := net.ParseCIDR(tc.network)
335 | network := NewNetwork(*n)
336 | _, n, _ = net.ParseCIDR(tc.covers)
337 | covers := NewNetwork(*n)
338 | assert.Equal(t, tc.result, network.Covers(covers))
339 | })
340 | }
341 | }
342 |
343 | func TestNetworkLeastCommonBitPosition(t *testing.T) {
344 | cases := []struct {
345 | cidr1 string
346 | cidr2 string
347 | expectedPos uint
348 | expectedErr error
349 | name string
350 | }{
351 | {"0.0.1.0/24", "0.0.0.0/24", uint(9), nil, "IPv4 diverge before mask pos"},
352 | {"0.0.0.0/24", "0.0.0.0/24", uint(8), nil, "IPv4 diverge after mask pos"},
353 | {"0.0.0.128/24", "0.0.0.0/16", uint(16), nil, "IPv4 different mask pos"},
354 | {"128.0.0.0/24", "0.0.0.0/24", 0, ErrNoGreatestCommonBit, "IPv4 diverge at 1st pos"},
355 | {"8000::/96", "8000::1:0:0/96", uint(33), nil, "IPv6 diverge before mask pos"},
356 | {"8000::/96", "8000::8:0/96", uint(32), nil, "IPv6 diverge after mask pos"},
357 | {"8000::/96", "8000::/95", uint(33), nil, "IPv6 different mask pos"},
358 | {"ffff::0/24", "0::1/24", 0, ErrNoGreatestCommonBit, "IPv6 diverge at 1st pos"},
359 | }
360 | for _, c := range cases {
361 | _, cidr1, err := net.ParseCIDR(c.cidr1)
362 | assert.NoError(t, err)
363 | _, cidr2, err := net.ParseCIDR(c.cidr2)
364 | assert.NoError(t, err)
365 | n1 := NewNetwork(*cidr1)
366 | pos, err := n1.LeastCommonBitPosition(NewNetwork(*cidr2))
367 | if c.expectedErr != nil {
368 | assert.Equal(t, c.expectedErr, err)
369 | } else {
370 | assert.Equal(t, c.expectedPos, pos)
371 | }
372 | }
373 | }
374 |
375 | func TestMask(t *testing.T) {
376 | cases := []struct {
377 | mask NetworkNumberMask
378 | ip NetworkNumber
379 | masked NetworkNumber
380 | err error
381 | name string
382 | }{
383 | {NetworkNumberMask{math.MaxUint32}, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32}, nil, "nop IPv4 mask"},
384 | {NetworkNumberMask{math.MaxUint32 - math.MaxUint16}, NetworkNumber{math.MaxUint16 + 1}, NetworkNumber{math.MaxUint16 + 1}, nil, "nop IPv4 mask"},
385 | {NetworkNumberMask{math.MaxUint32 - math.MaxUint16}, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32 - math.MaxUint16}, nil, "IPv4 masked"},
386 | {NetworkNumberMask{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, nil, "nop IPv6 mask"},
387 | {NetworkNumberMask{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, nil, "nop IPv6 mask"},
388 | {NetworkNumberMask{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, nil, "IPv6 masked"},
389 | {NetworkNumberMask{math.MaxUint32}, NetworkNumber{math.MaxUint32, 0}, nil, ErrVersionMismatch, "Version mismatch"},
390 | }
391 | for _, tc := range cases {
392 | t.Run(tc.name, func(t *testing.T) {
393 | masked, err := tc.mask.Mask(tc.ip)
394 | assert.Equal(t, tc.masked, masked)
395 | assert.Equal(t, tc.err, err)
396 | })
397 | }
398 | }
399 |
400 | func TestNextIP(t *testing.T) {
401 | cases := []struct {
402 | ip string
403 | next string
404 | name string
405 | }{
406 | {"0.0.0.0", "0.0.0.1", "IPv4 basic"},
407 | {"0.0.0.255", "0.0.1.0", "IPv4 rollover"},
408 | {"0.255.255.255", "1.0.0.0", "IPv4 consecutive rollover"},
409 | {"8000::0", "8000::1", "IPv6 basic"},
410 | {"0::ffff", "0::1:0", "IPv6 rollover"},
411 | {"0:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "1::", "IPv6 consecutive rollover"},
412 | }
413 |
414 | for _, tc := range cases {
415 | t.Run(tc.name, func(t *testing.T) {
416 | assert.Equal(t, net.ParseIP(tc.next), NextIP(net.ParseIP(tc.ip)))
417 | })
418 | }
419 | }
420 |
421 | func TestPreviousIP(t *testing.T) {
422 | cases := []struct {
423 | ip string
424 | next string
425 | name string
426 | }{
427 | {"0.0.0.1", "0.0.0.0", "IPv4 basic"},
428 | {"0.0.1.0", "0.0.0.255", "IPv4 rollover"},
429 | {"1.0.0.0", "0.255.255.255", "IPv4 consecutive rollover"},
430 | {"8000::1", "8000::0", "IPv6 basic"},
431 | {"0::1:0", "0::ffff", "IPv6 rollover"},
432 | {"1::0", "0:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "IPv6 consecutive rollover"},
433 | }
434 |
435 | for _, tc := range cases {
436 | t.Run(tc.name, func(t *testing.T) {
437 | assert.Equal(t, net.ParseIP(tc.next), PreviousIP(net.ParseIP(tc.ip)))
438 | })
439 | }
440 | }
441 |
442 | /*
443 | *********************************
444 | Benchmarking ip manipulations.
445 | *********************************
446 | */
447 | func BenchmarkNetworkNumberBitIPv4(b *testing.B) {
448 | benchmarkNetworkNumberBit(b, "52.95.110.1", 6)
449 | }
450 | func BenchmarkNetworkNumberBitIPv6(b *testing.B) {
451 | benchmarkNetworkNumberBit(b, "2600:1ffe:e000::", 44)
452 | }
453 |
454 | func BenchmarkNetworkNumberEqualIPv4(b *testing.B) {
455 | benchmarkNetworkNumberEqual(b, "52.95.110.1", "52.95.110.1")
456 | }
457 |
458 | func BenchmarkNetworkNumberEqualIPv6(b *testing.B) {
459 | benchmarkNetworkNumberEqual(b, "2600:1ffe:e000::", "2600:1ffe:e000::")
460 | }
461 |
462 | func BenchmarkNetworkContainsIPv4(b *testing.B) {
463 | benchmarkNetworkContains(b, "52.95.110.0/24", "52.95.110.1")
464 | }
465 |
466 | func BenchmarkNetworkContainsIPv6(b *testing.B) {
467 | benchmarkNetworkContains(b, "2600:1ffe:e000::/40", "2600:1ffe:f000::")
468 | }
469 |
470 | func BenchmarkNetworkEqualIPv4(b *testing.B) {
471 | benchmarkNetworkEqual(b, "192.128.0.0/24", "192.128.0.0/24")
472 | }
473 |
474 | func BenchmarkNetworkEqualIPv6(b *testing.B) {
475 | benchmarkNetworkEqual(b, "8000::/24", "8000::/24")
476 | }
477 |
478 | func benchmarkNetworkNumberBit(b *testing.B, ip string, pos uint) {
479 | nn := NewNetworkNumber(net.ParseIP(ip))
480 | for n := 0; n < b.N; n++ {
481 | nn.Bit(pos)
482 | }
483 | }
484 |
485 | func benchmarkNetworkNumberEqual(b *testing.B, ip1 string, ip2 string) {
486 | nn1 := NewNetworkNumber(net.ParseIP(ip1))
487 | nn2 := NewNetworkNumber(net.ParseIP(ip2))
488 | for n := 0; n < b.N; n++ {
489 | nn1.Equal(nn2)
490 | }
491 | }
492 |
493 | func benchmarkNetworkContains(b *testing.B, cidr string, ip string) {
494 | nn := NewNetworkNumber(net.ParseIP(ip))
495 | _, ipNet, _ := net.ParseCIDR(cidr)
496 | network := NewNetwork(*ipNet)
497 | for n := 0; n < b.N; n++ {
498 | network.Contains(nn)
499 | }
500 | }
501 |
502 | func benchmarkNetworkEqual(b *testing.B, net1 string, net2 string) {
503 | _, ipNet1, _ := net.ParseCIDR(net1)
504 | _, ipNet2, _ := net.ParseCIDR(net2)
505 | n1 := NewNetwork(*ipNet1)
506 | n2 := NewNetwork(*ipNet2)
507 | for n := 0; n < b.N; n++ {
508 | n1.Equal(n2)
509 | }
510 | }
511 |
--------------------------------------------------------------------------------
/trie.go:
--------------------------------------------------------------------------------
1 | package cidranger
2 |
3 | import (
4 | "fmt"
5 | "net"
6 | "strings"
7 |
8 | rnet "github.com/yl2chen/cidranger/net"
9 | )
10 |
11 | // prefixTrie is a path-compressed (PC) trie implementation of the
12 | // ranger interface inspired by this blog post:
13 | // https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
14 | //
15 | // CIDR blocks are stored using a prefix tree structure where each node has its
16 | // parent as prefix, and the path from the root node represents current CIDR
17 | // block.
18 | //
19 | // For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
20 | // 32 bits long and each bit represents a prefix tree starting at that bit. This
21 | // property also guarantees constant lookup time in Big-O notation.
22 | //
23 | // Path compression compresses a string of node with only 1 child into a single
24 | // node, decrease the amount of lookups necessary during containment tests.
25 | //
26 | // Level compression dictates the amount of direct children of a node by
27 | // allowing it to handle multiple bits in the path. The heuristic (based on
28 | // children population) to decide when the compression and decompression happens
29 | // is outlined in the prior linked blog, and will be experimented with in more
30 | // depth in this project in the future.
31 | //
32 | // Note: Can not insert both IPv4 and IPv6 network addresses into the same
33 | // prefix trie, use versionedRanger wrapper instead.
34 | //
35 | // TODO: Implement level-compressed component of the LPC trie.
36 | type prefixTrie struct {
37 | parent *prefixTrie
38 | children []*prefixTrie
39 |
40 | numBitsSkipped uint
41 | numBitsHandled uint
42 |
43 | network rnet.Network
44 | entry RangerEntry
45 |
46 | size int // This is only maintained in the root trie.
47 | }
48 |
49 | // newPrefixTree creates a new prefixTrie.
50 | func newPrefixTree(version rnet.IPVersion) Ranger {
51 | _, rootNet, _ := net.ParseCIDR("0.0.0.0/0")
52 | if version == rnet.IPv6 {
53 | _, rootNet, _ = net.ParseCIDR("0::0/0")
54 | }
55 | return &prefixTrie{
56 | children: make([]*prefixTrie, 2, 2),
57 | numBitsSkipped: 0,
58 | numBitsHandled: 1,
59 | network: rnet.NewNetwork(*rootNet),
60 | }
61 | }
62 |
63 | func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
64 | path := &prefixTrie{
65 | children: make([]*prefixTrie, 2, 2),
66 | numBitsSkipped: numBitsSkipped,
67 | numBitsHandled: 1,
68 | network: network.Masked(int(numBitsSkipped)),
69 | }
70 | return path
71 | }
72 |
73 | func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
74 | ones, _ := network.IPNet.Mask.Size()
75 | leaf := newPathprefixTrie(network, uint(ones))
76 | leaf.entry = entry
77 | return leaf
78 | }
79 |
80 | // Insert inserts a RangerEntry into prefix trie.
81 | func (p *prefixTrie) Insert(entry RangerEntry) error {
82 | network := entry.Network()
83 | sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry)
84 | if sizeIncreased {
85 | p.size++
86 | }
87 | return err
88 | }
89 |
90 | // Remove removes RangerEntry identified by given network from trie.
91 | func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
92 | entry, err := p.remove(rnet.NewNetwork(network))
93 | if entry != nil {
94 | p.size--
95 | }
96 | return entry, err
97 | }
98 |
99 | // Contains returns boolean indicating whether given ip is contained in any
100 | // of the inserted networks.
101 | func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
102 | nn := rnet.NewNetworkNumber(ip)
103 | if nn == nil {
104 | return false, ErrInvalidNetworkNumberInput
105 | }
106 | return p.contains(nn)
107 | }
108 |
109 | // ContainingNetworks returns the list of RangerEntry(s) the given ip is
110 | // contained in in ascending prefix order.
111 | func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
112 | nn := rnet.NewNetworkNumber(ip)
113 | if nn == nil {
114 | return nil, ErrInvalidNetworkNumberInput
115 | }
116 | return p.containingNetworks(nn)
117 | }
118 |
119 | // CoveredNetworks returns the list of RangerEntry(s) the given ipnet
120 | // covers. That is, the networks that are completely subsumed by the
121 | // specified network.
122 | func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
123 | net := rnet.NewNetwork(network)
124 | return p.coveredNetworks(net)
125 | }
126 |
127 | // Len returns number of networks in ranger.
128 | func (p *prefixTrie) Len() int {
129 | return p.size
130 | }
131 |
132 | // String returns string representation of trie, mainly for visualization and
133 | // debugging.
134 | func (p *prefixTrie) String() string {
135 | children := []string{}
136 | padding := strings.Repeat("| ", p.level()+1)
137 | for bits, child := range p.children {
138 | if child == nil {
139 | continue
140 | }
141 | childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
142 | children = append(children, childStr)
143 | }
144 | return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
145 | p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
146 | }
147 |
148 | func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
149 | if !p.network.Contains(number) {
150 | return false, nil
151 | }
152 | if p.hasEntry() {
153 | return true, nil
154 | }
155 | if p.targetBitPosition() < 0 {
156 | return false, nil
157 | }
158 | bit, err := p.targetBitFromIP(number)
159 | if err != nil {
160 | return false, err
161 | }
162 | child := p.children[bit]
163 | if child != nil {
164 | return child.contains(number)
165 | }
166 | return false, nil
167 | }
168 |
169 | func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
170 | results := []RangerEntry{}
171 | if !p.network.Contains(number) {
172 | return results, nil
173 | }
174 | if p.hasEntry() {
175 | results = []RangerEntry{p.entry}
176 | }
177 | if p.targetBitPosition() < 0 {
178 | return results, nil
179 | }
180 | bit, err := p.targetBitFromIP(number)
181 | if err != nil {
182 | return nil, err
183 | }
184 | child := p.children[bit]
185 | if child != nil {
186 | ranges, err := child.containingNetworks(number)
187 | if err != nil {
188 | return nil, err
189 | }
190 | if len(ranges) > 0 {
191 | if len(results) > 0 {
192 | results = append(results, ranges...)
193 | } else {
194 | results = ranges
195 | }
196 | }
197 | }
198 | return results, nil
199 | }
200 |
201 | func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
202 | var results []RangerEntry
203 | if network.Covers(p.network) {
204 | for entry := range p.walkDepth() {
205 | results = append(results, entry)
206 | }
207 | } else if p.targetBitPosition() >= 0 {
208 | bit, err := p.targetBitFromIP(network.Number)
209 | if err != nil {
210 | return results, err
211 | }
212 | child := p.children[bit]
213 | if child != nil {
214 | return child.coveredNetworks(network)
215 | }
216 | }
217 | return results, nil
218 | }
219 |
220 | func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) {
221 | if p.network.Equal(network) {
222 | sizeIncreased := p.entry == nil
223 | p.entry = entry
224 | return sizeIncreased, nil
225 | }
226 |
227 | bit, err := p.targetBitFromIP(network.Number)
228 | if err != nil {
229 | return false, err
230 | }
231 | existingChild := p.children[bit]
232 |
233 | // No existing child, insert new leaf trie.
234 | if existingChild == nil {
235 | p.appendTrie(bit, newEntryTrie(network, entry))
236 | return true, nil
237 | }
238 |
239 | // Check whether it is necessary to insert additional path prefix between current trie and existing child,
240 | // in the case that inserted network diverges on its path to existing child.
241 | lcb, err := network.LeastCommonBitPosition(existingChild.network)
242 | divergingBitPos := int(lcb) - 1
243 | if divergingBitPos > existingChild.targetBitPosition() {
244 | pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
245 | err := p.insertPrefix(bit, pathPrefix, existingChild)
246 | if err != nil {
247 | return false, err
248 | }
249 | // Update new child
250 | existingChild = pathPrefix
251 | }
252 | return existingChild.insert(network, entry)
253 | }
254 |
255 | func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) {
256 | p.children[bit] = prefix
257 | prefix.parent = p
258 | }
259 |
260 | func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error {
261 | // Set parent/child relationship between current trie and inserted pathPrefix
262 | p.children[bit] = pathPrefix
263 | pathPrefix.parent = p
264 |
265 | // Set parent/child relationship between inserted pathPrefix and original child
266 | pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number)
267 | if err != nil {
268 | return err
269 | }
270 | pathPrefix.children[pathPrefixBit] = child
271 | child.parent = pathPrefix
272 | return nil
273 | }
274 |
275 | func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
276 | if p.hasEntry() && p.network.Equal(network) {
277 | entry := p.entry
278 | p.entry = nil
279 |
280 | err := p.compressPathIfPossible()
281 | if err != nil {
282 | return nil, err
283 | }
284 | return entry, nil
285 | }
286 | if p.targetBitPosition() < 0 {
287 | return nil, nil
288 | }
289 | bit, err := p.targetBitFromIP(network.Number)
290 | if err != nil {
291 | return nil, err
292 | }
293 | child := p.children[bit]
294 | if child != nil {
295 | return child.remove(network)
296 | }
297 | return nil, nil
298 | }
299 |
300 | func (p *prefixTrie) qualifiesForPathCompression() bool {
301 | // Current prefix trie can be path compressed if it meets all following.
302 | // 1. records no CIDR entry
303 | // 2. has single or no child
304 | // 3. is not root trie
305 | return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil
306 | }
307 |
308 | func (p *prefixTrie) compressPathIfPossible() error {
309 | if !p.qualifiesForPathCompression() {
310 | // Does not qualify to be compressed
311 | return nil
312 | }
313 |
314 | // Find lone child.
315 | var loneChild *prefixTrie
316 | for _, child := range p.children {
317 | if child != nil {
318 | loneChild = child
319 | break
320 | }
321 | }
322 |
323 | // Find root of currnt single child lineage.
324 | parent := p.parent
325 | for ; parent.qualifiesForPathCompression(); parent = parent.parent {
326 | }
327 | parentBit, err := parent.targetBitFromIP(p.network.Number)
328 | if err != nil {
329 | return err
330 | }
331 | parent.children[parentBit] = loneChild
332 |
333 | // Attempts to furthur apply path compression at current lineage parent, in case current lineage
334 | // compressed into parent.
335 | return parent.compressPathIfPossible()
336 | }
337 |
338 | func (p *prefixTrie) childrenCount() int {
339 | count := 0
340 | for _, child := range p.children {
341 | if child != nil {
342 | count++
343 | }
344 | }
345 | return count
346 | }
347 |
348 | func (p *prefixTrie) totalNumberOfBits() uint {
349 | return rnet.BitsPerUint32 * uint(len(p.network.Number))
350 | }
351 |
352 | func (p *prefixTrie) targetBitPosition() int {
353 | return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
354 | }
355 |
356 | func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
357 | // This is a safe uint boxing of int since we should never attempt to get
358 | // target bit at a negative position.
359 | return n.Bit(uint(p.targetBitPosition()))
360 | }
361 |
362 | func (p *prefixTrie) hasEntry() bool {
363 | return p.entry != nil
364 | }
365 |
366 | func (p *prefixTrie) level() int {
367 | if p.parent == nil {
368 | return 0
369 | }
370 | return p.parent.level() + 1
371 | }
372 |
373 | // walkDepth walks the trie in depth order, for unit testing.
374 | func (p *prefixTrie) walkDepth() <-chan RangerEntry {
375 | entries := make(chan RangerEntry)
376 | go func() {
377 | if p.hasEntry() {
378 | entries <- p.entry
379 | }
380 | childEntriesList := []<-chan RangerEntry{}
381 | for _, trie := range p.children {
382 | if trie == nil {
383 | continue
384 | }
385 | childEntriesList = append(childEntriesList, trie.walkDepth())
386 | }
387 | for _, childEntries := range childEntriesList {
388 | for entry := range childEntries {
389 | entries <- entry
390 | }
391 | }
392 | close(entries)
393 | }()
394 | return entries
395 | }
396 |
--------------------------------------------------------------------------------
/trie_test.go:
--------------------------------------------------------------------------------
1 | package cidranger
2 |
3 | import (
4 | "encoding/binary"
5 | "math/rand"
6 | "net"
7 | "runtime"
8 | "testing"
9 | "time"
10 |
11 | "github.com/stretchr/testify/assert"
12 | rnet "github.com/yl2chen/cidranger/net"
13 | )
14 |
15 | func getAllByVersion(version rnet.IPVersion) *net.IPNet {
16 | if version == rnet.IPv6 {
17 | return AllIPv6
18 | }
19 | return AllIPv4
20 | }
21 |
22 | func TestPrefixTrieInsert(t *testing.T) {
23 | cases := []struct {
24 | version rnet.IPVersion
25 | inserts []string
26 | expectedNetworksInDepthOrder []string
27 | name string
28 | }{
29 | {rnet.IPv4, []string{"192.168.0.1/24"}, []string{"192.168.0.1/24"}, "basic insert"},
30 | {
31 | rnet.IPv4,
32 | []string{"1.2.3.4/32", "1.2.3.5/32"},
33 | []string{"1.2.3.4/32", "1.2.3.5/32"},
34 | "single ip IPv4 network insert",
35 | },
36 | {
37 | rnet.IPv6,
38 | []string{"0::1/128", "0::2/128"},
39 | []string{"0::1/128", "0::2/128"},
40 | "single ip IPv6 network insert",
41 | },
42 | {
43 | rnet.IPv4,
44 | []string{"192.168.0.1/16", "192.168.0.1/24"},
45 | []string{"192.168.0.1/16", "192.168.0.1/24"},
46 | "in order insert",
47 | },
48 | {
49 | rnet.IPv4,
50 | []string{"192.168.0.1/32", "192.168.0.1/32"},
51 | []string{"192.168.0.1/32"},
52 | "duplicate network insert",
53 | },
54 | {
55 | rnet.IPv4,
56 | []string{"192.168.0.1/24", "192.168.0.1/16"},
57 | []string{"192.168.0.1/16", "192.168.0.1/24"},
58 | "reverse insert",
59 | },
60 | {
61 | rnet.IPv4,
62 | []string{"192.168.0.1/24", "192.168.1.1/24"},
63 | []string{"192.168.0.1/24", "192.168.1.1/24"},
64 | "branch insert",
65 | },
66 | {
67 | rnet.IPv4,
68 | []string{"192.168.0.1/24", "192.168.1.1/24", "192.168.1.1/30"},
69 | []string{"192.168.0.1/24", "192.168.1.1/24", "192.168.1.1/30"},
70 | "branch inserts",
71 | },
72 | }
73 | for _, tc := range cases {
74 | t.Run(tc.name, func(t *testing.T) {
75 | trie := newPrefixTree(tc.version).(*prefixTrie)
76 | for _, insert := range tc.inserts {
77 | _, network, _ := net.ParseCIDR(insert)
78 | err := trie.Insert(NewBasicRangerEntry(*network))
79 | assert.NoError(t, err)
80 | }
81 |
82 | assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match")
83 |
84 | allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version))
85 | assert.Nil(t, err)
86 | assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match")
87 |
88 | walk := trie.walkDepth()
89 | for _, network := range tc.expectedNetworksInDepthOrder {
90 | _, ipnet, _ := net.ParseCIDR(network)
91 | expected := NewBasicRangerEntry(*ipnet)
92 | actual := <-walk
93 | assert.Equal(t, expected, actual)
94 | }
95 |
96 | // Ensure no unexpected elements in trie.
97 | for network := range walk {
98 | assert.Nil(t, network)
99 | }
100 | })
101 | }
102 | }
103 |
104 | func TestPrefixTrieString(t *testing.T) {
105 | inserts := []string{"192.168.0.1/24", "192.168.1.1/24", "192.168.1.1/30"}
106 | trie := newPrefixTree(rnet.IPv4).(*prefixTrie)
107 | for _, insert := range inserts {
108 | _, network, _ := net.ParseCIDR(insert)
109 | trie.Insert(NewBasicRangerEntry(*network))
110 | }
111 | expected := `0.0.0.0/0 (target_pos:31:has_entry:false)
112 | | 1--> 192.168.0.0/23 (target_pos:8:has_entry:false)
113 | | | 0--> 192.168.0.0/24 (target_pos:7:has_entry:true)
114 | | | 1--> 192.168.1.0/24 (target_pos:7:has_entry:true)
115 | | | | 0--> 192.168.1.0/30 (target_pos:1:has_entry:true)`
116 | assert.Equal(t, expected, trie.String())
117 | }
118 |
119 | func TestPrefixTrieRemove(t *testing.T) {
120 | cases := []struct {
121 | version rnet.IPVersion
122 | inserts []string
123 | removes []string
124 | expectedRemoves []string
125 | expectedNetworksInDepthOrder []string
126 | expectedTrieString string
127 | name string
128 | }{
129 | {
130 | rnet.IPv4,
131 | []string{"192.168.0.1/24"},
132 | []string{"192.168.0.1/24"},
133 | []string{"192.168.0.1/24"},
134 | []string{},
135 | "0.0.0.0/0 (target_pos:31:has_entry:false)",
136 | "basic remove",
137 | },
138 | {
139 | rnet.IPv4,
140 | []string{"192.168.0.1/32"},
141 | []string{"192.168.0.1/24"},
142 | []string{""},
143 | []string{"192.168.0.1/32"},
144 | `0.0.0.0/0 (target_pos:31:has_entry:false)
145 | | 1--> 192.168.0.1/32 (target_pos:-1:has_entry:true)`,
146 | "remove from ranger that contains a single ip block",
147 | },
148 | {
149 | rnet.IPv4,
150 | []string{"1.2.3.4/32", "1.2.3.5/32"},
151 | []string{"1.2.3.5/32"},
152 | []string{"1.2.3.5/32"},
153 | []string{"1.2.3.4/32"},
154 | `0.0.0.0/0 (target_pos:31:has_entry:false)
155 | | 0--> 1.2.3.4/32 (target_pos:-1:has_entry:true)`,
156 | "single ip IPv4 network remove",
157 | },
158 | {
159 | rnet.IPv4,
160 | []string{"0::1/128", "0::2/128"},
161 | []string{"0::2/128"},
162 | []string{"0::2/128"},
163 | []string{"0::1/128"},
164 | `0.0.0.0/0 (target_pos:31:has_entry:false)
165 | | 0--> ::1/128 (target_pos:-1:has_entry:true)`,
166 | "single ip IPv6 network remove",
167 | },
168 | {
169 | rnet.IPv4,
170 | []string{"192.168.0.1/24", "192.168.0.1/25", "192.168.0.1/26"},
171 | []string{"192.168.0.1/25"},
172 | []string{"192.168.0.1/25"},
173 | []string{"192.168.0.1/24", "192.168.0.1/26"},
174 | `0.0.0.0/0 (target_pos:31:has_entry:false)
175 | | 1--> 192.168.0.0/24 (target_pos:7:has_entry:true)
176 | | | 0--> 192.168.0.0/26 (target_pos:5:has_entry:true)`,
177 | "remove path prefix",
178 | },
179 | {
180 | rnet.IPv4,
181 | []string{"192.168.0.1/24", "192.168.0.1/25", "192.168.0.64/26", "192.168.0.1/26"},
182 | []string{"192.168.0.1/25"},
183 | []string{"192.168.0.1/25"},
184 | []string{"192.168.0.1/24", "192.168.0.1/26", "192.168.0.64/26"},
185 | `0.0.0.0/0 (target_pos:31:has_entry:false)
186 | | 1--> 192.168.0.0/24 (target_pos:7:has_entry:true)
187 | | | 0--> 192.168.0.0/25 (target_pos:6:has_entry:false)
188 | | | | 0--> 192.168.0.0/26 (target_pos:5:has_entry:true)
189 | | | | 1--> 192.168.0.64/26 (target_pos:5:has_entry:true)`,
190 | "remove path prefix with more than 1 children",
191 | },
192 | {
193 | rnet.IPv4,
194 | []string{"192.168.0.1/24", "192.168.0.1/25"},
195 | []string{"192.168.0.1/26"},
196 | []string{""},
197 | []string{"192.168.0.1/24", "192.168.0.1/25"},
198 | `0.0.0.0/0 (target_pos:31:has_entry:false)
199 | | 1--> 192.168.0.0/24 (target_pos:7:has_entry:true)
200 | | | 0--> 192.168.0.0/25 (target_pos:6:has_entry:true)`,
201 | "remove non existent",
202 | },
203 | }
204 |
205 | for _, tc := range cases {
206 | t.Run(tc.name, func(t *testing.T) {
207 | trie := newPrefixTree(tc.version).(*prefixTrie)
208 | for _, insert := range tc.inserts {
209 | _, network, _ := net.ParseCIDR(insert)
210 | err := trie.Insert(NewBasicRangerEntry(*network))
211 | assert.NoError(t, err)
212 | }
213 | for i, remove := range tc.removes {
214 | _, network, _ := net.ParseCIDR(remove)
215 | removed, err := trie.Remove(*network)
216 | assert.NoError(t, err)
217 | if str := tc.expectedRemoves[i]; str != "" {
218 | _, ipnet, _ := net.ParseCIDR(str)
219 | expected := NewBasicRangerEntry(*ipnet)
220 | assert.Equal(t, expected, removed)
221 | } else {
222 | assert.Nil(t, removed)
223 | }
224 | }
225 |
226 | assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match after revmoval")
227 |
228 | allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version))
229 | assert.Nil(t, err)
230 | assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match")
231 |
232 | walk := trie.walkDepth()
233 | for _, network := range tc.expectedNetworksInDepthOrder {
234 | _, ipnet, _ := net.ParseCIDR(network)
235 | expected := NewBasicRangerEntry(*ipnet)
236 | actual := <-walk
237 | assert.Equal(t, expected, actual)
238 | }
239 |
240 | // Ensure no unexpected elements in trie.
241 | for network := range walk {
242 | assert.Nil(t, network)
243 | }
244 |
245 | assert.Equal(t, tc.expectedTrieString, trie.String())
246 | })
247 | }
248 | }
249 |
250 | func TestToReplicateIssue(t *testing.T) {
251 | cases := []struct {
252 | version rnet.IPVersion
253 | inserts []string
254 | ip net.IP
255 | networks []string
256 | name string
257 | }{
258 | {
259 | rnet.IPv4,
260 | []string{"192.168.0.1/32"},
261 | net.ParseIP("192.168.0.1"),
262 | []string{"192.168.0.1/32"},
263 | "basic containing network for /32 mask",
264 | },
265 | {
266 | rnet.IPv6,
267 | []string{"a::1/128"},
268 | net.ParseIP("a::1"),
269 | []string{"a::1/128"},
270 | "basic containing network for /128 mask",
271 | },
272 | }
273 | for _, tc := range cases {
274 | t.Run(tc.name, func(t *testing.T) {
275 | trie := newPrefixTree(tc.version)
276 | for _, insert := range tc.inserts {
277 | _, network, _ := net.ParseCIDR(insert)
278 | err := trie.Insert(NewBasicRangerEntry(*network))
279 | assert.NoError(t, err)
280 | }
281 | expectedEntries := []RangerEntry{}
282 | for _, network := range tc.networks {
283 | _, net, _ := net.ParseCIDR(network)
284 | expectedEntries = append(expectedEntries, NewBasicRangerEntry(*net))
285 | }
286 | contains, err := trie.Contains(tc.ip)
287 | assert.NoError(t, err)
288 | assert.True(t, contains)
289 | networks, err := trie.ContainingNetworks(tc.ip)
290 | assert.NoError(t, err)
291 | assert.Equal(t, expectedEntries, networks)
292 | })
293 | }
294 | }
295 |
296 | type expectedIPRange struct {
297 | start net.IP
298 | end net.IP
299 | }
300 |
301 | func TestPrefixTrieContains(t *testing.T) {
302 | cases := []struct {
303 | version rnet.IPVersion
304 | inserts []string
305 | expectedIPs []expectedIPRange
306 | name string
307 | }{
308 | {
309 | rnet.IPv4,
310 | []string{"192.168.0.0/24"},
311 | []expectedIPRange{
312 | {net.ParseIP("192.168.0.0"), net.ParseIP("192.168.1.0")},
313 | },
314 | "basic contains",
315 | },
316 | {
317 | rnet.IPv4,
318 | []string{"192.168.0.0/24", "128.168.0.0/24"},
319 | []expectedIPRange{
320 | {net.ParseIP("192.168.0.0"), net.ParseIP("192.168.1.0")},
321 | {net.ParseIP("128.168.0.0"), net.ParseIP("128.168.1.0")},
322 | },
323 | "multiple ranges contains",
324 | },
325 | }
326 |
327 | for _, tc := range cases {
328 | t.Run(tc.name, func(t *testing.T) {
329 | trie := newPrefixTree(tc.version)
330 | for _, insert := range tc.inserts {
331 | _, network, _ := net.ParseCIDR(insert)
332 | err := trie.Insert(NewBasicRangerEntry(*network))
333 | assert.NoError(t, err)
334 | }
335 | for _, expectedIPRange := range tc.expectedIPs {
336 | var contains bool
337 | var err error
338 | start := expectedIPRange.start
339 | for ; !expectedIPRange.end.Equal(start); start = rnet.NextIP(start) {
340 | contains, err = trie.Contains(start)
341 | assert.NoError(t, err)
342 | assert.True(t, contains)
343 | }
344 |
345 | // Check out of bounds ips on both ends
346 | contains, err = trie.Contains(rnet.PreviousIP(expectedIPRange.start))
347 | assert.NoError(t, err)
348 | assert.False(t, contains)
349 | contains, err = trie.Contains(rnet.NextIP(expectedIPRange.end))
350 | assert.NoError(t, err)
351 | assert.False(t, contains)
352 | }
353 | })
354 | }
355 | }
356 |
357 | func TestPrefixTrieContainingNetworks(t *testing.T) {
358 | cases := []struct {
359 | version rnet.IPVersion
360 | inserts []string
361 | ip net.IP
362 | networks []string
363 | name string
364 | }{
365 | {
366 | rnet.IPv4,
367 | []string{"192.168.0.0/24"},
368 | net.ParseIP("192.168.0.1"),
369 | []string{"192.168.0.0/24"},
370 | "basic containing networks",
371 | },
372 | {
373 | rnet.IPv4,
374 | []string{"192.168.0.0/24", "192.168.0.0/25"},
375 | net.ParseIP("192.168.0.1"),
376 | []string{"192.168.0.0/24", "192.168.0.0/25"},
377 | "inclusive networks",
378 | },
379 | }
380 | for _, tc := range cases {
381 | t.Run(tc.name, func(t *testing.T) {
382 | trie := newPrefixTree(tc.version)
383 | for _, insert := range tc.inserts {
384 | _, network, _ := net.ParseCIDR(insert)
385 | err := trie.Insert(NewBasicRangerEntry(*network))
386 | assert.NoError(t, err)
387 | }
388 | expectedEntries := []RangerEntry{}
389 | for _, network := range tc.networks {
390 | _, net, _ := net.ParseCIDR(network)
391 | expectedEntries = append(expectedEntries, NewBasicRangerEntry(*net))
392 | }
393 | networks, err := trie.ContainingNetworks(tc.ip)
394 | assert.NoError(t, err)
395 | assert.Equal(t, expectedEntries, networks)
396 | })
397 | }
398 | }
399 |
400 | type coveredNetworkTest struct {
401 | version rnet.IPVersion
402 | inserts []string
403 | search string
404 | networks []string
405 | name string
406 | }
407 |
408 | var coveredNetworkTests = []coveredNetworkTest{
409 | {
410 | rnet.IPv4,
411 | []string{"192.168.0.0/24"},
412 | "192.168.0.0/16",
413 | []string{"192.168.0.0/24"},
414 | "basic covered networks",
415 | },
416 | {
417 | rnet.IPv4,
418 | []string{"192.168.0.0/24"},
419 | "10.1.0.0/16",
420 | nil,
421 | "nothing",
422 | },
423 | {
424 | rnet.IPv4,
425 | []string{"192.168.0.0/24", "192.168.0.0/25"},
426 | "192.168.0.0/16",
427 | []string{"192.168.0.0/24", "192.168.0.0/25"},
428 | "multiple networks",
429 | },
430 | {
431 | rnet.IPv4,
432 | []string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"},
433 | "192.168.0.0/16",
434 | []string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"},
435 | "multiple networks 2",
436 | },
437 | {
438 | rnet.IPv4,
439 | []string{"192.168.1.1/32"},
440 | "192.168.0.0/16",
441 | []string{"192.168.1.1/32"},
442 | "leaf",
443 | },
444 | {
445 | rnet.IPv4,
446 | []string{"0.0.0.0/0", "192.168.1.1/32"},
447 | "192.168.0.0/16",
448 | []string{"192.168.1.1/32"},
449 | "leaf with root",
450 | },
451 | {
452 | rnet.IPv4,
453 | []string{
454 | "0.0.0.0/0", "192.168.0.0/24", "192.168.1.1/32",
455 | "10.1.0.0/16", "10.1.1.0/24",
456 | },
457 | "192.168.0.0/16",
458 | []string{"192.168.0.0/24", "192.168.1.1/32"},
459 | "path not taken",
460 | },
461 | {
462 | rnet.IPv4,
463 | []string{
464 | "192.168.0.0/15",
465 | },
466 | "192.168.0.0/16",
467 | nil,
468 | "only masks different",
469 | },
470 | }
471 |
472 | func TestPrefixTrieCoveredNetworks(t *testing.T) {
473 | for _, tc := range coveredNetworkTests {
474 | t.Run(tc.name, func(t *testing.T) {
475 | trie := newPrefixTree(tc.version)
476 | for _, insert := range tc.inserts {
477 | _, network, _ := net.ParseCIDR(insert)
478 | err := trie.Insert(NewBasicRangerEntry(*network))
479 | assert.NoError(t, err)
480 | }
481 | var expectedEntries []RangerEntry
482 | for _, network := range tc.networks {
483 | _, net, _ := net.ParseCIDR(network)
484 | expectedEntries = append(expectedEntries,
485 | NewBasicRangerEntry(*net))
486 | }
487 | _, snet, _ := net.ParseCIDR(tc.search)
488 | networks, err := trie.CoveredNetworks(*snet)
489 | assert.NoError(t, err)
490 | assert.Equal(t, expectedEntries, networks)
491 | })
492 | }
493 | }
494 |
495 | func TestTrieMemUsage(t *testing.T) {
496 | if testing.Short() {
497 | t.Skip("Skipping memory test in `-short` mode")
498 | }
499 | numIPs := 100000
500 | runs := 10
501 |
502 | // Avg heap allocation over all runs should not be more than the heap allocation of first run multiplied
503 | // by threshold, picking 1% as sane number for detecting memory leak.
504 | thresh := 1.01
505 |
506 | trie := newPrefixTree(rnet.IPv4)
507 |
508 | var baseLineHeap, totalHeapAllocOverRuns uint64
509 | for i := 0; i < runs; i++ {
510 | t.Logf("Executing Run %d of %d", i+1, runs)
511 |
512 | // Insert networks.
513 | for n := 0; n < numIPs; n++ {
514 | trie.Insert(NewBasicRangerEntry(GenLeafIPNet(GenIPV4())))
515 | }
516 | t.Logf("Inserted All (%d networks)", trie.Len())
517 | assert.Less(t, 0, trie.Len(), "Len should > 0")
518 | assert.LessOrEqualf(t, trie.Len(), numIPs, "Len should <= %d", numIPs)
519 |
520 | allNetworks, err := trie.CoveredNetworks(*getAllByVersion(rnet.IPv4))
521 | assert.Nil(t, err)
522 | assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match")
523 |
524 | // Remove networks.
525 | _, all, _ := net.ParseCIDR("0.0.0.0/0")
526 | ll, _ := trie.CoveredNetworks(*all)
527 | for i := 0; i < len(ll); i++ {
528 | trie.Remove(ll[i].Network())
529 | }
530 | t.Logf("Removed All (%d networks)", len(ll))
531 | assert.Equal(t, 0, trie.Len(), "Len after removal should == 0")
532 |
533 | // Perform GC
534 | runtime.GC()
535 |
536 | // Get HeapAlloc stats.
537 | heapAlloc := GetHeapAllocation()
538 | totalHeapAllocOverRuns += heapAlloc
539 | if i == 0 {
540 | baseLineHeap = heapAlloc
541 | }
542 | }
543 |
544 | // Assert that heap allocation from first loop is within set threshold of avg over all runs.
545 | assert.Less(t, uint64(0), baseLineHeap)
546 | assert.LessOrEqual(t, float64(baseLineHeap), float64(totalHeapAllocOverRuns/uint64(runs))*thresh)
547 | }
548 |
549 | func GenLeafIPNet(ip net.IP) net.IPNet {
550 | return net.IPNet{
551 | IP: ip,
552 | Mask: net.CIDRMask(32, 32),
553 | }
554 | }
555 |
556 | // GenIPV4 generates an IPV4 address
557 | func GenIPV4() net.IP {
558 | rand.Seed(time.Now().UnixNano())
559 | nn := rand.Uint32()
560 | if nn < 4294967295 {
561 | nn++
562 | }
563 | ip := make(net.IP, 4)
564 | binary.BigEndian.PutUint32(ip, uint32(nn))
565 | return ip
566 | }
567 |
568 | func GetHeapAllocation() uint64 {
569 | var m runtime.MemStats
570 | runtime.ReadMemStats(&m)
571 | return m.HeapAlloc
572 | }
573 |
--------------------------------------------------------------------------------
/version.go:
--------------------------------------------------------------------------------
1 | package cidranger
2 |
3 | import (
4 | "net"
5 |
6 | rnet "github.com/yl2chen/cidranger/net"
7 | )
8 |
9 | type rangerFactory func(rnet.IPVersion) Ranger
10 |
11 | type versionedRanger struct {
12 | ipV4Ranger Ranger
13 | ipV6Ranger Ranger
14 | }
15 |
16 | func newVersionedRanger(factory rangerFactory) Ranger {
17 | return &versionedRanger{
18 | ipV4Ranger: factory(rnet.IPv4),
19 | ipV6Ranger: factory(rnet.IPv6),
20 | }
21 | }
22 |
23 | func (v *versionedRanger) Insert(entry RangerEntry) error {
24 | network := entry.Network()
25 | ranger, err := v.getRangerForIP(network.IP)
26 | if err != nil {
27 | return err
28 | }
29 | return ranger.Insert(entry)
30 | }
31 |
32 | func (v *versionedRanger) Remove(network net.IPNet) (RangerEntry, error) {
33 | ranger, err := v.getRangerForIP(network.IP)
34 | if err != nil {
35 | return nil, err
36 | }
37 | return ranger.Remove(network)
38 | }
39 |
40 | func (v *versionedRanger) Contains(ip net.IP) (bool, error) {
41 | ranger, err := v.getRangerForIP(ip)
42 | if err != nil {
43 | return false, err
44 | }
45 | return ranger.Contains(ip)
46 | }
47 |
48 | func (v *versionedRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
49 | ranger, err := v.getRangerForIP(ip)
50 | if err != nil {
51 | return nil, err
52 | }
53 | return ranger.ContainingNetworks(ip)
54 | }
55 |
56 | func (v *versionedRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
57 | ranger, err := v.getRangerForIP(network.IP)
58 | if err != nil {
59 | return nil, err
60 | }
61 | return ranger.CoveredNetworks(network)
62 | }
63 |
64 | // Len returns number of networks in ranger.
65 | func (v *versionedRanger) Len() int {
66 | return v.ipV4Ranger.Len() + v.ipV6Ranger.Len()
67 | }
68 |
69 | func (v *versionedRanger) getRangerForIP(ip net.IP) (Ranger, error) {
70 | if ip.To4() != nil {
71 | return v.ipV4Ranger, nil
72 | }
73 | if ip.To16() != nil {
74 | return v.ipV6Ranger, nil
75 | }
76 | return nil, ErrInvalidNetworkNumberInput
77 | }
78 |
--------------------------------------------------------------------------------