├── .circleci └── config.yml ├── .github └── workflows │ ├── go.yml │ └── test.yaml ├── .gitignore ├── LICENSE.md ├── README.md ├── codec.go ├── codec_test.go ├── go.mod ├── go.sum ├── messages.go ├── messages_test.go ├── odoh.go └── odoh_test.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Golang CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-go/ for more details 4 | version: 2 5 | jobs: 6 | build: 7 | docker: 8 | - image: circleci/golang:1.14 9 | working_directory: /go/src/github.com/chris-wood/odoh 10 | steps: 11 | - checkout 12 | 13 | # specify any bash command here prefixed with `run: ` 14 | - run: go get -v -t -d ./... 15 | - run: go test -v ./... 16 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.13 20 | id: go 21 | 22 | - name: Check out code into the Go module directory 23 | uses: actions/checkout@v2 24 | 25 | - name: Get dependencies 26 | run: | 27 | go get -v -t -d ./... 28 | if [ -f Gopkg.toml ]; then 29 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh 30 | dep ensure 31 | fi 32 | 33 | - name: Build 34 | run: go build -v . 35 | 36 | - name: Test 37 | run: go test -v . 38 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | test: 11 | name: Test with Coverage 12 | runs-on: ubuntu-latest 13 | steps: 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v1 17 | with: 18 | go-version: '1.13' 19 | 20 | - name: Check out code 21 | uses: actions/checkout@v2 22 | 23 | - name: Install dependencies 24 | run: | 25 | go mod download 26 | 27 | - name: Run Unit tests 28 | run: | 29 | go test -race -covermode atomic -coverprofile=covprofile ./... 30 | 31 | - name: Install goveralls 32 | env: 33 | GO111MODULE: off 34 | run: go get github.com/mattn/goveralls 35 | 36 | - name: Send coverage 37 | env: 38 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 39 | run: $(go env GOPATH)/bin/goveralls -coverprofile=covprofile -service=github -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019 Apple, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Oblivious DoH](https://tools.ietf.org/html/draft-pauly-dprive-oblivious-doh) 2 | 3 | [![CircleCI](https://circleci.com/gh/chris-wood/odoh.svg?style=svg)](https://circleci.com/gh/chris-wood/odoh) 4 | [![Coverage Status](https://coveralls.io/repos/github/chris-wood/odoh/badge.svg?branch=master)](https://coveralls.io/github/chris-wood/odoh?branch=master) 5 | [![GoDoc](https://godoc.org/github.com/chris-wood/odoh?status.svg)](https://godoc.org/github.com/chris-wood/odoh) 6 | 7 | This library implements draft -02 of [Oblivious DoH](https://tools.ietf.org/html/draft-pauly-dprive-oblivious-doh-02). 8 | 9 | ## Test vector generation 10 | 11 | To generate test vectors, run: 12 | 13 | ``` 14 | $ ODOH_TEST_VECTORS_OUT=test-vectors.json go test -v -run TestVectorGenerate 15 | ``` 16 | 17 | To check test vectors, run: 18 | 19 | ``` 20 | $ ODOH_TEST_VECTORS_IN=test-vectors.json go test -v -run TestVectorVerify 21 | ``` 22 | -------------------------------------------------------------------------------- /codec.go: -------------------------------------------------------------------------------- 1 | // The MIT License 2 | // 3 | // Copyright (c) 2019 Apple, Inc. 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in 13 | // all copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | // THE SOFTWARE. 22 | 23 | package odoh 24 | 25 | import ( 26 | "encoding/binary" 27 | "fmt" 28 | ) 29 | 30 | func encodeLengthPrefixedSlice(slice []byte) []byte { 31 | result := make([]byte, 2) 32 | binary.BigEndian.PutUint16(result, uint16(len(slice))) 33 | return append(result, slice...) 34 | } 35 | 36 | func decodeLengthPrefixedSlice(slice []byte) ([]byte, int, error) { 37 | if len(slice) < 2 { 38 | return nil, 0, fmt.Errorf("Expected at least 2 bytes of length encoded prefix") 39 | } 40 | 41 | length := binary.BigEndian.Uint16(slice) 42 | if int(2+length) > len(slice) { 43 | return nil, 0, fmt.Errorf("Insufficient data. Expected %d, got %d", 2+length, len(slice)) 44 | } 45 | 46 | return slice[2 : 2+length], int(2 + length), nil 47 | } 48 | -------------------------------------------------------------------------------- /codec_test.go: -------------------------------------------------------------------------------- 1 | package odoh 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestEncodeEmptySlice(t *testing.T) { 9 | expectedBytes := []byte{0x00, 0x00} 10 | if !bytes.Equal(encodeLengthPrefixedSlice(nil), expectedBytes) { 11 | t.Fatalf("Result mismatch.") 12 | } 13 | } 14 | 15 | func TestEncodeLengthPrefixedSlice(t *testing.T) { 16 | testData := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} 17 | result := encodeLengthPrefixedSlice(testData) 18 | expectedBytes := []byte{0x00, 0x09, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} 19 | 20 | if !bytes.Equal(result, expectedBytes) { 21 | t.Fatalf("Result mismatch.") 22 | } 23 | } 24 | 25 | func TestDecodeLengthPrefixedSlice(t *testing.T) { 26 | testData := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} 27 | result := encodeLengthPrefixedSlice(testData) 28 | decodedBytes, length, err := decodeLengthPrefixedSlice(result) 29 | if err != nil { 30 | t.Fatalf("Raised an error. Decoding error.") 31 | } 32 | if !bytes.Equal(testData, decodedBytes) { 33 | t.Fatalf("Decoding result mismatch.") 34 | } 35 | if len(testData)+2 != length { 36 | t.Fatalf("Incorrect length in the encoded message.") 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/chris-wood/odoh 2 | 3 | go 1.14 4 | 5 | require github.com/cisco/go-hpke v0.0.0-20200904203048-9e7d3e90b7c3 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | git.schwanenlied.me/yawning/x448.git v0.0.0-20170617130356-01b048fb03d6 h1:w8IZgCntCe0RuBJp+dENSMwEBl/k8saTgJ5hPca5IWw= 2 | git.schwanenlied.me/yawning/x448.git v0.0.0-20170617130356-01b048fb03d6/go.mod h1:wQaGCqEu44ykB17jZHCevrgSVl3KJnwQBObUtrKU4uU= 3 | github.com/cisco/go-hpke v0.0.0-20200904203048-9e7d3e90b7c3 h1:3PT/MB4kSeuHr78O8Dkf538V7HrtdckYyi4STn8iJYM= 4 | github.com/cisco/go-hpke v0.0.0-20200904203048-9e7d3e90b7c3/go.mod h1:AyK7f6CWiLAvOFmAyCEF5xDN51zS6PIZgj3Qq7hla1Y= 5 | github.com/cisco/go-tls-syntax v0.0.0-20200617162716-46b0cfb76b9b h1:Ves2turKTX7zruivAcUOQg155xggcbv3suVdbKCBQNM= 6 | github.com/cisco/go-tls-syntax v0.0.0-20200617162716-46b0cfb76b9b/go.mod h1:0AZAV7lYvynZQ5ErHlGMKH+4QYMyNCFd+AiL9MlrCYA= 7 | github.com/cloudflare/circl v1.0.0 h1:64b6pyfCFbYm623ncIkYGNZaOcmIbyd+CjyMi2L9vdI= 8 | github.com/cloudflare/circl v1.0.0/go.mod h1:MhjB3NEEhJbTOdLLq964NIUisXDxaE1WkQPUxtgZXiY= 9 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 10 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 13 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 14 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= 15 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 16 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 17 | golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= 18 | golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 19 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 20 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 21 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 22 | golang.org/x/sys v0.0.0-20190602015325-4c4f7f33c9ed h1:uPxWBzB3+mlnjy9W58qY1j/cjyFjutgw/Vhan2zLy/A= 23 | golang.org/x/sys v0.0.0-20190602015325-4c4f7f33c9ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 24 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 25 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 26 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 27 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 28 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 29 | -------------------------------------------------------------------------------- /messages.go: -------------------------------------------------------------------------------- 1 | // The MIT License 2 | // 3 | // Copyright (c) 2019 Apple, Inc. 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in 13 | // all copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | // THE SOFTWARE. 22 | 23 | package odoh 24 | 25 | import ( 26 | "encoding/binary" 27 | "fmt" 28 | ) 29 | 30 | type ObliviousMessageType uint8 31 | 32 | const ( 33 | QueryType ObliviousMessageType = 0x01 34 | ResponseType ObliviousMessageType = 0x02 35 | ) 36 | 37 | // 38 | // struct { 39 | // opaque dns_message<1..2^16-1>; 40 | // opaque padding<0..2^16-1>; 41 | // } ObliviousDoHQueryBody; 42 | // 43 | type ObliviousDNSMessageBody struct { 44 | DnsMessage []byte 45 | Padding []byte 46 | } 47 | 48 | func (m ObliviousDNSMessageBody) Marshal() []byte { 49 | return append(encodeLengthPrefixedSlice(m.DnsMessage), encodeLengthPrefixedSlice(m.Padding)...) 50 | } 51 | 52 | func UnmarshalMessageBody(data []byte) (ObliviousDNSMessageBody, error) { 53 | messageLength := binary.BigEndian.Uint16(data) 54 | if int(2+messageLength) > len(data) { 55 | return ObliviousDNSMessageBody{}, fmt.Errorf("Invalid DNS message length") 56 | } 57 | message := data[2 : 2+messageLength] 58 | 59 | paddingLength := binary.BigEndian.Uint16(data[2+messageLength:]) 60 | if int(2+messageLength+2+paddingLength) > len(data) { 61 | return ObliviousDNSMessageBody{}, fmt.Errorf("Invalid DNS padding length") 62 | } 63 | 64 | padding := data[2+messageLength+2 : 2+messageLength+2+paddingLength] 65 | return ObliviousDNSMessageBody{ 66 | DnsMessage: message, 67 | Padding: padding, 68 | }, nil 69 | } 70 | 71 | func (m ObliviousDNSMessageBody) Message() []byte { 72 | return m.DnsMessage 73 | } 74 | 75 | type ObliviousDNSQuery struct { 76 | ObliviousDNSMessageBody 77 | } 78 | 79 | func CreateObliviousDNSQuery(query []byte, paddingBytes uint16) *ObliviousDNSQuery { 80 | msg := ObliviousDNSMessageBody{ 81 | DnsMessage: query, 82 | Padding: make([]byte, int(paddingBytes)), 83 | } 84 | return &ObliviousDNSQuery{ 85 | msg, 86 | } 87 | } 88 | 89 | func UnmarshalQueryBody(data []byte) (*ObliviousDNSQuery, error) { 90 | msg, err := UnmarshalMessageBody(data) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | return &ObliviousDNSQuery{msg}, nil 96 | } 97 | 98 | type ObliviousDNSResponse struct { 99 | ObliviousDNSMessageBody 100 | } 101 | 102 | func CreateObliviousDNSResponse(response []byte, paddingBytes uint16) *ObliviousDNSResponse { 103 | msg := ObliviousDNSMessageBody{ 104 | DnsMessage: response, 105 | Padding: make([]byte, int(paddingBytes)), 106 | } 107 | return &ObliviousDNSResponse{ 108 | msg, 109 | } 110 | } 111 | 112 | func UnmarshalResponseBody(data []byte) (*ObliviousDNSResponse, error) { 113 | msg, err := UnmarshalMessageBody(data) 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | return &ObliviousDNSResponse{msg}, nil 119 | } 120 | 121 | // 122 | // struct { 123 | // uint8 message_type; 124 | // opaque key_id<0..2^16-1>; 125 | // opaque encrypted_message<1..2^16-1>; 126 | // } ObliviousDoHMessage; 127 | // 128 | type ObliviousDNSMessage struct { 129 | MessageType ObliviousMessageType 130 | KeyID []byte 131 | EncryptedMessage []byte 132 | } 133 | 134 | func (m ObliviousDNSMessage) Type() ObliviousMessageType { 135 | return m.MessageType 136 | } 137 | 138 | func CreateObliviousDNSMessage(messageType ObliviousMessageType, keyID []byte, encryptedMessage []byte) *ObliviousDNSMessage { 139 | return &ObliviousDNSMessage{ 140 | MessageType: messageType, 141 | KeyID: keyID, 142 | EncryptedMessage: encryptedMessage, 143 | } 144 | } 145 | 146 | func (m ObliviousDNSMessage) Marshal() []byte { 147 | encodedKey := encodeLengthPrefixedSlice(m.KeyID) 148 | encodedMessage := encodeLengthPrefixedSlice(m.EncryptedMessage) 149 | 150 | result := append([]byte{uint8(m.MessageType)}, encodedKey...) 151 | result = append(result, encodedMessage...) 152 | 153 | return result 154 | } 155 | 156 | func UnmarshalDNSMessage(data []byte) (ObliviousDNSMessage, error) { 157 | if len(data) < 1 { 158 | return ObliviousDNSMessage{}, fmt.Errorf("Invalid data length: %d", len(data)) 159 | } 160 | 161 | messageType := data[0] 162 | keyID, messageOffset, err := decodeLengthPrefixedSlice(data[1:]) 163 | if err != nil { 164 | return ObliviousDNSMessage{}, err 165 | } 166 | encryptedMessage, _, err := decodeLengthPrefixedSlice(data[1+messageOffset:]) 167 | if err != nil { 168 | return ObliviousDNSMessage{}, err 169 | } 170 | 171 | return ObliviousDNSMessage{ 172 | MessageType: ObliviousMessageType(messageType), 173 | KeyID: keyID, 174 | EncryptedMessage: encryptedMessage, 175 | }, nil 176 | } 177 | -------------------------------------------------------------------------------- /messages_test.go: -------------------------------------------------------------------------------- 1 | package odoh 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestObliviousMessageMarshalEmptyKeyId(t *testing.T) { 9 | testMessage := []byte{0x06, 0x07, 0x08, 0x09} 10 | message := ObliviousDNSMessage{ 11 | MessageType: 0xFF, 12 | KeyID: nil, 13 | EncryptedMessage: testMessage, 14 | } 15 | 16 | serializedMessage := message.Marshal() 17 | expectedBytes := []byte{0xFF} 18 | expectedBytes = append(expectedBytes, []byte{0x00, 0x00}...) // empty key ID 19 | expectedBytes = append(expectedBytes, []byte{0x00, 0x04}...) // non-empty message 20 | expectedBytes = append(expectedBytes, testMessage...) 21 | if !bytes.Equal(serializedMessage, expectedBytes) { 22 | t.Fatalf("Marshalling mismatch in the encoding. Got %x, received %x", serializedMessage, expectedBytes) 23 | } 24 | } 25 | 26 | func TestObliviousMessageMarshalEmptyMessage(t *testing.T) { 27 | testKeyId := []byte{0x02, 0x03} 28 | message := ObliviousDNSMessage{ 29 | MessageType: 0xFF, 30 | KeyID: testKeyId, 31 | EncryptedMessage: nil, 32 | } 33 | 34 | serializedMessage := message.Marshal() 35 | expectedBytes := []byte{0xFF} 36 | expectedBytes = append(expectedBytes, []byte{0x00, 0x02}...) // non-empty key ID 37 | expectedBytes = append(expectedBytes, testKeyId...) 38 | expectedBytes = append(expectedBytes, []byte{0x00, 0x00}...) // empty message 39 | if !bytes.Equal(serializedMessage, expectedBytes) { 40 | t.Fatalf("Marshalling mismatch in the encoding. Got %x, received %x", serializedMessage, expectedBytes) 41 | } 42 | } 43 | 44 | func TestObliviousMessageMarshalNonEmptyKeyId(t *testing.T) { 45 | testMessage := []byte{0x06, 0x07, 0x08, 0x09} 46 | testKeyId := []byte{0x02, 0x03} 47 | message := ObliviousDNSMessage{ 48 | MessageType: 0xFF, 49 | KeyID: testKeyId, 50 | EncryptedMessage: testMessage, 51 | } 52 | 53 | serializedMessage := message.Marshal() 54 | expectedBytes := []byte{0xFF} 55 | expectedBytes = append(expectedBytes, []byte{0x00, 0x02}...) // non-empty key ID 56 | expectedBytes = append(expectedBytes, testKeyId...) 57 | expectedBytes = append(expectedBytes, []byte{0x00, 0x04}...) // non-empty message 58 | expectedBytes = append(expectedBytes, testMessage...) 59 | if !bytes.Equal(serializedMessage, expectedBytes) { 60 | t.Fatalf("Marshalling mismatch in the encoding. Got %x, received %x", serializedMessage, expectedBytes) 61 | } 62 | } 63 | 64 | func TestObliviousDoHQueryNoPaddingMarshal(t *testing.T) { 65 | dnsMessage := []byte{0x06, 0x07, 0x08, 0x09} 66 | query := CreateObliviousDNSQuery(dnsMessage, 0) 67 | 68 | serializedMessage := query.Marshal() 69 | expectedBytes := []byte{ 70 | 0x00, 0x04, 71 | 0x06, 0x07, 0x08, 0x09, 72 | 0x00, 0x00} 73 | if !bytes.Equal(serializedMessage, expectedBytes) { 74 | t.Fatalf("Marshalling mismatch in the encoding.") 75 | } 76 | } 77 | 78 | func TestObliviousDoHQueryPaddingMarshal(t *testing.T) { 79 | dnsMessage := []byte{0x06, 0x07, 0x08, 0x09} 80 | 81 | paddingLength := uint16(8) 82 | paddedBytes := make([]byte, paddingLength) 83 | query := CreateObliviousDNSQuery(dnsMessage, paddingLength) 84 | 85 | serializedMessage := query.Marshal() 86 | expectedBytes := []byte{ 87 | 0x00, 0x04, 88 | 0x06, 0x07, 0x08, 0x09, 89 | 0x00, uint8(paddingLength)} 90 | expectedBytes = append(expectedBytes, paddedBytes...) 91 | if !bytes.Equal(serializedMessage, expectedBytes) { 92 | t.Fatalf("Marshalling mismatch in the encoding.") 93 | } 94 | } 95 | 96 | func TestObliviousDoHMessage_Marshal(t *testing.T) { 97 | messageType := QueryType 98 | keyId := []byte{0x00, 0x01, 0x02, 0x03, 0x04} 99 | encryptedMessage := []byte{0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F} 100 | 101 | odnsMessage := ObliviousDNSMessage{ 102 | MessageType: messageType, 103 | KeyID: keyId, 104 | EncryptedMessage: encryptedMessage, 105 | } 106 | 107 | serializedMessage := odnsMessage.Marshal() 108 | expectedBytes := []byte{0x01, 109 | 0x00, 0x05, 110 | 0x00, 0x01, 0x02, 0x03, 0x04, 111 | 0x00, 0x0B, 112 | 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F} 113 | 114 | if !bytes.Equal(serializedMessage, expectedBytes) { 115 | t.Fatalf("Failed to serialize correctly. Got %x, expected %x", serializedMessage, expectedBytes) 116 | } 117 | } 118 | 119 | func TestObliviousDoHMessage_Unmarshal(t *testing.T) { 120 | messageType := QueryType 121 | keyId := []byte{0x00, 0x01, 0x02, 0x03, 0x04} 122 | encryptedMessage := []byte{0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F} 123 | 124 | odnsMessage := ObliviousDNSMessage{ 125 | MessageType: messageType, 126 | KeyID: keyId, 127 | EncryptedMessage: encryptedMessage, 128 | } 129 | 130 | expectedBytes := []byte{0x01, 131 | 0x00, 0x05, 132 | 0x00, 0x01, 0x02, 0x03, 0x04, 133 | 0x00, 0x0B, 134 | 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F} 135 | 136 | deserializedMessage, err := UnmarshalDNSMessage(expectedBytes) 137 | 138 | if err != nil { 139 | t.Fatalf("Failed to unmarshal ObliviousDNSMessage") 140 | } 141 | 142 | if !(deserializedMessage.MessageType == odnsMessage.MessageType) { 143 | t.Fatalf("Message type mismatch after unmarshaling") 144 | } 145 | 146 | if !bytes.Equal(deserializedMessage.KeyID, odnsMessage.KeyID) { 147 | t.Fatalf("Failed to unmarshal the KeyID correctly.") 148 | } 149 | 150 | if !bytes.Equal(deserializedMessage.EncryptedMessage, odnsMessage.EncryptedMessage) { 151 | t.Fatalf("Failed to unmarshal the Encrypted Message Correctly.") 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /odoh.go: -------------------------------------------------------------------------------- 1 | // The MIT License 2 | // 3 | // Copyright (c) 2019 Apple, Inc. 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in 13 | // all copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | // THE SOFTWARE. 22 | 23 | package odoh 24 | 25 | import ( 26 | "crypto/rand" 27 | "crypto/subtle" 28 | "encoding/binary" 29 | "errors" 30 | "fmt" 31 | "github.com/cisco/go-hpke" 32 | ) 33 | 34 | const ( 35 | ODOH_VERSION = uint16(0xff02) 36 | ODOH_SECRET_LENGTH = 32 37 | ODOH_PADDING_BYTE = uint8(0) 38 | ODOH_LABEL_KEY_ID = "odoh key id" 39 | ODOH_LABEL_KEY = "odoh key" 40 | ODOH_LABEL_NONCE = "odoh nonce" 41 | ODOH_LABEL_SECRET = "odoh secret" 42 | ODOH_LABEL_QUERY = "odoh query" 43 | ODOH_DEFAULT_KEMID hpke.KEMID = hpke.DHKEM_X25519 44 | ODOH_DEFAULT_KDFID hpke.KDFID = hpke.KDF_HKDF_SHA256 45 | ODOH_DEFAULT_AEADID hpke.AEADID = hpke.AEAD_AESGCM128 46 | ) 47 | 48 | type ObliviousDoHConfigContents struct { 49 | KemID hpke.KEMID 50 | KdfID hpke.KDFID 51 | AeadID hpke.AEADID 52 | PublicKeyBytes []byte 53 | } 54 | 55 | func CreateObliviousDoHConfigContents(kemID hpke.KEMID, kdfID hpke.KDFID, aeadID hpke.AEADID, publicKeyBytes []byte) (ObliviousDoHConfigContents, error) { 56 | suite, err := hpke.AssembleCipherSuite(kemID, kdfID, aeadID) 57 | if err != nil { 58 | return ObliviousDoHConfigContents{}, err 59 | } 60 | 61 | _, err = suite.KEM.Deserialize(publicKeyBytes) 62 | if err != nil { 63 | return ObliviousDoHConfigContents{}, err 64 | } 65 | 66 | return ObliviousDoHConfigContents{ 67 | KemID: kemID, 68 | KdfID: kdfID, 69 | AeadID: aeadID, 70 | PublicKeyBytes: publicKeyBytes, 71 | }, nil 72 | } 73 | 74 | func (k ObliviousDoHConfigContents) KeyID() []byte { 75 | suite, err := hpke.AssembleCipherSuite(k.KemID, k.KdfID, k.AeadID) 76 | if err != nil { 77 | return nil 78 | } 79 | 80 | identifiers := make([]byte, 8) 81 | binary.BigEndian.PutUint16(identifiers[0:], uint16(k.KemID)) 82 | binary.BigEndian.PutUint16(identifiers[2:], uint16(k.KdfID)) 83 | binary.BigEndian.PutUint16(identifiers[4:], uint16(k.AeadID)) 84 | binary.BigEndian.PutUint16(identifiers[6:], uint16(len(k.PublicKeyBytes))) 85 | config := append(identifiers, k.PublicKeyBytes...) 86 | 87 | prk := suite.KDF.Extract(nil, config) 88 | identifier := suite.KDF.Expand(prk, []byte(ODOH_LABEL_KEY_ID), suite.KDF.OutputSize()) 89 | 90 | return identifier 91 | } 92 | 93 | func (k ObliviousDoHConfigContents) Marshal() []byte { 94 | identifiers := make([]byte, 8) 95 | binary.BigEndian.PutUint16(identifiers[0:], uint16(k.KemID)) 96 | binary.BigEndian.PutUint16(identifiers[2:], uint16(k.KdfID)) 97 | binary.BigEndian.PutUint16(identifiers[4:], uint16(k.AeadID)) 98 | binary.BigEndian.PutUint16(identifiers[6:], uint16(len(k.PublicKeyBytes))) 99 | 100 | response := append(identifiers, k.PublicKeyBytes...) 101 | return response 102 | } 103 | 104 | func UnmarshalObliviousDoHConfigContents(buffer []byte) (ObliviousDoHConfigContents, error) { 105 | if len(buffer) < 8 { 106 | return ObliviousDoHConfigContents{}, errors.New("Invalid serialized ObliviousDoHConfigContents") 107 | } 108 | 109 | kemId := binary.BigEndian.Uint16(buffer[0:]) 110 | kdfId := binary.BigEndian.Uint16(buffer[2:]) 111 | aeadId := binary.BigEndian.Uint16(buffer[4:]) 112 | publicKeyLength := binary.BigEndian.Uint16(buffer[6:]) 113 | 114 | if len(buffer[8:]) < int(publicKeyLength) { 115 | return ObliviousDoHConfigContents{}, errors.New("Invalid serialized ObliviousDoHConfigContents") 116 | } 117 | 118 | publicKeyBytes := buffer[8 : 8+publicKeyLength] 119 | 120 | var KemID hpke.KEMID 121 | var KdfID hpke.KDFID 122 | var AeadID hpke.AEADID 123 | 124 | switch kemId { 125 | case 0x0010: 126 | KemID = hpke.DHKEM_P256 127 | break 128 | case 0x0012: 129 | KemID = hpke.DHKEM_P521 130 | break 131 | case 0x0020: 132 | KemID = hpke.DHKEM_X25519 133 | break 134 | case 0x0021: 135 | KemID = hpke.DHKEM_X448 136 | break 137 | case 0xFFFE: 138 | KemID = hpke.KEM_SIKE503 139 | break 140 | case 0xFFFF: 141 | KemID = hpke.KEM_SIKE751 142 | break 143 | default: 144 | return ObliviousDoHConfigContents{}, errors.New(fmt.Sprintf("Unsupported KEMID: %04x", kemId)) 145 | } 146 | 147 | switch kdfId { 148 | case 0x0001: 149 | KdfID = hpke.KDF_HKDF_SHA256 150 | break 151 | case 0x0002: 152 | KdfID = hpke.KDF_HKDF_SHA384 153 | break 154 | case 0x0003: 155 | KdfID = hpke.KDF_HKDF_SHA512 156 | break 157 | default: 158 | return ObliviousDoHConfigContents{}, errors.New(fmt.Sprintf("Unsupported KDFID: %04x", kdfId)) 159 | } 160 | 161 | switch aeadId { 162 | case 0x0001: 163 | AeadID = hpke.AEAD_AESGCM128 164 | break 165 | case 0x0002: 166 | AeadID = hpke.AEAD_AESGCM256 167 | break 168 | case 0x0003: 169 | AeadID = hpke.AEAD_CHACHA20POLY1305 170 | break 171 | default: 172 | return ObliviousDoHConfigContents{}, errors.New(fmt.Sprintf("Unsupported AEADID: %04x", aeadId)) 173 | } 174 | 175 | suite, err := hpke.AssembleCipherSuite(KemID, KdfID, AeadID) 176 | if err != nil { 177 | return ObliviousDoHConfigContents{}, errors.New(fmt.Sprintf("Unsupported HPKE ciphersuite")) 178 | } 179 | 180 | _, err = suite.KEM.Deserialize(publicKeyBytes) 181 | if err != nil { 182 | return ObliviousDoHConfigContents{}, errors.New(fmt.Sprintf("Invalid HPKE public key bytes")) 183 | } 184 | 185 | return ObliviousDoHConfigContents{ 186 | KemID: KemID, 187 | KdfID: KdfID, 188 | AeadID: AeadID, 189 | PublicKeyBytes: publicKeyBytes, 190 | }, nil 191 | } 192 | 193 | func (k ObliviousDoHConfigContents) PublicKey() []byte { 194 | return k.PublicKeyBytes 195 | } 196 | 197 | func (k ObliviousDoHConfigContents) CipherSuite() (hpke.CipherSuite, error) { 198 | return hpke.AssembleCipherSuite(k.KemID, k.KdfID, k.AeadID) 199 | } 200 | 201 | type ObliviousDoHConfig struct { 202 | Version uint16 203 | Contents ObliviousDoHConfigContents 204 | } 205 | 206 | func CreateObliviousDoHConfig(contents ObliviousDoHConfigContents) ObliviousDoHConfig { 207 | return ObliviousDoHConfig{ 208 | Version: ODOH_VERSION, 209 | Contents: contents, 210 | } 211 | } 212 | 213 | func (c ObliviousDoHConfig) Marshal() []byte { 214 | marshalledConfig := c.Contents.Marshal() 215 | 216 | buffer := make([]byte, 4) 217 | binary.BigEndian.PutUint16(buffer[0:], uint16(c.Version)) 218 | binary.BigEndian.PutUint16(buffer[2:], uint16(len(marshalledConfig))) 219 | 220 | configBytes := append(buffer, marshalledConfig...) 221 | return configBytes 222 | } 223 | 224 | func parseConfigHeader(buffer []byte) (uint16, uint16, error) { 225 | if len(buffer) < 4 { 226 | return uint16(0), uint16(0), errors.New("Invalid ObliviousDoHConfig encoding") 227 | } 228 | 229 | version := binary.BigEndian.Uint16(buffer[0:]) 230 | length := binary.BigEndian.Uint16(buffer[2:]) 231 | return version, length, nil 232 | } 233 | 234 | func isSupportedConfigVersion(version uint16) bool { 235 | return version == ODOH_VERSION 236 | } 237 | 238 | func UnmarshalObliviousDoHConfig(buffer []byte) (ObliviousDoHConfig, error) { 239 | version, length, err := parseConfigHeader(buffer) 240 | if err != nil { 241 | return ObliviousDoHConfig{}, err 242 | } 243 | 244 | if !isSupportedConfigVersion(version) { 245 | return ObliviousDoHConfig{}, errors.New(fmt.Sprintf("Unsupported version: %04x", version)) 246 | } 247 | if len(buffer[4:]) < int(length) { 248 | return ObliviousDoHConfig{}, errors.New(fmt.Sprintf("Invalid serialized ObliviousDoHConfig, expected %v bytes, got %v", length, len(buffer[4:]))) 249 | } 250 | 251 | configContents, err := UnmarshalObliviousDoHConfigContents(buffer[4:]) 252 | if err != nil { 253 | return ObliviousDoHConfig{}, err 254 | } 255 | 256 | return ObliviousDoHConfig{ 257 | Version: version, 258 | Contents: configContents, 259 | }, nil 260 | } 261 | 262 | type ObliviousDoHConfigs struct { 263 | Configs []ObliviousDoHConfig 264 | } 265 | 266 | func CreateObliviousDoHConfigs(configs []ObliviousDoHConfig) ObliviousDoHConfigs { 267 | return ObliviousDoHConfigs{ 268 | Configs: configs, 269 | } 270 | } 271 | 272 | func (c ObliviousDoHConfigs) Marshal() []byte { 273 | serializedConfigs := make([]byte, 0) 274 | for _, config := range c.Configs { 275 | serializedConfigs = append(serializedConfigs, config.Marshal()...) 276 | } 277 | 278 | buffer := make([]byte, 2) 279 | binary.BigEndian.PutUint16(buffer[0:], uint16(len(serializedConfigs))) 280 | 281 | result := append(buffer, serializedConfigs...) 282 | return result 283 | } 284 | 285 | func UnmarshalObliviousDoHConfigs(buffer []byte) (ObliviousDoHConfigs, error) { 286 | if len(buffer) < 2 { 287 | return ObliviousDoHConfigs{}, errors.New("Invalid ObliviousDoHConfigs encoding") 288 | } 289 | 290 | configs := make([]ObliviousDoHConfig, 0) 291 | length := binary.BigEndian.Uint16(buffer[0:]) 292 | offset := uint16(2) 293 | 294 | for { 295 | configVersion, configLength, err := parseConfigHeader(buffer[offset:]) 296 | if err != nil { 297 | return ObliviousDoHConfigs{}, errors.New("Invalid ObliviousDoHConfigs encoding") 298 | } 299 | 300 | if uint16(len(buffer[offset:])) < configLength { 301 | // The configs vector is encoded incorrectly, so discard the whole thing 302 | return ObliviousDoHConfigs{}, errors.New(fmt.Sprintf("Invalid serialized ObliviousDoHConfig, expected %v bytes, got %v", length, len(buffer[offset:]))) 303 | } 304 | 305 | if isSupportedConfigVersion(configVersion) { 306 | config, err := UnmarshalObliviousDoHConfig(buffer[offset:]) 307 | if err == nil { 308 | configs = append(configs, config) 309 | } 310 | } else { 311 | // Skip over unsupported versions 312 | } 313 | 314 | offset += 4 + configLength 315 | if offset >= 2+length { 316 | // Stop reading 317 | break 318 | } 319 | } 320 | 321 | return CreateObliviousDoHConfigs(configs), nil 322 | } 323 | 324 | type ObliviousDoHKeyPair struct { 325 | Config ObliviousDoHConfig 326 | secretKey hpke.KEMPrivateKey 327 | Seed []byte 328 | } 329 | 330 | func CreateKeyPairFromSeed(kemID hpke.KEMID, kdfID hpke.KDFID, aeadID hpke.AEADID, ikm []byte) (ObliviousDoHKeyPair, error) { 331 | suite, err := hpke.AssembleCipherSuite(kemID, kdfID, aeadID) 332 | if err != nil { 333 | return ObliviousDoHKeyPair{}, err 334 | } 335 | 336 | sk, pk, err := suite.KEM.DeriveKeyPair(ikm) 337 | if err != nil { 338 | return ObliviousDoHKeyPair{}, err 339 | } 340 | 341 | configContents, err := CreateObliviousDoHConfigContents(kemID, kdfID, aeadID, suite.KEM.Serialize(pk)) 342 | if err != nil { 343 | return ObliviousDoHKeyPair{}, err 344 | } 345 | 346 | config := CreateObliviousDoHConfig(configContents) 347 | 348 | return ObliviousDoHKeyPair{ 349 | Config: config, 350 | secretKey: sk, 351 | Seed: ikm, 352 | }, nil 353 | } 354 | 355 | func CreateDefaultKeyPairFromSeed(seed []byte) (ObliviousDoHKeyPair, error) { 356 | return CreateKeyPairFromSeed(ODOH_DEFAULT_KEMID, ODOH_DEFAULT_KDFID, ODOH_DEFAULT_AEADID, seed) 357 | } 358 | 359 | func CreateKeyPair(kemID hpke.KEMID, kdfID hpke.KDFID, aeadID hpke.AEADID) (ObliviousDoHKeyPair, error) { 360 | suite, err := hpke.AssembleCipherSuite(kemID, kdfID, aeadID) 361 | if err != nil { 362 | return ObliviousDoHKeyPair{}, err 363 | } 364 | 365 | ikm := make([]byte, suite.KEM.PrivateKeySize()) 366 | rand.Reader.Read(ikm) 367 | sk, pk, err := suite.KEM.DeriveKeyPair(ikm) 368 | if err != nil { 369 | return ObliviousDoHKeyPair{}, err 370 | } 371 | 372 | configContents, err := CreateObliviousDoHConfigContents(kemID, kdfID, aeadID, suite.KEM.Serialize(pk)) 373 | if err != nil { 374 | return ObliviousDoHKeyPair{}, err 375 | } 376 | 377 | config := CreateObliviousDoHConfig(configContents) 378 | 379 | return ObliviousDoHKeyPair{ 380 | Config: config, 381 | secretKey: sk, 382 | Seed: ikm, 383 | }, nil 384 | } 385 | 386 | func CreateDefaultKeyPair() (ObliviousDoHKeyPair, error) { 387 | return CreateKeyPair(ODOH_DEFAULT_KEMID, ODOH_DEFAULT_KDFID, ODOH_DEFAULT_AEADID) 388 | } 389 | 390 | type QueryContext struct { 391 | odohSecret []byte 392 | suite hpke.CipherSuite 393 | query []byte 394 | publicKey ObliviousDoHConfigContents 395 | } 396 | 397 | func (c QueryContext) DecryptResponse(message ObliviousDNSMessage) ([]byte, error) { 398 | aad := append([]byte{byte(ResponseType)}, []byte{0x00, 0x00}...) // 0-length encoded KeyID 399 | 400 | odohPRK := c.suite.KDF.Extract(c.query, c.odohSecret) 401 | key := c.suite.KDF.Expand(odohPRK, []byte(ODOH_LABEL_KEY), c.suite.AEAD.KeySize()) 402 | nonce := c.suite.KDF.Expand(odohPRK, []byte(ODOH_LABEL_NONCE), c.suite.AEAD.NonceSize()) 403 | 404 | aead, err := c.suite.AEAD.New(key) 405 | if err != nil { 406 | return nil, err 407 | } 408 | 409 | return aead.Open(nil, nonce, message.EncryptedMessage, aad) 410 | } 411 | 412 | type ResponseContext struct { 413 | query []byte 414 | suite hpke.CipherSuite 415 | odohSecret []byte 416 | } 417 | 418 | func (c ResponseContext) EncryptResponse(response *ObliviousDNSResponse) (ObliviousDNSMessage, error) { 419 | aad := append([]byte{byte(ResponseType)}, []byte{0x00, 0x00}...) // 0-length encoded KeyID 420 | 421 | odohPRK := c.suite.KDF.Extract(c.query, c.odohSecret) 422 | key := c.suite.KDF.Expand(odohPRK, []byte(ODOH_LABEL_KEY), c.suite.AEAD.KeySize()) 423 | nonce := c.suite.KDF.Expand(odohPRK, []byte(ODOH_LABEL_NONCE), c.suite.AEAD.NonceSize()) 424 | 425 | aead, err := c.suite.AEAD.New(key) 426 | if err != nil { 427 | return ObliviousDNSMessage{}, err 428 | } 429 | 430 | ciphertext := aead.Seal(nil, nonce, response.Marshal(), aad) 431 | 432 | odohMessage := ObliviousDNSMessage{ 433 | KeyID: nil, 434 | MessageType: ResponseType, 435 | EncryptedMessage: ciphertext, 436 | } 437 | 438 | return odohMessage, nil 439 | } 440 | 441 | func (targetKey ObliviousDoHConfigContents) EncryptQuery(query *ObliviousDNSQuery) (ObliviousDNSMessage, QueryContext, error) { 442 | suite, err := hpke.AssembleCipherSuite(targetKey.KemID, targetKey.KdfID, targetKey.AeadID) 443 | if err != nil { 444 | return ObliviousDNSMessage{}, QueryContext{}, err 445 | } 446 | 447 | pkR, err := suite.KEM.Deserialize(targetKey.PublicKeyBytes) 448 | if err != nil { 449 | return ObliviousDNSMessage{}, QueryContext{}, err 450 | } 451 | 452 | enc, ctxI, err := hpke.SetupBaseS(suite, rand.Reader, pkR, []byte(ODOH_LABEL_QUERY)) 453 | if err != nil { 454 | return ObliviousDNSMessage{}, QueryContext{}, err 455 | } 456 | 457 | keyID := targetKey.KeyID() 458 | keyIDLength := make([]byte, 2) 459 | binary.BigEndian.PutUint16(keyIDLength, uint16(len(keyID))) 460 | aad := append([]byte{byte(QueryType)}, keyIDLength...) 461 | aad = append(aad, keyID...) 462 | 463 | encodedMessage := query.Marshal() 464 | ct := ctxI.Seal(aad, encodedMessage) 465 | odohSecret := ctxI.Export([]byte(ODOH_LABEL_SECRET), ODOH_SECRET_LENGTH) 466 | 467 | return ObliviousDNSMessage{ 468 | KeyID: targetKey.KeyID(), 469 | MessageType: QueryType, 470 | EncryptedMessage: append(enc, ct...), 471 | }, QueryContext{ 472 | odohSecret: odohSecret, 473 | suite: suite, 474 | query: query.Marshal(), 475 | publicKey: targetKey, 476 | }, nil 477 | } 478 | 479 | func validateMessagePadding(padding []byte) bool { 480 | validPadding := 1 481 | for _, v := range padding { 482 | validPadding &= subtle.ConstantTimeByteEq(v, ODOH_PADDING_BYTE) 483 | } 484 | return validPadding == 1 485 | } 486 | 487 | func (privateKey ObliviousDoHKeyPair) DecryptQuery(message ObliviousDNSMessage) (*ObliviousDNSQuery, ResponseContext, error) { 488 | if message.MessageType != QueryType { 489 | return nil, ResponseContext{}, errors.New("message is not a query") 490 | } 491 | 492 | suite, err := hpke.AssembleCipherSuite(privateKey.Config.Contents.KemID, privateKey.Config.Contents.KdfID, privateKey.Config.Contents.AeadID) 493 | if err != nil { 494 | return nil, ResponseContext{}, err 495 | } 496 | 497 | keySize := suite.KEM.PublicKeySize() 498 | enc := message.EncryptedMessage[0:keySize] 499 | ct := message.EncryptedMessage[keySize:] 500 | 501 | ctxR, err := hpke.SetupBaseR(suite, privateKey.secretKey, enc, []byte(ODOH_LABEL_QUERY)) 502 | if err != nil { 503 | return nil, ResponseContext{}, err 504 | } 505 | 506 | odohSecret := ctxR.Export([]byte(ODOH_LABEL_SECRET), ODOH_SECRET_LENGTH) 507 | 508 | keyID := privateKey.Config.Contents.KeyID() 509 | keyIDLength := make([]byte, 2) 510 | binary.BigEndian.PutUint16(keyIDLength, uint16(len(keyID))) 511 | aad := append([]byte{byte(QueryType)}, keyIDLength...) 512 | aad = append(aad, keyID...) 513 | 514 | dnsMessage, err := ctxR.Open(aad, ct) 515 | if err != nil { 516 | return nil, ResponseContext{}, err 517 | } 518 | 519 | query, err := UnmarshalQueryBody(dnsMessage) 520 | if err != nil { 521 | return nil, ResponseContext{}, err 522 | } 523 | 524 | if !validateMessagePadding(query.Padding) { 525 | return nil, ResponseContext{}, errors.New("invalid padding") 526 | } 527 | 528 | responseContext := ResponseContext{ 529 | odohSecret: odohSecret, 530 | suite: suite, 531 | query: query.Marshal(), 532 | } 533 | 534 | return query, responseContext, nil 535 | } 536 | 537 | func SealQuery(dnsQuery []byte, publicKey ObliviousDoHConfigContents) (ObliviousDNSMessage, QueryContext, error) { 538 | odohQuery := CreateObliviousDNSQuery(dnsQuery, 0) 539 | 540 | odohMessage, queryContext, err := publicKey.EncryptQuery(odohQuery) 541 | if err != nil { 542 | return ObliviousDNSMessage{}, QueryContext{}, err 543 | } 544 | 545 | return odohMessage, queryContext, nil 546 | } 547 | 548 | func (c QueryContext) OpenAnswer(message ObliviousDNSMessage) ([]byte, error) { 549 | if message.MessageType != ResponseType { 550 | return nil, errors.New("message is not a response") 551 | } 552 | 553 | decryptedResponseBytes, err := c.DecryptResponse(message) 554 | if err != nil { 555 | return nil, errors.New("unable to decrypt the obtained response using the symmetric key sent") 556 | } 557 | 558 | decryptedResponse, err := UnmarshalResponseBody(decryptedResponseBytes) 559 | if err != nil { 560 | return nil, err 561 | } 562 | 563 | return decryptedResponse.DnsMessage, nil 564 | } 565 | -------------------------------------------------------------------------------- /odoh_test.go: -------------------------------------------------------------------------------- 1 | // The MIT License 2 | // 3 | // Copyright (c) 2019 Apple, Inc. 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in 13 | // all copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | // THE SOFTWARE. 22 | 23 | package odoh 24 | 25 | import ( 26 | "bytes" 27 | "crypto/rand" 28 | "encoding/binary" 29 | "encoding/hex" 30 | "encoding/json" 31 | "fmt" 32 | "io" 33 | "io/ioutil" 34 | "os" 35 | "testing" 36 | 37 | "github.com/cisco/go-hpke" 38 | ) 39 | 40 | const ( 41 | outputTestVectorEnvironmentKey = "ODOH_TEST_VECTORS_OUT" 42 | inputTestVectorEnvironmentKey = "ODOH_TEST_VECTORS_IN" 43 | ) 44 | 45 | func TestConfigDerivation(t *testing.T) { 46 | keyPair, err := CreateDefaultKeyPair() 47 | if err != nil { 48 | t.Fatalf("CreateDefaultKeyPair failed") 49 | } 50 | 51 | derivedKeyPair, err := CreateDefaultKeyPairFromSeed(keyPair.Seed) 52 | if err != nil { 53 | t.Fatalf("CreateDefaultKeyPair failed") 54 | } 55 | 56 | if keyPair.Config.Version != derivedKeyPair.Config.Version { 57 | t.Fatalf("Mismatched versions.") 58 | } 59 | if !bytes.Equal(keyPair.Config.Marshal(), derivedKeyPair.Config.Marshal()) { 60 | t.Fatalf("Mismatched configs.") 61 | } 62 | if !bytes.Equal(keyPair.Seed, derivedKeyPair.Seed) { 63 | t.Fatalf("Mismatched seeds.") 64 | } 65 | } 66 | 67 | func TestConfigDeserialization(t *testing.T) { 68 | keyPair, err := CreateDefaultKeyPair() 69 | if err != nil { 70 | t.Fatalf("CreateDefaultKeyPair failed") 71 | } 72 | 73 | serializedConfig := keyPair.Config.Marshal() 74 | 75 | recoveredConfig, err := UnmarshalObliviousDoHConfig(serializedConfig) 76 | if err != nil { 77 | t.Fatalf("Failed deserializing config") 78 | } 79 | if recoveredConfig.Version != keyPair.Config.Version { 80 | t.Fatalf("Mismatched versions.") 81 | } 82 | if !bytes.Equal(keyPair.Config.Marshal(), recoveredConfig.Marshal()) { 83 | t.Fatalf("Mismatched configs.") 84 | } 85 | 86 | serializedConfig = append(serializedConfig, 0x00) // append an extra byte 87 | 88 | recoveredConfig, err = UnmarshalObliviousDoHConfig(serializedConfig) 89 | if err != nil { 90 | t.Fatalf("Failed deserializing config") 91 | } 92 | if recoveredConfig.Version != keyPair.Config.Version { 93 | t.Fatalf("Mismatched versions.") 94 | } 95 | if !bytes.Equal(keyPair.Config.Marshal(), recoveredConfig.Marshal()) { 96 | t.Fatalf("Mismatched configs.") 97 | } 98 | } 99 | 100 | func TestConfigDeserializationFailures(t *testing.T) { 101 | keyPair, err := CreateDefaultKeyPair() 102 | if err != nil { 103 | t.Fatalf("CreateDefaultKeyPair failed") 104 | } 105 | 106 | serializedConfig := keyPair.Config.Marshal() 107 | 108 | // Encoding without full version or length 109 | _, err = UnmarshalObliviousDoHConfig(serializedConfig[0:1]) 110 | if err == nil { 111 | t.Fatalf("Failed to deserialize with insufficient length") 112 | } 113 | 114 | // Encoding with a mismatched version 115 | invalidSerializedConfig := serializedConfig[:] 116 | invalidSerializedConfig[0] = invalidSerializedConfig[0] ^ 0xFF 117 | _, err = UnmarshalObliviousDoHConfig(invalidSerializedConfig) 118 | if err == nil { 119 | t.Fatalf("Failed to deserialize with invalid version") 120 | } 121 | 122 | // Encoding with an invalid length 123 | invalidSerializedConfig = serializedConfig[:] 124 | _, err = UnmarshalObliviousDoHConfig(invalidSerializedConfig[0:4]) 125 | if err == nil { 126 | t.Fatalf("Failed to deserialize with insufficient contents length") 127 | } 128 | } 129 | 130 | func mustCreateDefaultKeyPair(t *testing.T) ObliviousDoHKeyPair { 131 | keyPair, err := CreateDefaultKeyPair() 132 | if err != nil { 133 | t.Fatalf("CreateDefaultKeyPair failed") 134 | } 135 | return keyPair 136 | } 137 | 138 | func copySlice(src []byte) []byte { 139 | copied := make([]byte, len(src)) 140 | copy(copied, src) 141 | return copied 142 | } 143 | 144 | func TestConfigsDeserialization(t *testing.T) { 145 | keyPairA := mustCreateDefaultKeyPair(t) 146 | keyPairB := mustCreateDefaultKeyPair(t) 147 | 148 | configSet := []ObliviousDoHConfig{keyPairA.Config, keyPairB.Config} 149 | configs := CreateObliviousDoHConfigs(configSet) 150 | 151 | serializedConfigA := keyPairA.Config.Marshal() 152 | serializedConfigB := keyPairB.Config.Marshal() 153 | serializedConfigs := configs.Marshal() 154 | 155 | if len(serializedConfigs) != 2+len(serializedConfigA)+len(serializedConfigB) { 156 | t.Fatalf("Invalid serialized length. Expected %v, got %v", 2+len(serializedConfigA)+len(serializedConfigB), len(serializedConfigs)) 157 | } 158 | 159 | _, err := UnmarshalObliviousDoHConfigs(serializedConfigs) 160 | if err != nil { 161 | t.Fatalf("UnmarshalObliviousDoHConfigs failed: %v", err) 162 | } 163 | 164 | longSerializedConfigs := append(serializedConfigs, 0x00) 165 | _, err = UnmarshalObliviousDoHConfigs(longSerializedConfigs) 166 | if err != nil { 167 | t.Fatalf("UnmarshalObliviousDoHConfigs failed: %v", err) 168 | } 169 | 170 | invalidSerializedConfigs := copySlice(serializedConfigs) 171 | deserializedConfigs, err := UnmarshalObliviousDoHConfigs(invalidSerializedConfigs[0 : len(invalidSerializedConfigs)-1]) 172 | if err != nil { 173 | t.Fatalf("UnmarshalObliviousDoHConfigs failed to parse first config") 174 | } 175 | if len(deserializedConfigs.Configs) != 1 { 176 | t.Fatalf("UnmarshalObliviousDoHConfigs parsed more than one ObliviousDoHConfig elements, got %v", len(deserializedConfigs.Configs)) 177 | } 178 | 179 | invalidSerializedConfigs = copySlice(serializedConfigs) 180 | invalidSerializedConfigs[0] = 0xFF // Invalidate the outer vector length 181 | _, err = UnmarshalObliviousDoHConfigs(invalidSerializedConfigs) 182 | if err == nil { 183 | t.Fatalf("UnmarshalObliviousDoHConfigs succeeded without enough bytes") 184 | } 185 | 186 | invalidSerializedConfigs = copySlice(serializedConfigs) 187 | invalidSerializedConfigs[2] ^= 0xFF // Flip the version value 188 | deserializedConfigs, err = UnmarshalObliviousDoHConfigs(invalidSerializedConfigs) 189 | if err != nil { 190 | t.Fatalf("UnmarshalObliviousDoHConfigs failed to parse one of the valid configs: %v", err) 191 | } 192 | if len(deserializedConfigs.Configs) != 1 { 193 | t.Fatalf("UnmarshalObliviousDoHConfigs parsed more than one ObliviousDoHConfig elements, got %v", len(deserializedConfigs.Configs)) 194 | } 195 | 196 | invalidSerializedConfigs = copySlice(serializedConfigs) 197 | invalidSerializedConfigs[4] = 0xFF // Extend the length of the first config 198 | _, err = UnmarshalObliviousDoHConfigs(invalidSerializedConfigs) 199 | if err == nil { 200 | t.Fatalf("UnmarshalObliviousDoHConfigs succeeded without enough bytes") 201 | } 202 | } 203 | 204 | func createDefaultSerializedPublicKey(t *testing.T) []byte { 205 | suite, err := hpke.AssembleCipherSuite(ODOH_DEFAULT_KEMID, ODOH_DEFAULT_KDFID, ODOH_DEFAULT_AEADID) 206 | if err != nil { 207 | t.Fatalf("Failed generating HPKE suite") 208 | } 209 | 210 | ikm := make([]byte, suite.KEM.PrivateKeySize()) 211 | _, _ = rand.Read(ikm) 212 | _, publicKey, err := suite.KEM.DeriveKeyPair(ikm) 213 | if err != nil { 214 | t.Fatalf("Failed generating public key") 215 | } 216 | 217 | return suite.KEM.Serialize(publicKey) 218 | } 219 | 220 | func validateSerializedContents(t *testing.T, configContents ObliviousDoHConfigContents, serializedContents []byte) { 221 | kemId := binary.BigEndian.Uint16(serializedContents[0:]) 222 | kdfId := binary.BigEndian.Uint16(serializedContents[2:]) 223 | aeadId := binary.BigEndian.Uint16(serializedContents[4:]) 224 | publicKeyLength := int(binary.BigEndian.Uint16(serializedContents[6:])) 225 | 226 | if kemId != uint16(ODOH_DEFAULT_KEMID) { 227 | t.Fatalf("Invalid serialized KEMID. Expected %v, got %v.", ODOH_DEFAULT_KEMID, kemId) 228 | } 229 | if kdfId != uint16(ODOH_DEFAULT_KDFID) { 230 | t.Fatalf("Invalid serialized KDFID. Expected %v, got %v.", ODOH_DEFAULT_KDFID, kdfId) 231 | } 232 | if aeadId != uint16(ODOH_DEFAULT_AEADID) { 233 | t.Fatalf("Invalid serialized AEADID. Expected %v, got %v.", ODOH_DEFAULT_AEADID, aeadId) 234 | } 235 | if publicKeyLength != len(configContents.PublicKeyBytes) { 236 | t.Fatalf("Invalid serialized public key length. Expected %v, got %v.", len(configContents.PublicKeyBytes), publicKeyLength) 237 | } 238 | if !bytes.Equal(configContents.PublicKeyBytes, serializedContents[8:8+publicKeyLength]) { 239 | t.Fatalf("Invalid bytes serialized. Expected %x, got %x", configContents.PublicKeyBytes, serializedContents[8:8+publicKeyLength]) 240 | } 241 | } 242 | 243 | func TestConfigContentsSerialization(t *testing.T) { 244 | publicKeyBytes := createDefaultSerializedPublicKey(t) 245 | configContents, err := CreateObliviousDoHConfigContents(ODOH_DEFAULT_KEMID, ODOH_DEFAULT_KDFID, ODOH_DEFAULT_AEADID, publicKeyBytes) 246 | if err != nil { 247 | t.Fatalf("CreateObliviousDoHConfigContents failed: %v", err) 248 | } 249 | 250 | serializedContents := configContents.Marshal() 251 | if len(serializedContents) != 8+len(publicKeyBytes) { 252 | t.Fatalf("Invalid length of serialized ObliviousDoHConfigContents. Expected %v, got %v.", 8+len(publicKeyBytes), len(serializedContents)) 253 | } 254 | 255 | validateSerializedContents(t, configContents, serializedContents) 256 | 257 | serializedContents = append(serializedContents, 0x00) 258 | validateSerializedContents(t, configContents, serializedContents) 259 | } 260 | 261 | func TestConfigContentsDeserialization(t *testing.T) { 262 | publicKeyBytes := createDefaultSerializedPublicKey(t) 263 | configContents, err := CreateObliviousDoHConfigContents(ODOH_DEFAULT_KEMID, ODOH_DEFAULT_KDFID, ODOH_DEFAULT_AEADID, publicKeyBytes) 264 | if err != nil { 265 | t.Fatalf("CreateObliviousDoHConfigContents failed: %v", err) 266 | } 267 | 268 | serializedContents := configContents.Marshal() 269 | 270 | _, err = UnmarshalObliviousDoHConfigContents(serializedContents[0:7]) 271 | if err == nil { 272 | t.Fatalf("Failed to deserialize with insufficient length") 273 | } 274 | 275 | _, err = UnmarshalObliviousDoHConfigContents(serializedContents[0:8]) 276 | if err == nil { 277 | t.Fatalf("Failed to deserialize with insufficient public key length") 278 | } 279 | 280 | invalidSerializedConfig := serializedContents[:] 281 | invalidSerializedConfig[0] = invalidSerializedConfig[0] ^ 0xFF 282 | _, err = UnmarshalObliviousDoHConfigContents(invalidSerializedConfig) 283 | if err == nil { 284 | t.Fatalf("Failed to deserialize with invalid KEMID") 285 | } 286 | 287 | invalidSerializedConfig = serializedContents[:] 288 | invalidSerializedConfig[2] = invalidSerializedConfig[2] ^ 0xFF 289 | _, err = UnmarshalObliviousDoHConfigContents(invalidSerializedConfig) 290 | if err == nil { 291 | t.Fatalf("Failed to deserialize with invalid KDFID") 292 | } 293 | 294 | invalidSerializedConfig = serializedContents[:] 295 | invalidSerializedConfig[4] = invalidSerializedConfig[4] ^ 0xFF 296 | _, err = UnmarshalObliviousDoHConfigContents(invalidSerializedConfig) 297 | if err == nil { 298 | t.Fatalf("Failed to deserialize with invalid AEADID") 299 | } 300 | 301 | invalidSerializedConfig = serializedContents[:] 302 | invalidSerializedConfig[7] = 0x1F // instead of 0x20, for x25519 public keys 303 | _, err = UnmarshalObliviousDoHConfigContents(invalidSerializedConfig[0 : len(invalidSerializedConfig)-1]) 304 | if err == nil { 305 | t.Fatalf("Failed to deserialize with invalid x25519 public key") 306 | } 307 | } 308 | 309 | func TestQueryBodyMarshal(t *testing.T) { 310 | message := []byte{0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F} 311 | 312 | queryBody := CreateObliviousDNSQuery(message, 0) 313 | 314 | encoded := queryBody.Marshal() 315 | decoded, err := UnmarshalQueryBody(encoded) 316 | if err != nil { 317 | t.Fatalf("Encode/decode failed") 318 | } 319 | if !bytes.Equal(decoded.DnsMessage, message) { 320 | t.Fatalf("Key mismatch") 321 | } 322 | } 323 | 324 | func TestDNSMessageMarshal(t *testing.T) { 325 | keyID := []byte{0x00, 0x01, 0x02, 0x04} 326 | encryptedMessage := []byte{0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F} 327 | 328 | message := ObliviousDNSMessage{ 329 | MessageType: 0x01, 330 | KeyID: keyID, 331 | EncryptedMessage: encryptedMessage, 332 | } 333 | 334 | encoded := message.Marshal() 335 | decoded, err := UnmarshalDNSMessage(encoded) 336 | if err != nil { 337 | t.Fatalf("Encode/decode failed") 338 | } 339 | if decoded.MessageType != 0x01 { 340 | t.Fatalf("MessageType mismatch") 341 | } 342 | if !bytes.Equal(decoded.KeyID, keyID) { 343 | t.Fatalf("KeyID mismatch") 344 | } 345 | if !bytes.Equal(decoded.EncryptedMessage, encryptedMessage) { 346 | t.Fatalf("EncryptedMessage mismatch") 347 | } 348 | } 349 | 350 | func TestQueryEncryption(t *testing.T) { 351 | kemID := hpke.DHKEM_X25519 352 | kdfID := hpke.KDF_HKDF_SHA256 353 | aeadID := hpke.AEAD_AESGCM128 354 | 355 | suite, err := hpke.AssembleCipherSuite(kemID, kdfID, aeadID) 356 | if err != nil { 357 | t.Fatalf("[%x, %x, %x] Error looking up ciphersuite: %s", kemID, kdfID, aeadID, err) 358 | } 359 | 360 | ikm := make([]byte, suite.KEM.PrivateKeySize()) 361 | rand.Reader.Read(ikm) 362 | skR, pkR, err := suite.KEM.DeriveKeyPair(ikm) 363 | if err != nil { 364 | t.Fatalf("[%x, %x, %x] Error generating DH key pair: %s", kemID, kdfID, aeadID, err) 365 | } 366 | 367 | targetKey := ObliviousDoHConfigContents{ 368 | KemID: kemID, 369 | KdfID: kdfID, 370 | AeadID: aeadID, 371 | PublicKeyBytes: suite.KEM.Serialize(pkR), 372 | } 373 | 374 | targetConfig := ObliviousDoHConfig{ 375 | Contents: targetKey, 376 | } 377 | 378 | odohKeyPair := ObliviousDoHKeyPair{targetConfig, skR, ikm} 379 | 380 | dnsMessage := []byte{0x01, 0x02} 381 | 382 | message := CreateObliviousDNSQuery(dnsMessage, 0) 383 | 384 | encryptedMessage, _, err := targetKey.EncryptQuery(message) 385 | if err != nil { 386 | t.Fatalf("EncryptQuery failed: %s", err) 387 | } 388 | 389 | result, _, err := odohKeyPair.DecryptQuery(encryptedMessage) 390 | if err != nil { 391 | t.Fatalf("DecryptQuery failed: %s", err) 392 | } 393 | 394 | if !bytes.Equal(result.DnsMessage, dnsMessage) { 395 | t.Fatalf("Incorrect DnsMessage returned") 396 | } 397 | } 398 | 399 | func Test_Sender_ODOHQueryEncryption(t *testing.T) { 400 | kemID := hpke.DHKEM_P256 // 0x0010 401 | kdfID := hpke.KDF_HKDF_SHA256 // 0x0001 402 | aeadID := hpke.AEAD_AESGCM128 // 0x0001 403 | 404 | suite, err := hpke.AssembleCipherSuite(kemID, kdfID, aeadID) 405 | if err != nil { 406 | t.Fatalf("[%x, %x, %x] Error looking up ciphersuite: %s", kemID, kdfID, aeadID, err) 407 | } 408 | 409 | responseKey := make([]byte, suite.AEAD.KeySize()) 410 | if _, err := io.ReadFull(rand.Reader, responseKey); err != nil { 411 | t.Fatalf("Failed generating random key: %s", err) 412 | } 413 | 414 | ikm := make([]byte, suite.KEM.PrivateKeySize()) 415 | rand.Reader.Read(ikm) 416 | 417 | skR, pkR, err := suite.KEM.DeriveKeyPair(ikm) 418 | if err != nil { 419 | t.Fatalf("[%x, %x, %x] Error generating DH key pair: %s", kemID, kdfID, aeadID, err) 420 | } 421 | 422 | targetKey := ObliviousDoHConfigContents{ 423 | KemID: kemID, 424 | KdfID: kdfID, 425 | AeadID: aeadID, 426 | PublicKeyBytes: suite.KEM.Serialize(pkR), 427 | } 428 | 429 | targetConfig := ObliviousDoHConfig{ 430 | Contents: targetKey, 431 | } 432 | 433 | odohKeyPair := ObliviousDoHKeyPair{targetConfig, skR, ikm} 434 | symmetricKey := make([]byte, suite.AEAD.KeySize()) 435 | rand.Read(symmetricKey) 436 | 437 | dnsMessage := []byte{0x01, 0x02, 0x03} 438 | message := CreateObliviousDNSQuery(dnsMessage, 0) 439 | 440 | encryptedMessage, _, err := targetKey.EncryptQuery(message) 441 | if err != nil { 442 | t.Fatalf("Failed to encrypt the message using the public key.") 443 | } 444 | 445 | dnsQuery, _, err := odohKeyPair.DecryptQuery(encryptedMessage) 446 | if err != nil { 447 | t.Fatalf("Failed to decrypt message with error: %s", err) 448 | } 449 | 450 | if !bytes.Equal(dnsQuery.DnsMessage, dnsMessage) { 451 | t.Fatalf("Incorrect dnsMessage returned") 452 | } 453 | } 454 | 455 | func TestEncoding(t *testing.T) { 456 | emptySlice := make([]byte, 0) 457 | if !bytes.Equal([]byte{0x00, 0x00}, encodeLengthPrefixedSlice(emptySlice)) { 458 | t.Fatalf("encodeLengthPrefixedSlice for empty slice failed") 459 | } 460 | } 461 | 462 | func TestOdohPublicKeyMarshalUnmarshal(t *testing.T) { 463 | kemID := hpke.DHKEM_P256 // 0x0010 464 | kdfID := hpke.KDF_HKDF_SHA256 // 0x0001 465 | aeadID := hpke.AEAD_AESGCM128 // 0x0001 466 | 467 | suite, err := hpke.AssembleCipherSuite(kemID, kdfID, aeadID) 468 | if err != nil { 469 | t.Fatalf("[%x, %x, %x] Error looking up ciphersuite: %s", kemID, kdfID, aeadID, err) 470 | } 471 | 472 | responseKey := make([]byte, suite.AEAD.KeySize()) 473 | if _, err := io.ReadFull(rand.Reader, responseKey); err != nil { 474 | t.Fatalf("Failed generating random key: %s", err) 475 | } 476 | 477 | ikm := make([]byte, suite.KEM.PrivateKeySize()) 478 | rand.Reader.Read(ikm) 479 | 480 | _, pkR, err := suite.KEM.DeriveKeyPair(ikm) 481 | if err != nil { 482 | t.Fatalf("[%x, %x, %x] Error generating DH key pair: %s", kemID, kdfID, aeadID, err) 483 | } 484 | 485 | targetKey := ObliviousDoHConfigContents{ 486 | KemID: kemID, 487 | KdfID: kdfID, 488 | AeadID: aeadID, 489 | PublicKeyBytes: suite.KEM.Serialize(pkR), 490 | } 491 | 492 | serializedPublicKey := targetKey.Marshal() 493 | deserializedPublicKey, err := UnmarshalObliviousDoHConfigContents(serializedPublicKey) 494 | if err != nil { 495 | t.Fatalf("UnmarshalObliviousDoHConfigContents failed: %v", err) 496 | } 497 | 498 | if !bytes.Equal(deserializedPublicKey.PublicKeyBytes, targetKey.PublicKeyBytes) { 499 | t.Fatalf("The deserialized and serialized bytes do not match.") 500 | } 501 | 502 | if deserializedPublicKey.KemID != targetKey.KemID { 503 | t.Fatalf("The KEM IDs do not match.") 504 | } 505 | 506 | if deserializedPublicKey.KdfID != targetKey.KdfID { 507 | t.Fatalf("The KDF IDs do not match.") 508 | } 509 | 510 | if deserializedPublicKey.AeadID != targetKey.AeadID { 511 | t.Fatalf("The AEAD IDs do not match.") 512 | } 513 | } 514 | 515 | func TestFixedOdohKeyPairCreation(t *testing.T) { 516 | const ( 517 | kemID = hpke.DHKEM_X25519 518 | kdfID = hpke.KDF_HKDF_SHA256 519 | aeadID = hpke.AEAD_AESGCM128 520 | ) 521 | 522 | // Fixed 16 byte seed 523 | seedHex := "f7c664a7959b2aa02ffa7abb0d2022ab" 524 | seed, err := hex.DecodeString(seedHex) 525 | if err != nil { 526 | t.Fatalf("Unable to decode seed to bytes") 527 | } 528 | keyPair, err := CreateKeyPairFromSeed(kemID, kdfID, aeadID, seed) 529 | if err != nil { 530 | t.Fatalf("Unable to derive a ObliviousDoHKeyPair") 531 | } 532 | for i := 0; i < 10; i++ { 533 | keyPairDerived, err := CreateKeyPairFromSeed(kemID, kdfID, aeadID, seed) 534 | if err != nil { 535 | t.Fatalf("Unable to derive a ObliviousDoHKeyPair") 536 | } 537 | if !bytes.Equal(keyPairDerived.Config.Contents.Marshal(), keyPair.Config.Contents.Marshal()) { 538 | t.Fatalf("Public Key Derived does not match") 539 | } 540 | } 541 | } 542 | 543 | func TestSealQueryAndOpenAnswer(t *testing.T) { 544 | kemID := hpke.DHKEM_X25519 545 | kdfID := hpke.KDF_HKDF_SHA256 546 | aeadID := hpke.AEAD_AESGCM128 547 | 548 | kp, err := CreateKeyPair(kemID, kdfID, aeadID) 549 | if err != nil { 550 | t.Fatalf("Unable to create a Key Pair") 551 | } 552 | 553 | dnsQueryData := make([]byte, 40) 554 | _, err = rand.Read(dnsQueryData) 555 | 556 | encryptedData, queryContext, err := SealQuery(dnsQueryData, kp.Config.Contents) 557 | 558 | mockAnswerData := make([]byte, 100) 559 | _, err = rand.Read(mockAnswerData) 560 | 561 | _, responseContext, err := kp.DecryptQuery(encryptedData) 562 | 563 | mockResponse := CreateObliviousDNSResponse(mockAnswerData, 0) 564 | encryptedAnswer, err := responseContext.EncryptResponse(mockResponse) 565 | 566 | response, err := queryContext.OpenAnswer(encryptedAnswer) 567 | 568 | if !bytes.Equal(response, mockAnswerData) { 569 | t.Fatalf("Decryption of the result does not match encrypted value") 570 | } 571 | } 572 | 573 | /////// 574 | // Assertions 575 | func assert(t *testing.T, msg string, test bool) { 576 | if !test { 577 | t.Fatalf("%s", msg) 578 | } 579 | } 580 | 581 | func assertBytesEqual(t *testing.T, msg string, lhs, rhs []byte) { 582 | realMsg := fmt.Sprintf("%s: [%x] != [%x]", msg, lhs, rhs) 583 | assert(t, realMsg, bytes.Equal(lhs, rhs)) 584 | } 585 | 586 | func assertNotError(t *testing.T, msg string, err error) { 587 | realMsg := fmt.Sprintf("%s: %v", msg, err) 588 | assert(t, realMsg, err == nil) 589 | } 590 | 591 | func fatalOnError(t *testing.T, err error, msg string) { 592 | realMsg := fmt.Sprintf("%s: %v", msg, err) 593 | if err != nil { 594 | if t != nil { 595 | t.Fatalf(realMsg) 596 | } else { 597 | panic(realMsg) 598 | } 599 | } 600 | } 601 | 602 | func mustUnhex(t *testing.T, h string) []byte { 603 | out, err := hex.DecodeString(h) 604 | fatalOnError(t, err, "Unhex failed") 605 | return out 606 | } 607 | 608 | func mustHex(d []byte) string { 609 | return hex.EncodeToString(d) 610 | } 611 | 612 | func mustDeserializePub(t *testing.T, suite hpke.CipherSuite, h string, required bool) hpke.KEMPublicKey { 613 | pkm := mustUnhex(t, h) 614 | pk, err := suite.KEM.Deserialize(pkm) 615 | if required { 616 | fatalOnError(t, err, "Deserialize failed") 617 | } 618 | return pk 619 | } 620 | 621 | func mustSerializePub(suite hpke.CipherSuite, pub hpke.KEMPublicKey) string { 622 | return mustHex(suite.KEM.Serialize(pub)) 623 | } 624 | 625 | /////// 626 | // Query/Response transaction test vector structure 627 | type rawTransactionTestVector struct { 628 | Query string `json:"query"` 629 | QueryPaddingLength int `json:"queryPaddingLength"` 630 | Response string `json:"response"` 631 | ResponsePaddingLength int `json:"responsePaddingLength"` 632 | ObliviousQuery string `json:"obliviousQuery"` 633 | ObliviousResponse string `json:"obliviousResponse"` 634 | } 635 | 636 | type transactionTestVector struct { 637 | query []byte 638 | queryPaddingLength uint16 639 | response []byte 640 | responsePaddingLength uint16 641 | obliviousQuery ObliviousDNSMessage 642 | obliviousResponse ObliviousDNSMessage 643 | } 644 | 645 | func (etv transactionTestVector) MarshalJSON() ([]byte, error) { 646 | return json.Marshal(rawTransactionTestVector{ 647 | Query: mustHex(etv.query), 648 | QueryPaddingLength: int(etv.queryPaddingLength), 649 | Response: mustHex(etv.response), 650 | ResponsePaddingLength: int(etv.responsePaddingLength), 651 | ObliviousQuery: mustHex(etv.obliviousQuery.Marshal()), 652 | ObliviousResponse: mustHex(etv.obliviousResponse.Marshal()), 653 | }) 654 | } 655 | 656 | func (etv *transactionTestVector) UnmarshalJSON(data []byte) error { 657 | raw := rawTransactionTestVector{} 658 | err := json.Unmarshal(data, &raw) 659 | if err != nil { 660 | return err 661 | } 662 | 663 | etv.query = mustUnhex(nil, raw.Query) 664 | etv.queryPaddingLength = uint16(raw.QueryPaddingLength) 665 | etv.response = mustUnhex(nil, raw.Response) 666 | etv.responsePaddingLength = uint16(raw.ResponsePaddingLength) 667 | 668 | obliviousQueryBytes := mustUnhex(nil, raw.ObliviousQuery) 669 | obliviousResponseBytes := mustUnhex(nil, raw.ObliviousResponse) 670 | 671 | etv.obliviousQuery, err = UnmarshalDNSMessage(obliviousQueryBytes) 672 | if err != nil { 673 | return err 674 | } 675 | etv.obliviousResponse, err = UnmarshalDNSMessage(obliviousResponseBytes) 676 | if err != nil { 677 | return err 678 | } 679 | 680 | return nil 681 | } 682 | 683 | type rawTestVector struct { 684 | KemID int `json:"kem_id"` 685 | KdfID int `json:"kdf_id"` 686 | AeadID int `json:"aead_id"` 687 | Configs string `json:"odohconfigs"` 688 | PublicKeySeed string `json:"public_key_seed"` 689 | KeyId string `json:"key_id"` 690 | 691 | Transactions []transactionTestVector `json:"transactions"` 692 | } 693 | 694 | type testVector struct { 695 | t *testing.T 696 | kem_id hpke.KEMID 697 | kdf_id hpke.KDFID 698 | aead_id hpke.AEADID 699 | odoh_configs []byte 700 | public_key_seed []byte 701 | key_id []byte 702 | 703 | transactions []transactionTestVector 704 | } 705 | 706 | func (tv testVector) MarshalJSON() ([]byte, error) { 707 | return json.Marshal(rawTestVector{ 708 | KemID: int(tv.kem_id), 709 | KdfID: int(tv.kdf_id), 710 | AeadID: int(tv.aead_id), 711 | Configs: mustHex(tv.odoh_configs), 712 | PublicKeySeed: mustHex(tv.public_key_seed), 713 | KeyId: mustHex(tv.key_id), 714 | Transactions: tv.transactions, 715 | }) 716 | } 717 | 718 | func (tv *testVector) UnmarshalJSON(data []byte) error { 719 | raw := rawTestVector{} 720 | err := json.Unmarshal(data, &raw) 721 | if err != nil { 722 | return err 723 | } 724 | 725 | tv.kem_id = hpke.KEMID(raw.KemID) 726 | tv.kdf_id = hpke.KDFID(raw.KdfID) 727 | tv.aead_id = hpke.AEADID(raw.AeadID) 728 | tv.public_key_seed = mustUnhex(tv.t, raw.PublicKeySeed) 729 | tv.odoh_configs = mustUnhex(tv.t, raw.Configs) 730 | tv.key_id = mustUnhex(tv.t, raw.KeyId) 731 | 732 | tv.transactions = raw.Transactions 733 | return nil 734 | } 735 | 736 | type testVectorArray struct { 737 | t *testing.T 738 | vectors []testVector 739 | } 740 | 741 | func (tva testVectorArray) MarshalJSON() ([]byte, error) { 742 | return json.Marshal(tva.vectors) 743 | } 744 | 745 | func (tva *testVectorArray) UnmarshalJSON(data []byte) error { 746 | err := json.Unmarshal(data, &tva.vectors) 747 | if err != nil { 748 | return err 749 | } 750 | 751 | for i := range tva.vectors { 752 | tva.vectors[i].t = tva.t 753 | } 754 | return nil 755 | } 756 | 757 | func generateRandomData(n int) []byte { 758 | data := make([]byte, n) 759 | _, err := rand.Read(data) 760 | if err != nil { 761 | panic(err) 762 | } 763 | return data 764 | } 765 | 766 | func generateTransaction(t *testing.T, kp ObliviousDoHKeyPair, querySize int, queryPadding, responsePadding uint16) transactionTestVector { 767 | publicKey := kp.Config.Contents 768 | 769 | mockQueryData := generateRandomData(querySize) 770 | mockResponseData := append(mockQueryData, mockQueryData...) // answer = query || query 771 | 772 | mockQuery := CreateObliviousDNSQuery(mockQueryData, queryPadding) 773 | mockResponse := CreateObliviousDNSResponse(mockResponseData, responsePadding) 774 | 775 | // Run the query/response transaction 776 | obliviousQuery, queryContext, err := publicKey.EncryptQuery(mockQuery) 777 | if err != nil { 778 | t.Fatalf("Query encryption failed: %v", err) 779 | } 780 | 781 | recoveredQuery, responseContext, err := kp.DecryptQuery(obliviousQuery) 782 | if !bytes.Equal(recoveredQuery.Marshal(), mockQuery.Marshal()) { 783 | t.Fatalf("Query decryption did not match plaintext value: %v", err) 784 | } 785 | 786 | obliviousResponse, err := responseContext.EncryptResponse(mockResponse) 787 | if err != nil { 788 | t.Fatalf("Response encryption failed: %v", err) 789 | } 790 | 791 | responseData, err := queryContext.OpenAnswer(obliviousResponse) 792 | if err != nil || !bytes.Equal(responseData, mockResponseData) { 793 | t.Fatalf("Decryption of the result does not match encrypted value: %v", err) 794 | } 795 | 796 | return transactionTestVector{ 797 | query: mockQueryData, 798 | queryPaddingLength: queryPadding, 799 | obliviousQuery: obliviousQuery, 800 | response: mockResponseData, 801 | responsePaddingLength: responsePadding, 802 | obliviousResponse: obliviousResponse, 803 | } 804 | } 805 | 806 | func generateTestVector(t *testing.T, kem_id hpke.KEMID, kdf_id hpke.KDFID, aead_id hpke.AEADID) testVector { 807 | kp, err := CreateKeyPair(kem_id, kdf_id, aead_id) 808 | if err != nil { 809 | t.Fatalf("Unable to create a Key Pair") 810 | } 811 | 812 | queryBlockPaddingLengths := []int{0, 32, 64, 128} 813 | responseBlockPaddingLengths := []int{0, 128, 256, 468} 814 | queryLength := 32 815 | 816 | transactions := make([]transactionTestVector, 0) 817 | for _, queryBlockLength := range queryBlockPaddingLengths { 818 | for _, responseBlockLength := range responseBlockPaddingLengths { 819 | queryPadding := 0 820 | if queryBlockLength > 0 { 821 | queryPadding = queryBlockLength - (queryLength % queryBlockLength) 822 | } 823 | responsePadding := 0 824 | if responseBlockLength > 0 { 825 | responsePadding = responseBlockLength - ((queryLength * 2) % responseBlockLength) 826 | } 827 | 828 | transactions = append(transactions, generateTransaction(t, kp, queryLength, uint16(queryPadding), uint16(responsePadding))) 829 | } 830 | } 831 | 832 | configs := []ObliviousDoHConfig{kp.Config} 833 | vector := testVector{ 834 | t: t, 835 | kem_id: kem_id, 836 | kdf_id: kdf_id, 837 | aead_id: aead_id, 838 | odoh_configs: CreateObliviousDoHConfigs(configs).Marshal(), 839 | public_key_seed: kp.Seed, 840 | key_id: kp.Config.Contents.KeyID(), 841 | transactions: transactions, 842 | } 843 | 844 | return vector 845 | } 846 | 847 | func verifyTestVector(t *testing.T, tv testVector) { 848 | configs, err := UnmarshalObliviousDoHConfigs(tv.odoh_configs) 849 | assertNotError(t, "UnmarshalObliviousDoHConfigs failed", err) 850 | 851 | config := configs.Configs[0] 852 | 853 | kp, err := CreateKeyPairFromSeed(config.Contents.KemID, config.Contents.KdfID, config.Contents.AeadID, tv.public_key_seed) 854 | assertNotError(t, "CreateKeyPairFromSeed failed", err) 855 | 856 | expectedKeyId := kp.Config.Contents.KeyID() 857 | assertBytesEqual(t, "KeyID mismatch", expectedKeyId, tv.key_id) 858 | 859 | for _, transaction := range tv.transactions { 860 | query, responseContext, err := kp.DecryptQuery(transaction.obliviousQuery) 861 | assertNotError(t, "Query decryption failed", err) 862 | assertBytesEqual(t, "Query decryption mismatch", query.DnsMessage, transaction.query) 863 | 864 | testResponse := CreateObliviousDNSResponse(transaction.response, transaction.responsePaddingLength) 865 | obliviousResponse, err := responseContext.EncryptResponse(testResponse) 866 | assertNotError(t, "Response encryption failed", err) 867 | assertBytesEqual(t, "Response encryption mismatch", obliviousResponse.Marshal(), transaction.obliviousResponse.Marshal()) 868 | 869 | // Rebuild decryption context, since we don't control the client's ephemeral key 870 | queryContext := QueryContext{ 871 | odohSecret: responseContext.odohSecret, 872 | query: query.Marshal(), 873 | suite: responseContext.suite, 874 | } 875 | response, err := queryContext.OpenAnswer(obliviousResponse) 876 | assertNotError(t, "Response decryption failed", err) 877 | assertBytesEqual(t, "Final response encryption mismatch", response, transaction.response) 878 | } 879 | } 880 | 881 | func vectorTest(vector testVector) func(t *testing.T) { 882 | return func(t *testing.T) { 883 | verifyTestVector(t, vector) 884 | } 885 | } 886 | 887 | func verifyTestVectors(t *testing.T, vectorString []byte, subtest bool) { 888 | vectors := testVectorArray{t: t} 889 | err := json.Unmarshal(vectorString, &vectors) 890 | if err != nil { 891 | t.Fatalf("Error decoding test vector string: %v", err) 892 | } 893 | 894 | for _, tv := range vectors.vectors { 895 | test := vectorTest(tv) 896 | if !subtest { 897 | test(t) 898 | } else { 899 | label := fmt.Sprintf("odohconfigs=%x", tv.odoh_configs) 900 | t.Run(label, test) 901 | } 902 | } 903 | } 904 | 905 | func TestVectorGenerate(t *testing.T) { 906 | // This is the mandatory HPKE ciphersuite 907 | supportedKEMs := []hpke.KEMID{hpke.DHKEM_X25519} 908 | supportedKDFs := []hpke.KDFID{hpke.KDF_HKDF_SHA256} 909 | supportedAEADs := []hpke.AEADID{hpke.AEAD_AESGCM128} 910 | 911 | vectors := make([]testVector, 0) 912 | for _, kem_id := range supportedKEMs { 913 | for _, kdf_id := range supportedKDFs { 914 | for _, aead_id := range supportedAEADs { 915 | vectors = append(vectors, generateTestVector(t, kem_id, kdf_id, aead_id)) 916 | } 917 | } 918 | } 919 | 920 | // Encode the test vectors 921 | encoded, err := json.Marshal(vectors) 922 | if err != nil { 923 | t.Fatalf("Error producing test vectors: %v", err) 924 | } 925 | 926 | // Verify that we process them correctly 927 | verifyTestVectors(t, encoded, false) 928 | 929 | // Write them to a file if requested 930 | var outputFile string 931 | if outputFile = os.Getenv(outputTestVectorEnvironmentKey); len(outputFile) > 0 { 932 | err = ioutil.WriteFile(outputFile, encoded, 0644) 933 | if err != nil { 934 | t.Fatalf("Error writing test vectors: %v", err) 935 | } 936 | } 937 | } 938 | 939 | func TestVectorVerify(t *testing.T) { 940 | var inputFile string 941 | if inputFile = os.Getenv(inputTestVectorEnvironmentKey); len(inputFile) == 0 { 942 | t.Skip("Test vectors were not provided") 943 | } 944 | 945 | encoded, err := ioutil.ReadFile(inputFile) 946 | if err != nil { 947 | t.Fatalf("Failed reading test vectors: %v", err) 948 | } 949 | 950 | verifyTestVectors(t, encoded, true) 951 | } 952 | --------------------------------------------------------------------------------