├── .gitignore ├── README.md ├── cmd └── main.go ├── go.mod ├── go.sum └── mmtls ├── client_finish.go ├── client_hello.go ├── const.go ├── mmtls.go ├── mmtls_short.go ├── record.go ├── server_finish.go ├── server_hello.go ├── session.go ├── session_ticket.go ├── signature.go └── utility.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac OS X files 2 | .DS_Store 3 | 4 | # Dependency directories (remove the comment below to include it) 5 | vendor/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 微信 mmtls 协议的 go 语言实现 2 | 3 | 参考 https://github.com/anonymous5l/mmtls 4 | 5 | 更新到了版本 0xF104 的协议支持, 修复了些 Bug 6 | 7 | 仅用于协议的研究学习 8 | 9 | 目前支持 10 | - [x] 1-RTT ECDHE 11 | - [x] 1-RTT PSK 12 | - [x] 0-RTT PSK 13 | -------------------------------------------------------------------------------- /cmd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/hex" 5 | 6 | "github.com/duo/gommtls/mmtls" 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | func main() { 11 | log.SetFormatter(&log.TextFormatter{ 12 | FullTimestamp: true, 13 | }) 14 | log.SetLevel(log.DebugLevel) 15 | 16 | { 17 | client := mmtls.NewMMTLSClient() 18 | 19 | defer client.Close() 20 | 21 | if session, err := mmtls.LoadSession("session"); err == nil { 22 | client.Session = session 23 | } 24 | 25 | if err := client.Handshake("long.weixin.qq.com:80"); err != nil { 26 | panic(err) 27 | } 28 | 29 | if client.Session != nil { 30 | client.Session.Save("session") 31 | } 32 | 33 | if err := client.Noop(); err != nil { 34 | panic(err) 35 | } 36 | } 37 | 38 | { 39 | client := mmtls.NewMMTLSClientShort() 40 | 41 | if session, err := mmtls.LoadSession("session"); err == nil { 42 | client.Session = session 43 | } 44 | 45 | defer client.Close() 46 | 47 | response, err := client.Request( 48 | "dns.weixin.qq.com.cn", 49 | "/cgi-bin/micromsg-bin/newgetdns", 50 | nil, 51 | ) 52 | if err != nil { 53 | panic(err) 54 | } 55 | 56 | log.Debugf("Response:\n%s\n", hex.Dump(response)) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/duo/gommtls 2 | 3 | go 1.18 4 | 5 | require github.com/sirupsen/logrus v1.8.1 6 | 7 | require golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 8 | 9 | require golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8 // indirect 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= 6 | github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= 7 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 8 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 9 | golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o= 10 | golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 11 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 12 | golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8 h1:OH54vjqzRWmbJ62fjuhxy7AxFFgoHN0/DPc/UrL8cAs= 13 | golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 14 | -------------------------------------------------------------------------------- /mmtls/client_finish.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import "encoding/binary" 4 | 5 | type clientFinish struct { 6 | reversed byte 7 | data []byte 8 | } 9 | 10 | func newClientFinish(data []byte) *clientFinish { 11 | return &clientFinish{ 12 | reversed: 0x14, 13 | data: data, 14 | } 15 | } 16 | 17 | func (c *clientFinish) serialize() []byte { 18 | buf := make([]byte, len(c.data)+7) 19 | 20 | binary.BigEndian.PutUint32(buf, uint32(len(c.data)+3)) 21 | 22 | buf[4] = c.reversed 23 | 24 | binary.BigEndian.PutUint16(buf[5:], uint16(len(c.data))) 25 | 26 | copy(buf[7:], c.data) 27 | 28 | return buf 29 | } 30 | -------------------------------------------------------------------------------- /mmtls/client_hello.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/elliptic" 6 | "crypto/tls" 7 | "encoding/binary" 8 | "fmt" 9 | "time" 10 | ) 11 | 12 | type clientHello struct { 13 | protocolVersion uint16 14 | cipherSuites []uint16 15 | random []byte 16 | timestamp uint32 17 | extensions map[uint16][][]byte 18 | } 19 | 20 | // 1-RTT ECDHE 21 | func newECDHEHello(cliPubKey *ecdsa.PublicKey, cliVerKey *ecdsa.PublicKey) *clientHello { 22 | ch := &clientHello{} 23 | 24 | ch.protocolVersion = ProtocolVersion 25 | ch.timestamp = uint32(time.Now().Unix()) 26 | ch.random = getRandom(32) 27 | ch.cipherSuites = []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256} 28 | 29 | ch.extensions = make(map[uint16][][]byte) 30 | ch.extensions[tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256] = [][]byte{ 31 | elliptic.Marshal(cliPubKey.Curve, cliPubKey.X, cliPubKey.Y), 32 | elliptic.Marshal(cliVerKey.Curve, cliVerKey.X, cliVerKey.Y), 33 | } 34 | 35 | return ch 36 | } 37 | 38 | // 1-RTT PSK 39 | func newPskOneHello(cliPubKey *ecdsa.PublicKey, cliVerKey *ecdsa.PublicKey, ticket *sessionTicket) *clientHello { 40 | ch := &clientHello{} 41 | 42 | ch.protocolVersion = ProtocolVersion 43 | ch.timestamp = uint32(time.Now().Unix()) 44 | ch.random = getRandom(32) 45 | ch.cipherSuites = []uint16{ 46 | tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 47 | TLS_PSK_WITH_AES_128_GCM_SHA256, 48 | } 49 | 50 | t := ticket 51 | t.ticketAgeAdd = make([]byte, 0) 52 | ticketData, _ := t.serialize() 53 | 54 | ch.extensions = make(map[uint16][][]byte) 55 | ch.extensions[TLS_PSK_WITH_AES_128_GCM_SHA256] = [][]byte{ 56 | ticketData, 57 | } 58 | ch.extensions[tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256] = [][]byte{ 59 | elliptic.Marshal(cliPubKey.Curve, cliPubKey.X, cliPubKey.Y), 60 | elliptic.Marshal(cliVerKey.Curve, cliVerKey.X, cliVerKey.Y), 61 | } 62 | 63 | return ch 64 | } 65 | 66 | // 0-RTT PSK 67 | func newPskZeroHello(ticket *sessionTicket) *clientHello { 68 | ch := &clientHello{} 69 | 70 | ch.protocolVersion = ProtocolVersion 71 | ch.timestamp = uint32(time.Now().Unix()) 72 | ch.random = getRandom(32) 73 | ch.cipherSuites = []uint16{TLS_PSK_WITH_AES_128_GCM_SHA256} 74 | 75 | t := ticket 76 | t.ticketAgeAdd = make([]byte, 0) 77 | ticketData, _ := t.serialize() 78 | 79 | ch.extensions = make(map[uint16][][]byte) 80 | ch.extensions[TLS_PSK_WITH_AES_128_GCM_SHA256] = [][]byte{ 81 | ticketData, 82 | } 83 | 84 | return ch 85 | } 86 | 87 | func (c *clientHello) serialize() []byte { 88 | buf := make([]byte, 0, 512) 89 | 90 | // total length 91 | buf = append(buf, 0x00, 0x00, 0x00, 0x00) 92 | // flag ? 93 | buf = append(buf, 0x01) 94 | 95 | // protocol version 96 | buf = append(buf, 0x00, 0x00) 97 | binary.LittleEndian.PutUint16(buf[len(buf)-2:], c.protocolVersion) 98 | 99 | // cipher suites 100 | buf = append(buf, byte(len(c.cipherSuites))) 101 | for _, v := range c.cipherSuites { 102 | buf = append(buf, 0x00, 0x00) 103 | binary.BigEndian.PutUint16(buf[len(buf)-2:], v) 104 | } 105 | 106 | // random 107 | buf = append(buf, c.random...) 108 | 109 | // timestamp 110 | buf = append(buf, 0x00, 0x00, 0x00, 0x00) 111 | binary.BigEndian.PutUint32(buf[len(buf)-4:], uint32(c.timestamp)) 112 | 113 | cipherPos := len(buf) 114 | buf = append(buf, 0x00, 0x00, 0x00, 0x00) 115 | buf = append(buf, byte(len(c.cipherSuites))) 116 | 117 | for i := len(c.cipherSuites) - 1; i >= 0; i-- { 118 | cipher := c.cipherSuites[i] 119 | if cipher == TLS_PSK_WITH_AES_128_GCM_SHA256 { 120 | pskPos := len(buf) 121 | buf = append(buf, 0x00, 0x00, 0x00, 0x00) 122 | buf = append(buf, 0x00, 0x0F) // cipher type? 123 | buf = append(buf, 0x01) 124 | 125 | keyPos := len(buf) 126 | buf = append(buf, 0x00, 0x00, 0x00, 0x00) 127 | 128 | buf = append(buf, c.extensions[cipher][0]...) 129 | binary.BigEndian.PutUint32(buf[keyPos:], uint32(len(buf)-keyPos-4)) 130 | 131 | binary.BigEndian.PutUint32(buf[pskPos:], uint32(len(buf)-pskPos-4)) 132 | } else if cipher == tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 { 133 | // ECDSA keys 134 | ecdsaPos := len(buf) 135 | buf = append(buf, 0x00, 0x00, 0x00, 0x00) 136 | buf = append(buf, 0x00, 0x10) // cipher type? 137 | buf = append(buf, byte(len(c.extensions[cipher]))) 138 | 139 | var keyFlag uint32 = 5 140 | for _, v := range c.extensions[cipher] { 141 | keyPos := len(buf) 142 | buf = append(buf, 0x00, 0x00, 0x00, 0x00) 143 | 144 | buf = append(buf, 0x00, 0x00, 0x00, 0x00) 145 | binary.BigEndian.PutUint32(buf[len(buf)-4:], keyFlag) 146 | keyFlag += 1 147 | 148 | buf = append(buf, 0x00, 0x00) 149 | binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(len(v))) 150 | 151 | buf = append(buf, v...) 152 | 153 | binary.BigEndian.PutUint32(buf[keyPos:], uint32(len(buf)-keyPos-4)) 154 | } 155 | 156 | // magic... 157 | buf = append(buf, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04) 158 | 159 | // ecdsa length 160 | binary.BigEndian.PutUint32(buf[ecdsaPos:], uint32(len(buf)-ecdsaPos-4)) 161 | } else { 162 | panic(fmt.Sprintf("cipher(%d) not support", cipher)) 163 | } 164 | } 165 | 166 | // cipher length 167 | binary.BigEndian.PutUint32(buf[cipherPos:], uint32(len(buf)-cipherPos-4)) 168 | 169 | // struct length 170 | binary.BigEndian.PutUint32(buf[0:], uint32(len(buf)-4)) 171 | 172 | return buf 173 | } 174 | -------------------------------------------------------------------------------- /mmtls/const.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/elliptic" 6 | "math/big" 7 | ) 8 | 9 | const ( 10 | ProtocolVersion uint16 = 0xF104 11 | 12 | TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0xA8 13 | 14 | MagicAbort uint8 = 0x15 15 | MagicHandshake uint8 = 0x16 16 | MagicRecord uint8 = 0x17 17 | MagicSystem uint8 = 0x19 18 | ) 19 | 20 | const ( 21 | TCP_NoopRequest uint32 = 0x6 22 | TCP_NoopResponse uint32 = 0x3B9ACA06 23 | ) 24 | 25 | var ( 26 | ServerEcdh *ecdsa.PublicKey = &ecdsa.PublicKey{ 27 | Curve: elliptic.P256(), 28 | X: bigintFromHex("1da177b6a5ed34dabb3f2b047697ca8bbeb78c68389ced43317a298d77316d54"), 29 | Y: bigintFromHex("4175c032bc573d5ce4b3ac0b7f2b9a8d48ca4b990ce2fa3ce75cc9d12720fa35"), 30 | } 31 | ) 32 | 33 | func bigintFromHex(s string) *big.Int { 34 | b := big.NewInt(0) 35 | b.SetString(s, 16) 36 | return b 37 | } 38 | -------------------------------------------------------------------------------- /mmtls/mmtls.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "bytes" 5 | "crypto/ecdsa" 6 | "crypto/elliptic" 7 | "crypto/hmac" 8 | "crypto/rand" 9 | "crypto/sha256" 10 | "encoding/binary" 11 | "encoding/hex" 12 | "errors" 13 | "hash" 14 | "io" 15 | "math/big" 16 | "net" 17 | "sync/atomic" 18 | 19 | log "github.com/sirupsen/logrus" 20 | "golang.org/x/crypto/hkdf" 21 | ) 22 | 23 | var ( 24 | curve = elliptic.P256() 25 | ) 26 | 27 | type MMTLSClient struct { 28 | conn net.Conn 29 | 30 | status int32 31 | 32 | publicEcdh *ecdsa.PrivateKey 33 | verifyEcdh *ecdsa.PrivateKey 34 | serverEcdh *ecdsa.PublicKey 35 | 36 | handshakeHasher hash.Hash 37 | 38 | serverSeqNum uint32 39 | clientSeqNum uint32 40 | 41 | Session *Session 42 | } 43 | 44 | func NewMMTLSClient() *MMTLSClient { 45 | c := &MMTLSClient{} 46 | 47 | c.handshakeHasher = sha256.New() 48 | 49 | return c 50 | } 51 | 52 | func (c *MMTLSClient) Handshake(host string) error { 53 | if c.conn == nil { 54 | conn, err := net.Dial("tcp", host) 55 | if err != nil { 56 | return err 57 | } 58 | 59 | c.conn = conn 60 | } 61 | 62 | if c.handshakeComplete() { 63 | return nil 64 | } 65 | 66 | c.reset() 67 | 68 | if err := c.genKeyPairs(); err != nil { 69 | return err 70 | } 71 | 72 | var ch *clientHello 73 | if c.Session != nil && len(c.Session.tk.tickets) > 1 { 74 | log.Info("1-RTT PSK handshake") 75 | ch = newPskOneHello(&c.publicEcdh.PublicKey, &c.verifyEcdh.PublicKey, &c.Session.tk.tickets[1]) 76 | } else { 77 | log.Info("1-RTT ECDHE handshake") 78 | ch = newECDHEHello(&c.publicEcdh.PublicKey, &c.verifyEcdh.PublicKey) 79 | } 80 | if err := c.sendClientHello(ch); err != nil { 81 | return err 82 | } 83 | 84 | serverHello, err := c.readServerHello() 85 | if err != nil { 86 | return err 87 | } 88 | 89 | // DH compute key 90 | comKey := c.computeEphemeralSecret( 91 | serverHello.publicKey.X, 92 | serverHello.publicKey.Y, 93 | c.publicEcdh.D) 94 | 95 | // trafffic key 96 | trafficKey, err := c.computeTrafficKey( 97 | comKey, 98 | c.hkdfExpand("handshake key expansion", c.handshakeHasher)) 99 | if err != nil { 100 | return nil 101 | } 102 | 103 | // compare traffic key is valid 104 | if err := c.readSignature(trafficKey); err != nil { 105 | return err 106 | } 107 | 108 | // gen psk 109 | if err := c.readNewSessionTicket(comKey, trafficKey); err != nil { 110 | return err 111 | } 112 | 113 | if err := c.readServerFinish(comKey, trafficKey); err != nil { 114 | return err 115 | } 116 | 117 | if err := c.sendClientFinish(comKey, trafficKey); err != nil { 118 | return err 119 | } 120 | 121 | // ComputeMasterSecre 122 | expandedSecret := make([]byte, 32) 123 | hkdf.Expand( 124 | sha256.New, 125 | comKey, 126 | c.hkdfExpand("expanded secret", c.handshakeHasher)).Read(expandedSecret) 127 | 128 | // AppKey 129 | appKey, err := c.computeTrafficKey( 130 | expandedSecret, 131 | c.hkdfExpand("application data key expansion", c.handshakeHasher)) 132 | if err != nil { 133 | return err 134 | } 135 | c.Session.appKey = appKey 136 | 137 | // fully complete handshake 138 | atomic.StoreInt32(&c.status, 1) 139 | 140 | return nil 141 | } 142 | 143 | func (c *MMTLSClient) Noop() error { 144 | if err := c.sendNoop(); err != nil { 145 | return err 146 | } 147 | 148 | if err := c.readNoop(); err != nil { 149 | return err 150 | } 151 | 152 | return nil 153 | } 154 | 155 | func (c *MMTLSClient) Close() error { 156 | if c.conn != nil { 157 | log.Debug("Close connection...") 158 | return c.conn.Close() 159 | } 160 | return nil 161 | } 162 | 163 | func (c *MMTLSClient) reset() { 164 | c.handshakeHasher.Reset() 165 | 166 | c.clientSeqNum = 0 167 | c.serverSeqNum = 0 168 | } 169 | 170 | func (c *MMTLSClient) handshakeComplete() bool { 171 | return atomic.LoadInt32(&c.status) == 1 172 | } 173 | 174 | func (c *MMTLSClient) sendClientHello(hello *clientHello) error { 175 | data := hello.serialize() 176 | 177 | c.handshakeHasher.Write(data) 178 | 179 | packet := createHandshakeRecord(data).serialize() 180 | log.Debugf("Send ClientHello packet(%d):\n%s", len(packet), hex.Dump(packet)) 181 | 182 | _, err := c.conn.Write(packet) 183 | 184 | c.clientSeqNum++ 185 | 186 | return err 187 | } 188 | 189 | func (c *MMTLSClient) readServerHello() (*serverHello, error) { 190 | record, err := c.readRecord() 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | c.handshakeHasher.Write(record.data) 196 | c.serverSeqNum++ 197 | 198 | return readServerHello(record.data) 199 | } 200 | 201 | func (c *MMTLSClient) readSignature(trafficKey *trafficKeyPair) error { 202 | record, err := c.readRecord() 203 | if err != nil { 204 | return err 205 | } 206 | 207 | if err := record.decrypt(trafficKey, c.serverSeqNum); err != nil { 208 | return err 209 | } 210 | 211 | signature, err := readSignature(record.data) 212 | if err != nil { 213 | return err 214 | } 215 | 216 | if !c.verifyEcdsa(signature.EcdsaSignature) { 217 | return errors.New("verify signature failed") 218 | } 219 | 220 | c.handshakeHasher.Write(record.data) 221 | c.serverSeqNum++ 222 | 223 | return nil 224 | } 225 | 226 | func (c *MMTLSClient) readNewSessionTicket(comKey []byte, trafficKey *trafficKeyPair) error { 227 | record, err := c.readRecord() 228 | if err != nil { 229 | return err 230 | } 231 | 232 | if err := record.decrypt(trafficKey, c.serverSeqNum); err != nil { 233 | return err 234 | } 235 | 236 | tickets, err := readNewSessionTicket(record.data) 237 | if err != nil { 238 | return err 239 | } 240 | 241 | pskAccess := make([]byte, 32) 242 | hkdf.Expand( 243 | sha256.New, 244 | comKey, 245 | c.hkdfExpand("PSK_ACCESS", c.handshakeHasher)).Read(pskAccess) 246 | log.Debugf("PSK_ACCESS:\n%s\n", hex.Dump(pskAccess)) 247 | 248 | pskRefresh := make([]byte, 32) 249 | hkdf.Expand( 250 | sha256.New, 251 | comKey, 252 | c.hkdfExpand("PSK_REFRESH", c.handshakeHasher)).Read(pskRefresh) 253 | log.Debugf("PSK_REFRESH:\n%s\n", hex.Dump(pskRefresh)) 254 | 255 | c.Session = &Session{ 256 | tk: tickets, 257 | pskAccess: pskAccess, 258 | pskRefresh: pskRefresh, 259 | } 260 | 261 | c.handshakeHasher.Write(record.data) 262 | c.serverSeqNum++ 263 | 264 | return nil 265 | } 266 | 267 | func (c *MMTLSClient) readServerFinish(comKey []byte, trafficKey *trafficKeyPair) error { 268 | record, err := c.readRecord() 269 | if err != nil { 270 | return err 271 | } 272 | 273 | if err := record.decrypt(trafficKey, c.serverSeqNum); err != nil { 274 | return err 275 | } 276 | 277 | sf, err := ReadServerFinish(record.data) 278 | if err != nil { 279 | return nil 280 | } 281 | 282 | sfKey := make([]byte, 32) 283 | hkdf.Expand( 284 | sha256.New, 285 | comKey, 286 | c.hkdfExpand("server finished", nil)).Read(sfKey) 287 | 288 | securityParam := c.hmac(sfKey, c.handshakeHasher.Sum(nil)) 289 | 290 | if bytes.Compare(sf.data, securityParam) != 0 { 291 | return errors.New("security key not compare") 292 | } 293 | 294 | c.serverSeqNum++ 295 | 296 | return nil 297 | } 298 | 299 | func (c *MMTLSClient) sendClientFinish(comKey []byte, trafficKey *trafficKeyPair) error { 300 | cliKey := make([]byte, 32) 301 | hkdf.Expand( 302 | sha256.New, 303 | comKey, 304 | c.hkdfExpand("client finished", nil)).Read(cliKey) 305 | cliKey = c.hmac(cliKey, c.handshakeHasher.Sum(nil)) 306 | 307 | cf := newClientFinish(cliKey) 308 | 309 | cfRecord := createHandshakeRecord(cf.serialize()) 310 | if err := cfRecord.encrypt(trafficKey, c.clientSeqNum); err != nil { 311 | return err 312 | } 313 | 314 | packet := cfRecord.serialize() 315 | log.Debugf("Send ClientFinish packet(%d):\n%s", len(packet), hex.Dump(packet)) 316 | _, err := c.conn.Write(packet) 317 | 318 | c.clientSeqNum++ 319 | 320 | return err 321 | } 322 | 323 | func (c *MMTLSClient) sendNoop() error { 324 | noop := createDataRecord(TCP_NoopRequest, 0xFFFFFFFF, nil) 325 | noop.encrypt(c.Session.appKey, c.clientSeqNum) 326 | 327 | packet := noop.serialize() 328 | log.Debugf("Send Noop packet(%d):\n%s", len(packet), hex.Dump(packet)) 329 | _, err := c.conn.Write(packet) 330 | 331 | c.clientSeqNum++ 332 | 333 | return err 334 | } 335 | 336 | func (c *MMTLSClient) readNoop() error { 337 | record, err := c.readRecord() 338 | if err != nil { 339 | return err 340 | } 341 | 342 | if err := record.decrypt(c.Session.appKey, c.serverSeqNum); err != nil { 343 | return err 344 | } 345 | 346 | r := bytes.NewReader(record.data) 347 | 348 | var packLen uint32 349 | if err := binary.Read(r, binary.BigEndian, &packLen); err != nil { 350 | return err 351 | } 352 | if packLen != 16 { 353 | return errors.New("noop response packet length invalid") 354 | } 355 | 356 | // skip flag 357 | if _, err := r.Seek(4, io.SeekCurrent); err != nil { 358 | return err 359 | } 360 | 361 | var dataType uint32 362 | if err := binary.Read(r, binary.BigEndian, &dataType); err != nil { 363 | return err 364 | } 365 | if TCP_NoopResponse != dataType { 366 | return errors.New("noop response packet type mismatch") 367 | } 368 | 369 | c.serverSeqNum++ 370 | 371 | return nil 372 | } 373 | 374 | func (c *MMTLSClient) readRecord() (*mmtlsRecord, error) { 375 | header := make([]byte, 5) 376 | if _, err := io.ReadFull(c.conn, header); err != nil { 377 | return nil, err 378 | } 379 | 380 | packLen := binary.BigEndian.Uint16(header[3:]) 381 | 382 | payload := make([]byte, packLen) 383 | if _, err := io.ReadFull(c.conn, payload); err != nil { 384 | return nil, err 385 | } 386 | 387 | log.Debugf("Receive Packet Header(%d):\n%s", len(header), hex.Dump(header)) 388 | log.Debugf("Receive Packet payload(%d):\n%s", len(payload), hex.Dump(payload)) 389 | 390 | return readRecord(bytes.NewReader(append(header, payload...))) 391 | } 392 | 393 | func (c *MMTLSClient) computeEphemeralSecret(x, y, z *big.Int) []byte { 394 | r, _ := curve.ScalarMult(x, y, z.Bytes()) 395 | s := sha256.Sum256(r.Bytes()) 396 | return s[:] 397 | } 398 | 399 | func (c *MMTLSClient) computeTrafficKey(shareKey, info []byte) (*trafficKeyPair, error) { 400 | trafficKey := make([]byte, 56) 401 | if _, err := hkdf.Expand(sha256.New, shareKey, info).Read(trafficKey); err != nil { 402 | return nil, err 403 | } 404 | 405 | log.Debugf("TrafficKey:\n%s\n", hex.Dump(trafficKey)) 406 | 407 | pair := &trafficKeyPair{} 408 | pair.clientKey = trafficKey[:16] 409 | pair.serverKey = trafficKey[16:32] 410 | pair.clientNonce = trafficKey[32:44] 411 | pair.serverNonce = trafficKey[44:] 412 | 413 | return pair, nil 414 | } 415 | 416 | func (c *MMTLSClient) verifyEcdsa(data []byte) bool { 417 | dataHash := sha256.Sum256(c.handshakeHasher.Sum(nil)) 418 | return ecdsa.VerifyASN1(ServerEcdh, dataHash[:], data) 419 | } 420 | 421 | func (c *MMTLSClient) hkdfExpand(prefix string, hash hash.Hash) []byte { 422 | info := []byte(prefix) 423 | if hash != nil { 424 | info = append(info, hash.Sum(nil)...) 425 | } 426 | return info 427 | } 428 | 429 | func (c *MMTLSClient) hmac(k []byte, d []byte) []byte { 430 | hm := hmac.New(sha256.New, k) 431 | hm.Write(d) 432 | return hm.Sum(nil) 433 | } 434 | 435 | func (c *MMTLSClient) genKeyPairs() error { 436 | if c.publicEcdh == nil { 437 | public, err := ecdsa.GenerateKey(curve, rand.Reader) 438 | if err != nil { 439 | return err 440 | } 441 | c.publicEcdh = public 442 | } 443 | 444 | if c.verifyEcdh == nil { 445 | verify, err := ecdsa.GenerateKey(curve, rand.Reader) 446 | if err != nil { 447 | return err 448 | } 449 | c.verifyEcdh = verify 450 | } 451 | 452 | return nil 453 | } 454 | -------------------------------------------------------------------------------- /mmtls/mmtls_short.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "crypto/hmac" 7 | "crypto/rand" 8 | "crypto/sha256" 9 | "encoding/binary" 10 | "encoding/hex" 11 | "errors" 12 | "fmt" 13 | "hash" 14 | "io" 15 | "io/ioutil" 16 | "net" 17 | "net/http" 18 | "net/http/httputil" 19 | "net/url" 20 | 21 | log "github.com/sirupsen/logrus" 22 | "golang.org/x/crypto/hkdf" 23 | ) 24 | 25 | type MMTLSClientShort struct { 26 | conn net.Conn 27 | 28 | status int32 29 | 30 | packetReader io.Reader 31 | 32 | handshakeHasher hash.Hash 33 | 34 | serverSeqNum uint32 35 | clientSeqNum uint32 36 | 37 | Session *Session 38 | } 39 | 40 | func NewMMTLSClientShort() *MMTLSClientShort { 41 | c := &MMTLSClientShort{} 42 | 43 | c.handshakeHasher = sha256.New() 44 | 45 | return c 46 | } 47 | 48 | func (c *MMTLSClientShort) Request(host, path string, req []byte) ([]byte, error) { 49 | log.Info("0-RTT PSK handshake") 50 | if c.Session == nil { 51 | return nil, errors.New("0-RTT requires session") 52 | } 53 | 54 | if c.conn == nil { 55 | conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, 80)) 56 | if err != nil { 57 | return nil, err 58 | } 59 | 60 | c.conn = conn 61 | } 62 | 63 | httpPacket, err := c.packHttp(host, path, req) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | _, err = c.conn.Write(httpPacket) 69 | 70 | response, err := c.parseResponse(c.conn) 71 | log.Debugf("Receive response:\n%s\n", hex.Dump(response)) 72 | 73 | c.packetReader = bytes.NewReader(response) 74 | 75 | if err := c.readServerHello(); err != nil { 76 | return nil, err 77 | } 78 | 79 | // trafffic key 80 | trafficKey, err := c.computeTrafficKey( 81 | c.Session.pskAccess, 82 | c.hkdfExpand("handshake key expansion", c.handshakeHasher)) 83 | if err != nil { 84 | return nil, err 85 | } 86 | c.Session.appKey = trafficKey 87 | 88 | if err := c.readServerFinish(); err != nil { 89 | return nil, err 90 | } 91 | 92 | dataRecord, err := c.readDataRecord() 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | if err := c.readAbort(); err != nil { 98 | return nil, err 99 | } 100 | 101 | return dataRecord.data, nil 102 | } 103 | 104 | func (c *MMTLSClientShort) Close() error { 105 | if c.conn != nil { 106 | log.Debug("Close connection...") 107 | return c.conn.Close() 108 | } 109 | return nil 110 | } 111 | 112 | func (c *MMTLSClientShort) packHttp(host, path string, req []byte) ([]byte, error) { 113 | tlsPayload := make([]byte, 0) 114 | 115 | datPart, err := c.genDataPart(host, path, req) 116 | if err != nil { 117 | return nil, err 118 | } 119 | 120 | // ClientHello 121 | hello := newPskZeroHello(&c.Session.tk.tickets[0]) 122 | helloPart := hello.serialize() 123 | 124 | c.handshakeHasher.Write(helloPart) 125 | 126 | earlyKey, _ := c.earlyDataKey(c.Session.pskAccess, &c.Session.tk.tickets[0]) 127 | 128 | tlsPayload = append(tlsPayload, createSystemRecord(helloPart).serialize()...) 129 | c.clientSeqNum++ 130 | 131 | // Extensions 132 | extensionsPart := []byte{ 133 | 0x00, 0x00, 0x00, 0x10, 0x08, 0x00, 0x00, 0x00, 134 | 0x0b, 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x12, 135 | 0x00, 0x00, 0x00, 0x00, 136 | } 137 | binary.BigEndian.PutUint32(extensionsPart[16:], hello.timestamp) 138 | 139 | c.handshakeHasher.Write(extensionsPart) 140 | 141 | extensionsRecord := createSystemRecord(extensionsPart) 142 | extensionsRecord.encrypt(earlyKey, c.clientSeqNum) 143 | 144 | tlsPayload = append(tlsPayload, extensionsRecord.serialize()...) 145 | c.clientSeqNum++ 146 | 147 | // Request 148 | requestRecord := createRawDataRecord(datPart) 149 | requestRecord.encrypt(earlyKey, c.clientSeqNum) 150 | 151 | tlsPayload = append(tlsPayload, requestRecord.serialize()...) 152 | c.clientSeqNum++ 153 | 154 | // Abort 155 | abortPart := []byte{0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x01} 156 | abortRecord := createAbortRecord(abortPart) 157 | abortRecord.encrypt(earlyKey, c.clientSeqNum) 158 | 159 | tlsPayload = append(tlsPayload, abortRecord.serialize()...) 160 | c.clientSeqNum++ 161 | 162 | // HTTP header 163 | header, err := c.buildRequestHeader(host, len(tlsPayload)) 164 | if err != nil { 165 | return nil, err 166 | } 167 | 168 | return append(header, tlsPayload...), nil 169 | } 170 | 171 | func (c *MMTLSClientShort) genDataPart(host, path string, req []byte) ([]byte, error) { 172 | buf := &bytes.Buffer{} 173 | 174 | if err := writeU16LenData(buf, []byte(path)); err != nil { 175 | return nil, err 176 | } 177 | if err := writeU16LenData(buf, []byte(host)); err != nil { 178 | return nil, err 179 | } 180 | if err := writeU32LenData(buf, req); err != nil { 181 | return nil, err 182 | } 183 | 184 | data := buf.Bytes() 185 | pkt := make([]byte, 4) 186 | binary.BigEndian.PutUint32(pkt, uint32(len(data))) 187 | pkt = append(pkt, data...) 188 | 189 | return pkt, nil 190 | } 191 | 192 | func (c *MMTLSClientShort) buildRequestHeader(host string, length int) ([]byte, error) { 193 | request := &http.Request{ 194 | Method: http.MethodPost, 195 | Proto: "HTTP/1.1", 196 | ProtoMajor: 1, 197 | ProtoMinor: 1, 198 | Close: false, 199 | Header: map[string][]string{}, 200 | } 201 | 202 | randName := make([]byte, 4) 203 | if _, err := rand.Read(randName); err != nil { 204 | return nil, err 205 | } 206 | 207 | request.Header.Set("Accept", "*/*") 208 | request.Header.Set("Cache-Control", "no-cache") 209 | request.Header.Set("Connection", "Keep-Alive") 210 | request.Header.Set("Content-Type", "application/octet-stream") 211 | request.Header.Set("Content-Length", fmt.Sprintf("%d", length)) 212 | request.Header.Set("Upgrade", "mmtls") 213 | request.Header.Set("User-Agent", "MicroMessenger Client") 214 | request.URL, _ = url.Parse(fmt.Sprintf("https://%s/mmtls/%x", host, randName)) 215 | 216 | b, err := httputil.DumpRequest(request, false) 217 | if err != nil { 218 | return nil, err 219 | } 220 | 221 | return b, nil 222 | } 223 | 224 | func (c *MMTLSClientShort) parseResponse(conn net.Conn) ([]byte, error) { 225 | resp, err := http.ReadResponse(bufio.NewReader(conn), nil) 226 | if err != nil { 227 | return nil, err 228 | } 229 | 230 | b := new(bytes.Buffer) 231 | io.Copy(b, resp.Body) 232 | resp.Body.Close() 233 | resp.Body = ioutil.NopCloser(b) 234 | 235 | return b.Bytes(), nil 236 | } 237 | 238 | func (c *MMTLSClientShort) readServerHello() error { 239 | serverHelloRecord, err := readRecord(c.packetReader) 240 | if err != nil { 241 | return err 242 | } 243 | 244 | c.handshakeHasher.Write(serverHelloRecord.data) 245 | c.serverSeqNum++ 246 | 247 | return nil 248 | } 249 | 250 | func (c *MMTLSClientShort) readServerFinish() error { 251 | record, err := readRecord(c.packetReader) 252 | if err != nil { 253 | return err 254 | } 255 | 256 | if err := record.decrypt(c.Session.appKey, c.serverSeqNum); err != nil { 257 | return err 258 | } 259 | 260 | // TODO: verify server finished 261 | c.serverSeqNum++ 262 | 263 | return nil 264 | } 265 | 266 | func (c *MMTLSClientShort) readDataRecord() (*mmtlsRecord, error) { 267 | record, err := readRecord(c.packetReader) 268 | if err != nil { 269 | return nil, err 270 | } 271 | 272 | if err := record.decrypt(c.Session.appKey, c.serverSeqNum); err != nil { 273 | return nil, err 274 | } 275 | 276 | c.serverSeqNum++ 277 | 278 | return record, nil 279 | } 280 | 281 | func (c *MMTLSClientShort) readAbort() error { 282 | record, err := readRecord(c.packetReader) 283 | if err != nil { 284 | return err 285 | } 286 | 287 | if err := record.decrypt(c.Session.appKey, c.serverSeqNum); err != nil { 288 | return err 289 | } 290 | 291 | c.serverSeqNum++ 292 | 293 | return nil 294 | } 295 | 296 | func (c *MMTLSClientShort) earlyDataKey(pskAccess []byte, ticket *sessionTicket) (*trafficKeyPair, error) { 297 | trafficKey := make([]byte, 28) 298 | 299 | if _, err := hkdf.Expand(sha256.New, pskAccess, 300 | c.hkdfExpand("early data key expansion", c.handshakeHasher)). 301 | Read(trafficKey); err != nil { 302 | return nil, err 303 | } 304 | 305 | // early data key expansion 306 | pair := &trafficKeyPair{} 307 | pair.clientKey = trafficKey[:16] 308 | pair.clientNonce = trafficKey[16:] 309 | 310 | return pair, nil 311 | } 312 | 313 | func (c *MMTLSClientShort) computeTrafficKey(shareKey, info []byte) (*trafficKeyPair, error) { 314 | trafficKey := make([]byte, 28) 315 | 316 | if _, err := hkdf.Expand(sha256.New, shareKey, 317 | c.hkdfExpand("handshake key expansion", c.handshakeHasher)). 318 | Read(trafficKey); err != nil { 319 | return nil, err 320 | } 321 | 322 | // handshake key expansion 323 | pair := &trafficKeyPair{} 324 | pair.serverKey = trafficKey[:16] 325 | pair.serverNonce = trafficKey[16:] 326 | 327 | return pair, nil 328 | } 329 | 330 | func (c *MMTLSClientShort) hkdfExpand(prefix string, hash hash.Hash) []byte { 331 | info := []byte(prefix) 332 | if hash != nil { 333 | info = append(info, hash.Sum(nil)...) 334 | } 335 | return info 336 | } 337 | 338 | func (c *MMTLSClientShort) hmac(k []byte, d []byte) []byte { 339 | hm := hmac.New(sha256.New, k) 340 | hm.Write(d) 341 | return hm.Sum(nil) 342 | } 343 | -------------------------------------------------------------------------------- /mmtls/record.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "encoding/binary" 7 | "encoding/hex" 8 | "io" 9 | 10 | log "github.com/sirupsen/logrus" 11 | ) 12 | 13 | type dataRecord struct { 14 | dataType uint32 15 | seq uint32 16 | data []byte 17 | } 18 | 19 | type mmtlsRecord struct { 20 | recordType uint8 21 | version uint16 22 | length uint16 23 | data []byte 24 | } 25 | 26 | func (d *dataRecord) serialize() []byte { 27 | length := uint32(len(d.data) + 16) 28 | buf := make([]byte, length) 29 | 30 | binary.BigEndian.PutUint32(buf, length) 31 | binary.BigEndian.PutUint16(buf[4:], 0x10) 32 | binary.BigEndian.PutUint16(buf[6:], 0x1) 33 | binary.BigEndian.PutUint32(buf[8:], d.dataType) 34 | binary.BigEndian.PutUint32(buf[12:], d.seq) 35 | 36 | if length > 16 { 37 | copy(buf[16:], d.data) 38 | } 39 | 40 | return buf 41 | } 42 | 43 | func createAbortRecord(data []byte) *mmtlsRecord { 44 | return createRecord(MagicAbort, data) 45 | } 46 | 47 | func createHandshakeRecord(data []byte) *mmtlsRecord { 48 | return createRecord(MagicHandshake, data) 49 | } 50 | 51 | func createDataRecord(dataType uint32, seq uint32, data []byte) *mmtlsRecord { 52 | r := &dataRecord{ 53 | dataType: dataType, 54 | seq: seq, 55 | data: data, 56 | } 57 | return createRecord(MagicRecord, r.serialize()) 58 | } 59 | 60 | func createRawDataRecord(data []byte) *mmtlsRecord { 61 | return createRecord(MagicRecord, data) 62 | } 63 | 64 | func createSystemRecord(data []byte) *mmtlsRecord { 65 | return createRecord(MagicSystem, data) 66 | } 67 | 68 | func createRecord(recordType uint8, data []byte) *mmtlsRecord { 69 | return &mmtlsRecord{ 70 | recordType: recordType, 71 | version: ProtocolVersion, 72 | length: uint16(len(data)), 73 | data: data, 74 | } 75 | } 76 | 77 | func readRecord(buf io.Reader) (*mmtlsRecord, error) { 78 | r := &mmtlsRecord{} 79 | 80 | if err := binary.Read(buf, binary.BigEndian, &r.recordType); err != nil { 81 | return nil, err 82 | } 83 | if err := binary.Read(buf, binary.BigEndian, &r.version); err != nil { 84 | return nil, err 85 | } 86 | if err := binary.Read(buf, binary.BigEndian, &r.length); err != nil { 87 | return nil, err 88 | } 89 | r.data = make([]byte, r.length) 90 | if _, err := buf.Read(r.data); err != nil { 91 | return nil, err 92 | } 93 | 94 | return r, nil 95 | } 96 | 97 | func (r *mmtlsRecord) serialize() []byte { 98 | buf := make([]byte, r.length+5) 99 | 100 | buf[0] = r.recordType 101 | binary.BigEndian.PutUint16(buf[1:], r.version) 102 | binary.BigEndian.PutUint16(buf[3:], r.length) 103 | copy(buf[5:], r.data) 104 | 105 | return buf 106 | } 107 | 108 | func (r *mmtlsRecord) encrypt(keys *trafficKeyPair, clientSeqNum uint32) error { 109 | c, err := aes.NewCipher(keys.clientKey) 110 | if err != nil { 111 | return err 112 | } 113 | aead, err := cipher.NewGCM(c) 114 | if err != nil { 115 | return err 116 | } 117 | 118 | nonce := make([]byte, 12) 119 | copy(nonce, keys.clientNonce) 120 | xorNonce(nonce, clientSeqNum) 121 | 122 | auddit := make([]byte, 13) 123 | binary.BigEndian.PutUint64(auddit, uint64(clientSeqNum)) 124 | auddit[8] = r.recordType 125 | binary.BigEndian.PutUint16(auddit[9:], r.version) 126 | // GCM add 16-byte tag 127 | binary.BigEndian.PutUint16(auddit[11:], r.length+16) 128 | 129 | dst := aead.Seal(nil, nonce, r.data, auddit) 130 | 131 | log.Debugf("Encrypt(%d/%d):\n%s\n", len(r.data), len(dst), hex.Dump(dst)) 132 | 133 | r.data = dst 134 | r.length = uint16(len(dst)) 135 | 136 | return nil 137 | } 138 | 139 | func (r *mmtlsRecord) decrypt(keys *trafficKeyPair, serverSeqNum uint32) error { 140 | c, err := aes.NewCipher(keys.serverKey) 141 | if err != nil { 142 | return err 143 | } 144 | aead, err := cipher.NewGCM(c) 145 | if err != nil { 146 | return err 147 | } 148 | 149 | nonce := make([]byte, 12) 150 | copy(nonce, keys.serverNonce) 151 | xorNonce(nonce, serverSeqNum) 152 | auddit := make([]byte, 13) 153 | binary.BigEndian.PutUint64(auddit, uint64(serverSeqNum)) 154 | auddit[8] = r.recordType 155 | binary.BigEndian.PutUint16(auddit[9:], r.version) 156 | binary.BigEndian.PutUint16(auddit[11:], r.length) 157 | 158 | dst, err := aead.Open(nil, nonce, r.data, auddit) 159 | if err != nil { 160 | return err 161 | } 162 | 163 | log.Debugf("Decrypt:\n%s\n", hex.Dump(dst)) 164 | 165 | r.data = dst 166 | r.length = uint16(len(dst)) 167 | 168 | return nil 169 | } 170 | -------------------------------------------------------------------------------- /mmtls/server_finish.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "io" 7 | ) 8 | 9 | type serverFinish struct { 10 | reversed byte 11 | data []byte 12 | } 13 | 14 | func ReadServerFinish(buf []byte) (*serverFinish, error) { 15 | r := bytes.NewReader(buf) 16 | 17 | s := &serverFinish{} 18 | 19 | // package length 20 | if _, err := r.Seek(4, io.SeekCurrent); err != nil { 21 | return nil, err 22 | } 23 | 24 | // static reversed 25 | s.reversed, _ = r.ReadByte() 26 | 27 | var length uint16 28 | if err := binary.Read(r, binary.BigEndian, &length); err != nil { 29 | return nil, err 30 | } 31 | 32 | s.data = make([]byte, length) 33 | if _, err := r.Read(s.data); err != nil { 34 | return nil, err 35 | } 36 | 37 | return s, nil 38 | } 39 | -------------------------------------------------------------------------------- /mmtls/server_hello.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "bytes" 5 | "crypto/ecdsa" 6 | "crypto/elliptic" 7 | "encoding/binary" 8 | "errors" 9 | "io" 10 | ) 11 | 12 | type serverHello struct { 13 | protocolVersion uint16 14 | cipherSuites uint16 15 | publicKey *ecdsa.PublicKey 16 | } 17 | 18 | func readServerHello(buf []byte) (*serverHello, error) { 19 | r := bytes.NewReader(buf) 20 | 21 | hello := &serverHello{} 22 | 23 | var packLen uint32 24 | if err := binary.Read(r, binary.BigEndian, &packLen); err != nil { 25 | return nil, err 26 | } 27 | 28 | if len(buf) != int(packLen)+4 { 29 | return nil, errors.New("data corrupted") 30 | } 31 | 32 | // skip flag 33 | if _, err := r.Seek(1, io.SeekCurrent); err != nil { 34 | return nil, err 35 | } 36 | 37 | if err := binary.Read(r, binary.BigEndian, &hello.protocolVersion); err != nil { 38 | return nil, err 39 | } 40 | 41 | if err := binary.Read(r, binary.BigEndian, &hello.cipherSuites); err != nil { 42 | return nil, err 43 | } 44 | 45 | // skip server random 46 | if _, err := r.Seek(32, io.SeekCurrent); err != nil { 47 | return nil, err 48 | } 49 | 50 | // skip exntensions package length 51 | if _, err := r.Seek(4, io.SeekCurrent); err != nil { 52 | return nil, err 53 | } 54 | 55 | // skip extensions count 56 | if _, err := r.Seek(1, io.SeekCurrent); err != nil { 57 | return nil, err 58 | } 59 | 60 | // skip extension package length 61 | if _, err := r.Seek(4, io.SeekCurrent); err != nil { 62 | return nil, err 63 | } 64 | 65 | // skip extension type 66 | if _, err := r.Seek(2, io.SeekCurrent); err != nil { 67 | return nil, err 68 | } 69 | 70 | // skip extension array index 71 | if _, err := r.Seek(4, io.SeekCurrent); err != nil { 72 | return nil, err 73 | } 74 | 75 | var keyLen uint16 76 | if err := binary.Read(r, binary.BigEndian, &keyLen); err != nil { 77 | return nil, err 78 | } 79 | 80 | ecPoint := make([]byte, keyLen) 81 | if _, err := r.Read(ecPoint); err != nil { 82 | return nil, err 83 | } 84 | 85 | x, y := elliptic.Unmarshal(curve, ecPoint) 86 | 87 | hello.publicKey = &ecdsa.PublicKey{ 88 | Curve: curve, 89 | X: x, 90 | Y: y, 91 | } 92 | 93 | return hello, nil 94 | } 95 | -------------------------------------------------------------------------------- /mmtls/session.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "bytes" 5 | "io/ioutil" 6 | ) 7 | 8 | type trafficKeyPair struct { 9 | clientKey []byte 10 | serverKey []byte 11 | clientNonce []byte 12 | serverNonce []byte 13 | } 14 | 15 | type Session struct { 16 | tk *newSessionTicket 17 | pskAccess []byte 18 | pskRefresh []byte 19 | appKey *trafficKeyPair 20 | } 21 | 22 | func (s *Session) Save(path string) error { 23 | buf := &bytes.Buffer{} 24 | 25 | if err := writeU16LenData(buf, s.pskAccess); err != nil { 26 | return err 27 | } 28 | if err := writeU16LenData(buf, s.pskRefresh); err != nil { 29 | return err 30 | } 31 | 32 | ticketBytes, err := s.tk.serialize() 33 | if err != nil { 34 | return err 35 | } 36 | buf.Write(ticketBytes) 37 | 38 | return ioutil.WriteFile(path, buf.Bytes(), 0644) 39 | } 40 | 41 | func LoadSession(path string) (*Session, error) { 42 | sessionBytes, err := ioutil.ReadFile(path) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | r := bytes.NewReader(sessionBytes) 48 | pskAccess, err := readU16LenData(r) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | pskRefresh, err := readU16LenData(r) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | ticketBytes, err := ioutil.ReadAll(r) 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | ticket, err := readNewSessionTicket(ticketBytes) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return &Session{ 69 | pskAccess: pskAccess, 70 | pskRefresh: pskRefresh, 71 | tk: ticket, 72 | }, nil 73 | } 74 | -------------------------------------------------------------------------------- /mmtls/session_ticket.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | ) 7 | 8 | type sessionTicket struct { 9 | ticketType byte // reversed unknown 10 | ticketLifeTime uint32 11 | ticketAgeAdd []byte 12 | reversed uint32 // always 0x48 13 | nonce []byte // 12 bytes nonce 14 | ticket []byte 15 | } 16 | 17 | type newSessionTicket struct { 18 | reversed byte 19 | count byte 20 | tickets []sessionTicket 21 | } 22 | 23 | func readNewSessionTicket(buf []byte) (*newSessionTicket, error) { 24 | r := bytes.NewReader(buf) 25 | 26 | t := &newSessionTicket{} 27 | 28 | var length uint32 29 | if err := binary.Read(r, binary.BigEndian, &length); err != nil { 30 | return nil, err 31 | } 32 | 33 | t.reversed, _ = r.ReadByte() 34 | t.count, _ = r.ReadByte() 35 | 36 | for i := byte(0); i < t.count; i++ { 37 | if err := binary.Read(r, binary.BigEndian, &length); err != nil { 38 | return nil, err 39 | } 40 | data := make([]byte, length) 41 | if _, err := r.Read(data); err != nil { 42 | return nil, err 43 | } 44 | 45 | ticket, err := readSessionTicket(data) 46 | if err != nil { 47 | return nil, err 48 | } 49 | t.tickets = append(t.tickets, *ticket) 50 | } 51 | 52 | return t, nil 53 | } 54 | 55 | func readSessionTicket(buf []byte) (*sessionTicket, error) { 56 | r := bytes.NewReader(buf) 57 | 58 | t := &sessionTicket{} 59 | 60 | t.ticketType, _ = r.ReadByte() 61 | 62 | if err := binary.Read(r, binary.BigEndian, &t.ticketLifeTime); err != nil { 63 | return nil, err 64 | } 65 | 66 | var err error 67 | t.ticketAgeAdd, err = readU16LenData(r) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | if err := binary.Read(r, binary.BigEndian, &t.reversed); err != nil { 73 | return nil, err 74 | } 75 | 76 | t.nonce, err = readU16LenData(r) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | t.ticket, err = readU16LenData(r) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | return t, nil 87 | } 88 | 89 | func (t *sessionTicket) serialize() ([]byte, error) { 90 | buf := &bytes.Buffer{} 91 | 92 | if err := buf.WriteByte(t.ticketType); err != nil { 93 | return nil, err 94 | } 95 | 96 | if err := binary.Write(buf, binary.BigEndian, t.ticketLifeTime); err != nil { 97 | return nil, err 98 | } 99 | 100 | if err := writeU16LenData(buf, t.ticketAgeAdd); err != nil { 101 | return nil, err 102 | } 103 | 104 | if err := binary.Write(buf, binary.BigEndian, t.reversed); err != nil { 105 | return nil, err 106 | } 107 | 108 | if err := writeU16LenData(buf, t.nonce); err != nil { 109 | return nil, err 110 | } 111 | 112 | if err := writeU16LenData(buf, t.ticket); err != nil { 113 | return nil, err 114 | } 115 | 116 | return buf.Bytes(), nil 117 | } 118 | 119 | func (t *newSessionTicket) serialize() ([]byte, error) { 120 | buf := &bytes.Buffer{} 121 | 122 | if _, err := buf.Write([]byte{0x00, 0x00, 0x00, 0x00}); err != nil { 123 | return nil, err 124 | } 125 | if err := buf.WriteByte(0x04); err != nil { 126 | return nil, err 127 | } 128 | if err := buf.WriteByte(byte(len(t.tickets))); err != nil { 129 | return nil, err 130 | } 131 | 132 | for _, v := range t.tickets { 133 | vBytes, err := v.serialize() 134 | if err != nil { 135 | return nil, err 136 | } 137 | writeU32LenData(buf, vBytes) 138 | } 139 | 140 | data := buf.Bytes() 141 | binary.BigEndian.PutUint32(data, uint32(len(data)-4)) 142 | return data, nil 143 | } 144 | 145 | func (t *newSessionTicket) export() ([]byte, error) { 146 | earlyDataBuf := &bytes.Buffer{} 147 | 148 | data, err := t.tickets[0].serialize() 149 | if err != nil { 150 | return nil, err 151 | } 152 | if err := writeU32LenData(earlyDataBuf, data); err != nil { 153 | return nil, err 154 | } 155 | 156 | return earlyDataBuf.Bytes(), nil 157 | } 158 | -------------------------------------------------------------------------------- /mmtls/signature.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "io" 7 | ) 8 | 9 | type signature struct { 10 | Type byte 11 | EcdsaSignature []byte 12 | } 13 | 14 | func readSignature(buf []byte) (*signature, error) { 15 | r := bytes.NewReader(buf) 16 | 17 | s := &signature{} 18 | 19 | // skip package length 20 | if _, err := r.Seek(4, io.SeekCurrent); err != nil { 21 | return nil, err 22 | } 23 | 24 | // static 0x0f 25 | s.Type, _ = r.ReadByte() 26 | 27 | var length uint16 28 | if err := binary.Read(r, binary.BigEndian, &length); err != nil { 29 | return nil, err 30 | } 31 | 32 | s.EcdsaSignature = make([]byte, length) 33 | if _, err := r.Read(s.EcdsaSignature); err != nil { 34 | return nil, err 35 | } 36 | 37 | return s, nil 38 | } 39 | -------------------------------------------------------------------------------- /mmtls/utility.go: -------------------------------------------------------------------------------- 1 | package mmtls 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/binary" 6 | "io" 7 | ) 8 | 9 | func getRandom(n int) []byte { 10 | key := make([]byte, n) 11 | rand.Read(key) 12 | return key 13 | } 14 | 15 | func xorNonce(nonce []byte, seq uint32) { 16 | seqBytes := make([]byte, 4) 17 | binary.LittleEndian.PutUint32(seqBytes, seq) 18 | 19 | for i := 0; i < 4; i++ { 20 | pos := len(nonce) - i - 1 21 | nonce[pos] = nonce[pos] ^ seqBytes[i] 22 | } 23 | } 24 | 25 | func readU16LenData(r io.Reader) ([]byte, error) { 26 | var length uint16 27 | if err := binary.Read(r, binary.BigEndian, &length); err != nil { 28 | return nil, err 29 | } 30 | 31 | if length > 0 { 32 | b := make([]byte, length) 33 | if _, err := r.Read(b); err != nil { 34 | return nil, err 35 | } 36 | return b, nil 37 | } 38 | return nil, nil 39 | } 40 | 41 | func writeU32LenData(w io.Writer, d []byte) error { 42 | if err := binary.Write(w, binary.BigEndian, uint32(len(d))); err != nil { 43 | return err 44 | } 45 | if len(d) > 0 { 46 | if _, err := w.Write(d); err != nil { 47 | return err 48 | } 49 | } 50 | return nil 51 | } 52 | 53 | func writeU16LenData(w io.Writer, d []byte) error { 54 | if err := binary.Write(w, binary.BigEndian, uint16(len(d))); err != nil { 55 | return err 56 | } 57 | if len(d) > 0 { 58 | if _, err := w.Write(d); err != nil { 59 | return err 60 | } 61 | } 62 | return nil 63 | } 64 | --------------------------------------------------------------------------------