├── .github ├── CODEOWNERS ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── make └── proto.mk └── protocol ├── Makefile ├── address └── address.go ├── crypto └── aes │ ├── aes.go │ └── aes_test.go ├── curve ├── curve25519 │ ├── curve25519.go │ ├── curve25519_test.go │ └── util.go ├── djb_private.go ├── djb_public.go ├── doc.go ├── keytype.go ├── pair.go ├── pair_test.go ├── private.go ├── public.go └── public_test.go ├── direction └── direction.go ├── distribution └── id.go ├── fingerprint ├── displayable.go ├── encode.go ├── fingerprint.go ├── fingerprint_test.go ├── scannable.go └── scannable_test.go ├── generated └── v1 │ ├── fingerprint.pb.go │ ├── sealed_sender.pb.go │ ├── storage.pb.go │ └── wire.pb.go ├── go.mod ├── go.sum ├── identity ├── keys.go ├── keys_test.go ├── store.go └── store_inmem.go ├── internal ├── pointer │ └── pointer.go └── tools │ └── tools.go ├── message ├── ciphertext.go ├── ciphertexttype_string.go ├── mac.go ├── plaintext.go ├── prekey.go ├── prekey_test.go ├── senderkey.go ├── senderkey_test.go ├── signal.go └── signal_test.go ├── perrors └── errors.go ├── prekey ├── bundle.go ├── prekey.go ├── signed.go ├── store.go └── store_inmem.go ├── proto └── v1 │ ├── fingerprint.proto │ ├── sealed_sender.proto │ ├── storage.proto │ └── wire.proto ├── protocol ├── store.go └── store_inmem.go ├── ratchet ├── keys.go ├── keys_test.go ├── params.go └── ratchet.go ├── senderkey └── keys.go ├── session ├── cipher.go ├── group_cipher.go ├── group_record.go ├── group_record_test.go ├── group_session.go ├── group_state.go ├── record.go ├── session.go ├── state.go ├── store.go └── store_inmem.go └── tests ├── group_session_test.go ├── ratchet_test.go ├── session_test.go └── util_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | @RTann 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: 'github-actions' 4 | directory: '/' 5 | schedule: 6 | interval: 'weekly' 7 | day: 'wednesday' 8 | open-pull-requests-limit: 3 9 | reviewers: 10 | - 'RTann' 11 | - package-ecosystem: 'gomod' 12 | directory: '/protocol/' 13 | schedule: 14 | interval: 'weekly' 15 | day: 'wednesday' 16 | open-pull-requests-limit: 3 17 | reviewers: 18 | - 'RTann' 19 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Protocol CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | types: [opened, synchronize] 8 | 9 | jobs: 10 | test: 11 | name: Test 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/setup-go@v5 16 | with: 17 | go-version: ^1.20 18 | - uses: actions/checkout@v4 19 | with: 20 | fetch-depth: 1 21 | 22 | - name: Style checks 23 | run: | 24 | make -C protocol style 25 | if ! git diff --exit-code HEAD; then 26 | echo 27 | echo "*** Files are not formatted properly. See the above diff for more info." 28 | exit 1 29 | fi 30 | - name: Unit tests 31 | run: | 32 | make -C protocol unit-tests 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | # Editor directories 18 | .idea/ 19 | 20 | # Tool directories 21 | .proto 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # libsignal-go 2 | 3 | A pure Go implementation of https://github.com/signalapp/libsignal. 4 | 5 | This repository is meant to be broken down into different 6 | APIs similar to the source repository. 7 | 8 | Only libsignal-protocol is implemented at this time. 9 | 10 | This repository is still under development, so 11 | breaking changes should be expected. 12 | 13 | ## Roadmap 14 | 15 | * Add and refactor tests 16 | * Implement other APIs available in the source repository. 17 | -------------------------------------------------------------------------------- /make/proto.mk: -------------------------------------------------------------------------------- 1 | BASE_DIR ?= $(CURDIR) 2 | SILENT ?= @ 3 | 4 | # Protocol Buffers 5 | 6 | PROTOC_VERSION := 26.1 7 | 8 | PROTOC_DIR := $(BASE_DIR)/.proto 9 | $(PROTOC_DIR): 10 | $(SILENT)mkdir -p "$@" 11 | 12 | UNAME_S := $(shell uname -s) 13 | ifeq ($(UNAME_S),Linux) 14 | PROTOC_OS = linux 15 | endif 16 | ifeq ($(UNAME_S),Darwin) 17 | PROTOC_OS = osx 18 | endif 19 | PROTOC_ARCH=$(shell case $$(uname -m) in (arm64) echo aarch_64 ;; (*) uname -m ;; esac) 20 | 21 | DOWNLOAD_DIR := $(PROTOC_DIR)/.downloads 22 | $(DOWNLOAD_DIR): 23 | $(SILENT)mkdir -p "$@" 24 | 25 | PROTOC_ZIP := protoc-$(PROTOC_VERSION)-$(PROTOC_OS)-$(PROTOC_ARCH).zip 26 | PROTOC_FILE := $(DOWNLOAD_DIR)/$(PROTOC_ZIP) 27 | 28 | .PRECIOUS: $(PROTOC_FILE) 29 | $(PROTOC_FILE): $(DOWNLOAD_DIR) 30 | curl --output-dir $(DOWNLOAD_DIR) -LO "https://github.com/protocolbuffers/protobuf/releases/download/v$(PROTOC_VERSION)/$(PROTOC_ZIP)" 31 | 32 | PROTO_BIN := $(PROTOC_DIR)/bin 33 | $(PROTO_BIN): 34 | $(SILENT)mkdir -p "$@" 35 | 36 | PROTOC := $(PROTO_BIN)/protoc 37 | $(PROTOC): $(PROTOC_FILE) 38 | $(SILENT)unzip -q -o -d "$(PROTOC_DIR)" "$(PROTOC_FILE)" 39 | $(SILENT)test -x "$@" 40 | 41 | PROTOC_GEN_GO_BIN := $(PROTO_BIN)/protoc-gen-go 42 | $(PROTOC_GEN_GO_BIN): $(PROTO_BIN) 43 | GOBIN=$(PROTO_BIN) go install google.golang.org/protobuf/cmd/protoc-gen-go 44 | 45 | .PHONY: proto-install 46 | proto-install: $(PROTOC) $(PROTOC_GEN_GO_BIN) 47 | -------------------------------------------------------------------------------- /protocol/Makefile: -------------------------------------------------------------------------------- 1 | BASE_DIR := $(CURDIR)/.. 2 | 3 | # Code generation 4 | 5 | include ../make/proto.mk 6 | 7 | .PHONY: proto-gen 8 | proto-gen: proto-install 9 | mkdir -p generated/v1 10 | PATH=$(PROTO_BIN) && $(PROTOC) -I=proto/v1/ --go_out=generated/ proto/v1/* 11 | 12 | .PHONY: go-gen 13 | go-gen: 14 | go generate ./... 15 | 16 | # Tests 17 | 18 | .PHONY: unit-tests 19 | unit-tests: proto-gen 20 | go test -race -v ./... 21 | 22 | # Style 23 | 24 | .PHONY: style 25 | style: 26 | go fmt ./... 27 | -------------------------------------------------------------------------------- /protocol/address/address.go: -------------------------------------------------------------------------------- 1 | // Package address defines the structure of a device address. 2 | package address 3 | 4 | import "fmt" 5 | 6 | // DeviceID represents a unique device identifier. 7 | type DeviceID uint32 8 | 9 | // Address represents a unique address used by the protocol. 10 | type Address struct { 11 | Name string 12 | DeviceID DeviceID 13 | } 14 | 15 | func (a Address) String() string { 16 | return fmt.Sprintf("%s.%d", a.Name, a.DeviceID) 17 | } 18 | -------------------------------------------------------------------------------- /protocol/crypto/aes/aes.go: -------------------------------------------------------------------------------- 1 | // Package aes implements AES functions used by the protocol. 2 | package aes 3 | 4 | import ( 5 | "bytes" 6 | "crypto/aes" 7 | "crypto/cipher" 8 | "errors" 9 | ) 10 | 11 | // CBCEncrypt encrypts a plaintext message via 12 | // AES encryption in cipher block chaining mode. 13 | func CBCEncrypt(key, iv, plaintext []byte) ([]byte, error) { 14 | block, err := aes.NewCipher(key) 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | // cipher.NewCBCEncrypter panics if this does not hold, 20 | // so it's best to check here. 21 | if len(iv) != aes.BlockSize { 22 | return nil, errors.New("IV length must equal block size") 23 | } 24 | 25 | aescbc := cipher.NewCBCEncrypter(block, iv) 26 | 27 | // Plaintext must be padded to the next whole block. 28 | plaintext = pkcs7pad(plaintext) 29 | ciphertext := make([]byte, len(plaintext)) 30 | aescbc.CryptBlocks(ciphertext, plaintext) 31 | 32 | return ciphertext, nil 33 | } 34 | 35 | // CBCDecrypt decrypts a ciphertext message via 36 | // AES encryption in cipher block chaining mode. 37 | func CBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) { 38 | block, err := aes.NewCipher(key) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | if len(ciphertext) == 0 || len(ciphertext)%aes.BlockSize != 0 { 44 | return nil, errors.New("ciphertext length must be a non-zero multiple of the block size") 45 | } 46 | 47 | // cipher.NewCBCEncrypter panics if this does not hold, 48 | // so it's best to check here. 49 | if len(iv) != aes.BlockSize { 50 | return nil, errors.New("IV length must equal block size") 51 | } 52 | 53 | aescbc := cipher.NewCBCDecrypter(block, iv) 54 | 55 | plaintext := make([]byte, len(ciphertext)) 56 | aescbc.CryptBlocks(plaintext, ciphertext) 57 | 58 | return pkcs7unpad(plaintext) 59 | } 60 | 61 | // pkcs7pad implements PKCS #7 padding rules. 62 | // 63 | // See https://www.rfc-editor.org/rfc/rfc2315 Section 10.3 and 64 | // https://www.ibm.com/docs/en/zos/2.4.0?topic=rules-pkcs-padding-method 65 | // for more information. 66 | func pkcs7pad(plaintext []byte) []byte { 67 | n := aes.BlockSize - (len(plaintext) % aes.BlockSize) 68 | 69 | return append(plaintext, bytes.Repeat([]byte{byte(n)}, n)...) 70 | } 71 | 72 | // pkcs7unpad unpads plaintext which adhered to PKS #7 padding rules. 73 | func pkcs7unpad(plaintext []byte) ([]byte, error) { 74 | length := len(plaintext) 75 | n := int(plaintext[length-1]) 76 | 77 | if n < 1 || n > aes.BlockSize { 78 | return nil, errors.New("invalid padding") 79 | } 80 | 81 | return plaintext[:length-n], nil 82 | } 83 | -------------------------------------------------------------------------------- /protocol/crypto/aes/aes_test.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import ( 4 | "encoding/hex" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestPKCS7Pad(t *testing.T) { 12 | testcases := []struct { 13 | plaintext string 14 | padding []byte 15 | }{ 16 | { 17 | plaintext: "", 18 | padding: []byte{ 19 | 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 20 | 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 21 | }, 22 | }, 23 | { 24 | plaintext: "H", 25 | padding: []byte{ 26 | 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 27 | 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 28 | }, 29 | }, 30 | { 31 | plaintext: "He", 32 | padding: []byte{ 33 | 0x0E, 0x0E, 0x0E, 0x0E, 0x0E, 0x0E, 0x0E, 34 | 0x0E, 0x0E, 0x0E, 0x0E, 0x0E, 0x0E, 0x0E, 35 | }, 36 | }, 37 | { 38 | plaintext: "Hel", 39 | padding: []byte{ 40 | 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 41 | 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 42 | }, 43 | }, 44 | { 45 | plaintext: "Hell", 46 | padding: []byte{ 47 | 0x0C, 0x0C, 0x0C, 0x0C, 0x0C, 0x0C, 48 | 0x0C, 0x0C, 0x0C, 0x0C, 0x0C, 0x0C, 49 | }, 50 | }, 51 | { 52 | plaintext: "Hello", 53 | padding: []byte{ 54 | 0x0B, 0x0B, 0x0B, 0x0B, 0x0B, 0x0B, 55 | 0x0B, 0x0B, 0x0B, 0x0B, 0x0B, 56 | }, 57 | }, 58 | { 59 | plaintext: "Hello,", 60 | padding: []byte{ 61 | 0x0A, 0x0A, 0x0A, 0x0A, 0x0A, 62 | 0x0A, 0x0A, 0x0A, 0x0A, 0x0A, 63 | }, 64 | }, 65 | { 66 | plaintext: "Hello, ", 67 | padding: []byte{ 68 | 0x09, 0x09, 0x09, 0x09, 0x09, 69 | 0x09, 0x09, 0x09, 0x09, 70 | }, 71 | }, 72 | { 73 | plaintext: "Hello, W", 74 | padding: []byte{ 75 | 0x08, 0x08, 0x08, 0x08, 76 | 0x08, 0x08, 0x08, 0x08, 77 | }, 78 | }, 79 | { 80 | plaintext: "Hello, Wo", 81 | padding: []byte{ 82 | 0x07, 0x07, 0x07, 0x07, 83 | 0x07, 0x07, 0x07, 84 | }, 85 | }, 86 | { 87 | plaintext: "Hello, Wor", 88 | padding: []byte{ 89 | 0x06, 0x06, 0x06, 90 | 0x06, 0x06, 0x06, 91 | }, 92 | }, 93 | { 94 | plaintext: "Hello, Worl", 95 | padding: []byte{ 96 | 0x05, 0x05, 0x05, 97 | 0x05, 0x05, 98 | }, 99 | }, 100 | { 101 | plaintext: "Hello, World", 102 | padding: []byte{ 103 | 0x04, 0x04, 104 | 0x04, 0x04, 105 | }, 106 | }, 107 | { 108 | plaintext: "Hello, World!", 109 | padding: []byte{ 110 | 0x03, 0x03, 111 | 0x03, 112 | }, 113 | }, 114 | { 115 | plaintext: "Hello, World! ", 116 | padding: []byte{ 117 | 0x02, 118 | 0x02, 119 | }, 120 | }, 121 | { 122 | plaintext: "Hello, World! :", 123 | padding: []byte{ 124 | 0x01, 125 | }, 126 | }, 127 | { 128 | plaintext: "Hello, World! :)", 129 | padding: []byte{ 130 | 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 131 | 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 132 | }, 133 | }, 134 | } 135 | 136 | for _, testcase := range testcases { 137 | t.Run(testcase.plaintext, func(t *testing.T) { 138 | padded := append([]byte(testcase.plaintext), testcase.padding...) 139 | assert.Equal(t, padded, pkcs7pad([]byte(testcase.plaintext))) 140 | unpadded, err := pkcs7unpad(padded) 141 | assert.NoError(t, err) 142 | assert.Equal(t, testcase.plaintext, string(unpadded)) 143 | }) 144 | } 145 | } 146 | 147 | func TestCBC(t *testing.T) { 148 | key, err := hex.DecodeString("4e22eb16d964779994222e82192ce9f747da72dc4abe49dfdeeb71d0ffe3796e") 149 | require.NoError(t, err) 150 | iv, err := hex.DecodeString("6f8a557ddc0a140c878063a6d5f31d3d") 151 | require.NoError(t, err) 152 | plaintext, err := hex.DecodeString("30736294a124482a4159") 153 | require.NoError(t, err) 154 | 155 | ciphertext, err := CBCEncrypt(key, iv, plaintext) 156 | assert.NoError(t, err) 157 | assert.Equal(t, "dd3f573ab4508b9ed0e45e0baf5608f3", hex.EncodeToString(ciphertext)) 158 | 159 | recovered, err := CBCDecrypt(key, iv, ciphertext) 160 | assert.NoError(t, err) 161 | assert.Equal(t, hex.EncodeToString(plaintext), hex.EncodeToString(recovered)) 162 | 163 | // Invalid padding 164 | _, err = CBCDecrypt(key, iv, recovered) 165 | assert.Error(t, err) 166 | _, err = CBCDecrypt(key, ciphertext, ciphertext) 167 | assert.Error(t, err) 168 | 169 | badIV, err := hex.DecodeString("ef8a557ddc0a140c878063a6d5f31d3d") 170 | require.NoError(t, err) 171 | 172 | recovered, err = CBCDecrypt(key, badIV, ciphertext) 173 | assert.NoError(t, err) 174 | assert.Equal(t, "b0736294a124482a4159", hex.EncodeToString(recovered)) 175 | assert.NotEqual(t, hex.EncodeToString(plaintext), hex.EncodeToString(recovered)) 176 | } 177 | -------------------------------------------------------------------------------- /protocol/curve/curve25519/curve25519.go: -------------------------------------------------------------------------------- 1 | // Package curve25519 implements the XEd25519 signature scheme. 2 | // 3 | // See https://signal.org/docs/specifications/xeddsa/#curve25519 for more information. 4 | package curve25519 5 | 6 | import ( 7 | "crypto/ecdh" 8 | "crypto/ed25519" 9 | "crypto/rand" 10 | "crypto/sha512" 11 | "io" 12 | 13 | "filippo.io/edwards25519" 14 | 15 | "github.com/RTann/libsignal-go/protocol/perrors" 16 | ) 17 | 18 | const ( 19 | PrivateKeySize = 32 20 | PublicKeySize = 32 21 | SignatureSize = ed25519.SignatureSize 22 | randomSize = 64 23 | ) 24 | 25 | var ( 26 | hashPrefix = []byte{ 27 | 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 28 | 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 29 | 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 30 | } 31 | ) 32 | 33 | // PrivateKey represents a Montgomery private key used for the XEdDSA scheme. 34 | type PrivateKey struct { 35 | privateKey []byte 36 | publicKey []byte 37 | scalarKey *edwards25519.Scalar 38 | ecdhKey *ecdh.PrivateKey 39 | } 40 | 41 | // GeneratePrivateKey generates a random private key. 42 | // 43 | // It is recommended to use a cryptographic random reader. 44 | // If random is nil, then [crypto/rand.Reader] is used. 45 | func GeneratePrivateKey(random io.Reader) (*PrivateKey, error) { 46 | if random == nil { 47 | random = rand.Reader 48 | } 49 | 50 | key := make([]byte, PrivateKeySize) 51 | _, err := io.ReadFull(random, key) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return NewPrivateKey(key) 57 | } 58 | 59 | // NewPrivateKey creates a new private key based on the given input. 60 | func NewPrivateKey(key []byte) (*PrivateKey, error) { 61 | if len(key) != PrivateKeySize { 62 | return nil, perrors.ErrInvalidKeyLength(PrivateKeySize, len(key)) 63 | } 64 | 65 | scalarKey, err := edwards25519.NewScalar().SetBytesWithClamping(key) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | // No need to clamp here, as ECDH will take care of it. 71 | ecdhKey, err := ecdh.X25519().NewPrivateKey(key) 72 | if err != nil { 73 | return nil, err 74 | } 75 | 76 | privateKey := make([]byte, PrivateKeySize) 77 | copy(privateKey, key) 78 | // Clamp the given private key. 79 | // See step 2 in https://www.rfc-editor.org/rfc/rfc8032#section-5.1.5. 80 | privateKey[0] &= 248 81 | privateKey[31] &= 63 82 | privateKey[31] |= 64 83 | 84 | return &PrivateKey{ 85 | privateKey: privateKey, 86 | publicKey: xtou(scalarKey), 87 | scalarKey: scalarKey, 88 | ecdhKey: ecdhKey, 89 | }, nil 90 | } 91 | 92 | // Bytes returns a copy of the private key. 93 | func (p *PrivateKey) Bytes() []byte { 94 | bytes := make([]byte, PrivateKeySize) 95 | copy(bytes, p.privateKey) 96 | return bytes 97 | } 98 | 99 | // PublicKeyBytes returns the public key in the form of a Montgomery u-point. 100 | func (p *PrivateKey) PublicKeyBytes() []byte { 101 | bytes := make([]byte, PublicKeySize) 102 | copy(bytes, p.publicKey) 103 | return bytes 104 | } 105 | 106 | // Agreement computes the ECDH shared key between the private key and 107 | // the given public key. 108 | func (p *PrivateKey) Agreement(key []byte) ([]byte, error) { 109 | if len(key) != PublicKeySize { 110 | return nil, perrors.ErrInvalidKeyLength(PublicKeySize, len(key)) 111 | } 112 | 113 | publicKey, err := ecdh.X25519().NewPublicKey(key) 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | return p.ecdhKey.ECDH(publicKey) 119 | } 120 | 121 | // Sign calculates an XEdDSA signature using the X25519 private key, directly. 122 | // 123 | // The calculated signature is a valid ed25519 signature. 124 | // 125 | // It is recommended to use a cryptographic random reader. 126 | // If random is nil, then [crypto/rand.Reader] is used. 127 | func (p *PrivateKey) Sign(random io.Reader, messages ...[]byte) ([]byte, error) { 128 | if random == nil { 129 | random = rand.Reader 130 | } 131 | 132 | Z := make([]byte, randomSize) 133 | _, err := io.ReadFull(random, Z) 134 | if err != nil { 135 | return nil, err 136 | } 137 | 138 | a := p.scalarKey.Bytes() 139 | A := new(edwards25519.Point).ScalarBaseMult(p.scalarKey).Bytes() 140 | 141 | digest := make([]byte, 0, sha512.Size) 142 | hash := sha512.New() 143 | 144 | hash.Write(hashPrefix) 145 | hash.Write(a) 146 | for _, message := range messages { 147 | hash.Write(message) 148 | } 149 | hash.Write(Z) 150 | 151 | digest = hash.Sum(digest) 152 | r, err := edwards25519.NewScalar().SetUniformBytes(digest) 153 | if err != nil { 154 | return nil, err 155 | } 156 | R := new(edwards25519.Point).ScalarBaseMult(r).Bytes() 157 | 158 | digest = digest[:0] 159 | hash.Reset() 160 | 161 | hash.Write(R) 162 | hash.Write(A) 163 | for _, message := range messages { 164 | hash.Write(message) 165 | } 166 | 167 | digest = hash.Sum(digest) 168 | h, err := edwards25519.NewScalar().SetUniformBytes(digest) 169 | if err != nil { 170 | return nil, err 171 | } 172 | 173 | s := edwards25519.NewScalar().MultiplyAdd(p.scalarKey, h, r).Bytes() 174 | 175 | signBit := A[31] & 0b1000_0000 176 | 177 | signature := make([]byte, SignatureSize) 178 | copy(signature[:32], R) 179 | copy(signature[32:], s) 180 | signature[63] &= 0b0111_1111 181 | signature[63] |= signBit 182 | 183 | return signature, nil 184 | } 185 | 186 | // VerifySignature verifies the signature is a valid signature 187 | // for the messages by the public key. 188 | // 189 | // It is expected the given public key is Montgomery u-point. 190 | func VerifySignature(publicKey []byte, signature []byte, messages ...[]byte) (bool, error) { 191 | y, err := utoy(publicKey, (signature[63]&0b1000_0000) == 0) 192 | if err != nil { 193 | return false, err 194 | } 195 | 196 | // According to the spec, the signature's sign is supposed to be fixed to zero. 197 | // Sign does not enforce this, so we enforce it here. 198 | sig := make([]byte, 64) 199 | copy(sig, signature) 200 | sig[63] &= 0b0111_1111 201 | 202 | var msg []byte 203 | for _, message := range messages { 204 | msg = append(msg, message...) 205 | } 206 | 207 | return ed25519.Verify(y, msg, sig), nil 208 | } 209 | -------------------------------------------------------------------------------- /protocol/curve/curve25519/curve25519_test.go: -------------------------------------------------------------------------------- 1 | package curve25519 2 | 3 | import ( 4 | "crypto/rand" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestAgreement(t *testing.T) { 13 | alicePub := []byte{ 14 | 0x1b, 0xb7, 0x59, 0x66, 0xf2, 0xe9, 0x3a, 0x36, 0x91, 0xdf, 0xff, 0x94, 0x2b, 0xb2, 15 | 0xa4, 0x66, 0xa1, 0xc0, 0x8b, 0x8d, 0x78, 0xca, 0x3f, 0x4d, 0x6d, 0xf8, 0xb8, 0xbf, 16 | 0xa2, 0xe4, 0xee, 0x28, 17 | } 18 | alicePriv := []byte{ 19 | 0xc8, 0x06, 0x43, 0x9d, 0xc9, 0xd2, 0xc4, 0x76, 0xff, 0xed, 0x8f, 0x25, 0x80, 0xc0, 20 | 0x88, 0x8d, 0x58, 0xab, 0x40, 0x6b, 0xf7, 0xae, 0x36, 0x98, 0x87, 0x90, 0x21, 0xb9, 21 | 0x6b, 0xb4, 0xbf, 0x59, 22 | } 23 | bobPub := []byte{ 24 | 0x65, 0x36, 0x14, 0x99, 0x3d, 0x2b, 0x15, 0xee, 0x9e, 0x5f, 0xd3, 0xd8, 0x6c, 0xe7, 25 | 0x19, 0xef, 0x4e, 0xc1, 0xda, 0xae, 0x18, 0x86, 0xa8, 0x7b, 0x3f, 0x5f, 0xa9, 0x56, 26 | 0x5a, 0x27, 0xa2, 0x2f, 27 | } 28 | bobPriv := []byte{ 29 | 0xb0, 0x3b, 0x34, 0xc3, 0x3a, 0x1c, 0x44, 0xf2, 0x25, 0xb6, 0x62, 0xd2, 0xbf, 0x48, 30 | 0x59, 0xb8, 0x13, 0x54, 0x11, 0xfa, 0x7b, 0x03, 0x86, 0xd4, 0x5f, 0xb7, 0x5d, 0xc5, 31 | 0xb9, 0x1b, 0x44, 0x66, 32 | } 33 | shared := []byte{ 34 | 0x32, 0x5f, 0x23, 0x93, 0x28, 0x94, 0x1c, 0xed, 0x6e, 0x67, 0x3b, 0x86, 0xba, 0x41, 35 | 0x01, 0x74, 0x48, 0xe9, 0x9b, 0x64, 0x9a, 0x9c, 0x38, 0x06, 0xc1, 0xdd, 0x7c, 0xa4, 36 | 0xc4, 0x77, 0xe6, 0x29, 37 | } 38 | 39 | aliceKey, err := NewPrivateKey(alicePriv) 40 | require.NoError(t, err) 41 | assert.Equal(t, alicePub, aliceKey.PublicKeyBytes()) 42 | 43 | bobKey, err := NewPrivateKey(bobPriv) 44 | require.NoError(t, err) 45 | assert.Equal(t, bobPub, bobKey.PublicKeyBytes()) 46 | 47 | aliceSecret, err := aliceKey.Agreement(bobPub) 48 | assert.NoError(t, err) 49 | assert.Equal(t, shared, aliceSecret) 50 | 51 | bobSecret, err := bobKey.Agreement(alicePub) 52 | assert.NoError(t, err) 53 | assert.Equal(t, shared, bobSecret) 54 | } 55 | 56 | func TestAgreement_Random(t *testing.T) { 57 | for i := 0; i < 50; i++ { 58 | aliceKey, err := GeneratePrivateKey(rand.Reader) 59 | require.NoError(t, err) 60 | bobKey, err := GeneratePrivateKey(rand.Reader) 61 | require.NoError(t, err) 62 | 63 | aliceSecret, err := aliceKey.Agreement(bobKey.PublicKeyBytes()) 64 | assert.NoError(t, err) 65 | bobSecret, err := bobKey.Agreement(aliceKey.PublicKeyBytes()) 66 | assert.NoError(t, err) 67 | 68 | assert.Equal(t, aliceSecret, bobSecret) 69 | } 70 | } 71 | 72 | func TestVerifySignature(t *testing.T) { 73 | aliceIdentityPriv := []byte{ 74 | 0xc0, 0x97, 0x24, 0x84, 0x12, 0xe5, 0x8b, 0xf0, 0x5d, 0xf4, 0x87, 0x96, 0x82, 0x05, 75 | 0x13, 0x27, 0x94, 0x17, 0x8e, 0x36, 0x76, 0x37, 0xf5, 0x81, 0x8f, 0x81, 0xe0, 0xe6, 76 | 0xce, 0x73, 0xe8, 0x65, 77 | } 78 | aliceIdentityPub := []byte{ 79 | 0xab, 0x7e, 0x71, 0x7d, 0x4a, 0x16, 0x3b, 0x7d, 0x9a, 0x1d, 0x80, 0x71, 0xdf, 0xe9, 80 | 0xdc, 0xf8, 0xcd, 0xcd, 0x1c, 0xea, 0x33, 0x39, 0xb6, 0x35, 0x6b, 0xe8, 0x4d, 0x88, 81 | 0x7e, 0x32, 0x2c, 0x64, 82 | } 83 | aliceEphemeralPub := []byte{ 84 | 0x05, 0xed, 0xce, 0x9d, 0x9c, 0x41, 0x5c, 0xa7, 0x8c, 0xb7, 0x25, 0x2e, 0x72, 0xc2, 85 | 0xc4, 0xa5, 0x54, 0xd3, 0xeb, 0x29, 0x48, 0x5a, 0x0e, 0x1d, 0x50, 0x31, 0x18, 0xd1, 86 | 0xa8, 0x2d, 0x99, 0xfb, 0x4a, 87 | } 88 | aliceSignature := []byte{ 89 | 0x5d, 0xe8, 0x8c, 0xa9, 0xa8, 0x9b, 0x4a, 0x11, 0x5d, 0xa7, 0x91, 0x09, 0xc6, 0x7c, 90 | 0x9c, 0x74, 0x64, 0xa3, 0xe4, 0x18, 0x02, 0x74, 0xf1, 0xcb, 0x8c, 0x63, 0xc2, 0x98, 91 | 0x4e, 0x28, 0x6d, 0xfb, 0xed, 0xe8, 0x2d, 0xeb, 0x9d, 0xcd, 0x9f, 0xae, 0x0b, 0xfb, 92 | 0xb8, 0x21, 0x56, 0x9b, 0x3d, 0x90, 0x01, 0xbd, 0x81, 0x30, 0xcd, 0x11, 0xd4, 0x86, 93 | 0xce, 0xf0, 0x47, 0xbd, 0x60, 0xb8, 0x6e, 0x88, 94 | } 95 | 96 | aliceIdentityKey, err := NewPrivateKey(aliceIdentityPriv) 97 | require.NoError(t, err) 98 | assert.Equal(t, aliceIdentityPub, aliceIdentityKey.PublicKeyBytes()) 99 | 100 | valid, err := VerifySignature(aliceIdentityPub, aliceSignature, aliceEphemeralPub) 101 | assert.NoError(t, err) 102 | assert.True(t, valid) 103 | 104 | aliceSig := make([]byte, 64) 105 | for i := range aliceSignature { 106 | copy(aliceSig, aliceSignature) 107 | aliceSig[i] ^= 1 108 | 109 | valid, err := VerifySignature(aliceIdentityPub, aliceSig, aliceEphemeralPub) 110 | assert.NoError(t, err) 111 | assert.False(t, valid) 112 | } 113 | } 114 | 115 | func TestVerifySignature_Random(t *testing.T) { 116 | message := make([]byte, 64) 117 | for i := 0; i < 50; i++ { 118 | _, err := io.ReadFull(rand.Reader, message) 119 | require.NoError(t, err) 120 | key, err := GeneratePrivateKey(rand.Reader) 121 | require.NoError(t, err) 122 | 123 | signature, err := key.Sign(rand.Reader, message) 124 | assert.NoError(t, err) 125 | 126 | valid, err := VerifySignature(key.PublicKeyBytes(), signature, message) 127 | assert.NoError(t, err) 128 | assert.True(t, valid) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /protocol/curve/curve25519/util.go: -------------------------------------------------------------------------------- 1 | package curve25519 2 | 3 | import ( 4 | "errors" 5 | 6 | "filippo.io/edwards25519" 7 | "filippo.io/edwards25519/field" 8 | ) 9 | 10 | var ( 11 | one = new(field.Element).One() 12 | negativeOne = new(field.Element).Negate(one) 13 | ) 14 | 15 | // xtou converts an Edwards x-point to a Montgomery u-point. 16 | func xtou(edX *edwards25519.Scalar) []byte { 17 | return new(edwards25519.Point).ScalarBaseMult(edX).BytesMontgomery() 18 | } 19 | 20 | // utoy converts a Montgomery u-point to an Edwards y-point. 21 | func utoy(montU []byte, positive bool) ([]byte, error) { 22 | u, err := new(field.Element).SetBytes(montU) 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | // y = (u - 1) / (u + 1) 28 | // See https://www.rfc-editor.org/rfc/rfc7748#section-4.1 for more information. 29 | 30 | // Based on the formula above, a valid u-point cannot be -1. 31 | // Equal returns 1 to mean "true". 32 | if u.Equal(negativeOne) == 1 { 33 | return nil, errors.New("invalid u-point") 34 | } 35 | 36 | uMinusOne := new(field.Element).Subtract(u, one) 37 | invUPlusOne := new(field.Element).Invert(new(field.Element).Add(u, one)) 38 | 39 | y := new(field.Element).Multiply(uMinusOne, invUPlusOne).Bytes() 40 | var sign byte 41 | if !positive { 42 | sign = 0b1000_0000 43 | } 44 | y[31] ^= sign 45 | 46 | return y, nil 47 | } 48 | -------------------------------------------------------------------------------- /protocol/curve/djb_private.go: -------------------------------------------------------------------------------- 1 | package curve 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/RTann/libsignal-go/protocol/curve/curve25519" 7 | ) 8 | 9 | var _ PrivateKey = (*DJBPrivateKey)(nil) 10 | 11 | // DJBPrivateKey represents an elliptic curve private key. 12 | type DJBPrivateKey struct { 13 | key *curve25519.PrivateKey 14 | } 15 | 16 | // GeneratePrivateKey generates a private key using the given random reader. 17 | // 18 | // It is recommended to use a cryptographic random reader. 19 | // If random is nil, then [crypto/rand.Reader] is used. 20 | func GeneratePrivateKey(random io.Reader) (PrivateKey, error) { 21 | privateKey, err := curve25519.GeneratePrivateKey(random) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | return &DJBPrivateKey{ 27 | key: privateKey, 28 | }, nil 29 | } 30 | 31 | // newDJBPrivateKey returns a private key based on the given key bytes. 32 | func newDJBPrivateKey(key []byte) (PrivateKey, error) { 33 | privateKey, err := curve25519.NewPrivateKey(key) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | return &DJBPrivateKey{ 39 | key: privateKey, 40 | }, nil 41 | } 42 | 43 | func (d *DJBPrivateKey) Bytes() []byte { 44 | return d.key.Bytes() 45 | } 46 | 47 | func (d *DJBPrivateKey) PublicKey() PublicKey { 48 | key, _ := newDJBPublicKey(d.key.PublicKeyBytes()) 49 | return key 50 | } 51 | 52 | func (d *DJBPrivateKey) Agreement(key PublicKey) ([]byte, error) { 53 | return d.key.Agreement(key.KeyBytes()) 54 | } 55 | 56 | func (d *DJBPrivateKey) Sign(random io.Reader, messages ...[]byte) ([]byte, error) { 57 | return d.key.Sign(random, messages...) 58 | } 59 | -------------------------------------------------------------------------------- /protocol/curve/djb_public.go: -------------------------------------------------------------------------------- 1 | package curve 2 | 3 | import ( 4 | "crypto/subtle" 5 | 6 | "github.com/RTann/libsignal-go/protocol/curve/curve25519" 7 | "github.com/RTann/libsignal-go/protocol/perrors" 8 | ) 9 | 10 | var _ PublicKey = (*DJBPublicKey)(nil) 11 | 12 | // DJBPublicKey represents an elliptic curve public key. 13 | type DJBPublicKey struct { 14 | key []byte 15 | } 16 | 17 | // newDJBPublicKey returns a public key based on the given key bytes. 18 | func newDJBPublicKey(key []byte) (*DJBPublicKey, error) { 19 | if len(key) != PublicKeySize { 20 | return nil, perrors.ErrInvalidKeyLength(PublicKeySize, len(key)) 21 | } 22 | 23 | return &DJBPublicKey{ 24 | key: key, 25 | }, nil 26 | } 27 | 28 | func (d *DJBPublicKey) keyType() KeyType { 29 | return DJB 30 | } 31 | 32 | func (d *DJBPublicKey) Bytes() []byte { 33 | bytes := make([]byte, 1+PublicKeySize) 34 | bytes[0] = byte(DJB) 35 | copy(bytes[1:], d.key) 36 | return bytes 37 | } 38 | 39 | func (d *DJBPublicKey) KeyBytes() []byte { 40 | bytes := make([]byte, PublicKeySize) 41 | copy(bytes, d.key) 42 | return bytes 43 | } 44 | 45 | func (d *DJBPublicKey) Equal(key PublicKey) bool { 46 | return key.keyType() == DJB && 47 | subtle.ConstantTimeCompare(d.KeyBytes(), key.KeyBytes()) == 1 48 | } 49 | 50 | func (d *DJBPublicKey) VerifySignature(signature []byte, messages ...[]byte) (bool, error) { 51 | if len(signature) != SignatureSize { 52 | return false, nil 53 | } 54 | 55 | return curve25519.VerifySignature(d.key, signature, messages...) 56 | } 57 | -------------------------------------------------------------------------------- /protocol/curve/doc.go: -------------------------------------------------------------------------------- 1 | // Package curve implements elliptic curve cryptography 2 | // functions used for the protocol. 3 | package curve 4 | -------------------------------------------------------------------------------- /protocol/curve/keytype.go: -------------------------------------------------------------------------------- 1 | package curve 2 | 3 | // KeyType represents a type of public or private key. 4 | type KeyType byte 5 | 6 | const ( 7 | DJB KeyType = 0x05 8 | ) 9 | -------------------------------------------------------------------------------- /protocol/curve/pair.go: -------------------------------------------------------------------------------- 1 | package curve 2 | 3 | import "io" 4 | 5 | // KeyPair represents a related pair of public and private keys. 6 | type KeyPair struct { 7 | privateKey PrivateKey 8 | publicKey PublicKey 9 | } 10 | 11 | // GenerateKeyPair returns a public/private key pair using the given reader. 12 | // 13 | // It is recommended to use a cryptographic random reader. 14 | // If random is nil, then [crypto/rand.Reader] is used. 15 | func GenerateKeyPair(random io.Reader) (*KeyPair, error) { 16 | privateKey, err := GeneratePrivateKey(random) 17 | if err != nil { 18 | return nil, err 19 | } 20 | 21 | publicKey := privateKey.PublicKey() 22 | 23 | return &KeyPair{ 24 | privateKey: privateKey, 25 | publicKey: publicKey, 26 | }, nil 27 | } 28 | 29 | // NewKeyPair returns a public/private key pair from the given pair. 30 | // 31 | // The given pair is expected to represent a valid [curve.PrivateKey] and 32 | // [curve.PublicKey], respectively. 33 | func NewKeyPair(privateKey, publicKey []byte) (*KeyPair, error) { 34 | private, err := NewPrivateKey(privateKey) 35 | if err != nil { 36 | return nil, err 37 | } 38 | public, err := NewPublicKey(publicKey) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | return &KeyPair{ 44 | privateKey: private, 45 | publicKey: public, 46 | }, nil 47 | } 48 | 49 | // PrivateKey returns the pair's private key. 50 | func (k *KeyPair) PrivateKey() PrivateKey { 51 | return k.privateKey 52 | } 53 | 54 | // PublicKey returns the pair's public key. 55 | func (k *KeyPair) PublicKey() PublicKey { 56 | return k.publicKey 57 | } 58 | 59 | // Agreement calculates and returns the shared secret between 60 | // the key pair's private key and the given public key. 61 | func (k *KeyPair) Agreement(key PublicKey) ([]byte, error) { 62 | return k.privateKey.Agreement(key) 63 | } 64 | 65 | // Sign calculates the digital signature of the messages using 66 | // the key pair's private key. 67 | // 68 | // It is recommended to use a cryptographic random reader. 69 | // If random is nil, then [crypto/rand.Reader] is used. 70 | func (k *KeyPair) Sign(random io.Reader, messages ...[]byte) ([]byte, error) { 71 | return k.privateKey.Sign(random, messages...) 72 | } 73 | -------------------------------------------------------------------------------- /protocol/curve/pair_test.go: -------------------------------------------------------------------------------- 1 | package curve 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestSignatures(t *testing.T) { 11 | keyPair, err := GenerateKeyPair(rand.Reader) 12 | assert.NoError(t, err) 13 | 14 | msg := make([]byte, 1024*1024) 15 | signature, err := keyPair.PrivateKey().Sign(rand.Reader, msg) 16 | assert.NoError(t, err) 17 | 18 | valid, err := keyPair.PublicKey().VerifySignature(signature, msg) 19 | assert.NoError(t, err) 20 | assert.True(t, valid) 21 | 22 | msg[0] ^= 0x01 23 | valid, err = keyPair.PublicKey().VerifySignature(signature, msg) 24 | assert.NoError(t, err) 25 | assert.False(t, valid) 26 | 27 | msg[0] ^= 0x01 28 | publicKey := keyPair.PrivateKey().PublicKey() 29 | valid, err = publicKey.VerifySignature(signature, msg) 30 | assert.NoError(t, err) 31 | assert.True(t, valid) 32 | 33 | valid, err = publicKey.VerifySignature(signature, msg[:7], msg[7:]) 34 | assert.NoError(t, err) 35 | assert.True(t, valid) 36 | 37 | signature, err = keyPair.PrivateKey().Sign(rand.Reader, msg[:20], msg[20:]) 38 | assert.NoError(t, err) 39 | valid, err = publicKey.VerifySignature(signature, msg) 40 | assert.NoError(t, err) 41 | assert.True(t, valid) 42 | } 43 | -------------------------------------------------------------------------------- /protocol/curve/private.go: -------------------------------------------------------------------------------- 1 | package curve 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/RTann/libsignal-go/protocol/curve/curve25519" 7 | "github.com/RTann/libsignal-go/protocol/perrors" 8 | ) 9 | 10 | const PrivateKeySize = curve25519.PrivateKeySize 11 | 12 | // PrivateKey represents an elliptic curve private key. 13 | type PrivateKey interface { 14 | // Bytes returns an encoding of the private key. 15 | Bytes() []byte 16 | // PublicKey returns the private key's related public key. 17 | PublicKey() PublicKey 18 | // Agreement calculates and returns the shared secret between the private key 19 | // and the given public key. 20 | Agreement(key PublicKey) ([]byte, error) 21 | // Sign calculates the digital signature of the messages. 22 | Sign(random io.Reader, messages ...[]byte) ([]byte, error) 23 | } 24 | 25 | // NewPrivateKey returns a PrivateKey based on the given key. 26 | func NewPrivateKey(key []byte) (PrivateKey, error) { 27 | if len(key) != PrivateKeySize { 28 | return nil, perrors.ErrInvalidKeyLength(PrivateKeySize, len(key)) 29 | } 30 | 31 | return newDJBPrivateKey(key) 32 | } 33 | -------------------------------------------------------------------------------- /protocol/curve/public.go: -------------------------------------------------------------------------------- 1 | package curve 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/RTann/libsignal-go/protocol/curve/curve25519" 7 | "github.com/RTann/libsignal-go/protocol/perrors" 8 | ) 9 | 10 | const ( 11 | PublicKeySize = curve25519.PublicKeySize 12 | SignatureSize = curve25519.SignatureSize 13 | ) 14 | 15 | // PublicKey represents an elliptic curve public key. 16 | type PublicKey interface { 17 | keyType() KeyType 18 | // Bytes returns an encoding of the public key. 19 | Bytes() []byte 20 | // KeyBytes returns an encoding of the public key without the type prefix. 21 | KeyBytes() []byte 22 | // Equal reports whether the given public key is the same as this public key. 23 | // 24 | // This check is performed in constant time as long as the keys have the same type. 25 | Equal(key PublicKey) bool 26 | // VerifySignature verifies the signature is a valid signature 27 | // of the messages by the public key. 28 | VerifySignature(signature []byte, messages ...[]byte) (bool, error) 29 | } 30 | 31 | // NewPublicKey returns a PublicKey based on the given key. 32 | // 33 | // The first byte of the given key is expected to identify the type of the key. 34 | func NewPublicKey(key []byte) (PublicKey, error) { 35 | // Allow trailing data after the public key for some reason... 36 | if len(key) < 1+PublicKeySize { 37 | return nil, perrors.ErrInvalidKeyLength(1+PublicKeySize, len(key)) 38 | } 39 | 40 | publicKey := make([]byte, PublicKeySize) 41 | copy(publicKey, key[1:]) 42 | 43 | switch t := KeyType(key[0]); t { 44 | case DJB: 45 | return newDJBPublicKey(publicKey) 46 | default: 47 | return nil, fmt.Errorf("unsupported key type: %v", t) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /protocol/curve/public_test.go: -------------------------------------------------------------------------------- 1 | package curve 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestPublicKeySize(t *testing.T) { 12 | keyPair, err := GenerateKeyPair(rand.Reader) 13 | require.NoError(t, err) 14 | 15 | publicBytes := keyPair.PublicKey().Bytes() 16 | assert.Equal(t, publicBytes, keyPair.PrivateKey().PublicKey().Bytes()) 17 | 18 | goodPublicKey, err := NewPublicKey(publicBytes) 19 | assert.NoError(t, err) 20 | _, err = NewPublicKey(publicBytes[1:]) 21 | assert.Error(t, err) 22 | _, err = NewPublicKey([]byte{}) 23 | assert.Error(t, err) 24 | 25 | badType := make([]byte, len(publicBytes)) 26 | copy(badType, publicBytes) 27 | badType[0] = 0x01 28 | _, err = NewPublicKey(badType) 29 | assert.Error(t, err) 30 | 31 | large := make([]byte, len(publicBytes)+1) 32 | copy(large, publicBytes) 33 | largePublicKey, err := NewPublicKey(large) 34 | assert.NoError(t, err) 35 | 36 | assert.Equal(t, publicBytes, goodPublicKey.Bytes()) 37 | assert.Equal(t, publicBytes, largePublicKey.Bytes()) 38 | } 39 | -------------------------------------------------------------------------------- /protocol/direction/direction.go: -------------------------------------------------------------------------------- 1 | // Package direction contains the possible directions of protocol messages. 2 | package direction 3 | 4 | // Direction is a protocol message direction. 5 | type Direction int 6 | 7 | const ( 8 | Sending = iota + 1 9 | Receiving 10 | ) 11 | -------------------------------------------------------------------------------- /protocol/distribution/id.go: -------------------------------------------------------------------------------- 1 | // Package distribution defines a group distribution ID. 2 | package distribution 3 | 4 | import "github.com/google/uuid" 5 | 6 | type ID struct { 7 | id uuid.UUID 8 | } 9 | 10 | func MustParse(id string) ID { 11 | distributionID := uuid.MustParse(id) 12 | 13 | return ID{ 14 | id: distributionID, 15 | } 16 | } 17 | 18 | func Parse(id string) (ID, error) { 19 | distributionID, err := uuid.Parse(id) 20 | if err != nil { 21 | return ID{}, err 22 | } 23 | 24 | return ID{ 25 | id: distributionID, 26 | }, nil 27 | } 28 | 29 | func ParseBytes(id []byte) (ID, error) { 30 | distributionID, err := uuid.ParseBytes(id) 31 | if err != nil { 32 | return ID{}, err 33 | } 34 | 35 | return ID{ 36 | id: distributionID, 37 | }, nil 38 | } 39 | 40 | func (i ID) String() string { 41 | return i.id.String() 42 | } 43 | -------------------------------------------------------------------------------- /protocol/fingerprint/displayable.go: -------------------------------------------------------------------------------- 1 | package fingerprint 2 | 3 | // Displayable represents a string representation of a fingerprint. 4 | type Displayable struct { 5 | local string 6 | remote string 7 | } 8 | 9 | // NewDisplayable creates a new displayable fingerprint. 10 | func NewDisplayable(local, remote []byte) (*Displayable, error) { 11 | encodedLocal, err := encode(local) 12 | if err != nil { 13 | return nil, err 14 | } 15 | encodedRemote, err := encode(remote) 16 | if err != nil { 17 | return nil, err 18 | } 19 | 20 | return &Displayable{ 21 | local: encodedLocal, 22 | remote: encodedRemote, 23 | }, nil 24 | } 25 | 26 | func (d *Displayable) String() string { 27 | if d.local < d.remote { 28 | return d.local + d.remote 29 | } 30 | 31 | return d.remote + d.local 32 | } 33 | -------------------------------------------------------------------------------- /protocol/fingerprint/encode.go: -------------------------------------------------------------------------------- 1 | package fingerprint 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | func encode(bytes []byte) (string, error) { 9 | if len(bytes) < 30 { 10 | return "", fmt.Errorf("encoding too short: %d < 30", len(bytes)) 11 | } 12 | 13 | uint64FromChunk := func(chunk [5]byte) uint64 { 14 | return uint64(chunk[0]&0xFF)<<32 | 15 | uint64(chunk[1]&0xFF)<<24 | 16 | uint64(chunk[2]&0xFF)<<16 | 17 | uint64(chunk[3]&0xFF)<<8 | 18 | uint64(chunk[4]&0xFF) 19 | } 20 | encodeChunk := func(chunk [5]byte) string { 21 | return fmt.Sprintf("%05d", uint64FromChunk(chunk)%100_000) 22 | } 23 | 24 | var encoding strings.Builder 25 | encoding.WriteString(encodeChunk([5]byte(bytes[:5]))) 26 | encoding.WriteString(encodeChunk([5]byte(bytes[5:10]))) 27 | encoding.WriteString(encodeChunk([5]byte(bytes[10:15]))) 28 | encoding.WriteString(encodeChunk([5]byte(bytes[15:20]))) 29 | encoding.WriteString(encodeChunk([5]byte(bytes[20:25]))) 30 | encoding.WriteString(encodeChunk([5]byte(bytes[25:30]))) 31 | 32 | return encoding.String(), nil 33 | } 34 | -------------------------------------------------------------------------------- /protocol/fingerprint/fingerprint.go: -------------------------------------------------------------------------------- 1 | // Package fingerprint defines a protocol user's unique fingerprint. 2 | package fingerprint 3 | 4 | import ( 5 | "crypto/sha512" 6 | "fmt" 7 | 8 | "github.com/RTann/libsignal-go/protocol/identity" 9 | ) 10 | 11 | var fingerprintVersion = []byte{0x00, 0x00} 12 | 13 | // Fingerprint represents a user's unique fingerprint. 14 | type Fingerprint struct { 15 | displayable *Displayable 16 | scannable *Scannable 17 | } 18 | 19 | // New creates a new fingerprint. 20 | func New( 21 | version, 22 | iterations uint32, 23 | localID []byte, 24 | localKey identity.Key, 25 | remoteID []byte, 26 | remoteKey identity.Key, 27 | ) (*Fingerprint, error) { 28 | localFingerprint, err := fingerprint(iterations, localID, localKey) 29 | if err != nil { 30 | return nil, err 31 | } 32 | remoteFingerprint, err := fingerprint(iterations, remoteID, remoteKey) 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | displayable, err := NewDisplayable(localFingerprint, remoteFingerprint) 38 | if err != nil { 39 | return nil, err 40 | } 41 | scannable, err := NewScannable(version, localFingerprint, remoteFingerprint) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | return &Fingerprint{ 47 | displayable: displayable, 48 | scannable: scannable, 49 | }, nil 50 | } 51 | 52 | // fingerprint generates a fingerprint for the user identified by the ID and key. 53 | func fingerprint(iterations uint32, localID []byte, localKey identity.Key) ([]byte, error) { 54 | if iterations <= 1 || iterations > 1_000_000 { 55 | return nil, fmt.Errorf("invalid iterations: %d", iterations) 56 | } 57 | 58 | localKeyBytes := localKey.Bytes() 59 | 60 | checksum := make([]byte, 0, sha512.Size) 61 | 62 | // iteration 0. 63 | hash := sha512.New() 64 | hash.Write(fingerprintVersion) 65 | hash.Write(localKeyBytes) 66 | hash.Write(localID) 67 | hash.Write(localKeyBytes) 68 | checksum = hash.Sum(checksum) 69 | 70 | for i := uint32(1); i < iterations; i++ { 71 | hash.Reset() 72 | hash.Write(checksum) 73 | hash.Write(localKeyBytes) 74 | checksum = checksum[:0] 75 | checksum = hash.Sum(checksum) 76 | } 77 | 78 | return checksum, nil 79 | } 80 | 81 | func (f *Fingerprint) String() string { 82 | return f.displayable.String() 83 | } 84 | -------------------------------------------------------------------------------- /protocol/fingerprint/fingerprint_test.go: -------------------------------------------------------------------------------- 1 | package fingerprint 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/RTann/libsignal-go/protocol/identity" 12 | ) 13 | 14 | const ( 15 | aliceIdentityHex = "0506863bc66d02b40d27b8d49ca7c09e9239236f9d7d25d6fcca5ce13c7064d868" 16 | bobIdentityHex = "05f781b6fb32fed9ba1cf2de978d4d5da28dc34046ae814402b5c0dbd96fda907b" 17 | 18 | displayableFingerprintV1 = "300354477692869396892869876765458257569162576843440918079131" 19 | 20 | aliceScannableFingerprintV1 = "080112220a201e301a0353dce3dbe7684cb8336e85136cdc0ee96219494ada305d62a7bd61df1a220a20d62cbf73a11592015b6b9f1682ac306fea3aaf3885b84d12bca631e9d4fb3a4d" 21 | bobScannableFingerprintV1 = "080112220a20d62cbf73a11592015b6b9f1682ac306fea3aaf3885b84d12bca631e9d4fb3a4d1a220a201e301a0353dce3dbe7684cb8336e85136cdc0ee96219494ada305d62a7bd61df" 22 | 23 | aliceScannableFingerprintV2 = "080212220a201e301a0353dce3dbe7684cb8336e85136cdc0ee96219494ada305d62a7bd61df1a220a20d62cbf73a11592015b6b9f1682ac306fea3aaf3885b84d12bca631e9d4fb3a4d" 24 | bobScannableFingerprintV2 = "080212220a20d62cbf73a11592015b6b9f1682ac306fea3aaf3885b84d12bca631e9d4fb3a4d1a220a201e301a0353dce3dbe7684cb8336e85136cdc0ee96219494ada305d62a7bd61df" 25 | ) 26 | 27 | var ( 28 | aliceStableID = []byte("+14152222222") 29 | bobStableID = []byte("+14153333333") 30 | ) 31 | 32 | func TestFingerprint_V1(t *testing.T) { 33 | aliceIdentity, err := hex.DecodeString(aliceIdentityHex) 34 | require.NoError(t, err) 35 | bobIdentity, err := hex.DecodeString(bobIdentityHex) 36 | require.NoError(t, err) 37 | 38 | aKey, err := identity.NewKey(aliceIdentity) 39 | require.NoError(t, err) 40 | bKey, err := identity.NewKey(bobIdentity) 41 | require.NoError(t, err) 42 | 43 | version := uint32(1) 44 | iterations := uint32(5200) 45 | 46 | aFprint, err := New(version, iterations, aliceStableID, aKey, bobStableID, bKey) 47 | assert.NoError(t, err) 48 | bFprint, err := New(version, iterations, bobStableID, bKey, aliceStableID, aKey) 49 | assert.NoError(t, err) 50 | 51 | aScannableBytes, err := aFprint.scannable.Bytes() 52 | require.NoError(t, err) 53 | assert.Equal(t, aliceScannableFingerprintV1, hex.EncodeToString(aScannableBytes)) 54 | bScannableBytes, err := bFprint.scannable.Bytes() 55 | require.NoError(t, err) 56 | assert.Equal(t, bobScannableFingerprintV1, hex.EncodeToString(bScannableBytes)) 57 | 58 | assert.Equal(t, displayableFingerprintV1, aFprint.displayable.String()) 59 | assert.Equal(t, displayableFingerprintV1, bFprint.displayable.String()) 60 | } 61 | 62 | func TestFingerprint_V2(t *testing.T) { 63 | aliceIdentity, err := hex.DecodeString(aliceIdentityHex) 64 | require.NoError(t, err) 65 | bobIdentity, err := hex.DecodeString(bobIdentityHex) 66 | require.NoError(t, err) 67 | 68 | aKey, err := identity.NewKey(aliceIdentity) 69 | require.NoError(t, err) 70 | bKey, err := identity.NewKey(bobIdentity) 71 | require.NoError(t, err) 72 | 73 | version := uint32(2) 74 | iterations := uint32(5200) 75 | 76 | aFprint, err := New(version, iterations, aliceStableID, aKey, bobStableID, bKey) 77 | assert.NoError(t, err) 78 | bFprint, err := New(version, iterations, bobStableID, bKey, aliceStableID, aKey) 79 | assert.NoError(t, err) 80 | 81 | aScannableBytes, err := aFprint.scannable.Bytes() 82 | require.NoError(t, err) 83 | assert.Equal(t, aliceScannableFingerprintV2, hex.EncodeToString(aScannableBytes)) 84 | bScannableBytes, err := bFprint.scannable.Bytes() 85 | require.NoError(t, err) 86 | assert.Equal(t, bobScannableFingerprintV2, hex.EncodeToString(bScannableBytes)) 87 | 88 | assert.Equal(t, displayableFingerprintV1, aFprint.displayable.String()) 89 | assert.Equal(t, displayableFingerprintV1, bFprint.displayable.String()) 90 | } 91 | 92 | func TestFingerprint_MatchingIdentifiers(t *testing.T) { 93 | aKeyPair, err := identity.GenerateKeyPair(rand.Reader) 94 | require.NoError(t, err) 95 | bKeyPair, err := identity.GenerateKeyPair(rand.Reader) 96 | require.NoError(t, err) 97 | 98 | aKey := aKeyPair.IdentityKey() 99 | bKey := bKeyPair.IdentityKey() 100 | 101 | version := uint32(1) 102 | iterations := uint32(1024) 103 | 104 | aFprint, err := New(version, iterations, aliceStableID, aKey, bobStableID, bKey) 105 | assert.NoError(t, err) 106 | bFprint, err := New(version, iterations, bobStableID, bKey, aliceStableID, aKey) 107 | assert.NoError(t, err) 108 | 109 | assert.Equal(t, aFprint.displayable.String(), bFprint.displayable.String()) 110 | assert.Len(t, aFprint.displayable.String(), 60) 111 | 112 | aScannable, err := aFprint.scannable.Bytes() 113 | require.NoError(t, err) 114 | bScannable, err := bFprint.scannable.Bytes() 115 | require.NoError(t, err) 116 | equal, err := aFprint.scannable.Compare(bScannable) 117 | assert.NoError(t, err) 118 | assert.True(t, equal) 119 | equal, err = bFprint.scannable.Compare(aScannable) 120 | assert.NoError(t, err) 121 | assert.True(t, equal) 122 | } 123 | 124 | func TestFingerprint_MismatchingVersions(t *testing.T) { 125 | aKeyPair, err := identity.GenerateKeyPair(rand.Reader) 126 | require.NoError(t, err) 127 | bKeyPair, err := identity.GenerateKeyPair(rand.Reader) 128 | require.NoError(t, err) 129 | 130 | aKey := aKeyPair.IdentityKey() 131 | bKey := bKeyPair.IdentityKey() 132 | 133 | iterations := uint32(5200) 134 | 135 | aFprintV1, err := New(uint32(1), iterations, aliceStableID, aKey, bobStableID, bKey) 136 | assert.NoError(t, err) 137 | aFprintV2, err := New(uint32(2), iterations, aliceStableID, aKey, bobStableID, bKey) 138 | assert.NoError(t, err) 139 | 140 | assert.Equal(t, aFprintV1.displayable.String(), aFprintV2.displayable.String()) 141 | aV1Scannable, err := aFprintV1.scannable.Bytes() 142 | require.NoError(t, err) 143 | aV2Scannable, err := aFprintV2.scannable.Bytes() 144 | require.NoError(t, err) 145 | assert.NotEqual(t, hex.EncodeToString(aV1Scannable), hex.EncodeToString(aV2Scannable)) 146 | } 147 | -------------------------------------------------------------------------------- /protocol/fingerprint/scannable.go: -------------------------------------------------------------------------------- 1 | package fingerprint 2 | 3 | import ( 4 | "crypto/subtle" 5 | "fmt" 6 | 7 | "google.golang.org/protobuf/proto" 8 | 9 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 10 | "github.com/RTann/libsignal-go/protocol/internal/pointer" 11 | ) 12 | 13 | // Scannable represents a fingerprint to be displayed on a QR code. 14 | type Scannable struct { 15 | version uint32 16 | local []byte 17 | remote []byte 18 | } 19 | 20 | // NewScannable creates a new scannable fingerprint. 21 | func NewScannable(version uint32, local, remote []byte) (*Scannable, error) { 22 | if len(local) < 32 { 23 | return nil, fmt.Errorf("invalid local fingerprint length: %d < 32", len(local)) 24 | } 25 | if len(remote) < 32 { 26 | return nil, fmt.Errorf("invalid remote fingerprint length: %d < 32", len(remote)) 27 | } 28 | 29 | return &Scannable{ 30 | version: version, 31 | local: local[:32], 32 | remote: remote[:32], 33 | }, nil 34 | } 35 | 36 | // Bytes returns an encoding of the scannable fingerprint. 37 | func (s *Scannable) Bytes() ([]byte, error) { 38 | combined := &v1.CombinedFingerprints{ 39 | Version: pointer.To(s.version), 40 | LocalFingerprint: &v1.LogicalFingerprint{ 41 | Content: s.local, 42 | }, 43 | RemoteFingerprint: &v1.LogicalFingerprint{ 44 | Content: s.remote, 45 | }, 46 | } 47 | 48 | bytes, err := proto.Marshal(combined) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | return bytes, nil 54 | } 55 | 56 | // Compare compares a scanned QR code with the expected fingerprint. 57 | func (s *Scannable) Compare(scanned []byte) (bool, error) { 58 | var combined v1.CombinedFingerprints 59 | if err := proto.Unmarshal(scanned, &combined); err != nil { 60 | return false, err 61 | } 62 | 63 | theirVersion := combined.GetVersion() 64 | if theirVersion != s.version { 65 | return false, fmt.Errorf("fingerprint version mismatch: %d != %d", theirVersion, s.version) 66 | } 67 | 68 | same1 := subtle.ConstantTimeCompare(s.local, combined.GetRemoteFingerprint().GetContent()) == 1 69 | same2 := subtle.ConstantTimeCompare(s.remote, combined.GetLocalFingerprint().GetContent()) == 1 70 | 71 | return same1 && same2, nil 72 | } 73 | -------------------------------------------------------------------------------- /protocol/fingerprint/scannable_test.go: -------------------------------------------------------------------------------- 1 | package fingerprint 2 | 3 | import ( 4 | "encoding/hex" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestScannable(t *testing.T) { 13 | l := []byte{ 14 | 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 15 | 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 16 | 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 17 | } 18 | r := []byte{ 19 | 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 20 | 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 21 | 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 0xBA, 22 | } 23 | 24 | scannable, err := NewScannable(2, l, r) 25 | require.NoError(t, err) 26 | scannableBytes, err := scannable.Bytes() 27 | require.NoError(t, err) 28 | 29 | var expected strings.Builder 30 | expected.WriteString("080212220a20") 31 | for i := 0; i < 32; i++ { 32 | expected.WriteString("12") 33 | } 34 | expected.WriteString("1a220a20") 35 | for i := 0; i < 32; i++ { 36 | expected.WriteString("ba") 37 | } 38 | 39 | assert.Equal(t, expected.String(), hex.EncodeToString(scannableBytes)) 40 | } 41 | -------------------------------------------------------------------------------- /protocol/generated/v1/fingerprint.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.34.1 4 | // protoc v5.26.1 5 | // source: fingerprint.proto 6 | 7 | // 8 | // Copyright 2020 Signal Messenger, LLC. 9 | // SPDX-License-Identifier: AGPL-3.0-only 10 | // 11 | 12 | package v1 13 | 14 | import ( 15 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 16 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 17 | reflect "reflect" 18 | sync "sync" 19 | ) 20 | 21 | const ( 22 | // Verify that this generated code is sufficiently up-to-date. 23 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 24 | // Verify that runtime/protoimpl is sufficiently up-to-date. 25 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 26 | ) 27 | 28 | type LogicalFingerprint struct { 29 | state protoimpl.MessageState 30 | sizeCache protoimpl.SizeCache 31 | unknownFields protoimpl.UnknownFields 32 | 33 | Content []byte `protobuf:"bytes,1,opt,name=content" json:"content,omitempty"` // bytes identifier = 2; 34 | } 35 | 36 | func (x *LogicalFingerprint) Reset() { 37 | *x = LogicalFingerprint{} 38 | if protoimpl.UnsafeEnabled { 39 | mi := &file_fingerprint_proto_msgTypes[0] 40 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 41 | ms.StoreMessageInfo(mi) 42 | } 43 | } 44 | 45 | func (x *LogicalFingerprint) String() string { 46 | return protoimpl.X.MessageStringOf(x) 47 | } 48 | 49 | func (*LogicalFingerprint) ProtoMessage() {} 50 | 51 | func (x *LogicalFingerprint) ProtoReflect() protoreflect.Message { 52 | mi := &file_fingerprint_proto_msgTypes[0] 53 | if protoimpl.UnsafeEnabled && x != nil { 54 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 55 | if ms.LoadMessageInfo() == nil { 56 | ms.StoreMessageInfo(mi) 57 | } 58 | return ms 59 | } 60 | return mi.MessageOf(x) 61 | } 62 | 63 | // Deprecated: Use LogicalFingerprint.ProtoReflect.Descriptor instead. 64 | func (*LogicalFingerprint) Descriptor() ([]byte, []int) { 65 | return file_fingerprint_proto_rawDescGZIP(), []int{0} 66 | } 67 | 68 | func (x *LogicalFingerprint) GetContent() []byte { 69 | if x != nil { 70 | return x.Content 71 | } 72 | return nil 73 | } 74 | 75 | type CombinedFingerprints struct { 76 | state protoimpl.MessageState 77 | sizeCache protoimpl.SizeCache 78 | unknownFields protoimpl.UnknownFields 79 | 80 | Version *uint32 `protobuf:"varint,1,opt,name=version" json:"version,omitempty"` 81 | LocalFingerprint *LogicalFingerprint `protobuf:"bytes,2,opt,name=local_fingerprint,json=localFingerprint" json:"local_fingerprint,omitempty"` 82 | RemoteFingerprint *LogicalFingerprint `protobuf:"bytes,3,opt,name=remote_fingerprint,json=remoteFingerprint" json:"remote_fingerprint,omitempty"` 83 | } 84 | 85 | func (x *CombinedFingerprints) Reset() { 86 | *x = CombinedFingerprints{} 87 | if protoimpl.UnsafeEnabled { 88 | mi := &file_fingerprint_proto_msgTypes[1] 89 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 90 | ms.StoreMessageInfo(mi) 91 | } 92 | } 93 | 94 | func (x *CombinedFingerprints) String() string { 95 | return protoimpl.X.MessageStringOf(x) 96 | } 97 | 98 | func (*CombinedFingerprints) ProtoMessage() {} 99 | 100 | func (x *CombinedFingerprints) ProtoReflect() protoreflect.Message { 101 | mi := &file_fingerprint_proto_msgTypes[1] 102 | if protoimpl.UnsafeEnabled && x != nil { 103 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 104 | if ms.LoadMessageInfo() == nil { 105 | ms.StoreMessageInfo(mi) 106 | } 107 | return ms 108 | } 109 | return mi.MessageOf(x) 110 | } 111 | 112 | // Deprecated: Use CombinedFingerprints.ProtoReflect.Descriptor instead. 113 | func (*CombinedFingerprints) Descriptor() ([]byte, []int) { 114 | return file_fingerprint_proto_rawDescGZIP(), []int{1} 115 | } 116 | 117 | func (x *CombinedFingerprints) GetVersion() uint32 { 118 | if x != nil && x.Version != nil { 119 | return *x.Version 120 | } 121 | return 0 122 | } 123 | 124 | func (x *CombinedFingerprints) GetLocalFingerprint() *LogicalFingerprint { 125 | if x != nil { 126 | return x.LocalFingerprint 127 | } 128 | return nil 129 | } 130 | 131 | func (x *CombinedFingerprints) GetRemoteFingerprint() *LogicalFingerprint { 132 | if x != nil { 133 | return x.RemoteFingerprint 134 | } 135 | return nil 136 | } 137 | 138 | var File_fingerprint_proto protoreflect.FileDescriptor 139 | 140 | var file_fingerprint_proto_rawDesc = []byte{ 141 | 0x0a, 0x11, 0x66, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x70, 0x72, 0x69, 0x6e, 0x74, 0x2e, 0x70, 0x72, 142 | 0x6f, 0x74, 0x6f, 0x12, 0x18, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 143 | 0x6f, 0x2e, 0x66, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x70, 0x72, 0x69, 0x6e, 0x74, 0x22, 0x2e, 0x0a, 144 | 0x12, 0x4c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x46, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x70, 0x72, 145 | 0x69, 0x6e, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x18, 0x01, 146 | 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x22, 0xe8, 0x01, 147 | 0x0a, 0x14, 0x43, 0x6f, 0x6d, 0x62, 0x69, 0x6e, 0x65, 0x64, 0x46, 0x69, 0x6e, 0x67, 0x65, 0x72, 148 | 0x70, 0x72, 0x69, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 149 | 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 150 | 0x12, 0x59, 0x0a, 0x11, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x5f, 0x66, 0x69, 0x6e, 0x67, 0x65, 0x72, 151 | 0x70, 0x72, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2c, 0x2e, 0x73, 0x69, 152 | 0x67, 0x6e, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x66, 0x69, 0x6e, 0x67, 0x65, 153 | 0x72, 0x70, 0x72, 0x69, 0x6e, 0x74, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x46, 0x69, 154 | 0x6e, 0x67, 0x65, 0x72, 0x70, 0x72, 0x69, 0x6e, 0x74, 0x52, 0x10, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 155 | 0x46, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x70, 0x72, 0x69, 0x6e, 0x74, 0x12, 0x5b, 0x0a, 0x12, 0x72, 156 | 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x5f, 0x66, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x70, 0x72, 0x69, 0x6e, 157 | 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2c, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 158 | 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x66, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x70, 0x72, 0x69, 159 | 0x6e, 0x74, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x46, 0x69, 0x6e, 0x67, 0x65, 0x72, 160 | 0x70, 0x72, 0x69, 0x6e, 0x74, 0x52, 0x11, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x46, 0x69, 0x6e, 161 | 0x67, 0x65, 0x72, 0x70, 0x72, 0x69, 0x6e, 0x74, 0x42, 0x05, 0x5a, 0x03, 0x2f, 0x76, 0x31, 162 | } 163 | 164 | var ( 165 | file_fingerprint_proto_rawDescOnce sync.Once 166 | file_fingerprint_proto_rawDescData = file_fingerprint_proto_rawDesc 167 | ) 168 | 169 | func file_fingerprint_proto_rawDescGZIP() []byte { 170 | file_fingerprint_proto_rawDescOnce.Do(func() { 171 | file_fingerprint_proto_rawDescData = protoimpl.X.CompressGZIP(file_fingerprint_proto_rawDescData) 172 | }) 173 | return file_fingerprint_proto_rawDescData 174 | } 175 | 176 | var file_fingerprint_proto_msgTypes = make([]protoimpl.MessageInfo, 2) 177 | var file_fingerprint_proto_goTypes = []interface{}{ 178 | (*LogicalFingerprint)(nil), // 0: signal.proto.fingerprint.LogicalFingerprint 179 | (*CombinedFingerprints)(nil), // 1: signal.proto.fingerprint.CombinedFingerprints 180 | } 181 | var file_fingerprint_proto_depIdxs = []int32{ 182 | 0, // 0: signal.proto.fingerprint.CombinedFingerprints.local_fingerprint:type_name -> signal.proto.fingerprint.LogicalFingerprint 183 | 0, // 1: signal.proto.fingerprint.CombinedFingerprints.remote_fingerprint:type_name -> signal.proto.fingerprint.LogicalFingerprint 184 | 2, // [2:2] is the sub-list for method output_type 185 | 2, // [2:2] is the sub-list for method input_type 186 | 2, // [2:2] is the sub-list for extension type_name 187 | 2, // [2:2] is the sub-list for extension extendee 188 | 0, // [0:2] is the sub-list for field type_name 189 | } 190 | 191 | func init() { file_fingerprint_proto_init() } 192 | func file_fingerprint_proto_init() { 193 | if File_fingerprint_proto != nil { 194 | return 195 | } 196 | if !protoimpl.UnsafeEnabled { 197 | file_fingerprint_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 198 | switch v := v.(*LogicalFingerprint); i { 199 | case 0: 200 | return &v.state 201 | case 1: 202 | return &v.sizeCache 203 | case 2: 204 | return &v.unknownFields 205 | default: 206 | return nil 207 | } 208 | } 209 | file_fingerprint_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { 210 | switch v := v.(*CombinedFingerprints); i { 211 | case 0: 212 | return &v.state 213 | case 1: 214 | return &v.sizeCache 215 | case 2: 216 | return &v.unknownFields 217 | default: 218 | return nil 219 | } 220 | } 221 | } 222 | type x struct{} 223 | out := protoimpl.TypeBuilder{ 224 | File: protoimpl.DescBuilder{ 225 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 226 | RawDescriptor: file_fingerprint_proto_rawDesc, 227 | NumEnums: 0, 228 | NumMessages: 2, 229 | NumExtensions: 0, 230 | NumServices: 0, 231 | }, 232 | GoTypes: file_fingerprint_proto_goTypes, 233 | DependencyIndexes: file_fingerprint_proto_depIdxs, 234 | MessageInfos: file_fingerprint_proto_msgTypes, 235 | }.Build() 236 | File_fingerprint_proto = out.File 237 | file_fingerprint_proto_rawDesc = nil 238 | file_fingerprint_proto_goTypes = nil 239 | file_fingerprint_proto_depIdxs = nil 240 | } 241 | -------------------------------------------------------------------------------- /protocol/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/RTann/libsignal-go/protocol 2 | 3 | go 1.20 4 | 5 | require ( 6 | filippo.io/edwards25519 v1.1.0 7 | github.com/golang/glog v1.2.4 8 | github.com/google/uuid v1.6.0 9 | github.com/stretchr/testify v1.10.0 10 | golang.org/x/crypto v0.33.0 11 | golang.org/x/tools v0.31.0 12 | google.golang.org/protobuf v1.34.2 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | golang.org/x/mod v0.24.0 // indirect 19 | golang.org/x/sync v0.12.0 // indirect 20 | gopkg.in/yaml.v3 v3.0.1 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /protocol/go.sum: -------------------------------------------------------------------------------- 1 | filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= 2 | filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc= 6 | github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= 7 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 8 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 9 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 10 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 11 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 12 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 13 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 14 | golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= 15 | golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= 16 | golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= 17 | golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= 18 | golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= 19 | golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 20 | golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU= 21 | golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ= 22 | google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= 23 | google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= 24 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 25 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 26 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 27 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 28 | -------------------------------------------------------------------------------- /protocol/identity/keys.go: -------------------------------------------------------------------------------- 1 | // Package identity defines an identity key. 2 | package identity 3 | 4 | import ( 5 | "io" 6 | 7 | "github.com/RTann/libsignal-go/protocol/curve" 8 | ) 9 | 10 | var ( 11 | alternateIdentitySignaturePrefix1 = []byte{ 12 | 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 13 | 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 14 | 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 15 | } 16 | alternateIdentitySignaturePrefix2 = []byte("Signal_PNI_Signature") 17 | ) 18 | 19 | // Key represents a public identity key. 20 | type Key struct { 21 | publicKey curve.PublicKey 22 | } 23 | 24 | // NewKey returns a public identity key from the given key bytes. 25 | func NewKey(key []byte) (Key, error) { 26 | publicKey, err := curve.NewPublicKey(key) 27 | if err != nil { 28 | return Key{}, err 29 | } 30 | 31 | return Key{ 32 | publicKey: publicKey, 33 | }, nil 34 | } 35 | 36 | // PublicKey returns the identity key's public key. 37 | func (k Key) PublicKey() curve.PublicKey { 38 | return k.publicKey 39 | } 40 | 41 | // Bytes returns an encoding of the identity key. 42 | func (k Key) Bytes() []byte { 43 | return k.publicKey.Bytes() 44 | } 45 | 46 | // VerifyAlternateIdentity verifies other key represents an alternate identity 47 | // for this user. 48 | // 49 | // It is expected the signature is the output of KeyPair.SignAlternateIdentity. 50 | func (k Key) VerifyAlternateIdentity(signature []byte, other Key) (bool, error) { 51 | return k.publicKey.VerifySignature(signature, 52 | alternateIdentitySignaturePrefix1, 53 | alternateIdentitySignaturePrefix2, 54 | other.Bytes(), 55 | ) 56 | } 57 | 58 | // Equal determines if the identity keys are the same. 59 | func (k Key) Equal(key Key) bool { 60 | return k.PublicKey().Equal(key.PublicKey()) 61 | } 62 | 63 | // KeyPair represents a public/private identity key pair. 64 | type KeyPair struct { 65 | privateKey curve.PrivateKey 66 | identityKey Key 67 | } 68 | 69 | // GenerateKeyPair generates an identity key pair using the given random reader. 70 | // 71 | // It is recommended to use a cryptographic random reader. 72 | // If random is `nil`, then crypto/rand.Reader is used. 73 | func GenerateKeyPair(random io.Reader) (KeyPair, error) { 74 | pair, err := curve.GenerateKeyPair(random) 75 | if err != nil { 76 | return KeyPair{}, err 77 | } 78 | 79 | return KeyPair{ 80 | privateKey: pair.PrivateKey(), 81 | identityKey: Key{ 82 | publicKey: pair.PublicKey(), 83 | }, 84 | }, nil 85 | } 86 | 87 | // NewKeyPair returns an identity key pair based on the given 88 | // public and private keys. 89 | func NewKeyPair(privateKey curve.PrivateKey, identityKey Key) KeyPair { 90 | return KeyPair{ 91 | privateKey: privateKey, 92 | identityKey: identityKey, 93 | } 94 | } 95 | 96 | // IdentityKey returns the key pair's public identity key. 97 | func (k KeyPair) IdentityKey() Key { 98 | return k.identityKey 99 | } 100 | 101 | // PublicKey returns the key pair's public key. 102 | func (k KeyPair) PublicKey() curve.PublicKey { 103 | return k.identityKey.PublicKey() 104 | } 105 | 106 | // PrivateKey returns the key pair's private key. 107 | func (k KeyPair) PrivateKey() curve.PrivateKey { 108 | return k.privateKey 109 | } 110 | 111 | // SignAlternateIdentity generates a signature claiming other key 112 | // represents the same user as this key pair. 113 | func (k KeyPair) SignAlternateIdentity(random io.Reader, other Key) ([]byte, error) { 114 | return k.privateKey.Sign(random, 115 | alternateIdentitySignaturePrefix1, 116 | alternateIdentitySignaturePrefix2, 117 | other.Bytes(), 118 | ) 119 | } 120 | -------------------------------------------------------------------------------- /protocol/identity/keys_test.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestIdentityKey(t *testing.T) { 12 | pair, err := GenerateKeyPair(rand.Reader) 13 | require.NoError(t, err) 14 | 15 | key := Key{ 16 | publicKey: pair.PublicKey(), 17 | } 18 | assert.Equal(t, pair.PublicKey().Bytes(), key.Bytes()) 19 | } 20 | 21 | func TestSignAlternateIdentity(t *testing.T) { 22 | random := rand.Reader 23 | 24 | primary, err := GenerateKeyPair(random) 25 | require.NoError(t, err) 26 | secondary, err := GenerateKeyPair(random) 27 | require.NoError(t, err) 28 | 29 | signature, err := secondary.SignAlternateIdentity(random, primary.IdentityKey()) 30 | assert.NoError(t, err) 31 | valid, err := secondary.IdentityKey().VerifyAlternateIdentity(signature, primary.IdentityKey()) 32 | assert.NoError(t, err) 33 | assert.True(t, valid) 34 | // Should not be symmetric. 35 | valid, err = primary.IdentityKey().VerifyAlternateIdentity(signature, secondary.IdentityKey()) 36 | assert.NoError(t, err) 37 | assert.False(t, valid) 38 | 39 | anotherSignature, err := secondary.SignAlternateIdentity(random, primary.IdentityKey()) 40 | assert.NoError(t, err) 41 | assert.NotEqual(t, signature, anotherSignature) 42 | valid, err = secondary.IdentityKey().VerifyAlternateIdentity(anotherSignature, primary.IdentityKey()) 43 | assert.NoError(t, err) 44 | assert.True(t, valid) 45 | 46 | unrelated, err := GenerateKeyPair(random) 47 | require.NoError(t, err) 48 | valid, err = secondary.IdentityKey().VerifyAlternateIdentity(signature, unrelated.IdentityKey()) 49 | assert.NoError(t, err) 50 | assert.False(t, valid) 51 | valid, err = unrelated.IdentityKey().VerifyAlternateIdentity(signature, primary.IdentityKey()) 52 | assert.NoError(t, err) 53 | assert.False(t, valid) 54 | } 55 | -------------------------------------------------------------------------------- /protocol/identity/store.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/RTann/libsignal-go/protocol/address" 7 | "github.com/RTann/libsignal-go/protocol/direction" 8 | ) 9 | 10 | // Store defines an identity key store. 11 | // 12 | // An identity key store is associated with a local identity key pair and registration ID. 13 | type Store interface { 14 | // KeyPair returns the associated identity key pair. 15 | KeyPair(ctx context.Context) KeyPair 16 | // LocalRegistrationID returns the associated registration ID. 17 | LocalRegistrationID(ctx context.Context) uint32 18 | // Load loads the identity key associated with the remote address. 19 | Load(ctx context.Context, address address.Address) (Key, bool, error) 20 | // Store stores the identity key associated with the remote address and returns 21 | // "true" if there is already an entry for the address which is overwritten 22 | // with a new identity key. 23 | // 24 | // Storing the identity key for the remote address implies the identity key 25 | // is trusted for the given address. 26 | Store(ctx context.Context, address address.Address, identity Key) (bool, error) 27 | // Clear removes all items from the store. 28 | Clear() error 29 | // IsTrustedIdentity returns "true" if the given identity key for the given address is already trusted. 30 | // 31 | // If there is no entry for the given address, the given identity key is trusted. 32 | IsTrustedIdentity(ctx context.Context, address address.Address, identity Key, direction direction.Direction) (bool, error) 33 | } 34 | -------------------------------------------------------------------------------- /protocol/identity/store_inmem.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/RTann/libsignal-go/protocol/address" 7 | "github.com/RTann/libsignal-go/protocol/direction" 8 | ) 9 | 10 | var _ Store = (*inMemStore)(nil) 11 | 12 | // inMemStore represents an in-memory identity key store. 13 | type inMemStore struct { 14 | keyPair KeyPair 15 | registrationID uint32 16 | knownKeys map[address.Address]Key 17 | } 18 | 19 | // NewInMemStore creates a new in-memory identity key store. 20 | func NewInMemStore(keyPair KeyPair, registrationID uint32) Store { 21 | return &inMemStore{ 22 | keyPair: keyPair, 23 | registrationID: registrationID, 24 | knownKeys: make(map[address.Address]Key), 25 | } 26 | } 27 | 28 | func (i *inMemStore) KeyPair(_ context.Context) KeyPair { 29 | return i.keyPair 30 | } 31 | 32 | func (i *inMemStore) LocalRegistrationID(_ context.Context) uint32 { 33 | return i.registrationID 34 | } 35 | 36 | func (i *inMemStore) Load(_ context.Context, address address.Address) (Key, bool, error) { 37 | key, exists := i.knownKeys[address] 38 | return key, exists, nil 39 | } 40 | 41 | func (i *inMemStore) Store(_ context.Context, address address.Address, identity Key) (bool, error) { 42 | knownIdentity, exists := i.knownKeys[address] 43 | i.knownKeys[address] = identity 44 | 45 | return exists && identity != knownIdentity, nil 46 | } 47 | 48 | func (i *inMemStore) Clear() error { 49 | tmp := i.knownKeys 50 | i.knownKeys = make(map[address.Address]Key) 51 | 52 | go func() { 53 | for k := range tmp { 54 | delete(tmp, k) 55 | } 56 | }() 57 | 58 | return nil 59 | } 60 | 61 | func (i *inMemStore) IsTrustedIdentity(_ context.Context, address address.Address, identity Key, _ direction.Direction) (bool, error) { 62 | knownIdentity, exists := i.knownKeys[address] 63 | if !exists { 64 | return true, nil 65 | } 66 | 67 | return identity.Equal(knownIdentity), nil 68 | } 69 | -------------------------------------------------------------------------------- /protocol/internal/pointer/pointer.go: -------------------------------------------------------------------------------- 1 | // Package pointer implements utility functions for pointers. 2 | package pointer 3 | 4 | // To returns a pointer to the given type. 5 | func To[T any](t T) *T { 6 | return &t 7 | } 8 | -------------------------------------------------------------------------------- /protocol/internal/tools/tools.go: -------------------------------------------------------------------------------- 1 | //go:build tools 2 | 3 | // Package tools exists simply to enable gotools. 4 | package tools 5 | 6 | import _ "golang.org/x/tools/cmd/stringer" 7 | -------------------------------------------------------------------------------- /protocol/message/ciphertext.go: -------------------------------------------------------------------------------- 1 | // Package message defines protocol messages. 2 | package message 3 | 4 | //go:generate stringer -type=CiphertextType 5 | 6 | // CiphertextType represents a protocol message type. 7 | type CiphertextType int 8 | 9 | const ( 10 | WhisperType CiphertextType = 2 11 | PreKeyType CiphertextType = 3 12 | SenderKeyType CiphertextType = 7 13 | PlaintextType CiphertextType = 8 14 | ) 15 | 16 | const ( 17 | // CiphertextVersion is the current version of ciphertext messages. 18 | CiphertextVersion = 3 19 | // SenderKeyVersion is the current version of sender-key messages. 20 | SenderKeyVersion = 3 21 | ) 22 | 23 | // Ciphertext defines a ciphertext message. 24 | type Ciphertext interface { 25 | // Type is the CiphertextType of the message. 26 | Type() CiphertextType 27 | // Bytes returns an encoding of the Ciphertext message. 28 | Bytes() []byte 29 | } 30 | -------------------------------------------------------------------------------- /protocol/message/ciphertexttype_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type=CiphertextType"; DO NOT EDIT. 2 | 3 | package message 4 | 5 | import "strconv" 6 | 7 | func _() { 8 | // An "invalid array index" compiler error signifies that the constant values have changed. 9 | // Re-run the stringer command to generate them again. 10 | var x [1]struct{} 11 | _ = x[WhisperType-2] 12 | _ = x[PreKeyType-3] 13 | _ = x[SenderKeyType-7] 14 | _ = x[PlaintextType-8] 15 | } 16 | 17 | const ( 18 | _CiphertextType_name_0 = "WhisperTypePreKeyType" 19 | _CiphertextType_name_1 = "SenderKeyTypePlaintextType" 20 | ) 21 | 22 | var ( 23 | _CiphertextType_index_0 = [...]uint8{0, 11, 21} 24 | _CiphertextType_index_1 = [...]uint8{0, 13, 26} 25 | ) 26 | 27 | func (i CiphertextType) String() string { 28 | switch { 29 | case 2 <= i && i <= 3: 30 | i -= 2 31 | return _CiphertextType_name_0[_CiphertextType_index_0[i]:_CiphertextType_index_0[i+1]] 32 | case 7 <= i && i <= 8: 33 | i -= 7 34 | return _CiphertextType_name_1[_CiphertextType_index_1[i]:_CiphertextType_index_1[i+1]] 35 | default: 36 | return "CiphertextType(" + strconv.FormatInt(int64(i), 10) + ")" 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /protocol/message/mac.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | 7 | "github.com/RTann/libsignal-go/protocol/identity" 8 | "github.com/RTann/libsignal-go/protocol/perrors" 9 | ) 10 | 11 | const ( 12 | macKeySize = 32 13 | macSize = 8 14 | ) 15 | 16 | // mac calculates a message authentication code for the 17 | // given message and keys via the Keyed-Hash Message Authentication Code. 18 | func mac(macKey []byte, senderIdentityKey, receiverIdentityKey identity.Key, message []byte) ([]byte, error) { 19 | if len(macKey) != macKeySize { 20 | return nil, perrors.ErrInvalidKeyLength(macKeySize, len(macKey)) 21 | } 22 | 23 | hash := hmac.New(sha256.New, macKey) 24 | hash.Write(senderIdentityKey.Bytes()) 25 | hash.Write(receiverIdentityKey.Bytes()) 26 | hash.Write(message) 27 | 28 | m := make([]byte, 0, sha256.Size) 29 | return hash.Sum(m)[:macSize], nil 30 | } 31 | -------------------------------------------------------------------------------- /protocol/message/plaintext.go: -------------------------------------------------------------------------------- 1 | // TODO: This is incomplete. 2 | 3 | package message 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | ) 9 | 10 | const ( 11 | plaintextContextIdentifier = 0xC0 12 | paddingBoundary = 0x80 13 | ) 14 | 15 | var _ Ciphertext = (*Plaintext)(nil) 16 | 17 | // Plaintext represents a plaintext message. 18 | type Plaintext struct { 19 | serialized []byte 20 | } 21 | 22 | func NewPlaintextFromBytes(bytes []byte) (*Plaintext, error) { 23 | if len(bytes) == 0 { 24 | return nil, errors.New("message too short") 25 | } 26 | 27 | if bytes[0] != plaintextContextIdentifier { 28 | return nil, fmt.Errorf("unsupported message version: %d != %d", uint32(bytes[0]), uint32(plaintextContextIdentifier)) 29 | } 30 | 31 | return &Plaintext{ 32 | serialized: bytes, 33 | }, nil 34 | } 35 | 36 | func (*Plaintext) Type() CiphertextType { 37 | return PlaintextType 38 | } 39 | 40 | func (p *Plaintext) Bytes() []byte { 41 | return p.serialized 42 | } 43 | 44 | func (p *Plaintext) Message() []byte { 45 | return p.serialized[1:] 46 | } 47 | -------------------------------------------------------------------------------- /protocol/message/prekey.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | 8 | "google.golang.org/protobuf/proto" 9 | 10 | "github.com/RTann/libsignal-go/protocol/curve" 11 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 12 | "github.com/RTann/libsignal-go/protocol/identity" 13 | "github.com/RTann/libsignal-go/protocol/internal/pointer" 14 | "github.com/RTann/libsignal-go/protocol/prekey" 15 | ) 16 | 17 | var _ Ciphertext = (*PreKey)(nil) 18 | 19 | // PreKey represents a pre-key message. 20 | type PreKey struct { 21 | version uint8 22 | registrationID uint32 23 | preKeyID *prekey.ID 24 | signedPreKeyID prekey.ID 25 | baseKey curve.PublicKey 26 | identityKey identity.Key 27 | message *Signal 28 | serialized []byte 29 | } 30 | 31 | // PreKeyConfig represents the configuration for a PreKey message. 32 | type PreKeyConfig struct { 33 | Version uint8 34 | RegistrationID uint32 35 | PreKeyID *prekey.ID 36 | SignedPreKeyID prekey.ID 37 | BaseKey curve.PublicKey 38 | IdentityKey identity.Key 39 | Message *Signal 40 | } 41 | 42 | func NewPreKey(cfg PreKeyConfig) (Ciphertext, error) { 43 | message, err := proto.Marshal(&v1.PreKeySignalMessage{ 44 | RegistrationId: &cfg.RegistrationID, 45 | PreKeyId: (*uint32)(cfg.PreKeyID), 46 | SignedPreKeyId: (*uint32)(&cfg.SignedPreKeyID), 47 | BaseKey: cfg.BaseKey.Bytes(), 48 | IdentityKey: cfg.IdentityKey.Bytes(), 49 | Message: cfg.Message.Bytes(), 50 | }) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | versionPrefix := ((cfg.Version & 0xF) << 4) | CiphertextVersion 56 | 57 | serialized := bytes.NewBuffer(make([]byte, 0, 1+len(message))) 58 | serialized.WriteByte(versionPrefix) 59 | serialized.Write(message) 60 | 61 | return &PreKey{ 62 | version: cfg.Version, 63 | registrationID: cfg.RegistrationID, 64 | preKeyID: cfg.PreKeyID, 65 | signedPreKeyID: cfg.SignedPreKeyID, 66 | baseKey: cfg.BaseKey, 67 | identityKey: cfg.IdentityKey, 68 | message: cfg.Message, 69 | serialized: serialized.Bytes(), 70 | }, nil 71 | } 72 | 73 | func NewPreKeyFromBytes(bytes []byte) (Ciphertext, error) { 74 | if len(bytes) == 0 { 75 | return nil, errors.New("message too short") 76 | } 77 | 78 | version := bytes[0] >> 4 79 | if int(version) != CiphertextVersion { 80 | return nil, fmt.Errorf("unsupported message version: %d != %d", int(version), CiphertextVersion) 81 | } 82 | 83 | var message v1.PreKeySignalMessage 84 | err := proto.Unmarshal(bytes[1:], &message) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | baseKey, err := curve.NewPublicKey(message.GetBaseKey()) 90 | if err != nil { 91 | return nil, err 92 | } 93 | identityKey, err := identity.NewKey(message.GetIdentityKey()) 94 | if err != nil { 95 | return nil, err 96 | } 97 | msg, err := NewSignalFromBytes(message.GetMessage()) 98 | if err != nil { 99 | return nil, err 100 | } 101 | 102 | var preKeyID *prekey.ID 103 | if message.PreKeyId != nil { 104 | preKeyID = pointer.To(prekey.ID(message.GetPreKeyId())) 105 | } 106 | 107 | return &PreKey{ 108 | version: version, 109 | registrationID: message.GetRegistrationId(), 110 | preKeyID: preKeyID, 111 | signedPreKeyID: prekey.ID(message.GetSignedPreKeyId()), 112 | baseKey: baseKey, 113 | identityKey: identityKey, 114 | message: msg.(*Signal), 115 | serialized: bytes, 116 | }, nil 117 | } 118 | 119 | func (*PreKey) Type() CiphertextType { 120 | return PreKeyType 121 | } 122 | 123 | func (p *PreKey) Bytes() []byte { 124 | return p.serialized 125 | } 126 | 127 | func (p *PreKey) Version() uint8 { 128 | return p.version 129 | } 130 | 131 | func (p *PreKey) RegistrationID() uint32 { 132 | return p.registrationID 133 | } 134 | 135 | func (p *PreKey) PreKeyID() *prekey.ID { 136 | return p.preKeyID 137 | } 138 | 139 | func (p *PreKey) SignedPreKeyID() prekey.ID { 140 | return p.signedPreKeyID 141 | } 142 | 143 | func (p *PreKey) BaseKey() curve.PublicKey { 144 | return p.baseKey 145 | } 146 | 147 | func (p *PreKey) IdentityKey() identity.Key { 148 | return p.identityKey 149 | } 150 | 151 | func (p *PreKey) Message() *Signal { 152 | return p.message 153 | } 154 | -------------------------------------------------------------------------------- /protocol/message/prekey_test.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/RTann/libsignal-go/protocol/curve" 11 | "github.com/RTann/libsignal-go/protocol/identity" 12 | ) 13 | 14 | func TestPreKey(t *testing.T) { 15 | random := rand.Reader 16 | 17 | identityKeyPair, err := identity.GenerateKeyPair(random) 18 | require.NoError(t, err) 19 | baseKeyPair, err := curve.GenerateKeyPair(random) 20 | require.NoError(t, err) 21 | signalMsg := testSignalMsg(t) 22 | 23 | preKey1, err := NewPreKey(PreKeyConfig{ 24 | Version: 3, 25 | RegistrationID: 365, 26 | PreKeyID: nil, 27 | SignedPreKeyID: 97, 28 | BaseKey: baseKeyPair.PublicKey(), 29 | IdentityKey: identityKeyPair.IdentityKey(), 30 | Message: signalMsg, 31 | }) 32 | assert.NoError(t, err) 33 | 34 | preKey2, err := NewPreKeyFromBytes(preKey1.Bytes()) 35 | assert.NoError(t, err) 36 | 37 | msg1, msg2 := preKey1.(*PreKey), preKey2.(*PreKey) 38 | 39 | assert.Equal(t, msg1.version, msg2.version) 40 | assert.Equal(t, msg1.registrationID, msg2.registrationID) 41 | assert.Equal(t, msg1.preKeyID, msg2.preKeyID) 42 | assert.Equal(t, msg1.signedPreKeyID, msg2.signedPreKeyID) 43 | assert.True(t, msg1.baseKey.Equal(msg2.baseKey)) 44 | assert.True(t, msg1.identityKey.Equal(msg2.identityKey)) 45 | assertSignalEquals(t, msg1.message, msg2.message) 46 | assert.Equal(t, msg1.serialized, msg2.serialized) 47 | } 48 | -------------------------------------------------------------------------------- /protocol/message/senderkey.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "io" 8 | 9 | "google.golang.org/protobuf/proto" 10 | 11 | "github.com/RTann/libsignal-go/protocol/curve" 12 | "github.com/RTann/libsignal-go/protocol/distribution" 13 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 14 | "github.com/RTann/libsignal-go/protocol/perrors" 15 | ) 16 | 17 | var _ Ciphertext = (*SenderKey)(nil) 18 | 19 | // SenderKey represents a sender key message 20 | // for group messaging. 21 | type SenderKey struct { 22 | version uint8 23 | distID distribution.ID 24 | chainID uint32 25 | iteration uint32 26 | ciphertext []byte 27 | serialized []byte 28 | } 29 | 30 | // SenderKeyConfig represents the configuration for a SenderKey message. 31 | type SenderKeyConfig struct { 32 | Version uint8 33 | DistID distribution.ID 34 | ChainID uint32 35 | Iteration uint32 36 | Ciphertext []byte 37 | SignatureKey curve.PrivateKey 38 | } 39 | 40 | func NewSenderKey(random io.Reader, cfg SenderKeyConfig) (*SenderKey, error) { 41 | message, err := proto.Marshal(&v1.SenderKeyMessage{ 42 | DistributionUuid: []byte(cfg.DistID.String()), 43 | ChainId: &cfg.ChainID, 44 | Iteration: &cfg.Iteration, 45 | Ciphertext: cfg.Ciphertext, 46 | }) 47 | if err != nil { 48 | return nil, err 49 | } 50 | 51 | versionPrefix := ((cfg.Version & 0xF) << 4) | SenderKeyVersion 52 | 53 | serialized := bytes.NewBuffer(make([]byte, 0, 1+len(message)+curve.SignatureSize)) 54 | serialized.WriteByte(versionPrefix) 55 | serialized.Write(message) 56 | 57 | signature, err := cfg.SignatureKey.Sign(random, serialized.Bytes()) 58 | if err != nil { 59 | return nil, err 60 | } 61 | serialized.Write(signature) 62 | 63 | return &SenderKey{ 64 | version: SenderKeyVersion, 65 | distID: cfg.DistID, 66 | chainID: cfg.ChainID, 67 | iteration: cfg.Iteration, 68 | ciphertext: cfg.Ciphertext, 69 | serialized: serialized.Bytes(), 70 | }, nil 71 | } 72 | 73 | func NewSenderKeyFromBytes(bytes []byte) (*SenderKey, error) { 74 | if len(bytes) < 1+curve.SignatureSize { 75 | return nil, errors.New("message too short") 76 | } 77 | 78 | version := bytes[0] >> 4 79 | if int(version) != SenderKeyVersion { 80 | return nil, fmt.Errorf("unsupported message version: %d != %d", int(version), SenderKeyVersion) 81 | } 82 | 83 | var message v1.SenderKeyMessage 84 | err := proto.Unmarshal(bytes[1:len(bytes)-curve.SignatureSize], &message) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | distID, err := distribution.ParseBytes(message.GetDistributionUuid()) 90 | if err != nil { 91 | return nil, err 92 | } 93 | 94 | return &SenderKey{ 95 | version: version, 96 | distID: distID, 97 | chainID: message.GetChainId(), 98 | iteration: message.GetIteration(), 99 | ciphertext: message.GetCiphertext(), 100 | serialized: bytes, 101 | }, nil 102 | } 103 | 104 | func (*SenderKey) Type() CiphertextType { 105 | return SenderKeyType 106 | } 107 | 108 | func (s *SenderKey) Bytes() []byte { 109 | return s.serialized 110 | } 111 | 112 | func (s *SenderKey) Version() uint8 { 113 | return s.version 114 | } 115 | 116 | func (s *SenderKey) DistributionID() distribution.ID { 117 | return s.distID 118 | } 119 | 120 | func (s *SenderKey) ChainID() uint32 { 121 | return s.chainID 122 | } 123 | 124 | func (s *SenderKey) Iteration() uint32 { 125 | return s.iteration 126 | } 127 | 128 | func (s *SenderKey) Message() []byte { 129 | return s.ciphertext 130 | } 131 | 132 | // VerifySignature verifies the signature is a valid signature 133 | // of the messages by the public key. 134 | func (s *SenderKey) VerifySignature(signatureKey curve.PublicKey) (bool, error) { 135 | idx := len(s.serialized) - curve.SignatureSize 136 | return signatureKey.VerifySignature(s.serialized[idx:], s.serialized[:idx]) 137 | } 138 | 139 | // SenderKeyDistribution represents a sender key distribution 140 | // message for group messaging. 141 | type SenderKeyDistribution struct { 142 | version uint8 143 | distID distribution.ID 144 | chainID uint32 145 | iteration uint32 146 | chainKey []byte 147 | signingKey curve.PublicKey 148 | serialized []byte 149 | } 150 | 151 | // SenderKeyDistConfig represents the configuration for 152 | // a SenderKeyDistribution message. 153 | type SenderKeyDistConfig struct { 154 | Version uint8 155 | DistID distribution.ID 156 | ChainID uint32 157 | Iteration uint32 158 | ChainKey []byte 159 | SigningKey curve.PublicKey 160 | } 161 | 162 | func NewSenderKeyDistribution(cfg SenderKeyDistConfig) (*SenderKeyDistribution, error) { 163 | message, err := proto.Marshal(&v1.SenderKeyDistributionMessage{ 164 | DistributionUuid: []byte(cfg.DistID.String()), 165 | ChainId: &cfg.ChainID, 166 | Iteration: &cfg.Iteration, 167 | ChainKey: cfg.ChainKey, 168 | SigningKey: cfg.SigningKey.Bytes(), 169 | }) 170 | if err != nil { 171 | return nil, err 172 | } 173 | 174 | versionPrefix := ((cfg.Version & 0xF) << 4) | SenderKeyVersion 175 | 176 | serialized := bytes.NewBuffer(make([]byte, 0, 1+len(message))) 177 | serialized.WriteByte(versionPrefix) 178 | serialized.Write(message) 179 | 180 | return &SenderKeyDistribution{ 181 | version: cfg.Version, 182 | distID: cfg.DistID, 183 | chainID: cfg.ChainID, 184 | iteration: cfg.Iteration, 185 | chainKey: cfg.ChainKey, 186 | signingKey: cfg.SigningKey, 187 | serialized: serialized.Bytes(), 188 | }, nil 189 | } 190 | 191 | func NewSenderKeyDistributionFromBytes(bytes []byte) (*SenderKeyDistribution, error) { 192 | // Message must contain key + chain key. 193 | if len(bytes) < 1+32+32 { 194 | return nil, errors.New("message too short") 195 | } 196 | 197 | messageVersion := bytes[0] >> 4 198 | if messageVersion != SenderKeyVersion { 199 | return nil, fmt.Errorf("unsupported message version: %d != %d", int(messageVersion), SenderKeyVersion) 200 | } 201 | 202 | var message v1.SenderKeyDistributionMessage 203 | err := proto.Unmarshal(bytes[1:], &message) 204 | if err != nil { 205 | return nil, err 206 | } 207 | 208 | distID, err := distribution.ParseBytes(message.GetDistributionUuid()) 209 | if err != nil { 210 | return nil, err 211 | } 212 | chainID := message.GetChainId() 213 | iteration := message.GetIteration() 214 | chainKey := message.GetChainKey() 215 | if len(chainKey) != 32 { 216 | return nil, perrors.ErrInvalidKeyLength(32, len(chainKey)) 217 | } 218 | signingKey, err := curve.NewPublicKey(message.GetSigningKey()) 219 | if err != nil { 220 | return nil, err 221 | } 222 | 223 | return &SenderKeyDistribution{ 224 | version: messageVersion, 225 | distID: distID, 226 | chainID: chainID, 227 | iteration: iteration, 228 | chainKey: chainKey, 229 | signingKey: signingKey, 230 | serialized: bytes, 231 | }, nil 232 | } 233 | 234 | func (s *SenderKeyDistribution) Bytes() []byte { 235 | return s.serialized 236 | } 237 | 238 | func (s *SenderKeyDistribution) Version() uint8 { 239 | return s.version 240 | } 241 | 242 | func (s *SenderKeyDistribution) DistributionID() distribution.ID { 243 | return s.distID 244 | } 245 | 246 | func (s *SenderKeyDistribution) ChainID() uint32 { 247 | return s.chainID 248 | } 249 | 250 | func (s *SenderKeyDistribution) Iteration() uint32 { 251 | return s.iteration 252 | } 253 | 254 | func (s *SenderKeyDistribution) ChainKey() []byte { 255 | return s.chainKey 256 | } 257 | 258 | func (s *SenderKeyDistribution) SigningKey() curve.PublicKey { 259 | return s.signingKey 260 | } 261 | -------------------------------------------------------------------------------- /protocol/message/senderkey_test.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/RTann/libsignal-go/protocol/curve" 11 | "github.com/RTann/libsignal-go/protocol/distribution" 12 | ) 13 | 14 | func TestSenderKey(t *testing.T) { 15 | signatureKeyPair, err := curve.GenerateKeyPair(rand.Reader) 16 | require.NoError(t, err) 17 | 18 | senderKey1, err := NewSenderKey(rand.Reader, SenderKeyConfig{ 19 | Version: SenderKeyVersion, 20 | DistID: distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6"), 21 | ChainID: 42, 22 | Iteration: 7, 23 | Ciphertext: []byte{1, 2, 3}, 24 | SignatureKey: signatureKeyPair.PrivateKey(), 25 | }) 26 | assert.NoError(t, err) 27 | 28 | senderKey2, err := NewSenderKeyFromBytes(senderKey1.Bytes()) 29 | assert.NoError(t, err) 30 | 31 | assert.Equal(t, senderKey1.version, senderKey2.version) 32 | assert.Equal(t, senderKey1.distID, senderKey2.distID) 33 | assert.Equal(t, senderKey1.chainID, senderKey2.chainID) 34 | assert.Equal(t, senderKey1.iteration, senderKey2.iteration) 35 | assert.Equal(t, senderKey1.ciphertext, senderKey2.ciphertext) 36 | assert.Equal(t, senderKey1.serialized, senderKey2.serialized) 37 | } 38 | -------------------------------------------------------------------------------- /protocol/message/signal.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "bytes" 5 | "crypto/hmac" 6 | "encoding/hex" 7 | "errors" 8 | "fmt" 9 | 10 | "github.com/golang/glog" 11 | "google.golang.org/protobuf/proto" 12 | 13 | "github.com/RTann/libsignal-go/protocol/curve" 14 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 15 | "github.com/RTann/libsignal-go/protocol/identity" 16 | ) 17 | 18 | var _ Ciphertext = (*Signal)(nil) 19 | 20 | // Signal represents a typical ciphertext message. 21 | type Signal struct { 22 | version uint8 23 | senderRatchetKey curve.PublicKey 24 | previousCounter uint32 25 | counter uint32 26 | ciphertext []byte 27 | serialized []byte 28 | } 29 | 30 | // SignalConfig represents the configuration for a Signal message. 31 | type SignalConfig struct { 32 | Version uint8 33 | MACKey []byte 34 | SenderRatchetKey curve.PublicKey 35 | PreviousCounter uint32 36 | Counter uint32 37 | Ciphertext []byte 38 | SenderIdentityKey identity.Key 39 | ReceiverIdentityKey identity.Key 40 | } 41 | 42 | func NewSignal(cfg SignalConfig) (Ciphertext, error) { 43 | message, err := proto.Marshal(&v1.SignalMessage{ 44 | RatchetKey: cfg.SenderRatchetKey.Bytes(), 45 | Counter: &cfg.Counter, 46 | PreviousCounter: &cfg.PreviousCounter, 47 | Ciphertext: cfg.Ciphertext, 48 | }) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | versionPrefix := ((cfg.Version & 0xF) << 4) | CiphertextVersion 54 | serialized := bytes.NewBuffer(make([]byte, 0, 1+len(message)+macSize)) 55 | serialized.WriteByte(versionPrefix) 56 | serialized.Write(message) 57 | 58 | mac, err := mac(cfg.MACKey, cfg.SenderIdentityKey, cfg.ReceiverIdentityKey, serialized.Bytes()) 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | serialized.Write(mac) 64 | 65 | return &Signal{ 66 | version: cfg.Version, 67 | senderRatchetKey: cfg.SenderRatchetKey, 68 | previousCounter: cfg.PreviousCounter, 69 | counter: cfg.Counter, 70 | ciphertext: cfg.Ciphertext, 71 | serialized: serialized.Bytes(), 72 | }, nil 73 | } 74 | 75 | func NewSignalFromBytes(bytes []byte) (Ciphertext, error) { 76 | if len(bytes) == 0 { 77 | return nil, errors.New("message too short") 78 | } 79 | 80 | version := bytes[0] >> 4 81 | if int(version) != CiphertextVersion { 82 | return nil, fmt.Errorf("unsupported message version: %d != %d", int(version), CiphertextVersion) 83 | } 84 | 85 | var message v1.SignalMessage 86 | err := proto.Unmarshal(bytes[1:len(bytes)-macSize], &message) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | senderRatchetKey, err := curve.NewPublicKey(message.GetRatchetKey()) 92 | if err != nil { 93 | return nil, err 94 | } 95 | 96 | return &Signal{ 97 | version: version, 98 | senderRatchetKey: senderRatchetKey, 99 | previousCounter: message.GetPreviousCounter(), 100 | counter: message.GetCounter(), 101 | ciphertext: message.GetCiphertext(), 102 | serialized: bytes, 103 | }, nil 104 | } 105 | 106 | func (*Signal) Type() CiphertextType { 107 | return WhisperType 108 | } 109 | 110 | func (s *Signal) Bytes() []byte { 111 | return s.serialized 112 | } 113 | 114 | func (s *Signal) Message() []byte { 115 | return s.ciphertext 116 | } 117 | 118 | func (s *Signal) Version() uint8 { 119 | return s.version 120 | } 121 | 122 | func (s *Signal) SenderRatchetKey() curve.PublicKey { 123 | return s.senderRatchetKey 124 | } 125 | 126 | func (s *Signal) Counter() uint32 { 127 | return s.counter 128 | } 129 | 130 | // VerifyMAC verifies the message authentication code (MAC) sent with the signal message 131 | // matches our computed MAC. 132 | // 133 | // The MAC is expected to be an HMAC. 134 | func (s *Signal) VerifyMAC(macKey []byte, senderIdentityKey, receiverIdentityKey identity.Key) (bool, error) { 135 | ourMAC, err := mac(macKey, senderIdentityKey, receiverIdentityKey, s.serialized[:len(s.serialized)-macSize]) 136 | if err != nil { 137 | return false, err 138 | } 139 | theirMAC := s.serialized[len(s.serialized)-macSize:] 140 | equal := hmac.Equal(ourMAC, theirMAC) 141 | if !equal { 142 | glog.Warningf("Bad Mac! Their Mac: %s Our Mac: %s", hex.EncodeToString(theirMAC), hex.EncodeToString(ourMAC)) 143 | } 144 | 145 | return equal, nil 146 | } 147 | -------------------------------------------------------------------------------- /protocol/message/signal_test.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "crypto/rand" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/RTann/libsignal-go/protocol/curve" 12 | "github.com/RTann/libsignal-go/protocol/identity" 13 | ) 14 | 15 | func testSignalMsg(t *testing.T) *Signal { 16 | random := rand.Reader 17 | 18 | macKey := make([]byte, 32) 19 | _, err := io.ReadFull(random, macKey) 20 | require.NoError(t, err) 21 | 22 | ciphertext := make([]byte, 20) 23 | _, err = io.ReadFull(random, ciphertext) 24 | require.NoError(t, err) 25 | 26 | senderRatchetKeyPair, err := curve.GenerateKeyPair(random) 27 | require.NoError(t, err) 28 | senderIdentityKeyPair, err := identity.GenerateKeyPair(random) 29 | require.NoError(t, err) 30 | receiverIdentityKeyPair, err := identity.GenerateKeyPair(random) 31 | require.NoError(t, err) 32 | 33 | signal, err := NewSignal(SignalConfig{ 34 | Version: 3, 35 | MACKey: macKey, 36 | SenderRatchetKey: senderRatchetKeyPair.PublicKey(), 37 | PreviousCounter: 41, 38 | Counter: 42, 39 | Ciphertext: ciphertext, 40 | SenderIdentityKey: senderIdentityKeyPair.IdentityKey(), 41 | ReceiverIdentityKey: receiverIdentityKeyPair.IdentityKey(), 42 | }) 43 | require.NoError(t, err) 44 | 45 | return signal.(*Signal) 46 | } 47 | 48 | func assertSignalEquals(t *testing.T, a, b *Signal) { 49 | assert.Equal(t, a.version, b.version) 50 | assert.True(t, a.senderRatchetKey.Equal(b.senderRatchetKey)) 51 | assert.Equal(t, a.counter, b.counter) 52 | assert.Equal(t, a.previousCounter, b.previousCounter) 53 | assert.Equal(t, a.ciphertext, b.ciphertext) 54 | assert.Equal(t, a.serialized, b.serialized) 55 | } 56 | 57 | func TestSignal(t *testing.T) { 58 | msg1 := testSignalMsg(t) 59 | msg2, err := NewSignalFromBytes(msg1.Bytes()) 60 | assert.NoError(t, err) 61 | assertSignalEquals(t, msg1, msg2.(*Signal)) 62 | } 63 | -------------------------------------------------------------------------------- /protocol/perrors/errors.go: -------------------------------------------------------------------------------- 1 | // Package perrors defines protocol errors. 2 | package perrors 3 | 4 | import ( 5 | "errors" 6 | "fmt" 7 | 8 | "github.com/RTann/libsignal-go/protocol/address" 9 | ) 10 | 11 | var ( 12 | ErrDuplicateMessage = errors.New("duplicate message") 13 | ErrNoCurrentSession = errors.New("no current session") 14 | ) 15 | 16 | var ( 17 | _ error = (*errInvalidKeyLength)(nil) 18 | _ error = (*errSessionNotFound)(nil) 19 | _ error = (*errUntrustedIdentity)(nil) 20 | ) 21 | 22 | type errInvalidKeyLength struct { 23 | expected int 24 | got int 25 | } 26 | 27 | func ErrInvalidKeyLength(expected, got int) error { 28 | return errInvalidKeyLength{ 29 | expected: expected, 30 | got: got, 31 | } 32 | } 33 | 34 | func (e errInvalidKeyLength) Error() string { 35 | return fmt.Sprintf("invalid key length: %d != %d", e.got, e.expected) 36 | } 37 | 38 | type errSessionNotFound struct { 39 | remoteAddress address.Address 40 | } 41 | 42 | func ErrSessionNotFound(remoteAddress address.Address) error { 43 | return errSessionNotFound{ 44 | remoteAddress: remoteAddress, 45 | } 46 | } 47 | 48 | func (e errSessionNotFound) Error() string { 49 | return "session with " + e.remoteAddress.String() + " not found" 50 | } 51 | 52 | type errUntrustedIdentity struct { 53 | remoteAddress address.Address 54 | } 55 | 56 | func ErrUntrustedIdentity(remoteAddress address.Address) error { 57 | return errUntrustedIdentity{ 58 | remoteAddress: remoteAddress, 59 | } 60 | } 61 | 62 | func IsErrUntrustedIdentity(e error) bool { 63 | _, ok := e.(errUntrustedIdentity) 64 | return ok 65 | } 66 | 67 | func (e errUntrustedIdentity) Error() string { 68 | return "untrusted identity for address " + e.remoteAddress.String() 69 | } 70 | -------------------------------------------------------------------------------- /protocol/prekey/bundle.go: -------------------------------------------------------------------------------- 1 | package prekey 2 | 3 | import ( 4 | "github.com/RTann/libsignal-go/protocol/address" 5 | "github.com/RTann/libsignal-go/protocol/curve" 6 | "github.com/RTann/libsignal-go/protocol/identity" 7 | ) 8 | 9 | // Bundle represents a pre-key bundle as defined by the X3DH protocol. 10 | // 11 | // See https://signal.org/docs/specifications/x3dh/#sending-the-initial-message for more information. 12 | type Bundle struct { 13 | RegistrationID uint32 14 | DeviceID address.DeviceID 15 | PreKeyID *ID 16 | PreKeyPublic curve.PublicKey 17 | SignedPreKeyID ID 18 | SignedPreKeyPublic curve.PublicKey 19 | SignedPreKeySignature []byte 20 | IdentityKey identity.Key 21 | } 22 | -------------------------------------------------------------------------------- /protocol/prekey/prekey.go: -------------------------------------------------------------------------------- 1 | // Package prekey defines a pre-key and signed pre-key. 2 | package prekey 3 | 4 | import ( 5 | "strconv" 6 | 7 | "github.com/RTann/libsignal-go/protocol/curve" 8 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 9 | ) 10 | 11 | // ID represents a pre-key identifier. 12 | type ID uint32 13 | 14 | func (i ID) String() string { 15 | return strconv.FormatUint(uint64(i), 10) 16 | } 17 | 18 | // PreKey represents a public pre-key. 19 | type PreKey struct { 20 | preKey *v1.PreKeyRecordStructure 21 | } 22 | 23 | // NewPreKey creates a new pre-key. 24 | func NewPreKey(id ID, key *curve.KeyPair) *PreKey { 25 | return &PreKey{ 26 | preKey: &v1.PreKeyRecordStructure{ 27 | Id: uint32(id), 28 | PublicKey: key.PublicKey().Bytes(), 29 | PrivateKey: key.PrivateKey().Bytes(), 30 | }, 31 | } 32 | } 33 | 34 | // KeyPair returns the pre-key's public/private key pair. 35 | func (s *PreKey) KeyPair() (*curve.KeyPair, error) { 36 | return curve.NewKeyPair(s.preKey.GetPrivateKey(), s.preKey.GetPublicKey()) 37 | } 38 | -------------------------------------------------------------------------------- /protocol/prekey/signed.go: -------------------------------------------------------------------------------- 1 | package prekey 2 | 3 | import ( 4 | "github.com/RTann/libsignal-go/protocol/curve" 5 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 6 | ) 7 | 8 | // SignedPreKey represents a public signed pre-key. 9 | type SignedPreKey struct { 10 | signed *v1.SignedPreKeyRecordStructure 11 | } 12 | 13 | // NewSigned creates a new signed pre-key. 14 | func NewSigned(id ID, timestamp uint64, key *curve.KeyPair, signature []byte) *SignedPreKey { 15 | return &SignedPreKey{ 16 | signed: &v1.SignedPreKeyRecordStructure{ 17 | Id: uint32(id), 18 | PublicKey: key.PublicKey().Bytes(), 19 | PrivateKey: key.PrivateKey().Bytes(), 20 | Signature: signature, 21 | Timestamp: timestamp, 22 | }, 23 | } 24 | } 25 | 26 | // KeyPair returns the signed pre-key's public/private key pair. 27 | func (s *SignedPreKey) KeyPair() (*curve.KeyPair, error) { 28 | return curve.NewKeyPair(s.signed.GetPrivateKey(), s.signed.GetPublicKey()) 29 | } 30 | -------------------------------------------------------------------------------- /protocol/prekey/store.go: -------------------------------------------------------------------------------- 1 | package prekey 2 | 3 | import "context" 4 | 5 | // Store defines a pre-key store. 6 | type Store interface { 7 | // Load fetches the pre-key associated with the id from the store. 8 | Load(ctx context.Context, id ID) (*PreKey, bool, error) 9 | // Store stores the pre-key associated with the given ID in the store. 10 | Store(ctx context.Context, id ID, preKey *PreKey) error 11 | // Delete removes the pre-key entry identified by the given ID from the store. 12 | Delete(ctx context.Context, id ID) error 13 | } 14 | 15 | // SignedStore defines a signed pre-key store. 16 | type SignedStore interface { 17 | // Load fetches the signed pre-key associated with the id from the store. 18 | Load(ctx context.Context, id ID) (*SignedPreKey, bool, error) 19 | // Store stores the signed pre-key associated with the given ID in the store. 20 | Store(ctx context.Context, id ID, record *SignedPreKey) error 21 | } 22 | -------------------------------------------------------------------------------- /protocol/prekey/store_inmem.go: -------------------------------------------------------------------------------- 1 | package prekey 2 | 3 | import "context" 4 | 5 | var _ Store = (*inMemStore)(nil) 6 | 7 | // inMemStore represents an in-memory pre-key store. 8 | type inMemStore struct { 9 | preKeys map[ID]*PreKey 10 | } 11 | 12 | // NewInMemStore creates a new in-memory pre-key store. 13 | func NewInMemStore() Store { 14 | return &inMemStore{ 15 | preKeys: make(map[ID]*PreKey), 16 | } 17 | } 18 | 19 | func (i *inMemStore) Load(_ context.Context, id ID) (*PreKey, bool, error) { 20 | record, exists := i.preKeys[id] 21 | return record, exists, nil 22 | } 23 | 24 | func (i *inMemStore) Store(_ context.Context, id ID, record *PreKey) error { 25 | i.preKeys[id] = record 26 | return nil 27 | } 28 | 29 | func (i *inMemStore) Delete(_ context.Context, id ID) error { 30 | delete(i.preKeys, id) 31 | return nil 32 | } 33 | 34 | var _ SignedStore = (*inMemSignedStore)(nil) 35 | 36 | // inMemSignedStore represents an in-memory signed pre-key store. 37 | type inMemSignedStore struct { 38 | signedPreKeys map[ID]*SignedPreKey 39 | } 40 | 41 | // NewInMemSignedStore creates a new in-memory signed pre-key store. 42 | func NewInMemSignedStore() SignedStore { 43 | return &inMemSignedStore{ 44 | signedPreKeys: make(map[ID]*SignedPreKey), 45 | } 46 | } 47 | 48 | func (i *inMemSignedStore) Load(_ context.Context, id ID) (*SignedPreKey, bool, error) { 49 | record, exists := i.signedPreKeys[id] 50 | return record, exists, nil 51 | } 52 | 53 | func (i *inMemSignedStore) Store(_ context.Context, id ID, record *SignedPreKey) error { 54 | i.signedPreKeys[id] = record 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /protocol/proto/v1/fingerprint.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | // 4 | // Copyright 2020 Signal Messenger, LLC. 5 | // SPDX-License-Identifier: AGPL-3.0-only 6 | // 7 | 8 | package signal.proto.fingerprint; 9 | 10 | option go_package = "/v1"; 11 | 12 | message LogicalFingerprint { 13 | optional bytes content = 1; 14 | // bytes identifier = 2; 15 | } 16 | 17 | message CombinedFingerprints { 18 | optional uint32 version = 1; 19 | optional LogicalFingerprint local_fingerprint = 2; 20 | optional LogicalFingerprint remote_fingerprint = 3; 21 | } 22 | -------------------------------------------------------------------------------- /protocol/proto/v1/sealed_sender.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | // 4 | // Copyright 2020 Signal Messenger, LLC. 5 | // SPDX-License-Identifier: AGPL-3.0-only 6 | // 7 | 8 | package signal.proto.sealed_sender; 9 | 10 | option go_package = "/v1"; 11 | 12 | message ServerCertificate { 13 | message Certificate { 14 | optional uint32 id = 1; 15 | optional bytes key = 2; 16 | } 17 | 18 | optional bytes certificate = 1; 19 | optional bytes signature = 2; 20 | } 21 | 22 | message SenderCertificate { 23 | message Certificate { 24 | optional string senderE164 = 1; 25 | optional string senderUuid = 6; 26 | optional uint32 senderDevice = 2; 27 | optional fixed64 expires = 3; 28 | optional bytes identityKey = 4; 29 | optional ServerCertificate signer = 5; 30 | } 31 | 32 | optional bytes certificate = 1; 33 | optional bytes signature = 2; 34 | } 35 | 36 | message UnidentifiedSenderMessage { 37 | 38 | message Message { 39 | enum Type { 40 | PREKEY_MESSAGE = 1; 41 | MESSAGE = 2; 42 | // Further cases should line up with Envelope.Type, even though old cases don't. 43 | reserved 3 to 6; 44 | SENDERKEY_MESSAGE = 7; 45 | PLAINTEXT_CONTENT = 8; 46 | } 47 | 48 | enum ContentHint { 49 | reserved 0; // Default: sender will not resend; an error should be shown immediately 50 | RESENDABLE = 1; // Sender will try to resend; delay any error UI if possible 51 | IMPLICIT = 2; // Don't show any error UI at all; this is something sent implicitly like a typing message or a receipt 52 | } 53 | 54 | optional Type type = 1; 55 | optional bytes /*SenderCertificate*/ senderCertificate = 2; 56 | optional bytes content = 3; 57 | optional ContentHint contentHint = 4; 58 | optional bytes groupId = 5; 59 | } 60 | 61 | optional bytes ephemeralPublic = 1; 62 | optional bytes encryptedStatic = 2; 63 | optional bytes encryptedMessage = 3; 64 | } 65 | -------------------------------------------------------------------------------- /protocol/proto/v1/storage.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | // 4 | // Copyright 2020 Signal Messenger, LLC. 5 | // SPDX-License-Identifier: AGPL-3.0-only 6 | // 7 | 8 | package signal.proto.storage; 9 | 10 | option go_package = "/v1"; 11 | 12 | message SessionStructure { 13 | message Chain { 14 | bytes sender_ratchet_key = 1; 15 | bytes sender_ratchet_key_private = 2; 16 | 17 | message ChainKey { 18 | uint32 index = 1; 19 | bytes key = 2; 20 | } 21 | 22 | ChainKey chain_key = 3; 23 | 24 | message MessageKey { 25 | uint32 index = 1; 26 | bytes cipher_key = 2; 27 | bytes mac_key = 3; 28 | bytes iv = 4; 29 | } 30 | 31 | repeated MessageKey message_keys = 4; 32 | } 33 | 34 | message PendingPreKey { 35 | uint32 pre_key_id = 1; 36 | int32 signed_pre_key_id = 3; 37 | bytes base_key = 2; 38 | } 39 | 40 | uint32 session_version = 1; 41 | bytes local_identity_public = 2; 42 | bytes remote_identity_public = 3; 43 | 44 | bytes root_key = 4; 45 | uint32 previous_counter = 5; 46 | 47 | Chain sender_chain = 6; 48 | // The order is significant; keys at the end are "older" and will get trimmed. 49 | repeated Chain receiver_chains = 7; 50 | 51 | PendingPreKey pending_pre_key = 9; 52 | 53 | uint32 remote_registration_id = 10; 54 | uint32 local_registration_id = 11; 55 | 56 | reserved 12; // no longer used 57 | bytes alice_base_key = 13; 58 | } 59 | 60 | message RecordStructure { 61 | SessionStructure current_session = 1; 62 | // The order is significant; sessions at the end are "older" and will get trimmed. 63 | repeated /*SessionStructure*/ bytes previous_sessions = 2; 64 | } 65 | 66 | message PreKeyRecordStructure { 67 | uint32 id = 1; 68 | bytes public_key = 2; 69 | bytes private_key = 3; 70 | } 71 | 72 | message SignedPreKeyRecordStructure { 73 | uint32 id = 1; 74 | bytes public_key = 2; 75 | bytes private_key = 3; 76 | bytes signature = 4; 77 | fixed64 timestamp = 5; 78 | } 79 | 80 | message IdentityKeyPairStructure { 81 | bytes public_key = 1; 82 | bytes private_key = 2; 83 | } 84 | 85 | message SenderKeyStateStructure { 86 | message SenderChainKey { 87 | uint32 iteration = 1; 88 | bytes seed = 2; 89 | } 90 | 91 | message SenderMessageKey { 92 | uint32 iteration = 1; 93 | bytes seed = 2; 94 | } 95 | 96 | message SenderSigningKey { 97 | bytes public = 1; 98 | bytes private = 2; 99 | } 100 | 101 | uint32 message_version = 5; 102 | uint32 chain_id = 1; 103 | SenderChainKey sender_chain_key = 2; 104 | SenderSigningKey sender_signing_key = 3; 105 | repeated SenderMessageKey sender_message_keys = 4; 106 | } 107 | 108 | message SenderKeyRecordStructure { 109 | repeated SenderKeyStateStructure sender_key_states = 1; 110 | } 111 | -------------------------------------------------------------------------------- /protocol/proto/v1/wire.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | // 4 | // Copyright 2020 Signal Messenger, LLC. 5 | // SPDX-License-Identifier: AGPL-3.0-only 6 | // 7 | 8 | package signal.proto.wire; 9 | 10 | option go_package = "/v1"; 11 | 12 | message SignalMessage { 13 | optional bytes ratchet_key = 1; 14 | optional uint32 counter = 2; 15 | optional uint32 previous_counter = 3; 16 | optional bytes ciphertext = 4; 17 | } 18 | 19 | message PreKeySignalMessage { 20 | optional uint32 registration_id = 5; 21 | optional uint32 pre_key_id = 1; 22 | optional uint32 signed_pre_key_id = 6; 23 | optional bytes base_key = 2; 24 | optional bytes identity_key = 3; 25 | optional bytes message = 4; // SignalMessage 26 | } 27 | 28 | message SenderKeyMessage { 29 | optional bytes distribution_uuid = 1; 30 | optional uint32 chain_id = 2; 31 | optional uint32 iteration = 3; 32 | optional bytes ciphertext = 4; 33 | } 34 | 35 | message SenderKeyDistributionMessage { 36 | optional bytes distribution_uuid = 1; 37 | optional uint32 chain_id = 2; 38 | optional uint32 iteration = 3; 39 | optional bytes chain_key = 4; 40 | optional bytes signing_key = 5; 41 | } 42 | -------------------------------------------------------------------------------- /protocol/protocol/store.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "github.com/RTann/libsignal-go/protocol/identity" 5 | "github.com/RTann/libsignal-go/protocol/prekey" 6 | "github.com/RTann/libsignal-go/protocol/session" 7 | ) 8 | 9 | type Store interface { 10 | SessionStore() session.Store 11 | IdentityStore() identity.Store 12 | PreKeyStore() prekey.Store 13 | SignedPreKeyStore() prekey.SignedStore 14 | GroupStore() session.GroupStore 15 | } 16 | -------------------------------------------------------------------------------- /protocol/protocol/store_inmem.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "github.com/RTann/libsignal-go/protocol/identity" 5 | "github.com/RTann/libsignal-go/protocol/prekey" 6 | "github.com/RTann/libsignal-go/protocol/session" 7 | ) 8 | 9 | type inMemSignalProtocolStore struct { 10 | sessionStore session.Store 11 | preKeyStore prekey.Store 12 | signedPreKeyStore prekey.SignedStore 13 | identityStore identity.Store 14 | groupStore session.GroupStore 15 | } 16 | 17 | func NewInMemStore(keyPair identity.KeyPair, registrationID uint32) Store { 18 | return &inMemSignalProtocolStore{ 19 | sessionStore: session.NewInMemStore(), 20 | preKeyStore: prekey.NewInMemStore(), 21 | signedPreKeyStore: prekey.NewInMemSignedStore(), 22 | identityStore: identity.NewInMemStore(keyPair, registrationID), 23 | groupStore: session.NewInMemGroupStore(), 24 | } 25 | } 26 | 27 | func (i *inMemSignalProtocolStore) SessionStore() session.Store { 28 | return i.sessionStore 29 | } 30 | 31 | func (i *inMemSignalProtocolStore) IdentityStore() identity.Store { 32 | return i.identityStore 33 | } 34 | 35 | func (i *inMemSignalProtocolStore) PreKeyStore() prekey.Store { 36 | return i.preKeyStore 37 | } 38 | 39 | func (i *inMemSignalProtocolStore) SignedPreKeyStore() prekey.SignedStore { 40 | return i.signedPreKeyStore 41 | } 42 | 43 | func (i *inMemSignalProtocolStore) GroupStore() session.GroupStore { 44 | return i.groupStore 45 | } 46 | -------------------------------------------------------------------------------- /protocol/ratchet/keys.go: -------------------------------------------------------------------------------- 1 | package ratchet 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | "io" 7 | 8 | "golang.org/x/crypto/hkdf" 9 | 10 | "github.com/RTann/libsignal-go/protocol/curve" 11 | "github.com/RTann/libsignal-go/protocol/perrors" 12 | ) 13 | 14 | const ( 15 | cipherKeySize = 32 16 | macKeySize = 32 17 | ivSize = 16 18 | messageKeysInfo = "WhisperMessageKeys" 19 | 20 | ChainKeySize = 32 21 | 22 | RootKeySize = 32 23 | rootInfo = "WhisperRatchet" 24 | ) 25 | 26 | var ( 27 | messageKeySeed = []byte{0x01} 28 | chainKeySeed = []byte{0x02} 29 | ) 30 | 31 | // MessageKeys defines the keys used to encrypt messages. 32 | type MessageKeys struct { 33 | cipherKey []byte 34 | macKey []byte 35 | // initialization vector 36 | iv []byte 37 | counter uint32 38 | } 39 | 40 | // NewMessageKeys derives message keys from the given inputs. 41 | func NewMessageKeys(cipherKey, macKey, iv []byte, counter uint32) (MessageKeys, error) { 42 | if len(cipherKey) != cipherKeySize { 43 | return MessageKeys{}, perrors.ErrInvalidKeyLength(cipherKeySize, len(cipherKey)) 44 | } 45 | if len(macKey) != macKeySize { 46 | return MessageKeys{}, perrors.ErrInvalidKeyLength(macKeySize, len(macKey)) 47 | } 48 | if len(iv) != ivSize { 49 | return MessageKeys{}, perrors.ErrInvalidKeyLength(ivSize, len(iv)) 50 | } 51 | 52 | return MessageKeys{ 53 | cipherKey: cipherKey, 54 | macKey: macKey, 55 | iv: iv, 56 | counter: counter, 57 | }, nil 58 | } 59 | 60 | // DeriveMessageKeys derives message keys from the given input material and counter. 61 | // 62 | // The input material is used as the secret for HKDF. 63 | func DeriveMessageKeys(inputKeyMaterial []byte, counter uint32) (MessageKeys, error) { 64 | kdf := hkdf.New(sha256.New, inputKeyMaterial, nil, []byte(messageKeysInfo)) 65 | // 32 + 32 + 16 = 80 66 | outputKeyMaterial := make([]byte, 80) 67 | _, err := io.ReadFull(kdf, outputKeyMaterial) 68 | if err != nil { 69 | return MessageKeys{}, err 70 | } 71 | 72 | return MessageKeys{ 73 | cipherKey: outputKeyMaterial[:32], 74 | macKey: outputKeyMaterial[32:64], 75 | iv: outputKeyMaterial[64:], 76 | counter: counter, 77 | }, nil 78 | } 79 | 80 | // CipherKey returns a block cipher key. 81 | func (m MessageKeys) CipherKey() []byte { 82 | return m.cipherKey 83 | } 84 | 85 | // MACKey returns a key used for a MAC function like HMAC. 86 | func (m MessageKeys) MACKey() []byte { 87 | return m.macKey 88 | } 89 | 90 | // IV returns an initialization vector used 91 | // for encryption and decryption. 92 | func (m MessageKeys) IV() []byte { 93 | return m.iv 94 | } 95 | 96 | // Counter returns the corresponding index in the chain. 97 | func (m MessageKeys) Counter() uint32 { 98 | return m.counter 99 | } 100 | 101 | // ChainKey represents a sending or receiving chain key 102 | // used for Symmetric-key ratchet. 103 | type ChainKey struct { 104 | key []byte 105 | index uint32 106 | } 107 | 108 | // NewChainKey derives a chain key from the given key and index. 109 | func NewChainKey(key []byte, index uint32) (ChainKey, error) { 110 | if len(key) != ChainKeySize { 111 | return ChainKey{}, perrors.ErrInvalidKeyLength(ChainKeySize, len(key)) 112 | } 113 | 114 | return ChainKey{ 115 | key: key, 116 | index: index, 117 | }, nil 118 | } 119 | 120 | // Index returns the index of chain key in the 121 | // sending or receiving chain. 122 | func (c ChainKey) Index() uint32 { 123 | return c.index 124 | } 125 | 126 | // Key returns an encoding of the chain key. 127 | func (c ChainKey) Key() []byte { 128 | return c.key 129 | } 130 | 131 | // Next derives the next chain key in the 132 | // sending or receiving chain. 133 | func (c ChainKey) Next() ChainKey { 134 | return ChainKey{ 135 | key: hash(c.key, chainKeySeed), 136 | index: c.index + 1, 137 | } 138 | } 139 | 140 | // MessageKeys performs a Symmetric-key ratchet step 141 | // to derive new message keys. 142 | func (c ChainKey) MessageKeys() (MessageKeys, error) { 143 | return DeriveMessageKeys(hash(c.key, messageKeySeed), c.index) 144 | } 145 | 146 | // hash returns the HMAC hash of the seed using the given key. 147 | func hash(key, seed []byte) []byte { 148 | buf := make([]byte, 0, ChainKeySize) 149 | hash := hmac.New(sha256.New, key) 150 | hash.Write(seed) 151 | buf = hash.Sum(buf) 152 | return buf 153 | } 154 | 155 | // RootKey is a key used for the root chain in the Double Ratchet algorithm. 156 | type RootKey struct { 157 | key []byte 158 | } 159 | 160 | // NewRootKey derives a root key from the given bytes. 161 | func NewRootKey(key []byte) (RootKey, error) { 162 | if len(key) != RootKeySize { 163 | return RootKey{}, perrors.ErrInvalidKeyLength(RootKeySize, len(key)) 164 | } 165 | 166 | return RootKey{ 167 | key: key, 168 | }, nil 169 | } 170 | 171 | // Bytes returns an encoding of the root key. 172 | func (r RootKey) Bytes() []byte { 173 | return r.key 174 | } 175 | 176 | // CreateChain performs a single Diffie-Hellman ratchet step to 177 | // create a new root key and chain key. 178 | func (r RootKey) CreateChain(ourRatchetKey curve.PrivateKey, theirRatchetKey curve.PublicKey) (RootKey, ChainKey, error) { 179 | sharedSecret, err := ourRatchetKey.Agreement(theirRatchetKey) 180 | if err != nil { 181 | return RootKey{}, ChainKey{}, err 182 | } 183 | 184 | derivedSecret := make([]byte, 64) 185 | kdf := hkdf.New(sha256.New, sharedSecret, r.key, []byte(rootInfo)) 186 | _, err = io.ReadFull(kdf, derivedSecret) 187 | if err != nil { 188 | return RootKey{}, ChainKey{}, err 189 | } 190 | 191 | rootKey, err := NewRootKey(derivedSecret[:32]) 192 | if err != nil { 193 | return RootKey{}, ChainKey{}, err 194 | } 195 | 196 | chainKey, err := NewChainKey(derivedSecret[32:], 0) 197 | if err != nil { 198 | return RootKey{}, ChainKey{}, err 199 | } 200 | 201 | return rootKey, chainKey, nil 202 | } 203 | -------------------------------------------------------------------------------- /protocol/ratchet/keys_test.go: -------------------------------------------------------------------------------- 1 | package ratchet 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestChainKey(t *testing.T) { 10 | seed := []byte{ 11 | 0x8a, 0xb7, 0x2d, 0x6f, 0x4c, 0xc5, 0xac, 0x0d, 0x38, 0x7e, 0xaf, 0x46, 0x33, 0x78, 12 | 0xdd, 0xb2, 0x8e, 0xdd, 0x07, 0x38, 0x5b, 0x1c, 0xb0, 0x12, 0x50, 0xc7, 0x15, 0x98, 13 | 0x2e, 0x7a, 0xd4, 0x8f, 14 | } 15 | expectedMessageKey := []byte{ 16 | 0xbf, 0x51, 0xe9, 0xd7, 0x5e, 0x0e, 0x31, 0x03, 0x10, 0x51, 0xf8, 0x2a, 0x24, 0x91, 17 | 0xff, 0xc0, 0x84, 0xfa, 0x29, 0x8b, 0x77, 0x93, 0xbd, 0x9d, 0xb6, 0x20, 0x05, 0x6f, 18 | 0xeb, 0xf4, 0x52, 0x17, 19 | } 20 | expectedMACKey := []byte{ 21 | 0xc6, 0xc7, 0x7d, 0x6a, 0x73, 0xa3, 0x54, 0x33, 0x7a, 0x56, 0x43, 0x5e, 0x34, 0x60, 22 | 0x7d, 0xfe, 0x48, 0xe3, 0xac, 0xe1, 0x4e, 0x77, 0x31, 0x4d, 0xc6, 0xab, 0xc1, 0x72, 23 | 0xe7, 0xa7, 0x03, 0x0b, 24 | } 25 | expectedNextChainKey := []byte{ 26 | 0x28, 0xe8, 0xf8, 0xfe, 0xe5, 0x4b, 0x80, 0x1e, 0xef, 0x7c, 0x5c, 0xfb, 0x2f, 0x17, 27 | 0xf3, 0x2c, 0x7b, 0x33, 0x44, 0x85, 0xbb, 0xb7, 0x0f, 0xac, 0x6e, 0xc1, 0x03, 0x42, 28 | 0xa2, 0x46, 0xd1, 0x5d, 29 | } 30 | 31 | chainKey, err := NewChainKey(seed, 0) 32 | assert.NoError(t, err) 33 | assert.Equal(t, seed, chainKey.key) 34 | 35 | messageKeys, err := chainKey.MessageKeys() 36 | assert.NoError(t, err) 37 | assert.Equal(t, expectedMessageKey, messageKeys.cipherKey) 38 | assert.Equal(t, expectedMACKey, messageKeys.macKey) 39 | assert.Equal(t, expectedNextChainKey, chainKey.Next().key) 40 | assert.Equal(t, uint32(0), chainKey.index) 41 | assert.Equal(t, uint32(0), messageKeys.counter) 42 | assert.Equal(t, uint32(1), chainKey.Next().index) 43 | 44 | messageKeys, err = chainKey.Next().MessageKeys() 45 | assert.NoError(t, err) 46 | assert.Equal(t, uint32(1), messageKeys.counter) 47 | } 48 | -------------------------------------------------------------------------------- /protocol/ratchet/params.go: -------------------------------------------------------------------------------- 1 | package ratchet 2 | 3 | import ( 4 | "github.com/RTann/libsignal-go/protocol/curve" 5 | "github.com/RTann/libsignal-go/protocol/identity" 6 | ) 7 | 8 | // AliceParameters represents the "Alice" side 9 | // of the double ratchet algorithm required to 10 | // perform X3DH. 11 | type AliceParameters struct { 12 | OurIdentityKeyPair identity.KeyPair 13 | OurBaseKeyPair *curve.KeyPair 14 | 15 | TheirIdentityKey identity.Key 16 | TheirSignedPreKey curve.PublicKey 17 | TheirOneTimePreKey curve.PublicKey 18 | TheirRatchetKey curve.PublicKey 19 | } 20 | 21 | // BobParameters represents the "Bob" side 22 | // of the double ratchet algorithm required to 23 | // perform X3DH. 24 | type BobParameters struct { 25 | OurIdentityKeyPair identity.KeyPair 26 | OurSignedPreKeyPair *curve.KeyPair 27 | OurOneTimePreKeyPair *curve.KeyPair 28 | OurRatchetKeyPair *curve.KeyPair 29 | 30 | TheirIdentityKey identity.Key 31 | TheirBaseKey curve.PublicKey 32 | } 33 | -------------------------------------------------------------------------------- /protocol/ratchet/ratchet.go: -------------------------------------------------------------------------------- 1 | // Package ratchet defines the keys and parameters required to perform 2 | // the Double Ratchet algorithm to send and receive encrypted messages. 3 | package ratchet 4 | 5 | import ( 6 | "crypto/sha256" 7 | "io" 8 | 9 | "golang.org/x/crypto/hkdf" 10 | ) 11 | 12 | const initRootInfo = "WhisperText" 13 | 14 | // DeriveKeys derives a root key and chain key based on the secret input 15 | // for the root KDF chain. 16 | func DeriveKeys(secretInput []byte) (RootKey, ChainKey, error) { 17 | secrets := make([]byte, 64) 18 | kdf := hkdf.New(sha256.New, secretInput, nil, []byte(initRootInfo)) 19 | _, err := io.ReadFull(kdf, secrets) 20 | if err != nil { 21 | return RootKey{}, ChainKey{}, err 22 | } 23 | 24 | rootKey, err := NewRootKey(secrets[:32]) 25 | if err != nil { 26 | return RootKey{}, ChainKey{}, err 27 | } 28 | 29 | chainKey, err := NewChainKey(secrets[32:], 0) 30 | if err != nil { 31 | return RootKey{}, ChainKey{}, err 32 | } 33 | 34 | return rootKey, chainKey, nil 35 | } 36 | -------------------------------------------------------------------------------- /protocol/senderkey/keys.go: -------------------------------------------------------------------------------- 1 | // Package senderkey defines the keys required to 2 | // send and receive encrypted messages in a group. 3 | package senderkey 4 | 5 | import ( 6 | "crypto/hmac" 7 | "crypto/sha256" 8 | "io" 9 | 10 | "golang.org/x/crypto/hkdf" 11 | ) 12 | 13 | const ( 14 | ChainKeySize = 32 15 | 16 | info = "WhisperGroup" 17 | ) 18 | 19 | var ( 20 | messageKeySeed = []byte{0x01} 21 | chainKeySeed = []byte{0x02} 22 | ) 23 | 24 | type MessageKey struct { 25 | cipherKey []byte 26 | iv []byte 27 | seed []byte 28 | iteration uint32 29 | } 30 | 31 | func DeriveMessageKey(seed []byte, iteration uint32) (MessageKey, error) { 32 | // 16 + 32 = 48 33 | derived := make([]byte, 48) 34 | kdf := hkdf.New(sha256.New, seed, nil, []byte(info)) 35 | _, err := io.ReadFull(kdf, derived) 36 | if err != nil { 37 | return MessageKey{}, err 38 | } 39 | 40 | return MessageKey{ 41 | cipherKey: derived[16:], 42 | iv: derived[:16], 43 | seed: seed, 44 | iteration: iteration, 45 | }, nil 46 | } 47 | 48 | func (m MessageKey) CipherKey() []byte { 49 | return m.cipherKey 50 | } 51 | 52 | func (m MessageKey) IV() []byte { 53 | return m.iv 54 | } 55 | 56 | func (m MessageKey) Seed() []byte { 57 | return m.seed 58 | } 59 | 60 | func (m MessageKey) Iteration() uint32 { 61 | return m.iteration 62 | } 63 | 64 | type ChainKey struct { 65 | iteration uint32 66 | chainKey []byte 67 | } 68 | 69 | func NewChainKey(chainKey []byte, iteration uint32) ChainKey { 70 | return ChainKey{ 71 | chainKey: chainKey, 72 | iteration: iteration, 73 | } 74 | } 75 | 76 | func (c ChainKey) Iteration() uint32 { 77 | return c.iteration 78 | } 79 | 80 | func (c ChainKey) Seed() []byte { 81 | return c.chainKey 82 | } 83 | 84 | func (c ChainKey) Next() ChainKey { 85 | return ChainKey{ 86 | iteration: c.iteration + 1, 87 | chainKey: hash(c.chainKey, chainKeySeed), 88 | } 89 | } 90 | 91 | func (c ChainKey) MessageKey() (MessageKey, error) { 92 | return DeriveMessageKey(hash(c.chainKey, messageKeySeed), c.iteration) 93 | } 94 | 95 | func hash(key, seed []byte) []byte { 96 | buf := make([]byte, 0, ChainKeySize) 97 | hash := hmac.New(sha256.New, key) 98 | hash.Write(seed) 99 | buf = hash.Sum(buf) 100 | return buf 101 | } 102 | -------------------------------------------------------------------------------- /protocol/session/cipher.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "errors" 7 | "fmt" 8 | "io" 9 | 10 | "github.com/golang/glog" 11 | 12 | "github.com/RTann/libsignal-go/protocol/crypto/aes" 13 | "github.com/RTann/libsignal-go/protocol/curve" 14 | "github.com/RTann/libsignal-go/protocol/direction" 15 | "github.com/RTann/libsignal-go/protocol/message" 16 | "github.com/RTann/libsignal-go/protocol/perrors" 17 | "github.com/RTann/libsignal-go/protocol/ratchet" 18 | ) 19 | 20 | const MaxJumps = 25_000 21 | 22 | // EncryptMessage encrypts the plaintext message. 23 | func (s *Session) EncryptMessage(ctx context.Context, plaintext []byte) (message.Ciphertext, error) { 24 | record, exists, err := s.SessionStore.Load(ctx, s.RemoteAddress) 25 | if err != nil { 26 | return nil, err 27 | } 28 | if !exists { 29 | return nil, perrors.ErrSessionNotFound(s.RemoteAddress) 30 | } 31 | 32 | state := record.State() 33 | if state == nil { 34 | return nil, perrors.ErrSessionNotFound(s.RemoteAddress) 35 | } 36 | 37 | chainKey, err := state.SenderChainKey() 38 | if err != nil { 39 | return nil, err 40 | } 41 | messageKeys, err := chainKey.MessageKeys() 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | senderEphemeral, err := state.SenderRatchetKey() 47 | if err != nil { 48 | return nil, err 49 | } 50 | previousCounter := state.PreviousCounter() 51 | version := uint8(state.Version()) 52 | 53 | localIdentityKey, err := state.LocalIdentityKey() 54 | if err != nil { 55 | return nil, err 56 | } 57 | theirIdentityKey, exists, err := state.RemoteIdentityKey() 58 | if err != nil { 59 | return nil, err 60 | } 61 | if !exists { 62 | return nil, fmt.Errorf("no remote identity key for %s", s.RemoteAddress) 63 | } 64 | 65 | ciphertext, err := aes.CBCEncrypt(messageKeys.CipherKey(), messageKeys.IV(), plaintext) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | msg, err := message.NewSignal(message.SignalConfig{ 71 | Version: version, 72 | MACKey: messageKeys.MACKey(), 73 | SenderRatchetKey: senderEphemeral, 74 | PreviousCounter: previousCounter, 75 | Counter: chainKey.Index(), 76 | Ciphertext: ciphertext, 77 | SenderIdentityKey: localIdentityKey, 78 | ReceiverIdentityKey: theirIdentityKey, 79 | }) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | items, err := state.UnacknowledgedPreKeyMessages() 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | // If there are unacknowledged pre-key messages, return a pre-key message instead. 90 | if items != nil { 91 | msg, err = message.NewPreKey(message.PreKeyConfig{ 92 | Version: version, 93 | RegistrationID: state.LocalRegistrationID(), 94 | PreKeyID: items.PreKeyID(), 95 | SignedPreKeyID: items.SignedPreKeyID(), 96 | BaseKey: items.BaseKey(), 97 | IdentityKey: localIdentityKey, 98 | Message: msg.(*message.Signal), 99 | }) 100 | if err != nil { 101 | return nil, err 102 | } 103 | } 104 | 105 | state.SetSenderChainKey(chainKey.Next()) 106 | 107 | trusted, err := s.IdentityKeyStore.IsTrustedIdentity(ctx, s.RemoteAddress, theirIdentityKey, direction.Sending) 108 | if err != nil { 109 | return nil, err 110 | } 111 | if !trusted { 112 | glog.Warningf("Identity key %s is not trusted for remote address %s", hex.EncodeToString(theirIdentityKey.PublicKey().KeyBytes()), s.RemoteAddress) 113 | return nil, perrors.ErrUntrustedIdentity(s.RemoteAddress) 114 | } 115 | 116 | if _, err := s.IdentityKeyStore.Store(ctx, s.RemoteAddress, theirIdentityKey); err != nil { 117 | return nil, err 118 | } 119 | if err := s.SessionStore.Store(ctx, s.RemoteAddress, record); err != nil { 120 | return nil, err 121 | } 122 | 123 | return msg, nil 124 | } 125 | 126 | // DecryptMessage decrypts the ciphertext message. 127 | func (s *Session) DecryptMessage(ctx context.Context, random io.Reader, ciphertext message.Ciphertext) ([]byte, error) { 128 | switch msg := ciphertext.(type) { 129 | case *message.PreKey: 130 | return s.decryptPreKey(ctx, random, msg) 131 | case *message.Signal: 132 | return s.decryptSignal(ctx, random, msg) 133 | default: 134 | return nil, fmt.Errorf("DecryptMessage cannot be used to decrypt %v messages", msg.Type()) 135 | } 136 | } 137 | 138 | func (s *Session) decryptPreKey(ctx context.Context, random io.Reader, ciphertext *message.PreKey) ([]byte, error) { 139 | record, exists, err := s.SessionStore.Load(ctx, s.RemoteAddress) 140 | if err != nil { 141 | return nil, err 142 | } 143 | if !exists { 144 | // New "fresh" record. 145 | record = NewRecord(nil) 146 | } 147 | 148 | preKeyID, err := s.ProcessPreKey(ctx, record, ciphertext) 149 | if err != nil { 150 | return nil, err 151 | } 152 | 153 | plaintext, err := s.decryptMessage(random, record, ciphertext.Type(), ciphertext.Message()) 154 | if err != nil { 155 | return nil, err 156 | } 157 | 158 | err = s.SessionStore.Store(ctx, s.RemoteAddress, record) 159 | if err != nil { 160 | return nil, err 161 | } 162 | 163 | if preKeyID != nil { 164 | err := s.PreKeyStore.Delete(ctx, *preKeyID) 165 | if err != nil { 166 | return nil, err 167 | } 168 | } 169 | 170 | return plaintext, nil 171 | } 172 | 173 | func (s *Session) decryptSignal(ctx context.Context, random io.Reader, ciphertext *message.Signal) ([]byte, error) { 174 | record, exists, err := s.SessionStore.Load(ctx, s.RemoteAddress) 175 | if err != nil { 176 | return nil, err 177 | } 178 | if !exists { 179 | return nil, perrors.ErrSessionNotFound(s.RemoteAddress) 180 | } 181 | 182 | plaintext, err := s.decryptMessage(random, record, ciphertext.Type(), ciphertext) 183 | 184 | if record.State() == nil { 185 | return nil, errors.New("successfully decrypted; must have a current state") 186 | } 187 | theirIdentityKey, exists, err := record.State().RemoteIdentityKey() 188 | if err != nil || !exists { 189 | return nil, errors.New("successfully decrypted; must have a remote identity key") 190 | } 191 | 192 | trusted, err := s.IdentityKeyStore.IsTrustedIdentity(ctx, s.RemoteAddress, theirIdentityKey, direction.Receiving) 193 | if err != nil { 194 | return nil, err 195 | } 196 | if !trusted { 197 | glog.Warningf("Identity key %s is not trusted for remote address %v", hex.EncodeToString(theirIdentityKey.Bytes()), s.RemoteAddress) 198 | return nil, perrors.ErrUntrustedIdentity(s.RemoteAddress) 199 | } 200 | 201 | _, err = s.IdentityKeyStore.Store(ctx, s.RemoteAddress, theirIdentityKey) 202 | if err != nil { 203 | return nil, err 204 | } 205 | 206 | err = s.SessionStore.Store(ctx, s.RemoteAddress, record) 207 | if err != nil { 208 | return nil, err 209 | } 210 | 211 | return plaintext, nil 212 | } 213 | 214 | type updatedState struct { 215 | idx int 216 | state *State 217 | plaintext []byte 218 | } 219 | 220 | func (s *Session) decryptMessage(random io.Reader, record *Record, typ message.CiphertextType, ciphertext *message.Signal) ([]byte, error) { 221 | if record.State() != nil { 222 | currentState := record.State().Clone() 223 | plaintext, err := s.decryptMessageSession(random, currentState, ciphertext) 224 | switch { 225 | case errors.Is(err, nil): 226 | glog.Infof("decrypted %v message from %v with current session state", typ, s.RemoteAddress) 227 | record.SetSessionState(currentState) 228 | return plaintext, nil 229 | case errors.Is(err, perrors.ErrDuplicateMessage): 230 | return nil, err 231 | default: 232 | } 233 | } 234 | 235 | previousStates, err := record.PreviousStates() 236 | if err != nil { 237 | return nil, err 238 | } 239 | 240 | var state *updatedState 241 | 242 | for i, previous := range previousStates { 243 | plaintext, err := s.decryptMessageSession(random, previous, ciphertext) 244 | switch { 245 | case errors.Is(err, nil): 246 | glog.Infof("decrypted %v message from %v with PREVIOUS session state", typ, s.RemoteAddress) 247 | state = &updatedState{ 248 | idx: i, 249 | state: previous, 250 | plaintext: plaintext, 251 | } 252 | break 253 | case errors.Is(err, perrors.ErrDuplicateMessage): 254 | return nil, err 255 | default: 256 | } 257 | } 258 | 259 | if state != nil { 260 | record.PromoteOldState(state.idx, state.state) 261 | return state.plaintext, nil 262 | } 263 | 264 | if record.State() != nil { 265 | glog.Errorf("no valid session for recipient %v (previous states: %d)", s.RemoteAddress, len(previousStates)) 266 | } else { 267 | glog.Errorf("no valid session for recipient %v (no current session state, previous states: %d)", s.RemoteAddress, len(previousStates)) 268 | } 269 | 270 | return nil, errors.New("decryption failed: invalid message") 271 | } 272 | 273 | func (s *Session) decryptMessageSession(random io.Reader, state *State, ciphertext *message.Signal) ([]byte, error) { 274 | if state.session.GetSenderChain() == nil { 275 | return nil, errors.New("no session available to decrypt") 276 | } 277 | 278 | ciphertextVersion := ciphertext.Version() 279 | if uint32(ciphertextVersion) != state.Version() { 280 | return nil, fmt.Errorf("unrecognized message version: %d", ciphertextVersion) 281 | } 282 | 283 | theirEphemeral := ciphertext.SenderRatchetKey() 284 | counter := ciphertext.Counter() 285 | chainKey, err := s.chainKey(random, state, theirEphemeral) 286 | if err != nil { 287 | return nil, err 288 | } 289 | messageKeys, err := s.messageKeys(state, theirEphemeral, chainKey, counter) 290 | if err != nil { 291 | return nil, err 292 | } 293 | 294 | theirIdentityKey, found, err := state.RemoteIdentityKey() 295 | if err != nil { 296 | return nil, err 297 | } 298 | if !found { 299 | return nil, errors.New("cannot decrypt without remote identity key") 300 | } 301 | 302 | localIdentityKey, err := state.LocalIdentityKey() 303 | if err != nil { 304 | return nil, err 305 | } 306 | 307 | valid, err := ciphertext.VerifyMAC(messageKeys.MACKey(), theirIdentityKey, localIdentityKey) 308 | if err != nil { 309 | return nil, err 310 | } 311 | if !valid { 312 | return nil, errors.New("MAC verification failed") 313 | } 314 | 315 | plaintext, err := aes.CBCDecrypt(messageKeys.CipherKey(), messageKeys.IV(), ciphertext.Message()) 316 | if err != nil { 317 | return nil, err 318 | } 319 | 320 | state.ClearUnacknowledgedPreKeyMessage() 321 | 322 | return plaintext, nil 323 | } 324 | 325 | func (s *Session) chainKey(random io.Reader, state *State, theirEphemeral curve.PublicKey) (ratchet.ChainKey, error) { 326 | chain, exists, err := state.ReceiverChainKey(theirEphemeral) 327 | if err != nil { 328 | return ratchet.ChainKey{}, err 329 | } 330 | if exists { 331 | return chain, nil 332 | } 333 | 334 | glog.Infof("%v creating new chains", s.RemoteAddress) 335 | 336 | rootKey, err := state.RootKey() 337 | if err != nil { 338 | return ratchet.ChainKey{}, err 339 | } 340 | ourEphemeral, err := state.SenderRatchetPrivateKey() 341 | if err != nil { 342 | return ratchet.ChainKey{}, err 343 | } 344 | 345 | receiverRootKey, receiverChainKey, err := rootKey.CreateChain(ourEphemeral, theirEphemeral) 346 | if err != nil { 347 | return ratchet.ChainKey{}, err 348 | } 349 | 350 | ourNewEphemeral, err := curve.GenerateKeyPair(random) 351 | if err != nil { 352 | return ratchet.ChainKey{}, err 353 | } 354 | 355 | senderRootKey, senderChainKey, err := receiverRootKey.CreateChain(ourNewEphemeral.PrivateKey(), theirEphemeral) 356 | if err != nil { 357 | return ratchet.ChainKey{}, err 358 | } 359 | 360 | currentSenderChainKey, err := state.SenderChainKey() 361 | if err != nil { 362 | return ratchet.ChainKey{}, err 363 | } 364 | 365 | state.SetRootKey(senderRootKey) 366 | state.AddReceiverChain(theirEphemeral, receiverChainKey) 367 | 368 | previousIdx := uint32(0) 369 | if currentIdx := currentSenderChainKey.Index(); currentIdx > 0 { 370 | previousIdx = currentIdx - 1 371 | } 372 | state.SetPreviousCounter(previousIdx) 373 | state.SetSenderChain(ourNewEphemeral, senderChainKey) 374 | 375 | return receiverChainKey, nil 376 | } 377 | 378 | func (s *Session) messageKeys(state *State, theirEphemeral curve.PublicKey, chainKey ratchet.ChainKey, counter uint32) (ratchet.MessageKeys, error) { 379 | chainIdx := chainKey.Index() 380 | if chainIdx > counter { 381 | if keys, found, err := state.MessageKeys(theirEphemeral, counter); err != nil { 382 | return ratchet.MessageKeys{}, err 383 | } else if found { 384 | return keys, nil 385 | } 386 | 387 | glog.Warningf("%v Duplicate message for counter: %d", s.RemoteAddress, counter) 388 | return ratchet.MessageKeys{}, perrors.ErrDuplicateMessage 389 | } 390 | 391 | jump := counter - chainIdx 392 | if jump > MaxJumps { 393 | sessionWithSelf, err := state.SessionWithSelf() 394 | if err != nil { 395 | return ratchet.MessageKeys{}, err 396 | } 397 | if sessionWithSelf { 398 | glog.Infof("%v Jumping ahead %d messages (index: %d, counter: %d)", s.RemoteAddress, jump, chainIdx, counter) 399 | } else { 400 | glog.Errorf("%v Exceeded future message limit: %d, index: %d, counter: %d", s.RemoteAddress, MaxJumps, chainIdx, counter) 401 | return ratchet.MessageKeys{}, errors.New("message from too far in the future") 402 | } 403 | } 404 | 405 | for chainKey.Index() < counter { 406 | messageKeys, err := chainKey.MessageKeys() 407 | if err != nil { 408 | return ratchet.MessageKeys{}, err 409 | } 410 | if err := state.SetMessageKeys(theirEphemeral, messageKeys); err != nil { 411 | return ratchet.MessageKeys{}, err 412 | } 413 | chainKey = chainKey.Next() 414 | } 415 | 416 | err := state.SetReceiverChainKey(theirEphemeral, chainKey) 417 | if err != nil { 418 | return ratchet.MessageKeys{}, err 419 | } 420 | 421 | return chainKey.MessageKeys() 422 | } 423 | -------------------------------------------------------------------------------- /protocol/session/group_cipher.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | 9 | "github.com/golang/glog" 10 | 11 | "github.com/RTann/libsignal-go/protocol/crypto/aes" 12 | "github.com/RTann/libsignal-go/protocol/distribution" 13 | "github.com/RTann/libsignal-go/protocol/message" 14 | "github.com/RTann/libsignal-go/protocol/perrors" 15 | "github.com/RTann/libsignal-go/protocol/senderkey" 16 | ) 17 | 18 | // EncryptMessage encrypts the plaintext message. 19 | func (g *GroupSession) EncryptMessage(ctx context.Context, random io.Reader, plaintext []byte) (*message.SenderKey, error) { 20 | record, exists, err := g.SenderKeyStore.Load(ctx, g.SenderAddress, g.DistID) 21 | if err != nil { 22 | return nil, err 23 | } 24 | if !exists { 25 | return nil, fmt.Errorf("no sender key state for distribution ID %s", g.DistID.String()) 26 | } 27 | 28 | state, err := record.State() 29 | if err != nil { 30 | return nil, err 31 | } 32 | 33 | senderChainKey := state.SenderChainKey() 34 | messageKey, err := senderChainKey.MessageKey() 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | ciphertext, err := aes.CBCEncrypt(messageKey.CipherKey(), messageKey.IV(), plaintext) 40 | if err != nil { 41 | return nil, err 42 | } 43 | 44 | signingKey, err := state.PrivateSigningKey() 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | msg, err := message.NewSenderKey(random, message.SenderKeyConfig{ 50 | Version: uint8(state.Version()), 51 | DistID: g.DistID, 52 | ChainID: state.ChainID(), 53 | Iteration: messageKey.Iteration(), 54 | Ciphertext: ciphertext, 55 | SignatureKey: signingKey, 56 | }) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | state.SetSenderChainKey(senderChainKey.Next()) 62 | 63 | err = g.SenderKeyStore.Store(ctx, g.SenderAddress, g.DistID, record) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return msg, nil 69 | } 70 | 71 | // DecryptMessage decrypts the ciphertext message. 72 | func (g *GroupSession) DecryptMessage(ctx context.Context, ciphertext *message.SenderKey) ([]byte, error) { 73 | distributionID := ciphertext.DistributionID() 74 | chainID := ciphertext.ChainID() 75 | 76 | record, exists, err := g.SenderKeyStore.Load(ctx, g.SenderAddress, distributionID) 77 | if err != nil { 78 | return nil, err 79 | } 80 | if !exists { 81 | return nil, fmt.Errorf("no sender key state for distribution ID %s", distributionID.String()) 82 | } 83 | 84 | state := record.StateForChainID(chainID) 85 | if state == nil { 86 | return nil, fmt.Errorf("no sender key state for distribution ID %s", distributionID.String()) 87 | } 88 | 89 | messageVersion := ciphertext.Version() 90 | if uint32(messageVersion) != state.Version() { 91 | return nil, fmt.Errorf("unrecognized message version: %d", messageVersion) 92 | } 93 | 94 | signingKey, err := state.PublicSigningKey() 95 | if err != nil { 96 | return nil, err 97 | } 98 | 99 | valid, err := ciphertext.VerifySignature(signingKey) 100 | if err != nil { 101 | return nil, err 102 | } 103 | if !valid { 104 | return nil, errors.New("signature verification failed") 105 | } 106 | 107 | messageKey, err := g.messageKey(state, ciphertext.Iteration(), ciphertext.DistributionID()) 108 | if err != nil { 109 | return nil, err 110 | } 111 | 112 | plaintext, err := aes.CBCDecrypt(messageKey.CipherKey(), messageKey.IV(), ciphertext.Message()) 113 | if err != nil { 114 | return nil, err 115 | } 116 | 117 | err = g.SenderKeyStore.Store(ctx, g.SenderAddress, ciphertext.DistributionID(), record) 118 | if err != nil { 119 | return nil, err 120 | } 121 | 122 | return plaintext, nil 123 | } 124 | 125 | func (g *GroupSession) messageKey(state *GroupState, iteration uint32, distributionID distribution.ID) (senderkey.MessageKey, error) { 126 | chainKey := state.SenderChainKey() 127 | currentIteration := chainKey.Iteration() 128 | 129 | if currentIteration > iteration { 130 | keys, ok, err := state.RemoveMessageKeys(iteration) 131 | if err != nil { 132 | return senderkey.MessageKey{}, err 133 | } 134 | if !ok { 135 | glog.Warningf("SenderKey distribution %s Duplicate message for iteration: %d", distributionID, iteration) 136 | return senderkey.MessageKey{}, perrors.ErrDuplicateMessage 137 | } 138 | 139 | return keys, nil 140 | } 141 | 142 | jump := iteration - currentIteration 143 | if jump > MaxJumps { 144 | glog.Errorf("Sender distribution %s Exceeded future message limit: %d, iteration: %d", distributionID, MaxJumps, iteration) 145 | return senderkey.MessageKey{}, errors.New("message from too far in the future") 146 | } 147 | 148 | for chainKey.Iteration() < iteration { 149 | keys, err := chainKey.MessageKey() 150 | if err != nil { 151 | return senderkey.MessageKey{}, err 152 | } 153 | state.AddMessageKey(keys) 154 | chainKey = chainKey.Next() 155 | } 156 | 157 | state.SetSenderChainKey(chainKey.Next()) 158 | 159 | messageKeys, err := chainKey.MessageKey() 160 | if err != nil { 161 | return senderkey.MessageKey{}, err 162 | } 163 | 164 | return messageKeys, nil 165 | } 166 | -------------------------------------------------------------------------------- /protocol/session/group_record.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/golang/glog" 7 | 8 | "github.com/RTann/libsignal-go/protocol/curve" 9 | ) 10 | 11 | const maxSenderKeyStates = 5 12 | 13 | // GroupRecord holds a record of a group session's current and past states. 14 | type GroupRecord struct { 15 | states []*GroupState 16 | } 17 | 18 | func NewGroupRecord() *GroupRecord { 19 | return &GroupRecord{ 20 | states: make([]*GroupState, 0, maxSenderKeyStates), 21 | } 22 | } 23 | 24 | func (g *GroupRecord) State() (*GroupState, error) { 25 | if len(g.states) == 0 { 26 | return nil, errors.New("empty sender key state") 27 | } 28 | 29 | return g.states[0], nil 30 | } 31 | 32 | func (g *GroupRecord) StateForChainID(chainID uint32) *GroupState { 33 | for _, state := range g.states { 34 | if chainID == state.ChainID() { 35 | return state 36 | } 37 | } 38 | 39 | return nil 40 | } 41 | 42 | func (g *GroupRecord) AddState(state *GroupState) error { 43 | signingKey, err := state.PublicSigningKey() 44 | if err != nil { 45 | return err 46 | } 47 | 48 | existing, removed := g.RemoveState(state.ChainID(), signingKey) 49 | 50 | if g.RemoveStates(state.ChainID()) > 0 { 51 | glog.Warningf("Removed a matching chain_id (%d) found with a different public key", state.ChainID()) 52 | } 53 | 54 | if !removed { 55 | existing = state 56 | } 57 | 58 | if len(g.states) >= maxSenderKeyStates { 59 | g.states[0] = nil 60 | g.states = g.states[1:] 61 | } 62 | 63 | g.states = append(g.states, existing) 64 | 65 | return nil 66 | } 67 | 68 | func (g *GroupRecord) RemoveState(chainID uint32, signatureKey curve.PublicKey) (*GroupState, bool) { 69 | idx := -1 70 | for i, state := range g.states { 71 | publicKey, err := state.PublicSigningKey() 72 | if err != nil { 73 | continue 74 | } 75 | if state.ChainID() == chainID && signatureKey.Equal(publicKey) { 76 | idx = i 77 | break 78 | } 79 | } 80 | 81 | if idx < 0 { 82 | return nil, false 83 | } 84 | 85 | state := g.states[idx] 86 | g.states = append(g.states[:idx], g.states[idx+1:]...) 87 | 88 | return state, true 89 | } 90 | 91 | func (g *GroupRecord) RemoveStates(chainID uint32) int { 92 | length := len(g.states) 93 | filtered := g.states[:0] 94 | for _, state := range g.states { 95 | if state.ChainID() == chainID { 96 | continue 97 | } 98 | 99 | filtered = append(filtered, state) 100 | } 101 | // TODO: set remaining to nil? 102 | g.states = filtered 103 | 104 | return length - len(filtered) 105 | } 106 | -------------------------------------------------------------------------------- /protocol/session/group_record_test.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/suite" 8 | 9 | "github.com/RTann/libsignal-go/protocol/curve" 10 | "github.com/RTann/libsignal-go/protocol/senderkey" 11 | ) 12 | 13 | var _ suite.TestingSuite = (*groupRecordTestSuite)(nil) 14 | 15 | type groupRecordTestSuite struct { 16 | suite.Suite 17 | 18 | record *GroupRecord 19 | } 20 | 21 | func (s *groupRecordTestSuite) SetupTest() { 22 | s.record = NewGroupRecord() 23 | } 24 | 25 | func TestGroupRecordTestSuite(t *testing.T) { 26 | suite.Run(t, new(groupRecordTestSuite)) 27 | } 28 | 29 | func (s *groupRecordTestSuite) TestAddSingleState() { 30 | r := recordKey{ 31 | chainID: 1, 32 | chainKey: testChainKey(1), 33 | publicKey: s.testPublicKey(), 34 | } 35 | 36 | s.addStates(r) 37 | s.Len(s.record.states, 1) 38 | s.assertChainKey(r) 39 | } 40 | 41 | func (s *groupRecordTestSuite) TestAddSecondState() { 42 | r1 := recordKey{ 43 | chainID: 1, 44 | chainKey: testChainKey(1), 45 | publicKey: s.testPublicKey(), 46 | } 47 | r2 := recordKey{ 48 | chainID: 2, 49 | chainKey: testChainKey(2), 50 | publicKey: s.testPublicKey(), 51 | } 52 | 53 | s.addStates(r1, r2) 54 | 55 | s.Len(s.record.states, 2) 56 | s.assertChainKey(r1) 57 | s.assertChainKey(r2) 58 | } 59 | 60 | func (s *groupRecordTestSuite) TestExceedMax() { 61 | s.Equal(5, maxSenderKeyStates) 62 | 63 | rs := []recordKey{ 64 | { 65 | chainID: 1, 66 | chainKey: testChainKey(1), 67 | publicKey: s.testPublicKey(), 68 | }, 69 | { 70 | chainID: 2, 71 | chainKey: testChainKey(2), 72 | publicKey: s.testPublicKey(), 73 | }, 74 | { 75 | chainID: 3, 76 | chainKey: testChainKey(3), 77 | publicKey: s.testPublicKey(), 78 | }, 79 | { 80 | chainID: 4, 81 | chainKey: testChainKey(4), 82 | publicKey: s.testPublicKey(), 83 | }, 84 | { 85 | chainID: 5, 86 | chainKey: testChainKey(5), 87 | publicKey: s.testPublicKey(), 88 | }, 89 | { 90 | chainID: 6, 91 | chainKey: testChainKey(6), 92 | publicKey: s.testPublicKey(), 93 | }, 94 | } 95 | 96 | s.addStates(rs[:5]...) 97 | s.assertOrder(rs[:5]...) 98 | 99 | s.addStates(rs[5]) 100 | s.assertOrder(rs[1:]...) 101 | } 102 | 103 | func (s *groupRecordTestSuite) TestSameChainIDAndPublicKey() { 104 | r1 := recordKey{ 105 | chainID: 1, 106 | chainKey: testChainKey(1), 107 | publicKey: s.testPublicKey(), 108 | } 109 | r2 := recordKey{ 110 | chainID: r1.chainID, 111 | chainKey: testChainKey(2), 112 | publicKey: r1.publicKey, 113 | } 114 | 115 | s.addStates(r1, r2) 116 | 117 | s.Len(s.record.states, 1) 118 | s.assertChainKey(r1) 119 | } 120 | 121 | func (s *groupRecordTestSuite) TestSameChainIDDifferentPublicKey() { 122 | r1 := recordKey{ 123 | chainID: 1, 124 | chainKey: testChainKey(1), 125 | publicKey: s.testPublicKey(), 126 | } 127 | r2 := recordKey{ 128 | chainID: r1.chainID, 129 | chainKey: testChainKey(2), 130 | publicKey: s.testPublicKey(), 131 | } 132 | 133 | s.addStates(r1, r2) 134 | 135 | s.Len(s.record.states, 1) 136 | s.assertChainKey(r2) 137 | } 138 | 139 | func (s *groupRecordTestSuite) TestUpdateState() { 140 | r1 := recordKey{ 141 | chainID: 1, 142 | chainKey: testChainKey(1), 143 | publicKey: s.testPublicKey(), 144 | } 145 | r2 := recordKey{ 146 | chainID: 2, 147 | chainKey: testChainKey(2), 148 | publicKey: s.testPublicKey(), 149 | } 150 | 151 | s.addStates(r1, r2) 152 | s.assertOrder(r1, r2) 153 | 154 | r1.chainKey = testChainKey(3) 155 | s.addStates(r1) 156 | s.assertOrder(r2, r1) 157 | } 158 | 159 | func (s *groupRecordTestSuite) testPublicKey() curve.PublicKey { 160 | pair, err := curve.GenerateKeyPair(rand.Reader) 161 | s.Require().NoError(err) 162 | 163 | return pair.PublicKey() 164 | } 165 | 166 | func testChainKey(i uint8) []byte { 167 | chainKey := make([]byte, senderkey.ChainKeySize) 168 | chainKey[0] = i 169 | return chainKey 170 | } 171 | 172 | func (s *groupRecordTestSuite) addStates(rs ...recordKey) { 173 | var err error 174 | for _, r := range rs { 175 | err = s.record.AddState(NewGroupState(GroupStateConfig{ 176 | MessageVersion: 1, 177 | ChainID: r.chainID, 178 | Iteration: 1, 179 | ChainKey: r.chainKey, 180 | SignatureKey: r.publicKey, 181 | })) 182 | s.Require().NoError(err) 183 | } 184 | } 185 | 186 | func (s *groupRecordTestSuite) assertChainKey(r recordKey) { 187 | state := s.record.StateForChainID(r.chainID) 188 | s.Require().NotNil(state) 189 | 190 | foundChainKey := state.SenderChainKey() 191 | s.Equal(r.chainKey, foundChainKey.Seed()) 192 | 193 | var matchingState *GroupState 194 | for _, state := range s.record.states { 195 | if state.ChainID() == r.chainID { 196 | publicKey, err := state.PublicSigningKey() 197 | s.Require().NoError(err) 198 | 199 | if publicKey.Equal(r.publicKey) { 200 | s.Nil(matchingState) 201 | matchingState = state 202 | } 203 | } 204 | } 205 | 206 | s.Require().NotNil(matchingState) 207 | s.Equal(r.chainKey, matchingState.SenderChainKey().Seed()) 208 | } 209 | 210 | type recordKey struct { 211 | chainID uint32 212 | chainKey []byte 213 | publicKey curve.PublicKey 214 | } 215 | 216 | func (s *groupRecordTestSuite) assertOrder(rs ...recordKey) { 217 | s.Equal(len(rs), len(s.record.states)) 218 | 219 | for i, state := range s.record.states { 220 | publicKey, err := state.PublicSigningKey() 221 | s.Require().NoError(err) 222 | 223 | s.Equal(rs[i].chainID, state.ChainID()) 224 | s.True(publicKey.Equal(rs[i].publicKey)) 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /protocol/session/group_session.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "encoding/binary" 7 | "io" 8 | 9 | "github.com/golang/glog" 10 | 11 | "github.com/RTann/libsignal-go/protocol/address" 12 | "github.com/RTann/libsignal-go/protocol/curve" 13 | "github.com/RTann/libsignal-go/protocol/distribution" 14 | "github.com/RTann/libsignal-go/protocol/message" 15 | ) 16 | 17 | // GroupSession represents a unidirectional group sender-key encrypted session. 18 | // It may only be used for sending or for receiving, but not both. 19 | type GroupSession struct { 20 | // SenderAddress is the address of the user sending the message. 21 | // 22 | // It is meant to be populated by both a sender and a receiver. 23 | SenderAddress address.Address 24 | // DistID is the distribution ID of the group. 25 | // 26 | // It is meant to be populated by a sender, only. 27 | DistID distribution.ID 28 | SenderKeyStore GroupStore 29 | } 30 | 31 | // ProcessSenderKeyDistribution processes a group sender-key distribution message 32 | // to establish a group session to receive messages from the sender. 33 | func (g *GroupSession) ProcessSenderKeyDistribution(ctx context.Context, message *message.SenderKeyDistribution) error { 34 | glog.Infof("%s Processing SenderKey distribution %s with chain ID %s", g.SenderAddress, message.DistributionID(), message.ChainID()) 35 | 36 | record, exists, err := g.SenderKeyStore.Load(ctx, g.SenderAddress, message.DistributionID()) 37 | if err != nil { 38 | return err 39 | } 40 | if !exists { 41 | record = NewGroupRecord() 42 | } 43 | 44 | state := NewGroupState(GroupStateConfig{ 45 | MessageVersion: message.Version(), 46 | ChainID: message.ChainID(), 47 | Iteration: message.Iteration(), 48 | ChainKey: message.ChainKey(), 49 | SignatureKey: message.SigningKey(), 50 | }) 51 | err = record.AddState(state) 52 | if err != nil { 53 | return err 54 | } 55 | 56 | err = g.SenderKeyStore.Store(ctx, g.SenderAddress, message.DistributionID(), record) 57 | return err 58 | } 59 | 60 | // NewSenderKeyDistribution constructs a sender-key distribution message for establishing a group session. 61 | func (g *GroupSession) NewSenderKeyDistribution(ctx context.Context, random io.Reader) (*message.SenderKeyDistribution, error) { 62 | record, exists, err := g.SenderKeyStore.Load(ctx, g.SenderAddress, g.DistID) 63 | if err != nil { 64 | return nil, err 65 | } 66 | if !exists { 67 | chainID, err := randomUint32() 68 | if err != nil { 69 | return nil, err 70 | } 71 | // libsignal-protocol-java uses 31-bit integers. 72 | chainID >>= 1 73 | glog.Infof("Creating SenderKey for distribution %s with chain ID %d", g.DistID, chainID) 74 | 75 | senderKey := make([]byte, 32) 76 | _, err = io.ReadFull(random, senderKey) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | signingKey, err := curve.GenerateKeyPair(random) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | record = NewGroupRecord() 87 | state := NewGroupState(GroupStateConfig{ 88 | MessageVersion: message.SenderKeyVersion, 89 | ChainID: chainID, 90 | Iteration: 0, 91 | ChainKey: senderKey, 92 | SignatureKey: signingKey.PublicKey(), 93 | SignaturePrivateKey: signingKey.PrivateKey(), 94 | }) 95 | err = record.AddState(state) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | err = g.SenderKeyStore.Store(ctx, g.SenderAddress, g.DistID, record) 101 | if err != nil { 102 | return nil, err 103 | } 104 | } 105 | 106 | state, err := record.State() 107 | if err != nil { 108 | return nil, err 109 | } 110 | 111 | senderChainKey := state.SenderChainKey() 112 | 113 | signingKey, err := state.PublicSigningKey() 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | return message.NewSenderKeyDistribution(message.SenderKeyDistConfig{ 119 | Version: uint8(state.Version()), 120 | DistID: g.DistID, 121 | ChainID: state.ChainID(), 122 | Iteration: senderChainKey.Iteration(), 123 | ChainKey: senderChainKey.Seed(), 124 | SigningKey: signingKey, 125 | }) 126 | } 127 | 128 | func randomUint32() (uint32, error) { 129 | bytes := make([]byte, 4) 130 | _, err := io.ReadFull(rand.Reader, bytes) 131 | if err != nil { 132 | return 0, err 133 | } 134 | 135 | return binary.LittleEndian.Uint32(bytes), nil 136 | } 137 | -------------------------------------------------------------------------------- /protocol/session/group_state.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "github.com/RTann/libsignal-go/protocol/curve" 5 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 6 | "github.com/RTann/libsignal-go/protocol/senderkey" 7 | ) 8 | 9 | // GroupState represents a group session's state. 10 | type GroupState struct { 11 | state *v1.SenderKeyStateStructure 12 | } 13 | 14 | type GroupStateConfig struct { 15 | MessageVersion uint8 16 | ChainID uint32 17 | Iteration uint32 18 | ChainKey []byte 19 | SignatureKey curve.PublicKey 20 | SignaturePrivateKey curve.PrivateKey 21 | } 22 | 23 | func NewGroupState(cfg GroupStateConfig) *GroupState { 24 | var private []byte 25 | if cfg.SignaturePrivateKey != nil { 26 | private = cfg.SignaturePrivateKey.Bytes() 27 | } 28 | 29 | seed := make([]byte, len(cfg.ChainKey)) 30 | copy(seed, cfg.ChainKey) 31 | 32 | return &GroupState{ 33 | state: &v1.SenderKeyStateStructure{ 34 | MessageVersion: uint32(cfg.MessageVersion), 35 | ChainId: cfg.ChainID, 36 | SenderChainKey: &v1.SenderKeyStateStructure_SenderChainKey{ 37 | Iteration: cfg.Iteration, 38 | Seed: seed, 39 | }, 40 | SenderSigningKey: &v1.SenderKeyStateStructure_SenderSigningKey{ 41 | Public: cfg.SignatureKey.Bytes(), 42 | Private: private, 43 | }, 44 | }, 45 | } 46 | } 47 | 48 | func (s *GroupState) Version() uint32 { 49 | switch v := s.state.GetMessageVersion(); v { 50 | case 0: 51 | return 3 52 | default: 53 | return v 54 | } 55 | } 56 | 57 | func (s *GroupState) ChainID() uint32 { 58 | return s.state.GetChainId() 59 | } 60 | 61 | func (s *GroupState) SenderChainKey() senderkey.ChainKey { 62 | chainKey := s.state.GetSenderChainKey() 63 | return senderkey.NewChainKey(chainKey.GetSeed(), chainKey.GetIteration()) 64 | } 65 | 66 | func (s *GroupState) SetSenderChainKey(chainKey senderkey.ChainKey) { 67 | s.state.SenderChainKey = &v1.SenderKeyStateStructure_SenderChainKey{ 68 | Iteration: chainKey.Iteration(), 69 | Seed: chainKey.Seed(), 70 | } 71 | } 72 | 73 | func (s *GroupState) PrivateSigningKey() (curve.PrivateKey, error) { 74 | return curve.NewPrivateKey(s.state.GetSenderSigningKey().GetPrivate()) 75 | } 76 | 77 | func (s *GroupState) PublicSigningKey() (curve.PublicKey, error) { 78 | return curve.NewPublicKey(s.state.GetSenderSigningKey().GetPublic()) 79 | } 80 | 81 | func (s *GroupState) AddMessageKey(key senderkey.MessageKey) { 82 | msgKeys := &v1.SenderKeyStateStructure_SenderMessageKey{ 83 | Iteration: key.Iteration(), 84 | Seed: key.Seed(), 85 | } 86 | s.state.SenderMessageKeys = append(s.state.GetSenderMessageKeys(), msgKeys) 87 | if len(s.state.GetSenderMessageKeys()) > maxMessageKeys { 88 | s.state.GetSenderMessageKeys()[0] = nil 89 | s.state.SenderMessageKeys = s.state.GetSenderMessageKeys()[1:] 90 | } 91 | } 92 | 93 | func (s *GroupState) RemoveMessageKeys(iteration uint32) (senderkey.MessageKey, bool, error) { 94 | var messageKey *v1.SenderKeyStateStructure_SenderMessageKey 95 | idx := -1 96 | for i, key := range s.state.GetSenderMessageKeys() { 97 | if key.GetIteration() == iteration { 98 | messageKey = key 99 | idx = i 100 | break 101 | } 102 | } 103 | 104 | if idx < 0 { 105 | return senderkey.MessageKey{}, false, nil 106 | } 107 | 108 | derived, err := senderkey.DeriveMessageKey(messageKey.GetSeed(), messageKey.GetIteration()) 109 | if err != nil { 110 | return senderkey.MessageKey{}, false, err 111 | } 112 | 113 | s.state.GetSenderMessageKeys()[idx] = nil 114 | s.state.SenderMessageKeys = append(s.state.GetSenderMessageKeys()[:idx], s.state.GetSenderMessageKeys()[idx+1:]...) 115 | 116 | return derived, true, nil 117 | } 118 | -------------------------------------------------------------------------------- /protocol/session/record.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "crypto/subtle" 5 | 6 | "github.com/golang/glog" 7 | "google.golang.org/protobuf/proto" 8 | 9 | "github.com/RTann/libsignal-go/protocol/curve" 10 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 11 | "github.com/RTann/libsignal-go/protocol/identity" 12 | "github.com/RTann/libsignal-go/protocol/perrors" 13 | "github.com/RTann/libsignal-go/protocol/ratchet" 14 | ) 15 | 16 | // Record holds a record of a session's current and past states. 17 | type Record struct { 18 | currentSession *State 19 | previousSessions [][]byte 20 | } 21 | 22 | // NewRecord creates a new Record with current session set to the given state. 23 | // Set state to `nil` for a "fresh" record. 24 | func NewRecord(state *State) *Record { 25 | return &Record{ 26 | currentSession: state, 27 | previousSessions: make([][]byte, 0, maxArchivedStates), 28 | } 29 | } 30 | 31 | func NewRecordBytes(bytes []byte) (*Record, error) { 32 | var session v1.SessionStructure 33 | err := proto.Unmarshal(bytes, &session) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | return &Record{ 39 | currentSession: NewState(&session), 40 | }, nil 41 | } 42 | 43 | func (r *Record) State() *State { 44 | return r.currentSession 45 | } 46 | 47 | func (r *Record) Version() (uint32, error) { 48 | state := r.State() 49 | if state == nil { 50 | return 0, perrors.ErrNoCurrentSession 51 | } 52 | 53 | return state.Version(), nil 54 | } 55 | 56 | func (r *Record) PreviousStates() ([]*State, error) { 57 | states := make([]*State, 0, len(r.previousSessions)) 58 | for _, state := range r.previousSessions { 59 | session := new(v1.SessionStructure) 60 | err := proto.Unmarshal(state, session) 61 | if err != nil { 62 | return nil, err 63 | } 64 | states = append(states, NewState(session)) 65 | } 66 | return states, nil 67 | } 68 | 69 | func (r *Record) SetSessionState(session *State) { 70 | r.currentSession = session 71 | } 72 | 73 | func (r *Record) HasSessionState(version uint32, aliceBaseKey []byte) (bool, error) { 74 | if r.currentSession != nil && 75 | version == r.currentSession.Version() && 76 | subtle.ConstantTimeCompare(aliceBaseKey, r.currentSession.AliceBaseKey()) == 1 { 77 | return true, nil 78 | } 79 | 80 | previousStates, err := r.PreviousStates() 81 | if err != nil { 82 | return false, err 83 | } 84 | for _, previous := range previousStates { 85 | if version == previous.Version() && subtle.ConstantTimeCompare(aliceBaseKey, previous.AliceBaseKey()) == 1 { 86 | return true, nil 87 | } 88 | } 89 | 90 | return false, nil 91 | } 92 | 93 | func (r *Record) PromoteOldState(idx int, state *State) { 94 | if idx < 0 || idx >= len(r.previousSessions) { 95 | return 96 | } 97 | r.previousSessions[idx] = nil 98 | r.previousSessions = append(r.previousSessions[:idx], r.previousSessions[idx+1:]...) 99 | r.PromoteState(state) 100 | } 101 | 102 | func (r *Record) PromoteState(state *State) { 103 | r.ArchiveCurrentState() 104 | r.currentSession = state 105 | } 106 | 107 | func (r *Record) ArchiveCurrentState() { 108 | if r.currentSession == nil { 109 | glog.Infoln("skipping archive; current session state is fresh") 110 | return 111 | } 112 | 113 | if len(r.previousSessions) >= maxArchivedStates { 114 | r.previousSessions = r.previousSessions[1:] 115 | } 116 | 117 | r.previousSessions = append(r.previousSessions, r.currentSession.Bytes()) 118 | } 119 | 120 | func (r *Record) LocalIdentityKey() (identity.Key, error) { 121 | state := r.State() 122 | if state == nil { 123 | return identity.Key{}, perrors.ErrNoCurrentSession 124 | } 125 | 126 | return state.LocalIdentityKey() 127 | } 128 | 129 | func (r *Record) RemoteIdentityKey() (identity.Key, bool, error) { 130 | state := r.State() 131 | if state == nil { 132 | return identity.Key{}, false, perrors.ErrNoCurrentSession 133 | } 134 | 135 | return state.RemoteIdentityKey() 136 | } 137 | 138 | func (r *Record) ReceiverChainKey(sender curve.PublicKey) (ratchet.ChainKey, bool, error) { 139 | state := r.State() 140 | if state == nil { 141 | return ratchet.ChainKey{}, false, perrors.ErrNoCurrentSession 142 | } 143 | 144 | return state.ReceiverChainKey(sender) 145 | } 146 | 147 | func (r *Record) SenderChainKey() (ratchet.ChainKey, error) { 148 | state := r.State() 149 | if state == nil { 150 | return ratchet.ChainKey{}, perrors.ErrNoCurrentSession 151 | } 152 | 153 | return state.SenderChainKey() 154 | } 155 | -------------------------------------------------------------------------------- /protocol/session/session.go: -------------------------------------------------------------------------------- 1 | // Package session implements the functionality necessary to establish 2 | // encrypted peer and group sessions. 3 | package session 4 | 5 | import ( 6 | "bytes" 7 | "context" 8 | "errors" 9 | "io" 10 | 11 | "github.com/golang/glog" 12 | 13 | "github.com/RTann/libsignal-go/protocol/address" 14 | "github.com/RTann/libsignal-go/protocol/curve" 15 | "github.com/RTann/libsignal-go/protocol/direction" 16 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 17 | "github.com/RTann/libsignal-go/protocol/identity" 18 | "github.com/RTann/libsignal-go/protocol/message" 19 | "github.com/RTann/libsignal-go/protocol/perrors" 20 | "github.com/RTann/libsignal-go/protocol/prekey" 21 | "github.com/RTann/libsignal-go/protocol/ratchet" 22 | ) 23 | 24 | // Session represents a protocol session with another user. 25 | type Session struct { 26 | RemoteAddress address.Address 27 | SessionStore Store 28 | PreKeyStore prekey.Store 29 | SignedPreKeyStore prekey.SignedStore 30 | IdentityKeyStore identity.Store 31 | } 32 | 33 | // ProcessPreKey processes a pre-key message to initialize a "Bob" session 34 | // after receiving a message from "Alice". 35 | // 36 | // This method returns the one-time pre-key used by "Alice" when sending the initial message, 37 | // if one was used. 38 | func (s *Session) ProcessPreKey(ctx context.Context, record *Record, message *message.PreKey) (*prekey.ID, error) { 39 | theirIdentityKey := message.IdentityKey() 40 | 41 | trusted, err := s.IdentityKeyStore.IsTrustedIdentity(ctx, s.RemoteAddress, theirIdentityKey, direction.Receiving) 42 | if err != nil { 43 | return nil, err 44 | } 45 | if !trusted { 46 | return nil, perrors.ErrUntrustedIdentity(s.RemoteAddress) 47 | } 48 | 49 | unsignedPreKeyID, err := s.processPreKeyV3(ctx, record, message) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | _, err = s.IdentityKeyStore.Store(ctx, s.RemoteAddress, theirIdentityKey) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | return unsignedPreKeyID, nil 60 | } 61 | 62 | func (s *Session) processPreKeyV3(ctx context.Context, record *Record, message *message.PreKey) (*prekey.ID, error) { 63 | exists, err := record.HasSessionState(uint32(message.Version()), message.BaseKey().Bytes()) 64 | if err != nil { 65 | return nil, err 66 | } 67 | if exists { 68 | // We've already set up a session for this V3 message, letting bundled message fall through. 69 | return nil, nil 70 | } 71 | 72 | var ourSignedPreKeyPair *curve.KeyPair 73 | ourSignedPreKeyRecord, exists, err := s.SignedPreKeyStore.Load(ctx, message.SignedPreKeyID()) 74 | if err != nil { 75 | return nil, err 76 | } 77 | if exists { 78 | ourSignedPreKeyPair, err = ourSignedPreKeyRecord.KeyPair() 79 | if err != nil { 80 | return nil, err 81 | } 82 | } 83 | 84 | var ourOneTimePreKeyPair *curve.KeyPair 85 | if message.PreKeyID() == nil { 86 | glog.Warningf("processing PreKey message from %s which had no one-time pre-key", s.RemoteAddress) 87 | } else { 88 | glog.Infof("processing PreKey message from %s", s.RemoteAddress) 89 | 90 | ourOneTimePreKeyRecord, exists, err := s.PreKeyStore.Load(ctx, *message.PreKeyID()) 91 | if err != nil { 92 | return nil, err 93 | } 94 | if exists { 95 | ourOneTimePreKeyPair, err = ourOneTimePreKeyRecord.KeyPair() 96 | if err != nil { 97 | return nil, err 98 | } 99 | } 100 | } 101 | 102 | session, err := initializeBobSession(&ratchet.BobParameters{ 103 | OurIdentityKeyPair: s.IdentityKeyStore.KeyPair(ctx), 104 | OurSignedPreKeyPair: ourSignedPreKeyPair, 105 | OurOneTimePreKeyPair: ourOneTimePreKeyPair, 106 | OurRatchetKeyPair: ourSignedPreKeyPair, 107 | TheirIdentityKey: message.IdentityKey(), 108 | TheirBaseKey: message.BaseKey(), 109 | }) 110 | 111 | session.SetLocalRegistrationID(s.IdentityKeyStore.LocalRegistrationID(ctx)) 112 | session.SetRemoteRegistrationID(message.RegistrationID()) 113 | session.SetAliceBaseKey(message.BaseKey().Bytes()) 114 | 115 | record.PromoteState(session) 116 | 117 | return message.PreKeyID(), nil 118 | } 119 | 120 | // ProcessPreKeyBundle processes a pre-key bundle to initialize an "Alice" session 121 | // to send encrypted messages to some "Bob" user identified by the pre-key bundle. 122 | func (s *Session) ProcessPreKeyBundle(ctx context.Context, random io.Reader, bundle *prekey.Bundle) error { 123 | theirIdentityKey := bundle.IdentityKey 124 | 125 | trusted, err := s.IdentityKeyStore.IsTrustedIdentity(ctx, s.RemoteAddress, theirIdentityKey, direction.Sending) 126 | if err != nil { 127 | return err 128 | } 129 | if !trusted { 130 | return errors.New("untrusted identity") 131 | } 132 | 133 | ok, err := theirIdentityKey.PublicKey().VerifySignature(bundle.SignedPreKeySignature, bundle.SignedPreKeyPublic.Bytes()) 134 | if err != nil { 135 | return err 136 | } 137 | if !ok { 138 | return errors.New("signature validation failed") 139 | } 140 | 141 | record, exists, err := s.SessionStore.Load(ctx, s.RemoteAddress) 142 | if err != nil { 143 | return err 144 | } 145 | if !exists { 146 | record = NewRecord(nil) 147 | } 148 | 149 | ourBaseKeyPair, err := curve.GenerateKeyPair(random) 150 | if err != nil { 151 | return err 152 | } 153 | theirSignedPreKey := bundle.SignedPreKeyPublic 154 | theirOneTimePreKey := bundle.PreKeyPublic 155 | ourIdentityKeyPair := s.IdentityKeyStore.KeyPair(ctx) 156 | 157 | session, err := initializeAliceSession(random, &ratchet.AliceParameters{ 158 | OurIdentityKeyPair: ourIdentityKeyPair, 159 | OurBaseKeyPair: ourBaseKeyPair, 160 | TheirIdentityKey: theirIdentityKey, 161 | TheirSignedPreKey: theirSignedPreKey, 162 | TheirOneTimePreKey: theirOneTimePreKey, 163 | TheirRatchetKey: theirSignedPreKey, 164 | }) 165 | if err != nil { 166 | return err 167 | } 168 | 169 | theirOneTimePreKeyID := bundle.PreKeyID 170 | preKeyString := "" 171 | if theirOneTimePreKeyID != nil { 172 | preKeyString = theirOneTimePreKeyID.String() 173 | } 174 | glog.Infof("set_unacknowledged_pre_key_message for: %s with preKeyId: %s", s.RemoteAddress, preKeyString) 175 | 176 | session.SetUnacknowledgedPreKeyMessage(theirOneTimePreKeyID, bundle.SignedPreKeyID, ourBaseKeyPair.PublicKey()) 177 | 178 | session.SetLocalRegistrationID(s.IdentityKeyStore.LocalRegistrationID(ctx)) 179 | session.SetRemoteRegistrationID(bundle.RegistrationID) 180 | session.SetAliceBaseKey(ourBaseKeyPair.PublicKey().Bytes()) 181 | 182 | _, err = s.IdentityKeyStore.Store(ctx, s.RemoteAddress, theirIdentityKey) 183 | if err != nil { 184 | return err 185 | } 186 | 187 | record.PromoteState(session) 188 | 189 | err = s.SessionStore.Store(ctx, s.RemoteAddress, record) 190 | if err != nil { 191 | return err 192 | } 193 | 194 | return nil 195 | } 196 | 197 | func InitializeAliceSessionRecord(random io.Reader, params *ratchet.AliceParameters) (*Record, error) { 198 | session, err := initializeAliceSession(random, params) 199 | if err != nil { 200 | return nil, err 201 | } 202 | 203 | return NewRecord(session), nil 204 | } 205 | 206 | func initializeAliceSession(random io.Reader, params *ratchet.AliceParameters) (*State, error) { 207 | localIdentity := params.OurIdentityKeyPair.IdentityKey() 208 | sendingRatchetKeyPair, err := curve.GenerateKeyPair(random) 209 | if err != nil { 210 | return nil, err 211 | } 212 | 213 | dh1, err := params.OurIdentityKeyPair.PrivateKey().Agreement(params.TheirSignedPreKey) 214 | if err != nil { 215 | return nil, err 216 | } 217 | 218 | ourBasePrivateKey := params.OurBaseKeyPair.PrivateKey() 219 | dh2, err := ourBasePrivateKey.Agreement(params.TheirIdentityKey.PublicKey()) 220 | if err != nil { 221 | return nil, err 222 | } 223 | dh3, err := ourBasePrivateKey.Agreement(params.TheirSignedPreKey) 224 | if err != nil { 225 | return nil, err 226 | } 227 | 228 | // 32 * 5 = 160 229 | secrets := bytes.NewBuffer(make([]byte, 0, 160)) 230 | secrets.Write(discontinuityBytes) 231 | secrets.Write(dh1) 232 | secrets.Write(dh2) 233 | secrets.Write(dh3) 234 | 235 | if params.TheirOneTimePreKey != nil { 236 | dh4, err := ourBasePrivateKey.Agreement(params.TheirOneTimePreKey) 237 | if err != nil { 238 | return nil, err 239 | } 240 | 241 | secrets.Write(dh4) 242 | } 243 | 244 | rootKey, chainKey, err := ratchet.DeriveKeys(secrets.Bytes()) 245 | if err != nil { 246 | return nil, err 247 | } 248 | 249 | sendingChainRootKey, sendingChainChainKey, err := rootKey.CreateChain(sendingRatchetKeyPair.PrivateKey(), params.TheirRatchetKey) 250 | if err != nil { 251 | return nil, err 252 | } 253 | 254 | session := NewState(&v1.SessionStructure{ 255 | SessionVersion: message.CiphertextVersion, 256 | LocalIdentityPublic: localIdentity.PublicKey().Bytes(), 257 | RemoteIdentityPublic: params.TheirIdentityKey.Bytes(), 258 | RootKey: sendingChainRootKey.Bytes(), 259 | }) 260 | session.AddReceiverChain(params.TheirRatchetKey, chainKey) 261 | session.SetSenderChain(sendingRatchetKeyPair, sendingChainChainKey) 262 | 263 | return session, nil 264 | } 265 | 266 | func InitializeBobSessionRecord(params *ratchet.BobParameters) (*Record, error) { 267 | session, err := initializeBobSession(params) 268 | if err != nil { 269 | return nil, err 270 | } 271 | 272 | return NewRecord(session), nil 273 | } 274 | 275 | func initializeBobSession(params *ratchet.BobParameters) (*State, error) { 276 | localIdentity := params.OurIdentityKeyPair.IdentityKey() 277 | 278 | dh1, err := params.OurSignedPreKeyPair.PrivateKey().Agreement(params.TheirIdentityKey.PublicKey()) 279 | if err != nil { 280 | return nil, err 281 | } 282 | 283 | dh2, err := params.OurIdentityKeyPair.PrivateKey().Agreement(params.TheirBaseKey) 284 | if err != nil { 285 | return nil, err 286 | } 287 | 288 | dh3, err := params.OurSignedPreKeyPair.PrivateKey().Agreement(params.TheirBaseKey) 289 | if err != nil { 290 | return nil, err 291 | } 292 | 293 | // 32 * 5 = 160 294 | secrets := bytes.NewBuffer(make([]byte, 0, 160)) 295 | secrets.Write(discontinuityBytes) 296 | secrets.Write(dh1) 297 | secrets.Write(dh2) 298 | secrets.Write(dh3) 299 | 300 | if params.OurOneTimePreKeyPair != nil { 301 | dh4, err := params.OurOneTimePreKeyPair.PrivateKey().Agreement(params.TheirBaseKey) 302 | if err != nil { 303 | return nil, err 304 | } 305 | 306 | secrets.Write(dh4) 307 | } 308 | 309 | rootKey, chainKey, err := ratchet.DeriveKeys(secrets.Bytes()) 310 | 311 | session := NewState(&v1.SessionStructure{ 312 | SessionVersion: message.CiphertextVersion, 313 | LocalIdentityPublic: localIdentity.PublicKey().Bytes(), 314 | RemoteIdentityPublic: params.TheirIdentityKey.Bytes(), 315 | RootKey: rootKey.Bytes(), 316 | }) 317 | session.SetSenderChain(params.OurRatchetKeyPair, chainKey) 318 | 319 | return session, nil 320 | } 321 | -------------------------------------------------------------------------------- /protocol/session/state.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | 7 | "google.golang.org/protobuf/proto" 8 | 9 | "github.com/RTann/libsignal-go/protocol/curve" 10 | v1 "github.com/RTann/libsignal-go/protocol/generated/v1" 11 | "github.com/RTann/libsignal-go/protocol/identity" 12 | "github.com/RTann/libsignal-go/protocol/internal/pointer" 13 | "github.com/RTann/libsignal-go/protocol/prekey" 14 | "github.com/RTann/libsignal-go/protocol/ratchet" 15 | ) 16 | 17 | const ( 18 | maxReceiverChains = 5 19 | maxArchivedStates = 40 20 | 21 | maxMessageKeys = 2000 22 | ) 23 | 24 | var discontinuityBytes = []byte{ 25 | 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 26 | 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 27 | 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 28 | } 29 | 30 | // State represents a session's state. 31 | type State struct { 32 | session *v1.SessionStructure 33 | } 34 | 35 | func NewState(session *v1.SessionStructure) *State { 36 | return &State{ 37 | session: session, 38 | } 39 | } 40 | 41 | func (s *State) Clone() *State { 42 | session := proto.Clone(s.session) 43 | return NewState(session.(*v1.SessionStructure)) 44 | } 45 | 46 | func (s *State) SetAliceBaseKey(key []byte) { 47 | s.session.AliceBaseKey = key 48 | } 49 | 50 | func (s *State) AliceBaseKey() []byte { 51 | return s.session.GetAliceBaseKey() 52 | } 53 | 54 | func (s *State) Version() uint32 { 55 | v := s.session.GetSessionVersion() 56 | if v == 0 { 57 | return uint32(2) 58 | } 59 | 60 | return v 61 | } 62 | 63 | func (s *State) RemoteIdentityKey() (identity.Key, bool, error) { 64 | remoteBytes := s.session.GetRemoteIdentityPublic() 65 | if len(remoteBytes) == 0 { 66 | return identity.Key{}, false, nil 67 | } 68 | 69 | remoteKey, err := identity.NewKey(remoteBytes) 70 | if err != nil { 71 | return identity.Key{}, false, err 72 | } 73 | 74 | return remoteKey, true, nil 75 | } 76 | 77 | func (s *State) LocalIdentityKey() (identity.Key, error) { 78 | return identity.NewKey(s.session.GetLocalIdentityPublic()) 79 | } 80 | 81 | func (s *State) SessionWithSelf() (bool, error) { 82 | remote, exists, err := s.RemoteIdentityKey() 83 | if err != nil { 84 | return false, err 85 | } 86 | if exists { 87 | local, err := s.LocalIdentityKey() 88 | if err != nil { 89 | return false, err 90 | } 91 | 92 | return remote.Equal(local), nil 93 | } 94 | 95 | return false, nil 96 | } 97 | 98 | func (s *State) SetPreviousCounter(counter uint32) { 99 | s.session.PreviousCounter = counter 100 | } 101 | 102 | func (s *State) PreviousCounter() uint32 { 103 | return s.session.GetPreviousCounter() 104 | } 105 | 106 | func (s *State) SetRootKey(key ratchet.RootKey) { 107 | s.session.RootKey = key.Bytes() 108 | } 109 | 110 | func (s *State) RootKey() (ratchet.RootKey, error) { 111 | return ratchet.NewRootKey(s.session.GetRootKey()) 112 | } 113 | 114 | func (s *State) SenderRatchetPrivateKey() (curve.PrivateKey, error) { 115 | chain := s.session.GetSenderChain() 116 | if chain == nil { 117 | return nil, errors.New("missing sender chain") 118 | } 119 | 120 | ratchetKey, err := curve.NewPrivateKey(chain.GetSenderRatchetKeyPrivate()) 121 | if err != nil { 122 | return nil, err 123 | } 124 | 125 | return ratchetKey, nil 126 | } 127 | 128 | func (s *State) SenderRatchetKey() (curve.PublicKey, error) { 129 | chain := s.session.GetSenderChain() 130 | if chain == nil { 131 | return nil, errors.New("missing sender chain") 132 | } 133 | 134 | ratchetKey, err := curve.NewPublicKey(chain.GetSenderRatchetKey()) 135 | if err != nil { 136 | return nil, err 137 | } 138 | 139 | return ratchetKey, nil 140 | } 141 | 142 | func (s *State) ReceiverChain(sender curve.PublicKey) (int, *v1.SessionStructure_Chain) { 143 | key := sender.Bytes() 144 | for i, chain := range s.session.GetReceiverChains() { 145 | if bytes.Equal(chain.GetSenderRatchetKey(), key) { 146 | return i, chain 147 | } 148 | } 149 | 150 | return -1, nil 151 | } 152 | 153 | func (s *State) SetReceiverChainKey(sender curve.PublicKey, chainKey ratchet.ChainKey) error { 154 | idx, chain := s.ReceiverChain(sender) 155 | if idx < 0 { 156 | return errors.New("SetReceiverChainKey called for non-existent chain") 157 | } 158 | 159 | chain.ChainKey = &v1.SessionStructure_Chain_ChainKey{ 160 | Index: chainKey.Index(), 161 | Key: chainKey.Key(), 162 | } 163 | 164 | s.session.GetReceiverChains()[idx] = chain 165 | 166 | return nil 167 | } 168 | 169 | func (s *State) ReceiverChainKey(sender curve.PublicKey) (ratchet.ChainKey, bool, error) { 170 | idx, chain := s.ReceiverChain(sender) 171 | if idx < 0 { 172 | return ratchet.ChainKey{}, false, nil 173 | } 174 | 175 | chainKey, err := ratchet.NewChainKey(chain.GetChainKey().GetKey(), chain.GetChainKey().GetIndex()) 176 | if err != nil { 177 | return ratchet.ChainKey{}, false, err 178 | } 179 | 180 | return chainKey, true, nil 181 | } 182 | 183 | func (s *State) AddReceiverChain(sender curve.PublicKey, chainKey ratchet.ChainKey) { 184 | chain := &v1.SessionStructure_Chain{ 185 | SenderRatchetKey: sender.Bytes(), 186 | ChainKey: &v1.SessionStructure_Chain_ChainKey{ 187 | Index: chainKey.Index(), 188 | Key: chainKey.Key(), 189 | }, 190 | } 191 | s.session.ReceiverChains = append(s.session.GetReceiverChains(), chain) 192 | 193 | if len(s.session.GetReceiverChains()) > maxReceiverChains { 194 | s.session.GetReceiverChains()[0] = nil 195 | s.session.ReceiverChains = s.session.GetReceiverChains()[1:] 196 | } 197 | } 198 | 199 | func (s *State) SetSenderChain(sender *curve.KeyPair, nextChainKey ratchet.ChainKey) { 200 | s.session.SenderChain = &v1.SessionStructure_Chain{ 201 | SenderRatchetKey: sender.PublicKey().Bytes(), 202 | SenderRatchetKeyPrivate: sender.PrivateKey().Bytes(), 203 | ChainKey: &v1.SessionStructure_Chain_ChainKey{ 204 | Index: nextChainKey.Index(), 205 | Key: nextChainKey.Key(), 206 | }, 207 | } 208 | } 209 | 210 | func (s *State) SenderChainKey() (ratchet.ChainKey, error) { 211 | senderChain := s.session.GetSenderChain() 212 | if senderChain == nil { 213 | return ratchet.ChainKey{}, errors.New("missing sender chain") 214 | } 215 | 216 | chainKey := senderChain.GetChainKey() 217 | if chainKey == nil { 218 | return ratchet.ChainKey{}, errors.New("missing sender chain key") 219 | } 220 | 221 | return ratchet.NewChainKey(chainKey.GetKey(), chainKey.GetIndex()) 222 | } 223 | 224 | func (s *State) SetSenderChainKey(nextChainKey ratchet.ChainKey) { 225 | chainKey := &v1.SessionStructure_Chain_ChainKey{ 226 | Index: nextChainKey.Index(), 227 | Key: nextChainKey.Key(), 228 | } 229 | 230 | senderChain := s.session.GetSenderChain() 231 | if senderChain != nil { 232 | senderChain.ChainKey = chainKey 233 | return 234 | } 235 | 236 | s.session.SenderChain = &v1.SessionStructure_Chain{ 237 | ChainKey: chainKey, 238 | } 239 | } 240 | 241 | func (s *State) SetMessageKeys(sender curve.PublicKey, messageKeys ratchet.MessageKeys) error { 242 | newKeys := &v1.SessionStructure_Chain_MessageKey{ 243 | Index: messageKeys.Counter(), 244 | CipherKey: messageKeys.CipherKey(), 245 | MacKey: messageKeys.MACKey(), 246 | Iv: messageKeys.IV(), 247 | } 248 | 249 | idx, chain := s.ReceiverChain(sender) 250 | if idx < 0 { 251 | return errors.New("SetMessageKeys called for non-existent chain") 252 | } 253 | chain.MessageKeys = append(chain.GetMessageKeys(), newKeys) 254 | if len(chain.GetMessageKeys()) > maxMessageKeys { 255 | chain.GetMessageKeys()[0] = nil 256 | chain.MessageKeys = chain.GetMessageKeys()[1:] 257 | } 258 | 259 | s.session.GetReceiverChains()[idx] = chain 260 | 261 | return nil 262 | } 263 | 264 | func (s *State) MessageKeys(sender curve.PublicKey, counter uint32) (ratchet.MessageKeys, bool, error) { 265 | idx, chain := s.ReceiverChain(sender) 266 | if idx < 0 { 267 | return ratchet.MessageKeys{}, false, nil 268 | } 269 | 270 | var err error 271 | var found bool 272 | var messageKeys ratchet.MessageKeys 273 | filtered := chain.GetMessageKeys()[:0] 274 | for _, key := range chain.GetMessageKeys() { 275 | if key.GetIndex() == counter { 276 | found = true 277 | messageKeys, err = ratchet.NewMessageKeys(key.GetCipherKey(), key.GetMacKey(), key.GetIv(), counter) 278 | key = nil 279 | continue 280 | } 281 | 282 | filtered = append(filtered, key) 283 | } 284 | 285 | chain.MessageKeys = filtered 286 | 287 | return messageKeys, found, err 288 | } 289 | 290 | func (s *State) SetUnacknowledgedPreKeyMessage(preKeyID *prekey.ID, signedPreKeyID prekey.ID, baseKey curve.PublicKey) { 291 | pending := &v1.SessionStructure_PendingPreKey{ 292 | SignedPreKeyId: int32(signedPreKeyID), 293 | BaseKey: baseKey.Bytes(), 294 | } 295 | if preKeyID != nil { 296 | pending.PreKeyId = uint32(*preKeyID) 297 | } 298 | 299 | s.session.PendingPreKey = pending 300 | } 301 | 302 | func (s *State) ClearUnacknowledgedPreKeyMessage() { 303 | s.session.PendingPreKey = nil 304 | } 305 | 306 | type UnacknowledgedPreKeyMessageItems struct { 307 | preKeyID *prekey.ID 308 | signedPreKeyID prekey.ID 309 | baseKey curve.PublicKey 310 | } 311 | 312 | func (u UnacknowledgedPreKeyMessageItems) PreKeyID() *prekey.ID { 313 | return u.preKeyID 314 | } 315 | 316 | func (u UnacknowledgedPreKeyMessageItems) SignedPreKeyID() prekey.ID { 317 | return u.signedPreKeyID 318 | } 319 | 320 | func (u UnacknowledgedPreKeyMessageItems) BaseKey() curve.PublicKey { 321 | return u.baseKey 322 | } 323 | 324 | func (s *State) UnacknowledgedPreKeyMessages() (*UnacknowledgedPreKeyMessageItems, error) { 325 | pendingPreKey := s.session.GetPendingPreKey() 326 | if pendingPreKey == nil { 327 | return nil, nil 328 | } 329 | 330 | key, err := curve.NewPublicKey(pendingPreKey.GetBaseKey()) 331 | if err != nil { 332 | return nil, err 333 | } 334 | 335 | u := &UnacknowledgedPreKeyMessageItems{ 336 | signedPreKeyID: prekey.ID(pendingPreKey.GetSignedPreKeyId()), 337 | baseKey: key, 338 | } 339 | 340 | if preKeyID := pendingPreKey.GetPreKeyId(); preKeyID != 0 { 341 | u.preKeyID = pointer.To(prekey.ID(preKeyID)) 342 | } 343 | 344 | return u, nil 345 | } 346 | 347 | func (s *State) SetRemoteRegistrationID(id uint32) { 348 | s.session.RemoteRegistrationId = id 349 | } 350 | 351 | func (s *State) SetLocalRegistrationID(id uint32) { 352 | s.session.LocalRegistrationId = id 353 | } 354 | 355 | func (s *State) LocalRegistrationID() uint32 { 356 | return s.session.LocalRegistrationId 357 | } 358 | 359 | func (s *State) Bytes() []byte { 360 | b, _ := proto.Marshal(s.session) 361 | return b 362 | } 363 | -------------------------------------------------------------------------------- /protocol/session/store.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/RTann/libsignal-go/protocol/address" 7 | "github.com/RTann/libsignal-go/protocol/distribution" 8 | ) 9 | 10 | // Store defines a session store. 11 | type Store interface { 12 | Load(ctx context.Context, address address.Address) (*Record, bool, error) 13 | Store(ctx context.Context, address address.Address, record *Record) error 14 | } 15 | 16 | type GroupStore interface { 17 | Load(ctx context.Context, sender address.Address, distributionID distribution.ID) (*GroupRecord, bool, error) 18 | Store(ctx context.Context, sender address.Address, distributionID distribution.ID, record *GroupRecord) error 19 | } 20 | -------------------------------------------------------------------------------- /protocol/session/store_inmem.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/RTann/libsignal-go/protocol/address" 7 | "github.com/RTann/libsignal-go/protocol/distribution" 8 | ) 9 | 10 | var _ Store = (*inMemStore)(nil) 11 | 12 | // inMemStore represents an in-memory session store. 13 | type inMemStore struct { 14 | sessions map[address.Address]*Record 15 | } 16 | 17 | // NewInMemStore creates a new in-memory session store. 18 | func NewInMemStore() Store { 19 | return &inMemStore{ 20 | sessions: make(map[address.Address]*Record), 21 | } 22 | } 23 | 24 | func (i *inMemStore) Load(_ context.Context, address address.Address) (*Record, bool, error) { 25 | record, exists := i.sessions[address] 26 | return record, exists, nil 27 | } 28 | 29 | func (i *inMemStore) Store(_ context.Context, address address.Address, record *Record) error { 30 | i.sessions[address] = record 31 | return nil 32 | } 33 | 34 | type key struct { 35 | address address.Address 36 | id distribution.ID 37 | } 38 | 39 | type inMemGroupStore struct { 40 | senderKeys map[key]*GroupRecord 41 | } 42 | 43 | func NewInMemGroupStore() GroupStore { 44 | return &inMemGroupStore{ 45 | senderKeys: make(map[key]*GroupRecord), 46 | } 47 | } 48 | 49 | func (i *inMemGroupStore) Load(_ context.Context, sender address.Address, distributionID distribution.ID) (*GroupRecord, bool, error) { 50 | record, exists := i.senderKeys[key{ 51 | address: sender, 52 | id: distributionID, 53 | }] 54 | return record, exists, nil 55 | } 56 | 57 | func (i *inMemGroupStore) Store(_ context.Context, sender address.Address, distributionID distribution.ID, record *GroupRecord) error { 58 | i.senderKeys[key{ 59 | address: sender, 60 | id: distributionID, 61 | }] = record 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /protocol/tests/group_session_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | mathrand "math/rand" 8 | "sort" 9 | "testing" 10 | "unicode/utf8" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | 15 | "github.com/RTann/libsignal-go/protocol/address" 16 | "github.com/RTann/libsignal-go/protocol/distribution" 17 | "github.com/RTann/libsignal-go/protocol/message" 18 | "github.com/RTann/libsignal-go/protocol/perrors" 19 | "github.com/RTann/libsignal-go/protocol/session" 20 | ) 21 | 22 | func TestGroupNoSendSession(t *testing.T) { 23 | senderAddress := address.Address{ 24 | Name: "+14159999111", 25 | DeviceID: 1, 26 | } 27 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 28 | 29 | aliceStore := testInMemProtocolStore(t, random) 30 | 31 | aliceSession := &session.GroupSession{ 32 | SenderAddress: senderAddress, 33 | DistID: distributionID, 34 | SenderKeyStore: aliceStore.GroupStore(), 35 | } 36 | _, err := aliceSession.EncryptMessage(ctx, random, []byte("space camp?")) 37 | assert.Error(t, err) 38 | } 39 | 40 | func TestGroupNoRecvSession(t *testing.T) { 41 | senderAddress := address.Address{ 42 | Name: "+14159999111", 43 | DeviceID: 1, 44 | } 45 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 46 | 47 | aliceStore := testInMemProtocolStore(t, random) 48 | bobStore := testInMemProtocolStore(t, random) 49 | 50 | aliceSession := session.GroupSession{ 51 | SenderAddress: senderAddress, 52 | DistID: distributionID, 53 | SenderKeyStore: aliceStore.GroupStore(), 54 | } 55 | _, err := aliceSession.NewSenderKeyDistribution(ctx, random) 56 | assert.NoError(t, err) 57 | 58 | aliceCiphertext, err := aliceSession.EncryptMessage(ctx, random, []byte("space camp?")) 59 | assert.NoError(t, err) 60 | 61 | bobSession := &session.GroupSession{ 62 | SenderAddress: senderAddress, 63 | SenderKeyStore: bobStore.GroupStore(), 64 | } 65 | _, err = bobSession.DecryptMessage(ctx, aliceCiphertext) 66 | assert.Error(t, err) 67 | } 68 | 69 | func TestGroupBasic(t *testing.T) { 70 | senderAddress := address.Address{ 71 | Name: "+14159999111", 72 | DeviceID: 1, 73 | } 74 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 75 | 76 | aliceStore := testInMemProtocolStore(t, random) 77 | bobStore := testInMemProtocolStore(t, random) 78 | 79 | aliceSession := session.GroupSession{ 80 | SenderAddress: senderAddress, 81 | DistID: distributionID, 82 | SenderKeyStore: aliceStore.GroupStore(), 83 | } 84 | sentDistributionMsg, err := aliceSession.NewSenderKeyDistribution(ctx, random) 85 | assert.NoError(t, err) 86 | 87 | recvDistributionMsg, err := message.NewSenderKeyDistributionFromBytes(sentDistributionMsg.Bytes()) 88 | assert.NoError(t, err) 89 | 90 | originalMsg := []byte("space camp?") 91 | aliceCiphertext, err := aliceSession.EncryptMessage(ctx, random, originalMsg) 92 | assert.NoError(t, err) 93 | 94 | bobSession := &session.GroupSession{ 95 | SenderAddress: senderAddress, 96 | SenderKeyStore: bobStore.GroupStore(), 97 | } 98 | err = bobSession.ProcessSenderKeyDistribution(ctx, recvDistributionMsg) 99 | assert.NoError(t, err) 100 | 101 | bobPlaintext, err := bobSession.DecryptMessage(ctx, aliceCiphertext) 102 | assert.NoError(t, err) 103 | assert.True(t, utf8.Valid(bobPlaintext)) 104 | assert.Equal(t, originalMsg, bobPlaintext) 105 | } 106 | 107 | func TestGroupLargeMessages(t *testing.T) { 108 | senderAddress := address.Address{ 109 | Name: "+14159999111", 110 | DeviceID: 1, 111 | } 112 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 113 | 114 | aliceStore := testInMemProtocolStore(t, random) 115 | bobStore := testInMemProtocolStore(t, random) 116 | 117 | aliceSession := session.GroupSession{ 118 | SenderAddress: senderAddress, 119 | DistID: distributionID, 120 | SenderKeyStore: aliceStore.GroupStore(), 121 | } 122 | sentDistributionMsg, err := aliceSession.NewSenderKeyDistribution(ctx, random) 123 | assert.NoError(t, err) 124 | 125 | recvDistributionMsg, err := message.NewSenderKeyDistributionFromBytes(sentDistributionMsg.Bytes()) 126 | assert.NoError(t, err) 127 | 128 | largeMsg := make([]byte, 1024) 129 | _, err = io.ReadFull(random, largeMsg) 130 | require.NoError(t, err) 131 | 132 | aliceCiphertext, err := aliceSession.EncryptMessage(ctx, random, largeMsg) 133 | assert.NoError(t, err) 134 | 135 | bobSession := &session.GroupSession{ 136 | SenderAddress: senderAddress, 137 | SenderKeyStore: bobStore.GroupStore(), 138 | } 139 | err = bobSession.ProcessSenderKeyDistribution(ctx, recvDistributionMsg) 140 | assert.NoError(t, err) 141 | 142 | bobPlaintext, err := bobSession.DecryptMessage(ctx, aliceCiphertext) 143 | assert.NoError(t, err) 144 | assert.Equal(t, largeMsg, bobPlaintext) 145 | } 146 | 147 | func TestGroupBasicRatchet(t *testing.T) { 148 | senderAddress := address.Address{ 149 | Name: "+14159999111", 150 | DeviceID: 1, 151 | } 152 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 153 | 154 | aliceStore := testInMemProtocolStore(t, random) 155 | bobStore := testInMemProtocolStore(t, random) 156 | 157 | aliceSession := session.GroupSession{ 158 | SenderAddress: senderAddress, 159 | DistID: distributionID, 160 | SenderKeyStore: aliceStore.GroupStore(), 161 | } 162 | sentDistributionMsg, err := aliceSession.NewSenderKeyDistribution(ctx, random) 163 | assert.NoError(t, err) 164 | 165 | recvDistributionMsg, err := message.NewSenderKeyDistributionFromBytes(sentDistributionMsg.Bytes()) 166 | assert.NoError(t, err) 167 | 168 | bobSession := &session.GroupSession{ 169 | SenderAddress: senderAddress, 170 | SenderKeyStore: bobStore.GroupStore(), 171 | } 172 | err = bobSession.ProcessSenderKeyDistribution(ctx, recvDistributionMsg) 173 | assert.NoError(t, err) 174 | 175 | aliceCiphertext1, err := aliceSession.EncryptMessage(ctx, random, []byte("swim camp")) 176 | assert.NoError(t, err) 177 | aliceCiphertext2, err := aliceSession.EncryptMessage(ctx, random, []byte("robot camp")) 178 | assert.NoError(t, err) 179 | aliceCiphertext3, err := aliceSession.EncryptMessage(ctx, random, []byte("ninja camp")) 180 | assert.NoError(t, err) 181 | 182 | bobPlaintext1, err := bobSession.DecryptMessage(ctx, aliceCiphertext1) 183 | assert.NoError(t, err) 184 | assert.True(t, utf8.Valid(bobPlaintext1)) 185 | assert.Equal(t, []byte("swim camp"), bobPlaintext1) 186 | 187 | _, err = bobSession.DecryptMessage(ctx, aliceCiphertext1) 188 | assert.Error(t, err) 189 | assert.Equal(t, perrors.ErrDuplicateMessage, err) 190 | 191 | bobPlaintext3, err := bobSession.DecryptMessage(ctx, aliceCiphertext3) 192 | assert.NoError(t, err) 193 | assert.True(t, utf8.Valid(bobPlaintext3)) 194 | assert.Equal(t, []byte("ninja camp"), bobPlaintext3) 195 | 196 | bobPlaintext2, err := bobSession.DecryptMessage(ctx, aliceCiphertext2) 197 | assert.NoError(t, err) 198 | assert.True(t, utf8.Valid(bobPlaintext2)) 199 | assert.Equal(t, []byte("robot camp"), bobPlaintext2) 200 | } 201 | 202 | func TestGroupLateJoin(t *testing.T) { 203 | senderAddress := address.Address{ 204 | Name: "+14159999111", 205 | DeviceID: 1, 206 | } 207 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 208 | 209 | aliceStore := testInMemProtocolStore(t, random) 210 | bobStore := testInMemProtocolStore(t, random) 211 | 212 | aliceSession := session.GroupSession{ 213 | SenderAddress: senderAddress, 214 | DistID: distributionID, 215 | SenderKeyStore: aliceStore.GroupStore(), 216 | } 217 | sentDistributionMsg, err := aliceSession.NewSenderKeyDistribution(ctx, random) 218 | assert.NoError(t, err) 219 | 220 | recvDistributionMsg, err := message.NewSenderKeyDistributionFromBytes(sentDistributionMsg.Bytes()) 221 | assert.NoError(t, err) 222 | 223 | for i := 0; i < 100; i++ { 224 | msg := fmt.Sprintf("nefarious plotting %d/100", i) 225 | _, err := aliceSession.EncryptMessage(ctx, random, []byte(msg)) 226 | assert.NoError(t, err) 227 | } 228 | 229 | bobSession := &session.GroupSession{ 230 | SenderAddress: senderAddress, 231 | SenderKeyStore: bobStore.GroupStore(), 232 | } 233 | err = bobSession.ProcessSenderKeyDistribution(ctx, recvDistributionMsg) 234 | assert.NoError(t, err) 235 | 236 | msg := []byte("welcome bob") 237 | aliceCiphertext, err := aliceSession.EncryptMessage(ctx, random, msg) 238 | assert.NoError(t, err) 239 | 240 | bobPlaintext, err := bobSession.DecryptMessage(ctx, aliceCiphertext) 241 | assert.NoError(t, err) 242 | assert.True(t, utf8.Valid(bobPlaintext)) 243 | assert.Equal(t, msg, bobPlaintext) 244 | } 245 | 246 | func TestGroupOutOfOrder(t *testing.T) { 247 | senderAddress := address.Address{ 248 | Name: "+14159999111", 249 | DeviceID: 1, 250 | } 251 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 252 | 253 | aliceStore := testInMemProtocolStore(t, random) 254 | bobStore := testInMemProtocolStore(t, random) 255 | 256 | aliceSession := session.GroupSession{ 257 | SenderAddress: senderAddress, 258 | DistID: distributionID, 259 | SenderKeyStore: aliceStore.GroupStore(), 260 | } 261 | sentDistributionMsg, err := aliceSession.NewSenderKeyDistribution(ctx, random) 262 | assert.NoError(t, err) 263 | 264 | recvDistributionMsg, err := message.NewSenderKeyDistributionFromBytes(sentDistributionMsg.Bytes()) 265 | assert.NoError(t, err) 266 | 267 | bobSession := &session.GroupSession{ 268 | SenderAddress: senderAddress, 269 | SenderKeyStore: bobStore.GroupStore(), 270 | } 271 | err = bobSession.ProcessSenderKeyDistribution(ctx, recvDistributionMsg) 272 | assert.NoError(t, err) 273 | 274 | ciphertexts := make([]*message.SenderKey, 0, 100) 275 | for i := 0; i < len(ciphertexts); i++ { 276 | msg := fmt.Sprintf("nefarious plotting %d/100", i) 277 | ciphertext, err := aliceSession.EncryptMessage(ctx, random, []byte(msg)) 278 | assert.NoError(t, err) 279 | ciphertexts = append(ciphertexts, ciphertext) 280 | } 281 | mathrand.Shuffle(len(ciphertexts), func(i, j int) { 282 | ciphertexts[i], ciphertexts[j] = ciphertexts[j], ciphertexts[i] 283 | }) 284 | 285 | plaintexts := make([][]byte, 0, len(ciphertexts)) 286 | for _, ciphertext := range ciphertexts { 287 | plaintext, err := bobSession.DecryptMessage(ctx, ciphertext) 288 | assert.NoError(t, err) 289 | plaintexts = append(plaintexts, plaintext) 290 | } 291 | sort.Slice(plaintexts, func(i, j int) bool { 292 | return bytes.Compare(plaintexts[i], plaintexts[j]) < 0 293 | }) 294 | 295 | for i, plaintext := range plaintexts { 296 | assert.True(t, utf8.Valid(plaintext)) 297 | msg := fmt.Sprintf("nefarious plotting %d/100", i) 298 | assert.Equal(t, []byte(msg), plaintext) 299 | } 300 | } 301 | 302 | func TestGroupTooFarInFuture(t *testing.T) { 303 | senderAddress := address.Address{ 304 | Name: "+14159999111", 305 | DeviceID: 1, 306 | } 307 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 308 | 309 | aliceStore := testInMemProtocolStore(t, random) 310 | bobStore := testInMemProtocolStore(t, random) 311 | 312 | aliceSession := session.GroupSession{ 313 | SenderAddress: senderAddress, 314 | DistID: distributionID, 315 | SenderKeyStore: aliceStore.GroupStore(), 316 | } 317 | sentDistributionMsg, err := aliceSession.NewSenderKeyDistribution(ctx, random) 318 | assert.NoError(t, err) 319 | 320 | recvDistributionMsg, err := message.NewSenderKeyDistributionFromBytes(sentDistributionMsg.Bytes()) 321 | assert.NoError(t, err) 322 | 323 | bobSession := &session.GroupSession{ 324 | SenderAddress: senderAddress, 325 | SenderKeyStore: bobStore.GroupStore(), 326 | } 327 | err = bobSession.ProcessSenderKeyDistribution(ctx, recvDistributionMsg) 328 | assert.NoError(t, err) 329 | 330 | for i := 0; i < session.MaxJumps+1; i++ { 331 | _, err = aliceSession.EncryptMessage(ctx, random, []byte(fmt.Sprintf("nefarious plotting %d", i))) 332 | assert.NoError(t, err) 333 | } 334 | 335 | aliceCiphertext, err := aliceSession.EncryptMessage(ctx, random, []byte("you got the plan?")) 336 | assert.NoError(t, err) 337 | 338 | _, err = bobSession.DecryptMessage(ctx, aliceCiphertext) 339 | assert.Error(t, err) 340 | } 341 | 342 | func TestGroupMessageKeyLimit(t *testing.T) { 343 | senderAddress := address.Address{ 344 | Name: "+14159999111", 345 | DeviceID: 1, 346 | } 347 | distributionID := distribution.MustParse("d1d1d1d1-7000-11eb-b32a-33b8a8a487a6") 348 | 349 | aliceStore := testInMemProtocolStore(t, random) 350 | bobStore := testInMemProtocolStore(t, random) 351 | 352 | aliceSession := session.GroupSession{ 353 | SenderAddress: senderAddress, 354 | DistID: distributionID, 355 | SenderKeyStore: aliceStore.GroupStore(), 356 | } 357 | sentDistributionMsg, err := aliceSession.NewSenderKeyDistribution(ctx, random) 358 | assert.NoError(t, err) 359 | 360 | recvDistributionMsg, err := message.NewSenderKeyDistributionFromBytes(sentDistributionMsg.Bytes()) 361 | assert.NoError(t, err) 362 | 363 | bobSession := &session.GroupSession{ 364 | SenderAddress: senderAddress, 365 | SenderKeyStore: bobStore.GroupStore(), 366 | } 367 | err = bobSession.ProcessSenderKeyDistribution(ctx, recvDistributionMsg) 368 | assert.NoError(t, err) 369 | 370 | ciphertexts := make([]*message.SenderKey, 0, 2010) 371 | msg := []byte("too many messages") 372 | for i := 0; i < 2010; i++ { 373 | ciphertext, err := aliceSession.EncryptMessage(ctx, random, msg) 374 | assert.NoError(t, err) 375 | ciphertexts = append(ciphertexts, ciphertext) 376 | } 377 | 378 | bobPlaintext, err := bobSession.DecryptMessage(ctx, ciphertexts[1000]) 379 | assert.NoError(t, err) 380 | assert.True(t, utf8.Valid(bobPlaintext)) 381 | assert.Equal(t, []byte("too many messages"), bobPlaintext) 382 | 383 | bobPlaintext, err = bobSession.DecryptMessage(ctx, ciphertexts[2009]) 384 | assert.NoError(t, err) 385 | assert.True(t, utf8.Valid(bobPlaintext)) 386 | assert.Equal(t, []byte("too many messages"), bobPlaintext) 387 | 388 | _, err = bobSession.DecryptMessage(ctx, ciphertexts[0]) 389 | assert.Error(t, err) 390 | } 391 | -------------------------------------------------------------------------------- /protocol/tests/ratchet_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/RTann/libsignal-go/protocol/curve" 12 | "github.com/RTann/libsignal-go/protocol/identity" 13 | "github.com/RTann/libsignal-go/protocol/ratchet" 14 | "github.com/RTann/libsignal-go/protocol/session" 15 | ) 16 | 17 | func TestBobSession(t *testing.T) { 18 | bobEphemeralPublic, err := hex.DecodeString("052cb49776b8770205745a3a6e24f579cdb4ba7a89041005928ebbadc9c05ad458") 19 | require.NoError(t, err) 20 | bobEphemeralPrivate, err := hex.DecodeString("a1cab48f7c893fafa9880a28c3b4999d28d6329562d27a4ea4e22e9ff1bdd65a") 21 | require.NoError(t, err) 22 | bobIdentityPublic, err := hex.DecodeString("05f1f43874f6966956c2dd473f8fa15adeb71d1cb991b2341692324cefb1c5e626") 23 | require.NoError(t, err) 24 | bobIdentityPrivate, err := hex.DecodeString("4875cc69ddf8ea0719ec947d61081135868d5fd801f02c0225e516df2156605e") 25 | require.NoError(t, err) 26 | aliceBasePublic, err := hex.DecodeString("05472d1fb1a9862c3af6beaca8920277e2b26f4a79213ec7c906aeb35e03cf8950") 27 | require.NoError(t, err) 28 | aliceIdentityPublic, err := hex.DecodeString("05b4a8455660ada65b401007f615e654041746432e3339c6875149bceefcb42b4a") 29 | require.NoError(t, err) 30 | bobSignedPreKeyPublic, err := hex.DecodeString("05ac248a8f263be6863576eb0362e28c828f0107a3379d34bab1586bf8c770cd67") 31 | require.NoError(t, err) 32 | bobSignedPreKeyPrivate, err := hex.DecodeString("583900131fb727998b7803fe6ac22cc591f342e4e42a8c8d5d78194209b8d253") 33 | require.NoError(t, err) 34 | 35 | expectedSenderChain := "9797caca53c989bbe229a40ca7727010eb2604fc14945d77958a0aeda088b44d" 36 | 37 | bobIdentityKeyPrivate, err := curve.NewPrivateKey(bobIdentityPrivate) 38 | require.NoError(t, err) 39 | bobIdentityKeyPublic, err := identity.NewKey(bobIdentityPublic) 40 | require.NoError(t, err) 41 | bobIdentityKeyPair := identity.NewKeyPair(bobIdentityKeyPrivate, bobIdentityKeyPublic) 42 | 43 | bobEphemeralPair, err := curve.NewKeyPair(bobEphemeralPrivate, bobEphemeralPublic) 44 | require.NoError(t, err) 45 | bobSignedPreKeyPair, err := curve.NewKeyPair(bobSignedPreKeyPrivate, bobSignedPreKeyPublic) 46 | require.NoError(t, err) 47 | 48 | aliceIdentityPublicKey, err := identity.NewKey(aliceIdentityPublic) 49 | require.NoError(t, err) 50 | aliceBasePublicKey, err := curve.NewPublicKey(aliceBasePublic) 51 | require.NoError(t, err) 52 | 53 | bobParams := &ratchet.BobParameters{ 54 | OurIdentityKeyPair: bobIdentityKeyPair, 55 | OurSignedPreKeyPair: bobSignedPreKeyPair, 56 | OurOneTimePreKeyPair: nil, 57 | OurRatchetKeyPair: bobEphemeralPair, 58 | TheirIdentityKey: aliceIdentityPublicKey, 59 | TheirBaseKey: aliceBasePublicKey, 60 | } 61 | 62 | bobRecord, err := session.InitializeBobSessionRecord(bobParams) 63 | assert.NoError(t, err) 64 | 65 | bobLocalIdentityKey, err := bobRecord.LocalIdentityKey() 66 | assert.NoError(t, err) 67 | assert.Equal(t, hex.EncodeToString(bobIdentityPublic), hex.EncodeToString(bobLocalIdentityKey.Bytes())) 68 | 69 | bobRemoteIdentityKey, exists, err := bobRecord.RemoteIdentityKey() 70 | assert.NoError(t, err) 71 | assert.True(t, exists) 72 | assert.Equal(t, hex.EncodeToString(aliceIdentityPublic), hex.EncodeToString(bobRemoteIdentityKey.Bytes())) 73 | 74 | bobSenderChainKey, err := bobRecord.SenderChainKey() 75 | assert.NoError(t, err) 76 | assert.Equal(t, expectedSenderChain, hex.EncodeToString(bobSenderChainKey.Key())) 77 | } 78 | 79 | func TestAliceSession(t *testing.T) { 80 | bobEphemeralPublic, err := hex.DecodeString("052cb49776b8770205745a3a6e24f579cdb4ba7a89041005928ebbadc9c05ad458") 81 | require.NoError(t, err) 82 | bobIdentityPublic, err := hex.DecodeString("05f1f43874f6966956c2dd473f8fa15adeb71d1cb991b2341692324cefb1c5e626") 83 | require.NoError(t, err) 84 | aliceBasePublic, err := hex.DecodeString("05472d1fb1a9862c3af6beaca8920277e2b26f4a79213ec7c906aeb35e03cf8950") 85 | require.NoError(t, err) 86 | aliceBasePrivate, err := hex.DecodeString("11ae7c64d1e61cd596b76a0db5012673391cae66edbfcf073b4da80516a47449") 87 | require.NoError(t, err) 88 | bobSignedPrePublic, err := hex.DecodeString("05ac248a8f263be6863576eb0362e28c828f0107a3379d34bab1586bf8c770cd67") 89 | require.NoError(t, err) 90 | aliceIdentityPublic, err := hex.DecodeString("05b4a8455660ada65b401007f615e654041746432e3339c6875149bceefcb42b4a") 91 | require.NoError(t, err) 92 | aliceIdentityPrivate, err := hex.DecodeString("9040f0d4e09cf38f6dc7c13779c908c015a1da4fa78737a080eb0a6f4f5f8f58") 93 | require.NoError(t, err) 94 | 95 | expectedReceiverChain := "ab9be50e5cb22a925446ab90ee5670545f4fd32902459ec274b6ad0ae5d6031a" 96 | 97 | aliceIdentityKeyPrivate, err := curve.NewPrivateKey(aliceIdentityPrivate) 98 | require.NoError(t, err) 99 | aliceIdentityKeyPublic, err := identity.NewKey(aliceIdentityPublic) 100 | require.NoError(t, err) 101 | aliceIdentityKeyPair := identity.NewKeyPair(aliceIdentityKeyPrivate, aliceIdentityKeyPublic) 102 | 103 | aliceBaseKeyPair, err := curve.NewKeyPair(aliceBasePrivate, aliceBasePublic) 104 | require.NoError(t, err) 105 | 106 | bobIdentityKey, err := identity.NewKey(bobIdentityPublic) 107 | require.NoError(t, err) 108 | 109 | bobEphemeralKeyPublic, err := curve.NewPublicKey(bobEphemeralPublic) 110 | require.NoError(t, err) 111 | bobSignedPreKeyPublic, err := curve.NewPublicKey(bobSignedPrePublic) 112 | require.NoError(t, err) 113 | 114 | aliceParams := &ratchet.AliceParameters{ 115 | OurIdentityKeyPair: aliceIdentityKeyPair, 116 | OurBaseKeyPair: aliceBaseKeyPair, 117 | TheirIdentityKey: bobIdentityKey, 118 | TheirSignedPreKey: bobSignedPreKeyPublic, 119 | TheirOneTimePreKey: nil, 120 | TheirRatchetKey: bobEphemeralKeyPublic, 121 | } 122 | 123 | aliceRecord, err := session.InitializeAliceSessionRecord(rand.Reader, aliceParams) 124 | assert.NoError(t, err) 125 | 126 | aliceLocalIdentityKey, err := aliceRecord.LocalIdentityKey() 127 | assert.NoError(t, err) 128 | assert.Equal(t, hex.EncodeToString(aliceIdentityPublic), hex.EncodeToString(aliceLocalIdentityKey.Bytes())) 129 | 130 | aliceRemoteIdentityKey, exists, err := aliceRecord.RemoteIdentityKey() 131 | assert.NoError(t, err) 132 | assert.True(t, exists) 133 | assert.Equal(t, hex.EncodeToString(bobIdentityPublic), hex.EncodeToString(aliceRemoteIdentityKey.Bytes())) 134 | 135 | aliceReceiverChainKey, exists, err := aliceRecord.ReceiverChainKey(bobEphemeralKeyPublic) 136 | assert.NoError(t, err) 137 | assert.True(t, exists) 138 | assert.Equal(t, expectedReceiverChain, hex.EncodeToString(aliceReceiverChainKey.Key())) 139 | } 140 | -------------------------------------------------------------------------------- /protocol/tests/util_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/RTann/libsignal-go/protocol/curve" 10 | "github.com/RTann/libsignal-go/protocol/identity" 11 | "github.com/RTann/libsignal-go/protocol/protocol" 12 | "github.com/RTann/libsignal-go/protocol/ratchet" 13 | "github.com/RTann/libsignal-go/protocol/session" 14 | ) 15 | 16 | func testInMemProtocolStore(t *testing.T, random io.Reader) protocol.Store { 17 | identityKeyPair, err := identity.GenerateKeyPair(random) 18 | require.NoError(t, err) 19 | 20 | registrationID := uint32(5) 21 | 22 | return protocol.NewInMemStore(identityKeyPair, registrationID) 23 | } 24 | 25 | func testInitRecordsV3(t *testing.T, random io.Reader) (*session.Record, *session.Record) { 26 | aliceIdentity, err := identity.GenerateKeyPair(random) 27 | require.NoError(t, err) 28 | bobIdentity, err := identity.GenerateKeyPair(random) 29 | require.NoError(t, err) 30 | 31 | aliceBaseKey, err := curve.GenerateKeyPair(random) 32 | require.NoError(t, err) 33 | 34 | bobBaseKey, err := curve.GenerateKeyPair(random) 35 | require.NoError(t, err) 36 | bobEphemeralKey := bobBaseKey 37 | 38 | aliceParams := &ratchet.AliceParameters{ 39 | OurIdentityKeyPair: aliceIdentity, 40 | OurBaseKeyPair: aliceBaseKey, 41 | TheirIdentityKey: bobIdentity.IdentityKey(), 42 | TheirSignedPreKey: bobBaseKey.PublicKey(), 43 | TheirOneTimePreKey: nil, 44 | TheirRatchetKey: bobEphemeralKey.PublicKey(), 45 | } 46 | aliceSession, err := session.InitializeAliceSessionRecord(random, aliceParams) 47 | require.NoError(t, err) 48 | 49 | bobParams := &ratchet.BobParameters{ 50 | OurIdentityKeyPair: bobIdentity, 51 | OurSignedPreKeyPair: bobBaseKey, 52 | OurOneTimePreKeyPair: nil, 53 | OurRatchetKeyPair: bobEphemeralKey, 54 | TheirIdentityKey: aliceIdentity.IdentityKey(), 55 | TheirBaseKey: aliceBaseKey.PublicKey(), 56 | } 57 | bobSession, err := session.InitializeBobSessionRecord(bobParams) 58 | require.NoError(t, err) 59 | 60 | return aliceSession, bobSession 61 | } 62 | --------------------------------------------------------------------------------