├── .github └── workflows │ └── go.yml ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── common.go ├── common_test.go ├── credential.go ├── credential_test.go ├── crypto.go ├── crypto_test.go ├── extensions.go ├── extensions_test.go ├── go.mod ├── go.sum ├── key-schedule.go ├── key-schedule_test.go ├── messages.go ├── messages_test.go ├── profile.cov ├── state.go ├── state_test.go ├── test-vectors_test.go ├── tree-math.go ├── tree-math_test.go ├── treekem.go └── treekem_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.13 20 | id: go 21 | 22 | - name: Check out code 23 | uses: actions/checkout@v2 24 | 25 | - name: Get dependencies 26 | run: | 27 | go get -v -t -d ./... 28 | if [ -f Gopkg.toml ]; then 29 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh 30 | dep ensure 31 | fi 32 | 33 | - name: Build 34 | run: go build -v . 35 | 36 | - name: Test 37 | run: go test -race -covermode atomic -coverprofile=profile.cov ./... 38 | 39 | - name: Send Coverage 40 | env: 41 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 42 | run: | 43 | GO111MODULE=off go get github.com/mattn/goveralls 44 | $(go env GOPATH)/bin/goveralls -coverprofile=profile.cov -service=github 45 | 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | doc/MLS-Protocol.pdf 2 | doc/auto/* 3 | # Created by https://www.gitignore.io/api/emacs,go,latex 4 | 5 | ### Emacs ### 6 | # -*- mode: gitignore; -*- 7 | *~ 8 | \#*\# 9 | /.emacs.desktop 10 | /.emacs.desktop.lock 11 | *.elc 12 | auto-save-list 13 | tramp 14 | .\#* 15 | 16 | # Org-mode 17 | .org-id-locations 18 | *_archive 19 | 20 | # flymake-mode 21 | *_flymake.* 22 | 23 | # eshell files 24 | /eshell/history 25 | /eshell/lastdir 26 | 27 | # elpa packages 28 | /elpa/ 29 | 30 | # reftex files 31 | *.rel 32 | 33 | # AUCTeX auto folder 34 | /auto/ 35 | 36 | # cask packages 37 | .cask/ 38 | dist/ 39 | 40 | # Flycheck 41 | flycheck_*.el 42 | 43 | # server auth directory 44 | /server/ 45 | 46 | # projectiles files 47 | .projectile 48 | projectile-bookmarks.eld 49 | 50 | # directory configuration 51 | .dir-locals.el 52 | 53 | # saveplace 54 | places 55 | 56 | # url cache 57 | url/cache/ 58 | 59 | # cedet 60 | ede-projects.el 61 | 62 | # smex 63 | smex-items 64 | 65 | # company-statistics 66 | company-statistics-cache.el 67 | 68 | # anaconda-mode 69 | anaconda-mode/ 70 | 71 | ### Go ### 72 | # Binaries for programs and plugins 73 | *.exe 74 | *.dll 75 | *.so 76 | *.dylib 77 | 78 | # Test binary, build with `go test -c` 79 | *.test 80 | 81 | # Output of the go coverage tool, specifically when used with LiteIDE 82 | *.out 83 | 84 | # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 85 | .glide/ 86 | 87 | ### LaTeX ### 88 | ## Core latex/pdflatex auxiliary files: 89 | *.aux 90 | *.lof 91 | *.log 92 | *.lot 93 | *.fls 94 | *.toc 95 | *.fmt 96 | *.fot 97 | *.cb 98 | *.cb2 99 | 100 | ## Intermediate documents: 101 | *.dvi 102 | *.xdv 103 | *-converted-to.* 104 | # these rules might exclude image files for figures etc. 105 | # *.ps 106 | # *.eps 107 | # *.pdf 108 | 109 | ## Generated if empty string is given at "Please type another file name for output:" 110 | .pdf 111 | 112 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 113 | *.bbl 114 | *.bcf 115 | *.blg 116 | *-blx.aux 117 | *-blx.bib 118 | *.run.xml 119 | 120 | ## Build tool auxiliary files: 121 | *.fdb_latexmk 122 | *.synctex 123 | *.synctex(busy) 124 | *.synctex.gz 125 | *.synctex.gz(busy) 126 | *.pdfsync 127 | *Notes.bib 128 | 129 | ## Auxiliary and intermediate files from other packages: 130 | # algorithms 131 | *.alg 132 | *.loa 133 | 134 | # achemso 135 | acs-*.bib 136 | 137 | # amsthm 138 | *.thm 139 | 140 | # beamer 141 | *.nav 142 | *.pre 143 | *.snm 144 | *.vrb 145 | 146 | # changes 147 | *.soc 148 | 149 | # cprotect 150 | *.cpt 151 | 152 | # elsarticle (documentclass of Elsevier journals) 153 | *.spl 154 | 155 | # endnotes 156 | *.ent 157 | 158 | # fixme 159 | *.lox 160 | 161 | # feynmf/feynmp 162 | *.mf 163 | *.mp 164 | *.t[1-9] 165 | *.t[1-9][0-9] 166 | *.tfm 167 | 168 | #(r)(e)ledmac/(r)(e)ledpar 169 | *.end 170 | *.?end 171 | *.[1-9] 172 | *.[1-9][0-9] 173 | *.[1-9][0-9][0-9] 174 | *.[1-9]R 175 | *.[1-9][0-9]R 176 | *.[1-9][0-9][0-9]R 177 | *.eledsec[1-9] 178 | *.eledsec[1-9]R 179 | *.eledsec[1-9][0-9] 180 | *.eledsec[1-9][0-9]R 181 | *.eledsec[1-9][0-9][0-9] 182 | *.eledsec[1-9][0-9][0-9]R 183 | 184 | # glossaries 185 | *.acn 186 | *.acr 187 | *.glg 188 | *.glo 189 | *.gls 190 | *.glsdefs 191 | 192 | # gnuplottex 193 | *-gnuplottex-* 194 | 195 | # gregoriotex 196 | *.gaux 197 | *.gtex 198 | 199 | # hyperref 200 | *.brf 201 | 202 | # knitr 203 | *-concordance.tex 204 | # TODO Comment the next line if you want to keep your tikz graphics files 205 | *.tikz 206 | *-tikzDictionary 207 | 208 | # listings 209 | *.lol 210 | 211 | # makeidx 212 | *.idx 213 | *.ilg 214 | *.ind 215 | *.ist 216 | 217 | # minitoc 218 | *.maf 219 | *.mlf 220 | *.mlt 221 | *.mtc[0-9]* 222 | *.slf[0-9]* 223 | *.slt[0-9]* 224 | *.stc[0-9]* 225 | 226 | # minted 227 | _minted* 228 | *.pyg 229 | 230 | # morewrites 231 | *.mw 232 | 233 | # nomencl 234 | *.nlo 235 | 236 | # pax 237 | *.pax 238 | 239 | # pdfpcnotes 240 | *.pdfpc 241 | 242 | # sagetex 243 | *.sagetex.sage 244 | *.sagetex.py 245 | *.sagetex.scmd 246 | 247 | # scrwfile 248 | *.wrt 249 | 250 | # sympy 251 | *.sout 252 | *.sympy 253 | sympy-plots-for-*.tex/ 254 | 255 | # pdfcomment 256 | *.upa 257 | *.upb 258 | 259 | # pythontex 260 | *.pytxcode 261 | pythontex-files-*/ 262 | 263 | # thmtools 264 | *.loe 265 | 266 | # TikZ & PGF 267 | *.dpth 268 | *.md5 269 | *.auxlock 270 | 271 | # todonotes 272 | *.tdo 273 | 274 | # easy-todo 275 | *.lod 276 | 277 | # xindy 278 | *.xdy 279 | 280 | # xypic precompiled matrices 281 | *.xyc 282 | 283 | # endfloat 284 | *.ttt 285 | *.fff 286 | 287 | # Latexian 288 | TSWLatexianTemp* 289 | 290 | ## Editors: 291 | # WinEdt 292 | *.bak 293 | *.sav 294 | 295 | # Texpad 296 | .texpadtmp 297 | 298 | # Kile 299 | *.backup 300 | 301 | # KBibTeX 302 | *~[0-9]* 303 | 304 | # auto folder when using emacs and auctex 305 | /auto/* 306 | 307 | # expex forward references with \gathertags 308 | *-tags.tex 309 | 310 | # End of https://www.gitignore.io/api/emacs,go,latex 311 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.13.x 4 | before_install: 5 | - go get github.com/mattn/goveralls 6 | script: 7 | - $HOME/gopath/bin/goveralls -v -service=travis-ci 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Cisco Systems 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Message Layer Security 2 | ====================== 3 | 4 | [![Coverage Status](https://coveralls.io/repos/github/cisco/go-mls/badge.svg)](https://coveralls.io/github/cisco/go-mls) 5 | 6 | This is a protocol to do group key establishment in an asynchronous, 7 | message-oriented setting. Its core ideas borrow a lot from 8 | [Asynchronous Ratchet Trees](https://eprint.iacr.org/2017/666.pdf). 9 | 10 | Right now, this is just a Go library that implements the core 11 | protocol. It is missing key things like message sequencing, 12 | deconfliction, and retransmission. The interface should not be 13 | considered stable. 14 | 15 | The most you can really do with it is run the tests: 16 | 17 | ``` 18 | > go test -v 19 | ``` 20 | 21 | The tests in `state_test.go` will illustrate the basic flows that 22 | are supported. 23 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func dup(in []byte) []byte { 8 | out := make([]byte, len(in)) 9 | copy(out, in) 10 | return out 11 | } 12 | 13 | func validateEnum(v interface{}, known ...interface{}) error { 14 | for _, kv := range known { 15 | if v == kv { 16 | return nil 17 | } 18 | } 19 | return fmt.Errorf("Unknown enum value: %v", v) 20 | } 21 | -------------------------------------------------------------------------------- /common_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "encoding/hex" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | type TestEnum uint8 11 | 12 | var ( 13 | TestEnumInvalid TestEnum = 0xFF 14 | TestEnumVal0 TestEnum = 0 15 | TestEnumVal1 TestEnum = 1 16 | ) 17 | 18 | func TestValidateEnum(t *testing.T) { 19 | err := validateEnum(TestEnumVal0, TestEnumVal0, TestEnumVal1) 20 | require.Nil(t, err) 21 | 22 | err = validateEnum(TestEnumInvalid, TestEnumVal0, TestEnumVal1) 23 | require.Error(t, err) 24 | } 25 | 26 | ////////// 27 | 28 | func unhex(h string) []byte { 29 | b, err := hex.DecodeString(h) 30 | if err != nil { 31 | panic(err) 32 | } 33 | return b 34 | } 35 | -------------------------------------------------------------------------------- /credential.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/ed25519" 6 | "crypto/elliptic" 7 | "crypto/x509" 8 | "fmt" 9 | "reflect" 10 | 11 | "github.com/cisco/go-tls-syntax" 12 | ) 13 | 14 | type CredentialType uint8 15 | 16 | const ( 17 | CredentialTypeInvalid CredentialType = 255 18 | CredentialTypeBasic CredentialType = 0 19 | CredentialTypeX509 CredentialType = 1 20 | ) 21 | 22 | func (ct CredentialType) ValidForTLS() error { 23 | return validateEnum(ct, CredentialTypeBasic, CredentialTypeX509) 24 | } 25 | 26 | // struct { 27 | // opaque identity<0..2^16-1>; 28 | // SignatureScheme algorithm; 29 | // SignaturePublicKey public_key; 30 | // } BasicCredential; 31 | type BasicCredential struct { 32 | Identity []byte `tls:"head=2"` 33 | SignatureScheme SignatureScheme 34 | PublicKey SignaturePublicKey 35 | } 36 | 37 | // case x509: 38 | // opaque cert_data<1..2^24-1>; 39 | type X509Credential struct { 40 | Chain []*x509.Certificate 41 | } 42 | 43 | func (cred X509Credential) Scheme() SignatureScheme { 44 | leaf := cred.Chain[0] 45 | switch leaf.PublicKeyAlgorithm { 46 | case x509.ECDSA: 47 | ecKey := leaf.PublicKey.(*ecdsa.PublicKey) 48 | switch ecKey.Curve { 49 | case elliptic.P256(): 50 | return ECDSA_SECP256R1_SHA256 51 | case elliptic.P521(): 52 | return ECDSA_SECP521R1_SHA512 53 | default: 54 | panic("Unsupported elliptic curve") 55 | } 56 | 57 | case x509.Ed25519: 58 | return Ed25519 59 | } 60 | 61 | panic("Unsupported algorithm in certificate") 62 | } 63 | 64 | func (cred X509Credential) PublicKey() *SignaturePublicKey { 65 | switch pub := cred.Chain[0].PublicKey.(type) { 66 | case *ecdsa.PublicKey: 67 | keyData := elliptic.Marshal(pub.Curve, pub.X, pub.Y) 68 | return &SignaturePublicKey{Data: keyData} 69 | 70 | case ed25519.PublicKey: 71 | return &SignaturePublicKey{Data: pub} 72 | } 73 | 74 | panic("Unsupported public key type in certificate") 75 | } 76 | 77 | type certChainData struct { 78 | Data []byte `tls:"head=3"` 79 | } 80 | 81 | func (cred X509Credential) Equals(other *X509Credential) bool { 82 | if len(cred.Chain) != len(other.Chain) { 83 | return false 84 | } 85 | 86 | for i, cert := range cred.Chain { 87 | if !cert.Equal(other.Chain[i]) { 88 | return false 89 | } 90 | } 91 | 92 | return true 93 | } 94 | 95 | func (cred X509Credential) MarshalTLS() ([]byte, error) { 96 | allCerts := []byte{} 97 | for _, cert := range cred.Chain { 98 | allCerts = append(allCerts, cert.Raw...) 99 | } 100 | 101 | return syntax.Marshal(certChainData{allCerts}) 102 | } 103 | 104 | func (cred *X509Credential) UnmarshalTLS(data []byte) (int, error) { 105 | allCerts := new(certChainData) 106 | read, err := syntax.Unmarshal(data, allCerts) 107 | if err != nil { 108 | return 0, err 109 | } 110 | 111 | cred.Chain, err = x509.ParseCertificates(allCerts.Data) 112 | if err != nil { 113 | return 0, err 114 | } 115 | 116 | return read, nil 117 | } 118 | 119 | // This is essentially a copy of what is in crypto/x509, but with things exposed 120 | // that are hidden in that module. 121 | type certPool struct { 122 | byKeyID map[string]*x509.Certificate 123 | byName map[string]*x509.Certificate 124 | } 125 | 126 | func newCertPool(trusted []*x509.Certificate) *certPool { 127 | pool := &certPool{ 128 | byKeyID: map[string]*x509.Certificate{}, 129 | byName: map[string]*x509.Certificate{}, 130 | } 131 | 132 | for _, cert := range trusted { 133 | ski := string(cert.SubjectKeyId) 134 | name := string(cert.RawSubject) 135 | 136 | pool.byName[name] = cert 137 | if len(ski) > 0 { 138 | pool.byKeyID[ski] = cert 139 | } 140 | } 141 | 142 | return pool 143 | } 144 | 145 | func (pool certPool) parent(cert *x509.Certificate) (*x509.Certificate, bool) { 146 | aki := string(cert.AuthorityKeyId) 147 | name := string(cert.RawIssuer) 148 | 149 | if parent, ok := pool.byKeyID[aki]; len(aki) > 0 && ok { 150 | return parent, true 151 | } 152 | 153 | if parent, ok := pool.byName[name]; ok { 154 | return parent, true 155 | } 156 | 157 | return nil, false 158 | } 159 | 160 | // XXX(RLB): This is a very simple chain validation, just looking at signatures 161 | // and whatever basic hop-by-hop policy is applied by CheckSignatureFrom. More 162 | // complex things like name constraints are not considered. They would be if we 163 | // were using x509.Certificate.Verify, but that method (1) requires a DNS name 164 | // as the authentication anchor, and (2) builds its own chain without strict 165 | // ordering. 166 | func (cred X509Credential) Verify(trusted []*x509.Certificate) error { 167 | pool := newCertPool(trusted) 168 | 169 | var curr, next *x509.Certificate 170 | for i := 0; i < len(cred.Chain)-1; i++ { 171 | curr = cred.Chain[i] 172 | next = cred.Chain[i+1] 173 | 174 | // If there is a valid signature from a trusted certificate, the chain is valid 175 | parent, ok := pool.parent(curr) 176 | if ok && curr.CheckSignatureFrom(parent) == nil { 177 | return nil 178 | } 179 | 180 | // Otherwise the cert must be signed by the next cert in the chain 181 | if err := curr.CheckSignatureFrom(next); err != nil { 182 | return err 183 | } 184 | } 185 | 186 | // If no previous certificate has been signed under a trusted certificate, 187 | // then the last certificate in the chain must be signed by a trusted 188 | // certificate 189 | last := cred.Chain[len(cred.Chain)-1] 190 | parent, ok := pool.parent(last) 191 | if !ok { 192 | return fmt.Errorf("No candidate trust anchor found") 193 | } 194 | 195 | return last.CheckSignatureFrom(parent) 196 | } 197 | 198 | // struct { 199 | // CredentialType credential_type; 200 | // select (Credential.credential_type) { 201 | // case basic: 202 | // BasicCredential; 203 | // case x509: 204 | // opaque cert_data<1..2^24-1>; 205 | // }; 206 | //} Credential; 207 | type Credential struct { 208 | X509 *X509Credential 209 | Basic *BasicCredential 210 | } 211 | 212 | func NewBasicCredential(userId []byte, scheme SignatureScheme, pub SignaturePublicKey) *Credential { 213 | basicCredential := &BasicCredential{ 214 | Identity: userId, 215 | SignatureScheme: scheme, 216 | PublicKey: pub, 217 | } 218 | return &Credential{Basic: basicCredential} 219 | } 220 | 221 | func NewX509Credential(chain []*x509.Certificate) (*Credential, error) { 222 | if len(chain) == 0 { 223 | return nil, fmt.Errorf("Malformed credential: At least one certificate is required") 224 | } 225 | 226 | x509Credential := &X509Credential{ 227 | Chain: chain, 228 | } 229 | 230 | return &Credential{X509: x509Credential}, nil 231 | } 232 | 233 | // compare the public aspects 234 | func (c Credential) Equals(o Credential) bool { 235 | switch c.Type() { 236 | case CredentialTypeX509: 237 | return c.X509.Equals(o.X509) 238 | case CredentialTypeBasic: 239 | return reflect.DeepEqual(c.Basic, o.Basic) 240 | default: 241 | panic("Malformed credential") 242 | } 243 | } 244 | 245 | func (c Credential) Type() CredentialType { 246 | switch { 247 | case c.X509 != nil: 248 | return CredentialTypeX509 249 | case c.Basic != nil: 250 | return CredentialTypeBasic 251 | default: 252 | panic("Malformed credential") 253 | } 254 | } 255 | 256 | func (c Credential) Identity() []byte { 257 | switch c.Type() { 258 | case CredentialTypeX509: 259 | return c.X509.Chain[0].RawSubject 260 | case CredentialTypeBasic: 261 | return c.Basic.Identity 262 | default: 263 | panic("mls.credential: Can't retrieve PublicKey") 264 | } 265 | } 266 | 267 | func (c Credential) Scheme() SignatureScheme { 268 | switch c.Type() { 269 | case CredentialTypeX509: 270 | return c.X509.Scheme() 271 | case CredentialTypeBasic: 272 | return c.Basic.SignatureScheme 273 | default: 274 | panic("mls.credential: Can't retrieve SignatureScheme") 275 | } 276 | } 277 | 278 | func (c Credential) PublicKey() *SignaturePublicKey { 279 | switch c.Type() { 280 | case CredentialTypeX509: 281 | return c.X509.PublicKey() 282 | case CredentialTypeBasic: 283 | return &c.Basic.PublicKey 284 | default: 285 | panic("mls.credential: Can't retrieve PublicKey") 286 | } 287 | } 288 | 289 | func (c Credential) MarshalTLS() ([]byte, error) { 290 | s := syntax.NewWriteStream() 291 | credentialType := c.Type() 292 | err := s.Write(credentialType) 293 | if err != nil { 294 | return nil, err 295 | } 296 | switch credentialType { 297 | case CredentialTypeX509: 298 | err = s.Write(c.X509) 299 | case CredentialTypeBasic: 300 | err = s.Write(c.Basic) 301 | default: 302 | err = fmt.Errorf("mls.credential: CredentialType type not allowed") 303 | } 304 | 305 | if err != nil { 306 | return nil, err 307 | } 308 | 309 | return s.Data(), nil 310 | } 311 | 312 | func (c *Credential) UnmarshalTLS(data []byte) (int, error) { 313 | s := syntax.NewReadStream(data) 314 | var credentialType CredentialType 315 | _, err := s.Read(&credentialType) 316 | if err != nil { 317 | return 0, err 318 | } 319 | 320 | switch credentialType { 321 | case CredentialTypeX509: 322 | c.X509 = new(X509Credential) 323 | _, err = s.Read(c.X509) 324 | case CredentialTypeBasic: 325 | c.Basic = new(BasicCredential) 326 | _, err = s.Read(c.Basic) 327 | default: 328 | err = fmt.Errorf("mls.credential: CredentialType type not allowed %v", err) 329 | } 330 | 331 | if err != nil { 332 | return 0, err 333 | } 334 | return s.Position(), nil 335 | } 336 | -------------------------------------------------------------------------------- /credential_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "crypto" 5 | "crypto/ed25519" 6 | "crypto/rand" 7 | "crypto/x509" 8 | "math/big" 9 | "testing" 10 | "time" 11 | 12 | "github.com/cisco/go-tls-syntax" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | var ( 17 | caTemplate = &x509.Certificate{ 18 | BasicConstraintsValid: true, 19 | IsCA: true, 20 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, 21 | } 22 | 23 | leafTemplate = &x509.Certificate{ 24 | BasicConstraintsValid: true, 25 | IsCA: false, 26 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, 27 | } 28 | ) 29 | 30 | func newEd25519(t *testing.T) ed25519.PrivateKey { 31 | _, priv, err := ed25519.GenerateKey(rand.Reader) 32 | require.Nil(t, err) 33 | return priv 34 | } 35 | 36 | func makeCert(t *testing.T, template, parent *x509.Certificate, parentPriv crypto.Signer, addSKI bool) (crypto.Signer, *x509.Certificate) { 37 | backdate := time.Hour 38 | lifetime := 24 * time.Hour 39 | skiSize := 4 // bytes 40 | 41 | // Set expiry 42 | template.NotBefore = time.Now().Add(-backdate) 43 | template.NotAfter = template.NotBefore.Add(lifetime) 44 | 45 | // Set serial number 46 | serialNumberLimit := big.NewInt(0).Lsh(big.NewInt(1), 128) 47 | serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 48 | require.Nil(t, err) 49 | template.SerialNumber = serialNumber 50 | 51 | // Add random SKI if requried 52 | template.SubjectKeyId = nil 53 | if addSKI { 54 | template.SubjectKeyId = make([]byte, skiSize) 55 | rand.Read(template.SubjectKeyId) 56 | } 57 | 58 | // Generate and parse the certificate 59 | priv := parentPriv 60 | realParent := template 61 | if parent != nil { 62 | priv = newEd25519(t) 63 | realParent = parent 64 | } 65 | 66 | certData, err := x509.CreateCertificate(rand.Reader, template, realParent, priv.Public(), parentPriv) 67 | require.Nil(t, err) 68 | cert, err := x509.ParseCertificate(certData) 69 | require.Nil(t, err) 70 | return priv, cert 71 | } 72 | 73 | func makeCertChain(t *testing.T, rootPriv crypto.Signer, depth int, addSKI bool) (*SignaturePrivateKey, *x509.Certificate, []*x509.Certificate) { 74 | chain := make([]*x509.Certificate, depth) 75 | 76 | _, rootCert := makeCert(t, caTemplate, nil, rootPriv, addSKI) 77 | 78 | currPriv := rootPriv 79 | cert := rootCert 80 | for i := depth - 1; i > 0; i-- { 81 | currPriv, cert = makeCert(t, caTemplate, cert, currPriv, addSKI) 82 | chain[i] = cert 83 | } 84 | 85 | currPriv, cert = makeCert(t, leafTemplate, cert, currPriv, addSKI) 86 | chain[0] = cert 87 | 88 | sigPriv := &SignaturePrivateKey{ 89 | Data: currPriv.(ed25519.PrivateKey), 90 | PublicKey: SignaturePublicKey{ 91 | Data: currPriv.Public().(ed25519.PublicKey), 92 | }, 93 | } 94 | 95 | return sigPriv, rootCert, chain 96 | } 97 | 98 | func makeX509Credential(t *testing.T, depth int, addSKI bool) (*Credential, *x509.Certificate) { 99 | rootPriv := newEd25519(t) 100 | _, rootCert, chain := makeCertChain(t, rootPriv, depth, addSKI) 101 | 102 | cred, err := NewX509Credential(chain) 103 | require.Nil(t, err) 104 | return cred, rootCert 105 | } 106 | 107 | func TestBasicCredential(t *testing.T) { 108 | identity := []byte("res ipsa") 109 | scheme := Ed25519 110 | priv, err := scheme.Generate() 111 | require.Nil(t, err) 112 | 113 | cred := NewBasicCredential(identity, scheme, priv.PublicKey) 114 | require.True(t, cred.Equals(*cred)) 115 | require.Equal(t, cred.Type(), CredentialTypeBasic) 116 | require.Equal(t, cred.Scheme(), scheme) 117 | require.Equal(t, *cred.PublicKey(), priv.PublicKey) 118 | 119 | credData, err := syntax.Marshal(cred) 120 | require.Nil(t, err) 121 | 122 | cred2 := new(Credential) 123 | _, err = syntax.Unmarshal(credData, cred2) 124 | require.Nil(t, err) 125 | } 126 | 127 | func TestX509Credential(t *testing.T) { 128 | cred, _ := makeX509Credential(t, 3, true) 129 | 130 | require.NotNil(t, cred) 131 | require.True(t, cred.Equals(*cred)) 132 | require.Equal(t, cred.Type(), CredentialTypeX509) 133 | require.Equal(t, cred.Scheme(), Ed25519) 134 | require.NotNil(t, cred.PublicKey()) 135 | 136 | credData, err := syntax.Marshal(cred) 137 | require.Nil(t, err) 138 | 139 | cred2 := new(Credential) 140 | _, err = syntax.Unmarshal(credData, cred2) 141 | require.Nil(t, err) 142 | } 143 | 144 | func TestX509CredentialOne(t *testing.T) { 145 | cred, root := makeX509Credential(t, 1, false) 146 | trusted := []*x509.Certificate{root} 147 | require.Nil(t, cred.X509.Verify(trusted)) 148 | } 149 | 150 | func TestX509CredentialVerifyByName(t *testing.T) { 151 | cred, root := makeX509Credential(t, 3, false) 152 | trusted := []*x509.Certificate{root} 153 | require.Nil(t, cred.X509.Verify(trusted)) 154 | } 155 | 156 | func TestX509CredentialVerifyBySKI(t *testing.T) { 157 | cred, root := makeX509Credential(t, 3, true) 158 | trusted := []*x509.Certificate{root} 159 | require.Nil(t, cred.X509.Verify(trusted)) 160 | } 161 | 162 | func TestCredentialErrorCases(t *testing.T) { 163 | cred := Credential{} 164 | 165 | require.Panics(t, func() { cred.Equals(cred) }) 166 | require.Panics(t, func() { cred.Type() }) 167 | require.Panics(t, func() { cred.PublicKey() }) 168 | require.Panics(t, func() { cred.Scheme() }) 169 | require.Panics(t, func() { syntax.Marshal(cred) }) 170 | 171 | // No certificate chain for X.509 Credential 172 | _, err := NewX509Credential(nil) 173 | require.Error(t, err) 174 | } 175 | -------------------------------------------------------------------------------- /crypto.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "crypto/ecdsa" 8 | "crypto/elliptic" 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "crypto/sha256" 12 | "crypto/sha512" 13 | "encoding/asn1" 14 | "fmt" 15 | "hash" 16 | "math/big" 17 | 18 | "github.com/cisco/go-hpke" 19 | "github.com/cisco/go-tls-syntax" 20 | "golang.org/x/crypto/chacha20poly1305" 21 | "golang.org/x/crypto/ed25519" 22 | ) 23 | 24 | type CipherSuite uint16 25 | 26 | const ( 27 | X25519_AES128GCM_SHA256_Ed25519 CipherSuite = 0x0001 28 | P256_AES128GCM_SHA256_P256 CipherSuite = 0x0002 29 | X25519_CHACHA20POLY1305_SHA256_Ed25519 CipherSuite = 0x0003 30 | X448_AES256GCM_SHA512_Ed448 CipherSuite = 0x0004 // UNSUPPORTED 31 | P521_AES256GCM_SHA512_P521 CipherSuite = 0x0005 32 | X448_CHACHA20POLY1305_SHA512_Ed448 CipherSuite = 0x0006 // UNSUPPORTED 33 | ) 34 | 35 | func (cs CipherSuite) supported() bool { 36 | switch cs { 37 | case X25519_AES128GCM_SHA256_Ed25519, 38 | P256_AES128GCM_SHA256_P256, 39 | P521_AES256GCM_SHA512_P521, 40 | X25519_CHACHA20POLY1305_SHA256_Ed25519: 41 | return true 42 | } 43 | 44 | return false 45 | } 46 | 47 | func (cs CipherSuite) String() string { 48 | switch cs { 49 | case X25519_AES128GCM_SHA256_Ed25519: 50 | return "X25519_AES128GCM_SHA256_Ed25519" 51 | case P256_AES128GCM_SHA256_P256: 52 | return "P256_AES128GCM_SHA256_P256" 53 | case X25519_CHACHA20POLY1305_SHA256_Ed25519: 54 | return "X25519_CHACHA20POLY1305_SHA256_Ed25519" 55 | case X448_AES256GCM_SHA512_Ed448: 56 | return "X448_AES256GCM_SHA512_Ed448" 57 | case P521_AES256GCM_SHA512_P521: 58 | return "P521_AES256GCM_SHA512_P521" 59 | case X448_CHACHA20POLY1305_SHA512_Ed448: 60 | return "X448_CHACHA20POLY1305_SHA512_Ed448" 61 | } 62 | 63 | return "UnknownCipherSuite" 64 | } 65 | 66 | type cipherConstants struct { 67 | KeySize int 68 | NonceSize int 69 | SecretSize int 70 | HPKEKEM hpke.KEMID 71 | HPKEKDF hpke.KDFID 72 | HPKEAEAD hpke.AEADID 73 | } 74 | 75 | func (cs CipherSuite) Constants() cipherConstants { 76 | switch cs { 77 | case X25519_AES128GCM_SHA256_Ed25519: 78 | return cipherConstants{ 79 | KeySize: 16, 80 | NonceSize: 12, 81 | SecretSize: 32, 82 | HPKEKEM: hpke.DHKEM_X25519, 83 | HPKEKDF: hpke.KDF_HKDF_SHA256, 84 | HPKEAEAD: hpke.AEAD_AESGCM128, 85 | } 86 | case P256_AES128GCM_SHA256_P256: 87 | return cipherConstants{ 88 | KeySize: 16, 89 | NonceSize: 12, 90 | SecretSize: 32, 91 | HPKEKEM: hpke.DHKEM_P256, 92 | HPKEKDF: hpke.KDF_HKDF_SHA256, 93 | HPKEAEAD: hpke.AEAD_AESGCM128, 94 | } 95 | case X25519_CHACHA20POLY1305_SHA256_Ed25519: 96 | return cipherConstants{ 97 | KeySize: 32, 98 | NonceSize: 12, 99 | SecretSize: 32, 100 | HPKEKEM: hpke.DHKEM_X25519, 101 | HPKEKDF: hpke.KDF_HKDF_SHA256, 102 | HPKEAEAD: hpke.AEAD_CHACHA20POLY1305, 103 | } 104 | case P521_AES256GCM_SHA512_P521: 105 | return cipherConstants{ 106 | KeySize: 32, 107 | NonceSize: 12, 108 | SecretSize: 64, 109 | HPKEKEM: hpke.DHKEM_P521, 110 | HPKEKDF: hpke.KDF_HKDF_SHA512, 111 | HPKEAEAD: hpke.AEAD_AESGCM256, 112 | } 113 | } 114 | 115 | panic("Unsupported ciphersuite") 116 | } 117 | 118 | func (cs CipherSuite) Scheme() SignatureScheme { 119 | switch cs { 120 | case X25519_AES128GCM_SHA256_Ed25519: 121 | return Ed25519 122 | case P256_AES128GCM_SHA256_P256: 123 | return ECDSA_SECP256R1_SHA256 124 | case X25519_CHACHA20POLY1305_SHA256_Ed25519: 125 | return Ed25519 126 | case P521_AES256GCM_SHA512_P521: 127 | return ECDSA_SECP521R1_SHA512 128 | } 129 | 130 | panic("Unsupported ciphersuite") 131 | } 132 | 133 | func (cs CipherSuite) zero() []byte { 134 | return bytes.Repeat([]byte{0x00}, cs.newDigest().Size()) 135 | } 136 | 137 | func (cs CipherSuite) newDigest() hash.Hash { 138 | switch cs { 139 | case X25519_AES128GCM_SHA256_Ed25519, P256_AES128GCM_SHA256_P256, 140 | X25519_CHACHA20POLY1305_SHA256_Ed25519: 141 | return sha256.New() 142 | 143 | case X448_AES256GCM_SHA512_Ed448, P521_AES256GCM_SHA512_P521: 144 | return sha512.New() 145 | } 146 | 147 | panic("Unsupported ciphersuite") 148 | } 149 | 150 | func (cs CipherSuite) Digest(data []byte) []byte { 151 | d := cs.newDigest() 152 | d.Write(data) 153 | return d.Sum(nil) 154 | } 155 | 156 | func (cs CipherSuite) NewHMAC(key []byte) hash.Hash { 157 | return hmac.New(cs.newDigest, key) 158 | } 159 | 160 | func (cs CipherSuite) NewAEAD(key []byte) (cipher.AEAD, error) { 161 | switch cs { 162 | case X25519_AES128GCM_SHA256_Ed25519, P256_AES128GCM_SHA256_P256: 163 | fallthrough 164 | case X448_AES256GCM_SHA512_Ed448, P521_AES256GCM_SHA512_P521: 165 | block, err := aes.NewCipher(key) 166 | if err != nil { 167 | return nil, err 168 | } 169 | 170 | return cipher.NewGCM(block) 171 | case X25519_CHACHA20POLY1305_SHA256_Ed25519: 172 | return chacha20poly1305.New(key) 173 | } 174 | 175 | panic("Unsupported ciphersuite") 176 | } 177 | 178 | func (cs CipherSuite) hkdfExtract(salt, ikm []byte) []byte { 179 | mac := cs.NewHMAC(salt) 180 | mac.Write(ikm) 181 | return mac.Sum(nil) 182 | } 183 | 184 | func (cs CipherSuite) hkdfExpand(secret, info []byte, size int) []byte { 185 | last := []byte{} 186 | buf := []byte{} 187 | counter := byte(1) 188 | for len(buf) < size { 189 | mac := cs.NewHMAC(secret) 190 | mac.Write(last) 191 | mac.Write(info) 192 | mac.Write([]byte{counter}) 193 | 194 | last = mac.Sum(nil) 195 | counter += 1 196 | buf = append(buf, last...) 197 | } 198 | return buf[:size] 199 | } 200 | 201 | type hkdfLabel struct { 202 | Length uint16 203 | Label []byte `tls:"head=1"` 204 | Context []byte `tls:"head=4"` 205 | } 206 | 207 | func (cs CipherSuite) hkdfExpandLabel(secret []byte, label string, context []byte, length int) []byte { 208 | mlsLabel := []byte("mls10 " + label) 209 | labelData, err := syntax.Marshal(hkdfLabel{uint16(length), mlsLabel, context}) 210 | if err != nil { 211 | panic(fmt.Errorf("Error marshaling HKDF label: %v", err)) 212 | } 213 | return cs.hkdfExpand(secret, labelData, length) 214 | } 215 | 216 | func (cs CipherSuite) deriveSecret(secret []byte, label string, context []byte) []byte { 217 | contextHash := cs.Digest(context) 218 | size := cs.Constants().SecretSize 219 | return cs.hkdfExpandLabel(secret, label, contextHash, size) 220 | } 221 | 222 | type applicationContext struct { 223 | Node NodeIndex 224 | Generation uint32 225 | } 226 | 227 | func (cs CipherSuite) deriveAppSecret(secret []byte, label string, node NodeIndex, generation uint32, length int) []byte { 228 | ctx, err := syntax.Marshal(applicationContext{node, generation}) 229 | if err != nil { 230 | panic(fmt.Errorf("Error marshaling application context: %v", err)) 231 | } 232 | 233 | return cs.hkdfExpandLabel(secret, label, ctx, length) 234 | } 235 | 236 | func (cs CipherSuite) hpke() HPKEInstance { 237 | cc := cs.Constants() 238 | suite, err := hpke.AssembleCipherSuite(cc.HPKEKEM, cc.HPKEKDF, cc.HPKEAEAD) 239 | if err != nil { 240 | panic("Unable to construct HPKE ciphersuite") 241 | } 242 | 243 | return HPKEInstance{cs, suite} 244 | } 245 | 246 | /// 247 | /// HPKE 248 | /// 249 | 250 | type HPKEPrivateKey struct { 251 | Data []byte `tls:"head=2"` 252 | PublicKey HPKEPublicKey 253 | } 254 | 255 | type HPKEPublicKey struct { 256 | Data []byte `tls:"head=2"` 257 | } 258 | 259 | func (k HPKEPublicKey) Equals(o HPKEPublicKey) bool { 260 | return bytes.Equal(k.Data, o.Data) 261 | } 262 | 263 | type HPKECiphertext struct { 264 | KEMOutput []byte `tls:"head=2"` 265 | Ciphertext []byte `tls:"head=4"` 266 | } 267 | 268 | type HPKEInstance struct { 269 | BaseSuite CipherSuite 270 | Suite hpke.CipherSuite 271 | } 272 | 273 | func (h HPKEInstance) Generate() (HPKEPrivateKey, error) { 274 | priv, pub, err := h.Suite.KEM.GenerateKeyPair(rand.Reader) 275 | if err != nil { 276 | return HPKEPrivateKey{}, err 277 | } 278 | 279 | key := HPKEPrivateKey{ 280 | Data: h.Suite.KEM.MarshalPrivate(priv), 281 | PublicKey: HPKEPublicKey{h.Suite.KEM.Marshal(pub)}, 282 | } 283 | return key, nil 284 | } 285 | 286 | func (h HPKEInstance) Derive(seed []byte) (HPKEPrivateKey, error) { 287 | keyPairSecretSize := 0 288 | switch h.BaseSuite.Constants().HPKEKEM { 289 | case hpke.DHKEM_X25519: 290 | keyPairSecretSize = 32 291 | case hpke.DHKEM_P256: 292 | keyPairSecretSize = 32 293 | case hpke.DHKEM_P521: 294 | keyPairSecretSize = 66 295 | case hpke.DHKEM_X448: 296 | keyPairSecretSize = 56 297 | } 298 | 299 | cs := h.BaseSuite 300 | keyPairSecret := cs.hkdfExpandLabel(seed, "key pair", []byte{}, keyPairSecretSize) 301 | 302 | var priv hpke.KEMPrivateKey 303 | var err error 304 | switch h.BaseSuite.Constants().HPKEKEM { 305 | case hpke.DHKEM_P256, hpke.DHKEM_P521, hpke.DHKEM_X25519: 306 | priv, err = h.Suite.KEM.UnmarshalPrivate(keyPairSecret) 307 | case hpke.DHKEM_X448: 308 | priv, err = h.Suite.KEM.UnmarshalPrivate(keyPairSecret) 309 | } 310 | 311 | if err != nil { 312 | return HPKEPrivateKey{}, err 313 | } 314 | 315 | pub := priv.PublicKey() 316 | key := HPKEPrivateKey{ 317 | Data: h.Suite.KEM.MarshalPrivate(priv), 318 | PublicKey: HPKEPublicKey{h.Suite.KEM.Marshal(pub)}, 319 | } 320 | return key, nil 321 | } 322 | 323 | func (h HPKEInstance) Encrypt(pub HPKEPublicKey, aad, pt []byte) (HPKECiphertext, error) { 324 | pkR, err := h.Suite.KEM.Unmarshal(pub.Data) 325 | if err != nil { 326 | return HPKECiphertext{}, err 327 | } 328 | 329 | enc, ctx, err := hpke.SetupBaseS(h.Suite, rand.Reader, pkR, nil) 330 | if err != nil { 331 | return HPKECiphertext{}, err 332 | } 333 | 334 | ct := ctx.Seal(aad, pt) 335 | return HPKECiphertext{enc, ct}, nil 336 | } 337 | 338 | func (h HPKEInstance) Decrypt(priv HPKEPrivateKey, aad []byte, ct HPKECiphertext) ([]byte, error) { 339 | skR, err := h.Suite.KEM.UnmarshalPrivate(priv.Data) 340 | if err != nil { 341 | return nil, err 342 | } 343 | 344 | ctx, err := hpke.SetupBaseR(h.Suite, skR, ct.KEMOutput, nil) 345 | if err != nil { 346 | return nil, err 347 | } 348 | 349 | return ctx.Open(aad, ct.Ciphertext) 350 | } 351 | 352 | /// 353 | /// Signing 354 | /// 355 | 356 | type SignaturePrivateKey struct { 357 | Data []byte `tls:"head=2"` 358 | PublicKey SignaturePublicKey 359 | } 360 | 361 | type SignaturePublicKey struct { 362 | Data []byte `tls:"head=2"` 363 | } 364 | 365 | func (pub SignaturePublicKey) Equals(other SignaturePublicKey) bool { 366 | return bytes.Equal(pub.Data, other.Data) 367 | } 368 | 369 | type SignatureScheme uint16 370 | 371 | const ( 372 | ECDSA_SECP256R1_SHA256 SignatureScheme = 0x0403 373 | ECDSA_SECP521R1_SHA512 SignatureScheme = 0x0603 374 | Ed25519 SignatureScheme = 0x0807 375 | ) 376 | 377 | func (ss SignatureScheme) supported() bool { 378 | switch ss { 379 | case ECDSA_SECP256R1_SHA256, ECDSA_SECP521R1_SHA512, Ed25519: 380 | return true 381 | } 382 | 383 | return false 384 | } 385 | 386 | func (ss SignatureScheme) String() string { 387 | switch ss { 388 | case ECDSA_SECP256R1_SHA256: 389 | return "ECDSA_SECP256R1_SHA256" 390 | case ECDSA_SECP521R1_SHA512: 391 | return "ECDSA_SECP521R1_SHA512" 392 | case Ed25519: 393 | return "Ed25519" 394 | } 395 | 396 | return "UnknownSignatureScheme" 397 | } 398 | 399 | func (ss SignatureScheme) Derive(preSeed []byte) (SignaturePrivateKey, error) { 400 | switch ss { 401 | case ECDSA_SECP256R1_SHA256: 402 | h := sha256.New() 403 | h.Write(preSeed) 404 | priv := h.Sum(nil) 405 | 406 | curve := elliptic.P256() 407 | x, y := curve.Params().ScalarBaseMult(priv) 408 | pub := elliptic.Marshal(curve, x, y) 409 | key := SignaturePrivateKey{ 410 | Data: priv, 411 | PublicKey: SignaturePublicKey{pub}, 412 | } 413 | return key, nil 414 | 415 | case ECDSA_SECP521R1_SHA512: 416 | h := sha512.New() 417 | h.Write(preSeed) 418 | priv := h.Sum(nil) 419 | 420 | curve := elliptic.P521() 421 | x, y := curve.Params().ScalarBaseMult(priv) 422 | pub := elliptic.Marshal(curve, x, y) 423 | key := SignaturePrivateKey{ 424 | Data: priv, 425 | PublicKey: SignaturePublicKey{pub}, 426 | } 427 | return key, nil 428 | 429 | case Ed25519: 430 | h := sha256.New() 431 | h.Write(preSeed) 432 | seed := h.Sum(nil) 433 | priv := ed25519.NewKeyFromSeed(seed) 434 | pub := priv.Public().(ed25519.PublicKey) 435 | key := SignaturePrivateKey{ 436 | Data: priv, 437 | PublicKey: SignaturePublicKey{pub}, 438 | } 439 | return key, nil 440 | } 441 | panic("Unsupported algorithm") 442 | } 443 | 444 | func (ss SignatureScheme) Generate() (SignaturePrivateKey, error) { 445 | switch ss { 446 | case ECDSA_SECP256R1_SHA256: 447 | curve := elliptic.P256() 448 | priv, x, y, err := elliptic.GenerateKey(curve, rand.Reader) 449 | if err != nil { 450 | return SignaturePrivateKey{}, err 451 | } 452 | 453 | pub := elliptic.Marshal(curve, x, y) 454 | key := SignaturePrivateKey{ 455 | Data: priv, 456 | PublicKey: SignaturePublicKey{pub}, 457 | } 458 | return key, nil 459 | 460 | case ECDSA_SECP521R1_SHA512: 461 | curve := elliptic.P521() 462 | priv, x, y, err := elliptic.GenerateKey(curve, rand.Reader) 463 | if err != nil { 464 | return SignaturePrivateKey{}, err 465 | } 466 | 467 | pub := elliptic.Marshal(curve, x, y) 468 | key := SignaturePrivateKey{ 469 | Data: priv, 470 | PublicKey: SignaturePublicKey{pub}, 471 | } 472 | return key, nil 473 | 474 | case Ed25519: 475 | pub, priv, err := ed25519.GenerateKey(rand.Reader) 476 | if err != nil { 477 | return SignaturePrivateKey{}, err 478 | } 479 | 480 | key := SignaturePrivateKey{ 481 | Data: priv, 482 | PublicKey: SignaturePublicKey{pub}, 483 | } 484 | return key, nil 485 | } 486 | panic("Unsupported algorithm") 487 | } 488 | 489 | type ecdsaSignature struct { 490 | R, S *big.Int 491 | } 492 | 493 | func (ss SignatureScheme) Sign(priv *SignaturePrivateKey, message []byte) ([]byte, error) { 494 | switch ss { 495 | case ECDSA_SECP256R1_SHA256: 496 | h := sha256.New() 497 | h.Write(message) 498 | digest := h.Sum(nil) 499 | 500 | ecPriv := &ecdsa.PrivateKey{ 501 | D: big.NewInt(0).SetBytes(priv.Data), 502 | PublicKey: ecdsa.PublicKey{ 503 | Curve: elliptic.P256(), 504 | }, 505 | } 506 | return ecPriv.Sign(rand.Reader, digest, nil) 507 | 508 | case ECDSA_SECP521R1_SHA512: 509 | h := sha512.New() 510 | h.Write(message) 511 | digest := h.Sum(nil) 512 | 513 | ecPriv := &ecdsa.PrivateKey{ 514 | D: big.NewInt(0).SetBytes(priv.Data), 515 | PublicKey: ecdsa.PublicKey{ 516 | Curve: elliptic.P521(), 517 | }, 518 | } 519 | return ecPriv.Sign(rand.Reader, digest, nil) 520 | 521 | case Ed25519: 522 | priv25519 := ed25519.PrivateKey(priv.Data) 523 | return ed25519.Sign(priv25519, message), nil 524 | } 525 | panic("Unsupported algorithm") 526 | } 527 | 528 | func (ss SignatureScheme) Verify(pub *SignaturePublicKey, message, signature []byte) bool { 529 | switch ss { 530 | case ECDSA_SECP256R1_SHA256: 531 | h := sha256.New() 532 | h.Write(message) 533 | digest := h.Sum(nil) 534 | 535 | curve := elliptic.P256() 536 | x, y := elliptic.Unmarshal(curve, pub.Data) 537 | 538 | var sig ecdsaSignature 539 | _, err := asn1.Unmarshal(signature, &sig) 540 | if err != nil { 541 | return false 542 | } 543 | 544 | ecPub := &ecdsa.PublicKey{Curve: curve, X: x, Y: y} 545 | return ecdsa.Verify(ecPub, digest, sig.R, sig.S) 546 | 547 | case ECDSA_SECP521R1_SHA512: 548 | h := sha512.New() 549 | h.Write(message) 550 | digest := h.Sum(nil) 551 | 552 | curve := elliptic.P521() 553 | x, y := elliptic.Unmarshal(curve, pub.Data) 554 | 555 | var sig ecdsaSignature 556 | _, err := asn1.Unmarshal(signature, &sig) 557 | if err != nil { 558 | return false 559 | } 560 | 561 | ecPub := &ecdsa.PublicKey{Curve: curve, X: x, Y: y} 562 | return ecdsa.Verify(ecPub, digest, sig.R, sig.S) 563 | 564 | case Ed25519: 565 | pub25519 := ed25519.PublicKey(pub.Data) 566 | return ed25519.Verify(pub25519, message, signature) 567 | } 568 | panic("Unsupported algorithm") 569 | } 570 | -------------------------------------------------------------------------------- /crypto_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "testing" 7 | 8 | "github.com/cisco/go-tls-syntax" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | var supportedSuites = []CipherSuite{ 13 | X25519_AES128GCM_SHA256_Ed25519, 14 | P256_AES128GCM_SHA256_P256, 15 | X25519_CHACHA20POLY1305_SHA256_Ed25519, 16 | P521_AES256GCM_SHA512_P521, 17 | } 18 | 19 | var supportedSchemes = []SignatureScheme{ 20 | ECDSA_SECP256R1_SHA256, 21 | Ed25519, 22 | } 23 | 24 | func randomBytes(size int) []byte { 25 | out := make([]byte, size) 26 | rand.Read(out) 27 | return out 28 | } 29 | 30 | func TestDigest(t *testing.T) { 31 | in := unhex("6162636462636465636465666465666765666768666768696768696a68696a6b6" + 32 | "96a6b6c6a6b6c6d6b6c6d6e6c6d6e6f6d6e6f706e6f7071") 33 | out256 := unhex("248d6a61d20638b8e5c026930c3e6039a33ce45964ff2167f6ecedd419db06c1") 34 | out512 := unhex("204a8fc6dda82f0a0ced7beb8e08a41657c16ef468b228a8279be331a703c3359" + 35 | "6fd15c13b1b07f9aa1d3bea57789ca031ad85c7a71dd70354ec631238ca3445") 36 | 37 | for _, suite := range supportedSuites { 38 | var out []byte 39 | switch suite { 40 | case X25519_AES128GCM_SHA256_Ed25519, P256_AES128GCM_SHA256_P256, 41 | X25519_CHACHA20POLY1305_SHA256_Ed25519: 42 | out = out256 43 | case P521_AES256GCM_SHA512_P521: 44 | out = out512 45 | } 46 | 47 | d := suite.Digest(in) 48 | require.Equal(t, d, out) 49 | } 50 | } 51 | 52 | func TestEncryptDecrypt(t *testing.T) { 53 | // AES-GCM 54 | // https://tools.ietf.org/html/draft-mcgrew-gcm-test-01#section-4 55 | key128 := unhex("4c80cdefbb5d10da906ac73c3613a634") 56 | nonce128 := unhex("2e443b684956ed7e3b244cfe") 57 | aad128 := unhex("000043218765432100000000") 58 | pt128 := unhex("45000048699a000080114db7c0a80102c0a801010a9bf15638d3010000010000" + 59 | "00000000045f736970045f756470037369700963796265726369747902646b00" + 60 | "0021000101020201") 61 | ct128 := unhex("fecf537e729d5b07dc30df528dd22b768d1b98736696a6fd348509fa13ceac34" + 62 | "cfa2436f14a3f3cf65925bf1f4a13c5d15b21e1884f5ff6247aeabb786b93bce" + 63 | "61bc17d768fd9732459018148f6cbe722fd04796562dfdb4") 64 | 65 | key256 := unhex("abbccddef00112233445566778899aababbccddef00112233445566778899aab") 66 | nonce256 := unhex("112233440102030405060708") 67 | aad256 := unhex("4a2cbfe300000002") 68 | pt256 := unhex("4500003069a6400080062690c0a801029389155e0a9e008b2dc57ee000000000" + 69 | "7002400020bf0000020405b40101040201020201") 70 | ct256 := unhex("ff425c9b724599df7a3bcd510194e00d6a78107f1b0b1cbf06efae9d65a5d763" + 71 | "748a637985771d347f0545659f14e99def842d8eb335f4eecfdbf831824b4c49" + 72 | "15956c96") 73 | 74 | // From RFC 8439 75 | // https://tools.ietf.org/html/rfc8439#appendix-A.5 76 | keyChaCha := unhex("1c9240a5eb55d38af333888604f6b5f0473917c1402b80099dca5cbc207075c0") 77 | nonceChaCha := unhex("000000000102030405060708") 78 | aadChaCha := unhex("f33388860000000000004e91") 79 | ptChaCha := unhex("496e7465726e65742d4472616674732061726520647261667420646f63756d65" + 80 | "6e74732076616c696420666f722061206d6178696d756d206f6620736978206d" + 81 | "6f6e74687320616e64206d617920626520757064617465642c207265706c6163" + 82 | "65642c206f72206f62736f6c65746564206279206f7468657220646f63756d65" + 83 | "6e747320617420616e792074696d652e20497420697320696e617070726f7072" + 84 | "6961746520746f2075736520496e7465726e65742d4472616674732061732072" + 85 | "65666572656e6365206d6174657269616c206f7220746f206369746520746865" + 86 | "6d206f74686572207468616e206173202fe2809c776f726b20696e2070726f67" + 87 | "726573732e2fe2809d") 88 | ctChaCha := unhex("64a0861575861af460f062c79be643bd5e805cfd345cf389f108670ac76c8cb2" + 89 | "4c6cfc18755d43eea09ee94e382d26b0bdb7b73c321b0100d4f03b7f355894cf" + 90 | "332f830e710b97ce98c8a84abd0b948114ad176e008d33bd60f982b1ff37c855" + 91 | "9797a06ef4f0ef61c186324e2b3506383606907b6a7c02b0f9f6157b53c867e4" + 92 | "b9166c767b804d46a59b5216cde7a4e99040c5a40433225ee282a1b0a06c523e" + 93 | "af4534d7f83fa1155b0047718cbc546a0d072b04b3564eea1b422273f548271a" + 94 | "0bb2316053fa76991955ebd63159434ecebb4e466dae5a1073a6727627097a10" + 95 | "49e617d91d361094fa68f0ff77987130305beaba2eda04df997b714d6c6f2c29" + 96 | "a6ad5cb4022b02709beead9d67890cbb22392336fea1851f38") 97 | 98 | encryptDecrypt := func(suite CipherSuite) func(t *testing.T) { 99 | return func(t *testing.T) { 100 | var key, nonce, aad, pt, ct []byte 101 | switch suite { 102 | case X25519_AES128GCM_SHA256_Ed25519, P256_AES128GCM_SHA256_P256: 103 | key, nonce, aad, pt, ct = key128, nonce128, aad128, pt128, ct128 104 | case X25519_CHACHA20POLY1305_SHA256_Ed25519: 105 | key, nonce, aad, pt, ct = keyChaCha, nonceChaCha, aadChaCha, ptChaCha, ctChaCha 106 | case P521_AES256GCM_SHA512_P521: 107 | key, nonce, aad, pt, ct = key256, nonce256, aad256, pt256, ct256 108 | } 109 | 110 | aead, err := suite.NewAEAD(key) 111 | require.Nil(t, err) 112 | 113 | // Test encryption 114 | encrypted := aead.Seal(nil, nonce, pt, aad) 115 | require.Equal(t, ct, encrypted) 116 | 117 | // Test decryption 118 | decrypted, err := aead.Open(nil, nonce, ct, aad) 119 | require.Nil(t, err) 120 | require.Equal(t, pt, decrypted) 121 | } 122 | } 123 | 124 | for _, suite := range supportedSuites { 125 | t.Run("todo" /*suite.String()*/, encryptDecrypt(suite)) 126 | } 127 | } 128 | 129 | func TestHPKE(t *testing.T) { 130 | aad := []byte("doo-bee-doo") 131 | original := []byte("Attack at dawn!") 132 | seed := []byte("All the flowers of tomorrow are in the seeds of today") 133 | 134 | encryptDecrypt := func(suite CipherSuite) func(t *testing.T) { 135 | return func(t *testing.T) { 136 | priv, err := suite.hpke().Generate() 137 | require.Nil(t, err) 138 | 139 | priv, err = suite.hpke().Derive(seed) 140 | require.Nil(t, err) 141 | 142 | encrypted, err := suite.hpke().Encrypt(priv.PublicKey, aad, original) 143 | require.Nil(t, err) 144 | 145 | decrypted, err := suite.hpke().Decrypt(priv, aad, encrypted) 146 | require.Nil(t, err) 147 | require.Equal(t, original, decrypted) 148 | } 149 | } 150 | 151 | for _, suite := range supportedSuites { 152 | t.Run("todo" /*suite.String()*/, encryptDecrypt(suite)) 153 | } 154 | } 155 | 156 | func TestSignVerify(t *testing.T) { 157 | message := []byte("I promise Suhas five dollars") 158 | seed := []byte("All the flowers of tomorrow are in the seeds of today") 159 | 160 | signVerify := func(scheme SignatureScheme) func(t *testing.T) { 161 | return func(t *testing.T) { 162 | priv, err := scheme.Generate() 163 | require.Nil(t, err) 164 | 165 | priv, err = scheme.Derive(seed) 166 | require.Nil(t, err) 167 | 168 | signature, err := scheme.Sign(&priv, message) 169 | require.Nil(t, err) 170 | 171 | verified := scheme.Verify(&priv.PublicKey, message, signature) 172 | require.True(t, verified) 173 | } 174 | } 175 | 176 | for _, scheme := range supportedSchemes { 177 | t.Run(scheme.String(), signVerify(scheme)) 178 | } 179 | } 180 | 181 | func TestCipherSuite_String(t *testing.T) { 182 | for _, suite := range supportedSuites { 183 | require.True(t, len(suite.String()) > 0) 184 | } 185 | 186 | var badCipherSuite CipherSuite = 0x0009 187 | require.Equal(t, badCipherSuite.String(),"UnknownCipherSuite") 188 | } 189 | 190 | /// 191 | /// Test Vectors 192 | /// 193 | 194 | type CryptoTestCase struct { 195 | CipherSuite CipherSuite 196 | HKDFExtractOut []byte `tls:"head=1"` 197 | DeriveKeyPairPub HPKEPublicKey 198 | HPKEOut HPKECiphertext 199 | } 200 | 201 | type CryptoTestVectors struct { 202 | HKDFExtractSalt []byte `tls:"head=1"` 203 | HKDFExtractIKM []byte `tls:"head=1"` 204 | DeriveKeyPairSeed []byte `tls:"head=1"` 205 | HPKEAAD []byte `tls:"head=1"` 206 | HPKEPlaintext []byte `tls:"head=1"` 207 | Cases []CryptoTestCase `tls:"head=4"` 208 | } 209 | 210 | func generateCryptoVectors(t *testing.T) []byte { 211 | tv := CryptoTestVectors{ 212 | HKDFExtractSalt: []byte{0, 1, 2, 3}, 213 | HKDFExtractIKM: []byte{4, 5, 6, 7}, 214 | DeriveKeyPairSeed: []byte{0, 1, 2, 3}, 215 | HPKEAAD: bytes.Repeat([]byte{0xB1}, 128), 216 | HPKEPlaintext: bytes.Repeat([]byte{0xB2}, 128), 217 | Cases: []CryptoTestCase{ 218 | {CipherSuite: X25519_AES128GCM_SHA256_Ed25519}, 219 | {CipherSuite: P256_AES128GCM_SHA256_P256}, 220 | }, 221 | } 222 | 223 | var err error 224 | for i := range tv.Cases { 225 | tc := &tv.Cases[i] 226 | 227 | tc.HKDFExtractOut = tc.CipherSuite.hkdfExtract(tv.HKDFExtractSalt, tv.HKDFExtractIKM) 228 | 229 | priv, err := tc.CipherSuite.hpke().Derive(tv.DeriveKeyPairSeed) 230 | tc.DeriveKeyPairPub = priv.PublicKey 231 | require.Nil(t, err) 232 | 233 | tc.HPKEOut, err = tc.CipherSuite.hpke().Encrypt(tc.DeriveKeyPairPub, tv.HPKEAAD, tv.HPKEPlaintext) 234 | require.Nil(t, err) 235 | } 236 | 237 | vec, err := syntax.Marshal(tv) 238 | require.Nil(t, err) 239 | return vec 240 | } 241 | 242 | func verifyCryptoVectors(t *testing.T, data []byte) { 243 | var tv CryptoTestVectors 244 | _, err := syntax.Unmarshal(data, &tv) 245 | require.Nil(t, err) 246 | 247 | for _, tc := range tv.Cases { 248 | hkdfExtractOut := tc.CipherSuite.hkdfExtract(tv.HKDFExtractSalt, tv.HKDFExtractIKM) 249 | require.Equal(t, hkdfExtractOut, tc.HKDFExtractOut) 250 | 251 | priv, err := tc.CipherSuite.hpke().Derive(tv.DeriveKeyPairSeed) 252 | require.Nil(t, err) 253 | require.Equal(t, priv.PublicKey.Data, tc.DeriveKeyPairPub.Data) 254 | 255 | plaintext, err := tc.CipherSuite.hpke().Decrypt(priv, tv.HPKEAAD, tc.HPKEOut) 256 | require.Nil(t, err) 257 | require.Equal(t, plaintext, tv.HPKEPlaintext) 258 | } 259 | } 260 | -------------------------------------------------------------------------------- /extensions.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "fmt" 5 | 6 | syntax "github.com/cisco/go-tls-syntax" 7 | ) 8 | 9 | type ExtensionType uint16 10 | 11 | const ( 12 | ExtensionTypeInvalid ExtensionType = 0x0000 13 | ExtensionTypeSupportedVersions ExtensionType = 0x0001 14 | ExtensionTypeSupportedCipherSuites ExtensionType = 0x0002 15 | ExtensionTypeLifetime ExtensionType = 0x0003 16 | ExtensionTypeKeyID ExtensionType = 0x0004 17 | ExtensionTypeParentHash ExtensionType = 0x0005 18 | ) 19 | 20 | type ExtensionBody interface { 21 | Type() ExtensionType 22 | } 23 | 24 | type Extension struct { 25 | ExtensionType ExtensionType 26 | ExtensionData []byte `tls:"head=2"` 27 | } 28 | 29 | type ExtensionList struct { 30 | Entries []Extension `tls:"head=2"` 31 | } 32 | 33 | func NewExtensionList() ExtensionList { 34 | return ExtensionList{[]Extension{}} 35 | } 36 | 37 | func (el *ExtensionList) Add(src ExtensionBody) error { 38 | data, err := syntax.Marshal(src) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | // If one already exists with this type, replace it 44 | for i := range el.Entries { 45 | if el.Entries[i].ExtensionType == src.Type() { 46 | el.Entries[i].ExtensionData = data 47 | return nil 48 | } 49 | } 50 | 51 | // Otherwise append 52 | el.Entries = append(el.Entries, Extension{ 53 | ExtensionType: src.Type(), 54 | ExtensionData: data, 55 | }) 56 | return nil 57 | } 58 | 59 | func (el ExtensionList) Has(extType ExtensionType) bool { 60 | for _, ext := range el.Entries { 61 | if ext.ExtensionType == extType { 62 | return true 63 | } 64 | } 65 | return false 66 | } 67 | 68 | func (el ExtensionList) Find(dst ExtensionBody) (bool, error) { 69 | for _, ext := range el.Entries { 70 | if ext.ExtensionType == dst.Type() { 71 | read, err := syntax.Unmarshal(ext.ExtensionData, dst) 72 | if err != nil { 73 | return true, err 74 | } 75 | 76 | if read != len(ext.ExtensionData) { 77 | return true, fmt.Errorf("Extension failed to consume all data") 78 | } 79 | 80 | return true, nil 81 | } 82 | } 83 | return false, nil 84 | } 85 | 86 | ////////// 87 | 88 | type SupportedVersionsExtension struct { 89 | SupportedVersions []ProtocolVersion `tls:"head=1"` 90 | } 91 | 92 | func (sve SupportedVersionsExtension) Type() ExtensionType { 93 | return ExtensionTypeSupportedVersions 94 | } 95 | 96 | ////////// 97 | 98 | type SupportedCipherSuitesExtension struct { 99 | SupportedCipherSuites []CipherSuite `tls:"head=1"` 100 | } 101 | 102 | func (sce SupportedCipherSuitesExtension) Type() ExtensionType { 103 | return ExtensionTypeSupportedCipherSuites 104 | } 105 | 106 | ////////// 107 | 108 | type LifetimeExtension struct { 109 | NotBefore uint64 110 | NotAfter uint64 111 | } 112 | 113 | func (lte LifetimeExtension) Type() ExtensionType { 114 | return ExtensionTypeLifetime 115 | } 116 | 117 | ////////// 118 | 119 | type ParentHashExtension struct { 120 | ParentHash []byte `tls:"head=1"` 121 | } 122 | 123 | func (phe ParentHashExtension) Type() ExtensionType { 124 | return ExtensionTypeParentHash 125 | } 126 | -------------------------------------------------------------------------------- /extensions_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "testing" 5 | 6 | syntax "github.com/cisco/go-tls-syntax" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | const ExtensionTypeTwoByte ExtensionType = 0xffff 11 | 12 | type TwoByteExtension [2]byte 13 | 14 | func (ne TwoByteExtension) Type() ExtensionType { 15 | return ExtensionTypeTwoByte 16 | } 17 | 18 | func TestExtensionList(t *testing.T) { 19 | // Add an extension to the list 20 | extBody1 := &TwoByteExtension{0xFF, 0xFE} 21 | extBody1Data := unhex("FFFE") 22 | el := NewExtensionList() 23 | err := el.Add(extBody1) 24 | require.Nil(t, err) 25 | require.Equal(t, len(el.Entries), 1) 26 | require.Equal(t, el.Entries[0].ExtensionType, extBody1.Type()) 27 | require.Equal(t, el.Entries[0].ExtensionData, extBody1Data) 28 | 29 | // Verify that Has() returns the expected values 30 | require.True(t, el.Has(ExtensionTypeTwoByte)) 31 | require.False(t, el.Has(ExtensionTypeSupportedVersions)) 32 | 33 | // Verify that adding again replaces the first 34 | extBody2 := &TwoByteExtension{0xFD, 0xFC} 35 | extBody2Data := unhex("FDFC") 36 | err = el.Add(extBody2) 37 | require.Nil(t, err) 38 | require.Equal(t, len(el.Entries), 1) 39 | require.Equal(t, el.Entries[0].ExtensionType, extBody2.Type()) 40 | require.Equal(t, el.Entries[0].ExtensionData, extBody2Data) 41 | 42 | // Verify that the body can be retrieved 43 | extBody3 := new(TwoByteExtension) 44 | found, err := el.Find(extBody3) 45 | require.True(t, found) 46 | require.Nil(t, err) 47 | require.Equal(t, extBody3, extBody2) 48 | 49 | // Verify that an error is returned if the extension body doesn't consume all 50 | // of the data in the extension 51 | el.Entries[0].ExtensionData = append(el.Entries[0].ExtensionData, 0x00) 52 | found, err = el.Find(extBody3) 53 | require.True(t, found) 54 | require.Error(t, err) 55 | 56 | // Verify that unknown extension are reported correctly 57 | extBody4 := new(ParentHashExtension) 58 | found, err = el.Find(extBody4) 59 | require.False(t, found) 60 | require.Nil(t, err) 61 | } 62 | 63 | type ExtensionTestCase struct { 64 | extensionType ExtensionType 65 | blank ExtensionBody 66 | unmarshaled ExtensionBody 67 | marshaledHex string 68 | } 69 | 70 | func (etc ExtensionTestCase) run(t *testing.T) { 71 | marshaled := unhex(etc.marshaledHex) 72 | 73 | // Test extension type 74 | require.Equal(t, etc.unmarshaled.Type(), etc.extensionType) 75 | 76 | // Test successful marshal 77 | out, err := syntax.Marshal(etc.unmarshaled) 78 | require.Nil(t, err) 79 | require.Equal(t, marshaled, out) 80 | 81 | // Test successful unmarshal 82 | read, err := syntax.Unmarshal(marshaled, etc.blank) 83 | require.Nil(t, err) 84 | require.Equal(t, etc.blank, etc.unmarshaled) 85 | require.Equal(t, read, len(marshaled)) 86 | } 87 | 88 | var ( 89 | lifetimeExtension = LifetimeExtension{NotBefore: 0, NotAfter: 0xA0A0A0A0A0A0A0A0} 90 | ) 91 | 92 | var validExtensionTestCases = map[string]ExtensionTestCase{ 93 | "SupportedVersions": { 94 | extensionType: ExtensionTypeSupportedVersions, 95 | blank: new(SupportedVersionsExtension), 96 | unmarshaled: &SupportedVersionsExtension{[]ProtocolVersion{ProtocolVersionMLS10}}, 97 | marshaledHex: "0100", 98 | }, 99 | "SupportedCiphersuites": { 100 | extensionType: ExtensionTypeSupportedCipherSuites, 101 | blank: new(SupportedCipherSuitesExtension), 102 | unmarshaled: &SupportedCipherSuitesExtension{[]CipherSuite{ 103 | X25519_AES128GCM_SHA256_Ed25519, 104 | P256_AES128GCM_SHA256_P256, 105 | }}, 106 | marshaledHex: "0400010002", 107 | }, 108 | "Lifetime": { 109 | extensionType: ExtensionTypeLifetime, 110 | blank: new(LifetimeExtension), 111 | unmarshaled: &lifetimeExtension, 112 | marshaledHex: "0000000000000000a0a0a0a0a0a0a0a0", 113 | }, 114 | "ParentHash": { 115 | extensionType: ExtensionTypeParentHash, 116 | blank: new(ParentHashExtension), 117 | unmarshaled: &ParentHashExtension{[]byte{0x00, 0x01, 0x02, 0x03}}, 118 | marshaledHex: "0400010203", 119 | }, 120 | } 121 | 122 | func TestExtensionBodyMarshalUnmarshal(t *testing.T) { 123 | for name, test := range validExtensionTestCases { 124 | t.Run(name, test.run) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cisco/go-mls 2 | 3 | go 1.14 4 | 5 | require ( 6 | git.schwanenlied.me/yawning/x448.git v0.0.0-20170617130356-01b048fb03d6 // indirect 7 | github.com/cisco/go-hpke v0.0.0-20200603153819-0a6c8374cd9a 8 | github.com/cisco/go-tls-syntax v0.0.0-20200615170901-cc95af012391 9 | github.com/cloudflare/circl v1.0.0 // indirect 10 | github.com/stretchr/testify v1.6.1 11 | golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 12 | ) 13 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | git.schwanenlied.me/yawning/x448.git v0.0.0-20170617130356-01b048fb03d6 h1:w8IZgCntCe0RuBJp+dENSMwEBl/k8saTgJ5hPca5IWw= 2 | git.schwanenlied.me/yawning/x448.git v0.0.0-20170617130356-01b048fb03d6/go.mod h1:wQaGCqEu44ykB17jZHCevrgSVl3KJnwQBObUtrKU4uU= 3 | github.com/cisco/go-hpke v0.0.0-20200603153819-0a6c8374cd9a h1:avwcoMq3mm7ACKdjsMooUWHPFuVrTc8Q47ZDSGP6GOo= 4 | github.com/cisco/go-hpke v0.0.0-20200603153819-0a6c8374cd9a/go.mod h1:7ykSQZaBVJLIRoJ7OMiJgpdOD74cTHdXRo6XPMIfu20= 5 | github.com/cisco/go-tls-syntax v0.0.0-20200615170901-cc95af012391 h1:psZtmcKE1XNc9SbeTfZTd530f+cS87x2bqI+QbVVEVw= 6 | github.com/cisco/go-tls-syntax v0.0.0-20200615170901-cc95af012391/go.mod h1:KoUJMVoZOKaVsiKsMwnZD0Y5jSUawe3/QHYrwOvld3k= 7 | github.com/cloudflare/circl v1.0.0 h1:64b6pyfCFbYm623ncIkYGNZaOcmIbyd+CjyMi2L9vdI= 8 | github.com/cloudflare/circl v1.0.0/go.mod h1:MhjB3NEEhJbTOdLLq964NIUisXDxaE1WkQPUxtgZXiY= 9 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 10 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 13 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 14 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= 15 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 16 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 17 | golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 h1:vEg9joUBmeBcK9iSJftGNf3coIG4HqZElCPehJsfAYM= 18 | golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 19 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 20 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 21 | golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= 22 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 23 | golang.org/x/sys v0.0.0-20190602015325-4c4f7f33c9ed h1:uPxWBzB3+mlnjy9W58qY1j/cjyFjutgw/Vhan2zLy/A= 24 | golang.org/x/sys v0.0.0-20190602015325-4c4f7f33c9ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 25 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 26 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 27 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 28 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 29 | -------------------------------------------------------------------------------- /key-schedule.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/cisco/go-tls-syntax" 7 | ) 8 | 9 | type keyAndNonce struct { 10 | Key []byte `tls:"head=1"` 11 | Nonce []byte `tls:"head=1"` 12 | } 13 | 14 | func (k keyAndNonce) clone() keyAndNonce { 15 | return keyAndNonce{ 16 | Key: dup(k.Key), 17 | Nonce: dup(k.Nonce), 18 | } 19 | } 20 | 21 | func zeroize(data []byte) { 22 | for i := range data { 23 | data[i] = 0 24 | } 25 | } 26 | 27 | /// 28 | /// Hash ratchet 29 | /// 30 | 31 | type hashRatchet struct { 32 | Suite CipherSuite 33 | Node NodeIndex 34 | NextSecret []byte `tls:"head=1"` 35 | NextGeneration uint32 36 | Cache map[uint32]keyAndNonce `tls:"head=4"` 37 | KeySize uint32 38 | NonceSize uint32 39 | SecretSize uint32 40 | } 41 | 42 | func newHashRatchet(suite CipherSuite, node NodeIndex, baseSecret []byte) *hashRatchet { 43 | return &hashRatchet{ 44 | Suite: suite, 45 | Node: node, 46 | NextSecret: baseSecret, 47 | NextGeneration: 0, 48 | Cache: map[uint32]keyAndNonce{}, 49 | KeySize: uint32(suite.Constants().KeySize), 50 | NonceSize: uint32(suite.Constants().NonceSize), 51 | SecretSize: uint32(suite.Constants().SecretSize), 52 | } 53 | } 54 | 55 | func (hr *hashRatchet) Next() (uint32, keyAndNonce) { 56 | key := hr.Suite.deriveAppSecret(hr.NextSecret, "app-key", hr.Node, hr.NextGeneration, int(hr.KeySize)) 57 | nonce := hr.Suite.deriveAppSecret(hr.NextSecret, "app-nonce", hr.Node, hr.NextGeneration, int(hr.NonceSize)) 58 | secret := hr.Suite.deriveAppSecret(hr.NextSecret, "app-secret", hr.Node, hr.NextGeneration, int(hr.SecretSize)) 59 | 60 | generation := hr.NextGeneration 61 | 62 | hr.NextGeneration += 1 63 | zeroize(hr.NextSecret) 64 | hr.NextSecret = secret 65 | 66 | kn := keyAndNonce{key, nonce} 67 | hr.Cache[generation] = kn 68 | return generation, kn.clone() 69 | } 70 | 71 | func (hr *hashRatchet) Get(generation uint32) (keyAndNonce, error) { 72 | if kn, ok := hr.Cache[generation]; ok { 73 | return kn, nil 74 | } 75 | 76 | if hr.NextGeneration > generation { 77 | return keyAndNonce{}, fmt.Errorf("Request for expired key") 78 | } 79 | 80 | for hr.NextGeneration < generation { 81 | hr.Next() 82 | } 83 | 84 | _, kn := hr.Next() 85 | return kn, nil 86 | } 87 | 88 | func (hr *hashRatchet) Erase(generation uint32) { 89 | if _, ok := hr.Cache[generation]; !ok { 90 | return 91 | } 92 | 93 | zeroize(hr.Cache[generation].Key) 94 | zeroize(hr.Cache[generation].Nonce) 95 | delete(hr.Cache, generation) 96 | } 97 | 98 | /// 99 | /// Base key sources 100 | /// 101 | 102 | type baseKeySource interface { 103 | Suite() CipherSuite 104 | Get(sender LeafIndex) []byte 105 | } 106 | 107 | type noFSBaseKeySource struct { 108 | CipherSuite CipherSuite 109 | RootSecret []byte `tls:"head=1"` 110 | } 111 | 112 | func newNoFSBaseKeySource(suite CipherSuite, rootSecret []byte) *noFSBaseKeySource { 113 | return &noFSBaseKeySource{suite, rootSecret} 114 | } 115 | 116 | func (nfbks *noFSBaseKeySource) Suite() CipherSuite { 117 | return nfbks.CipherSuite 118 | } 119 | 120 | func (nfbks *noFSBaseKeySource) Get(sender LeafIndex) []byte { 121 | secretSize := nfbks.CipherSuite.Constants().SecretSize 122 | return nfbks.CipherSuite.deriveAppSecret(nfbks.RootSecret, "hs-secret", toNodeIndex(sender), 0, secretSize) 123 | } 124 | 125 | type Bytes1 []byte 126 | 127 | func (b Bytes1) MarshalTLS() ([]byte, error) { 128 | return syntax.Marshal(struct { 129 | Data []byte `tls:"head=1"` 130 | }{b}) 131 | } 132 | 133 | func (b *Bytes1) UnmarshalTLS(data []byte) (int, error) { 134 | tmp := struct { 135 | Data []byte `tls:"head=1"` 136 | }{} 137 | read, err := syntax.Unmarshal(data, &tmp) 138 | if err != nil { 139 | return read, err 140 | } 141 | 142 | *b = dup(tmp.Data) 143 | return read, nil 144 | } 145 | 146 | type treeBaseKeySource struct { 147 | CipherSuite CipherSuite 148 | SecretSize uint32 149 | Root NodeIndex 150 | Size LeafCount 151 | Secrets map[NodeIndex]Bytes1 `tls:"head=4"` 152 | } 153 | 154 | func newTreeBaseKeySource(suite CipherSuite, size LeafCount, rootSecret []byte) *treeBaseKeySource { 155 | tbks := &treeBaseKeySource{ 156 | CipherSuite: suite, 157 | SecretSize: uint32(suite.Constants().SecretSize), 158 | Root: root(size), 159 | Size: size, 160 | Secrets: map[NodeIndex]Bytes1{}, 161 | } 162 | 163 | tbks.Secrets[tbks.Root] = rootSecret 164 | return tbks 165 | } 166 | 167 | func (tbks *treeBaseKeySource) Suite() CipherSuite { 168 | return tbks.CipherSuite 169 | } 170 | 171 | func (tbks *treeBaseKeySource) Get(sender LeafIndex) []byte { 172 | // Find an ancestor that is populated 173 | senderNode := toNodeIndex(sender) 174 | d := dirpath(senderNode, tbks.Size) 175 | d = append([]NodeIndex{senderNode}, d...) 176 | found := false 177 | curr := 0 178 | for i, node := range d { 179 | if _, ok := tbks.Secrets[node]; ok { 180 | found = true 181 | curr = i 182 | break 183 | } 184 | } 185 | 186 | if !found { 187 | panic("Unable to find source for base key") 188 | } 189 | 190 | // Derive down 191 | for ; curr > 0; curr -= 1 { 192 | node := d[curr] 193 | L := left(node) 194 | R := right(node, tbks.Size) 195 | 196 | secret := tbks.Secrets[node] 197 | tbks.Secrets[L] = tbks.CipherSuite.deriveAppSecret(secret, "tree", L, 0, int(tbks.SecretSize)) 198 | tbks.Secrets[R] = tbks.CipherSuite.deriveAppSecret(secret, "tree", R, 0, int(tbks.SecretSize)) 199 | zeroize(tbks.Secrets[node]) 200 | delete(tbks.Secrets, node) 201 | } 202 | 203 | // Copy and return the leaf 204 | out := dup(tbks.Secrets[senderNode]) 205 | zeroize(tbks.Secrets[senderNode]) 206 | delete(tbks.Secrets, senderNode) 207 | return out 208 | } 209 | 210 | func (tbks *treeBaseKeySource) dump() { 211 | w := nodeWidth(tbks.Size) 212 | fmt.Println("=== tbks ===") 213 | for i := NodeIndex(0); i < NodeIndex(w); i += 1 { 214 | s, ok := tbks.Secrets[i] 215 | if ok { 216 | fmt.Printf(" %3x [%x]\n", i, s) 217 | } else { 218 | fmt.Printf(" %3x _\n", i) 219 | } 220 | } 221 | } 222 | 223 | /// 224 | /// Group key source 225 | /// 226 | 227 | type groupKeySource struct { 228 | Base baseKeySource 229 | Ratchets map[LeafIndex]*hashRatchet 230 | } 231 | 232 | func (gks groupKeySource) ratchet(sender LeafIndex) *hashRatchet { 233 | if r, ok := gks.Ratchets[sender]; ok { 234 | return r 235 | } 236 | 237 | baseSecret := gks.Base.Get(sender) 238 | gks.Ratchets[sender] = newHashRatchet(gks.Base.Suite(), toNodeIndex(sender), baseSecret) 239 | return gks.Ratchets[sender] 240 | } 241 | 242 | func (gks groupKeySource) Next(sender LeafIndex) (uint32, keyAndNonce) { 243 | return gks.ratchet(sender).Next() 244 | } 245 | 246 | func (gks groupKeySource) Get(sender LeafIndex, generation uint32) (keyAndNonce, error) { 247 | return gks.ratchet(sender).Get(generation) 248 | } 249 | 250 | func (gks groupKeySource) Erase(sender LeafIndex, generation uint32) { 251 | gks.ratchet(sender).Erase(generation) 252 | } 253 | 254 | /// 255 | /// GroupInfo keys 256 | /// 257 | 258 | func groupInfoKeyAndNonce(suite CipherSuite, epochSecret []byte) keyAndNonce { 259 | secretSize := suite.Constants().SecretSize 260 | keySize := suite.Constants().KeySize 261 | nonceSize := suite.Constants().NonceSize 262 | 263 | groupInfoSecret := suite.hkdfExpandLabel(epochSecret, "group info", []byte{}, secretSize) 264 | groupInfoKey := suite.hkdfExpandLabel(groupInfoSecret, "key", []byte{}, keySize) 265 | groupInfoNonce := suite.hkdfExpandLabel(groupInfoSecret, "nonce", []byte{}, nonceSize) 266 | 267 | return keyAndNonce{ 268 | Key: groupInfoKey, 269 | Nonce: groupInfoNonce, 270 | } 271 | } 272 | 273 | /// 274 | /// Key schedule epoch 275 | /// 276 | 277 | type keyScheduleEpoch struct { 278 | Suite CipherSuite 279 | GroupContext []byte `tls:"head=1"` 280 | 281 | EpochSecret []byte `tls:"head=1"` 282 | SenderDataSecret []byte `tls:"head=1"` 283 | SenderDataKey []byte `tls:"head=1"` 284 | HandshakeSecret []byte `tls:"head=1"` 285 | ApplicationSecret []byte `tls:"head=1"` 286 | ExporterSecret []byte `tls:"head=1"` 287 | ConfirmationKey []byte `tls:"head=1"` 288 | InitSecret []byte `tls:"head=1"` 289 | 290 | HandshakeBaseKeys *noFSBaseKeySource 291 | ApplicationBaseKeys *treeBaseKeySource 292 | 293 | HandshakeRatchets map[LeafIndex]*hashRatchet `tls:"head=4"` 294 | ApplicationRatchets map[LeafIndex]*hashRatchet `tls:"head=4"` 295 | 296 | ApplicationKeys *groupKeySource `tls:"omit"` 297 | HandshakeKeys *groupKeySource `tls:"omit"` 298 | } 299 | 300 | func newKeyScheduleEpoch(suite CipherSuite, size LeafCount, epochSecret, context []byte) keyScheduleEpoch { 301 | senderDataSecret := suite.deriveSecret(epochSecret, "sender data", context) 302 | handshakeSecret := suite.deriveSecret(epochSecret, "handshake", context) 303 | applicationSecret := suite.deriveSecret(epochSecret, "app", context) 304 | exporterSecret := suite.deriveSecret(epochSecret, "exporter", context) 305 | confirmationKey := suite.deriveSecret(epochSecret, "confirm", context) 306 | initSecret := suite.deriveSecret(epochSecret, "init", context) 307 | 308 | senderDataKey := suite.hkdfExpandLabel(senderDataSecret, "sd key", []byte{}, suite.Constants().KeySize) 309 | handshakeBaseKeys := newNoFSBaseKeySource(suite, handshakeSecret) 310 | applicationBaseKeys := newTreeBaseKeySource(suite, size, applicationSecret) 311 | 312 | kse := keyScheduleEpoch{ 313 | Suite: suite, 314 | GroupContext: context, 315 | 316 | EpochSecret: epochSecret, 317 | SenderDataSecret: senderDataSecret, 318 | SenderDataKey: senderDataKey, 319 | HandshakeSecret: handshakeSecret, 320 | ApplicationSecret: applicationSecret, 321 | ExporterSecret: exporterSecret, 322 | ConfirmationKey: confirmationKey, 323 | InitSecret: initSecret, 324 | 325 | HandshakeBaseKeys: handshakeBaseKeys, 326 | ApplicationBaseKeys: applicationBaseKeys, 327 | 328 | HandshakeRatchets: map[LeafIndex]*hashRatchet{}, 329 | ApplicationRatchets: map[LeafIndex]*hashRatchet{}, 330 | } 331 | 332 | kse.enableKeySources() 333 | return kse 334 | } 335 | 336 | // Wire up the key sources as logic on top of data owned by the epoch 337 | func (kse *keyScheduleEpoch) enableKeySources() { 338 | kse.HandshakeKeys = &groupKeySource{kse.HandshakeBaseKeys, kse.HandshakeRatchets} 339 | kse.ApplicationKeys = &groupKeySource{kse.ApplicationBaseKeys, kse.ApplicationRatchets} 340 | } 341 | 342 | func (kse *keyScheduleEpoch) Next(size LeafCount, pskIn, commitSecret, context []byte) keyScheduleEpoch { 343 | psk := pskIn 344 | if len(psk) == 0 { 345 | psk = kse.Suite.zero() 346 | } 347 | 348 | earlySecret := kse.Suite.hkdfExtract(psk, kse.InitSecret) 349 | preEpochSecret := kse.Suite.deriveSecret(earlySecret, "derived", context) 350 | epochSecret := kse.Suite.hkdfExtract(commitSecret, preEpochSecret) 351 | return newKeyScheduleEpoch(kse.Suite, size, epochSecret, context) 352 | } 353 | 354 | func (kse *keyScheduleEpoch) Export(label string, context []byte, keyLength int) []byte { 355 | exporterBase := kse.Suite.deriveSecret(kse.ExporterSecret, label, kse.GroupContext) 356 | hctx := kse.Suite.Digest(context) 357 | return kse.Suite.hkdfExpandLabel(exporterBase, "exporter", hctx, keyLength) 358 | } 359 | -------------------------------------------------------------------------------- /key-schedule_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/cisco/go-tls-syntax" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | // XXX(rlb): Uncomment this to see a graphical illustration of how the 13 | // tree-based key derivation works 14 | /* 15 | func TestTreeBaseKeySource(t *testing.T) { 16 | size := LeafCount(11) 17 | root := unhex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") 18 | tbks := newTreeBaseKeySource(P256_SHA256_AES128GCM, size, root) 19 | for i := LeafIndex(0); i < LeafIndex(size); i += 1 { 20 | tbks.Get(i) 21 | tbks.dump() 22 | } 23 | } 24 | */ 25 | 26 | // XXX(rlb): This is a very loose check, just exercising the code and verifying 27 | // that it doesnt panic and produces outputs that are the right size. We should 28 | // do actual interop testing. There's not much between here and there. 29 | func TestKeySchedule(t *testing.T) { 30 | suite := P256_AES128GCM_SHA256_P256 31 | secretSize := suite.Constants().SecretSize 32 | keySize := suite.Constants().KeySize 33 | nonceSize := suite.Constants().NonceSize 34 | 35 | size1 := LeafCount(5) 36 | epochSecret1 := unhex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") 37 | context1 := []byte("first") 38 | 39 | size2 := LeafCount(11) 40 | psk2 := []byte("psk") 41 | commitSecret2 := unhex("404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f") 42 | context2 := []byte("second") 43 | 44 | exportSize := 128 45 | targetGeneration := uint32(3) 46 | 47 | checkEpoch := func(epoch *keyScheduleEpoch, size LeafCount) { 48 | require.Equal(t, epoch.Suite, suite) 49 | require.Equal(t, len(epoch.EpochSecret), secretSize) 50 | require.Equal(t, len(epoch.SenderDataSecret), secretSize) 51 | require.Equal(t, len(epoch.SenderDataKey), keySize) 52 | require.Equal(t, len(epoch.HandshakeSecret), secretSize) 53 | require.Equal(t, len(epoch.ApplicationSecret), secretSize) 54 | require.Equal(t, len(epoch.ExporterSecret), secretSize) 55 | require.Equal(t, len(epoch.ConfirmationKey), secretSize) 56 | require.Equal(t, len(epoch.InitSecret), secretSize) 57 | require.NotNil(t, epoch.HandshakeKeys) 58 | require.NotNil(t, epoch.HandshakeKeys) 59 | 60 | exportedKey := epoch.Export("test", []byte{0, 1, 2, 3}, exportSize) 61 | require.Equal(t, len(exportedKey), exportSize) 62 | 63 | for i := LeafIndex(0); i < LeafIndex(size); i += 1 { 64 | // Test successful generation 65 | hs, err := epoch.HandshakeKeys.Get(i, targetGeneration) 66 | require.Nil(t, err) 67 | require.Equal(t, len(hs.Key), keySize) 68 | require.Equal(t, len(hs.Nonce), nonceSize) 69 | 70 | app, err := epoch.ApplicationKeys.Get(i, targetGeneration) 71 | require.Nil(t, err) 72 | require.Equal(t, len(app.Key), keySize) 73 | require.Equal(t, len(app.Nonce), nonceSize) 74 | 75 | epoch.HandshakeKeys.Erase(i, targetGeneration) 76 | epoch.ApplicationKeys.Erase(i, targetGeneration) 77 | 78 | // Test forward secrecy 79 | _, err = epoch.HandshakeKeys.Get(i, targetGeneration) 80 | require.Error(t, err) 81 | 82 | _, err = epoch.ApplicationKeys.Get(i, targetGeneration) 83 | require.Error(t, err) 84 | } 85 | } 86 | 87 | epoch1 := newKeyScheduleEpoch(suite, size1, epochSecret1, context1) 88 | checkEpoch(&epoch1, size1) 89 | 90 | epoch2 := epoch1.Next(size2, psk2, commitSecret2, context2) 91 | checkEpoch(&epoch2, size2) 92 | 93 | // Check that marshal/unmarshal works 94 | epoch2m, err := syntax.Marshal(epoch2) 95 | require.Nil(t, err) 96 | 97 | var epoch2u keyScheduleEpoch 98 | _, err = syntax.Unmarshal(epoch2m, &epoch2u) 99 | require.Nil(t, err) 100 | 101 | epoch2u.enableKeySources() 102 | 103 | // Verify that the contents match (not the group key generators) 104 | require.Equal(t, epoch2.Suite, epoch2u.Suite) 105 | require.Equal(t, epoch2.EpochSecret, epoch2u.EpochSecret) 106 | require.Equal(t, epoch2.SenderDataSecret, epoch2u.SenderDataSecret) 107 | require.Equal(t, epoch2.SenderDataKey, epoch2u.SenderDataKey) 108 | require.Equal(t, epoch2.HandshakeSecret, epoch2u.HandshakeSecret) 109 | require.Equal(t, epoch2.ApplicationSecret, epoch2u.ApplicationSecret) 110 | require.Equal(t, epoch2.ConfirmationKey, epoch2u.ConfirmationKey) 111 | require.Equal(t, epoch2.InitSecret, epoch2u.InitSecret) 112 | require.Equal(t, epoch2.HandshakeBaseKeys, epoch2u.HandshakeBaseKeys) 113 | require.Equal(t, epoch2.ApplicationBaseKeys, epoch2u.ApplicationBaseKeys) 114 | require.Equal(t, epoch2.HandshakeRatchets, epoch2u.HandshakeRatchets) 115 | require.Equal(t, epoch2.ApplicationRatchets, epoch2u.ApplicationRatchets) 116 | 117 | // Verify that we can't get a key for the target generation (because it's 118 | // already consumed) 119 | _, err = epoch2u.HandshakeKeys.Get(0, targetGeneration) 120 | require.Error(t, err) 121 | 122 | // Verify that we can get one for the next epoch, and it's the same as the 123 | // original key schedule would have produced 124 | _, err = epoch2u.HandshakeKeys.Get(0, targetGeneration+1) 125 | require.Nil(t, err) 126 | } 127 | 128 | /// 129 | /// Vectors 130 | /// 131 | 132 | type KsEpoch struct { 133 | NumMembers LeafCount 134 | PSK []byte `tls:"head=1"` 135 | CommitSecret []byte `tls:"head=1"` 136 | 137 | EpochSecret []byte `tls:"head=1"` 138 | SenderDataSecret []byte `tls:"head=1"` 139 | SenderDataKey []byte `tls:"head=1"` 140 | HandshakeSecret []byte `tls:"head=1"` 141 | HandshakeKeys []keyAndNonce `tls:"head=4"` 142 | AppSecret []byte `tls:"head=1"` 143 | AppKeys []keyAndNonce `tls:"head=4"` 144 | ExporterSecret []byte `tls:"head=1"` 145 | ExportedSecret []byte `tls:"head=1"` 146 | ConfirmationKey []byte `tls:"head=1"` 147 | InitSecret []byte `tls:"head=1"` 148 | } 149 | 150 | type KsTestCase struct { 151 | CipherSuite CipherSuite 152 | Epochs []KsEpoch `tls:"head=2"` 153 | } 154 | 155 | type KsTestVectors struct { 156 | NumEpochs uint32 157 | TargetGeneration uint32 158 | ExportLabel []byte `tls:"head=1"` 159 | ExportContext []byte `tls:"head=1"` 160 | ExportSize uint32 161 | BaseInitSecret []byte `tls:"head=1"` 162 | BaseGroupContext []byte `tls:"head=4"` 163 | Cases []KsTestCase `tls:"head=4"` 164 | } 165 | 166 | /// Gen and Verify 167 | func generateKeyScheduleVectors(t *testing.T) []byte { 168 | var tv KsTestVectors 169 | suites := []CipherSuite{P256_AES128GCM_SHA256_P256} 170 | baseGrpCtx := GroupContext{ 171 | GroupID: []byte{0xA0, 0xA0, 0xA0, 0xA0}, 172 | Epoch: 0, 173 | TreeHash: bytes.Repeat([]byte{0xA1}, 32), 174 | ConfirmedTranscriptHash: bytes.Repeat([]byte{0xA2}, 32), 175 | } 176 | 177 | encCtx, err := syntax.Marshal(baseGrpCtx) 178 | require.Nil(t, err) 179 | tv.NumEpochs = 50 180 | tv.TargetGeneration = 3 181 | tv.ExportLabel = []byte("exportLabel") 182 | tv.ExportContext = []byte("exportContext") 183 | tv.ExportSize = 24 184 | tv.BaseInitSecret = bytes.Repeat([]byte{0xA3}, 32) 185 | tv.BaseGroupContext = encCtx 186 | 187 | for _, suite := range suites { 188 | var tc KsTestCase 189 | tc.CipherSuite = suite 190 | // start with the base context for epoch0 191 | grpCtx := baseGrpCtx 192 | minMembers := 5 193 | maxMembers := 20 194 | nMembers := minMembers 195 | 196 | var epoch keyScheduleEpoch 197 | epoch.Suite = suite 198 | epoch.InitSecret = tv.BaseInitSecret 199 | for i := 0; i < int(tv.NumEpochs); i++ { 200 | ctx, _ := syntax.Marshal(grpCtx) 201 | 202 | psk := []byte(fmt.Sprintf("psk @ %d", i)) 203 | commitSecret := []byte(fmt.Sprintf("commitSecret @ %d", i)) 204 | epoch = epoch.Next(LeafCount(nMembers), psk, commitSecret, ctx) 205 | 206 | var handshakeKeys []keyAndNonce 207 | var applicationKeys []keyAndNonce 208 | appSecret := dup(epoch.ApplicationSecret) 209 | for j := 0; j < nMembers; j++ { 210 | hs, _ := epoch.HandshakeKeys.Get(LeafIndex(j), tv.TargetGeneration) 211 | handshakeKeys = append(handshakeKeys, hs) 212 | as, _ := epoch.ApplicationKeys.Get(LeafIndex(j), tv.TargetGeneration) 213 | applicationKeys = append(applicationKeys, as) 214 | } 215 | 216 | exportedSecret := epoch.Export(string(tv.ExportLabel), tv.ExportContext, int(tv.ExportSize)) 217 | 218 | kse := KsEpoch{ 219 | PSK: psk, 220 | CommitSecret: commitSecret, 221 | 222 | NumMembers: LeafCount(nMembers), 223 | EpochSecret: epoch.EpochSecret, 224 | SenderDataSecret: epoch.SenderDataSecret, 225 | SenderDataKey: epoch.SenderDataKey, 226 | HandshakeSecret: epoch.HandshakeSecret, 227 | HandshakeKeys: handshakeKeys, 228 | AppSecret: appSecret, 229 | AppKeys: applicationKeys, 230 | ExporterSecret: epoch.ExporterSecret, 231 | ExportedSecret: exportedSecret, 232 | ConfirmationKey: epoch.ConfirmationKey, 233 | InitSecret: epoch.InitSecret, 234 | } 235 | 236 | tc.Epochs = append(tc.Epochs, kse) 237 | 238 | grpCtx.Epoch += 1 239 | nMembers = (nMembers-minMembers)%(maxMembers-minMembers) + minMembers 240 | } 241 | tv.Cases = append(tv.Cases, tc) 242 | } 243 | 244 | vec, err := syntax.Marshal(tv) 245 | require.Nil(t, err) 246 | return vec 247 | } 248 | 249 | func verifyKeyScheduleVectors(t *testing.T, data []byte) { 250 | var tv KsTestVectors 251 | _, err := syntax.Unmarshal(data, &tv) 252 | require.Nil(t, err) 253 | for _, tc := range tv.Cases { 254 | suite := tc.CipherSuite 255 | var grpCtx GroupContext 256 | _, err := syntax.Unmarshal(tv.BaseGroupContext, &grpCtx) 257 | require.Nil(t, err) 258 | var myEpoch keyScheduleEpoch 259 | myEpoch.Suite = suite 260 | myEpoch.InitSecret = tv.BaseInitSecret 261 | for _, epoch := range tc.Epochs { 262 | ctx, _ := syntax.Marshal(grpCtx) 263 | myEpoch = myEpoch.Next(epoch.NumMembers, epoch.PSK, epoch.CommitSecret, ctx) 264 | 265 | // check the secrets 266 | require.Equal(t, myEpoch.EpochSecret, epoch.EpochSecret) 267 | require.Equal(t, myEpoch.SenderDataSecret, epoch.SenderDataSecret) 268 | require.Equal(t, myEpoch.SenderDataKey, epoch.SenderDataKey) 269 | require.Equal(t, myEpoch.HandshakeSecret, epoch.HandshakeSecret) 270 | require.Equal(t, myEpoch.ApplicationSecret, epoch.AppSecret) 271 | require.Equal(t, myEpoch.ExporterSecret, epoch.ExporterSecret) 272 | require.Equal(t, myEpoch.ConfirmationKey, epoch.ConfirmationKey) 273 | require.Equal(t, myEpoch.InitSecret, epoch.InitSecret) 274 | 275 | // check export 276 | exportedSecret := myEpoch.Export(string(tv.ExportLabel), tv.ExportContext, int(tv.ExportSize)) 277 | require.Equal(t, exportedSecret, epoch.ExportedSecret) 278 | 279 | // check the keys 280 | for i := 0; LeafCount(i) < epoch.NumMembers; i++ { 281 | hs, err := myEpoch.HandshakeKeys.Get(LeafIndex(i), tv.TargetGeneration) 282 | require.Nil(t, err) 283 | require.Equal(t, hs.Key, epoch.HandshakeKeys[i].Key) 284 | require.Equal(t, hs.Nonce, epoch.HandshakeKeys[i].Nonce) 285 | 286 | as, err := myEpoch.ApplicationKeys.Get(LeafIndex(i), tv.TargetGeneration) 287 | require.Nil(t, err) 288 | require.Equal(t, as.Key, epoch.AppKeys[i].Key) 289 | require.Equal(t, as.Nonce, epoch.AppKeys[i].Nonce) 290 | 291 | } 292 | grpCtx.Epoch += 1 293 | } 294 | } 295 | } 296 | -------------------------------------------------------------------------------- /messages.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "time" 7 | 8 | syntax "github.com/cisco/go-tls-syntax" 9 | ) 10 | 11 | /// 12 | /// KeyPackage 13 | /// 14 | type Signature struct { 15 | Data []byte `tls:"head=2"` 16 | } 17 | 18 | type ProtocolVersion uint8 19 | 20 | const ( 21 | ProtocolVersionMLS10 ProtocolVersion = 0x00 22 | ) 23 | 24 | var ( 25 | supportedVersions = []ProtocolVersion{ProtocolVersionMLS10} 26 | supportedCipherSuites = []CipherSuite{ 27 | X25519_AES128GCM_SHA256_Ed25519, 28 | P256_AES128GCM_SHA256_P256, 29 | X25519_CHACHA20POLY1305_SHA256_Ed25519, 30 | P521_AES256GCM_SHA512_P521, 31 | } 32 | defaultLifetime = 30 * 24 * time.Hour 33 | ) 34 | 35 | type KeyPackage struct { 36 | Version ProtocolVersion 37 | CipherSuite CipherSuite 38 | InitKey HPKEPublicKey 39 | Credential Credential 40 | Extensions ExtensionList 41 | Signature Signature 42 | } 43 | 44 | func (kp KeyPackage) Equals(other KeyPackage) bool { 45 | version := kp.Version == other.Version 46 | suite := kp.CipherSuite == other.CipherSuite 47 | initKey := reflect.DeepEqual(kp.InitKey, other.InitKey) 48 | credential := kp.Credential.Equals(other.Credential) 49 | extensions := reflect.DeepEqual(kp.Extensions, kp.Extensions) 50 | signature := reflect.DeepEqual(kp.Signature, other.Signature) 51 | return version && suite && initKey && credential && extensions && signature 52 | } 53 | 54 | func (kp KeyPackage) Clone() KeyPackage { 55 | return KeyPackage{ 56 | Version: kp.Version, 57 | CipherSuite: kp.CipherSuite, 58 | InitKey: kp.InitKey, 59 | Credential: kp.Credential, 60 | Extensions: kp.Extensions, 61 | Signature: kp.Signature, 62 | } 63 | } 64 | 65 | func (kp KeyPackage) toBeSigned() ([]byte, error) { 66 | enc, err := syntax.Marshal(struct { 67 | Version ProtocolVersion 68 | CipherSuite CipherSuite 69 | InitKey HPKEPublicKey 70 | Credential Credential 71 | Extensions ExtensionList 72 | }{ 73 | Version: kp.Version, 74 | CipherSuite: kp.CipherSuite, 75 | InitKey: kp.InitKey, 76 | Credential: kp.Credential, 77 | Extensions: kp.Extensions, 78 | }) 79 | 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | return enc, nil 85 | } 86 | 87 | func (kp *KeyPackage) SetExtensions(exts []ExtensionBody) error { 88 | for _, ext := range exts { 89 | err := kp.Extensions.Add(ext) 90 | if err != nil { 91 | return err 92 | } 93 | } 94 | 95 | return nil 96 | } 97 | 98 | func (kp *KeyPackage) Sign(priv SignaturePrivateKey) error { 99 | if !priv.PublicKey.Equals(*kp.Credential.PublicKey()) { 100 | return fmt.Errorf("Public key mismatch") 101 | } 102 | 103 | tbs, err := kp.toBeSigned() 104 | if err != nil { 105 | return err 106 | } 107 | 108 | sig, err := kp.Credential.Scheme().Sign(&priv, tbs) 109 | if err != nil { 110 | return err 111 | } 112 | 113 | kp.Signature = Signature{sig} 114 | return nil 115 | } 116 | 117 | func (kp KeyPackage) Verify() bool { 118 | // Check for required extensions, but do not verify contents 119 | var sve SupportedVersionsExtension 120 | var sce SupportedCipherSuitesExtension 121 | foundSV, _ := kp.Extensions.Find(&sve) 122 | foundSC, _ := kp.Extensions.Find(&sce) 123 | if !foundSV || !foundSC { 124 | return false 125 | } 126 | 127 | // Verify that the KeyPackage has not expired 128 | var lifetimeExt LifetimeExtension 129 | found, err := kp.Extensions.Find(&lifetimeExt) 130 | if !found || err != nil { 131 | return false 132 | } 133 | 134 | now := time.Now() 135 | notAfter := time.Unix(int64(lifetimeExt.NotAfter), 0) 136 | if now.After(notAfter) { 137 | return false 138 | } 139 | notBefore := time.Unix(int64(lifetimeExt.NotBefore), 0) 140 | if now.Before(notBefore) { 141 | return false 142 | } 143 | 144 | // Verify the signature 145 | scheme := kp.Credential.Scheme() 146 | if scheme != kp.CipherSuite.Scheme() { 147 | return false 148 | } 149 | 150 | tbs, err := kp.toBeSigned() 151 | if err != nil { 152 | return false 153 | } 154 | 155 | return kp.Credential.Scheme().Verify(kp.Credential.PublicKey(), tbs, kp.Signature.Data) 156 | } 157 | 158 | func NewKeyPackageWithSecret(suite CipherSuite, initSecret []byte, cred *Credential, sigPriv SignaturePrivateKey) (*KeyPackage, error) { 159 | initPriv, err := suite.hpke().Derive(initSecret) 160 | if err != nil { 161 | return nil, err 162 | } 163 | 164 | return NewKeyPackageWithInitKey(suite, initPriv.PublicKey, cred, sigPriv) 165 | } 166 | 167 | func NewKeyPackageWithInitKey(suite CipherSuite, initKey HPKEPublicKey, cred *Credential, sigPriv SignaturePrivateKey) (*KeyPackage, error) { 168 | kp := &KeyPackage{ 169 | Version: ProtocolVersionMLS10, 170 | CipherSuite: suite, 171 | InitKey: initKey, 172 | Credential: *cred, 173 | } 174 | 175 | // Add required extensions 176 | err := kp.Extensions.Add(SupportedVersionsExtension{supportedVersions}) 177 | if err != nil { 178 | return nil, err 179 | } 180 | 181 | err = kp.Extensions.Add(SupportedCipherSuitesExtension{supportedCipherSuites}) 182 | if err != nil { 183 | return nil, err 184 | } 185 | 186 | expiry := uint64(time.Now().Add(defaultLifetime).Unix()) 187 | err = kp.Extensions.Add(LifetimeExtension{NotBefore: 0, NotAfter: expiry}) 188 | if err != nil { 189 | return nil, err 190 | } 191 | 192 | // Sign 193 | err = kp.Sign(sigPriv) 194 | if err != nil { 195 | return nil, err 196 | } 197 | return kp, nil 198 | } 199 | 200 | /// 201 | /// Proposal 202 | /// 203 | type ProposalType uint8 204 | 205 | const ( 206 | ProposalTypeInvalid ProposalType = 0 207 | ProposalTypeAdd ProposalType = 1 208 | ProposalTypeUpdate ProposalType = 2 209 | ProposalTypeRemove ProposalType = 3 210 | ) 211 | 212 | func (pt ProposalType) ValidForTLS() error { 213 | return validateEnum(pt, ProposalTypeAdd, ProposalTypeUpdate, ProposalTypeRemove) 214 | } 215 | 216 | type AddProposal struct { 217 | KeyPackage KeyPackage 218 | } 219 | 220 | type UpdateProposal struct { 221 | KeyPackage KeyPackage 222 | } 223 | 224 | type RemoveProposal struct { 225 | Removed LeafIndex 226 | } 227 | 228 | type Proposal struct { 229 | Add *AddProposal 230 | Update *UpdateProposal 231 | Remove *RemoveProposal 232 | } 233 | 234 | func (p Proposal) Type() ProposalType { 235 | switch { 236 | case p.Add != nil: 237 | return ProposalTypeAdd 238 | case p.Update != nil: 239 | return ProposalTypeUpdate 240 | case p.Remove != nil: 241 | return ProposalTypeRemove 242 | default: 243 | panic("Malformed proposal") 244 | } 245 | } 246 | 247 | func (p Proposal) MarshalTLS() ([]byte, error) { 248 | s := syntax.NewWriteStream() 249 | proposalType := p.Type() 250 | err := s.Write(proposalType) 251 | if err != nil { 252 | return nil, fmt.Errorf("mls.proposal: Marshal failed for ProposalType: %v", err) 253 | } 254 | 255 | switch proposalType { 256 | case ProposalTypeAdd: 257 | err = s.Write(p.Add) 258 | case ProposalTypeUpdate: 259 | err = s.Write(p.Update) 260 | case ProposalTypeRemove: 261 | err = s.Write(p.Remove) 262 | default: 263 | return nil, fmt.Errorf("mls.proposal: ProposalType type not allowed: %v", err) 264 | } 265 | 266 | if err != nil { 267 | return nil, fmt.Errorf("mls.proposal: Marshal failed: %v", err) 268 | } 269 | 270 | return s.Data(), nil 271 | } 272 | 273 | func (p *Proposal) UnmarshalTLS(data []byte) (int, error) { 274 | s := syntax.NewReadStream(data) 275 | var proposalType ProposalType 276 | _, err := s.Read(&proposalType) 277 | if err != nil { 278 | return 0, fmt.Errorf("mls.proposal: Unmarshal failed for ProposalTpe") 279 | } 280 | 281 | switch proposalType { 282 | case ProposalTypeAdd: 283 | p.Add = new(AddProposal) 284 | _, err = s.Read(p.Add) 285 | case ProposalTypeUpdate: 286 | p.Update = new(UpdateProposal) 287 | _, err = s.Read(p.Update) 288 | case ProposalTypeRemove: 289 | p.Remove = new(RemoveProposal) 290 | _, err = s.Read(p.Remove) 291 | default: 292 | err = fmt.Errorf("mls.proposal: ProposalType type not allowed") 293 | } 294 | 295 | if err != nil { 296 | return 0, err 297 | } 298 | 299 | return s.Position(), nil 300 | } 301 | 302 | /// 303 | /// Commit 304 | /// 305 | type ProposalID struct { 306 | Hash []byte `tls:"head=1"` 307 | } 308 | 309 | func (pid ProposalID) String() string { 310 | return fmt.Sprintf("%x", pid.Hash) 311 | } 312 | 313 | type Commit struct { 314 | Updates []ProposalID `tls:"head=2"` 315 | Removes []ProposalID `tls:"head=2"` 316 | Adds []ProposalID `tls:"head=2"` 317 | 318 | Path *DirectPath `tls:"optional"` 319 | } 320 | 321 | func (commit Commit) PathRequired() bool { 322 | haveUpdates := len(commit.Updates) > 0 323 | haveRemoves := len(commit.Removes) > 0 324 | haveAdds := len(commit.Adds) > 0 325 | 326 | nonAddProposals := haveUpdates || haveRemoves 327 | noProposalsAtAll := !haveUpdates && !haveRemoves && !haveAdds 328 | 329 | return nonAddProposals || noProposalsAtAll 330 | } 331 | 332 | func (commit Commit) ValidForTLS() bool { 333 | return commit.Path != nil || !commit.PathRequired() 334 | } 335 | 336 | /// 337 | /// MLSPlaintext and MLSCiphertext 338 | /// 339 | type Epoch uint64 340 | 341 | type ContentType uint8 342 | 343 | const ( 344 | ContentTypeInvalid ContentType = 0 345 | ContentTypeApplication ContentType = 1 346 | ContentTypeProposal ContentType = 2 347 | ContentTypeCommit ContentType = 3 348 | ) 349 | 350 | func (ct ContentType) ValidForTLS() error { 351 | return validateEnum(ct, ContentTypeApplication, ContentTypeProposal, ContentTypeCommit) 352 | } 353 | 354 | type SenderType uint8 355 | 356 | const ( 357 | SenderTypeInvalid SenderType = 0 358 | SenderTypeMember SenderType = 1 359 | SenderTypePreconfigured SenderType = 2 360 | SenderTypeNewMember SenderType = 3 361 | ) 362 | 363 | func (st SenderType) ValidForTLS() error { 364 | return validateEnum(st, SenderTypeMember, SenderTypePreconfigured, SenderTypeNewMember) 365 | } 366 | 367 | type Sender struct { 368 | Type SenderType 369 | Sender uint32 370 | } 371 | 372 | type ApplicationData struct { 373 | Data []byte `tls:"head=4"` 374 | } 375 | 376 | type Confirmation struct { 377 | Data []byte `tls:"head=1"` 378 | } 379 | type CommitData struct { 380 | Commit Commit 381 | Confirmation Confirmation 382 | } 383 | 384 | type MLSPlaintextContent struct { 385 | Application *ApplicationData 386 | Proposal *Proposal 387 | Commit *CommitData 388 | } 389 | 390 | func (c MLSPlaintextContent) Type() ContentType { 391 | switch { 392 | case c.Application != nil: 393 | return ContentTypeApplication 394 | case c.Proposal != nil: 395 | return ContentTypeProposal 396 | case c.Commit != nil: 397 | return ContentTypeCommit 398 | default: 399 | panic("Malformed plaintext content") 400 | } 401 | } 402 | 403 | func (c MLSPlaintextContent) MarshalTLS() ([]byte, error) { 404 | s := syntax.NewWriteStream() 405 | contentType := c.Type() 406 | err := s.Write(contentType) 407 | if err != nil { 408 | return nil, err 409 | } 410 | 411 | switch contentType { 412 | case ContentTypeApplication: 413 | err = s.Write(c.Application) 414 | case ContentTypeProposal: 415 | err = s.Write(c.Proposal) 416 | case ContentTypeCommit: 417 | err = s.Write(c.Commit) 418 | default: 419 | return nil, fmt.Errorf("mls.mlsplaintext: ContentType type not allowed") 420 | } 421 | 422 | if err != nil { 423 | return nil, err 424 | } 425 | 426 | return s.Data(), nil 427 | } 428 | 429 | func (c *MLSPlaintextContent) UnmarshalTLS(data []byte) (int, error) { 430 | s := syntax.NewReadStream(data) 431 | var contentType ContentType 432 | _, err := s.Read(&contentType) 433 | if err != nil { 434 | return 0, err 435 | } 436 | 437 | switch contentType { 438 | case ContentTypeApplication: 439 | c.Application = new(ApplicationData) 440 | _, err = s.Read(c.Application) 441 | case ContentTypeProposal: 442 | c.Proposal = new(Proposal) 443 | _, err = s.Read(c.Proposal) 444 | case ContentTypeCommit: 445 | c.Commit = new(CommitData) 446 | _, err = s.Read(c.Commit) 447 | default: 448 | return 0, fmt.Errorf("mls.mlsplaintext: ContentType type not allowed") 449 | } 450 | 451 | if err != nil { 452 | return 0, err 453 | } 454 | 455 | return s.Position(), nil 456 | } 457 | 458 | type MLSPlaintext struct { 459 | GroupID []byte `tls:"head=1"` 460 | Epoch Epoch 461 | Sender Sender 462 | AuthenticatedData []byte `tls:"head=4"` 463 | Content MLSPlaintextContent 464 | Signature Signature 465 | } 466 | 467 | func (pt MLSPlaintext) toBeSigned(ctx GroupContext) []byte { 468 | s := syntax.NewWriteStream() 469 | err := s.Write(ctx) 470 | if err != nil { 471 | panic(fmt.Errorf("mls.mlsplaintext: grpCtx marshal failure %v", err)) 472 | } 473 | 474 | err = s.Write(struct { 475 | GroupID []byte `tls:"head=1"` 476 | Epoch Epoch 477 | Sender Sender 478 | AuthenticatedData []byte `tls:"head=4"` 479 | Content MLSPlaintextContent 480 | }{ 481 | GroupID: pt.GroupID, 482 | Epoch: pt.Epoch, 483 | Sender: pt.Sender, 484 | AuthenticatedData: pt.AuthenticatedData, 485 | Content: pt.Content, 486 | }) 487 | 488 | if err != nil { 489 | panic(fmt.Errorf("mls.mlsplaintext: marshal failure %v", err)) 490 | } 491 | return s.Data() 492 | } 493 | 494 | func (pt *MLSPlaintext) sign(ctx GroupContext, priv SignaturePrivateKey, scheme SignatureScheme) error { 495 | tbs := pt.toBeSigned(ctx) 496 | sig, err := scheme.Sign(&priv, tbs) 497 | if err != nil { 498 | return err 499 | } 500 | 501 | pt.Signature = Signature{sig} 502 | return nil 503 | } 504 | 505 | func (pt *MLSPlaintext) verify(ctx GroupContext, pub *SignaturePublicKey, scheme SignatureScheme) bool { 506 | tbs := pt.toBeSigned(ctx) 507 | return scheme.Verify(pub, tbs, pt.Signature.Data) 508 | } 509 | 510 | func (pt MLSPlaintext) commitContent() []byte { 511 | enc, err := syntax.Marshal(struct { 512 | GroupId []byte `tls:"head=1"` 513 | Epoch Epoch 514 | Sender Sender 515 | Commit Commit 516 | ContentType ContentType 517 | }{ 518 | GroupId: pt.GroupID, 519 | Epoch: pt.Epoch, 520 | Sender: pt.Sender, 521 | Commit: pt.Content.Commit.Commit, 522 | ContentType: pt.Content.Type(), 523 | }) 524 | 525 | if err != nil { 526 | return nil 527 | } 528 | 529 | return enc 530 | } 531 | func (pt MLSPlaintext) commitAuthData() ([]byte, error) { 532 | data := pt.Content.Commit 533 | s := syntax.NewWriteStream() 534 | err := s.WriteAll(data.Confirmation, pt.Signature) 535 | if err != nil { 536 | return nil, err 537 | } 538 | return s.Data(), nil 539 | } 540 | 541 | type MLSCiphertext struct { 542 | GroupID []byte `tls:"head=1"` 543 | Epoch Epoch 544 | ContentType ContentType 545 | SenderDataNonce []byte `tls:"head=1"` 546 | EncryptedSenderData []byte `tls:"head=1"` 547 | AuthenticatedData []byte `tls:"head=4"` 548 | Ciphertext []byte `tls:"head=4"` 549 | } 550 | 551 | /// 552 | /// GroupInfo 553 | /// 554 | 555 | type GroupInfo struct { 556 | GroupID []byte `tls:"head=1"` 557 | Epoch Epoch 558 | Tree TreeKEMPublicKey 559 | ConfirmedTranscriptHash []byte `tls:"head=1"` 560 | InterimTranscriptHash []byte `tls:"head=1"` 561 | Extensions ExtensionList 562 | Confirmation []byte `tls:"head=1"` 563 | SignerIndex LeafIndex 564 | Signature []byte `tls:"head=2"` 565 | } 566 | 567 | func (gi GroupInfo) dump() { 568 | fmt.Printf("\n+++++ groupInfo +++++\n") 569 | fmt.Printf("\tGroupID %x, Epoch %x\n", gi.GroupID, gi.Epoch) 570 | gi.Tree.dump("Tree") 571 | fmt.Printf("ConfirmedTranscriptHash %x, InterimTranscriptHash %x\n", 572 | gi.ConfirmedTranscriptHash, gi.InterimTranscriptHash) 573 | fmt.Printf("\tConfirmation %x, SignerIndex %x\n", gi.Confirmation, gi.SignerIndex) 574 | fmt.Printf("\tSignature %x\n", gi.Signature) 575 | fmt.Printf("\n+++++ groupInfo +++++\n") 576 | } 577 | 578 | func (gi GroupInfo) toBeSigned() ([]byte, error) { 579 | return syntax.Marshal(struct { 580 | GroupID []byte `tls:"head=1"` 581 | Epoch Epoch 582 | Tree TreeKEMPublicKey 583 | ConfirmedTranscriptHash []byte `tls:"head=1"` 584 | InterimTranscriptHash []byte `tls:"head=1"` 585 | Confirmation []byte `tls:"head=1"` 586 | SignerIndex LeafIndex 587 | }{ 588 | GroupID: gi.GroupID, 589 | Epoch: gi.Epoch, 590 | Tree: gi.Tree, 591 | ConfirmedTranscriptHash: gi.ConfirmedTranscriptHash, 592 | InterimTranscriptHash: gi.InterimTranscriptHash, 593 | Confirmation: gi.Confirmation, 594 | SignerIndex: gi.SignerIndex, 595 | }) 596 | } 597 | 598 | func (gi *GroupInfo) sign(index LeafIndex, priv *SignaturePrivateKey) error { 599 | // Verify that priv corresponds to tree[index] 600 | kp, ok := gi.Tree.KeyPackage(index) 601 | if !ok { 602 | return fmt.Errorf("mls.groupInfo: Attempt to sign from unoccupied leaf") 603 | } 604 | 605 | scheme := kp.CipherSuite.Scheme() 606 | pub := kp.Credential.PublicKey() 607 | if !pub.Equals(priv.PublicKey) { 608 | return fmt.Errorf("mls.groupInfo: Incorrect private key for index") 609 | } 610 | 611 | // Marshal the contents 612 | gi.SignerIndex = index 613 | tbs, err := gi.toBeSigned() 614 | if err != nil { 615 | return err 616 | } 617 | 618 | // Sign toBeSigned() with priv -> SignerIndex, Signature 619 | sig, err := scheme.Sign(priv, tbs) 620 | if err != nil { 621 | return err 622 | } 623 | 624 | gi.Signature = sig 625 | return nil 626 | } 627 | 628 | func (gi GroupInfo) verify() error { 629 | // Get pub from tree[SignerIndex] 630 | kp, ok := gi.Tree.KeyPackage(gi.SignerIndex) 631 | if !ok { 632 | return fmt.Errorf("mls.groupInfo: Attempt to sign from unoccupied leaf") 633 | } 634 | 635 | scheme := kp.CipherSuite.Scheme() 636 | pub := kp.Credential.PublicKey() 637 | 638 | // Marshal the contents of the GroupInfo 639 | tbs, err := gi.toBeSigned() 640 | if err != nil { 641 | return err 642 | } 643 | 644 | // Verify (toBeSigned(), Signature) with pub 645 | ver := scheme.Verify(pub, tbs, gi.Signature) 646 | if !ver { 647 | return fmt.Errorf("mls.groupInfo: Vefication failed") 648 | } 649 | 650 | return nil 651 | } 652 | 653 | /// 654 | /// GroupSecrets 655 | /// 656 | type PathSecret struct { 657 | Data []byte `tls:"head=1"` 658 | } 659 | 660 | type GroupSecrets struct { 661 | EpochSecret []byte `tls:"head=1"` 662 | PathSecret *PathSecret `tls:"optional"` 663 | } 664 | 665 | /// 666 | /// EncryptedGroupSecrets 667 | /// 668 | type EncryptedGroupSecrets struct { 669 | KeyPackageHash []byte `tls:"head=1"` 670 | EncryptedGroupSecrets HPKECiphertext 671 | } 672 | 673 | /// 674 | /// Welcome 675 | /// 676 | 677 | type Welcome struct { 678 | Version ProtocolVersion 679 | CipherSuite CipherSuite 680 | Secrets []EncryptedGroupSecrets `tls:"head=4"` 681 | EncryptedGroupInfo []byte `tls:"head=4"` 682 | epochSecret []byte `tls:"omit"` 683 | } 684 | 685 | // XXX(rlb): The pattern we follow here basically locks us into having empty 686 | // AAD. I suspect that eventually we're going to want to have the header to the 687 | // message (version, cipher, encrypted key packages) as AAD. We should consider 688 | // refactoring so that the API flows slightly differently: 689 | // 690 | // * newWelcome() - caches initSecret and *unencrypted* GroupInfo 691 | // * encrypt() for each member 692 | // * finalize() - computes AAD and encrypts GroupInfo 693 | // 694 | // This will also probably require a helper method for decryption. 695 | func newWelcome(cs CipherSuite, epochSecret []byte, groupInfo *GroupInfo) *Welcome { 696 | // Encrypt the GroupInfo 697 | pt, err := syntax.Marshal(groupInfo) 698 | if err != nil { 699 | panic(fmt.Errorf("mls.welcome: GroupInfo marshal failure %v", err)) 700 | } 701 | 702 | kn := groupInfoKeyAndNonce(cs, epochSecret) 703 | aead, err := cs.NewAEAD(kn.Key) 704 | if err != nil { 705 | panic(fmt.Errorf("mls.welcome: error creating AEAD: %v", err)) 706 | } 707 | ct := aead.Seal(nil, kn.Nonce, pt, []byte{}) 708 | 709 | // Assemble the Welcome 710 | return &Welcome{ 711 | Version: ProtocolVersionMLS10, 712 | CipherSuite: cs, 713 | EncryptedGroupInfo: ct, 714 | epochSecret: epochSecret, 715 | } 716 | } 717 | 718 | // TODO(RLB): Return error instead of panicking 719 | func (w *Welcome) EncryptTo(kp KeyPackage, pathSecret []byte) { 720 | // Check that the ciphersuite is acceptable 721 | if kp.CipherSuite != w.CipherSuite { 722 | panic(fmt.Errorf("mls.welcome: cipher suite mismatch %v != %v", kp.CipherSuite, w.CipherSuite)) 723 | } 724 | 725 | // Compute the hash of the kp 726 | data, err := syntax.Marshal(kp) 727 | if err != nil { 728 | panic(fmt.Errorf("mls.welcome: kp marshal failure %v", err)) 729 | } 730 | 731 | kpHash := w.CipherSuite.Digest(data) 732 | 733 | // Encrypt the group init secret to new member's public key 734 | gs := GroupSecrets{ 735 | EpochSecret: w.epochSecret, 736 | } 737 | 738 | if pathSecret != nil { 739 | gs.PathSecret = &PathSecret{pathSecret} 740 | } 741 | 742 | pt, err := syntax.Marshal(gs) 743 | if err != nil { 744 | panic(fmt.Errorf("mls.welcome: KeyPackage marshal failure %v", err)) 745 | } 746 | 747 | egs, err := w.CipherSuite.hpke().Encrypt(kp.InitKey, []byte{}, pt) 748 | if err != nil { 749 | panic(fmt.Errorf("mls.welcome: encrpyting KeyPackage failure %v", err)) 750 | } 751 | 752 | // Assemble and append the key package 753 | ekp := EncryptedGroupSecrets{ 754 | KeyPackageHash: kpHash, 755 | EncryptedGroupSecrets: egs, 756 | } 757 | w.Secrets = append(w.Secrets, ekp) 758 | } 759 | 760 | func (w Welcome) Decrypt(suite CipherSuite, epochSecret []byte) (*GroupInfo, error) { 761 | gikn := groupInfoKeyAndNonce(suite, epochSecret) 762 | 763 | aead, err := suite.NewAEAD(gikn.Key) 764 | if err != nil { 765 | return nil, fmt.Errorf("mls.state: error creating AEAD: %v", err) 766 | } 767 | 768 | data, err := aead.Open(nil, gikn.Nonce, w.EncryptedGroupInfo, []byte{}) 769 | if err != nil { 770 | return nil, fmt.Errorf("mls.state: unable to decrypt groupInfo: %v", err) 771 | } 772 | 773 | gi := new(GroupInfo) 774 | _, err = syntax.Unmarshal(data, gi) 775 | if err != nil { 776 | return nil, fmt.Errorf("mls.state: unable to unmarshal groupInfo: %v", err) 777 | } 778 | 779 | gi.Tree.Suite = suite 780 | gi.Tree.SetHashAll() 781 | 782 | if err = gi.verify(); err != nil { 783 | return nil, fmt.Errorf("mls.state: invalid groupInfo") 784 | } 785 | 786 | gi.Tree.Suite = suite 787 | 788 | return gi, nil 789 | } 790 | -------------------------------------------------------------------------------- /messages_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "time" 7 | 8 | syntax "github.com/cisco/go-tls-syntax" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | var ( 13 | sigPublicKey = SignaturePublicKey{[]byte{0xA0, 0xA0, 0xA0, 0xA0}} 14 | basicCredential = &BasicCredential{ 15 | Identity: []byte{0x01, 0x02, 0x03, 0x04}, 16 | SignatureScheme: 0x0403, 17 | PublicKey: sigPublicKey, 18 | } 19 | 20 | credentialBasic = Credential{ 21 | Basic: basicCredential, 22 | } 23 | 24 | extIn = Extension{ 25 | ExtensionType: ExtensionType(0x0001), 26 | ExtensionData: []byte{0xf0, 0xf1, 0xf2, 0xf3, 0xf4}, 27 | } 28 | 29 | extEmpty = Extension{ 30 | ExtensionType: ExtensionType(0x0002), 31 | ExtensionData: []byte{}, 32 | } 33 | 34 | extListIn = ExtensionList{[]Extension{extIn, extEmpty}} 35 | 36 | extValidIn = Extension{ 37 | ExtensionType: ExtensionType(0x000a), 38 | ExtensionData: []byte{0xf0, 0xf1, 0xf2, 0xf3, 0xf4}, 39 | } 40 | extEmptyIn = Extension{ 41 | ExtensionType: ExtensionType(0x000a), 42 | ExtensionData: []byte{}, 43 | } 44 | 45 | extListValidIn = ExtensionList{[]Extension{extValidIn, extEmptyIn}} 46 | 47 | ikPriv, _ = suite.hpke().Generate() 48 | 49 | keyPackage = &KeyPackage{ 50 | Version: ProtocolVersionMLS10, 51 | CipherSuite: suite, 52 | InitKey: ikPriv.PublicKey, 53 | Credential: credentialBasic, 54 | Extensions: extListValidIn, 55 | Signature: Signature{[]byte{0x00, 0x00, 0x00}}, 56 | } 57 | 58 | addProposal = &Proposal{ 59 | Add: &AddProposal{ 60 | KeyPackage: *keyPackage, 61 | }, 62 | } 63 | 64 | removeProposal = &Proposal{ 65 | Remove: &RemoveProposal{ 66 | Removed: 12, 67 | }, 68 | } 69 | 70 | updateProposal = &Proposal{ 71 | Update: &UpdateProposal{ 72 | KeyPackage: *keyPackage, 73 | }, 74 | } 75 | 76 | nodePublicKey = HPKEPublicKey{ 77 | Data: []byte{0x11, 0x12, 0x13, 0x14, 0x15, 0x16}, 78 | } 79 | 80 | nodes = []DirectPathNode{ 81 | { 82 | PublicKey: nodePublicKey, 83 | EncryptedPathSecrets: []HPKECiphertext{}, 84 | }, 85 | } 86 | 87 | dp = &DirectPath{ 88 | LeafKeyPackage: *keyPackage, 89 | Steps: nodes, 90 | } 91 | 92 | commit = &Commit{ 93 | Updates: []ProposalID{{Hash: []byte{0x00, 0x01}}}, 94 | Removes: []ProposalID{{Hash: []byte{0x02, 0x03}}}, 95 | Adds: []ProposalID{{Hash: []byte{0x04, 0x05}}}, 96 | Path: dp, 97 | } 98 | 99 | mlsPlaintextIn = &MLSPlaintext{ 100 | GroupID: []byte{0x01, 0x02, 0x03, 0x04}, 101 | Epoch: 1, 102 | Sender: Sender{SenderTypeMember, 4}, 103 | AuthenticatedData: []byte{0xAA, 0xBB, 0xcc, 0xdd}, 104 | Content: MLSPlaintextContent{ 105 | Application: &ApplicationData{ 106 | Data: []byte{0x0A, 0x0B, 0x0C, 0x0D}, 107 | }, 108 | }, 109 | Signature: Signature{[]byte{0x00, 0x01, 0x02, 0x03}}, 110 | } 111 | 112 | mlsPlaintextCommitIn = &MLSPlaintext{ 113 | GroupID: []byte{0x01, 0x02, 0x03, 0x04}, 114 | Epoch: 1, 115 | Sender: Sender{SenderTypeMember, 4}, 116 | AuthenticatedData: []byte{0xAA, 0xBB, 0xcc, 0xdd}, 117 | Content: MLSPlaintextContent{ 118 | Commit: &CommitData{ 119 | Commit: *commit, 120 | Confirmation: Confirmation{ 121 | Data: []byte{0x0C, 0x00, 0x03, 0x03, 0x01, 0x0f}, 122 | }, 123 | }, 124 | }, 125 | Signature: Signature{[]byte{0x00, 0x01, 0x02, 0x03}}, 126 | } 127 | 128 | mlsPlaintextProposalIn = &MLSPlaintext{ 129 | GroupID: []byte{0x01, 0x02, 0x03, 0x04}, 130 | Epoch: 1, 131 | Sender: Sender{SenderTypeMember, 4}, 132 | AuthenticatedData: []byte{0xAA, 0xBB, 0xcc, 0xdd}, 133 | Content: MLSPlaintextContent{ 134 | Proposal: removeProposal, 135 | }, 136 | Signature: Signature{[]byte{0x00, 0x01, 0x02, 0x03}}, 137 | } 138 | 139 | mlsCiphertextIn = &MLSCiphertext{ 140 | GroupID: []byte{0x01, 0x02, 0x03, 0x04}, 141 | Epoch: 1, 142 | ContentType: 1, 143 | AuthenticatedData: []byte{0xAA, 0xBB, 0xCC}, 144 | SenderDataNonce: []byte{0x01, 0x02}, 145 | EncryptedSenderData: []byte{0x11, 0x12, 0x13, 0x14, 0x15, 0x16}, 146 | Ciphertext: []byte{0x11, 0x12, 0x13, 0x14, 0x15, 0x16}, 147 | } 148 | ) 149 | 150 | func roundTrip(original interface{}, decoded interface{}) func(t *testing.T) { 151 | return func(t *testing.T) { 152 | encoded, err := syntax.Marshal(original) 153 | require.Nil(t, err) 154 | 155 | _, err = syntax.Unmarshal(encoded, decoded) 156 | require.Nil(t, err) 157 | require.Equal(t, decoded, original) 158 | } 159 | } 160 | 161 | func TestMessagesMarshalUnmarshal(t *testing.T) { 162 | t.Run("BasicCredential", roundTrip(&credentialBasic, new(Credential))) 163 | t.Run("KeyPackage", roundTrip(keyPackage, new(KeyPackage))) 164 | t.Run("AddProposal", roundTrip(addProposal, new(Proposal))) 165 | t.Run("RemoveProposal", roundTrip(removeProposal, new(Proposal))) 166 | t.Run("UpdateProposal", roundTrip(updateProposal, new(Proposal))) 167 | t.Run("Commit", roundTrip(commit, new(Commit))) 168 | t.Run("MLSPlaintextContentApplication", roundTrip(mlsPlaintextIn, new(MLSPlaintext))) 169 | t.Run("MLSPlaintextContentProposal", roundTrip(mlsPlaintextProposalIn, new(MLSPlaintext))) 170 | t.Run("MLSPlaintextContentCommit", roundTrip(mlsPlaintextCommitIn, new(MLSPlaintext))) 171 | 172 | t.Run("MLSCiphertext", roundTrip(mlsCiphertextIn, new(MLSCiphertext))) 173 | } 174 | 175 | func TestKeyPackageExpiry(t *testing.T) { 176 | // Prepare a new key package, which should be valid 177 | scheme := suite.Scheme() 178 | priv, err := scheme.Generate() 179 | require.Nil(t, err) 180 | 181 | cred := NewBasicCredential(userID, scheme, priv.PublicKey) 182 | kp, err := NewKeyPackageWithSecret(suite, randomBytes(32), cred, priv) 183 | require.Nil(t, err) 184 | 185 | ver := kp.Verify() 186 | require.True(t, ver) 187 | 188 | // Change the expiration time to a time in the past and check that verify() 189 | // now fails 190 | alreadyExpired := LifetimeExtension{ 191 | NotBefore: 0, 192 | NotAfter: uint64(time.Now().Add(-24 * time.Hour).Unix()), 193 | } 194 | err = kp.SetExtensions([]ExtensionBody{alreadyExpired}) 195 | require.Nil(t, err) 196 | err = kp.Sign(priv) 197 | require.Nil(t, err) 198 | 199 | ver = kp.Verify() 200 | require.False(t, ver) 201 | } 202 | 203 | func newTestRatchetTree(t *testing.T, suite CipherSuite, secrets [][]byte) *TreeKEMPublicKey { 204 | scheme := suite.Scheme() 205 | 206 | tree := NewTreeKEMPublicKey(suite) 207 | for _, secret := range secrets { 208 | initPriv, err := suite.hpke().Derive(secret) 209 | require.Nil(t, err) 210 | 211 | sigPriv, err := scheme.Derive(secret) 212 | require.Nil(t, err) 213 | 214 | cred := NewBasicCredential(userID, scheme, sigPriv.PublicKey) 215 | 216 | keyPackage, err = NewKeyPackageWithInitKey(suite, initPriv.PublicKey, cred, sigPriv) 217 | require.Nil(t, err) 218 | 219 | tree.AddLeaf(*keyPackage) 220 | } 221 | 222 | // TODO(RLB): Encap to fill in the tree 223 | 224 | return tree 225 | } 226 | 227 | func TestWelcomeMarshalUnMarshalWithDecryption(t *testing.T) { 228 | // a tree with 2 members 229 | secrets := [][]byte{randomBytes(32), randomBytes(32)} 230 | tree := newTestRatchetTree(t, suite, secrets) 231 | 232 | keyPackage, ok := tree.KeyPackage(0) 233 | require.True(t, ok) 234 | 235 | initKey, err := suite.hpke().Derive(secrets[0]) 236 | require.Nil(t, err) 237 | 238 | // setup things needed to welcome c 239 | epochSecret := []byte("we welcome you c") 240 | gi := &GroupInfo{ 241 | GroupID: unhex("0007"), 242 | Epoch: 121, 243 | Tree: *tree, 244 | ConfirmedTranscriptHash: []byte{0x03, 0x04, 0x05, 0x06}, 245 | InterimTranscriptHash: []byte{0x02, 0x03, 0x04, 0x05}, 246 | SignerIndex: 0, 247 | Confirmation: []byte{0x00, 0x00, 0x00, 0x00}, 248 | Signature: []byte{0xAA, 0xBB, 0xCC}, 249 | } 250 | 251 | w1 := newWelcome(suite, epochSecret, gi) 252 | w1.EncryptTo(keyPackage, randomBytes(32)) 253 | // doing this so that test can omit this field when matching w1, w2 254 | w1.epochSecret = nil 255 | w2 := new(Welcome) 256 | t.Run("WelcomeOneMember", roundTrip(w1, w2)) 257 | 258 | // decrypt the group init secret with C's privateKey and check if 259 | // it matches. 260 | egs := w2.Secrets[0] 261 | pt, err := suite.hpke().Decrypt(initKey, []byte{}, egs.EncryptedGroupSecrets) 262 | require.Nil(t, err) 263 | 264 | w2kp := new(GroupSecrets) 265 | _, err = syntax.Unmarshal(pt, w2kp) 266 | require.Nil(t, err) 267 | require.Equal(t, epochSecret, w2kp.EpochSecret) 268 | } 269 | 270 | func TestProposalErrorCases(t *testing.T) { 271 | p := Proposal{Add: nil, Update: nil, Remove: nil} 272 | require.Panics(t, func() { p.Type() }) 273 | require.Panics(t, func() { syntax.Marshal(p) }) 274 | } 275 | 276 | func TestMLSPlainTestErrorCases(t *testing.T) { 277 | c := MLSPlaintextContent{Application: nil, Proposal: nil, Commit: nil} 278 | require.Panics(t, func() { c.Type() }) 279 | } 280 | 281 | /// 282 | /// Test Vectors 283 | /// 284 | 285 | type MessageTestCase struct { 286 | CipherSuite CipherSuite 287 | SignatureScheme SignatureScheme 288 | 289 | KeyPackage []byte `tls:"head=4"` 290 | GroupInfo []byte `tls:"head=4"` 291 | GroupSecrets []byte `tls:"head=4"` 292 | EncryptedGroupSecrets []byte `tls:"head=4"` 293 | Welcome []byte `tls:"head=4"` 294 | AddProposal []byte `tls:"head=4"` 295 | UpdateProposal []byte `tls:"head=4"` 296 | RemoveProposal []byte `tls:"head=4"` 297 | Commit []byte `tls:"head=4"` 298 | MLSCiphertext []byte `tls:"head=4"` 299 | } 300 | 301 | type MessageTestVectors struct { 302 | Epoch Epoch 303 | SenderType SenderType 304 | SignerIndex LeafIndex 305 | Removed LeafIndex 306 | UserId []byte `tls:"head=1"` 307 | GroupID []byte `tls:"head=1"` 308 | KeyPackageId []byte `tls:"head=1"` 309 | DHSeed []byte `tls:"head=1"` 310 | SigSeed []byte `tls:"head=1"` 311 | Random []byte `tls:"head=1"` 312 | Cases []MessageTestCase `tls:"head=4"` 313 | } 314 | 315 | /// Gen and Verify 316 | func generateMessageVectors(t *testing.T) []byte { 317 | tv := MessageTestVectors{ 318 | Epoch: 0xA0A1A2A3, 319 | SenderType: SenderTypeMember, 320 | SignerIndex: LeafIndex(0xB0B1B2B3), 321 | Removed: LeafIndex(0xC0C1C2C3), 322 | UserId: bytes.Repeat([]byte{0xD1}, 16), 323 | GroupID: bytes.Repeat([]byte{0xD2}, 16), 324 | KeyPackageId: bytes.Repeat([]byte{0xD3}, 16), 325 | DHSeed: bytes.Repeat([]byte{0xD4}, 32), 326 | SigSeed: bytes.Repeat([]byte{0xD5}, 32), 327 | Random: bytes.Repeat([]byte{0xD6}, 32), 328 | Cases: []MessageTestCase{}, 329 | } 330 | 331 | suites := []CipherSuite{P256_AES128GCM_SHA256_P256, X25519_AES128GCM_SHA256_Ed25519} 332 | schemes := []SignatureScheme{ECDSA_SECP256R1_SHA256, Ed25519} 333 | 334 | for i := range suites { 335 | suite := suites[i] 336 | scheme := schemes[i] 337 | // hpke 338 | priv, err := suite.hpke().Derive(tv.DHSeed) 339 | require.Nil(t, err) 340 | pub := priv.PublicKey 341 | 342 | // identity 343 | sigPriv, err := scheme.Derive(tv.SigSeed) 344 | require.Nil(t, err) 345 | sigPub := sigPriv.PublicKey 346 | 347 | bc := &BasicCredential{ 348 | Identity: tv.UserId, 349 | SignatureScheme: scheme, 350 | PublicKey: sigPub, 351 | } 352 | cred := Credential{Basic: bc} 353 | 354 | secrets := [][]byte{tv.Random, tv.Random, tv.Random, tv.Random} 355 | ratchetTree := newTestRatchetTree(t, suite, secrets) 356 | 357 | ratchetTree.BlankPath(LeafIndex(2)) 358 | 359 | treeSigPriv, err := scheme.Derive(secrets[0]) 360 | require.Nil(t, err) 361 | 362 | _, _, err = ratchetTree.Encap(LeafIndex(0), []byte{}, tv.Random, treeSigPriv, nil) 363 | require.Nil(t, err) 364 | 365 | // KeyPackage 366 | kp := KeyPackage{ 367 | Version: ProtocolVersionMLS10, 368 | CipherSuite: suite, 369 | InitKey: pub, 370 | Credential: cred, 371 | Signature: Signature{tv.Random}, 372 | } 373 | 374 | dp.LeafKeyPackage = kp 375 | 376 | kpM, err := syntax.Marshal(kp) 377 | require.Nil(t, err) 378 | 379 | // Welcome 380 | 381 | gi := &GroupInfo{ 382 | GroupID: tv.GroupID, 383 | Epoch: tv.Epoch, 384 | Tree: *ratchetTree, 385 | ConfirmedTranscriptHash: tv.Random, 386 | InterimTranscriptHash: tv.Random, 387 | Confirmation: tv.Random, 388 | SignerIndex: tv.SignerIndex, 389 | Signature: tv.Random, 390 | } 391 | 392 | giM, err := syntax.Marshal(gi) 393 | require.Nil(t, err) 394 | 395 | gs := GroupSecrets{ 396 | EpochSecret: tv.Random, 397 | } 398 | 399 | gsM, err := syntax.Marshal(gs) 400 | require.Nil(t, err) 401 | 402 | encPayload, err := suite.hpke().Encrypt(pub, []byte{}, tv.Random) 403 | require.Nil(t, err) 404 | egs := EncryptedGroupSecrets{ 405 | KeyPackageHash: tv.Random, 406 | EncryptedGroupSecrets: encPayload, 407 | } 408 | 409 | egsM, err := syntax.Marshal(egs) 410 | require.Nil(t, err) 411 | 412 | var welcome Welcome 413 | welcome.Version = ProtocolVersionMLS10 414 | welcome.CipherSuite = suite 415 | welcome.Secrets = []EncryptedGroupSecrets{egs, egs} 416 | welcome.EncryptedGroupInfo = tv.Random 417 | 418 | welM, err := syntax.Marshal(welcome) 419 | require.Nil(t, err) 420 | 421 | // proposals 422 | addProposal := &Proposal{ 423 | Add: &AddProposal{ 424 | KeyPackage: kp, 425 | }, 426 | } 427 | 428 | addHs := MLSPlaintext{ 429 | GroupID: tv.GroupID, 430 | Epoch: tv.Epoch, 431 | Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)}, 432 | Content: MLSPlaintextContent{ 433 | Proposal: addProposal, 434 | }, 435 | } 436 | addHs.Signature = Signature{tv.Random} 437 | 438 | addM, err := syntax.Marshal(addHs) 439 | require.Nil(t, err) 440 | 441 | updateProposal := &Proposal{ 442 | Update: &UpdateProposal{ 443 | KeyPackage: kp, 444 | }, 445 | } 446 | 447 | updateHs := MLSPlaintext{ 448 | GroupID: tv.GroupID, 449 | Epoch: tv.Epoch, 450 | Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)}, 451 | Content: MLSPlaintextContent{ 452 | Proposal: updateProposal, 453 | }, 454 | } 455 | updateHs.Signature = Signature{tv.Random} 456 | 457 | updateM, err := syntax.Marshal(updateHs) 458 | require.Nil(t, err) 459 | 460 | removeProposal := &Proposal{ 461 | Remove: &RemoveProposal{ 462 | Removed: tv.SignerIndex, 463 | }, 464 | } 465 | 466 | removeHs := MLSPlaintext{ 467 | GroupID: tv.GroupID, 468 | Epoch: tv.Epoch, 469 | Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)}, 470 | Content: MLSPlaintextContent{ 471 | Proposal: removeProposal, 472 | }, 473 | } 474 | removeHs.Signature = Signature{tv.Random} 475 | 476 | remM, err := syntax.Marshal(removeHs) 477 | require.Nil(t, err) 478 | 479 | // commit 480 | proposal := []ProposalID{{tv.Random}, {tv.Random}} 481 | commit := Commit{ 482 | Updates: proposal, 483 | Removes: proposal, 484 | Adds: proposal, 485 | Path: dp, 486 | } 487 | 488 | commitM, err := syntax.Marshal(commit) 489 | require.Nil(t, err) 490 | 491 | //MlsCiphertext 492 | ct := MLSCiphertext{ 493 | GroupID: tv.GroupID, 494 | Epoch: tv.Epoch, 495 | ContentType: ContentTypeApplication, 496 | SenderDataNonce: tv.Random, 497 | EncryptedSenderData: tv.Random, 498 | AuthenticatedData: tv.Random, 499 | } 500 | 501 | ctM, err := syntax.Marshal(ct) 502 | require.Nil(t, err) 503 | 504 | tc := MessageTestCase{ 505 | CipherSuite: suite, 506 | SignatureScheme: scheme, 507 | KeyPackage: kpM, 508 | GroupInfo: giM, 509 | GroupSecrets: gsM, 510 | EncryptedGroupSecrets: egsM, 511 | Welcome: welM, 512 | AddProposal: addM, 513 | UpdateProposal: updateM, 514 | RemoveProposal: remM, 515 | Commit: commitM, 516 | MLSCiphertext: ctM, 517 | } 518 | tv.Cases = append(tv.Cases, tc) 519 | } 520 | 521 | vec, err := syntax.Marshal(tv) 522 | require.Nil(t, err) 523 | return vec 524 | } 525 | 526 | func verifyMessageVectors(t *testing.T, data []byte) { 527 | var tv MessageTestVectors 528 | _, err := syntax.Unmarshal(data, &tv) 529 | require.Nil(t, err) 530 | 531 | for _, tc := range tv.Cases { 532 | suite := tc.CipherSuite 533 | scheme := tc.SignatureScheme 534 | priv, err := suite.hpke().Derive(tv.DHSeed) 535 | require.Nil(t, err) 536 | pub := priv.PublicKey 537 | 538 | sigPriv, err := scheme.Derive(tv.SigSeed) 539 | require.Nil(t, err) 540 | sigPub := sigPriv.PublicKey 541 | 542 | bc := &BasicCredential{ 543 | Identity: tv.UserId, 544 | SignatureScheme: scheme, 545 | PublicKey: sigPub, 546 | } 547 | cred := Credential{Basic: bc} 548 | 549 | secrets := [][]byte{tv.Random, tv.Random, tv.Random, tv.Random} 550 | ratchetTree := newTestRatchetTree(t, suite, secrets) 551 | 552 | ratchetTree.BlankPath(LeafIndex(2)) 553 | 554 | treeSigPriv, err := scheme.Derive(secrets[0]) 555 | require.Nil(t, err) 556 | 557 | _, _, err = ratchetTree.Encap(LeafIndex(0), []byte{}, tv.Random, treeSigPriv, nil) 558 | require.Nil(t, err) 559 | 560 | // KeyPackage 561 | kp := KeyPackage{ 562 | Version: ProtocolVersionMLS10, 563 | CipherSuite: suite, 564 | InitKey: pub, 565 | Credential: cred, 566 | Extensions: NewExtensionList(), 567 | Signature: Signature{tv.Random}, 568 | } 569 | 570 | dp.LeafKeyPackage = kp 571 | 572 | kpM, err := syntax.Marshal(kp) 573 | require.Nil(t, err) 574 | require.Equal(t, kpM, tc.KeyPackage) 575 | 576 | // Welcome 577 | var gi GroupInfo 578 | gi.Tree.Suite = suite 579 | _, err = syntax.Unmarshal(tc.GroupInfo, &gi) 580 | require.Nil(t, err) 581 | 582 | marshaled, err := syntax.Marshal(gi) 583 | require.Nil(t, err) 584 | require.Equal(t, marshaled, tc.GroupInfo) 585 | 586 | gs := GroupSecrets{ 587 | EpochSecret: tv.Random, 588 | } 589 | 590 | gsM, err := syntax.Marshal(gs) 591 | require.Nil(t, err) 592 | require.Equal(t, gsM, tc.GroupSecrets) 593 | 594 | encPayload, err := suite.hpke().Encrypt(pub, []byte{}, tv.Random) 595 | require.Nil(t, err) 596 | egs := EncryptedGroupSecrets{ 597 | KeyPackageHash: tv.Random, 598 | EncryptedGroupSecrets: encPayload, 599 | } 600 | var egsWire EncryptedGroupSecrets 601 | syntax.Unmarshal(tc.EncryptedGroupSecrets, &egsWire) 602 | require.Equal(t, egs.KeyPackageHash, egsWire.KeyPackageHash) 603 | 604 | var welcome Welcome 605 | welcome.Version = ProtocolVersionMLS10 606 | welcome.CipherSuite = suite 607 | welcome.Secrets = []EncryptedGroupSecrets{egs, egs} 608 | welcome.EncryptedGroupInfo = tv.Random 609 | 610 | var welWire Welcome 611 | syntax.Unmarshal(tc.Welcome, &welWire) 612 | require.Equal(t, welcome.CipherSuite, welWire.CipherSuite) 613 | require.Equal(t, welcome.Version, welWire.Version) 614 | require.Equal(t, welcome.EncryptedGroupInfo, welWire.EncryptedGroupInfo) 615 | 616 | // proposals 617 | addProposal := &Proposal{ 618 | Add: &AddProposal{ 619 | KeyPackage: kp, 620 | }, 621 | } 622 | 623 | addHs := MLSPlaintext{ 624 | GroupID: tv.GroupID, 625 | Epoch: tv.Epoch, 626 | Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)}, 627 | Content: MLSPlaintextContent{ 628 | Proposal: addProposal, 629 | }, 630 | } 631 | addHs.Signature = Signature{tv.Random} 632 | 633 | addM, err := syntax.Marshal(addHs) 634 | require.Nil(t, err) 635 | require.Equal(t, addM, tc.AddProposal) 636 | 637 | updateProposal := &Proposal{ 638 | Update: &UpdateProposal{ 639 | KeyPackage: kp, 640 | }, 641 | } 642 | 643 | updateHs := MLSPlaintext{ 644 | GroupID: tv.GroupID, 645 | Epoch: tv.Epoch, 646 | Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)}, 647 | Content: MLSPlaintextContent{ 648 | Proposal: updateProposal, 649 | }, 650 | } 651 | updateHs.Signature = Signature{tv.Random} 652 | 653 | updateM, err := syntax.Marshal(updateHs) 654 | require.Nil(t, err) 655 | require.Equal(t, updateM, tc.UpdateProposal) 656 | 657 | removeProposal := &Proposal{ 658 | Remove: &RemoveProposal{ 659 | Removed: tv.SignerIndex, 660 | }, 661 | } 662 | 663 | removeHs := MLSPlaintext{ 664 | GroupID: tv.GroupID, 665 | Epoch: tv.Epoch, 666 | Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)}, 667 | Content: MLSPlaintextContent{ 668 | Proposal: removeProposal, 669 | }, 670 | } 671 | removeHs.Signature = Signature{tv.Random} 672 | remM, err := syntax.Marshal(removeHs) 673 | require.Nil(t, err) 674 | require.Equal(t, remM, tc.RemoveProposal) 675 | 676 | // commit 677 | proposal := []ProposalID{{tv.Random}, {tv.Random}} 678 | commit := Commit{ 679 | Updates: proposal, 680 | Removes: proposal, 681 | Adds: proposal, 682 | Path: dp, 683 | } 684 | 685 | var commitWire Commit 686 | _, err = syntax.Unmarshal(tc.Commit, &commitWire) 687 | require.Nil(t, err) 688 | require.Equal(t, commit.Adds, commitWire.Adds) 689 | require.Equal(t, commit.Removes, commitWire.Removes) 690 | require.Equal(t, commit.Updates, commitWire.Updates) 691 | require.Equal(t, commit.Path.LeafKeyPackage, commitWire.Path.LeafKeyPackage) 692 | // Path not verified because HPKE is randomized 693 | 694 | //MlsCiphertext 695 | ct := MLSCiphertext{ 696 | GroupID: tv.GroupID, 697 | Epoch: tv.Epoch, 698 | ContentType: ContentTypeApplication, 699 | SenderDataNonce: tv.Random, 700 | EncryptedSenderData: tv.Random, 701 | AuthenticatedData: tv.Random, 702 | } 703 | 704 | ctM, err := syntax.Marshal(ct) 705 | require.Nil(t, err) 706 | require.Equal(t, ctM, tc.MLSCiphertext) 707 | } 708 | } 709 | -------------------------------------------------------------------------------- /state.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "math/rand" 7 | "reflect" 8 | 9 | "github.com/cisco/go-tls-syntax" 10 | ) 11 | 12 | /// 13 | /// GroupContext 14 | /// 15 | type GroupContext struct { 16 | GroupID []byte `tls:"head=1"` 17 | Epoch Epoch 18 | TreeHash []byte `tls:"head=1"` 19 | ConfirmedTranscriptHash []byte `tls:"head=1"` 20 | Extensions ExtensionList 21 | } 22 | 23 | /// 24 | /// State 25 | /// 26 | 27 | type ProposalRef uint64 28 | 29 | func toRef(id ProposalID) ProposalRef { 30 | ref := uint64(0) 31 | for i := uint(0); i < 8; i++ { 32 | ref |= uint64(id.Hash[i]) << i 33 | } 34 | return ProposalRef(ref) 35 | } 36 | 37 | type updateSecrets struct { 38 | Secret []byte `tls:"head=1"` 39 | IdentityPriv *SignaturePrivateKey `tls:"optional"` 40 | } 41 | 42 | var supportedGroupExtensions = []ExtensionType{ 43 | // TODO 44 | } 45 | 46 | type State struct { 47 | // Shared confirmed state 48 | CipherSuite CipherSuite 49 | GroupID []byte `tls:"head=1"` 50 | Epoch Epoch 51 | Tree TreeKEMPublicKey 52 | ConfirmedTranscriptHash []byte `tls:"head=1"` 53 | InterimTranscriptHash []byte `tls:"head=1"` 54 | Extensions ExtensionList 55 | 56 | // Per-participant non-secret state 57 | Index LeafIndex `tls:"omit"` 58 | IdentityPriv SignaturePrivateKey `tls:"omit"` 59 | TreePriv TreeKEMPrivateKey `tls:"omit"` 60 | Scheme SignatureScheme `tls:"omit"` 61 | PendingProposals []MLSPlaintext `tls:"omit"` 62 | 63 | // Secret state 64 | PendingUpdates map[ProposalRef]updateSecrets `tls:"omit"` 65 | Keys keyScheduleEpoch `tls:"omit"` 66 | 67 | // Helpful information 68 | NewCredentials map[LeafIndex]bool 69 | } 70 | 71 | func NewEmptyState(groupID []byte, leafSecret []byte, sigPriv SignaturePrivateKey, kp KeyPackage) (*State, error) { 72 | return NewEmptyStateWithExtensions(groupID, leafSecret, sigPriv, kp, NewExtensionList()) 73 | } 74 | 75 | func NewEmptyStateWithExtensions(groupID []byte, leafSecret []byte, sigPriv SignaturePrivateKey, kp KeyPackage, ext ExtensionList) (*State, error) { 76 | suite := kp.CipherSuite 77 | 78 | tree := NewTreeKEMPublicKey(suite) 79 | index := tree.AddLeaf(kp) 80 | 81 | treePriv := NewTreeKEMPrivateKey(suite, tree.Size(), index, leafSecret) 82 | 83 | // Verify that the creator supports the group's extensions 84 | for _, ext := range ext.Entries { 85 | if !kp.Extensions.Has(ext.ExtensionType) { 86 | return nil, fmt.Errorf("Unsupported extension type [%04x]", ext.ExtensionType) 87 | } 88 | } 89 | 90 | secret := make([]byte, suite.newDigest().Size()) 91 | kse := newKeyScheduleEpoch(suite, 1, secret, []byte{}) 92 | s := &State{ 93 | CipherSuite: kp.CipherSuite, 94 | GroupID: groupID, 95 | Epoch: 0, 96 | Tree: *tree, 97 | Keys: kse, 98 | Index: 0, 99 | IdentityPriv: sigPriv, 100 | TreePriv: *treePriv, 101 | Scheme: kp.Credential.Scheme(), 102 | PendingUpdates: map[ProposalRef]updateSecrets{}, 103 | ConfirmedTranscriptHash: []byte{}, 104 | InterimTranscriptHash: []byte{}, 105 | Extensions: ext, 106 | NewCredentials: map[LeafIndex]bool{}, 107 | } 108 | return s, nil 109 | } 110 | 111 | func NewStateFromWelcome(suite CipherSuite, epochSecret []byte, welcome Welcome) (*State, LeafIndex, []byte, error) { 112 | // Decrypt the GroupInfo 113 | gi, err := welcome.Decrypt(suite, epochSecret) 114 | if err != nil { 115 | return nil, 0, nil, err 116 | } 117 | 118 | // Construct the new state 119 | s := &State{ 120 | CipherSuite: suite, 121 | Epoch: gi.Epoch, 122 | Tree: gi.Tree.Clone(), 123 | GroupID: gi.GroupID, 124 | ConfirmedTranscriptHash: gi.ConfirmedTranscriptHash, 125 | InterimTranscriptHash: gi.InterimTranscriptHash, 126 | Extensions: gi.Extensions, 127 | PendingProposals: []MLSPlaintext{}, 128 | PendingUpdates: map[ProposalRef]updateSecrets{}, 129 | NewCredentials: map[LeafIndex]bool{}, 130 | } 131 | 132 | // At this point, every leaf in the tree is new 133 | // XXX(RLB) ... except our own 134 | for i := LeafIndex(0); i < LeafIndex(s.Tree.Size()); i++ { 135 | s.NewCredentials[i] = true 136 | } 137 | 138 | return s, gi.SignerIndex, gi.Confirmation, nil 139 | } 140 | 141 | func NewJoinedState(initSecret []byte, sigPrivs []SignaturePrivateKey, kps []KeyPackage, welcome Welcome) (*State, error) { 142 | var initPriv HPKEPrivateKey 143 | var sigPriv SignaturePrivateKey 144 | var keyPackage KeyPackage 145 | var encGroupSecrets EncryptedGroupSecrets 146 | var found = false 147 | suite := welcome.CipherSuite 148 | // extract the keyPackage for init secret 149 | for idx, kp := range kps { 150 | data, err := syntax.Marshal(kp) 151 | if err != nil { 152 | return nil, fmt.Errorf("mls.state: kp %d marshal failure %v", idx, err) 153 | } 154 | kphash := welcome.CipherSuite.Digest(data) 155 | // parse the encryptedKeyPackage to find our right kp 156 | for _, egs := range welcome.Secrets { 157 | found = bytes.Equal(kphash, egs.KeyPackageHash) 158 | if found { 159 | initPriv, err = kp.CipherSuite.hpke().Derive(initSecret) 160 | if err != nil { 161 | return nil, err 162 | } 163 | 164 | if !initPriv.PublicKey.Equals(kp.InitKey) { 165 | return nil, fmt.Errorf("Incorrect init secret") 166 | } 167 | 168 | sigPriv = sigPrivs[idx] 169 | keyPackage = kp 170 | encGroupSecrets = egs 171 | break 172 | } 173 | } 174 | if found { 175 | break 176 | } 177 | } 178 | 179 | if !found { 180 | return nil, fmt.Errorf("mls.state: unable to decrypt welcome message") 181 | } 182 | 183 | if keyPackage.CipherSuite != welcome.CipherSuite { 184 | return nil, fmt.Errorf("mls.state: ciphersuite mismatch") 185 | } 186 | 187 | pt, err := suite.hpke().Decrypt(initPriv, []byte{}, encGroupSecrets.EncryptedGroupSecrets) 188 | if err != nil { 189 | return nil, fmt.Errorf("mls.state: encKeyPkg decryption failure %v", err) 190 | } 191 | 192 | var groupSecrets GroupSecrets 193 | _, err = syntax.Unmarshal(pt, &groupSecrets) 194 | if err != nil { 195 | return nil, fmt.Errorf("mls.state: keyPkg unmarshal failure %v", err) 196 | } 197 | 198 | // Construct a new state based on the GroupInfo 199 | s, signerIndex, confirmation, err := NewStateFromWelcome(suite, groupSecrets.EpochSecret, welcome) 200 | if err != nil { 201 | return nil, err 202 | } 203 | 204 | s.IdentityPriv = sigPriv 205 | s.Scheme = keyPackage.Credential.Scheme() 206 | 207 | // Verify that the joiner supports the group's extensions 208 | for _, ext := range s.Extensions.Entries { 209 | if !keyPackage.Extensions.Has(ext.ExtensionType) { 210 | return nil, fmt.Errorf("Unsupported extension type [%04x]", ext.ExtensionType) 211 | } 212 | } 213 | 214 | // Construct TreeKEM private key from parts provided 215 | index, res := s.Tree.Find(keyPackage) 216 | if !res { 217 | return nil, fmt.Errorf("mls.state: new joiner not in the tree") 218 | } 219 | s.Index = index 220 | commonAncestor := ancestor(s.Index, signerIndex) 221 | 222 | var pathSecret []byte 223 | if groupSecrets.PathSecret != nil { 224 | pathSecret = groupSecrets.PathSecret.Data 225 | } 226 | 227 | treePriv := NewTreeKEMPrivateKeyForJoiner(s.CipherSuite, s.Index, s.Tree.Size(), initSecret, commonAncestor, pathSecret) 228 | s.TreePriv = *treePriv 229 | 230 | // Start up the key schedule 231 | encGrpCtx, err := syntax.Marshal(s.groupContext()) 232 | if err != nil { 233 | return nil, fmt.Errorf("mls.state: groupCtx marshal failure %v", err) 234 | } 235 | 236 | s.Keys = newKeyScheduleEpoch(suite, LeafCount(s.Tree.Size()), groupSecrets.EpochSecret, encGrpCtx) 237 | 238 | // confirmation verification 239 | if !s.verifyConfirmation(confirmation) { 240 | return nil, fmt.Errorf("mls.state: confirmation failed to verify") 241 | } 242 | 243 | return s, nil 244 | } 245 | 246 | func (s State) Add(kp KeyPackage) (*MLSPlaintext, error) { 247 | // Verify that the new member supports the group's extensions 248 | for _, ext := range s.Extensions.Entries { 249 | if !kp.Extensions.Has(ext.ExtensionType) { 250 | return nil, fmt.Errorf("Unsupported extension type [%04x]", ext.ExtensionType) 251 | } 252 | } 253 | 254 | addProposal := Proposal{ 255 | Add: &AddProposal{ 256 | KeyPackage: kp, 257 | }, 258 | } 259 | 260 | return s.sign(addProposal) 261 | } 262 | 263 | func (s State) Update(secret []byte, sigPriv *SignaturePrivateKey, kp KeyPackage) (*MLSPlaintext, error) { 264 | updateProposal := Proposal{ 265 | Update: &UpdateProposal{ 266 | KeyPackage: kp, 267 | }, 268 | } 269 | 270 | pt, err := s.sign(updateProposal) 271 | if err != nil { 272 | return nil, err 273 | } 274 | ref := toRef(s.proposalID(*pt)) 275 | s.PendingUpdates[ref] = updateSecrets{dup(secret), sigPriv} 276 | return pt, nil 277 | } 278 | 279 | func (s *State) Remove(removed LeafIndex) (*MLSPlaintext, error) { 280 | removeProposal := Proposal{ 281 | Remove: &RemoveProposal{ 282 | Removed: removed, 283 | }, 284 | } 285 | pt, err := s.sign(removeProposal) 286 | if err != nil { 287 | return nil, err 288 | } 289 | return pt, nil 290 | } 291 | 292 | func (s *State) Commit(leafSecret []byte) (*MLSPlaintext, *Welcome, *State, error) { 293 | // Construct and apply a commit message 294 | commit := Commit{} 295 | var joiners []KeyPackage 296 | 297 | for _, pp := range s.PendingProposals { 298 | pid := s.proposalID(pp) 299 | proposal := pp.Content.Proposal 300 | switch proposal.Type() { 301 | case ProposalTypeAdd: 302 | commit.Adds = append(commit.Adds, pid) 303 | joiners = append(joiners, proposal.Add.KeyPackage) 304 | case ProposalTypeUpdate: 305 | commit.Updates = append(commit.Updates, pid) 306 | case ProposalTypeRemove: 307 | commit.Removes = append(commit.Removes, pid) 308 | } 309 | } 310 | 311 | // init new state to apply commit and ratchet forward 312 | next := s.Clone() 313 | err := next.apply(commit) 314 | if err != nil { 315 | return nil, nil, nil, err 316 | } 317 | 318 | // reset after commit the proposals 319 | next.PendingProposals = nil 320 | 321 | // KEM new entropy to the new group if needed 322 | if commit.PathRequired() { 323 | ctx, err := syntax.Marshal(next.groupContext()) 324 | if err != nil { 325 | return nil, nil, nil, err 326 | } 327 | 328 | treePriv, treePath, err := next.Tree.Encap(s.Index, ctx, leafSecret, next.IdentityPriv, nil) 329 | if err != nil { 330 | return nil, nil, nil, err 331 | } 332 | 333 | next.TreePriv = *treePriv 334 | commit.Path = treePath 335 | } 336 | 337 | // Create the Commit message and advance the transcripts / key schedule 338 | pt, err := next.ratchetAndSign(commit, next.TreePriv.UpdateSecret, s.groupContext(), s.IdentityPriv) 339 | if err != nil { 340 | return nil, nil, nil, fmt.Errorf("mls.state: racthet forward failed %v", err) 341 | } 342 | 343 | // Complete the GroupInfo and form the Welcome 344 | gi := &GroupInfo{ 345 | GroupID: next.GroupID, 346 | Epoch: next.Epoch, 347 | Tree: next.Tree, 348 | ConfirmedTranscriptHash: next.ConfirmedTranscriptHash, 349 | InterimTranscriptHash: next.InterimTranscriptHash, 350 | Confirmation: pt.Content.Commit.Confirmation.Data, 351 | } 352 | err = gi.sign(next.Index, &next.IdentityPriv) 353 | if err != nil { 354 | return nil, nil, nil, fmt.Errorf("mls.state: groupInfo sign failure %v", err) 355 | } 356 | 357 | welcome := newWelcome(s.CipherSuite, next.Keys.EpochSecret, gi) 358 | for _, kp := range joiners { 359 | leaf, ok := next.Tree.Find(kp) 360 | if !ok { 361 | return nil, nil, nil, fmt.Errorf("mls.state: New joiner not in tree") 362 | } 363 | 364 | _, pathSecret, ok := next.TreePriv.SharedPathSecret(leaf) 365 | welcome.EncryptTo(kp, pathSecret) 366 | } 367 | 368 | return pt, welcome, next, nil 369 | } 370 | 371 | /// Proposal processing helpers 372 | 373 | func (s *State) apply(commit Commit) error { 374 | // state to identify proposals being processed 375 | // in the PendingProposals. Avoids linear loop to 376 | // remove entries from PendingProposals. 377 | var processedProposals = map[string]bool{} 378 | err := s.applyProposals(commit.Updates, processedProposals) 379 | if err != nil { 380 | return err 381 | } 382 | 383 | err = s.applyProposals(commit.Removes, processedProposals) 384 | if err != nil { 385 | return err 386 | } 387 | 388 | err = s.applyProposals(commit.Adds, processedProposals) 389 | if err != nil { 390 | return err 391 | } 392 | 393 | return nil 394 | } 395 | 396 | func (s *State) applyAddProposal(add *AddProposal) error { 397 | if add.KeyPackage.CipherSuite != s.CipherSuite { 398 | return fmt.Errorf("mls.state: new member kp does not use group ciphersuite") 399 | } 400 | 401 | if !add.KeyPackage.Verify() { 402 | return fmt.Errorf("mls.state: Invalid kp") 403 | } 404 | 405 | target := s.Tree.AddLeaf(add.KeyPackage) 406 | s.NewCredentials[target] = true 407 | return nil 408 | } 409 | 410 | func (s *State) applyRemoveProposal(remove *RemoveProposal) { 411 | s.Tree.BlankPath(LeafIndex(remove.Removed)) 412 | } 413 | 414 | func (s *State) applyUpdateProposal(target LeafIndex, update *UpdateProposal) error { 415 | if update.KeyPackage.CipherSuite != s.CipherSuite { 416 | panic(fmt.Errorf("mls.state: update kp does not use group ciphersuite %v != %v", update.KeyPackage.CipherSuite, s.CipherSuite)) 417 | } 418 | 419 | if !update.KeyPackage.Verify() { 420 | return fmt.Errorf("mls.state: Invalid kp") 421 | } 422 | 423 | currKP, ok := s.Tree.KeyPackage(target) 424 | if !ok { 425 | return fmt.Errorf("mls.state: Attempt to update an empty leaf") 426 | } 427 | 428 | if !update.KeyPackage.Credential.Equals(currKP.Credential) { 429 | s.NewCredentials[target] = true 430 | } 431 | 432 | s.Tree.UpdateLeaf(target, update.KeyPackage) 433 | return nil 434 | } 435 | 436 | func (s *State) applyProposals(ids []ProposalID, processed map[string]bool) error { 437 | for _, id := range ids { 438 | pt, ok := s.findProposal(id) 439 | if !ok { 440 | return fmt.Errorf("mls.state: commit of unknown proposal %s", id) 441 | } 442 | 443 | // we have processed this proposal already 444 | if processed[id.String()] { 445 | continue 446 | } else { 447 | processed[id.String()] = true 448 | } 449 | 450 | proposal := pt.Content.Proposal 451 | switch proposal.Type() { 452 | case ProposalTypeAdd: 453 | err := s.applyAddProposal(proposal.Add) 454 | if err != nil { 455 | return err 456 | } 457 | case ProposalTypeUpdate: 458 | if pt.Sender.Type != SenderTypeMember { 459 | return fmt.Errorf("mls.state: update from non-member") 460 | } 461 | 462 | senderIndex := LeafIndex(pt.Sender.Sender) 463 | err := s.applyUpdateProposal(senderIndex, proposal.Update) 464 | if err != nil { 465 | return err 466 | } 467 | 468 | if senderIndex == s.Index { 469 | secrets, ok := s.PendingUpdates[toRef(id)] 470 | if !ok { 471 | return fmt.Errorf("mls.state: self-update with no cached secret") 472 | } 473 | 474 | s.TreePriv.SetLeafSecret(secrets.Secret) 475 | if secrets.IdentityPriv != nil { 476 | s.IdentityPriv = *secrets.IdentityPriv 477 | } 478 | } 479 | 480 | case ProposalTypeRemove: 481 | s.applyRemoveProposal(proposal.Remove) 482 | 483 | default: 484 | return fmt.Errorf("mls.state: invalid proposal type") 485 | } 486 | } 487 | return nil 488 | } 489 | 490 | func (s State) findProposal(id ProposalID) (MLSPlaintext, bool) { 491 | for _, pt := range s.PendingProposals { 492 | otherPid := s.proposalID(pt) 493 | if bytes.Equal(otherPid.Hash, id.Hash) { 494 | return pt, true 495 | } 496 | } 497 | // we can return may be reference 498 | // regardless, the call has to do a check before 499 | // using the returned value 500 | return MLSPlaintext{}, false 501 | } 502 | 503 | func (s State) proposalID(plaintext MLSPlaintext) ProposalID { 504 | enc, err := syntax.Marshal(plaintext) 505 | if err != nil { 506 | panic(fmt.Errorf("mls.state: mlsPlainText marshal failure %v", err)) 507 | 508 | } 509 | return ProposalID{ 510 | Hash: s.CipherSuite.Digest(enc), 511 | } 512 | } 513 | 514 | func (s State) groupContext() GroupContext { 515 | return GroupContext{ 516 | GroupID: s.GroupID, 517 | Epoch: s.Epoch, 518 | TreeHash: s.Tree.RootHash(), 519 | ConfirmedTranscriptHash: s.ConfirmedTranscriptHash, 520 | Extensions: s.Extensions, 521 | } 522 | } 523 | 524 | func (s State) sign(p Proposal) (*MLSPlaintext, error) { 525 | pt := &MLSPlaintext{ 526 | GroupID: s.GroupID, 527 | Epoch: s.Epoch, 528 | Sender: Sender{SenderTypeMember, uint32(s.Index)}, 529 | Content: MLSPlaintextContent{ 530 | Proposal: &p, 531 | }, 532 | } 533 | 534 | err := pt.sign(s.groupContext(), s.IdentityPriv, s.Scheme) 535 | if err != nil { 536 | return nil, err 537 | } 538 | return pt, nil 539 | } 540 | 541 | func (s *State) updateEpochSecrets(secret []byte) { 542 | ctx, err := syntax.Marshal(GroupContext{ 543 | GroupID: s.GroupID, 544 | Epoch: s.Epoch, 545 | TreeHash: s.Tree.RootHash(), 546 | ConfirmedTranscriptHash: s.ConfirmedTranscriptHash, 547 | }) 548 | if err != nil { 549 | panic(fmt.Errorf("mls.state: update epoch secret failed %v", err)) 550 | } 551 | 552 | // TODO(RLB) Provide an API to provide PSKs 553 | s.Keys = s.Keys.Next(LeafCount(s.Tree.Size()), nil, secret, ctx) 554 | } 555 | 556 | func (s *State) ratchetAndSign(op Commit, commitSecret []byte, prevGrpCtx GroupContext, sigPriv SignaturePrivateKey) (*MLSPlaintext, error) { 557 | pt := &MLSPlaintext{ 558 | GroupID: s.GroupID, 559 | Epoch: s.Epoch, 560 | Sender: Sender{SenderTypeMember, uint32(s.Index)}, 561 | Content: MLSPlaintextContent{ 562 | Commit: &CommitData{ 563 | Commit: op, 564 | }, 565 | }, 566 | } 567 | 568 | // Update the Confirmed Transcript Hash 569 | digest := s.CipherSuite.newDigest() 570 | digest.Write(s.InterimTranscriptHash) 571 | digest.Write(pt.commitContent()) 572 | s.ConfirmedTranscriptHash = digest.Sum(nil) 573 | 574 | // Advance the key schedule 575 | s.Epoch += 1 576 | s.updateEpochSecrets(commitSecret) 577 | 578 | // generate the confirmation based on the new keys 579 | commit := pt.Content.Commit 580 | hmac := s.CipherSuite.NewHMAC(s.Keys.ConfirmationKey) 581 | hmac.Write(s.ConfirmedTranscriptHash) 582 | commit.Confirmation.Data = hmac.Sum(nil) 583 | 584 | // sign the MLSPlainText and update state hashes 585 | // as a result of ratcheting. 586 | err := pt.sign(prevGrpCtx, sigPriv, s.Scheme) 587 | if err != nil { 588 | return nil, err 589 | } 590 | 591 | authData, err := pt.commitAuthData() 592 | if err != nil { 593 | return nil, err 594 | } 595 | 596 | digest = s.CipherSuite.newDigest() 597 | digest.Write(s.ConfirmedTranscriptHash) 598 | digest.Write(authData) 599 | s.InterimTranscriptHash = digest.Sum(nil) 600 | 601 | return pt, nil 602 | } 603 | 604 | func (s State) signerPublicKey(sender Sender) (*SignaturePublicKey, error) { 605 | switch sender.Type { 606 | case SenderTypeMember: 607 | kp, ok := s.Tree.KeyPackage(LeafIndex(sender.Sender)) 608 | if !ok { 609 | return nil, fmt.Errorf("mls.state: Received from blank leaf") 610 | } 611 | 612 | return kp.Credential.PublicKey(), nil 613 | 614 | default: 615 | // TODO(RLB): Support add sent by new member 616 | // TODO(RLB): Support add/remove signed by preconfigured key 617 | return nil, fmt.Errorf("mls.state: Unsupported sender type") 618 | } 619 | } 620 | 621 | func (s *State) Handle(pt *MLSPlaintext) (*State, error) { 622 | if !bytes.Equal(pt.GroupID, s.GroupID) { 623 | return nil, fmt.Errorf("mls.state: groupId mismatch") 624 | } 625 | 626 | if pt.Epoch != s.Epoch { 627 | return nil, fmt.Errorf("mls.state: epoch mismatch, have %v, got %v", s.Epoch, pt.Epoch) 628 | } 629 | 630 | sigPubKey, err := s.signerPublicKey(pt.Sender) 631 | if err != nil { 632 | return nil, err 633 | } 634 | 635 | if !pt.verify(s.groupContext(), sigPubKey, s.Scheme) { 636 | return nil, fmt.Errorf("invalid handshake message signature") 637 | } 638 | 639 | // Proposals get queued, do not result in a state transition 640 | contentType := pt.Content.Type() 641 | if contentType == ContentTypeProposal { 642 | s.PendingProposals = append(s.PendingProposals, *pt) 643 | return nil, nil 644 | } 645 | 646 | if contentType != ContentTypeCommit { 647 | return nil, fmt.Errorf("mls.state: incorrect content type") 648 | } else if pt.Sender.Type != SenderTypeMember { 649 | return nil, fmt.Errorf("mls.state: commit from non-member") 650 | } 651 | 652 | if LeafIndex(pt.Sender.Sender) == s.Index { 653 | return nil, fmt.Errorf("mls.state: handle own commits with caching") 654 | } 655 | 656 | // apply the commit and discard any remaining pending proposals 657 | senderIndex := LeafIndex(pt.Sender.Sender) 658 | commitData := pt.Content.Commit 659 | next := s.Clone() 660 | err = next.apply(commitData.Commit) 661 | if err != nil { 662 | return nil, err 663 | } 664 | 665 | next.PendingProposals = next.PendingProposals[:0] 666 | 667 | // apply the direct path, if provided 668 | commitSecret := s.CipherSuite.zero() 669 | if commitData.Commit.Path != nil { 670 | ctx, err := syntax.Marshal(GroupContext{ 671 | GroupID: next.GroupID, 672 | Epoch: next.Epoch, 673 | TreeHash: next.Tree.RootHash(), 674 | ConfirmedTranscriptHash: next.ConfirmedTranscriptHash, 675 | }) 676 | if err != nil { 677 | return nil, fmt.Errorf("mls.state: failure to create context %v", err) 678 | } 679 | 680 | err = next.TreePriv.Decap(senderIndex, next.Tree, ctx, *commitData.Commit.Path) 681 | if err != nil { 682 | return nil, err 683 | } 684 | 685 | commitSecret = next.TreePriv.UpdateSecret 686 | 687 | err = next.Tree.Merge(senderIndex, *commitData.Commit.Path) 688 | if err != nil { 689 | return nil, err 690 | } 691 | } 692 | 693 | // Update the confirmed transcript hash 694 | digest := next.CipherSuite.newDigest() 695 | digest.Write(next.InterimTranscriptHash) 696 | digest.Write(pt.commitContent()) 697 | next.ConfirmedTranscriptHash = digest.Sum(nil) 698 | 699 | // Advance the key schedule 700 | next.Epoch += 1 701 | next.updateEpochSecrets(commitSecret) 702 | 703 | // Verify confirmation MAC 704 | if !next.verifyConfirmation(commitData.Confirmation.Data) { 705 | return nil, fmt.Errorf("mls.state: confirmation failed to verify") 706 | } 707 | 708 | authData, err := pt.commitAuthData() 709 | if err != nil { 710 | return nil, err 711 | } 712 | 713 | // Update the interim transcript hash 714 | digest = next.CipherSuite.newDigest() 715 | digest.Write(next.ConfirmedTranscriptHash) 716 | digest.Write(authData) 717 | next.InterimTranscriptHash = digest.Sum(nil) 718 | 719 | return next, nil 720 | } 721 | 722 | ///// protect/unprotect and helpers 723 | 724 | func (s State) verifyConfirmation(confirmation []byte) bool { 725 | hmac := s.CipherSuite.NewHMAC(s.Keys.ConfirmationKey) 726 | hmac.Write(s.ConfirmedTranscriptHash) 727 | confirm := hmac.Sum(nil) 728 | if !bytes.Equal(confirm, confirmation) { 729 | return false 730 | } 731 | return true 732 | } 733 | 734 | func applyGuard(nonceIn []byte, reuseGuard [4]byte) []byte { 735 | nonceOut := dup(nonceIn) 736 | for i := range reuseGuard { 737 | nonceOut[i] ^= reuseGuard[i] 738 | } 739 | return nonceOut 740 | } 741 | 742 | func (s *State) encrypt(pt *MLSPlaintext) (*MLSCiphertext, error) { 743 | var generation uint32 744 | var keys keyAndNonce 745 | switch pt.Content.Type() { 746 | case ContentTypeApplication: 747 | generation, keys = s.Keys.ApplicationKeys.Next(s.Index) 748 | case ContentTypeProposal, ContentTypeCommit: 749 | generation, keys = s.Keys.HandshakeKeys.Next(s.Index) 750 | default: 751 | return nil, fmt.Errorf("mls.state: encrypt unknown content type") 752 | } 753 | 754 | var reuseGuard [4]byte 755 | rand.Read(reuseGuard[:]) 756 | 757 | stream := syntax.NewWriteStream() 758 | err := stream.WriteAll(s.Index, generation, reuseGuard) 759 | if err != nil { 760 | return nil, fmt.Errorf("mls.state: sender data marshal failure %v", err) 761 | } 762 | 763 | senderData := stream.Data() 764 | senderDataNonce := make([]byte, s.CipherSuite.Constants().NonceSize) 765 | rand.Read(senderDataNonce) 766 | senderDataAADVal := senderDataAAD(s.GroupID, s.Epoch, pt.Content.Type(), senderDataNonce) 767 | sdAead, _ := s.CipherSuite.NewAEAD(s.Keys.SenderDataKey) 768 | sdCt := sdAead.Seal(nil, senderDataNonce, senderData, senderDataAADVal) 769 | 770 | // content data 771 | stream = syntax.NewWriteStream() 772 | err = stream.Write(pt.Content) 773 | if err == nil { 774 | err = stream.Write(pt.Signature) 775 | } 776 | if err != nil { 777 | return nil, fmt.Errorf("mls.state: content marshal failure %v", err) 778 | } 779 | content := stream.Data() 780 | 781 | aad := contentAAD(s.GroupID, s.Epoch, pt.Content.Type(), 782 | pt.AuthenticatedData, senderDataNonce, sdCt) 783 | aead, _ := s.CipherSuite.NewAEAD(keys.Key) 784 | contentCt := aead.Seal(nil, applyGuard(keys.Nonce, reuseGuard), content, aad) 785 | 786 | // set up MLSCipherText 787 | ct := &MLSCiphertext{ 788 | GroupID: s.GroupID, 789 | Epoch: s.Epoch, 790 | ContentType: pt.Content.Type(), 791 | AuthenticatedData: pt.AuthenticatedData, 792 | SenderDataNonce: senderDataNonce, 793 | EncryptedSenderData: sdCt, 794 | Ciphertext: contentCt, 795 | } 796 | 797 | return ct, nil 798 | } 799 | 800 | func (s *State) decrypt(ct *MLSCiphertext) (*MLSPlaintext, error) { 801 | if !bytes.Equal(ct.GroupID, s.GroupID) { 802 | return nil, fmt.Errorf("mls.state: ciphertext not from this group") 803 | } 804 | 805 | if ct.Epoch != s.Epoch { 806 | return nil, fmt.Errorf("mls.state: ciphertext not from this epoch") 807 | } 808 | 809 | // handle sender data 810 | sdAAD := senderDataAAD(ct.GroupID, ct.Epoch, ContentType(ct.ContentType), ct.SenderDataNonce) 811 | sdAead, _ := s.CipherSuite.NewAEAD(s.Keys.SenderDataKey) 812 | sd, err := sdAead.Open(nil, ct.SenderDataNonce, ct.EncryptedSenderData, sdAAD) 813 | if err != nil { 814 | return nil, fmt.Errorf("mls.state: senderData decryption failure %v", err) 815 | } 816 | 817 | // parse the senderData 818 | var sender LeafIndex 819 | var generation uint32 820 | var reuseGuard [4]byte 821 | stream := syntax.NewReadStream(sd) 822 | _, err = stream.ReadAll(&sender, &generation, &reuseGuard) 823 | if err != nil { 824 | return nil, fmt.Errorf("mls.state: senderData unmarshal failure %v", err) 825 | } 826 | 827 | var keys keyAndNonce 828 | contentType := ContentType(ct.ContentType) 829 | switch contentType { 830 | case ContentTypeApplication: 831 | keys, err = s.Keys.ApplicationKeys.Get(sender, generation) 832 | if err != nil { 833 | return nil, fmt.Errorf("mls.state: application keys extraction failed %v", err) 834 | } 835 | s.Keys.ApplicationKeys.Erase(sender, generation) 836 | case ContentTypeProposal, ContentTypeCommit: 837 | keys, err = s.Keys.HandshakeKeys.Get(sender, generation) 838 | if err != nil { 839 | return nil, fmt.Errorf("mls.state: handshake keys extraction failed %v", err) 840 | } 841 | s.Keys.HandshakeKeys.Erase(sender, generation) 842 | default: 843 | return nil, fmt.Errorf("mls.state: unsupported content type") 844 | } 845 | 846 | aad := contentAAD(ct.GroupID, ct.Epoch, ContentType(ct.ContentType), 847 | ct.AuthenticatedData, ct.SenderDataNonce, ct.EncryptedSenderData) 848 | aead, _ := s.CipherSuite.NewAEAD(keys.Key) 849 | content, err := aead.Open(nil, applyGuard(keys.Nonce, reuseGuard), ct.Ciphertext, aad) 850 | if err != nil { 851 | return nil, fmt.Errorf("mls.state: content decryption failure %v", err) 852 | } 853 | 854 | // parse the Content and Signature 855 | stream = syntax.NewReadStream(content) 856 | var mlsContent MLSPlaintextContent 857 | var signature Signature 858 | _, err = stream.Read(&mlsContent) 859 | if err == nil { 860 | _, err = stream.Read(&signature) 861 | } 862 | if err != nil { 863 | return nil, fmt.Errorf("mls.state: content unmarshal failure %v", err) 864 | } 865 | _, _ = syntax.Unmarshal(content, &mlsContent) 866 | 867 | pt := &MLSPlaintext{ 868 | GroupID: s.GroupID, 869 | Epoch: s.Epoch, 870 | Sender: Sender{SenderTypeMember, uint32(sender)}, 871 | AuthenticatedData: ct.AuthenticatedData, 872 | Content: mlsContent, 873 | Signature: signature, 874 | } 875 | return pt, nil 876 | } 877 | 878 | func (s *State) Protect(data []byte) (*MLSCiphertext, error) { 879 | pt := &MLSPlaintext{ 880 | GroupID: s.GroupID, 881 | Epoch: s.Epoch, 882 | Sender: Sender{SenderTypeMember, uint32(s.Index)}, 883 | Content: MLSPlaintextContent{ 884 | Application: &ApplicationData{ 885 | Data: data, 886 | }, 887 | }, 888 | } 889 | 890 | err := pt.sign(s.groupContext(), s.IdentityPriv, s.Scheme) 891 | if err != nil { 892 | return nil, err 893 | } 894 | return s.encrypt(pt) 895 | } 896 | 897 | func (s *State) Unprotect(ct *MLSCiphertext) ([]byte, error) { 898 | pt, err := s.decrypt(ct) 899 | if err != nil { 900 | return nil, err 901 | } 902 | 903 | sigPubKey, err := s.signerPublicKey(pt.Sender) 904 | if err != nil { 905 | return nil, err 906 | } 907 | 908 | if !pt.verify(s.groupContext(), sigPubKey, s.Scheme) { 909 | return nil, fmt.Errorf("invalid message signature") 910 | } 911 | 912 | if pt.Content.Type() != ContentTypeApplication { 913 | return nil, fmt.Errorf("unprotect attempted on non-application message") 914 | } 915 | return pt.Content.Application.Data, nil 916 | } 917 | 918 | func senderDataAAD(gid []byte, epoch Epoch, contentType ContentType, nonce []byte) []byte { 919 | s := syntax.NewWriteStream() 920 | err := s.Write(struct { 921 | GroupID []byte `tls:"head=1"` 922 | Epoch Epoch 923 | ContentType ContentType 924 | SenderDataNonce []byte `tls:"head=1"` 925 | }{ 926 | GroupID: gid, 927 | Epoch: epoch, 928 | ContentType: contentType, 929 | SenderDataNonce: nonce, 930 | }) 931 | 932 | if err != nil { 933 | return nil 934 | } 935 | 936 | return s.Data() 937 | } 938 | 939 | func contentAAD(gid []byte, epoch Epoch, 940 | contentType ContentType, authenticatedData []byte, 941 | nonce []byte, encSenderData []byte) []byte { 942 | 943 | s := syntax.NewWriteStream() 944 | err := s.Write(struct { 945 | GroupID []byte `tls:"head=1"` 946 | Epoch Epoch 947 | ContentType ContentType 948 | AuthenticatedData []byte `tls:"head=4"` 949 | SenderDataNonce []byte `tls:"head=1"` 950 | EncryptedSenderData []byte `tls:"head=1"` 951 | }{ 952 | GroupID: gid, 953 | Epoch: epoch, 954 | ContentType: contentType, 955 | AuthenticatedData: authenticatedData, 956 | SenderDataNonce: nonce, 957 | EncryptedSenderData: encSenderData, 958 | }) 959 | 960 | if err != nil { 961 | return nil 962 | } 963 | return s.Data() 964 | } 965 | 966 | func (s State) Clone() *State { 967 | // Note: all the slice/map copy operations below on state are mere 968 | // reference copies. 969 | clone := &State{ 970 | CipherSuite: s.CipherSuite, 971 | GroupID: dup(s.GroupID), 972 | Epoch: s.Epoch, 973 | Tree: s.Tree.Clone(), 974 | ConfirmedTranscriptHash: nil, 975 | InterimTranscriptHash: dup(s.InterimTranscriptHash), 976 | Keys: s.Keys, 977 | Index: s.Index, 978 | IdentityPriv: s.IdentityPriv, 979 | TreePriv: s.TreePriv.Clone(), 980 | Scheme: s.Scheme, 981 | PendingUpdates: s.PendingUpdates, 982 | PendingProposals: make([]MLSPlaintext, len(s.PendingProposals)), 983 | NewCredentials: map[LeafIndex]bool{}, 984 | } 985 | 986 | copy(clone.PendingProposals, s.PendingProposals) 987 | return clone 988 | } 989 | 990 | // Compare the public and shared private aspects of two nodes 991 | func (s State) Equals(o State) bool { 992 | suite := s.CipherSuite == o.CipherSuite 993 | groupID := bytes.Equal(s.GroupID, o.GroupID) 994 | epoch := s.Epoch == o.Epoch 995 | tree := s.Tree.Equals(o.Tree) 996 | cth := bytes.Equal(s.ConfirmedTranscriptHash, o.ConfirmedTranscriptHash) 997 | ith := bytes.Equal(s.InterimTranscriptHash, o.InterimTranscriptHash) 998 | keys := reflect.DeepEqual(s.Keys, o.Keys) 999 | 1000 | return suite && groupID && epoch && tree && cth && ith && keys 1001 | } 1002 | 1003 | // Isolated getters and setters for public and secret state 1004 | // 1005 | // Note that the get/set operations here are very shallow. We basically assume 1006 | // that the StateSecrets object is temporary, as a carrier for marshaling / 1007 | // unmarshaling. 1008 | type StateSecrets struct { 1009 | CipherSuite CipherSuite 1010 | 1011 | // Per-participant non-secret state 1012 | Index LeafIndex 1013 | InitPriv HPKEPrivateKey 1014 | IdentityPriv SignaturePrivateKey 1015 | Scheme SignatureScheme 1016 | PendingProposals []MLSPlaintext `tls:"head=4"` 1017 | 1018 | // Secret state 1019 | PendingUpdates map[ProposalRef]updateSecrets `tls:"head=4"` 1020 | Keys keyScheduleEpoch 1021 | TreePriv TreeKEMPrivateKey 1022 | } 1023 | 1024 | func NewStateFromWelcomeAndSecrets(welcome Welcome, ss StateSecrets) (*State, error) { 1025 | // Import the base data using some information from the secrets 1026 | suite := ss.CipherSuite 1027 | epochSecret := ss.Keys.EpochSecret 1028 | s, _, confirmation, err := NewStateFromWelcome(suite, epochSecret, welcome) 1029 | if err != nil { 1030 | return nil, err 1031 | } 1032 | 1033 | // Import the secrets 1034 | s.SetSecrets(ss) 1035 | 1036 | // Verify the confirmation 1037 | if !s.verifyConfirmation(confirmation) { 1038 | return nil, fmt.Errorf("mls.state: Confirmation failed to verify") 1039 | } 1040 | 1041 | return s, nil 1042 | } 1043 | 1044 | func (s *State) SetSecrets(ss StateSecrets) { 1045 | s.CipherSuite = ss.CipherSuite 1046 | s.Index = ss.Index 1047 | s.IdentityPriv = ss.IdentityPriv 1048 | s.Scheme = ss.Scheme 1049 | s.PendingProposals = ss.PendingProposals 1050 | s.Keys = ss.Keys 1051 | s.TreePriv = ss.TreePriv 1052 | 1053 | s.TreePriv.privateKeyCache = map[NodeIndex]HPKEPrivateKey{} 1054 | 1055 | s.PendingUpdates = map[ProposalRef]updateSecrets{} 1056 | for i, secret := range ss.PendingUpdates { 1057 | s.PendingUpdates[i] = secret 1058 | } 1059 | } 1060 | 1061 | func (s State) GetSecrets() StateSecrets { 1062 | pendingUpdates := map[ProposalRef]updateSecrets{} 1063 | for i, secret := range s.PendingUpdates { 1064 | pendingUpdates[i] = secret 1065 | } 1066 | 1067 | return StateSecrets{ 1068 | CipherSuite: s.CipherSuite, 1069 | Index: s.Index, 1070 | IdentityPriv: s.IdentityPriv, 1071 | Scheme: s.Scheme, 1072 | PendingProposals: s.PendingProposals, 1073 | PendingUpdates: pendingUpdates, 1074 | Keys: s.Keys, 1075 | TreePriv: s.TreePriv, 1076 | } 1077 | } 1078 | -------------------------------------------------------------------------------- /state_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cisco/go-tls-syntax" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | var ( 11 | groupID = []byte{0x01, 0x02, 0x03, 0x04} 12 | userID = []byte{0x04, 0x05, 0x06, 0x07} 13 | suite = P256_AES128GCM_SHA256_P256 14 | groupSize = 5 15 | 16 | testMessage = unhex("1112131415") 17 | ) 18 | 19 | type StateTest struct { 20 | initSecrets [][]byte 21 | identityPrivs []SignaturePrivateKey 22 | credentials []Credential 23 | initPrivs []HPKEPrivateKey 24 | keyPackages []KeyPackage 25 | states []State 26 | } 27 | 28 | func setup(t *testing.T) StateTest { 29 | stateTest := StateTest{} 30 | stateTest.keyPackages = make([]KeyPackage, groupSize) 31 | scheme := suite.Scheme() 32 | 33 | for i := 0; i < groupSize; i++ { 34 | // cred gen 35 | secret := randomBytes(32) 36 | sigPriv, err := scheme.Derive(secret) 37 | require.Nil(t, err) 38 | 39 | cred := NewBasicCredential(userID, scheme, sigPriv.PublicKey) 40 | 41 | //kp gen 42 | kp, err := NewKeyPackageWithSecret(suite, secret, cred, sigPriv) 43 | require.Nil(t, err) 44 | 45 | // save all the materials 46 | stateTest.initSecrets = append(stateTest.initSecrets, secret) 47 | stateTest.identityPrivs = append(stateTest.identityPrivs, sigPriv) 48 | stateTest.credentials = append(stateTest.credentials, *cred) 49 | stateTest.keyPackages[i] = *kp 50 | } 51 | return stateTest 52 | } 53 | 54 | func setupGroup(t *testing.T) StateTest { 55 | stateTest := setup(t) 56 | var states []State 57 | // start with the group creator 58 | s0, err := NewEmptyState(groupID, stateTest.initSecrets[0], stateTest.identityPrivs[0], stateTest.keyPackages[0]) 59 | require.Nil(t, err) 60 | states = append(states, *s0) 61 | 62 | // add proposals for rest of the participants 63 | for i := 1; i < groupSize; i++ { 64 | add, err := states[0].Add(stateTest.keyPackages[i]) 65 | require.Nil(t, err) 66 | _, err = states[0].Handle(add) 67 | require.Nil(t, err) 68 | } 69 | 70 | // commit the adds 71 | secret := randomBytes(32) 72 | _, welcome, next, err := states[0].Commit(secret) 73 | require.Nil(t, err) 74 | states[0] = *next 75 | // initialize the new joiners from the welcome 76 | for i := 1; i < groupSize; i++ { 77 | s, err := NewJoinedState(stateTest.initSecrets[i], stateTest.identityPrivs[i:i+1], stateTest.keyPackages[i:i+1], *welcome) 78 | require.Nil(t, err) 79 | states = append(states, *s) 80 | } 81 | stateTest.states = states 82 | 83 | // Verify that the states are all equivalent 84 | for _, lhs := range stateTest.states { 85 | for _, rhs := range stateTest.states { 86 | require.True(t, lhs.Equals(rhs)) 87 | } 88 | } 89 | 90 | return stateTest 91 | } 92 | 93 | func TestStateTwoPerson(t *testing.T) { 94 | stateTest := setup(t) 95 | // creator's state 96 | first0, err := NewEmptyState(groupID, stateTest.initSecrets[0], stateTest.identityPrivs[0], stateTest.keyPackages[0]) 97 | require.Nil(t, err) 98 | 99 | // add the second participant 100 | add, err := first0.Add(stateTest.keyPackages[1]) 101 | require.Nil(t, err) 102 | _, err = first0.Handle(add) 103 | require.Nil(t, err) 104 | 105 | // commit adding the second participant 106 | secret := randomBytes(32) 107 | _, welcome, first1, err := first0.Commit(secret) 108 | require.Nil(t, err) 109 | require.Equal(t, first1.NewCredentials, map[LeafIndex]bool{1: true}) 110 | 111 | // Initialize the second participant from the Welcome 112 | second1, err := NewJoinedState(stateTest.initSecrets[1], stateTest.identityPrivs[1:2], stateTest.keyPackages[1:2], *welcome) 113 | require.Nil(t, err) 114 | require.Equal(t, second1.NewCredentials, map[LeafIndex]bool{0: true, 1: true}) 115 | 116 | // Verify that the two states are equivalent 117 | require.True(t, first1.Equals(*second1)) 118 | 119 | /// Verify that they can exchange protected messages 120 | ct, err := first1.Protect(testMessage) 121 | require.Nil(t, err) 122 | pt, err := second1.Unprotect(ct) 123 | require.Nil(t, err) 124 | require.Equal(t, pt, testMessage) 125 | } 126 | 127 | const ExtensionTypeGroupTest ExtensionType = 0xFFFF 128 | 129 | type GroupTestExtension struct{} 130 | 131 | func (gte GroupTestExtension) Type() ExtensionType { 132 | return ExtensionTypeGroupTest 133 | } 134 | 135 | func TestStateExtensions(t *testing.T) { 136 | stateTest := setup(t) 137 | groupExtensions := NewExtensionList() 138 | groupExtensions.Add(GroupTestExtension{}) 139 | 140 | clientExtensions := []ExtensionBody{GroupTestExtension{}} 141 | 142 | // Check that NewEmptyStateWithExtensions fails if the KP doesn't support them 143 | kpA := stateTest.keyPackages[0] 144 | _, err := NewEmptyStateWithExtensions(groupID, stateTest.initSecrets[0], stateTest.identityPrivs[0], kpA, groupExtensions) 145 | require.Error(t, err) 146 | 147 | // Check that NewEmptyStateWithExtensions succeeds with exetnsion support 148 | err = kpA.SetExtensions(clientExtensions) 149 | require.Nil(t, err) 150 | err = kpA.Sign(stateTest.identityPrivs[0]) 151 | require.Nil(t, err) 152 | 153 | alice0, err := NewEmptyStateWithExtensions(groupID, stateTest.initSecrets[0], stateTest.identityPrivs[0], kpA, groupExtensions) 154 | require.Nil(t, err) 155 | require.Equal(t, len(alice0.Extensions.Entries), 1) 156 | 157 | // Check that Add fails if the KP doesn't support them 158 | kpB := stateTest.keyPackages[1] 159 | _, err = alice0.Add(kpB) 160 | require.Error(t, err) 161 | 162 | // Check that Add succeeds with extension support 163 | err = kpB.SetExtensions(clientExtensions) 164 | require.Nil(t, err) 165 | err = kpB.Sign(stateTest.identityPrivs[1]) 166 | require.Nil(t, err) 167 | 168 | _, err = alice0.Add(kpB) 169 | require.Nil(t, err) 170 | 171 | // TODO(RLB) Test extension verification in NewJoinedState 172 | } 173 | 174 | func TestStateMarshalUnmarshal(t *testing.T) { 175 | // Create Alice and have her add Bob to a group 176 | stateTest := setup(t) 177 | alice0, err := NewEmptyState(groupID, stateTest.initSecrets[0], stateTest.identityPrivs[0], stateTest.keyPackages[0]) 178 | require.Nil(t, err) 179 | 180 | add, err := alice0.Add(stateTest.keyPackages[1]) 181 | require.Nil(t, err) 182 | _, err = alice0.Handle(add) 183 | require.Nil(t, err) 184 | 185 | secret := randomBytes(32) 186 | _, welcome1, alice1, err := alice0.Commit(secret) 187 | require.Nil(t, err) 188 | 189 | // Marshal Alice's secret state 190 | alice1priv, err := syntax.Marshal(alice1.GetSecrets()) 191 | require.Nil(t, err) 192 | 193 | // Initialize Bob generate an Update+Commit 194 | bob1, err := NewJoinedState(stateTest.initSecrets[1], stateTest.identityPrivs[1:2], stateTest.keyPackages[1:2], *welcome1) 195 | require.Nil(t, err) 196 | require.True(t, alice1.Equals(*bob1)) 197 | 198 | newSecret := randomBytes(32) 199 | newKP, err := NewKeyPackageWithSecret(suite, newSecret, &stateTest.keyPackages[1].Credential, stateTest.identityPrivs[1]) 200 | require.Nil(t, err) 201 | update, err := bob1.Update(newSecret, nil, *newKP) 202 | require.Nil(t, err) 203 | _, err = bob1.Handle(update) 204 | require.Nil(t, err) 205 | 206 | commit, _, bob2, err := bob1.Commit(secret) 207 | require.Nil(t, err) 208 | 209 | // Recreate Alice from Welcome and secrets 210 | alice1aPriv := StateSecrets{} 211 | _, err = syntax.Unmarshal(alice1priv, &alice1aPriv) 212 | require.Nil(t, err) 213 | 214 | alice1a, err := NewStateFromWelcomeAndSecrets(*welcome1, alice1aPriv) 215 | require.Nil(t, err) 216 | 217 | require.True(t, alice1a.TreePriv.ConsistentPub(alice1.Tree)) 218 | require.True(t, alice1.TreePriv.ConsistentPub(alice1a.Tree)) 219 | 220 | // Verify that Alice can process Bob's Update+Commit 221 | _, err = alice1a.Handle(update) 222 | require.Nil(t, err) 223 | 224 | alice2, err := alice1a.Handle(commit) 225 | require.Nil(t, err) 226 | 227 | // Verify that Alice and Bob can exchange protected messages 228 | /// Verify that they can exchange protected messages 229 | ct, err := alice2.Protect(testMessage) 230 | require.Nil(t, err) 231 | pt, err := bob2.Unprotect(ct) 232 | require.Nil(t, err) 233 | require.Equal(t, pt, testMessage) 234 | } 235 | 236 | func TestStateMulti(t *testing.T) { 237 | stateTest := setup(t) 238 | // start with the group creator 239 | s0, err := NewEmptyState(groupID, stateTest.initSecrets[0], stateTest.identityPrivs[0], stateTest.keyPackages[0]) 240 | require.Nil(t, err) 241 | stateTest.states = append(stateTest.states, *s0) 242 | 243 | // add proposals for rest of the participants 244 | for i := 1; i < groupSize; i++ { 245 | add, err := stateTest.states[0].Add(stateTest.keyPackages[i]) 246 | require.Nil(t, err) 247 | _, err = stateTest.states[0].Handle(add) 248 | require.Nil(t, err) 249 | } 250 | 251 | // commit the adds 252 | secret := randomBytes(32) 253 | _, welcome, next, err := stateTest.states[0].Commit(secret) 254 | require.Nil(t, err) 255 | stateTest.states[0] = *next 256 | // initialize the new joiners from the welcome 257 | for i := 1; i < groupSize; i++ { 258 | s, err := NewJoinedState(stateTest.initSecrets[i], stateTest.identityPrivs[i:i+1], stateTest.keyPackages[i:i+1], *welcome) 259 | require.Nil(t, err) 260 | stateTest.states = append(stateTest.states, *s) 261 | } 262 | 263 | // Verify that the states are all equivalent 264 | for _, lhs := range stateTest.states { 265 | for _, rhs := range stateTest.states { 266 | require.True(t, lhs.Equals(rhs)) 267 | } 268 | } 269 | 270 | // verify that everyone can send and be received 271 | for i, s := range stateTest.states { 272 | ct, _ := s.Protect(testMessage) 273 | for j, o := range stateTest.states { 274 | if i == j { 275 | continue 276 | } 277 | pt, _ := o.Unprotect(ct) 278 | require.Equal(t, pt, testMessage) 279 | } 280 | } 281 | } 282 | 283 | func TestStateUpdate(t *testing.T) { 284 | stateTest := setupGroup(t) 285 | for i, state := range stateTest.states { 286 | oldCred := stateTest.keyPackages[i].Credential 287 | newPriv, _ := oldCred.Scheme().Generate() 288 | newCred := NewBasicCredential(oldCred.Identity(), oldCred.Scheme(), newPriv.PublicKey) 289 | 290 | newSecret := randomBytes(32) 291 | newInitKey, err := suite.hpke().Derive(newSecret) 292 | require.Nil(t, err) 293 | 294 | newKP, err := NewKeyPackageWithInitKey(suite, newInitKey.PublicKey, newCred, newPriv) 295 | require.Nil(t, err) 296 | 297 | update, err := state.Update(newSecret, &newPriv, *newKP) 298 | require.Nil(t, err) 299 | state.Handle(update) 300 | 301 | commitSecret := randomBytes(32) 302 | commit, _, next, err := state.Commit(commitSecret) 303 | require.Nil(t, err) 304 | 305 | for j := range stateTest.states { 306 | if j == i { 307 | stateTest.states[j] = *next 308 | } else { 309 | _, err := stateTest.states[j].Handle(update) 310 | require.Nil(t, err) 311 | 312 | newState, err := stateTest.states[j].Handle(commit) 313 | require.Nil(t, err) 314 | stateTest.states[j] = *newState 315 | } 316 | 317 | require.Equal(t, stateTest.states[j].NewCredentials, map[LeafIndex]bool{LeafIndex(i): true}) 318 | require.True(t, stateTest.states[0].Equals(stateTest.states[j])) 319 | } 320 | } 321 | } 322 | 323 | func TestStateRemove(t *testing.T) { 324 | stateTest := setupGroup(t) 325 | for i := groupSize - 2; i > 0; i-- { 326 | remove, err := stateTest.states[i].Remove(LeafIndex(i + 1)) 327 | require.Nil(t, err) 328 | stateTest.states[i].Handle(remove) 329 | secret := randomBytes(32) 330 | commit, _, next, err := stateTest.states[i].Commit(secret) 331 | require.Nil(t, err) 332 | 333 | stateTest.states = stateTest.states[:len(stateTest.states)-1] 334 | 335 | for j := range stateTest.states { 336 | if j == i { 337 | stateTest.states[j] = *next 338 | } else { 339 | _, err := stateTest.states[j].Handle(remove) 340 | require.Nil(t, err) 341 | 342 | newState, err := stateTest.states[j].Handle(commit) 343 | require.Nil(t, err) 344 | stateTest.states[j] = *newState 345 | } 346 | 347 | require.True(t, stateTest.states[0].Equals(stateTest.states[j])) 348 | } 349 | } 350 | } 351 | -------------------------------------------------------------------------------- /test-vectors_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "path/filepath" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // To generate or verify test vectors, run `go test` with these environment 14 | // variables set to point to the directory where the test files reside. The 15 | // names of the individual files of test vectors are specified in the test 16 | // vector cases below. 17 | // 18 | // > MLS_TEST_VECTORS_OUT=... go test -run VectorGen 19 | // > MLS_TEST_VECTORS_IN=... go test -run VectorVer 20 | const ( 21 | testDirWriteEnv = "MLS_TEST_VECTORS_OUT" 22 | testDirReadEnv = "MLS_TEST_VECTORS_IN" 23 | ) 24 | 25 | // For each set of test vectors, this struct defines: 26 | // 27 | // * The file name with which the vectors should be saved / loaded 28 | // * A function to generate test vectors 29 | // * A function to verify test vectors 30 | // 31 | // The generate and verify functions are responsible for reporting their own 32 | // errors through the testing.T object passed to them. The functions themselves 33 | // should be defined in the test files for the relevant modules. 34 | type TestVectorCase struct { 35 | Filename string 36 | Generate func(t *testing.T) []byte 37 | Verify func(t *testing.T, data []byte) 38 | } 39 | 40 | var testVectorCases = map[string]TestVectorCase{ 41 | "tree_math": { 42 | Filename: "tree_math.bin", 43 | Generate: generateTreeMathVectors, 44 | Verify: verifyTreeMathVectors, 45 | }, 46 | 47 | "crypto": { 48 | Filename: "crypto.bin", 49 | Generate: generateCryptoVectors, 50 | Verify: verifyCryptoVectors, 51 | }, 52 | 53 | "messages": { 54 | Filename: "messages.bin", 55 | Generate: generateMessageVectors, 56 | Verify: verifyMessageVectors, 57 | }, 58 | 59 | "key_schedule": { 60 | Filename: "key_schedule.bin", 61 | Generate: generateKeyScheduleVectors, 62 | Verify: verifyKeyScheduleVectors, 63 | }, 64 | 65 | "ratchet_tree": { 66 | Filename: "tree.bin", 67 | Generate: generateRatchetTreeVectors, 68 | Verify: verifyRatchetTreeVectors, 69 | }, 70 | // TODO continue 71 | } 72 | 73 | func vectorGenerate(c TestVectorCase, testDir string) func(t *testing.T) { 74 | return func(t *testing.T) { 75 | // Generate test vectors 76 | vec := c.Generate(t) 77 | 78 | // Verify that vectors pass 79 | c.Verify(t, vec) 80 | 81 | // Write the vectors to file if required 82 | if len(testDir) != 0 { 83 | file := filepath.Join(testDir, c.Filename) 84 | err := ioutil.WriteFile(file, vec, 0644) 85 | require.Nil(t, err) 86 | } 87 | } 88 | } 89 | 90 | func TestVectorGenerate(t *testing.T) { 91 | testDir := os.Getenv(testDirWriteEnv) 92 | 93 | for label, tvCase := range testVectorCases { 94 | t.Run(label, vectorGenerate(tvCase, testDir)) 95 | } 96 | } 97 | 98 | func vectorVerify(c TestVectorCase, testDir string) func(t *testing.T) { 99 | return func(t *testing.T) { 100 | // Read test vectors 101 | file := filepath.Join(testDir, c.Filename) 102 | fmt.Printf("Test File %v\n", file) 103 | vec, err := ioutil.ReadFile(file) 104 | require.Nil(t, err) 105 | 106 | // Verify test vectors 107 | c.Verify(t, vec) 108 | } 109 | } 110 | 111 | func TestVectorVerify(t *testing.T) { 112 | testDir := "" 113 | if testDir = os.Getenv(testDirReadEnv); len(testDir) == 0 { 114 | t.Skip("Test vectors were not provided") 115 | } 116 | 117 | for label, tvCase := range testVectorCases { 118 | t.Run(label, vectorVerify(tvCase, testDir)) 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /tree-math.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | // The below functions provide the index calculus for the tree structures used in MLS. 4 | // They are premised on a "flat" representation of a balanced binary tree. Leaf nodes 5 | // are even-numbered nodes, with the n-th leaf at 2*n. Intermediate nodes are held in 6 | // odd-numbered nodes. For example, a 11-element tree has the following structure: 7 | // 8 | // X 9 | // X 10 | // X X X 11 | // X X X X X 12 | // X X X X X X X X X X X 13 | // 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 14 14 | // 15 | // This allows us to compute relationships between tree nodes simply by manipulating 16 | // indices, rather than having to maintain complicated structures in memory, even for 17 | // partial trees. (The storage for a tree can just be a map[int]Node dictionary or 18 | // an array.) The basic rule is that the high-order bits of parent and child nodes 19 | // have the following relation: 20 | // 21 | // 01x = <00x, 10x> 22 | 23 | type LeafIndex uint32 24 | type LeafCount uint32 25 | type NodeIndex uint32 26 | type nodeCount uint32 27 | 28 | func toNodeIndex(leaf LeafIndex) NodeIndex { 29 | return NodeIndex(2 * leaf) 30 | } 31 | 32 | func toLeafIndex(node NodeIndex) LeafIndex { 33 | if node&0x01 != 0 { 34 | panic("toLeafIndex on non-leaf index") 35 | } 36 | 37 | return LeafIndex(node) >> 1 38 | } 39 | 40 | // Position of the most significant 1 bit 41 | func log2(x nodeCount) uint { 42 | if x == 0 { 43 | return 0 44 | } 45 | 46 | k := uint(0) 47 | for (x >> k) > 0 { 48 | k += 1 49 | } 50 | return k - 1 51 | } 52 | 53 | // Position of the least significant 0 bit 54 | func level(x NodeIndex) uint { 55 | if x&0x01 == 0 { 56 | return 0 57 | } 58 | 59 | k := uint(0) 60 | for (x>>k)&0x01 == 1 { 61 | k += 1 62 | } 63 | return k 64 | } 65 | 66 | // Number of nodes for a tree of size N 67 | func nodeWidth(n LeafCount) nodeCount { 68 | return nodeCount(2*n - 1) 69 | } 70 | 71 | // Number of leaves for a tree with N nodes 72 | func leafWidth(n nodeCount) LeafCount { 73 | return LeafCount((n + 1) >> 1) 74 | } 75 | 76 | // Index of the root of the tree with N leaves 77 | func root(n LeafCount) NodeIndex { 78 | w := nodeWidth(n) 79 | return NodeIndex((1 << log2(w)) - 1) 80 | } 81 | 82 | // Left child of x 83 | func left(x NodeIndex) NodeIndex { 84 | if level(x) == 0 { 85 | return x 86 | } 87 | 88 | return x ^ (0x01 << (level(x) - 1)) 89 | } 90 | 91 | // Right child of x 92 | func right(x NodeIndex, n LeafCount) NodeIndex { 93 | if level(x) == 0 { 94 | return x 95 | } 96 | 97 | w := NodeIndex(nodeWidth(n)) 98 | r := x ^ (0x03 << (level(x) - 1)) 99 | for r >= w { 100 | r = left(r) 101 | } 102 | return r 103 | } 104 | 105 | // Immediate parent of x; may not exist in tree 106 | func parent_step(x NodeIndex) NodeIndex { 107 | // xy01 -> x011 108 | k := level(x) 109 | one := uint(1) 110 | return NodeIndex((uint(x) | (one << k)) & ^(one << (k + 1))) 111 | } 112 | 113 | // Parent of x 114 | func parent(x NodeIndex, n LeafCount) NodeIndex { 115 | // root's parent is itself 116 | if x == root(n) { 117 | return x 118 | } 119 | 120 | w := NodeIndex(nodeWidth(n)) 121 | p := parent_step(x) 122 | for p >= w { 123 | p = parent_step(p) 124 | } 125 | return p 126 | } 127 | 128 | // Sibling of x 129 | func sibling(x NodeIndex, n LeafCount) NodeIndex { 130 | p := parent(x, n) 131 | if x < p { 132 | return right(p, n) 133 | } else if x > p { 134 | return left(p) 135 | } 136 | 137 | // root's sibling is itself 138 | return p 139 | } 140 | 141 | // Direct path for x 142 | // Ordered from leaf to root, excluding leaf, including root 143 | func dirpath(x NodeIndex, n LeafCount) []NodeIndex { 144 | d := []NodeIndex{} 145 | p := parent(x, n) 146 | r := root(n) 147 | for p != r { 148 | d = append(d, p) 149 | p = parent(p, n) 150 | } 151 | 152 | if x != r { 153 | d = append(d, p) 154 | } 155 | return d 156 | } 157 | 158 | // Copath for x 159 | // Ordered from leaf to root 160 | func copath(x NodeIndex, n LeafCount) []NodeIndex { 161 | d := dirpath(x, n) 162 | if len(d) == 0 { 163 | return []NodeIndex{} 164 | } 165 | 166 | d = append([]NodeIndex{x}, d[:len(d)-1]...) 167 | 168 | r := root(n) 169 | c := make([]NodeIndex, len(d)) 170 | for i, x := range d { 171 | // Don't include the root 172 | if x == r { 173 | continue 174 | } 175 | 176 | c[i] = sibling(x, n) 177 | } 178 | 179 | return c 180 | } 181 | 182 | func inPath(x, y NodeIndex) bool { 183 | lx, ly := level(x), level(y) 184 | return lx <= ly && x>>(ly+1) == y>>(ly+1) 185 | } 186 | 187 | func fullAncestor(l, r NodeIndex) NodeIndex { 188 | ll, lr := level(l)+1, level(r)+1 189 | if ll <= lr && l>>lr == r>>lr { 190 | return r 191 | } 192 | if lr <= ll && l>>ll == r>>ll { 193 | return l 194 | } 195 | 196 | k := uint(0) 197 | ln, rn := l, r 198 | for ln != rn { 199 | ln, rn = ln>>1, rn>>1 200 | k += 1 201 | } 202 | 203 | return (ln << k) + (1 << (k - 1)) - 1 204 | } 205 | 206 | // Common ancestor of two leaves 207 | func ancestor(l, r LeafIndex) NodeIndex { 208 | ln, rn := toNodeIndex(l), toNodeIndex(r) 209 | 210 | k := uint(0) 211 | for ln != rn { 212 | ln, rn = ln>>1, rn>>1 213 | k += 1 214 | } 215 | 216 | return (ln << k) + (1 << (k - 1)) - 1 217 | } 218 | -------------------------------------------------------------------------------- /tree-math_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/cisco/go-tls-syntax" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | // Precomputed answers for the tree on eleven elements: 12 | // 13 | // X 14 | // X 15 | // X X X 16 | // X X X X X 17 | // X X X X X X X X X X X 18 | // 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 14 19 | var ( 20 | aRoot = []NodeIndex{0x00, 0x01, 0x03, 0x03, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f} 21 | 22 | aN = LeafCount(0x0b) 23 | index = []NodeIndex{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14} 24 | aLog2 = []NodeIndex{0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x04, 0x04, 0x04, 0x04, 0x04} 25 | aLevel = []NodeIndex{0x00, 0x01, 0x00, 0x02, 0x00, 0x01, 0x00, 0x03, 0x00, 0x01, 0x00, 0x02, 0x00, 0x01, 0x00, 0x04, 0x00, 0x01, 0x00, 0x02, 0x00} 26 | aLeft = []NodeIndex{0x00, 0x00, 0x02, 0x01, 0x04, 0x04, 0x06, 0x03, 0x08, 0x08, 0x0a, 0x09, 0x0c, 0x0c, 0x0e, 0x07, 0x10, 0x10, 0x12, 0x11, 0x14} 27 | aRight = []NodeIndex{0x00, 0x02, 0x02, 0x05, 0x04, 0x06, 0x06, 0x0b, 0x08, 0x0a, 0x0a, 0x0d, 0x0c, 0x0e, 0x0e, 0x13, 0x10, 0x12, 0x12, 0x14, 0x14} 28 | aParent = []NodeIndex{0x01, 0x03, 0x01, 0x07, 0x05, 0x03, 0x05, 0x0f, 0x09, 0x0b, 0x09, 0x07, 0x0d, 0x0b, 0x0d, 0x0f, 0x11, 0x13, 0x11, 0x0f, 0x13} 29 | aSibling = []NodeIndex{0x02, 0x05, 0x00, 0x0b, 0x06, 0x01, 0x04, 0x13, 0x0a, 0x0d, 0x08, 0x03, 0x0e, 0x09, 0x0c, 0x0f, 0x12, 0x14, 0x10, 0x07, 0x11} 30 | 31 | aDirpath = [][]NodeIndex{ 32 | {0x01, 0x03, 0x07, 0x0f}, 33 | {0x03, 0x07, 0x0f}, 34 | {0x01, 0x03, 0x07, 0x0f}, 35 | {0x07, 0x0f}, 36 | {0x05, 0x03, 0x07, 0x0f}, 37 | {0x03, 0x07, 0x0f}, 38 | {0x05, 0x03, 0x07, 0x0f}, 39 | {0x0f}, 40 | {0x09, 0x0b, 0x07, 0x0f}, 41 | {0x0b, 0x07, 0x0f}, 42 | {0x09, 0x0b, 0x07, 0x0f}, 43 | {0x07, 0x0f}, 44 | {0x0d, 0x0b, 0x07, 0x0f}, 45 | {0x0b, 0x07, 0x0f}, 46 | {0x0d, 0x0b, 0x07, 0x0f}, 47 | {}, 48 | {0x11, 0x13, 0x0f}, 49 | {0x13, 0x0f}, 50 | {0x11, 0x13, 0x0f}, 51 | {0x0f}, 52 | {0x13, 0x0f}, 53 | } 54 | aCopath = [][]NodeIndex{ 55 | {0x02, 0x05, 0x0b, 0x13}, 56 | {0x05, 0x0b, 0x13}, 57 | {0x00, 0x05, 0x0b, 0x13}, 58 | {0x0b, 0x13}, 59 | {0x06, 0x01, 0x0b, 0x13}, 60 | {0x01, 0x0b, 0x13}, 61 | {0x04, 0x01, 0x0b, 0x13}, 62 | {0x13}, 63 | {0x0a, 0x0d, 0x03, 0x13}, 64 | {0x0d, 0x03, 0x13}, 65 | {0x08, 0x0d, 0x03, 0x13}, 66 | {0x03, 0x13}, 67 | {0x0e, 0x09, 0x03, 0x13}, 68 | {0x09, 0x03, 0x13}, 69 | {0x0c, 0x09, 0x03, 0x13}, 70 | {}, 71 | {0x12, 0x14, 0x07}, 72 | {0x14, 0x07}, 73 | {0x10, 0x14, 0x07}, 74 | {0x07}, 75 | {0x11, 0x07}, 76 | } 77 | 78 | aInPath = [][]int{ 79 | // 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 14 80 | /**/ {1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 0 81 | /**/ {0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 1 82 | /**/ {0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 2 83 | /**/ {0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 3 84 | /**/ {0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 4 85 | /**/ {0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 5 86 | /**/ {0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 6 87 | /**/ {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 7 88 | /**/ {0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 8 89 | /**/ {0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // 9 90 | /**/ {0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // a 91 | /**/ {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // b 92 | /**/ {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0}, // c 93 | /**/ {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0}, // d 94 | /**/ {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0}, // e 95 | /**/ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, // f 96 | /**/ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0}, // 10 97 | /**/ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0}, // 11 98 | /**/ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0}, // 12 99 | /**/ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0}, // 13 100 | /**/ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1}, // 14 101 | } 102 | 103 | aFullAncestor = [][]NodeIndex{ 104 | // 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 14 105 | {0x00, 0x01, 0x01, 0x03, 0x03, 0x03, 0x03, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 0 106 | {0x01, 0x01, 0x01, 0x03, 0x03, 0x03, 0x03, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 1 107 | {0x01, 0x01, 0x02, 0x03, 0x03, 0x03, 0x03, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 2 108 | {0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 3 109 | {0x03, 0x03, 0x03, 0x03, 0x04, 0x05, 0x05, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 4 110 | {0x03, 0x03, 0x03, 0x03, 0x05, 0x05, 0x05, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 5 111 | {0x03, 0x03, 0x03, 0x03, 0x05, 0x05, 0x06, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 6 112 | {0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 7 113 | {0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x08, 0x09, 0x09, 0x0b, 0x0b, 0x0b, 0x0b, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 8 114 | {0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x09, 0x09, 0x09, 0x0b, 0x0b, 0x0b, 0x0b, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // 9 115 | {0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x09, 0x09, 0x0a, 0x0b, 0x0b, 0x0b, 0x0b, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // a 116 | {0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // b 117 | {0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0b, 0x0b, 0x0b, 0x0b, 0x0c, 0x0d, 0x0d, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // c 118 | {0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0b, 0x0b, 0x0b, 0x0b, 0x0d, 0x0d, 0x0d, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // d 119 | {0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x0b, 0x0b, 0x0b, 0x0b, 0x0d, 0x0d, 0x0e, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // e 120 | {0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f}, // f 121 | {0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x10, 0x11, 0x11, 0x13, 0x13}, // 10 122 | {0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x11, 0x11, 0x11, 0x13, 0x13}, // 11 123 | {0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x11, 0x11, 0x12, 0x13, 0x13}, // 12 124 | {0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x13, 0x13, 0x13, 0x13, 0x13}, // 13 125 | {0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x13, 0x13, 0x13, 0x13, 0x14}, // 14 126 | } 127 | 128 | aAncestor = [][]NodeIndex{ 129 | {0x01, 0x03, 0x03, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f}, 130 | {0x03, 0x03, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f}, 131 | {0x05, 0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f}, 132 | {0x07, 0x07, 0x07, 0x07, 0x0f, 0x0f, 0x0f}, 133 | {0x09, 0x0b, 0x0b, 0x0f, 0x0f, 0x0f}, 134 | {0x0b, 0x0b, 0x0f, 0x0f, 0x0f}, 135 | {0x0d, 0x0f, 0x0f, 0x0f}, 136 | {0x0f, 0x0f, 0x0f}, 137 | {0x11, 0x13}, 138 | {0x13}, 139 | } 140 | ) 141 | 142 | func TestSizeProperties(t *testing.T) { 143 | for n := LeafCount(1); n < aN; n += 1 { 144 | if root(n) != aRoot[n-1] { 145 | t.Fatalf("Root mismatch: %v != %v", root(n), aRoot[n-1]) 146 | } 147 | } 148 | } 149 | 150 | func TestNodeRelations(t *testing.T) { 151 | run := func(label string, f func(x NodeIndex) NodeIndex, a []NodeIndex) { 152 | for i, x := range index { 153 | if f(x) != a[i] { 154 | t.Fatalf("Relation test failure: %s @ 0x%02x: %v != %v", label, x, f(x), a[i]) 155 | } 156 | } 157 | } 158 | 159 | run("log2", func(x NodeIndex) NodeIndex { return NodeIndex(log2(nodeCount(x))) }, aLog2) 160 | run("level", func(x NodeIndex) NodeIndex { return NodeIndex(level(x)) }, aLevel) 161 | run("left", left, aLeft) 162 | run("right", func(x NodeIndex) NodeIndex { return right(x, aN) }, aRight) 163 | run("parent", func(x NodeIndex) NodeIndex { return parent(x, aN) }, aParent) 164 | run("sibling", func(x NodeIndex) NodeIndex { return sibling(x, aN) }, aSibling) 165 | } 166 | 167 | func TestPaths(t *testing.T) { 168 | run := func(label string, f func(x NodeIndex, n LeafCount) []NodeIndex, a [][]NodeIndex) { 169 | for i, x := range index { 170 | if !reflect.DeepEqual(f(x, aN), a[i]) { 171 | t.Fatalf("Path test failure: %s @ 0x%02x: %v != %v", label, x, f(x, aN), a[i]) 172 | } 173 | } 174 | } 175 | 176 | run("dirpath", dirpath, aDirpath) 177 | run("copath", copath, aCopath) 178 | } 179 | 180 | func TestInPath(t *testing.T) { 181 | w := NodeIndex(nodeWidth(aN)) 182 | for l := NodeIndex(0); l < w; l += 1 { 183 | for r := NodeIndex(0); r < w; r += 1 { 184 | answer := aInPath[l][r] == 1 185 | lr := inPath(l, r) 186 | 187 | if lr != answer { 188 | t.Errorf("Incorrect inPath determination: %d %d => %v != %v [%08b %08b]", l, r, lr, answer, l, r) 189 | } 190 | } 191 | } 192 | } 193 | 194 | func TestFullAncestor(t *testing.T) { 195 | w := NodeIndex(nodeWidth(aN)) 196 | for l := NodeIndex(0); l < w; l += 1 { 197 | for r := NodeIndex(0); r < w; r += 1 { 198 | answer := aFullAncestor[l][r] 199 | lr := fullAncestor(l, r) 200 | rl := fullAncestor(r, l) 201 | 202 | if lr != answer { 203 | t.Errorf("Incorrect ancestor: %d %d => %d != %d", l, r, lr, answer) 204 | } 205 | 206 | if rl != lr { 207 | t.Errorf("Asymmetric ancestor: %d %d => %d != %d", l, r, rl, lr) 208 | } 209 | } 210 | } 211 | } 212 | 213 | func TestAncestor(t *testing.T) { 214 | for l := LeafIndex(0); l < LeafIndex(aN-1); l += 1 { 215 | for r := l + 1; r < LeafIndex(aN); r += 1 { 216 | answer := aAncestor[l][r-l-1] 217 | lr := ancestor(l, r) 218 | rl := ancestor(r, l) 219 | 220 | if lr != answer { 221 | t.Fatalf("Incorrect ancestor: %d %d => %d != %d", l, r, lr, answer) 222 | } 223 | 224 | if rl != answer { 225 | t.Fatalf("Asymmetric ancestor: %d %d => %d != %d", l, r, rl, lr) 226 | } 227 | } 228 | } 229 | } 230 | 231 | /// 232 | /// Test Vectors 233 | /// 234 | 235 | type NodeIndexSlice struct { 236 | Data []NodeIndex `tls:"head=4"` 237 | } 238 | 239 | type TreeMathTestVectors struct { 240 | NumLeaves LeafCount 241 | Root []NodeIndex `tls:"head=4"` 242 | Left []NodeIndex `tls:"head=4"` 243 | Right []NodeIndex `tls:"head=4"` 244 | Parent []NodeIndex `tls:"head=4"` 245 | Sibling []NodeIndex `tls:"head=4"` 246 | DirPath []NodeIndexSlice `tls:"head=4"` 247 | CoPath []NodeIndexSlice `tls:"head=4"` 248 | Ancestor []NodeIndexSlice `tls:"head=4"` 249 | } 250 | 251 | func generateTreeMathVectors(t *testing.T) []byte { 252 | numLeaves := LeafCount(255) 253 | numNodes := nodeWidth(numLeaves) 254 | tv := TreeMathTestVectors{ 255 | NumLeaves: numLeaves, 256 | Root: make([]NodeIndex, numLeaves), 257 | Left: make([]NodeIndex, numNodes), 258 | Right: make([]NodeIndex, numNodes), 259 | Parent: make([]NodeIndex, numNodes), 260 | Sibling: make([]NodeIndex, numNodes), 261 | DirPath: make([]NodeIndexSlice, numNodes), 262 | CoPath: make([]NodeIndexSlice, numNodes), 263 | Ancestor: make([]NodeIndexSlice, numNodes), 264 | } 265 | 266 | for i := range tv.Root { 267 | tv.Root[i] = root(LeafCount(i + 1)) 268 | } 269 | 270 | for i := range tv.Left { 271 | tv.Left[i] = left(NodeIndex(i)) 272 | tv.Right[i] = right(NodeIndex(i), numLeaves) 273 | tv.Parent[i] = parent(NodeIndex(i), numLeaves) 274 | tv.Sibling[i] = sibling(NodeIndex(i), numLeaves) 275 | tv.DirPath[i] = NodeIndexSlice{Data: dirpath(NodeIndex(i), numLeaves)} 276 | tv.CoPath[i] = NodeIndexSlice{Data: copath(NodeIndex(i), numLeaves)} 277 | } 278 | 279 | // ancestor 280 | for l := LeafIndex(0); l < LeafIndex(numLeaves-1); l += 1 { 281 | a := []NodeIndex{} 282 | for r := l + 1; r < LeafIndex(numLeaves); r += 1 { 283 | lr := ancestor(l, r) 284 | a = append(a, lr) 285 | } 286 | tv.Ancestor[l].Data = a 287 | } 288 | 289 | vec, err := syntax.Marshal(tv) 290 | require.Nil(t, err) 291 | return vec 292 | } 293 | 294 | func verifyTreeMathVectors(t *testing.T, data []byte) { 295 | var tv TreeMathTestVectors 296 | _, err := syntax.Unmarshal(data, &tv) 297 | require.Nil(t, err) 298 | 299 | tvLen := int(nodeWidth(tv.NumLeaves)) 300 | if len(tv.Root) != int(tv.NumLeaves) || len(tv.Left) != tvLen || 301 | len(tv.Right) != tvLen || len(tv.Parent) != tvLen || len(tv.Sibling) != tvLen || 302 | len(tv.DirPath) != tvLen || len(tv.CoPath) != tvLen { 303 | t.Fatalf("Malformed tree math test vectors: Incorrect vector sizes") 304 | } 305 | 306 | for i := range tv.Root { 307 | require.Equal(t, tv.Root[i], root(LeafCount(i+1))) 308 | } 309 | 310 | for i := range tv.Left { 311 | require.Equal(t, tv.Left[i], left(NodeIndex(i))) 312 | require.Equal(t, tv.Right[i], right(NodeIndex(i), tv.NumLeaves)) 313 | require.Equal(t, tv.Parent[i], parent(NodeIndex(i), tv.NumLeaves)) 314 | require.Equal(t, tv.Sibling[i], sibling(NodeIndex(i), tv.NumLeaves)) 315 | require.Equal(t, tv.DirPath[i].Data, dirpath(NodeIndex(i), tv.NumLeaves)) 316 | require.Equal(t, tv.CoPath[i].Data, copath(NodeIndex(i), tv.NumLeaves)) 317 | } 318 | 319 | // ancestor 320 | for l := LeafIndex(0); l < LeafIndex(tv.NumLeaves-1); l += 1 { 321 | a := []NodeIndex{} 322 | for r := l + 1; r < LeafIndex(tv.NumLeaves); r += 1 { 323 | lr := ancestor(l, r) 324 | a = append(a, lr) 325 | } 326 | require.Equal(t, tv.Ancestor[l].Data, a) 327 | } 328 | 329 | } 330 | 331 | func TestTreeMathErrorCases(t *testing.T) { 332 | require.Panics(t, func() { toLeafIndex(0x03) }) 333 | } 334 | -------------------------------------------------------------------------------- /treekem.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "reflect" 7 | 8 | syntax "github.com/cisco/go-tls-syntax" 9 | ) 10 | 11 | type NodeType uint8 12 | 13 | const ( 14 | NodeTypeLeaf NodeType = 0x00 15 | NodeTypeParent NodeType = 0x01 16 | ) 17 | 18 | /// 19 | /// ParentNode 20 | /// 21 | 22 | type ParentNode struct { 23 | PublicKey HPKEPublicKey 24 | UnmergedLeaves []LeafIndex `tls:"head=4"` 25 | ParentHash []byte `tls:"head=1"` 26 | } 27 | 28 | func (n *ParentNode) Equals(other *ParentNode) bool { 29 | pubKey := reflect.DeepEqual(n.PublicKey, other.PublicKey) 30 | unmerged := reflect.DeepEqual(n.UnmergedLeaves, other.UnmergedLeaves) 31 | parentHash := reflect.DeepEqual(n.ParentHash, other.ParentHash) 32 | 33 | return pubKey && unmerged && parentHash 34 | } 35 | 36 | func (n ParentNode) Clone() ParentNode { 37 | next := ParentNode{ 38 | PublicKey: n.PublicKey, 39 | UnmergedLeaves: make([]LeafIndex, len(n.UnmergedLeaves)), 40 | ParentHash: dup(n.ParentHash), 41 | } 42 | 43 | for i, n := range n.UnmergedLeaves { 44 | next.UnmergedLeaves[i] = n 45 | } 46 | 47 | return next 48 | } 49 | 50 | func (n *ParentNode) AddUnmerged(l LeafIndex) { 51 | n.UnmergedLeaves = append(n.UnmergedLeaves, l) 52 | } 53 | 54 | /// 55 | /// Node 56 | /// 57 | type Node struct { 58 | Leaf *KeyPackage 59 | Parent *ParentNode 60 | } 61 | 62 | func (n *Node) Equals(other *Node) bool { 63 | if n == nil || other == nil { 64 | return n == other 65 | } 66 | 67 | switch n.Type() { 68 | case NodeTypeLeaf: 69 | return n.Leaf.Equals(*other.Leaf) 70 | case NodeTypeParent: 71 | return n.Parent.Equals(other.Parent) 72 | default: 73 | return false 74 | } 75 | } 76 | 77 | func (n *Node) Clone() *Node { 78 | if n == nil { 79 | return nil 80 | } 81 | 82 | next := &Node{} 83 | switch n.Type() { 84 | case NodeTypeLeaf: 85 | clone := n.Leaf.Clone() 86 | next.Leaf = &clone 87 | case NodeTypeParent: 88 | clone := n.Parent.Clone() 89 | next.Parent = &clone 90 | default: 91 | panic("Malformed node") 92 | } 93 | 94 | return next 95 | } 96 | 97 | func (n Node) Type() NodeType { 98 | switch { 99 | case n.Leaf != nil: 100 | return NodeTypeLeaf 101 | case n.Parent != nil: 102 | return NodeTypeParent 103 | default: 104 | panic("Malformed node") 105 | } 106 | } 107 | 108 | func (n Node) PublicKey() HPKEPublicKey { 109 | switch n.Type() { 110 | case NodeTypeLeaf: 111 | return n.Leaf.InitKey 112 | case NodeTypeParent: 113 | return n.Parent.PublicKey 114 | default: 115 | panic("Malformed node") 116 | } 117 | } 118 | 119 | func (n Node) MarshalTLS() ([]byte, error) { 120 | s := syntax.NewWriteStream() 121 | nodeType := n.Type() 122 | err := s.Write(nodeType) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | switch nodeType { 128 | case NodeTypeLeaf: 129 | err = s.Write(n.Leaf) 130 | case NodeTypeParent: 131 | err = s.Write(n.Parent) 132 | default: 133 | err = fmt.Errorf("mls.node: Invalid node type") 134 | } 135 | if err != nil { 136 | return nil, err 137 | } 138 | 139 | return s.Data(), nil 140 | } 141 | 142 | func (n *Node) UnmarshalTLS(data []byte) (int, error) { 143 | s := syntax.NewReadStream(data) 144 | var nodeType NodeType 145 | _, err := s.Read(&nodeType) 146 | if err != nil { 147 | return 0, err 148 | } 149 | 150 | switch nodeType { 151 | case NodeTypeLeaf: 152 | n.Leaf = new(KeyPackage) 153 | _, err = s.Read(n.Leaf) 154 | case NodeTypeParent: 155 | n.Parent = new(ParentNode) 156 | _, err = s.Read(n.Parent) 157 | default: 158 | err = fmt.Errorf("mls.node: Invalid node type") 159 | } 160 | if err != nil { 161 | return 0, err 162 | } 163 | 164 | return s.Position(), nil 165 | } 166 | 167 | /// 168 | /// OptionalNode 169 | /// 170 | type OptionalNode struct { 171 | Node *Node `tls:"optional"` 172 | Hash []byte `tls:"omit"` 173 | } 174 | 175 | func newLeafNode(keyPkg KeyPackage) OptionalNode { 176 | return OptionalNode{Node: &Node{Leaf: &keyPkg}} 177 | } 178 | 179 | func newParentNode(pub HPKEPublicKey) OptionalNode { 180 | parentNode := &ParentNode{ 181 | PublicKey: pub, 182 | UnmergedLeaves: []LeafIndex{}, 183 | ParentHash: []byte{}, 184 | } 185 | return OptionalNode{Node: &Node{Parent: parentNode}} 186 | } 187 | 188 | func (n OptionalNode) Clone() OptionalNode { 189 | return OptionalNode{ 190 | Node: n.Node.Clone(), 191 | Hash: dup(n.Hash), 192 | } 193 | } 194 | 195 | func (n OptionalNode) Blank() bool { 196 | return n.Node == nil 197 | } 198 | 199 | func (n *OptionalNode) SetToBlank() { 200 | n.Node = nil 201 | } 202 | 203 | func (n *OptionalNode) setNodeHash(suite CipherSuite, input interface{}) error { 204 | data, err := syntax.Marshal(input) 205 | if err != nil { 206 | return err 207 | } 208 | 209 | n.Hash = suite.Digest(data) 210 | return nil 211 | } 212 | 213 | type LeafNodeHashInput struct { 214 | LeafIndex LeafIndex 215 | KeyPackage *KeyPackage `tls:"optional"` 216 | } 217 | 218 | func (n *OptionalNode) SetLeafNodeHash(suite CipherSuite, index LeafIndex) error { 219 | input := LeafNodeHashInput{ 220 | LeafIndex: index, 221 | KeyPackage: nil, 222 | } 223 | 224 | if !n.Blank() { 225 | if n.Node.Type() != NodeTypeLeaf { 226 | return fmt.Errorf("mls.rtn: SetLeafNodeHash on non-leaf node") 227 | } 228 | 229 | input.KeyPackage = n.Node.Leaf 230 | } 231 | 232 | return n.setNodeHash(suite, input) 233 | } 234 | 235 | type ParentNodeHashInput struct { 236 | NodeIndex NodeIndex 237 | ParentNode *ParentNode `tls:"optional"` 238 | LeftHash []byte `tls:"head=1"` 239 | RightHash []byte `tls:"head=1"` 240 | } 241 | 242 | func (n *OptionalNode) SetParentNodeHash(suite CipherSuite, index NodeIndex, left, right []byte) error { 243 | input := ParentNodeHashInput{ 244 | NodeIndex: index, 245 | ParentNode: nil, 246 | LeftHash: left, 247 | RightHash: right, 248 | } 249 | 250 | if !n.Blank() { 251 | if n.Node.Type() != NodeTypeParent { 252 | return fmt.Errorf("mls.rtn: SetParentNodeHash on non-leaf node") 253 | } 254 | 255 | input.ParentNode = n.Node.Parent 256 | } 257 | 258 | return n.setNodeHash(suite, input) 259 | } 260 | 261 | /// 262 | /// DirectPath 263 | /// 264 | type DirectPathNode struct { 265 | PublicKey HPKEPublicKey 266 | EncryptedPathSecrets []HPKECiphertext `tls:"head=4"` 267 | } 268 | 269 | type DirectPath struct { 270 | LeafKeyPackage KeyPackage 271 | Steps []DirectPathNode `tls:"head=4"` 272 | } 273 | 274 | // This produces a list of parent hashes that are off by one with respect to the 275 | // steps in the path. The path hash at position i goes with the public key at 276 | // position i-1, and the path hash at position 0 goes in the leaf. 277 | func (path DirectPath) ParentHashes(suite CipherSuite) ([][]byte, error) { 278 | ph := make([][]byte, len(path.Steps)) 279 | 280 | var err error 281 | var lastHash []byte 282 | for i := len(path.Steps) - 1; i >= 0; i-- { 283 | parentNode := ParentNode{ 284 | PublicKey: path.Steps[i].PublicKey, 285 | ParentHash: lastHash, 286 | } 287 | 288 | lastHash, err = syntax.Marshal(parentNode) 289 | if err != nil { 290 | return nil, err 291 | } 292 | 293 | ph[i] = suite.Digest(lastHash) 294 | } 295 | 296 | return ph, nil 297 | } 298 | 299 | func (path DirectPath) ParentHashValid(suite CipherSuite) error { 300 | leafParentHash := []byte{} 301 | if len(path.Steps) > 0 { 302 | ph, err := path.ParentHashes(suite) 303 | if err != nil { 304 | return err 305 | } 306 | 307 | leafParentHash = ph[0] 308 | } 309 | 310 | phe := ParentHashExtension{} 311 | found, err := path.LeafKeyPackage.Extensions.Find(&phe) 312 | switch { 313 | case err != nil: 314 | return err 315 | 316 | case !found: 317 | return fmt.Errorf("No ParentHash extension") 318 | 319 | case !bytes.Equal(leafParentHash, phe.ParentHash): 320 | return fmt.Errorf("Incorrect parent hash") 321 | } 322 | 323 | return nil 324 | } 325 | 326 | type KeyPackageOpts struct { 327 | // TODO New credential 328 | // TODO Extensions 329 | } 330 | 331 | func (path *DirectPath) Sign(suite CipherSuite, initPub HPKEPublicKey, sigPriv SignaturePrivateKey, opts *KeyPackageOpts) error { 332 | // Compute parent hashes down the tree from the root 333 | leafParentHash := []byte{} 334 | if len(path.Steps) > 0 { 335 | ph, err := path.ParentHashes(suite) 336 | if err != nil { 337 | return err 338 | } 339 | 340 | leafParentHash = ph[0] 341 | } 342 | 343 | // Re-sign the leaf key package 344 | // TODO(RLB) Apply any options from opts 345 | // TODO(RLB) Move resigning logic into KeyPackage 346 | phe := ParentHashExtension{leafParentHash} 347 | err := path.LeafKeyPackage.SetExtensions([]ExtensionBody{phe}) 348 | if err != nil { 349 | return err 350 | } 351 | 352 | path.LeafKeyPackage.InitKey = initPub 353 | 354 | return path.LeafKeyPackage.Sign(sigPriv) 355 | } 356 | 357 | //////////////////////////////////////////////////////////// 358 | //////////////////////////////////////////////////////////// 359 | //////////////////////////////////////////////////////////// 360 | 361 | type TreeKEMPrivateKey struct { 362 | Suite CipherSuite 363 | Index LeafIndex 364 | UpdateSecret []byte `tls:"head=1"` 365 | PathSecrets map[NodeIndex]Bytes1 `tls:"head=4"` 366 | privateKeyCache map[NodeIndex]HPKEPrivateKey `tls:"omit"` 367 | } 368 | 369 | func NewTreeKEMPrivateKeyForJoiner(suite CipherSuite, index LeafIndex, size LeafCount, leafSecret []byte, intersect NodeIndex, pathSecret []byte) *TreeKEMPrivateKey { 370 | priv := &TreeKEMPrivateKey{ 371 | Suite: suite, 372 | Index: index, 373 | PathSecrets: map[NodeIndex]Bytes1{}, 374 | privateKeyCache: map[NodeIndex]HPKEPrivateKey{}, 375 | } 376 | 377 | priv.PathSecrets[toNodeIndex(index)] = dup(leafSecret) 378 | if pathSecret != nil { 379 | priv.setPathSecrets(intersect, size, pathSecret) 380 | } 381 | return priv 382 | } 383 | 384 | func NewTreeKEMPrivateKey(suite CipherSuite, size LeafCount, index LeafIndex, leafSecret []byte) *TreeKEMPrivateKey { 385 | priv := &TreeKEMPrivateKey{ 386 | Suite: suite, 387 | Index: index, 388 | PathSecrets: map[NodeIndex]Bytes1{}, 389 | privateKeyCache: map[NodeIndex]HPKEPrivateKey{}, 390 | } 391 | 392 | priv.setPathSecrets(toNodeIndex(index), size, leafSecret) 393 | return priv 394 | } 395 | 396 | func (priv TreeKEMPrivateKey) pathStep(pathSecret []byte) []byte { 397 | return priv.Suite.hkdfExpandLabel(pathSecret, "path", []byte{}, priv.Suite.Constants().SecretSize) 398 | } 399 | 400 | func (priv *TreeKEMPrivateKey) setPathSecrets(start NodeIndex, size LeafCount, secret []byte) { 401 | r := root(size) 402 | pathSecret := secret 403 | for n := start; n != r; n = parent(n, size) { 404 | priv.PathSecrets[n] = dup(pathSecret) 405 | delete(priv.privateKeyCache, n) 406 | pathSecret = priv.pathStep(pathSecret) 407 | } 408 | 409 | priv.PathSecrets[r] = dup(pathSecret) 410 | delete(priv.privateKeyCache, r) 411 | 412 | priv.UpdateSecret = priv.pathStep(pathSecret) 413 | } 414 | 415 | func (priv TreeKEMPrivateKey) privateKey(n NodeIndex) (HPKEPrivateKey, error) { 416 | if key, ok := priv.privateKeyCache[n]; ok { 417 | return key, nil 418 | } 419 | 420 | secret, ok := priv.PathSecrets[n] 421 | if !ok || secret == nil { 422 | return HPKEPrivateKey{}, fmt.Errorf("Private key not found") 423 | } 424 | 425 | key, err := priv.Suite.hpke().Derive(secret) 426 | if err != nil { 427 | return HPKEPrivateKey{}, err 428 | } 429 | 430 | priv.privateKeyCache[n] = key 431 | return key, nil 432 | } 433 | 434 | func (priv TreeKEMPrivateKey) SharedPathSecret(to LeafIndex) (NodeIndex, []byte, bool) { 435 | n := ancestor(priv.Index, to) 436 | secret, ok := priv.PathSecrets[n] 437 | return n, secret, ok 438 | } 439 | 440 | func (priv *TreeKEMPrivateKey) SetLeafSecret(secret []byte) { 441 | // TODO(RLB) Check for consistency? 442 | ni := toNodeIndex(priv.Index) 443 | priv.PathSecrets[ni] = dup(secret) 444 | delete(priv.privateKeyCache, ni) 445 | } 446 | 447 | // TODO(RLB) Onece the spec is updated to have EncryptedPathSecrets as a map, 448 | // change the TreeKEMPublicKey argument to just be a size. 449 | func (priv *TreeKEMPrivateKey) Decap(from LeafIndex, pub TreeKEMPublicKey, context []byte, path DirectPath) error { 450 | // Decrypt a path secret 451 | ni := toNodeIndex(priv.Index) 452 | dp := dirpath(toNodeIndex(from), pub.Size()) 453 | if len(dp) != len(path.Steps) { 454 | return fmt.Errorf("Malformed DirectPath %d %d", len(dp), len(path.Steps)) 455 | } 456 | 457 | dpIndex := -1 458 | last := toNodeIndex(from) 459 | var overlap, copath NodeIndex 460 | for i, n := range dp { 461 | if inPath(ni, n) { 462 | dpIndex = i 463 | overlap = n 464 | copath = sibling(last, pub.Size()) 465 | break 466 | } 467 | 468 | last = n 469 | } 470 | 471 | if dpIndex < 0 { 472 | return fmt.Errorf("No overlap in path") 473 | } 474 | 475 | res := pub.resolve(copath) 476 | if len(res) != len(path.Steps[dpIndex].EncryptedPathSecrets) { 477 | return fmt.Errorf("Malformed DirectPathNode %d %d", len(res), len(path.Steps[dpIndex].EncryptedPathSecrets)) 478 | } 479 | 480 | var pathSecret []byte 481 | for i, ct := range path.Steps[dpIndex].EncryptedPathSecrets { 482 | n := res[i] 483 | if _, ok := priv.PathSecrets[n]; ok { 484 | nodePriv, err := priv.privateKey(n) 485 | if err != nil { 486 | return err 487 | } 488 | 489 | pathSecret, err = priv.Suite.hpke().Decrypt(nodePriv, context, ct) 490 | if err != nil { 491 | return err 492 | } 493 | } 494 | } 495 | 496 | if pathSecret == nil { 497 | return fmt.Errorf("Unable to decrypt path secret") 498 | } 499 | 500 | // TODO Check the accuracy of the public keys in the path 501 | 502 | // Hash toward the root 503 | priv.setPathSecrets(overlap, pub.Size(), pathSecret) 504 | return nil 505 | } 506 | 507 | func (priv TreeKEMPrivateKey) Clone() TreeKEMPrivateKey { 508 | out := TreeKEMPrivateKey{ 509 | Suite: priv.Suite, 510 | Index: priv.Index, 511 | PathSecrets: map[NodeIndex]Bytes1{}, 512 | privateKeyCache: map[NodeIndex]HPKEPrivateKey{}, 513 | } 514 | 515 | for n := range priv.PathSecrets { 516 | out.PathSecrets[n] = priv.PathSecrets[n] 517 | } 518 | 519 | for n := range priv.privateKeyCache { 520 | out.privateKeyCache[n] = priv.privateKeyCache[n] 521 | } 522 | 523 | return out 524 | } 525 | 526 | func (priv TreeKEMPrivateKey) dump(label string) { 527 | fmt.Printf("=== %s ===\n", label) 528 | fmt.Printf("suite=[%d] index=[%d]\n", priv.Suite, priv.Index) 529 | fmt.Printf("update=[%x]\n", priv.UpdateSecret) 530 | for n := range priv.PathSecrets { 531 | nodePriv, err := priv.privateKey(n) 532 | if err != nil { 533 | panic(err) 534 | } 535 | 536 | secret := priv.PathSecrets[n][:4] 537 | pub := nodePriv.PublicKey.Data[:4] 538 | fmt.Printf(" [%d] secret=%x... pub=%x...\n", n, secret, pub) 539 | } 540 | } 541 | 542 | func (priv TreeKEMPrivateKey) Consistent(other TreeKEMPrivateKey) bool { 543 | if priv.Suite != other.Suite { 544 | return false 545 | } 546 | 547 | if !bytes.Equal(priv.UpdateSecret, other.UpdateSecret) { 548 | return false 549 | } 550 | 551 | overlap := map[NodeIndex]bool{} 552 | for n := range priv.PathSecrets { 553 | if _, ok := other.PathSecrets[n]; ok { 554 | overlap[n] = true 555 | } 556 | } 557 | if len(overlap) == 0 { 558 | return false 559 | } 560 | 561 | for n := range overlap { 562 | if !bytes.Equal(priv.PathSecrets[n], other.PathSecrets[n]) { 563 | return false 564 | } 565 | } 566 | 567 | return true 568 | } 569 | 570 | func (priv TreeKEMPrivateKey) ConsistentPub(pub TreeKEMPublicKey) bool { 571 | if priv.Suite != pub.Suite { 572 | return false 573 | } 574 | 575 | for n := range priv.PathSecrets { 576 | nodePriv, err := priv.privateKey(n) 577 | if err != nil { 578 | return false 579 | } 580 | 581 | if pub.Nodes[n].Blank() { 582 | return false 583 | } 584 | 585 | lhs := nodePriv.PublicKey 586 | rhs := pub.Nodes[n].Node.PublicKey() 587 | 588 | if pub.Nodes[n].Blank() || !lhs.Equals(rhs) { 589 | return false 590 | } 591 | } 592 | 593 | return true 594 | } 595 | 596 | //////////////////////////////////////////////////////////// 597 | //////////////////////////////////////////////////////////// 598 | //////////////////////////////////////////////////////////// 599 | 600 | type TreeKEMPublicKey struct { 601 | Suite CipherSuite `tls:"omit"` 602 | Nodes []OptionalNode `tls:"head=4"` 603 | } 604 | 605 | func NewTreeKEMPublicKey(suite CipherSuite) *TreeKEMPublicKey { 606 | return &TreeKEMPublicKey{Suite: suite} 607 | } 608 | 609 | func (pub *TreeKEMPublicKey) AddLeaf(keyPkg KeyPackage) LeafIndex { 610 | // Find the leftmost free leaf 611 | index := LeafIndex(0) 612 | size := LeafIndex(pub.Size()) 613 | for index < size && !pub.Nodes[toNodeIndex(index)].Blank() { 614 | index++ 615 | } 616 | 617 | // Extend the tree if necessary 618 | n := toNodeIndex(index) 619 | for len(pub.Nodes) < int(n)+1 { 620 | pub.Nodes = append(pub.Nodes, OptionalNode{}) 621 | } 622 | 623 | pub.Nodes[n] = newLeafNode(keyPkg) 624 | 625 | // update unmerged list 626 | dp := dirpath(n, pub.Size()) 627 | for _, v := range dp { 628 | if v == toNodeIndex(index) || pub.Nodes[v].Node == nil { 629 | continue 630 | } 631 | pub.Nodes[v].Node.Parent.AddUnmerged(index) 632 | } 633 | 634 | pub.clearHashPath(index) 635 | return index 636 | } 637 | 638 | func (pub *TreeKEMPublicKey) UpdateLeaf(index LeafIndex, keyPkg KeyPackage) { 639 | pub.BlankPath(index) 640 | pub.Nodes[toNodeIndex(index)] = newLeafNode(keyPkg) 641 | pub.clearHashPath(index) 642 | } 643 | 644 | func (pub *TreeKEMPublicKey) BlankPath(index LeafIndex) { 645 | if len(pub.Nodes) == 0 { 646 | return 647 | } 648 | 649 | ni := toNodeIndex(index) 650 | 651 | pub.Nodes[ni].SetToBlank() 652 | 653 | for _, n := range dirpath(ni, pub.Size()) { 654 | pub.Nodes[n].SetToBlank() 655 | } 656 | } 657 | 658 | func (pub TreeKEMPublicKey) Encap(from LeafIndex, context, leafSecret []byte, leafSigPriv SignaturePrivateKey, opts *KeyPackageOpts) (*TreeKEMPrivateKey, *DirectPath, error) { 659 | // Generate path secrets 660 | priv := NewTreeKEMPrivateKey(pub.Suite, pub.Size(), from, leafSecret) 661 | 662 | // Package into a DirectPath 663 | dp := dirpath(toNodeIndex(from), pub.Size()) 664 | path := &DirectPath{ 665 | LeafKeyPackage: *pub.Nodes[toNodeIndex(from)].Node.Leaf, 666 | Steps: make([]DirectPathNode, len(dp)), 667 | } 668 | last := toNodeIndex(from) 669 | for i, n := range dp { 670 | nodePriv, err := priv.privateKey(n) 671 | if err != nil { 672 | return nil, nil, err 673 | } 674 | 675 | path.Steps[i] = DirectPathNode{ 676 | PublicKey: nodePriv.PublicKey, 677 | EncryptedPathSecrets: []HPKECiphertext{}, 678 | } 679 | 680 | pathSecret := priv.PathSecrets[n] 681 | 682 | copath := sibling(last, pub.Size()) 683 | res := pub.resolve(copath) 684 | path.Steps[i].EncryptedPathSecrets = make([]HPKECiphertext, len(res)) 685 | for j, nr := range res { 686 | nodePub := pub.Nodes[nr].Node.PublicKey() 687 | path.Steps[i].EncryptedPathSecrets[j], err = pub.Suite.hpke().Encrypt(nodePub, context, pathSecret) 688 | if err != nil { 689 | return nil, nil, err 690 | } 691 | } 692 | 693 | last = n 694 | } 695 | 696 | // Sign the DirectPath 697 | leafPriv, err := priv.privateKey(toNodeIndex(from)) 698 | if err != nil { 699 | return nil, nil, err 700 | } 701 | 702 | err = path.Sign(pub.Suite, leafPriv.PublicKey, leafSigPriv, opts) 703 | if err != nil { 704 | return nil, nil, err 705 | } 706 | 707 | // Update the public key itself 708 | err = pub.Merge(from, *path) 709 | if err != nil { 710 | return nil, nil, err 711 | } 712 | 713 | // XXX(RLB): Should be possible to make a more targeted change, e.g., clearHashPath(from) 714 | pub.clearHashAll() 715 | pub.SetHashAll() 716 | return priv, path, nil 717 | } 718 | 719 | func (pub *TreeKEMPublicKey) Merge(from LeafIndex, path DirectPath) error { 720 | ni := toNodeIndex(from) 721 | pub.Nodes[ni] = newLeafNode(path.LeafKeyPackage) 722 | 723 | dp := dirpath(ni, pub.Size()) 724 | if len(dp) != len(path.Steps) { 725 | return fmt.Errorf("Malformed DirectPath %d %d", len(dp), len(path.Steps)) 726 | } 727 | 728 | for i, n := range dp { 729 | pub.Nodes[n] = newParentNode(path.Steps[i].PublicKey) 730 | } 731 | 732 | // XXX(RLB): Should be possible to make a more targeted change, e.g., clearHashPath(from) 733 | pub.clearHashAll() 734 | pub.SetHashAll() 735 | return nil 736 | } 737 | 738 | func (pub TreeKEMPublicKey) Size() LeafCount { 739 | return leafWidth(nodeCount(len(pub.Nodes))) 740 | } 741 | 742 | func (pub TreeKEMPublicKey) Clone() TreeKEMPublicKey { 743 | next := TreeKEMPublicKey{ 744 | Suite: pub.Suite, 745 | Nodes: make([]OptionalNode, len(pub.Nodes)), 746 | } 747 | 748 | for i, n := range pub.Nodes { 749 | next.Nodes[i] = n.Clone() 750 | } 751 | 752 | return next 753 | } 754 | 755 | func (pub TreeKEMPublicKey) Equals(o TreeKEMPublicKey) bool { 756 | if len(pub.Nodes) != len(o.Nodes) { 757 | return false 758 | } 759 | 760 | for i := 0; i < len(pub.Nodes); i++ { 761 | if !pub.Nodes[i].Node.Equals(o.Nodes[i].Node) { 762 | return false 763 | } 764 | } 765 | return true 766 | } 767 | 768 | func (pub TreeKEMPublicKey) KeyPackage(index LeafIndex) (KeyPackage, bool) { 769 | ni := toNodeIndex(index) 770 | if pub.Nodes[ni].Blank() { 771 | return KeyPackage{}, false 772 | } 773 | 774 | return *pub.Nodes[ni].Node.Leaf, true 775 | } 776 | 777 | func (pub TreeKEMPublicKey) Find(kp KeyPackage) (LeafIndex, bool) { 778 | num := pub.Size() 779 | for i := LeafIndex(0); LeafCount(i) < num; i++ { 780 | ni := toNodeIndex(i) 781 | n := pub.Nodes[ni] 782 | if n.Blank() { 783 | continue 784 | } 785 | 786 | if n.Node.Leaf.Equals(kp) { 787 | return i, true 788 | } 789 | } 790 | 791 | return 0, false 792 | } 793 | 794 | func (pub TreeKEMPublicKey) resolve(index NodeIndex) []NodeIndex { 795 | // Resolution of non-blank is node + unmerged leaves 796 | if !pub.Nodes[index].Blank() { 797 | res := []NodeIndex{index} 798 | if level(index) > 0 { 799 | for _, v := range pub.Nodes[index].Node.Parent.UnmergedLeaves { 800 | res = append(res, toNodeIndex(v)) 801 | } 802 | } 803 | return res 804 | } 805 | 806 | // Resolution of blank leaf is the empty list 807 | if level(index) == 0 { 808 | return []NodeIndex{} 809 | } 810 | 811 | // Resolution of blank intermediate node is concatenation of the resolutions 812 | // of the children 813 | l := pub.resolve(left(index)) 814 | r := pub.resolve(right(index, pub.Size())) 815 | l = append(l, r...) 816 | return l 817 | } 818 | 819 | func (pub *TreeKEMPublicKey) clearHashAll() { 820 | for n := range pub.Nodes { 821 | pub.Nodes[n].Hash = nil 822 | } 823 | } 824 | 825 | func (pub *TreeKEMPublicKey) clearHashPath(index LeafIndex) { 826 | ni := toNodeIndex(index) 827 | pub.Nodes[ni].Hash = nil 828 | 829 | for _, n := range dirpath(ni, pub.Size()) { 830 | pub.Nodes[n].Hash = nil 831 | } 832 | } 833 | 834 | func (pub TreeKEMPublicKey) RootHash() []byte { 835 | h, err := pub.getHash(root(pub.Size())) 836 | if err != nil { 837 | // XXX(RLB) 838 | panic(err) 839 | } 840 | 841 | return h 842 | } 843 | 844 | func (pub *TreeKEMPublicKey) SetHashAll() error { 845 | _, err := pub.getHash(root(pub.Size())) 846 | return err 847 | } 848 | 849 | func (pub *TreeKEMPublicKey) getHash(index NodeIndex) ([]byte, error) { 850 | if pub.Nodes[index].Hash != nil { 851 | return pub.Nodes[index].Hash, nil 852 | } 853 | 854 | if level(index) == 0 { 855 | err := pub.Nodes[index].SetLeafNodeHash(pub.Suite, toLeafIndex(index)) 856 | return pub.Nodes[index].Hash, err 857 | } 858 | 859 | lh, err := pub.getHash(left(index)) 860 | if err != nil { 861 | return nil, err 862 | } 863 | 864 | rh, err := pub.getHash(right(index, pub.Size())) 865 | if err != nil { 866 | return nil, err 867 | } 868 | 869 | err = pub.Nodes[index].SetParentNodeHash(pub.Suite, index, lh, rh) 870 | return pub.Nodes[index].Hash, err 871 | } 872 | 873 | func (pub *TreeKEMPublicKey) setHash(index NodeIndex) error { 874 | if level(index) == 0 { 875 | return pub.Nodes[index].SetLeafNodeHash(pub.Suite, toLeafIndex(index)) 876 | } 877 | 878 | if pub.Nodes[index].Hash != nil { 879 | return nil 880 | } 881 | 882 | li := left(index) 883 | lh := pub.Nodes[li].Hash 884 | if lh == nil { 885 | if err := pub.setHash(li); err != nil { 886 | return err 887 | } 888 | } 889 | 890 | ri := right(index, pub.Size()) 891 | rh := pub.Nodes[ri].Hash 892 | if rh == nil { 893 | if err := pub.setHash(ri); err != nil { 894 | return err 895 | } 896 | } 897 | 898 | return pub.Nodes[index].SetParentNodeHash(pub.Suite, index, lh, rh) 899 | } 900 | 901 | func (pub *TreeKEMPublicKey) dump(label string) { 902 | fmt.Printf("~~~ %s ~~~\n", label) 903 | fmt.Printf("&pub = %p\n", pub) 904 | 905 | for i, n := range pub.Nodes { 906 | hash := "-" 907 | if len(n.Hash) > 0 { 908 | hash = fmt.Sprintf("%x", n.Hash[:4]) 909 | } 910 | 911 | if n.Blank() { 912 | fmt.Printf(" [%d] <%s> _\n", i, hash) 913 | continue 914 | } 915 | 916 | pub := n.Node.PublicKey().Data[:4] 917 | fmt.Printf(" [%d] <%s> %x...\n", i, hash, pub) 918 | } 919 | } 920 | -------------------------------------------------------------------------------- /treekem_test.go: -------------------------------------------------------------------------------- 1 | package mls 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func newKeyPackage(t *testing.T) ([]byte, SignaturePrivateKey, *KeyPackage) { 10 | secret := randomBytes(32) 11 | 12 | initPriv, err := suite.hpke().Derive(secret) 13 | require.Nil(t, err) 14 | 15 | sigPriv, err := suite.Scheme().Derive(secret) 16 | require.Nil(t, err) 17 | 18 | cred := NewBasicCredential(userID, suite.Scheme(), sigPriv.PublicKey) 19 | 20 | kp, err := NewKeyPackageWithInitKey(suite, initPriv.PublicKey, cred, sigPriv) 21 | 22 | return secret, sigPriv, kp 23 | } 24 | 25 | func TestTreeKEM(t *testing.T) { 26 | groupSize := 10 27 | var err error 28 | 29 | pub := NewTreeKEMPublicKey(suite) 30 | privs := make([]*TreeKEMPrivateKey, groupSize) 31 | sigPrivs := make([]SignaturePrivateKey, groupSize) 32 | 33 | // Make a new one-person pub + priv 34 | secret, sigPriv, kp := newKeyPackage(t) 35 | sigPrivs[0] = sigPriv 36 | 37 | index := pub.AddLeaf(*kp) 38 | require.Equal(t, index, LeafIndex(0)) 39 | 40 | privs[0] = NewTreeKEMPrivateKey(suite, pub.Size(), index, secret) 41 | require.True(t, privs[0].ConsistentPub(*pub)) 42 | 43 | // Each member adds the next 44 | var path *DirectPath 45 | for i := 0; i < groupSize-1; i++ { 46 | adder := LeafIndex(i) 47 | joiner := LeafIndex(i + 1) 48 | context := []byte{byte(i)} 49 | secret, sigPriv, kp := newKeyPackage(t) 50 | sigPrivs[i+1] = sigPriv 51 | 52 | index := pub.AddLeaf(*kp) 53 | require.Equal(t, index, joiner) 54 | 55 | // Add the new joiner 56 | leafSecret := randomBytes(32) 57 | privs[i], path, err = pub.Encap(adder, context, leafSecret, sigPrivs[i], nil) 58 | require.Nil(t, err) 59 | require.Nil(t, path.ParentHashValid(suite)) 60 | 61 | err = pub.Merge(adder, *path) 62 | require.Nil(t, err) 63 | require.True(t, privs[i].ConsistentPub(*pub)) 64 | 65 | overlap, pathSecret, ok := privs[i].SharedPathSecret(joiner) 66 | require.True(t, ok) 67 | require.NotNil(t, pathSecret) 68 | 69 | // New joiner initializes their private key 70 | privs[i+1] = NewTreeKEMPrivateKeyForJoiner(suite, joiner, pub.Size(), secret, overlap, pathSecret) 71 | require.True(t, privs[i+1].Consistent(*privs[i])) 72 | require.True(t, privs[i+1].ConsistentPub(*pub)) 73 | 74 | // Other members update their private keys 75 | for j := 0; j < i; j++ { 76 | err = privs[j].Decap(adder, *pub, context, *path) 77 | require.Nil(t, err) 78 | require.True(t, privs[j].Consistent(*privs[i])) 79 | require.True(t, privs[j].ConsistentPub(*pub)) 80 | } 81 | } 82 | } 83 | 84 | func generateRatchetTreeVectors(t *testing.T) []byte { 85 | return nil // TODO(RLB) 86 | } 87 | 88 | func verifyRatchetTreeVectors(t *testing.T, data []byte) { 89 | // TODO(RLB) 90 | } 91 | --------------------------------------------------------------------------------